Browse Source

Fixed rsqrt and sqrt.

Branimir Karadžić 2 years ago
parent
commit
c5593ad749
2 changed files with 123 additions and 80 deletions
  1. 34 9
      include/bx/inline/math.inl
  2. 89 71
      tests/math_test.cpp

+ 34 - 9
include/bx/inline/math.inl

@@ -220,27 +220,52 @@ namespace bx
 		return pow(_a, -0.5f);
 		return pow(_a, -0.5f);
 	}
 	}
 
 
-	inline BX_CONST_FUNC float sqrtRef(float _a)
-	{
-		return _a*pow(_a, -0.5f);
-	}
-
 	inline BX_CONST_FUNC float rsqrtSimd(float _a)
 	inline BX_CONST_FUNC float rsqrtSimd(float _a)
 	{
 	{
-		const simd128_t aa     = simd_splat(_a);
+		if (_a < kNearZero)
+		{
+			return kFloatInfinity;
+		}
+
+		const simd128_t aa = simd_splat(_a);
+#if BX_SIMD_NEON
 		const simd128_t rsqrta = simd_rsqrt_nr(aa);
 		const simd128_t rsqrta = simd_rsqrt_nr(aa);
+#else
+		const simd128_t rsqrta = simd_rsqrt_ni(aa);
+#endif // BX_SIMD_NEON
+
 		float result;
 		float result;
 		simd_stx(&result, rsqrta);
 		simd_stx(&result, rsqrta);
 
 
 		return result;
 		return result;
 	}
 	}
 
 
+	inline BX_CONST_FUNC float sqrtRef(float _a)
+	{
+		if (_a < 0.0F)
+		{
+			return bitsToFloat(kFloatExponentMask | kFloatMantissaMask);
+		}
+
+		return _a * pow(_a, -0.5f);
+	}
+
 	inline BX_CONST_FUNC float sqrtSimd(float _a)
 	inline BX_CONST_FUNC float sqrtSimd(float _a)
 	{
 	{
-		const simd128_t aa    = simd_splat(_a);
-		const simd128_t sqrta = simd_sqrt(aa);
+		if (_a < 0.0F)
+		{
+			return bitsToFloat(kFloatExponentMask | kFloatMantissaMask);
+		}
+		else if (_a < kNearZero)
+		{
+			return 0.0f;
+		}
+
+		const simd128_t aa   = simd_splat(_a);
+		const simd128_t sqrt = simd_sqrt(aa);
+
 		float result;
 		float result;
-		simd_stx(&result, sqrta);
+		simd_stx(&result, sqrt);
 
 
 		return result;
 		return result;
 	}
 	}

+ 89 - 71
tests/math_test.cpp

@@ -72,117 +72,135 @@ TEST_CASE("log2", "")
 	REQUIRE(8 == bx::log2(256) );
 	REQUIRE(8 == bx::log2(256) );
 }
 }
 
 
-TEST_CASE("libm", "")
+BX_PRAGMA_DIAGNOSTIC_PUSH();
+BX_PRAGMA_DIAGNOSTIC_IGNORED_MSVC(4723) // potential divide by 0
+
+TEST_CASE("libm sqrt", "")
 {
 {
 	bx::WriterI* writer = bx::getNullOut();
 	bx::WriterI* writer = bx::getNullOut();
 	bx::Error err;
 	bx::Error err;
 
 
-	REQUIRE(1389.0f == bx::abs(-1389.0f) );
-	REQUIRE(1389.0f == bx::abs( 1389.0f) );
-	REQUIRE(   0.0f == bx::abs(-0.0f) );
-	REQUIRE(   0.0f == bx::abs( 0.0f) );
-
-	REQUIRE(389.0f == bx::mod(1389.0f, 1000.0f) );
-
-	REQUIRE( 13.0f == bx::floor( 13.89f) );
-	REQUIRE(-14.0f == bx::floor(-13.89f) );
-	REQUIRE( 14.0f == bx::ceil(  13.89f) );
-	REQUIRE(-13.0f == bx::ceil( -13.89f) );
-
-	REQUIRE( 13.0f == bx::trunc( 13.89f) );
-	REQUIRE(-13.0f == bx::trunc(-13.89f) );
-	REQUIRE(bx::isEqual( 0.89f, bx::fract( 13.89f), 0.000001f) );
-	REQUIRE(bx::isEqual(-0.89f, bx::fract(-13.89f), 0.000001f) );
+	// rsqrtRef
+	REQUIRE(bx::isInfinite(bx::rsqrtRef(0.0f)));
 
 
-	for (int32_t yy = -10; yy < 10; ++yy)
+	for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
 	{
 	{
-		for (float xx = -100.0f; xx < 100.0f; xx += 0.1f)
-		{
-			bx::write(writer, &err, "ldexp(%f, %d) == %f (expected: %f)\n", xx, yy, bx::ldexp(xx, yy), ::ldexpf(xx, yy) );
-			REQUIRE(bx::isEqual(bx::ldexp(xx, yy), ::ldexpf(xx, yy), 0.00001f) );
-		}
+		bx::write(writer, &err, "rsqrtRef(%f) == %f (expected: %f)\n", xx, bx::rsqrtRef(xx), 1.0f / ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::rsqrtRef(xx), 1.0f / ::sqrtf(xx), 0.00001f));
 	}
 	}
 
 
-	for (float xx = -80.0f; xx < 80.0f; xx += 0.1f)
+	// rsqrtSimd
+	REQUIRE(bx::isInfinite(bx::rsqrtSimd(0.0f)));
+
+	for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
 	{
 	{
-		bx::write(writer, &err, "exp(%f) == %f (expected: %f)\n", xx, bx::exp(xx), ::expf(xx) );
-		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::exp(xx), ::expf(xx), 0.00001f) );
+		bx::write(writer, &err, "rsqrtSimd(%f) == %f (expected: %f)\n", xx, bx::rsqrtSimd(xx), 1.0f / ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::rsqrtSimd(xx), 1.0f / ::sqrtf(xx), 0.00001f));
 	}
 	}
 
 
 	// rsqrt
 	// rsqrt
