Forráskód Böngészése

Added SSE2 implementation of Vec3::sSelect, Vec4::sSelect and UVec4::sSelect (#1314)

* This works around an issue where FireFox has problems with the _mm_blendv_ps intrinsic when compiling to WASM. See: https://x.com/fforw/status/1848540672481214765.
* Also made DVec3::sSelect more consistent.
Jorrit Rouwe 9 hónapja
szülő
commit
e4debe8683

+ 2 - 2
Jolt/Math/DVec3.h

@@ -109,8 +109,8 @@ public:
 	/// Calculates inMul1 * inMul2 + inAdd
 	static JPH_INLINE DVec3		sFusedMultiplyAdd(DVec3Arg inMul1, DVec3Arg inMul2, DVec3Arg inAdd);
 
-	/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
-	static JPH_INLINE DVec3		sSelect(DVec3Arg inV1, DVec3Arg inV2, DVec3Arg inControl);
+	/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
+	static JPH_INLINE DVec3		sSelect(DVec3Arg inNotSet, DVec3Arg inSet, DVec3Arg inControl);
 
 	/// Logical or (component wise)
 	static JPH_INLINE DVec3		sOr(DVec3Arg inV1, DVec3Arg inV2);

+ 6 - 6
Jolt/Math/DVec3.inl

@@ -315,21 +315,21 @@ DVec3 DVec3::sFusedMultiplyAdd(DVec3Arg inMul1, DVec3Arg inMul2, DVec3Arg inAdd)
 #endif
 }
 
-DVec3 DVec3::sSelect(DVec3Arg inV1, DVec3Arg inV2, DVec3Arg inControl)
+DVec3 DVec3::sSelect(DVec3Arg inNotSet, DVec3Arg inSet, DVec3Arg inControl)
 {
 #if defined(JPH_USE_AVX)
-	return _mm256_blendv_pd(inV1.mValue, inV2.mValue, inControl.mValue);
+	return _mm256_blendv_pd(inNotSet.mValue, inSet.mValue, inControl.mValue);
 #elif defined(JPH_USE_SSE4_1)
-	Type v = { _mm_blendv_pd(inV1.mValue.mLow, inV2.mValue.mLow, inControl.mValue.mLow), _mm_blendv_pd(inV1.mValue.mHigh, inV2.mValue.mHigh, inControl.mValue.mHigh) };
+	Type v = { _mm_blendv_pd(inNotSet.mValue.mLow, inSet.mValue.mLow, inControl.mValue.mLow), _mm_blendv_pd(inNotSet.mValue.mHigh, inSet.mValue.mHigh, inControl.mValue.mHigh) };
 	return sFixW(v);
 #elif defined(JPH_USE_NEON)
-	Type v = { vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[0]), 63)), inV2.mValue.val[0], inV1.mValue.val[0]),
-			   vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[1]), 63)), inV2.mValue.val[1], inV1.mValue.val[1]) };
+	Type v = { vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[0]), 63)), inSet.mValue.val[0], inNotSet.mValue.val[0]),
+			   vbslq_f64(vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_f64(inControl.mValue.val[1]), 63)), inSet.mValue.val[1], inNotSet.mValue.val[1]) };
 	return sFixW(v);
 #else
 	DVec3 result;
 	for (int i = 0; i < 3; i++)
-		result.mF64[i] = BitCast<uint64>(inControl.mF64[i])? inV2.mF64[i] : inV1.mF64[i];
+		result.mF64[i] = (BitCast<uint64>(inControl.mF64[i]) & (uint64(1) << 63))? inSet.mF64[i] : inNotSet.mF64[i];
 #ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
 	result.mF64[3] = result.mF64[2];
 #endif // JPH_FLOATING_POINT_EXCEPTIONS_ENABLED

+ 2 - 2
Jolt/Math/UVec4.h

@@ -67,8 +67,8 @@ public:
 	/// Equals (component wise)
 	static JPH_INLINE UVec4		sEquals(UVec4Arg inV1, UVec4Arg inV2);
 
