cppspmd_math.h 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. // Do not include this header directly.
  2. //
  3. // Copyright 2020-2021 Binomial LLC
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. // The general goal of these vectorized estimated math functions is scalability/performance.
  17. // There are explictly no checks NaN's/Inf's on the input arguments. There are no assertions either.
  18. // These are fast estimate functions - if you need more than that, use stdlib. Please do a proper
  19. // engineering analysis before relying on them.
  20. // I have chosen functions written by others, ported them to CppSPMD, then measured their abs/rel errors.
  21. // I compared each to the ones in DirectXMath and stdlib's for accuracy/performance.
  22. CPPSPMD_FORCE_INLINE vfloat fmod_inv(const vfloat& a, const vfloat& b, const vfloat& b_inv)
  23. {
  24. vfloat c = frac(abs(a * b_inv)) * abs(b);
  25. return spmd_ternaryf(a < 0, -c, c);
  26. }
  27. CPPSPMD_FORCE_INLINE vfloat fmod_inv_p(const vfloat& a, const vfloat& b, const vfloat& b_inv)
  28. {
  29. return frac(a * b_inv) * b;
  30. }
  31. // Avoids dividing by zero or very small values.
  32. CPPSPMD_FORCE_INLINE vfloat safe_div(vfloat a, vfloat b, float fDivThresh = 1e-7f)
  33. {
  34. return a / spmd_ternaryf( abs(b) > fDivThresh, b, spmd_ternaryf(b < 0.0f, -fDivThresh, fDivThresh) );
  35. }
  36. /*
  37. clang 9.0.0 for win /fp:precise release
  38. f range: 0.0000000000001250 10000000000.0000000000000000, vals: 1073741824
  39. log2_est():
  40. max abs err: 0.0000023076808731
  41. max rel err: 0.0000000756678881
  42. avg abs err: 0.0000007535452724
  43. avg rel err: 0.0000000235117843
  44. XMVectorLog2():
  45. max abs err: 0.0000023329709933
  46. max rel err: 0.0000000826961046
  47. avg abs err: 0.0000007564889684
  48. avg rel err: 0.0000000236051899
  49. std::log2f():
  50. max abs err: 0.0000020265979401
  51. max rel err: 0.0000000626647654
  52. avg abs err: 0.0000007494445227
  53. avg rel err: 0.0000000233800985
  54. */
  55. // See https://tech.ebayinc.com/engineering/fast-approximate-logarithms-part-iii-the-formulas/
  56. inline vfloat spmd_kernel::log2_est(vfloat v)
  57. {
  58. vfloat signif, fexp;
  59. // Just clamp to a very small value, instead of checking for invalid inputs.
  60. vfloat x = max(v, 2.2e-38f);
  61. /*
  62. * Assume IEEE representation, which is sgn(1):exp(8):frac(23)
  63. * representing (1+frac)*2^(exp-127). Call 1+frac the significand
  64. */
  65. // get exponent
  66. vint ux1_i = cast_vfloat_to_vint(x);
  67. vint exp = VUINT_SHIFT_RIGHT(ux1_i & 0x7F800000, 23);
  68. // actual exponent is exp-127, will subtract 127 later
  69. vint ux2_i;
  70. vfloat ux2_f;
  71. vint greater = ux1_i & 0x00400000; // true if signif > 1.5
  72. SPMD_SIF(greater != 0)
  73. {
  74. // signif >= 1.5 so need to divide by 2. Accomplish this by stuffing exp = 126 which corresponds to an exponent of -1
  75. store_all(ux2_i, (ux1_i & 0x007FFFFF) | 0x3f000000);
  76. store_all(ux2_f, cast_vint_to_vfloat(ux2_i));
  77. // 126 instead of 127 compensates for division by 2
  78. store_all(fexp, vfloat(exp - 126));
  79. }
  80. SPMD_SELSE(greater != 0)
  81. {
  82. // get signif by stuffing exp = 127 which corresponds to an exponent of 0
  83. store(ux2_i, (ux1_i & 0x007FFFFF) | 0x3f800000);
  84. store(ux2_f, cast_vint_to_vfloat(ux2_i));
  85. store(fexp, vfloat(exp - 127));
  86. }
  87. SPMD_SENDIF
  88. store_all(signif, ux2_f);
  89. store_all(signif, signif - 1.0f);
  90. const float a = 0.1501692f, b = 3.4226132f, c = 5.0225057f, d = 4.1130283f, e = 3.4813372f;
  91. vfloat xm1 = signif;
  92. vfloat xm1sqr = xm1 * xm1;
  93. return fexp + ((a * (xm1sqr * xm1) + b * xm1sqr + c * xm1) / (xm1sqr + d * xm1 + e));
  94. // fma lowers accuracy for SSE4.1 - no idea why (compiler reordering?)
  95. //return fexp + ((vfma(a, (xm1sqr * xm1), vfma(b, xm1sqr, c * xm1))) / (xm1sqr + vfma(d, xm1, e)));
  96. }
  97. // Uses log2_est(), so this function must be <= the precision of that.
  98. inline vfloat spmd_kernel::log_est(vfloat v)
  99. {
  100. return log2_est(v) * 0.693147181f;
  101. }
  102. CPPSPMD_FORCE_INLINE void spmd_kernel::reduce_expb(vfloat& arg, vfloat& two_int_a, vint& adjustment)
  103. {
  104. // Assume we're using equation (2)
  105. store_all(adjustment, 0);
  106. // integer part of the input argument
  107. vint int_arg = (vint)arg;
  108. // if frac(arg) is in [0.5, 1.0]...
  109. SPMD_SIF((arg - int_arg) > 0.5f)
  110. {
  111. store(adjustment, 1);
  112. // then change it to [0.0, 0.5]
  113. store(arg, arg - 0.5f);
  114. }
  115. SPMD_SENDIF
  116. // arg == just the fractional part
  117. store_all(arg, arg - (vfloat)int_arg);
  118. // Now compute 2** (int) arg.
  119. store_all(int_arg, min(int_arg + 127, 254));
  120. store_all(two_int_a, cast_vint_to_vfloat(VINT_SHIFT_LEFT(int_arg, 23)));
  121. }
  122. /*
  123. clang 9.0.0 for win /fp:precise release
  124. f range : -50.0000000000000000 49.9999940395355225, vals : 16777216
  125. exp2_est():
  126. Total passed near - zero check : 16777216
  127. Total sign diffs : 0
  128. max abs err: 1668910609.7500000000000000
  129. max rel err: 0.0000015642030031
  130. avg abs err: 10793794.4007573910057545
  131. avg rel err: 0.0000003890893282
  132. XMVectorExp2():
  133. Total passed near-zero check: 16777216
  134. Total sign diffs: 0
  135. max abs err: 1665552836.8750000000000000
  136. max rel err: 0.0000114674862370
  137. avg abs err: 10771868.2627860084176064
  138. avg rel err: 0.0000011218880770
  139. std::exp2f():
  140. Total passed near-zero check: 16777216
  141. Total sign diffs: 0
  142. max abs err: 1591636585.6250000000000000
  143. max rel err: 0.0000014849731018
  144. avg abs err: 10775800.3204844966530800
  145. avg rel err: 0.0000003851496422
  146. */
  147. // http://www.ganssle.com/item/approximations-c-code-exponentiation-log.htm
  148. inline vfloat spmd_kernel::exp2_est(vfloat arg)
  149. {
  150. SPMD_BEGIN_CALL
  151. const vfloat P00 = +7.2152891521493f;
  152. const vfloat P01 = +0.0576900723731f;
  153. const vfloat Q00 = +20.8189237930062f;
  154. const vfloat Q01 = +1.0f;
  155. const vfloat sqrt2 = 1.4142135623730950488f; // sqrt(2) for scaling
  156. vfloat result = 0.0f;
  157. // Return 0 if arg is too large.
  158. // We're not introducing inf/nan's into calculations, or risk doing so by returning huge default values.
  159. SPMD_IF(abs(arg) > 126.0f)
  160. {
  161. spmd_return();
  162. }
  163. SPMD_END_IF
  164. // 2**(int(a))
  165. vfloat two_int_a;
  166. // set to 1 by reduce_expb
  167. vint adjustment;
  168. // 0 if arg is +; 1 if negative
  169. vint negative = 0;
  170. // If the input is negative, invert it. At the end we'll take the reciprocal, since n**(-1) = 1/(n**x).
  171. SPMD_SIF(arg < 0.0f)
  172. {
  173. store(arg, -arg);
  174. store(negative, 1);
  175. }
  176. SPMD_SENDIF
  177. store_all(arg, min(arg, 126.0f));
  178. // reduce to [0.0, 0.5]
  179. reduce_expb(arg, two_int_a, adjustment);
  180. // The format of the polynomial is:
  181. // answer=(Q(x**2) + x*P(x**2))/(Q(x**2) - x*P(x**2))
  182. //
  183. // The following computes the polynomial in several steps:
  184. // Q(x**2)
  185. vfloat Q = vfma(Q01, (arg * arg), Q00);
  186. // x*P(x**2)
  187. vfloat x_P = arg * (vfma(P01, arg * arg, P00));
  188. vfloat answer = (Q + x_P) / (Q - x_P);
  189. // Now correct for the scaling factor of 2**(int(a))
  190. store_all(answer, answer * two_int_a);
  191. // If the result had a fractional part > 0.5, correct for that
  192. store_all(answer, spmd_ternaryf(adjustment != 0, answer * sqrt2, answer));
  193. // Correct for a negative input
  194. SPMD_SIF(negative != 0)
  195. {
  196. store(answer, 1.0f / answer);
  197. }
  198. SPMD_SENDIF
  199. store(result, answer);
  200. return result;
  201. }
  202. inline vfloat spmd_kernel::exp_est(vfloat arg)
  203. {
  204. // e^x = exp2(x / log_base_e(2))
  205. // constant is 1.0/(log(2)/log(e)) or 1/log(2)
  206. return exp2_est(arg * 1.44269504f);
  207. }
  208. inline vfloat spmd_kernel::pow_est(vfloat arg1, vfloat arg2)
  209. {
  210. return exp_est(log_est(arg1) * arg2);
  211. }
  212. /*
  213. clang 9.0.0 for win /fp:precise release
  214. Total near-zero: 144, output above near-zero tresh: 30
  215. Total near-zero avg: 0.0000067941016621 max: 0.0000134706497192
  216. Total near-zero sign diffs: 5
  217. Total passed near-zero check: 16777072
  218. Total sign diffs: 5
  219. max abs err: 0.0000031375306036
  220. max rel err: 0.1140846017075028
  221. avg abs err: 0.0000003026226621
  222. avg rel err: 0.0000033564977623
  223. */
  224. // Math from this web page: http://developer.download.nvidia.com/cg/sin.html
  225. // This is ~2x slower than sin_est() or cos_est(), and less accurate, but I'm keeping it here for comparison purposes to help validate/sanity check sin_est() and cos_est().
  226. inline vfloat spmd_kernel::sincos_est_a(vfloat a, bool sin_flag)
  227. {
  228. const float c0_x = 0.0f, c0_y = 0.5f, c0_z = 1.0f;
  229. const float c1_x = 0.25f, c1_y = -9.0f, c1_z = 0.75f, c1_w = 0.159154943091f;
  230. const float c2_x = 24.9808039603f, c2_y = -24.9808039603f, c2_z = -60.1458091736f, c2_w = 60.1458091736f;
  231. const float c3_x = 85.4537887573f, c3_y = -85.4537887573f, c3_z = -64.9393539429f, c3_w = 64.9393539429f;
  232. const float c4_x = 19.7392082214f, c4_y = -19.7392082214f, c4_z = -1.0f, c4_w = 1.0f;
  233. vfloat r0_x, r0_y, r0_z, r1_x, r1_y, r1_z, r2_x, r2_y, r2_z;
  234. store_all(r1_x, sin_flag ? vfms(c1_w, a, c1_x) : c1_w * a);
  235. store_all(r1_y, frac(r1_x));
  236. store_all(r2_x, (vfloat)(r1_y < c1_x));
  237. store_all(r2_y, (vfloat)(r1_y >= c1_y));
  238. store_all(r2_z, (vfloat)(r1_y >= c1_z));
  239. store_all(r2_y, vfma(r2_x, c4_z, vfma(r2_y, c4_w, r2_z * c4_z)));
  240. store_all(r0_x, c0_x - r1_y);
  241. store_all(r0_y, c0_y - r1_y);
  242. store_all(r0_z, c0_z - r1_y);
  243. store_all(r0_x, r0_x * r0_x);
  244. store_all(r0_y, r0_y * r0_y);
  245. store_all(r0_z, r0_z * r0_z);
  246. store_all(r1_x, vfma(c2_x, r0_x, c2_z));
  247. store_all(r1_y, vfma(c2_y, r0_y, c2_w));
  248. store_all(r1_z, vfma(c2_x, r0_z, c2_z));
  249. store_all(r1_x, vfma(r1_x, r0_x, c3_x));
  250. store_all(r1_y, vfma(r1_y, r0_y, c3_y));
  251. store_all(r1_z, vfma(r1_z, r0_z, c3_x));
  252. store_all(r1_x, vfma(r1_x, r0_x, c3_z));
  253. store_all(r1_y, vfma(r1_y, r0_y, c3_w));
  254. store_all(r1_z, vfma(r1_z, r0_z, c3_z));
  255. store_all(r1_x, vfma(r1_x, r0_x, c4_x));
  256. store_all(r1_y, vfma(r1_y, r0_y, c4_y));
  257. store_all(r1_z, vfma(r1_z, r0_z, c4_x));
  258. store_all(r1_x, vfma(r1_x, r0_x, c4_z));
  259. store_all(r1_y, vfma(r1_y, r0_y, c4_w));
  260. store_all(r1_z, vfma(r1_z, r0_z, c4_z));
  261. store_all(r0_x, vfnma(r1_x, r2_x, vfnma(r1_y, r2_y, r1_z * -r2_z)));
  262. return r0_x;
  263. }
  264. // positive values only
  265. CPPSPMD_FORCE_INLINE vfloat spmd_kernel::recip_est1(const vfloat& q)
  266. {
  267. //const int mag = 0x7EF312AC; // 2 NR iters, 3 is 0x7EEEEBB3
  268. const int mag = 0x7EF311C3;
  269. const float fMinThresh = .0000125f;
  270. vfloat l = spmd_ternaryf(q >= fMinThresh, q, cast_vint_to_vfloat(vint(mag)));
  271. vint x_l = vint(mag) - cast_vfloat_to_vint(l);
  272. vfloat rcp_l = cast_vint_to_vfloat(x_l);
  273. return rcp_l * vfnma(rcp_l, q, 2.0f);
  274. }
  275. CPPSPMD_FORCE_INLINE vfloat spmd_kernel::recip_est1_pn(const vfloat& t)
  276. {
  277. //const int mag = 0x7EF312AC; // 2 NR iters, 3 is 0x7EEEEBB3
  278. const int mag = 0x7EF311C3;
  279. const float fMinThresh = .0000125f;
  280. vfloat s = sign(t);
  281. vfloat q = abs(t);
  282. vfloat l = spmd_ternaryf(q >= fMinThresh, q, cast_vint_to_vfloat(vint(mag)));
  283. vint x_l = vint(mag) - cast_vfloat_to_vint(l);
  284. vfloat rcp_l = cast_vint_to_vfloat(x_l);
  285. return rcp_l * vfnma(rcp_l, q, 2.0f) * s;
  286. }
  287. // https://basesandframes.files.wordpress.com/2020/04/even_faster_math_functions_green_2020.pdf
  288. // https://github.com/hcs0/Hackers-Delight/blob/master/rsqrt.c.txt
  289. CPPSPMD_FORCE_INLINE vfloat spmd_kernel::rsqrt_est1(vfloat x0)
  290. {
  291. vfloat xhalf = 0.5f * x0;
  292. vfloat x = cast_vint_to_vfloat(vint(0x5F375A82) - (VINT_SHIFT_RIGHT(cast_vfloat_to_vint(x0), 1)));
  293. return x * vfnma(xhalf * x, x, 1.5008909f);
  294. }
  295. CPPSPMD_FORCE_INLINE vfloat spmd_kernel::rsqrt_est2(vfloat x0)
  296. {
  297. vfloat xhalf = 0.5f * x0;
  298. vfloat x = cast_vint_to_vfloat(vint(0x5F37599E) - (VINT_SHIFT_RIGHT(cast_vfloat_to_vint(x0), 1)));
  299. vfloat x1 = x * vfnma(xhalf * x, x, 1.5);
  300. vfloat x2 = x1 * vfnma(xhalf * x1, x1, 1.5);
  301. return x2;
  302. }
  303. // Math from: http://developer.download.nvidia.com/cg/atan2.html
  304. // TODO: Needs more validation, parameter checking.
  305. CPPSPMD_FORCE_INLINE vfloat spmd_kernel::atan2_est(vfloat y, vfloat x)
  306. {
  307. vfloat t1 = abs(y);
  308. vfloat t3 = abs(x);
  309. vfloat t0 = max(t3, t1);
  310. store_all(t1, min(t3, t1));
  311. store_all(t3, t1 / t0);
  312. vfloat t4 = t3 * t3;
  313. store_all(t0, vfma(-0.013480470f, t4, 0.057477314f));
  314. store_all(t0, vfms(t0, t4, 0.121239071f));
  315. store_all(t0, vfma(t0, t4, 0.195635925f));
  316. store_all(t0, vfms(t0, t4, 0.332994597f));
  317. store_all(t0, vfma(t0, t4, 0.999995630f));
  318. store_all(t3, t0 * t3);
  319. store_all(t3, spmd_ternaryf(abs(y) > abs(x), vfloat(1.570796327f) - t3, t3));
  320. store_all(t3, spmd_ternaryf(x < 0.0f, vfloat(3.141592654f) - t3, t3));
  321. store_all(t3, spmd_ternaryf(y < 0.0f, -t3, t3));
  322. return t3;
  323. }
  324. /*
  325. clang 9.0.0 for win /fp:precise release
  326. Tested range: -25.1327412287183449 25.1327382326621169, vals : 16777216
  327. Skipped angles near 90/270 within +- .001 radians.
  328. Near-zero threshold: .0000125f
  329. Near-zero output above check threshold: 1e-6f
  330. Total near-zero: 144, output above near-zero tresh: 20
  331. Total near-zero avg: 0.0000067510751968 max: 0.0000133514404297
  332. Total near-zero sign diffs: 5
  333. Total passed near-zero check: 16766400
  334. Total sign diffs: 5
  335. max abs err: 1.4982600811139264
  336. max rel err: 0.1459155900188041
  337. avg rel err: 0.0000054659502568
  338. XMVectorTan() precise:
  339. Total near-zero: 144, output above near-zero tresh: 18
  340. Total near-zero avg: 0.0000067641216186 max: 0.0000133524126795
  341. Total near-zero sign diffs: 0
  342. Total passed near-zero check: 16766400
  343. Total sign diffs: 0
  344. max abs err: 1.9883573246424930
  345. max rel err: 0.1459724171926864
  346. avg rel err: 0.0000054965766843
  347. std::tanf():
  348. Total near-zero: 144, output above near-zero tresh: 0
  349. Total near-zero avg: 0.0000067116930779 max: 0.0000127713074107
  350. Total near-zero sign diffs: 11
  351. Total passed near-zero check: 16766400
  352. Total sign diffs: 11
  353. max abs err: 0.8989131818294709
  354. max rel err: 0.0573181403173166
  355. avg rel err: 0.0000030791301203
  356. Originally from:
  357. http://www.ganssle.com/approx.htm
  358. */
  359. CPPSPMD_FORCE_INLINE vfloat spmd_kernel::tan82(vfloat x)
  360. {
  361. // Original double version was 8.2 digits
  362. //double c1 = 211.849369664121f, c2 = -12.5288887278448f, c3 = 269.7350131214121f, c4 = -71.4145309347748f;
  363. // Tuned float constants for lower avg rel error (without using FMA3):
  364. const float c1 = 211.849350f, c2 = -12.5288887f, c3 = 269.734985f, c4 = -71.4145203f;
  365. vfloat x2 = x * x;
  366. return (x * (vfma(c2, x2, c1)) / (vfma(x2, (c4 + x2), c3)));
  367. }
  368. // Don't call this for angles close to 90/270!.
  369. inline vfloat spmd_kernel::tan_est(vfloat x)
  370. {
  371. const float fPi = 3.141592653589793f, fOneOverPi = 0.3183098861837907f;
  372. CPPSPMD_DECL(const uint8_t, s_table0[16]) = { 128 + 0, 128 + 2, 128 + -2, 128 + 4, 128 + 0, 128 + 2, 128 + -2, 128 + 4, 128 + 0, 128 + 2, 128 + -2, 128 + 4, 128 + 0, 128 + 2, 128 + -2, 128 + 4 };
  373. vint table = init_lookup4(s_table0); // a load
  374. vint sgn = cast_vfloat_to_vint(x) & 0x80000000;
  375. store_all(x, abs(x));
  376. vfloat orig_x = x;
  377. vfloat q = x * fOneOverPi;
  378. store_all(x, q - floor(q));
  379. vfloat x4 = x * 4.0f;
  380. vint octant = (vint)(x4);
  381. vfloat x0 = spmd_ternaryf((octant & 1) != 0, -x4, x4);
  382. vint k = table_lookup4_8(octant, table) & 0xFF; // a shuffle
  383. vfloat bias = (vfloat)k + -128.0f;
  384. vfloat y = x0 + bias;
  385. vfloat z = tan82(y);
  386. vfloat r;
  387. vbool octant_one_or_two = (octant == 1) || (octant == 2);
  388. // SPMD optimization - skip costly divide if we can
  389. if (spmd_any(octant_one_or_two))
  390. {
  391. const float fDivThresh = .4371e-7f;
  392. vfloat one_over_z = 1.0f / spmd_ternaryf(abs(z) > fDivThresh, z, spmd_ternaryf(z < 0.0f, -fDivThresh, fDivThresh));
  393. vfloat b = spmd_ternaryf(octant_one_or_two, one_over_z, z);
  394. store_all(r, spmd_ternaryf((octant & 2) != 0, -b, b));
  395. }
  396. else
  397. {
  398. store_all(r, spmd_ternaryf(octant == 0, z, -z));
  399. }
  400. // Small angle approximation, to decrease the max rel error near Pi.
  401. SPMD_SIF(x >= (1.0f - .0003125f*4.0f))
  402. {
  403. store(r, vfnma(floor(q) + 1.0f, fPi, orig_x));
  404. }
  405. SPMD_SENDIF
  406. return cast_vint_to_vfloat(cast_vfloat_to_vint(r) ^ sgn);
  407. }
  408. inline void spmd_kernel::seed_rand(rand_context& x, vint seed)
  409. {
  410. store(x.a, 0xf1ea5eed);
  411. store(x.b, seed ^ 0xd8487b1f);
  412. store(x.c, seed ^ 0xdbadef9a);
  413. store(x.d, seed);
  414. for (int i = 0; i < 20; ++i)
  415. (void)get_randu(x);
  416. }
  417. // https://burtleburtle.net/bob/rand/smallprng.html
  418. // Returns 32-bit unsigned random numbers.
  419. inline vint spmd_kernel::get_randu(rand_context& x)
  420. {
  421. vint e = x.a - VINT_ROT(x.b, 27);
  422. store(x.a, x.b ^ VINT_ROT(x.c, 17));
  423. store(x.b, x.c + x.d);
  424. store(x.c, x.d + e);
  425. store(x.d, e + x.a);
  426. return x.d;
  427. }
  428. // Returns random numbers between [low, high), or low if low >= high
  429. inline vint spmd_kernel::get_randi(rand_context& x, vint low, vint high)
  430. {
  431. vint rnd = get_randu(x);
  432. vint range = high - low;
  433. vint rnd_range = mulhiu(rnd, range);
  434. return spmd_ternaryi(low < high, low + rnd_range, low);
  435. }
  436. // Returns random numbers between [low, high), or low if low >= high
  437. inline vfloat spmd_kernel::get_randf(rand_context& x, vfloat low, vfloat high)
  438. {
  439. vint rndi = get_randu(x) & 0x7fffff;
  440. vfloat rnd = (vfloat)(rndi) * (1.0f / 8388608.0f);
  441. return spmd_ternaryf(low < high, vfma(high - low, rnd, low), low);
  442. }
  443. CPPSPMD_FORCE_INLINE void spmd_kernel::init_reverse_bits(vint& tab1, vint& tab2)
  444. {
  445. const uint8_t tab1_bytes[16] = { 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15 };
  446. const uint8_t tab2_bytes[16] = { 0, 8 << 4, 4 << 4, 12 << 4, 2 << 4, 10 << 4, 6 << 4, 14 << 4, 1 << 4, 9 << 4, 5 << 4, 13 << 4, 3 << 4, 11 << 4, 7 << 4, 15 << 4 };
  447. store_all(tab1, init_lookup4(tab1_bytes));
  448. store_all(tab2, init_lookup4(tab2_bytes));
  449. }
  450. CPPSPMD_FORCE_INLINE vint spmd_kernel::reverse_bits(vint k, vint tab1, vint tab2)
  451. {
  452. vint r0 = table_lookup4_8(k & 0x7F7F7F7F, tab2);
  453. vint r1 = table_lookup4_8(VUINT_SHIFT_RIGHT(k, 4) & 0x7F7F7F7F, tab1);
  454. vint r3 = r0 | r1;
  455. return byteswap(r3);
  456. }
  457. CPPSPMD_FORCE_INLINE vint spmd_kernel::count_leading_zeros(vint x)
  458. {
  459. CPPSPMD_DECL(const uint8_t, s_tab[16]) = { 0, 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 };
  460. vint tab = init_lookup4(s_tab);
  461. //x <= 0x0000ffff
  462. vbool c0 = (x & 0xFFFF0000) == 0;
  463. vint n0 = spmd_ternaryi(c0, 16, 0);
  464. vint x0 = spmd_ternaryi(c0, VINT_SHIFT_LEFT(x, 16), x);
  465. //x <= 0x00ffffff
  466. vbool c1 = (x0 & 0xFF000000) == 0;
  467. vint n1 = spmd_ternaryi(c1, n0 + 8, n0);
  468. vint x1 = spmd_ternaryi(c1, VINT_SHIFT_LEFT(x0, 8), x0);
  469. //x <= 0x0fffffff
  470. vbool c2 = (x1 & 0xF0000000) == 0;
  471. vint n2 = spmd_ternaryi(c2, n1 + 4, n1);
  472. vint x2 = spmd_ternaryi(c2, VINT_SHIFT_LEFT(x1, 4), x1);
  473. return table_lookup4_8(VUINT_SHIFT_RIGHT(x2, 28), tab) + n2;
  474. }
  475. CPPSPMD_FORCE_INLINE vint spmd_kernel::count_leading_zeros_alt(vint x)
  476. {
  477. //x <= 0x0000ffff
  478. vbool c0 = (x & 0xFFFF0000) == 0;
  479. vint n0 = spmd_ternaryi(c0, 16, 0);
  480. vint x0 = spmd_ternaryi(c0, VINT_SHIFT_LEFT(x, 16), x);
  481. //x <= 0x00ffffff
  482. vbool c1 = (x0 & 0xFF000000) == 0;
  483. vint n1 = spmd_ternaryi(c1, n0 + 8, n0);
  484. vint x1 = spmd_ternaryi(c1, VINT_SHIFT_LEFT(x0, 8), x0);
  485. //x <= 0x0fffffff
  486. vbool c2 = (x1 & 0xF0000000) == 0;
  487. vint n2 = spmd_ternaryi(c2, n1 + 4, n1);
  488. vint x2 = spmd_ternaryi(c2, VINT_SHIFT_LEFT(x1, 4), x1);
  489. // x <= 0x3fffffff
  490. vbool c3 = (x2 & 0xC0000000) == 0;
  491. vint n3 = spmd_ternaryi(c3, n2 + 2, n2);
  492. vint x3 = spmd_ternaryi(c3, VINT_SHIFT_LEFT(x2, 2), x2);
  493. // x <= 0x7fffffff
  494. vbool c4 = (x3 & 0x80000000) == 0;
  495. return spmd_ternaryi(c4, n3 + 1, n3);
  496. }
  497. CPPSPMD_FORCE_INLINE vint spmd_kernel::count_trailing_zeros(vint x)
  498. {
  499. // cast the least significant bit in v to a float
  500. vfloat f = (vfloat)(x & -x);
  501. // extract exponent and adjust
  502. return VUINT_SHIFT_RIGHT(cast_vfloat_to_vint(f), 23) - 0x7F;
  503. }
  504. CPPSPMD_FORCE_INLINE vint spmd_kernel::count_set_bits(vint x)
  505. {
  506. vint v = x - (VUINT_SHIFT_RIGHT(x, 1) & 0x55555555);
  507. vint v1 = (v & 0x33333333) + (VUINT_SHIFT_RIGHT(v, 2) & 0x33333333);
  508. return VUINT_SHIFT_RIGHT(((v1 + VUINT_SHIFT_RIGHT(v1, 4) & 0xF0F0F0F) * 0x1010101), 24);
  509. }
  510. CPPSPMD_FORCE_INLINE vint cmple_epu16(const vint &a, const vint &b)
  511. {
  512. return cmpeq_epi16(subs_epu16(a, b), vint(0));
  513. }
  514. CPPSPMD_FORCE_INLINE vint cmpge_epu16(const vint &a, const vint &b)
  515. {
  516. return cmple_epu16(b, a);
  517. }
  518. CPPSPMD_FORCE_INLINE vint cmpgt_epu16(const vint &a, const vint &b)
  519. {
  520. return andnot(cmpeq_epi16(a, b), cmple_epu16(b, a));
  521. }
  522. CPPSPMD_FORCE_INLINE vint cmplt_epu16(const vint &a, const vint &b)
  523. {
  524. return cmpgt_epu16(b, a);
  525. }
  526. CPPSPMD_FORCE_INLINE vint cmpge_epi16(const vint &a, const vint &b)
  527. {
  528. return cmpeq_epi16(a, b) | cmpgt_epi16(a, b);
  529. }
  530. CPPSPMD_FORCE_INLINE vint cmple_epi16(const vint &a, const vint &b)
  531. {
  532. return cmpge_epi16(b, a);
  533. }
  534. void spmd_kernel::print_vint(vint v)
  535. {
  536. for (uint32_t i = 0; i < PROGRAM_COUNT; i++)
  537. printf("%i ", extract(v, i));
  538. printf("\n");
  539. }
  540. void spmd_kernel::print_vbool(vbool v)
  541. {
  542. for (uint32_t i = 0; i < PROGRAM_COUNT; i++)
  543. printf("%i ", extract(v, i) ? 1 : 0);
  544. printf("\n");
  545. }
  546. void spmd_kernel::print_vint_hex(vint v)
  547. {
  548. for (uint32_t i = 0; i < PROGRAM_COUNT; i++)
  549. printf("0x%X ", extract(v, i));
  550. printf("\n");
  551. }
  552. void spmd_kernel::print_active_lanes(const char *pPrefix)
  553. {
  554. CPPSPMD_DECL(int, flags[PROGRAM_COUNT]);
  555. memset(flags, 0, sizeof(flags));
  556. storeu_linear(flags, vint(1));
  557. if (pPrefix)
  558. printf("%s", pPrefix);
  559. for (uint32_t i = 0; i < PROGRAM_COUNT; i++)
  560. {
  561. if (flags[i])
  562. printf("%u ", i);
  563. }
  564. printf("\n");
  565. }
  566. void spmd_kernel::print_vfloat(vfloat v)
  567. {
  568. for (uint32_t i = 0; i < PROGRAM_COUNT; i++)
  569. printf("%f ", extract(v, i));
  570. printf("\n");
  571. }