-	REQUIRE(bx::isInfinite(1.0f/::sqrtf(0.0f) ) );
-	REQUIRE(bx::isInfinite(bx::rsqrt(0.0f) ) );
+	REQUIRE(bx::isInfinite(1.0f / ::sqrtf(0.0f)));
+	REQUIRE(bx::isInfinite(bx::rsqrt(0.0f)));
 
 
 	for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
 	for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
 	{
 	{
-		bx::write(writer, &err, "rsqrt(%f) == %f (expected: %f)\n", xx, bx::rsqrt(xx), 1.0f/::sqrtf(xx) );
-		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::rsqrt(xx), 1.0f/::sqrtf(xx), 0.00001f) );
+		bx::write(writer, &err, "rsqrt(%f) == %f (expected: %f)\n", xx, bx::rsqrt(xx), 1.0f / ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::rsqrt(xx), 1.0f / ::sqrtf(xx), 0.00001f));
 	}
 	}
 
 
-	// rsqrtRef
-	REQUIRE(bx::isInfinite(bx::rsqrtRef(0.0f) ) );
+	// sqrtRef
+	REQUIRE(bx::isNan(bx::sqrtRef(-1.0f)));
+	REQUIRE(bx::isEqual(bx::sqrtRef(0.0f), ::sqrtf(0.0f), 0.0f));
+	REQUIRE(bx::isEqual(bx::sqrtRef(1.0f), ::sqrtf(1.0f), 0.0f));
 
 
-	for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
+	for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
 	{
 	{
-		bx::write(writer, &err, "rsqrtRef(%f) == %f (expected: %f)\n", xx, bx::rsqrtRef(xx), 1.0f/::sqrtf(xx) );
-		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::rsqrtRef(xx), 1.0f/::sqrtf(xx), 0.00001f) );
+		bx::write(writer, &err, "sqrtRef(%f) == %f (expected: %f)\n", xx, bx::sqrtRef(xx), ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::sqrtRef(xx), ::sqrtf(xx), 0.00001f));
 	}
 	}
 
 
-	// rsqrtSimd
-	REQUIRE(bx::isInfinite(bx::rsqrtSimd(0.0f) ) );
+	// sqrtSimd
+	REQUIRE(bx::isNan(bx::sqrtSimd(-1.0f)));
+	REQUIRE(bx::isEqual(bx::sqrtSimd(0.0f), ::sqrtf(0.0f), 0.0f));
+	REQUIRE(bx::isEqual(bx::sqrtSimd(1.0f), ::sqrtf(1.0f), 0.0f));
 
 
-	for (float xx = bx::kNearZero; xx < 100.0f; xx += 0.1f)
+	for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
 	{
 	{
-		bx::write(writer, &err, "rsqrtSimd(%f) == %f (expected: %f)\n", xx, bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx) );
-		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::rsqrtSimd(xx), 1.0f/::sqrtf(xx), 0.00001f) );
+		bx::write(writer, &err, "sqrtSimd(%f) == %f (expected: %f)\n", xx, bx::sqrtSimd(xx), ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::sqrtSimd(xx), ::sqrtf(xx), 0.00001f));
+	}
+
+	for (float xx = 0.0f; xx < 100.0f; xx += 0.1f)
+	{
+		bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f));
 	}
 	}
 
 
 	// sqrt
 	// sqrt