-	/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
-	static JPH_INLINE UVec4		sSelect(UVec4Arg inV1, UVec4Arg inV2, UVec4Arg inControl);
+	/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
+	static JPH_INLINE UVec4		sSelect(UVec4Arg inNotSet, UVec4Arg inSet, UVec4Arg inControl);
 
 	/// Logical or (component wise)
 	static JPH_INLINE UVec4		sOr(UVec4Arg inV1, UVec4Arg inV2);

+ 8 - 5
Jolt/Math/UVec4.inl

@@ -154,16 +154,19 @@ UVec4 UVec4::sEquals(UVec4Arg inV1, UVec4Arg inV2)
 #endif
 }
 
-UVec4 UVec4::sSelect(UVec4Arg inV1, UVec4Arg inV2, UVec4Arg inControl)
+UVec4 UVec4::sSelect(UVec4Arg inNotSet, UVec4Arg inSet, UVec4Arg inControl)
 {
-#if defined(JPH_USE_SSE4_1)
-	return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(inV1.mValue), _mm_castsi128_ps(inV2.mValue), _mm_castsi128_ps(inControl.mValue)));
+#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
+	return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(inNotSet.mValue), _mm_castsi128_ps(inSet.mValue), _mm_castsi128_ps(inControl.mValue)));
+#elif defined(JPH_USE_SSE)
+	__m128 is_set = _mm_castsi128_ps(_mm_srai_epi32(inControl.mValue, 31));
+	return _mm_castps_si128(_mm_or_ps(_mm_and_ps(is_set, _mm_castsi128_ps(inSet.mValue)), _mm_andnot_ps(is_set, _mm_castsi128_ps(inNotSet.mValue))));
 #elif defined(JPH_USE_NEON)
-	return vbslq_u32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inV2.mValue, inV1.mValue);
+	return vbslq_u32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inSet.mValue, inNotSet.mValue);
 #else
 	UVec4 result;
 	for (int i = 0; i < 4; i++)
-		result.mU32[i] = inControl.mU32[i] ? inV2.mU32[i] : inV1.mU32[i];
+		result.mU32[i] = (inControl.mU32[i] & 0x80000000u) ? inSet.mU32[i] : inNotSet.mU32[i];
 	return result;
 #endif
 }

+ 2 - 2
Jolt/Math/Vec3.h

@@ -87,8 +87,8 @@ public:
 	/// Calculates inMul1 * inMul2 + inAdd
 	static JPH_INLINE Vec3		sFusedMultiplyAdd(Vec3Arg inMul1, Vec3Arg inMul2, Vec3Arg inAdd);
 
-	/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
-	static JPH_INLINE Vec3		sSelect(Vec3Arg inV1, Vec3Arg inV2, UVec4Arg inControl);
+	/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
+	static JPH_INLINE Vec3		sSelect(Vec3Arg inNotSet, Vec3Arg inSet, UVec4Arg inControl);
 
 	/// Logical or (component wise)
 	static JPH_INLINE Vec3		sOr(Vec3Arg inV1, Vec3Arg inV2);

+ 10 - 6
Jolt/Math/Vec3.inl

@@ -266,18 +266,22 @@ Vec3 Vec3::sFusedMultiplyAdd(Vec3Arg inMul1, Vec3Arg inMul2, Vec3Arg inAdd)
 #endif
 }
 
-Vec3 Vec3::sSelect(Vec3Arg inV1, Vec3Arg inV2, UVec4Arg inControl)
+Vec3 Vec3::sSelect(Vec3Arg inNotSet, Vec3Arg inSet, UVec4Arg inControl)
 {
-#if defined(JPH_USE_SSE4_1)
-	Type v = _mm_blendv_ps(inV1.mValue, inV2.mValue, _mm_castsi128_ps(inControl.mValue));
+#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
+	Type v = _mm_blendv_ps(inNotSet.mValue, inSet.mValue, _mm_castsi128_ps(inControl.mValue));
+	return sFixW(v);
+#elif defined(JPH_USE_SSE)
+	__m128 is_set = _mm_castsi128_ps(_mm_srai_epi32(inControl.mValue, 31));
+	Type v = _mm_or_ps(_mm_and_ps(is_set, inSet.mValue), _mm_andnot_ps(is_set, inNotSet.mValue));
 	return sFixW(v);
 #elif defined(JPH_USE_NEON)
-	Type v = vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inV2.mValue, inV1.mValue);
+	Type v = vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inSet.mValue, inNotSet.mValue);
 	return sFixW(v);
 #else
 	Vec3 result;
 	for (int i = 0; i < 3; i++)
