zoh_utils.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. /*
  2. Copyright 2007 nVidia, Inc.
  3. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
  5. Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS,
  6. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  7. See the License for the specific language governing permissions and limitations under the License.
  8. */
  9. // Utility and common routines
  10. #include "zoh_utils.h"
  11. #include "nvmath/vector.inl"
  12. #include <math.h>
  13. using namespace nv;
  14. using namespace ZOH;
  15. static const int denom7_weights_64[] = {0, 9, 18, 27, 37, 46, 55, 64}; // divided by 64
  16. static const int denom15_weights_64[] = {0, 4, 9, 13, 17, 21, 26, 30, 34, 38, 43, 47, 51, 55, 60, 64}; // divided by 64
  17. /*static*/ Format Utils::FORMAT;
  18. int Utils::lerp(int a, int b, int i, int denom)
  19. {
  20. nvDebugCheck (denom == 3 || denom == 7 || denom == 15);
  21. nvDebugCheck (i >= 0 && i <= denom);
  22. int round = 32, shift = 6;
  23. const int *weights;
  24. switch(denom)
  25. {
  26. case 3: denom *= 5; i *= 5; // fall through to case 15
  27. case 15: weights = denom15_weights_64; break;
  28. case 7: weights = denom7_weights_64; break;
  29. default: nvUnreachable();
  30. }
  31. return (a*weights[denom-i] +b*weights[i] + round) >> shift;
  32. }
  33. Vector3 Utils::lerp(const Vector3& a, const Vector3 &b, int i, int denom)
  34. {
  35. nvDebugCheck (denom == 3 || denom == 7 || denom == 15);
  36. nvDebugCheck (i >= 0 && i <= denom);
  37. int shift = 6;
  38. const int *weights;
  39. switch(denom)
  40. {
  41. case 3: denom *= 5; i *= 5; // fall through to case 15
  42. case 15: weights = denom15_weights_64; break;
  43. case 7: weights = denom7_weights_64; break;
  44. default: nvUnreachable();
  45. }
  46. // no need to round these as this is an exact division
  47. return (a*float(weights[denom-i]) +b*float(weights[i])) / float(1 << shift);
  48. }
  49. /*
  50. For unsigned f16, clamp the input to [0,F16MAX]. Thus u15.
  51. For signed f16, clamp the input to [-F16MAX,F16MAX]. Thus s16.
  52. The conversions proceed as follows:
  53. unsigned f16: get bits. if high bit set, clamp to 0, else clamp to F16MAX.
  54. signed f16: get bits. extract exp+mantissa and clamp to F16MAX. return -value if sign bit was set, else value
  55. unsigned int: get bits. return as a positive value.
  56. signed int. get bits. return as a value in -32768..32767.
  57. The inverse conversions are just the inverse of the above.
  58. */
  59. // clamp the 3 channels of the input vector to the allowable range based on FORMAT
  60. // note that each channel is a float storing the allowable range as a bit pattern converted to float
  61. // that is, for unsigned f16 say, we would clamp each channel to the range [0, F16MAX]
  62. void Utils::clamp(Vector3 &v)
  63. {
  64. for (int i=0; i<3; ++i)
  65. {
  66. switch(Utils::FORMAT)
  67. {
  68. case UNSIGNED_F16:
  69. if (v.component[i] < 0.0) v.component[i] = 0;
  70. else if (v.component[i] > F16MAX) v.component[i] = F16MAX;
  71. break;
  72. case SIGNED_F16:
  73. if (v.component[i] < -F16MAX) v.component[i] = -F16MAX;
  74. else if (v.component[i] > F16MAX) v.component[i] = F16MAX;
  75. break;
  76. default:
  77. nvUnreachable();
  78. }
  79. }
  80. }
  81. // convert a u16 value to s17 (represented as an int) based on the format expected
  82. int Utils::ushort_to_format(unsigned short input)
  83. {
  84. int out, s;
  85. // clamp to the valid range we are expecting
  86. switch (Utils::FORMAT)
  87. {
  88. case UNSIGNED_F16:
  89. if (input & F16S_MASK) out = 0;
  90. else if (input > F16MAX) out = F16MAX;
  91. else out = input;
  92. break;
  93. case SIGNED_F16:
  94. s = input & F16S_MASK;
  95. input &= F16EM_MASK;
  96. if (input > F16MAX) out = F16MAX;
  97. else out = input;
  98. out = s ? -out : out;
  99. break;
  100. }
  101. return out;
  102. }
  103. // convert a s17 value to u16 based on the format expected
  104. unsigned short Utils::format_to_ushort(int input)
  105. {
  106. unsigned short out;
  107. // clamp to the valid range we are expecting
  108. switch (Utils::FORMAT)
  109. {
  110. case UNSIGNED_F16:
  111. nvDebugCheck (input >= 0 && input <= F16MAX);
  112. out = input;
  113. break;
  114. case SIGNED_F16:
  115. nvDebugCheck (input >= -F16MAX && input <= F16MAX);
  116. // convert to sign-magnitude
  117. int s;
  118. if (input < 0) { s = F16S_MASK; input = -input; }
  119. else { s = 0; }
  120. out = s | input;
  121. break;
  122. }
  123. return out;
  124. }
  125. // quantize the input range into equal-sized bins
  126. int Utils::quantize(float value, int prec)
  127. {
  128. int q, ivalue, s;
  129. nvDebugCheck (prec > 1); // didn't bother to make it work for 1
  130. value = (float)floor(value + 0.5);
  131. int bias = (prec > 10) ? ((1<<(prec-1))-1) : 0; // bias precisions 11..16 to get a more accurate quantization
  132. switch (Utils::FORMAT)
  133. {
  134. case UNSIGNED_F16:
  135. nvDebugCheck (value >= 0 && value <= F16MAX);
  136. ivalue = (int)value;
  137. q = ((ivalue << prec) + bias) / (F16MAX+1);
  138. nvDebugCheck (q >= 0 && q < (1 << prec));
  139. break;
  140. case SIGNED_F16:
  141. nvDebugCheck (value >= -F16MAX && value <= F16MAX);
  142. // convert to sign-magnitude
  143. ivalue = (int)value;
  144. if (ivalue < 0) { s = 1; ivalue = -ivalue; } else s = 0;
  145. q = ((ivalue << (prec-1)) + bias) / (F16MAX+1);
  146. if (s)
  147. q = -q;
  148. nvDebugCheck (q > -(1 << (prec-1)) && q < (1 << (prec-1)));
  149. break;
  150. }
  151. return q;
  152. }
  153. int Utils::finish_unquantize(int q, int prec)
  154. {
  155. if (Utils::FORMAT == UNSIGNED_F16)
  156. return (q * 31) >> 6; // scale the magnitude by 31/64
  157. else if (Utils::FORMAT == SIGNED_F16)
  158. return (q < 0) ? -(((-q) * 31) >> 5) : (q * 31) >> 5; // scale the magnitude by 31/32
  159. else
  160. return q;
  161. }
  162. // unquantize each bin to midpoint of original bin range, except
  163. // for the end bins which we push to an endpoint of the bin range.
  164. // we do this to ensure we can represent all possible original values.
  165. // the asymmetric end bins do not affect PSNR for the test images.
  166. //
  167. // code this function assuming an arbitrary bit pattern as the encoded block
  168. int Utils::unquantize(int q, int prec)
  169. {
  170. int unq, s;
  171. nvDebugCheck (prec > 1); // not implemented for prec 1
  172. switch (Utils::FORMAT)
  173. {
  174. // modify this case to move the multiplication by 31 after interpolation.
  175. // Need to use finish_unquantize.
  176. // since we have 16 bits available, let's unquantize this to 16 bits unsigned
  177. // thus the scale factor is [0-7c00)/[0-10000) = 31/64
  178. case UNSIGNED_F16:
  179. if (prec >= 15)
  180. unq = q;
  181. else if (q == 0)
  182. unq = 0;
  183. else if (q == ((1<<prec)-1))
  184. unq = U16MAX;
  185. else
  186. unq = (q * (U16MAX+1) + (U16MAX+1)/2) >> prec;
  187. break;
  188. // here, let's stick with S16 (no apparent quality benefit from going to S17)
  189. // range is (-7c00..7c00)/(-8000..8000) = 31/32
  190. case SIGNED_F16:
  191. // don't remove this test even though it appears equivalent to the code below
  192. // as it isn't -- the code below can overflow for prec = 16
  193. if (prec >= 16)
  194. unq = q;
  195. else
  196. {
  197. if (q < 0) { s = 1; q = -q; } else s = 0;
  198. if (q == 0)
  199. unq = 0;
  200. else if (q >= ((1<<(prec-1))-1))
  201. unq = s ? -S16MAX : S16MAX;
  202. else
  203. {
  204. unq = (q * (S16MAX+1) + (S16MAX+1)/2) >> (prec-1);
  205. if (s)
  206. unq = -unq;
  207. }
  208. }
  209. break;
  210. }
  211. return unq;
  212. }
  213. // pick a norm!
  214. #define NORM_EUCLIDEAN 1
  215. float Utils::norm(const Vector3 &a, const Vector3 &b)
  216. {
  217. #ifdef NORM_EUCLIDEAN
  218. return lengthSquared(a - b);
  219. #endif
  220. #ifdef NORM_ABS
  221. Vector3 err = a - b;
  222. return fabs(err.x) + fabs(err.y) + fabs(err.z);
  223. #endif
  224. }
  225. // parse <name>[<start>{:<end>}]{,}
  226. // the pointer starts here ^
  227. // name is 1 or 2 chars and matches field names. start and end are decimal numbers
  228. void Utils::parse(const char *encoding, int &ptr, Field &field, int &endbit, int &len)
  229. {
  230. if (ptr <= 0) return;
  231. --ptr;
  232. if (encoding[ptr] == ',') --ptr;
  233. nvDebugCheck (encoding[ptr] == ']');
  234. --ptr;
  235. endbit = 0;
  236. int scale = 1;
  237. while (encoding[ptr] != ':' && encoding[ptr] != '[')
  238. {
  239. nvDebugCheck(encoding[ptr] >= '0' && encoding[ptr] <= '9');
  240. endbit += (encoding[ptr--] - '0') * scale;
  241. scale *= 10;
  242. }
  243. int startbit = 0; scale = 1;
  244. if (encoding[ptr] == '[')
  245. startbit = endbit;
  246. else
  247. {
  248. ptr--;
  249. while (encoding[ptr] != '[')
  250. {
  251. nvDebugCheck(encoding[ptr] >= '0' && encoding[ptr] <= '9');
  252. startbit += (encoding[ptr--] - '0') * scale;
  253. scale *= 10;
  254. }
  255. }
  256. len = startbit - endbit + 1; // startbit>=endbit note
  257. --ptr;
  258. if (encoding[ptr] == 'm') field = FIELD_M;
  259. else if (encoding[ptr] == 'd') field = FIELD_D;
  260. else {
  261. // it's wxyz
  262. nvDebugCheck (encoding[ptr] >= 'w' && encoding[ptr] <= 'z');
  263. int foo = encoding[ptr--] - 'w';
  264. // now it is r g or b
  265. if (encoding[ptr] == 'r') foo += 10;
  266. else if (encoding[ptr] == 'g') foo += 20;
  267. else if (encoding[ptr] == 'b') foo += 30;
  268. else nvDebugCheck(0);
  269. field = (Field) foo;
  270. }
  271. }