-	REQUIRE(bx::isNan(::sqrtf(-1.0f) ) );
-	REQUIRE(bx::isNan(bx::sqrt(-1.0f) ) );
-	REQUIRE(bx::isEqual(bx::sqrt(0.0f), ::sqrtf(0.0f), 0.0f) );
-	REQUIRE(bx::isEqual(bx::sqrt(1.0f), ::sqrtf(1.0f), 0.0f) );
+	REQUIRE(bx::isNan(::sqrtf(-1.0f)));
+	REQUIRE(bx::isNan(bx::sqrt(-1.0f)));
+	REQUIRE(bx::isEqual(bx::sqrt(0.0f), ::sqrtf(0.0f), 0.0f));
+	REQUIRE(bx::isEqual(bx::sqrt(1.0f), ::sqrtf(1.0f), 0.0f));
 
 
 	for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
 	for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
 	{
 	{
-		bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx) );
-		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f) );
+		bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f));
 	}
 	}
 
 
-	// sqrtRef
-	REQUIRE(bx::isNan(bx::sqrtRef(-1.0f) ) );
-	REQUIRE(bx::isEqual(bx::sqrtRef(0.0f), ::sqrtf(0.0f), 0.0f) );
-	REQUIRE(bx::isEqual(bx::sqrtRef(1.0f), ::sqrtf(1.0f), 0.0f) );
-
-	for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
+	for (float xx = 0.0f; xx < 100.0f; xx += 0.1f)
 	{
 	{
-		bx::write(writer, &err, "sqrtRef(%f) == %f (expected: %f)\n", xx, bx::sqrtRef(xx), ::sqrtf(xx) );
-		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::sqrtRef(xx), ::sqrtf(xx), 0.00001f) );
+		bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx));
+		REQUIRE(err.isOk());
+		REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f));
 	}
 	}
+}
 
 
-	// sqrtSimd
-	REQUIRE(bx::isNan(bx::sqrtSimd(-1.0f) ) );
-	REQUIRE(bx::isEqual(bx::sqrtSimd(0.0f), ::sqrtf(0.0f), 0.0f) );
-	REQUIRE(bx::isEqual(bx::sqrtSimd(1.0f), ::sqrtf(1.0f), 0.0f) );
+BX_PRAGMA_DIAGNOSTIC_POP();
 
 
-	for (float xx = 0.0f; xx < 1000000.0f; xx += 1000.f)
+TEST_CASE("libm", "")
+{
+	bx::WriterI* writer = bx::getNullOut();
+	bx::Error err;
+
+	REQUIRE(1389.0f == bx::abs(-1389.0f) );
+	REQUIRE(1389.0f == bx::abs( 1389.0f) );
+	REQUIRE(   0.0f == bx::abs(-0.0f) );
+	REQUIRE(   0.0f == bx::abs( 0.0f) );
+
+	REQUIRE(389.0f == bx::mod(1389.0f, 1000.0f) );
+
+	REQUIRE( 13.0f == bx::floor( 13.89f) );
+	REQUIRE(-14.0f == bx::floor(-13.89f) );
+	REQUIRE( 14.0f == bx::ceil(  13.89f) );
+	REQUIRE(-13.0f == bx::ceil( -13.89f) );
+
+	REQUIRE( 13.0f == bx::trunc( 13.89f) );
+	REQUIRE(-13.0f == bx::trunc(-13.89f) );
+	REQUIRE(bx::isEqual( 0.89f, bx::fract( 13.89f), 0.000001f) );
+	REQUIRE(bx::isEqual(-0.89f, bx::fract(-13.89f), 0.000001f) );
+
+	for (int32_t yy = -10; yy < 10; ++yy)
 	{
 	{
-		bx::write(writer, &err, "sqrtSimd(%f) == %f (expected: %f)\n", xx, bx::sqrtSimd(xx), ::sqrtf(xx) );
-		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::sqrtSimd(xx), ::sqrtf(xx), 0.00001f) );
+		for (float xx = -100.0f; xx < 100.0f; xx += 0.1f)
+		{
+			bx::write(writer, &err, "ldexp(%f, %d) == %f (expected: %f)\n", xx, yy, bx::ldexp(xx, yy), ::ldexpf(xx, yy) );
+			REQUIRE(bx::isEqual(bx::ldexp(xx, yy), ::ldexpf(xx, yy), 0.00001f) );
+		}
 	}
 	}
 
 
-	for (float xx = 0.0f; xx < 100.0f; xx += 0.1f)
+	for (float xx = -80.0f; xx < 80.0f; xx += 0.1f)
 	{
 	{
-		bx::write(writer, &err, "sqrt(%f) == %f (expected: %f)\n", xx, bx::sqrt(xx), ::sqrtf(xx) );
+		bx::write(writer, &err, "exp(%f) == %f (expected: %f)\n", xx, bx::exp(xx), ::expf(xx) );
 		REQUIRE(err.isOk() );
 		REQUIRE(err.isOk() );
-		REQUIRE(bx::isEqual(bx::sqrt(xx), ::sqrtf(xx), 0.00001f) );
+		REQUIRE(bx::isEqual(bx::exp(xx), ::expf(xx), 0.00001f) );
 	}
 	}
 
 
 	for (float xx = -100.0f; xx < 100.0f; xx += 0.1f)
 	for (float xx = -100.0f; xx < 100.0f; xx += 0.1f)