-		result.mF32[i] = inControl.mU32[i] ? inV2.mF32[i] : inV1.mF32[i];
+		result.mF32[i] = (inControl.mU32[i] & 0x80000000u) ? inSet.mF32[i] : inNotSet.mF32[i];
 #ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
 	result.mF32[3] = result.mF32[2];
 #endif // JPH_FLOATING_POINT_EXCEPTIONS_ENABLED
@@ -715,7 +719,7 @@ Vec3 Vec3::Normalized() const
 
 Vec3 Vec3::NormalizedOr(Vec3Arg inZeroValue) const
 {
-#if defined(JPH_USE_SSE4_1)
+#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
 	Type len_sq = _mm_dp_ps(mValue, mValue, 0x7f);
 	Type is_zero = _mm_cmpeq_ps(len_sq, _mm_setzero_ps());
 #ifdef JPH_FLOATING_POINT_EXCEPTIONS_ENABLED

+ 2 - 2
Jolt/Math/Vec4.h

@@ -78,8 +78,8 @@ public:
 	/// Calculates inMul1 * inMul2 + inAdd
 	static JPH_INLINE Vec4		sFusedMultiplyAdd(Vec4Arg inMul1, Vec4Arg inMul2, Vec4Arg inAdd);
 
-	/// Component wise select, returns inV1 when highest bit of inControl = 0 and inV2 when highest bit of inControl = 1
-	static JPH_INLINE Vec4		sSelect(Vec4Arg inV1, Vec4Arg inV2, UVec4Arg inControl);
+	/// Component wise select, returns inNotSet when highest bit of inControl = 0 and inSet when highest bit of inControl = 1
+	static JPH_INLINE Vec4		sSelect(Vec4Arg inNotSet, Vec4Arg inSet, UVec4Arg inControl);
 
 	/// Logical or (component wise)
 	static JPH_INLINE Vec4		sOr(Vec4Arg inV1, Vec4Arg inV2);

+ 8 - 5
Jolt/Math/Vec4.inl

@@ -251,16 +251,19 @@ Vec4 Vec4::sFusedMultiplyAdd(Vec4Arg inMul1, Vec4Arg inMul2, Vec4Arg inAdd)
 #endif
 }
 
-Vec4 Vec4::sSelect(Vec4Arg inV1, Vec4Arg inV2, UVec4Arg inControl)
+Vec4 Vec4::sSelect(Vec4Arg inNotSet, Vec4Arg inSet, UVec4Arg inControl)
 {
-#if defined(JPH_USE_SSE4_1)
-	return _mm_blendv_ps(inV1.mValue, inV2.mValue, _mm_castsi128_ps(inControl.mValue));
+#if defined(JPH_USE_SSE4_1) && !defined(JPH_PLATFORM_WASM) // _mm_blendv_ps has problems on FireFox
+	return _mm_blendv_ps(inNotSet.mValue, inSet.mValue, _mm_castsi128_ps(inControl.mValue));
+#elif defined(JPH_USE_SSE)
+	__m128 is_set = _mm_castsi128_ps(_mm_srai_epi32(inControl.mValue, 31));
+	return _mm_or_ps(_mm_and_ps(is_set, inSet.mValue), _mm_andnot_ps(is_set, inNotSet.mValue));
 #elif defined(JPH_USE_NEON)
-	return vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inV2.mValue, inV1.mValue);
+	return vbslq_f32(vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(inControl.mValue), 31)), inSet.mValue, inNotSet.mValue);
 #else
 	Vec4 result;
 	for (int i = 0; i < 4; i++)
