vfloat16_avx512.h 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. // ======================================================================== //
  2. // Copyright 2009-2017 Intel Corporation //
  3. // //
  4. // Licensed under the Apache License, Version 2.0 (the "License"); //
  5. // you may not use this file except in compliance with the License. //
  6. // You may obtain a copy of the License at //
  7. // //
  8. // http://www.apache.org/licenses/LICENSE-2.0 //
  9. // //
  10. // Unless required by applicable law or agreed to in writing, software //
  11. // distributed under the License is distributed on an "AS IS" BASIS, //
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
  13. // See the License for the specific language governing permissions and //
  14. // limitations under the License. //
  15. // ======================================================================== //
  16. #pragma once
  17. namespace embree
  18. {
  19. /* 16-wide AVX-512 float type */
  20. template<>
  21. struct vfloat<16>
  22. {
  23. typedef vboolf16 Bool;
  24. typedef vint16 Int;
  25. typedef vfloat16 Float;
  26. enum { size = 16 }; // number of SIMD elements
  27. union { // data
  28. __m512 v;
  29. float f[16];
  30. int i[16];
  31. };
  32. ////////////////////////////////////////////////////////////////////////////////
  33. /// Constructors, Assignment & Cast Operators
  34. ////////////////////////////////////////////////////////////////////////////////
  35. __forceinline vfloat() {}
  36. __forceinline vfloat(const vfloat16& t) { v = t; }
  37. __forceinline vfloat16& operator=(const vfloat16& f) { v = f.v; return *this; }
  38. __forceinline vfloat(const __m512& t) { v = t; }
  39. __forceinline operator __m512 () const { return v; }
  40. __forceinline operator __m256 () const { return _mm512_castps512_ps256(v); }
  41. __forceinline vfloat(const float& f) {
  42. v = _mm512_set1_ps(f);
  43. }
  44. __forceinline vfloat(const float& a, const float& b, const float& c, const float& d) {
  45. v = _mm512_set4_ps(a,b,c,d);
  46. }
  47. __forceinline vfloat(const vfloat4 &i) {
  48. v = _mm512_broadcast_f32x4(i);
  49. }
  50. __forceinline vfloat(const vfloat4 &a, const vfloat4 &b, const vfloat4 &c, const vfloat4 &d) {
  51. v = _mm512_broadcast_f32x4(a);
  52. v = _mm512_insertf32x4(v, b, 1);
  53. v = _mm512_insertf32x4(v, c, 2);
  54. v = _mm512_insertf32x4(v, d, 3);
  55. }
  56. __forceinline vfloat(const vboolf16& mask, const vfloat4 &a, const vfloat4 &b) {
  57. v = _mm512_broadcast_f32x4(a);
  58. v = _mm512_mask_broadcast_f32x4(v,mask,b);
  59. }
  60. __forceinline vfloat(const vfloat8 &i) {
  61. v = _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_castps_pd(i)));
  62. }
  63. __forceinline vfloat(const vfloat8 &a, const vfloat8 &b) { // FIXME: optimize
  64. const vfloat aa = _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_castps_pd(a)));
  65. const vfloat bb = _mm512_castpd_ps(_mm512_broadcast_f64x4(_mm256_castps_pd(b)));
  66. v = _mm512_mask_blend_ps(0xff, bb, aa);
  67. }
  68. /* WARNING: due to f64x4 the mask is considered as an 8bit mask */
  69. __forceinline vfloat(const vboolf16& mask, const vfloat8 &a, const vfloat8 &b) {
  70. __m512d aa = _mm512_broadcast_f64x4(_mm256_castps_pd(a));
  71. aa = _mm512_mask_broadcast_f64x4(aa,mask,_mm256_castps_pd(b));
  72. v = _mm512_castpd_ps(aa);
  73. }
  74. __forceinline explicit vfloat(const __m512i& a) {
  75. v = _mm512_cvtepi32_ps(a);
  76. }
  77. ////////////////////////////////////////////////////////////////////////////////
  78. /// Loads and Stores
  79. ////////////////////////////////////////////////////////////////////////////////
  80. static __forceinline vfloat16 load (const void* const ptr) { return _mm512_load_ps((float*)ptr); }
  81. static __forceinline vfloat16 loadu(const void* const ptr) { return _mm512_loadu_ps((float*)ptr); }
  82. static __forceinline vfloat16 load (const vboolf16& mask, const void* const ptr) { return _mm512_mask_load_ps (_mm512_setzero_ps(),mask,(float*)ptr); }
  83. static __forceinline vfloat16 loadu(const vboolf16& mask, const void* const ptr) { return _mm512_mask_loadu_ps(_mm512_setzero_ps(),mask,(float*)ptr); }
  84. static __forceinline void store (void* const ptr, const vfloat16& v) { _mm512_store_ps ((float*)ptr,v); }
  85. static __forceinline void storeu(void* const ptr, const vfloat16& v) { _mm512_storeu_ps((float*)ptr,v); }
  86. static __forceinline void store (const vboolf16& mask, void* ptr, const vfloat16& v) { _mm512_mask_store_ps ((float*)ptr,mask,v); }
  87. static __forceinline void storeu(const vboolf16& mask, void* ptr, const vfloat16& v) { _mm512_mask_storeu_ps((float*)ptr,mask,v); }
  88. static __forceinline void store_nt(void *__restrict__ ptr, const vfloat16& a) {
  89. _mm512_stream_ps((float*)ptr,a);
  90. }
  91. static __forceinline vfloat16 broadcast(const float *const f) {
  92. return _mm512_set1_ps(*f);
  93. }
  94. static __forceinline vfloat16 compact(const vboolf16& mask, vfloat16 &v) {
  95. return _mm512_mask_compress_ps(v,mask,v);
  96. }
  97. static __forceinline vfloat16 compact(const vboolf16& mask, const vfloat16 &a, vfloat16 &b) {
  98. return _mm512_mask_compress_ps(a,mask,b);
  99. }
  100. static __forceinline vfloat16 loadu_compact(const vboolf16& mask, const void *const ptr) {
  101. return _mm512_mask_expandloadu_ps(vfloat16::undefined(),mask,(float*)ptr);
  102. }
  103. static __forceinline void storeu_compact(const vboolf16& mask, float *addr, const vfloat16 reg) {
  104. _mm512_mask_compressstoreu_ps(addr,mask,reg);
  105. }
  106. static __forceinline void storeu_compact_single(const vboolf16& mask, float * addr, const vfloat16 &reg) {
  107. //_mm512_mask_compressstoreu_ps(addr,mask,reg);
  108. *addr = mm512_cvtss_f32(_mm512_mask_compress_ps(reg,mask,reg));
  109. }
  110. ////////////////////////////////////////////////////////////////////////////////
  111. /// Constants
  112. ////////////////////////////////////////////////////////////////////////////////
  113. __forceinline vfloat( ZeroTy ) : v(_mm512_setzero_ps()) {}
  114. __forceinline vfloat( OneTy ) : v(_mm512_set1_ps(1.0f)) {}
  115. __forceinline vfloat( PosInfTy ) : v(_mm512_set1_ps(pos_inf)) {}
  116. __forceinline vfloat( NegInfTy ) : v(_mm512_set1_ps(neg_inf)) {}
  117. __forceinline vfloat( StepTy ) : v(_mm512_set_ps(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)) {}
  118. __forceinline vfloat( NaNTy ) : v(_mm512_set1_ps(nan)) {}
  119. __forceinline static vfloat16 undefined() { return _mm512_undefined(); }
  120. __forceinline static vfloat16 zero() { return _mm512_setzero_ps(); }
  121. __forceinline static vfloat16 one () { return _mm512_set1_ps(1.0f); }
  122. __forceinline static vfloat16 ulp () { return _mm512_set1_ps(embree::ulp); }
  123. __forceinline static vfloat16 inf () { return _mm512_set1_ps((float)pos_inf); }
  124. __forceinline static vfloat16 minus_inf () { return _mm512_set1_ps((float)neg_inf); }
  125. ////////////////////////////////////////////////////////////////////////////////
  126. /// Array Access
  127. ////////////////////////////////////////////////////////////////////////////////
  128. __forceinline float& operator [](const size_t index) { assert(index < 16); return f[index]; }
  129. __forceinline const float& operator [](const size_t index) const { assert(index < 16); return f[index]; }
  130. };
  131. ////////////////////////////////////////////////////////////////////////////////
  132. /// Unary Operators
  133. ////////////////////////////////////////////////////////////////////////////////
  134. __forceinline const vfloat16 asFloat ( const __m512i& a ) { return _mm512_castsi512_ps(a); }
  135. __forceinline const vfloat16 operator +( const vfloat16& a ) { return a; }
  136. __forceinline const vfloat16 operator -( const vfloat16& a ) { return _mm512_mul_ps(a,vfloat16(-1)); }
  137. __forceinline const vfloat16 abs ( const vfloat16& a ) { return _mm512_abs_ps(a); }
  138. __forceinline const vfloat16 signmsk ( const vfloat16& a ) { return _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(a),_mm512_set1_epi32(0x80000000))); }
  139. __forceinline const vfloat16 rcp(const vfloat16& a) {
  140. #if defined(__AVX512ER__)
  141. return _mm512_rcp28_ps(a);
  142. #else
  143. const vfloat16 r = _mm512_rcp14_ps(a.v);
  144. return _mm512_mul_ps(r, _mm512_fnmadd_ps(r, a, vfloat16(2.0f)));
  145. #endif
  146. }
  147. __forceinline const vfloat16 sqr ( const vfloat16& a ) { return _mm512_mul_ps(a,a); }
  148. __forceinline const vfloat16 sqrt ( const vfloat16& a ) { return _mm512_sqrt_ps(a); }
  149. __forceinline const vfloat16 rsqrt( const vfloat16& a )
  150. {
  151. #if defined(__AVX512VL__)
  152. const vfloat16 r = _mm512_rsqrt14_ps(a.v);
  153. return _mm512_fmadd_ps(_mm512_set1_ps(1.5f), r,
  154. _mm512_mul_ps(_mm512_mul_ps(_mm512_mul_ps(a, _mm512_set1_ps(-0.5f)), r), _mm512_mul_ps(r, r)));
  155. #else
  156. return _mm512_rsqrt28_ps(a.v);
  157. #endif
  158. }
  159. ////////////////////////////////////////////////////////////////////////////////
  160. /// Binary Operators
  161. ////////////////////////////////////////////////////////////////////////////////
  162. __forceinline const vfloat16 operator +( const vfloat16& a, const vfloat16& b ) { return _mm512_add_ps(a, b); }
  163. __forceinline const vfloat16 operator +( const vfloat16& a, const float& b ) { return a + vfloat16(b); }
  164. __forceinline const vfloat16 operator +( const float& a, const vfloat16& b ) { return vfloat16(a) + b; }
  165. __forceinline const vfloat16 operator -( const vfloat16& a, const vfloat16& b ) { return _mm512_sub_ps(a, b); }
  166. __forceinline const vfloat16 operator -( const vfloat16& a, const float& b ) { return a - vfloat16(b); }
  167. __forceinline const vfloat16 operator -( const float& a, const vfloat16& b ) { return vfloat16(a) - b; }
  168. __forceinline const vfloat16 operator *( const vfloat16& a, const vfloat16& b ) { return _mm512_mul_ps(a, b); }
  169. __forceinline const vfloat16 operator *( const vfloat16& a, const float& b ) { return a * vfloat16(b); }
  170. __forceinline const vfloat16 operator *( const float& a, const vfloat16& b ) { return vfloat16(a) * b; }
  171. __forceinline const vfloat16 operator /( const vfloat16& a, const vfloat16& b ) { return _mm512_div_ps(a,b); }
  172. __forceinline const vfloat16 operator /( const vfloat16& a, const float& b ) { return a/vfloat16(b); }
  173. __forceinline const vfloat16 operator /( const float& a, const vfloat16& b ) { return vfloat16(a)/b; }
  174. __forceinline const vfloat16 operator^(const vfloat16& a, const vfloat16& b) {
  175. return _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a),_mm512_castps_si512(b)));
  176. }
  177. __forceinline const vfloat16 min( const vfloat16& a, const vfloat16& b ) {
  178. return _mm512_min_ps(a,b);
  179. }
  180. __forceinline const vfloat16 min( const vfloat16& a, const float& b ) {
  181. return _mm512_min_ps(a,vfloat16(b));
  182. }
  183. __forceinline const vfloat16 min( const float& a, const vfloat16& b ) {
  184. return _mm512_min_ps(vfloat16(a),b);
  185. }
  186. __forceinline const vfloat16 max( const vfloat16& a, const vfloat16& b ) {
  187. return _mm512_max_ps(a,b);
  188. }
  189. __forceinline const vfloat16 max( const vfloat16& a, const float& b ) {
  190. return _mm512_max_ps(a,vfloat16(b));
  191. }
  192. __forceinline const vfloat16 max( const float& a, const vfloat16& b ) {
  193. return _mm512_max_ps(vfloat16(a),b);
  194. }
  195. __forceinline vfloat16 mask_add(const vboolf16& mask, const vfloat16& c, const vfloat16& a, const vfloat16& b) { return _mm512_mask_add_ps (c,mask,a,b); }
  196. __forceinline vfloat16 mask_min(const vboolf16& mask, const vfloat16& c, const vfloat16& a, const vfloat16& b) {
  197. return _mm512_mask_min_ps(c,mask,a,b);
  198. };
  199. __forceinline vfloat16 mask_max(const vboolf16& mask, const vfloat16& c, const vfloat16& a, const vfloat16& b) {
  200. return _mm512_mask_max_ps(c,mask,a,b);
  201. };
  202. __forceinline vfloat16 mini(const vfloat16& a, const vfloat16& b) {
  203. #if !defined(__AVX512ER__) // SKX
  204. const vint16 ai = _mm512_castps_si512(a);
  205. const vint16 bi = _mm512_castps_si512(b);
  206. const vint16 ci = _mm512_min_epi32(ai,bi);
  207. return _mm512_castsi512_ps(ci);
  208. #else // KNL
  209. return min(a,b);
  210. #endif
  211. }
  212. __forceinline vfloat16 maxi(const vfloat16& a, const vfloat16& b) {
  213. #if !defined(__AVX512ER__) // SKX
  214. const vint16 ai = _mm512_castps_si512(a);
  215. const vint16 bi = _mm512_castps_si512(b);
  216. const vint16 ci = _mm512_max_epi32(ai,bi);
  217. return _mm512_castsi512_ps(ci);
  218. #else // KNL
  219. return max(a,b);
  220. #endif
  221. }
  222. ////////////////////////////////////////////////////////////////////////////////
  223. /// Ternary Operators
  224. ////////////////////////////////////////////////////////////////////////////////
  225. __forceinline vfloat16 madd (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_ps(a,b,c); }
  226. __forceinline vfloat16 msub (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmsub_ps(a,b,c); }
  227. __forceinline vfloat16 nmadd (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmadd_ps(a,b,c); }
  228. __forceinline vfloat16 nmsub (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmsub_ps(a,b,c); }
  229. __forceinline vfloat16 mask_msub (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_fmsub_ps(a,mask,b,c); }
  230. __forceinline vfloat16 madd231 (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_ps(c,b,a); }
  231. __forceinline vfloat16 msub213 (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmsub_ps(a,b,c); }
  232. __forceinline vfloat16 msub231 (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmsub_ps(c,b,a); }
  233. __forceinline vfloat16 msubr231(const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fnmadd_ps(c,b,a); }
  234. ////////////////////////////////////////////////////////////////////////////////
  235. /// Operators with rounding
  236. ////////////////////////////////////////////////////////////////////////////////
  237. __forceinline vfloat16 madd_round_down (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_round_ps(a,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  238. __forceinline vfloat16 madd_round_up (const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_fmadd_round_ps(a,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  239. __forceinline vfloat16 mul_round_down (const vfloat16& a, const vfloat16& b) { return _mm512_mul_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  240. __forceinline vfloat16 mul_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_mul_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  241. __forceinline vfloat16 add_round_down (const vfloat16& a, const vfloat16& b) { return _mm512_add_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  242. __forceinline vfloat16 add_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_add_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  243. __forceinline vfloat16 sub_round_down (const vfloat16& a, const vfloat16& b) { return _mm512_sub_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  244. __forceinline vfloat16 sub_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_sub_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  245. __forceinline vfloat16 div_round_down (const vfloat16& a, const vfloat16& b) { return _mm512_div_round_ps(a,b,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  246. __forceinline vfloat16 div_round_up (const vfloat16& a, const vfloat16& b) { return _mm512_div_round_ps(a,b,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  247. __forceinline vfloat16 mask_msub_round_down (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_fmsub_round_ps(a,mask,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  248. __forceinline vfloat16 mask_msub_round_up (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_fmsub_round_ps(a,mask,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  249. __forceinline vfloat16 mask_mul_round_down (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_mul_round_ps(a,mask,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  250. __forceinline vfloat16 mask_mul_round_up (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_mul_round_ps(a,mask,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  251. __forceinline vfloat16 mask_sub_round_down (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_sub_round_ps(a,mask,b,c,_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); }
  252. __forceinline vfloat16 mask_sub_round_up (const vboolf16& mask,const vfloat16& a, const vfloat16& b, const vfloat16& c) { return _mm512_mask_sub_round_ps(a,mask,b,c,_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); }
  253. ////////////////////////////////////////////////////////////////////////////////
  254. /// Assignment Operators
  255. ////////////////////////////////////////////////////////////////////////////////
  256. __forceinline vfloat16& operator +=( vfloat16& a, const vfloat16& b ) { return a = a + b; }
  257. __forceinline vfloat16& operator +=( vfloat16& a, const float& b ) { return a = a + b; }
  258. __forceinline vfloat16& operator -=( vfloat16& a, const vfloat16& b ) { return a = a - b; }
  259. __forceinline vfloat16& operator -=( vfloat16& a, const float& b ) { return a = a - b; }
  260. __forceinline vfloat16& operator *=( vfloat16& a, const vfloat16& b ) { return a = a * b; }
  261. __forceinline vfloat16& operator *=( vfloat16& a, const float& b ) { return a = a * b; }
  262. __forceinline vfloat16& operator /=( vfloat16& a, const vfloat16& b ) { return a = a / b; }
  263. __forceinline vfloat16& operator /=( vfloat16& a, const float& b ) { return a = a / b; }
  264. ////////////////////////////////////////////////////////////////////////////////
  265. /// Comparison Operators + Select
  266. ////////////////////////////////////////////////////////////////////////////////
  267. __forceinline const vboolf16 operator ==( const vfloat16& a, const vfloat16& b ) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_EQ); }
  268. __forceinline const vboolf16 operator ==( const vfloat16& a, const float& b ) { return a == vfloat16(b); }
  269. __forceinline const vboolf16 operator ==( const float& a, const vfloat16& b ) { return vfloat16(a) == b; }
  270. __forceinline const vboolf16 operator !=( const vfloat16& a, const vfloat16& b ) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_NE); }
  271. __forceinline const vboolf16 operator !=( const vfloat16& a, const float& b ) { return a != vfloat16(b); }
  272. __forceinline const vboolf16 operator !=( const float& a, const vfloat16& b ) { return vfloat16(a) != b; }
  273. __forceinline const vboolf16 operator < ( const vfloat16& a, const vfloat16& b ) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LT); }
  274. __forceinline const vboolf16 operator < ( const vfloat16& a, const float& b ) { return a < vfloat16(b); }
  275. __forceinline const vboolf16 operator < ( const float& a, const vfloat16& b ) { return vfloat16(a) < b; }
  276. __forceinline const vboolf16 operator >=( const vfloat16& a, const vfloat16& b ) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GE); }
  277. __forceinline const vboolf16 operator >=( const vfloat16& a, const float& b ) { return a >= vfloat16(b); }
  278. __forceinline const vboolf16 operator >=( const float& a, const vfloat16& b ) { return vfloat16(a) >= b; }
  279. __forceinline const vboolf16 operator > ( const vfloat16& a, const vfloat16& b ) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GT); }
  280. __forceinline const vboolf16 operator > ( const vfloat16& a, const float& b ) { return a > vfloat16(b); }
  281. __forceinline const vboolf16 operator > ( const float& a, const vfloat16& b ) { return vfloat16(a) > b; }
  282. __forceinline const vboolf16 operator <=( const vfloat16& a, const vfloat16& b ) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LE); }
  283. __forceinline const vboolf16 operator <=( const vfloat16& a, const float& b ) { return a <= vfloat16(b); }
  284. __forceinline const vboolf16 operator <=( const float& a, const vfloat16& b ) { return vfloat16(a) <= b; }
  285. __forceinline vboolf16 eq(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_EQ); }
  286. __forceinline vboolf16 ne(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_NE); }
  287. __forceinline vboolf16 lt(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LT); }
  288. __forceinline vboolf16 ge(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GE); }
  289. __forceinline vboolf16 gt(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_GT); }
  290. __forceinline vboolf16 le(const vfloat16& a, const vfloat16& b) { return _mm512_cmp_ps_mask(a,b,_MM_CMPINT_LE); }
  291. __forceinline vboolf16 eq(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_EQ); }
  292. __forceinline vboolf16 ne(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_NE); }
  293. __forceinline vboolf16 lt(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_LT); }
  294. __forceinline vboolf16 ge(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_GE); }
  295. __forceinline vboolf16 gt(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_GT); }
  296. __forceinline vboolf16 le(const vboolf16& mask, const vfloat16& a, const vfloat16& b) { return _mm512_mask_cmp_ps_mask(mask,a,b,_MM_CMPINT_LE); }
  297. __forceinline const vfloat16 select( const vboolf16& s, const vfloat16& t, const vfloat16& f ) {
  298. return _mm512_mask_blend_ps(s, f, t);
  299. }
  300. __forceinline vfloat16 lerp(const vfloat16& a, const vfloat16& b, const vfloat16& t) {
  301. return madd(t,b-a,a);
  302. }
  303. __forceinline void xchg(vboolf16 m, vfloat16& a, vfloat16& b)
  304. {
  305. vfloat16 c = a;
  306. a = select(m,b,a);
  307. b = select(m,c,b);
  308. }
  309. ////////////////////////////////////////////////////////////////////////////////
  310. /// Rounding Functions
  311. ////////////////////////////////////////////////////////////////////////////////
  312. __forceinline vfloat16 floor(const vfloat16& a) {
  313. return _mm512_floor_ps(a);
  314. }
  315. __forceinline vfloat16 ceil (const vfloat16& a) {
  316. return _mm512_ceil_ps(a);
  317. }
  318. #if !defined (__clang__) // FIXME: not yet supported in clang v4.0.0
  319. __forceinline vfloat16 trunc(const vfloat16& a) {
  320. return _mm512_trunc_ps(a);
  321. }
  322. __forceinline vfloat16 frac( const vfloat16& a ) {
  323. return a-trunc(a);
  324. }
  325. #endif
  326. __forceinline vint16 floori (const vfloat16& a) {
  327. return _mm512_cvt_roundps_epi32(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
  328. }
  329. ////////////////////////////////////////////////////////////////////////////////
  330. /// Movement/Shifting/Shuffling Functions
  331. ////////////////////////////////////////////////////////////////////////////////
  332. template<size_t i>
  333. __forceinline const vfloat16 shuffle( const vfloat16& a ) {
  334. return _mm512_permute_ps(a, _MM_SHUFFLE(i, i, i, i));
  335. }
  336. template<int A, int B, int C, int D>
  337. __forceinline vfloat16 shuffle (const vfloat16& v) {
  338. return _mm512_permute_ps(v,_MM_SHUFFLE(D,C,B,A));
  339. }
  340. template<int i>
  341. __forceinline vfloat16 shuffle4(const vfloat16& x) {
  342. return _mm512_shuffle_f32x4(x,x,_MM_SHUFFLE(i,i,i,i));
  343. }
  344. template<int A, int B, int C, int D>
  345. __forceinline vfloat16 shuffle4(const vfloat16& x) {
  346. return _mm512_shuffle_f32x4(x,x,_MM_SHUFFLE(D,C,B,A));
  347. }
  348. __forceinline vfloat16 permute(vfloat16 v,__m512i index)
  349. {
  350. return _mm512_castsi512_ps(_mm512_permutexvar_epi32(index,_mm512_castps_si512(v)));
  351. }
  352. __forceinline vfloat16 reverse(const vfloat16 &a)
  353. {
  354. return permute(a,_mm512_setr_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
  355. }
  356. template<int i>
  357. __forceinline vfloat16 align_shift_right(const vfloat16 &a, const vfloat16 &b)
  358. {
  359. return _mm512_castsi512_ps(_mm512_alignr_epi32(_mm512_castps_si512(a),_mm512_castps_si512(b),i));
  360. };
  361. template<int i>
  362. __forceinline vfloat16 mask_align_shift_right(const vboolf16 &mask,vfloat16 &c,const vfloat16 &a, const vfloat16 &b)
  363. {
  364. return _mm512_castsi512_ps(_mm512_mask_alignr_epi32(_mm512_castps_si512(c),mask,_mm512_castps_si512(a),_mm512_castps_si512(b),i));
  365. };
  366. __forceinline vfloat16 shift_left_1(const vfloat16 &a) {
  367. vfloat16 z = vfloat16::zero();
  368. return mask_align_shift_right<15>(0xfffe,z,a,a);
  369. }
  370. __forceinline vfloat16 shift_right_1( const vfloat16& x) {
  371. return align_shift_right<1>(zero,x);
  372. }
  373. __forceinline float toScalar(const vfloat16& a) { return mm512_cvtss_f32(a); }
  374. template<int i> __forceinline const vfloat16 insert4(const vfloat16& a, const vfloat4& b) { return _mm512_insertf32x4(a, b, i); }
  375. template<int N, int i>
  376. vfloat<N> extractN(const vfloat16& v);
  377. template<> __forceinline vfloat4 extractN<4,0>(const vfloat16& v) { return _mm512_castps512_ps128(v); }
  378. template<> __forceinline vfloat4 extractN<4,1>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 1); }
  379. template<> __forceinline vfloat4 extractN<4,2>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 2); }
  380. template<> __forceinline vfloat4 extractN<4,3>(const vfloat16& v) { return _mm512_extractf32x4_ps(v, 3); }
  381. template<> __forceinline vfloat8 extractN<8,0>(const vfloat16& v) { return _mm512_castps512_ps256(v); }
  382. template<> __forceinline vfloat8 extractN<8,1>(const vfloat16& v) { return _mm512_extractf32x8_ps(v, 1); }
  383. template<int i> __forceinline vfloat4 extract4 (const vfloat16& v) { return _mm512_extractf32x4_ps(v, i); }
  384. template<> __forceinline vfloat4 extract4<0>(const vfloat16& v) { return _mm512_castps512_ps128(v); }
  385. template<int i> __forceinline vfloat8 extract8 (const vfloat16& v) { return _mm512_extractf32x8_ps(v, i); }
  386. template<> __forceinline vfloat8 extract8<0>(const vfloat16& v) { return _mm512_castps512_ps256(v); }
  387. ////////////////////////////////////////////////////////////////////////////////
  388. /// Reductions
  389. ////////////////////////////////////////////////////////////////////////////////
  390. __forceinline float reduce_add(const vfloat16 &a) { return _mm512_reduce_add_ps(a); }
  391. __forceinline float reduce_mul(const vfloat16 &a) { return _mm512_reduce_mul_ps(a); }
  392. __forceinline float reduce_min(const vfloat16 &a) { return _mm512_reduce_min_ps(a); }
  393. __forceinline float reduce_max(const vfloat16 &a) { return _mm512_reduce_max_ps(a); }
  394. __forceinline vfloat16 vreduce_add2(vfloat16 x) { return x + shuffle<1,0,3,2>(x); }
  395. __forceinline vfloat16 vreduce_add4(vfloat16 x) { x = vreduce_add2(x); return x + shuffle<2,3,0,1>(x); }
  396. __forceinline vfloat16 vreduce_add8(vfloat16 x) { x = vreduce_add4(x); return x + shuffle4<1,0,3,2>(x); }
  397. __forceinline vfloat16 vreduce_add (vfloat16 x) { x = vreduce_add8(x); return x + shuffle4<2,3,0,1>(x); }
  398. __forceinline vfloat16 vreduce_min2(vfloat16 x) { return min(x,shuffle<1,0,3,2>(x)); }
  399. __forceinline vfloat16 vreduce_min4(vfloat16 x) { x = vreduce_min2(x); return min(x,shuffle<2,3,0,1>(x)); }
  400. __forceinline vfloat16 vreduce_min8(vfloat16 x) { x = vreduce_min4(x); return min(x,shuffle4<1,0,3,2>(x)); }
  401. __forceinline vfloat16 vreduce_min (vfloat16 x) { x = vreduce_min8(x); return min(x,shuffle4<2,3,0,1>(x)); }
  402. __forceinline vfloat16 vreduce_max2(vfloat16 x) { return max(x,shuffle<1,0,3,2>(x)); }
  403. __forceinline vfloat16 vreduce_max4(vfloat16 x) { x = vreduce_max2(x); return max(x,shuffle<2,3,0,1>(x)); }
  404. __forceinline vfloat16 vreduce_max8(vfloat16 x) { x = vreduce_max4(x); return max(x,shuffle4<1,0,3,2>(x)); }
  405. __forceinline vfloat16 vreduce_max (vfloat16 x) { x = vreduce_max8(x); return max(x,shuffle4<2,3,0,1>(x)); }
  406. __forceinline size_t select_min(const vfloat16& v) {
  407. return __bsf(_mm512_kmov(_mm512_cmp_epi32_mask(_mm512_castps_si512(v),_mm512_castps_si512(vreduce_min(v)),_MM_CMPINT_EQ)));
  408. }
  409. __forceinline size_t select_max(const vfloat16& v) {
  410. return __bsf(_mm512_kmov(_mm512_cmp_epi32_mask(_mm512_castps_si512(v),_mm512_castps_si512(vreduce_max(v)),_MM_CMPINT_EQ)));
  411. }
  412. __forceinline size_t select_min(const vboolf16& valid, const vfloat16& v)
  413. {
  414. const vfloat16 a = select(valid,v,vfloat16(pos_inf));
  415. const vbool16 valid_min = valid & (a == vreduce_min(a));
  416. return __bsf(movemask(any(valid_min) ? valid_min : valid));
  417. }
  418. __forceinline size_t select_max(const vboolf16& valid, const vfloat16& v)
  419. {
  420. const vfloat16 a = select(valid,v,vfloat16(neg_inf));
  421. const vbool16 valid_max = valid & (a == vreduce_max(a));
  422. return __bsf(movemask(any(valid_max) ? valid_max : valid));
  423. }
  424. __forceinline vfloat16 prefix_sum(const vfloat16& a)
  425. {
  426. const vfloat16 z(zero);
  427. vfloat16 v = a;
  428. v = v + align_shift_right<16-1>(v,z);
  429. v = v + align_shift_right<16-2>(v,z);
  430. v = v + align_shift_right<16-4>(v,z);
  431. v = v + align_shift_right<16-8>(v,z);
  432. return v;
  433. }
  434. __forceinline vfloat16 reverse_prefix_sum(const vfloat16& a)
  435. {
  436. const vfloat16 z(zero);
  437. vfloat16 v = a;
  438. v = v + align_shift_right<1>(z,v);
  439. v = v + align_shift_right<2>(z,v);
  440. v = v + align_shift_right<4>(z,v);
  441. v = v + align_shift_right<8>(z,v);
  442. return v;
  443. }
  444. __forceinline vfloat16 prefix_min(const vfloat16& a)
  445. {
  446. const vfloat16 z(pos_inf);
  447. vfloat16 v = a;
  448. v = min(v,align_shift_right<16-1>(v,z));
  449. v = min(v,align_shift_right<16-2>(v,z));
  450. v = min(v,align_shift_right<16-4>(v,z));
  451. v = min(v,align_shift_right<16-8>(v,z));
  452. return v;
  453. }
  454. __forceinline vfloat16 prefix_max(const vfloat16& a)
  455. {
  456. const vfloat16 z(neg_inf);
  457. vfloat16 v = a;
  458. v = max(v,align_shift_right<16-1>(v,z));
  459. v = max(v,align_shift_right<16-2>(v,z));
  460. v = max(v,align_shift_right<16-4>(v,z));
  461. v = max(v,align_shift_right<16-8>(v,z));
  462. return v;
  463. }
  464. __forceinline vfloat16 reverse_prefix_min(const vfloat16& a)
  465. {
  466. const vfloat16 z(pos_inf);
  467. vfloat16 v = a;
  468. v = min(v,align_shift_right<1>(z,v));
  469. v = min(v,align_shift_right<2>(z,v));
  470. v = min(v,align_shift_right<4>(z,v));
  471. v = min(v,align_shift_right<8>(z,v));
  472. return v;
  473. }
  474. __forceinline vfloat16 reverse_prefix_max(const vfloat16& a)
  475. {
  476. const vfloat16 z(neg_inf);
  477. vfloat16 v = a;
  478. v = max(v,align_shift_right<1>(z,v));
  479. v = max(v,align_shift_right<2>(z,v));
  480. v = max(v,align_shift_right<4>(z,v));
  481. v = max(v,align_shift_right<8>(z,v));
  482. return v;
  483. }
  484. ////////////////////////////////////////////////////////////////////////////////
  485. /// Memory load and store operations
  486. ////////////////////////////////////////////////////////////////////////////////
  487. __forceinline void compactustore16f_low(const vboolf16& mask, float * addr, const vfloat16 &reg) {
  488. _mm512_mask_compressstoreu_ps(addr,mask,reg);
  489. }
  490. template<int scale = 4>
  491. __forceinline vfloat16 gather16f(const vboolf16& mask, const float *const ptr, __m512i index) {
  492. vfloat16 r = vfloat16::undefined();
  493. return _mm512_mask_i32gather_ps(r,mask,index,ptr,scale);
  494. }
  495. template<int scale = 4>
  496. __forceinline void scatter16f(const vboolf16& mask,const float *const ptr, const __m512i index,const vfloat16 v) {
  497. _mm512_mask_i32scatter_ps((void*)ptr,mask,index,v,scale);
  498. }
  499. __forceinline vfloat16 loadAOS4to16f(const float& x,const float& y, const float& z)
  500. {
  501. vfloat16 f = vfloat16::zero();
  502. f = select(0x1111,vfloat16::broadcast(&x),f);
  503. f = select(0x2222,vfloat16::broadcast(&y),f);
  504. f = select(0x4444,vfloat16::broadcast(&z),f);
  505. return f;
  506. }
  507. __forceinline vfloat16 loadAOS4to16f(const unsigned int index,
  508. const vfloat16 &x,
  509. const vfloat16 &y,
  510. const vfloat16 &z)
  511. {
  512. vfloat16 f = vfloat16::zero();
  513. f = select(0x1111,vfloat16::broadcast((float*)&x + index),f);
  514. f = select(0x2222,vfloat16::broadcast((float*)&y + index),f);
  515. f = select(0x4444,vfloat16::broadcast((float*)&z + index),f);
  516. return f;
  517. }
  518. __forceinline vfloat16 loadAOS4to16f(const unsigned int index,
  519. const vfloat16 &x,
  520. const vfloat16 &y,
  521. const vfloat16 &z,
  522. const vfloat16 &fill)
  523. {
  524. vfloat16 f = fill;
  525. f = select(0x1111,vfloat16::broadcast((float*)&x + index),f);
  526. f = select(0x2222,vfloat16::broadcast((float*)&y + index),f);
  527. f = select(0x4444,vfloat16::broadcast((float*)&z + index),f);
  528. return f;
  529. }
  530. __forceinline vfloat16 rcp_safe( const vfloat16& a ) {
  531. return rcp(select(a != vfloat16::zero(),a,vfloat16(min_rcp_input)));
  532. }
  533. ////////////////////////////////////////////////////////////////////////////////
  534. /// Output Operators
  535. ////////////////////////////////////////////////////////////////////////////////
  536. __forceinline std::ostream &operator<<(std::ostream& cout, const vfloat16& v)
  537. {
  538. cout << "<" << v[0];
  539. for (int i=1; i<16; i++) cout << ", " << v[i];
  540. cout << ">";
  541. return cout;
  542. }
  543. }