Browse Source

Fix potential NaNs/INFs from Vec3::NormalizedOr

Clang with '-ffast-math' (which you should not use!) can generate _mm_rsqrt_ps instructions which produce INFs/NaNs when they get a denormal float as input. We therefore treat denormals as zero in Vec3::NormalizedOr.
Jorrit Rouwe 8 months ago
parent
commit
232fed57a5
2 changed files with 10 additions and 6 deletions
  1. 8 6
      Jolt/Math/Vec3.inl
  2. 2 0
      UnitTests/Math/Vec3Tests.cpp

+ 8 - 6
Jolt/Math/Vec3.inl

@@ -721,7 +721,10 @@ Vec3 Vec3::NormalizedOr(Vec3Arg inZeroValue) const
 {
 {
 #if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
 #if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
 	Type len_sq = _mm_dp_ps(mValue, mValue, 0x7f);
 	Type len_sq = _mm_dp_ps(mValue, mValue, 0x7f);
-	Type is_zero = _mm_cmpeq_ps(len_sq, _mm_setzero_ps());
+	// clang with '-ffast-math' (which you should not use!) can generate _mm_rsqrt_ps
+	// instructions which produce INFs/NaNs when they get a denormal float as input.
+	// We therefore treat denormals as zero here.
+	Type is_zero = _mm_cmple_ps(len_sq, _mm_set1_ps(FLT_MIN));
 #ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
 #ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
 	if (_mm_movemask_ps(is_zero) == 0xf)
 	if (_mm_movemask_ps(is_zero) == 0xf)
 		return inZeroValue;
 		return inZeroValue;
@@ -733,13 +736,12 @@ Vec3 Vec3::NormalizedOr(Vec3Arg inZeroValue) const
 #elif defined(JPH_USE_NEON)
 #elif defined(JPH_USE_NEON)
 	float32x4_t mul = vmulq_f32(mValue, mValue);
 	float32x4_t mul = vmulq_f32(mValue, mValue);
 	mul = vsetq_lane_f32(0, mul, 3);
 	mul = vsetq_lane_f32(0, mul, 3);
-	float32x4_t sum = vdupq_n_f32(vaddvq_f32(mul));
-	float32x4_t len = vsqrtq_f32(sum);
-	uint32x4_t is_zero = vceqq_f32(len, vdupq_n_f32(0));
-	return vbslq_f32(is_zero, inZeroValue.mValue, vdivq_f32(mValue, len));
+	float32x4_t len_sq = vdupq_n_f32(vaddvq_f32(mul));
+	uint32x4_t is_zero = vcleq_f32(len_sq, vdupq_n_f32(FLT_MIN));
+	return vbslq_f32(is_zero, inZeroValue.mValue, vdivq_f32(mValue, vsqrtq_f32(len_sq)));
 #else
 #else
 	float len_sq = LengthSq();
 	float len_sq = LengthSq();
-	if (len_sq == 0.0f)
+	if (len_sq <= FLT_MIN)
 		return inZeroValue;
 		return inZeroValue;
 	else
 	else
 		return *this / sqrt(len_sq);
 		return *this / sqrt(len_sq);

+ 2 - 0
UnitTests/Math/Vec3Tests.cpp

@@ -274,6 +274,8 @@ TEST_SUITE("Vec3Tests")
 		CHECK(Vec3(3, 2, 1).Normalized() == Vec3(3, 2, 1) / sqrt(9.0f + 4.0f + 1.0f));
 		CHECK(Vec3(3, 2, 1).Normalized() == Vec3(3, 2, 1) / sqrt(9.0f + 4.0f + 1.0f));
 		CHECK(Vec3(3, 2, 1).NormalizedOr(Vec3(1, 2, 3)) == Vec3(3, 2, 1) / sqrt(9.0f + 4.0f + 1.0f));
 		CHECK(Vec3(3, 2, 1).NormalizedOr(Vec3(1, 2, 3)) == Vec3(3, 2, 1) / sqrt(9.0f + 4.0f + 1.0f));
 		CHECK(Vec3::sZero().NormalizedOr(Vec3(1, 2, 3)) == Vec3(1, 2, 3));
 		CHECK(Vec3::sZero().NormalizedOr(Vec3(1, 2, 3)) == Vec3(1, 2, 3));
+		CHECK(Vec3(0.999f * sqrt(FLT_MIN), 0, 0).NormalizedOr(Vec3(1, 2, 3)) == Vec3(1, 2, 3)); // A vector that has a squared length that is denormal should also be treated as zero
+		CHECK_APPROX_EQUAL(Vec3(1.001f * sqrt(FLT_MIN), 0, 0).NormalizedOr(Vec3(1, 2, 3)), Vec3(1, 0, 0)); // A value that is just above being denormal should work normally
 	}
 	}
 
 
 	TEST_CASE("TestVec3Cast")
 	TEST_CASE("TestVec3Cast")