-		result.mF32[i] = inControl.mU32[i] ? inV2.mF32[i] : inV1.mF32[i];
+		result.mF32[i] = (inControl.mU32[i] & 0x80000000u) ? inSet.mF32[i] : inNotSet.mF32[i];
 	return result;
 #endif
 }

+ 5 - 0
UnitTests/Math/DVec3Tests.cpp

@@ -141,8 +141,13 @@ TEST_SUITE("DVec3Tests")
 
 	TEST_CASE("TestDVec3Select")
 	{
+		const double cTrue2 = BitCast<double>(uint64(1) << 63);
+		const double cFalse2 = BitCast<double>(~uint64(0) >> 1);
+
 		CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(DVec3::cTrue, DVec3::cFalse, DVec3::cTrue)) == DVec3(4, 2, 6));
 		CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(DVec3::cFalse, DVec3::cTrue, DVec3::cFalse)) == DVec3(1, 5, 3));
+		CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(cTrue2, cFalse2, cTrue2)) == DVec3(4, 2, 6));
+		CHECK(DVec3::sSelect(DVec3(1, 2, 3), DVec3(4, 5, 6), DVec3(cFalse2, cTrue2, cFalse2)) == DVec3(1, 5, 3));
 	}
 
 	TEST_CASE("TestDVec3BitOps")

+ 2 - 0
UnitTests/Math/UVec4Tests.cpp

@@ -185,6 +185,8 @@ TEST_SUITE("UVec4Tests")
 	{
 		CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0x80000000U, 0, 0x80000000U, 0)) == UVec4(5, 2, 7, 4));
 		CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0, 0x80000000U, 0, 0x80000000U)) == UVec4(1, 6, 3, 8));
+		CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0xffffffffU, 0x7fffffffU, 0xffffffffU, 0x7fffffffU)) == UVec4(5, 2, 7, 4));
+		CHECK(UVec4::sSelect(UVec4(1, 2, 3, 4), UVec4(5, 6, 7, 8), UVec4(0x7fffffffU, 0xffffffffU, 0x7fffffffU, 0xffffffffU)) == UVec4(1, 6, 3, 8));
 	}
 
 	TEST_CASE("TestUVec4BitOps")

+ 2 - 0
UnitTests/Math/Vec3Tests.cpp

@@ -129,6 +129,8 @@ TEST_SUITE("Vec3Tests")
 	{
 		CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0x80000000U, 0, 0x80000000U, 0)) == Vec3(4, 2, 6));
 		CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0, 0x80000000U, 0, 0x80000000U)) == Vec3(1, 5, 3));
+		CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0xffffffffU, 0x7fffffffU, 0xffffffffU, 0x7fffffffU)) == Vec3(4, 2, 6));
+		CHECK(Vec3::sSelect(Vec3(1, 2, 3), Vec3(4, 5, 6), UVec4(0x7fffffffU, 0xffffffffU, 0x7fffffffU, 0xffffffffU)) == Vec3(1, 5, 3));
 	}
 
 	TEST_CASE("TestVec3BitOps")

+ 2 - 0
UnitTests/Math/Vec4Tests.cpp

@@ -117,6 +117,8 @@ TEST_SUITE("Vec4Tests")
 	{
 		CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0x80000000U, 0, 0x80000000U, 0)) == Vec4(5, 2, 7, 4));
 		CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0, 0x80000000U, 0, 0x80000000U)) == Vec4(1, 6, 3, 8));
+		CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0xffffffffU, 0x7fffffffU, 0xffffffffU, 0x7fffffffU)) == Vec4(5, 2, 7, 4));
+		CHECK(Vec4::sSelect(Vec4(1, 2, 3, 4), Vec4(5, 6, 7, 8), UVec4(0x7fffffffU, 0xffffffffU, 0x7fffffffU, 0xffffffffU)) == Vec4(1, 6, 3, 8));
 	}
 
 	TEST_CASE("TestVec4BitOps")