Browse Source

Add AVX512 acceleration for double floating-point types (#524)

* Add AVX512 acceleration for DVec3::IsNaN
* Add AVX512 acceleration for DVec3::GetSign
* Add AVX512 acceleration for DVec3::PrepareRoundToInf
Wunk 2 years ago
parent
commit
82c3f7d4f8
1 changed files with 12 additions and 3 deletions
  1. 12 3
      Jolt/Math/DVec3.inl

+ 12 - 3
Jolt/Math/DVec3.inl

@@ -797,7 +797,9 @@ bool DVec3::IsNormalized(double inTolerance) const
 
 bool DVec3::IsNaN() const
 {
-#if defined(JPH_USE_AVX)
+#if defined(JPH_USE_AVX512)
+	return (_mm256_fpclass_pd_mask(mValue, 0b10000001) & 0x7) != 0;
+#elif defined(JPH_USE_AVX)
 	return (_mm256_movemask_pd(_mm256_cmp_pd(mValue, mValue, _CMP_UNORD_Q)) & 0x7) != 0;
 #elif defined(JPH_USE_SSE)
 	return ((_mm_movemask_pd(_mm_cmpunord_pd(mValue.mLow, mValue.mLow)) + (_mm_movemask_pd(_mm_cmpunord_pd(mValue.mHigh, mValue.mHigh)) << 2)) & 0x7) != 0;
@@ -808,7 +810,9 @@ bool DVec3::IsNaN() const
 
 DVec3 DVec3::GetSign() const
 {
-#if defined(JPH_USE_AVX)
+#if defined(JPH_USE_AVX512)
+	return _mm256_fixupimm_pd(mValue, mValue, _mm256_set1_epi32(0xA9A90A00), 0);
+#elif defined(JPH_USE_AVX)
 	__m256d minus_one = _mm256_set1_pd(-1.0);
 	__m256d one = _mm256_set1_pd(1.0);
 	return _mm256_or_pd(_mm256_and_pd(mValue, minus_one), one);
@@ -854,7 +858,12 @@ DVec3 DVec3::PrepareRoundToInf() const
 	// Float has 23 bit mantissa, double 52 bit mantissa => we lose 29 bits when converting from double to float
 	constexpr uint64 cDoubleToFloatMantissaLoss = (1U << 29) - 1;
 
-#if defined(JPH_USE_AVX)
+#if defined(JPH_USE_AVX512)
+	__m256i mantissa_loss = _mm256_set1_epi64x(cDoubleToFloatMantissaLoss);
+	__mmask8 is_zero = _mm256_testn_epi64_mask(_mm256_castpd_si256(mValue), mantissa_loss);
+	__m256d value_or_mantissa_loss = _mm256_or_pd(mValue, _mm256_castsi256_pd(mantissa_loss));
+	return _mm256_mask_blend_pd(is_zero, value_or_mantissa_loss, mValue);
+#elif defined(JPH_USE_AVX)
 	__m256i mantissa_loss = _mm256_set1_epi64x(cDoubleToFloatMantissaLoss);
 	__m256d value_and_mantissa_loss = _mm256_and_pd(mValue, _mm256_castsi256_pd(mantissa_loss));
 	__m256d is_zero = _mm256_cmp_pd(value_and_mantissa_loss, _mm256_setzero_pd(), _CMP_EQ_OQ);