Jelajahi Sumber

Fixed an issue with the JPH_Mat44 definition, it ended up as being used row major instead of column major (#1868)

Jorrit Rouwe 2 minggu lalu
induk
melakukan
2de5386965

+ 62 - 62
Jolt/Compute/CPU/HLSLToCPP.h

@@ -43,8 +43,8 @@ struct float2
 	float &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
 
 	// Swizzling (note return value is const to prevent assignment to swizzled results)
-	const float2		xy() const											{ return float2(x, y); }
-	const float2		yx() const											{ return float2(y, x); }
+	const float2		swizzle_xy() const									{ return float2(x, y); }
+	const float2		swizzle_yx() const									{ return float2(y, x); }
 
 	float				x, y;
 };
@@ -109,14 +109,14 @@ struct float3
 	float &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
 
 	// Swizzling (note return value is const to prevent assignment to swizzled results)
-	const float2		xy() const											{ return float2(x, y); }
-	const float2		yx() const											{ return float2(y, x); }
-	const float3		xyz() const											{ return float3(x, y, z); }
-	const float3		xzy() const											{ return float3(x, z, y); }
-	const float3		yxz() const											{ return float3(y, x, z); }
-	const float3		yzx() const											{ return float3(y, z, x); }
-	const float3		zxy() const											{ return float3(z, x, y); }
-	const float3		zyx() const											{ return float3(z, y, x); }
+	const float2		swizzle_xy() const									{ return float2(x, y); }
+	const float2		swizzle_yx() const									{ return float2(y, x); }
+	const float3		swizzle_xyz() const									{ return float3(x, y, z); }
+	const float3		swizzle_xzy() const									{ return float3(x, z, y); }
+	const float3		swizzle_yxz() const									{ return float3(y, x, z); }
+	const float3		swizzle_yzx() const									{ return float3(y, z, x); }
+	const float3		swizzle_zxy() const									{ return float3(z, x, y); }
+	const float3		swizzle_zyx() const									{ return float3(z, y, x); }
 
 	float				x, y, z;
 };
@@ -184,17 +184,17 @@ struct float4
 	float &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
 
 	// Swizzling (note return value is const to prevent assignment to swizzled results)
-	const float2		xy() const											{ return float2(x, y); }
-	const float2		yx() const											{ return float2(y, x); }
-	const float3		xyz() const											{ return float3(x, y, z); }
-	const float3		xzy() const											{ return float3(x, z, y); }
-	const float3		yxz() const											{ return float3(y, x, z); }
-	const float3		yzx() const											{ return float3(y, z, x); }
-	const float3		zxy() const											{ return float3(z, x, y); }
-	const float3		zyx() const											{ return float3(z, y, x); }
-	const float4		xywz() const										{ return float4(x, y, w, z); }
-	const float4		xwyz() const										{ return float4(x, w, y, z); }
-	const float4		wxyz() const										{ return float4(w, x, y, z); }
+	const float2		swizzle_xy() const									{ return float2(x, y); }
+	const float2		swizzle_yx() const									{ return float2(y, x); }
+	const float3		swizzle_xyz() const									{ return float3(x, y, z); }
+	const float3		swizzle_xzy() const									{ return float3(x, z, y); }
+	const float3		swizzle_yxz() const									{ return float3(y, x, z); }
+	const float3		swizzle_yzx() const									{ return float3(y, z, x); }
+	const float3		swizzle_zxy() const									{ return float3(z, x, y); }
+	const float3		swizzle_zyx() const									{ return float3(z, y, x); }
+	const float4		swizzle_xywz() const								{ return float4(x, y, w, z); }
+	const float4		swizzle_xwyz() const								{ return float4(x, w, y, z); }
+	const float4		swizzle_wxyz() const								{ return float4(w, x, y, z); }
 
 	float				x, y, z, w;
 };
@@ -254,12 +254,12 @@ struct uint3
 	uint32 &			operator [] (uint inIndex)							{ return (&x)[inIndex]; }
 
 	// Swizzling (note return value is const to prevent assignment to swizzled results)
-	const uint3			xyz() const											{ return uint3(x, y, z); }
-	const uint3			xzy() const											{ return uint3(x, z, y); }
-	const uint3			yxz() const											{ return uint3(y, x, z); }
-	const uint3			yzx() const											{ return uint3(y, z, x); }
-	const uint3			zxy() const											{ return uint3(z, x, y); }
-	const uint3			zyx() const											{ return uint3(z, y, x); }
+	const uint3			swizzle_xyz() const									{ return uint3(x, y, z); }
+	const uint3			swizzle_xzy() const									{ return uint3(x, z, y); }
+	const uint3			swizzle_yxz() const									{ return uint3(y, x, z); }
+	const uint3			swizzle_yzx() const									{ return uint3(y, z, x); }
+	const uint3			swizzle_zxy() const									{ return uint3(z, x, y); }
+	const uint3			swizzle_zyx() const									{ return uint3(z, y, x); }
 
 	uint32				x, y, z;
 };
@@ -311,15 +311,15 @@ struct uint4
 	uint32 &			operator [] (uint inIndex)							{ return (&x)[inIndex]; }
 
 	// Swizzling (note return value is const to prevent assignment to swizzled results)
-	const uint3			xyz() const											{ return uint3(x, y, z); }
-	const uint3			xzy() const											{ return uint3(x, z, y); }
-	const uint3			yxz() const											{ return uint3(y, x, z); }
-	const uint3			yzx() const											{ return uint3(y, z, x); }
-	const uint3			zxy() const											{ return uint3(z, x, y); }
-	const uint3			zyx() const											{ return uint3(z, y, x); }
-	const uint4			xywz() const										{ return uint4(x, y, w, z); }
-	const uint4			xwyz() const										{ return uint4(x, w, y, z); }
-	const uint4			wxyz() const										{ return uint4(w, x, y, z); }
+	const uint3			swizzle_xyz() const									{ return uint3(x, y, z); }
+	const uint3			swizzle_xzy() const									{ return uint3(x, z, y); }
+	const uint3			swizzle_yxz() const									{ return uint3(y, x, z); }
+	const uint3			swizzle_yzx() const									{ return uint3(y, z, x); }
+	const uint3			swizzle_zxy() const									{ return uint3(z, x, y); }
+	const uint3			swizzle_zyx() const									{ return uint3(z, y, x); }
+	const uint4			swizzle_xywz() const								{ return uint4(x, y, w, z); }
+	const uint4			swizzle_xwyz() const								{ return uint4(x, w, y, z); }
+	const uint4			swizzle_wxyz() const								{ return uint4(w, x, y, z); }
 
 	uint32				x, y, z, w;
 };
@@ -369,12 +369,12 @@ struct int3
 	int &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
 
 	// Swizzling (note return value is const to prevent assignment to swizzled results)
-	const int3			xyz() const											{ return int3(x, y, z); }
-	const int3			xzy() const											{ return int3(x, z, y); }
-	const int3			yxz() const											{ return int3(y, x, z); }
-	const int3			yzx() const											{ return int3(y, z, x); }
-	const int3			zxy() const											{ return int3(z, x, y); }
-	const int3			zyx() const											{ return int3(z, y, x); }
+	const int3			swizzle_xyz() const									{ return int3(x, y, z); }
+	const int3			swizzle_xzy() const									{ return int3(x, z, y); }
+	const int3			swizzle_yxz() const									{ return int3(y, x, z); }
+	const int3			swizzle_yzx() const									{ return int3(y, z, x); }
+	const int3			swizzle_zxy() const									{ return int3(z, x, y); }
+	const int3			swizzle_zyx() const									{ return int3(z, y, x); }
 
 	int					x, y, z;
 };
@@ -428,15 +428,15 @@ struct int4
 	int &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
 
 	// Swizzling (note return value is const to prevent assignment to swizzled results)
-	const int3			xyz() const											{ return int3(x, y, z); }
-	const int3			xzy() const											{ return int3(x, z, y); }
-	const int3			yxz() const											{ return int3(y, x, z); }
-	const int3			yzx() const											{ return int3(y, z, x); }
-	const int3			zxy() const											{ return int3(z, x, y); }
-	const int3			zyx() const											{ return int3(z, y, x); }
-	const int4			xywz() const										{ return int4(x, y, w, z); }
-	const int4			xwyz() const										{ return int4(x, w, y, z); }
-	const int4			wxyz() const										{ return int4(w, x, y, z); }
+	const int3			swizzle_xyz() const									{ return int3(x, y, z); }
+	const int3			swizzle_xzy() const									{ return int3(x, z, y); }
+	const int3			swizzle_yxz() const									{ return int3(y, x, z); }
+	const int3			swizzle_yzx() const									{ return int3(y, z, x); }
+	const int3			swizzle_zxy() const									{ return int3(z, x, y); }
+	const int3			swizzle_zyx() const									{ return int3(z, y, x); }
+	const int4			swizzle_xywz() const								{ return int4(x, y, w, z); }
+	const int4			swizzle_xwyz() const								{ return int4(x, w, y, z); }
+	const int4			swizzle_wxyz() const								{ return int4(w, x, y, z); }
 
 	int					x, y, z, w;
 };
@@ -508,17 +508,17 @@ constexpr				float3::float3(const uint3 &inV)					: x(float(inV.x)), y(float(inV
 constexpr				float4::float4(const int4 &inV)						: x(float(inV.x)), y(float(inV.y)), z(float(inV.z)), w(float(inV.w)) { }
 
 // Swizzle operators
-#define xy				xy()
-#define yx				yx()
-#define xyz				xyz()
-#define xzy				xzy()
-#define yxz				yxz()
-#define yzx				yzx()
-#define zxy				zxy()
-#define zyx				zyx()
-#define xywz			xywz()
-#define xwyz			xwyz()
-#define wxyz			wxyz()
+#define xy				swizzle_xy()
+#define yx				swizzle_yx()
+#define xyz				swizzle_xyz()
+#define xzy				swizzle_xzy()
+#define yxz				swizzle_yxz()
+#define yzx				swizzle_yzx()
+#define zxy				swizzle_zxy()
+#define zyx				swizzle_zyx()
+#define xywz			swizzle_xywz()
+#define xwyz			swizzle_xwyz()
+#define wxyz			swizzle_wxyz()
 
 } // HLSLToCPP
 

+ 2 - 0
Jolt/Jolt.cmake

@@ -479,6 +479,7 @@ if (JPH_USE_DX12 OR JPH_USE_VK OR JPH_USE_MTL)
 	# Compute shaders
 	set(JOLT_PHYSICS_SHADERS
 		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute.hlsl
+		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute2.hlsl
 	)
 
 	set(JOLT_PHYSICS_SHADER_HEADERS
@@ -489,6 +490,7 @@ if (JPH_USE_DX12 OR JPH_USE_VK OR JPH_USE_MTL)
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderQuat.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderVec3.h
 		${JOLT_PHYSICS_ROOT}/Shaders/TestComputeBindings.h
+		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute2Bindings.h
 	)
 endif()
 

+ 1 - 3
Jolt/Shaders/ShaderCore.h

@@ -34,8 +34,6 @@
 
 	JPH_SUPPRESS_WARNING_POP
 #else
-	#pragma pack_matrix(column_major)
-
 	typedef float JPH_float;
 	typedef float3 JPH_float3;
 	typedef float4 JPH_float4;
@@ -47,7 +45,7 @@
 	typedef int4 JPH_int4;
 	typedef float4 JPH_Quat; // xyz = imaginary part, w = real part
 	typedef float4 JPH_Plane; // xyz = normal, w = constant
-	typedef float4x4 JPH_Mat44; // matrix, column major
+	typedef float4 JPH_Mat44[4]; // matrix, column major
 
 	#define JPH_SHADER_CONSTANT(type, name, value)	static const type name = value;
 

+ 1 - 1
Jolt/Shaders/ShaderMat44.h

@@ -2,7 +2,7 @@
 // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
 // SPDX-License-Identifier: MIT
 
-inline float3 JPH_Mat44MulVec3(JPH_Mat44 inLHS, float3 inRHS)
+inline float3 JPH_Mat44Mul3x4Vec3(JPH_Mat44 inLHS, float3 inRHS)
 {
 	return inLHS[0].xyz * inRHS.x + inLHS[1].xyz * inRHS.y + inLHS[2].xyz * inRHS.z + inLHS[3].xyz;
 }

+ 25 - 0
Jolt/Shaders/TestCompute2.hlsl

@@ -0,0 +1,25 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "TestCompute2Bindings.h"
+#include "ShaderMat44.h"
+#include "ShaderVec3.h"
+#include "ShaderQuat.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cTestCompute2GroupSize, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	TestCompute2Input input = gInput[tid.x];
+	TestCompute2Output output;
+
+	output.mMul3x4Output = JPH_Mat44Mul3x4Vec3(input.mMat44Value, input.mMat44MulValue);
+	output.mMul3x3Output = JPH_Mat44Mul3x3Vec3(input.mMat44Value, input.mMat44MulValue);
+
+	output.mDecompressedVec3 = JPH_Vec3DecompressUnit(input.mCompressedVec3);
+
+	output.mDecompressedQuat = JPH_QuatDecompress(input.mCompressedQuat);
+
+	gOutput[tid.x] = output;
+}

+ 26 - 0
Jolt/Shaders/TestCompute2Bindings.h

@@ -0,0 +1,26 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "ShaderCore.h"
+
+JPH_SHADER_CONSTANT(int, cTestCompute2GroupSize, 1)
+
+JPH_SHADER_STRUCT_BEGIN(TestCompute2Input)
+	JPH_SHADER_STRUCT_MEMBER(JPH_Mat44,			Mat44Value)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		Mat44MulValue)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			CompressedVec3)
+	JPH_SHADER_STRUCT_MEMBER(JPH_uint,			CompressedQuat)
+JPH_SHADER_STRUCT_END(TestComputeContext)
+
+JPH_SHADER_STRUCT_BEGIN(TestCompute2Output)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		Mul3x4Output)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		Mul3x3Output)
+	JPH_SHADER_STRUCT_MEMBER(JPH_float3,		DecompressedVec3)
+	JPH_SHADER_STRUCT_MEMBER(JPH_Quat,			DecompressedQuat)
+JPH_SHADER_STRUCT_END(TestCompute2Output)
+
+JPH_SHADER_BIND_BEGIN(JPH_TestCompute2)
+	JPH_SHADER_BIND_BUFFER(TestCompute2Input, gInput)
+	JPH_SHADER_BIND_RW_BUFFER(TestCompute2Output, gOutput)
+JPH_SHADER_BIND_END(JPH_TestCompute2)

+ 7 - 0
Jolt/Shaders/TestComputeWrapper.cpp

@@ -10,3 +10,10 @@
 #include <Jolt/Compute/CPU/WrapShaderBindings.h>
 #include "TestComputeBindings.h"
 #include <Jolt/Compute/CPU/WrapShaderEnd.h>
+
+#define JPH_SHADER_NAME TestCompute2
+#include <Jolt/Compute/CPU/WrapShaderBegin.h>
+#include "TestCompute2.hlsl"
+#include <Jolt/Compute/CPU/WrapShaderBindings.h>
+#include "TestCompute2Bindings.h"
+#include <Jolt/Compute/CPU/WrapShaderEnd.h>

+ 179 - 104
UnitTests/Compute/ComputeTests.cpp

@@ -6,8 +6,8 @@
 
 #include <Jolt/Compute/ComputeSystem.h>
 #include <Jolt/Compute/CPU/ComputeSystemCPU.h>
-#include <Jolt/Shaders/ShaderCore.h>
 #include <Jolt/Shaders/TestComputeBindings.h>
+#include <Jolt/Shaders/TestCompute2Bindings.h>
 #include <Jolt/Core/IncludeWindows.h>
 #include <Jolt/Core/RTTI.h>
 
@@ -24,6 +24,7 @@ JPH_SUPPRESS_WARNINGS_STD_END
 #endif
 
 JPH_DECLARE_REGISTER_SHADER(TestCompute)
+JPH_DECLARE_REGISTER_SHADER(TestCompute2)
 
 TEST_SUITE("ComputeTests")
 {
@@ -105,127 +106,200 @@ TEST_SUITE("ComputeTests")
 			return true;
 		};
 
+		// Create a queue
+		ComputeQueueResult queue_result = inComputeSystem->CreateComputeQueue();
+		CHECK(!queue_result.HasError());
+		Ref<ComputeQueue> queue = queue_result.Get();
+		CHECK(queue != nullptr);
+
 		// Test failing shader creation
 		{
 			ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("NonExistingShader", 64);
 			CHECK(shader_result.HasError());
 		}
 
-		constexpr uint32 cNumElements = 1234; // Not a multiple of cTestComputeGroupSize
-		constexpr uint32 cNumIterations = 10;
-		constexpr JPH_float3 cFloat3Value = JPH_float3(0, 0, 0);
-		constexpr JPH_float3 cFloat3Value2 = JPH_float3(0, 13, 0);
-		constexpr uint32 cUIntValue = 7;
-		constexpr uint32 cUploadValue = 42;
-
-		// Can't change context buffer while commands are queued, so create multiple constant buffers
-		Ref<ComputeBuffer> context[cNumIterations];
-		for (uint32 iter = 0; iter < cNumIterations; ++iter)
-		{
-			ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(TestComputeContext));
-			CHECK(!buffer_result.HasError());
-			context[iter] = buffer_result.Get();
-		}
-		CHECK(context != nullptr);
-
-		// Create an upload buffer
-		ComputeBufferResult upload_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, 1, sizeof(uint32));
-		CHECK(!upload_buffer_result.HasError());
-		Ref<ComputeBuffer> upload_buffer = upload_buffer_result.Get();
-		CHECK(upload_buffer != nullptr);
-		uint32 *upload_data = upload_buffer->Map<uint32>(ComputeBuffer::EMode::Write);
-		upload_data[0] = cUploadValue;
-		upload_buffer->Unmap();
-
-		// Create a read buffer
-		UnitTestRandom rnd;
-		Array<uint32> optional_data(cNumElements);
-		for (uint32 &d : optional_data)
-			d = rnd();
-		ComputeBufferResult optional_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, cNumElements, sizeof(uint32), optional_data.data());
-		CHECK(!optional_buffer_result.HasError());
-		Ref<ComputeBuffer> optional_buffer = optional_buffer_result.Get();
-		CHECK(optional_buffer != nullptr);
-
-		// Create a read-write buffer
-		ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, cNumElements, sizeof(uint32));
-		CHECK(!buffer_result.HasError());
-		Ref<ComputeBuffer> buffer = buffer_result.Get();
-		CHECK(buffer != nullptr);
-
-		// Create a read back buffer
-		ComputeBufferResult readback_buffer_result = buffer->CreateReadBackBuffer();
-		CHECK(!readback_buffer_result.HasError());
-		Ref<ComputeBuffer> readback_buffer = readback_buffer_result.Get();
-		CHECK(readback_buffer != nullptr);
-
-		// Create the shader
-		ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("TestCompute", cTestComputeGroupSize);
-		if (shader_result.HasError())
 		{
-			Trace("Shader could not be created: %s", shader_result.GetError().c_str());
-			return;
-		}
-		Ref<ComputeShader> shader = shader_result.Get();
-		CHECK(shader != nullptr);
+			constexpr uint32 cNumElements = 1234; // Not a multiple of cTestComputeGroupSize
+			constexpr uint32 cNumIterations = 10;
+			constexpr JPH_float3 cFloat3Value = JPH_float3(0, 0, 0);
+			constexpr JPH_float3 cFloat3Value2 = JPH_float3(0, 13, 0);
+			constexpr uint32 cUIntValue = 7;
+			constexpr uint32 cUploadValue = 42;
 
-		// Create the queue
-		ComputeQueueResult queue_result = inComputeSystem->CreateComputeQueue();
-		CHECK(!queue_result.HasError());
-		Ref<ComputeQueue> queue = queue_result.Get();
-		CHECK(queue != nullptr);
+			// Can't change context buffer while commands are queued, so create multiple constant buffers
+			Ref<ComputeBuffer> context[cNumIterations];
+			for (uint32 iter = 0; iter < cNumIterations; ++iter)
+			{
+				ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(TestComputeContext));
+				CHECK(!buffer_result.HasError());
+				context[iter] = buffer_result.Get();
+			}
+			CHECK(context != nullptr);
 
-		// Schedule work
-		for (uint32 iter = 0; iter < cNumIterations; ++iter)
-		{
-			// Fill in the context
-			TestComputeContext *value = context[iter]->Map<TestComputeContext>(ComputeBuffer::EMode::Write);
-			value->cFloat3Value = cFloat3Value;
-			value->cUIntValue = cUIntValue;
-			value->cFloat3Value2 = cFloat3Value2;
-			value->cUIntValue2 = iter;
-			value->cNumElements = cNumElements;
-			context[iter]->Unmap();
+			// Create an upload buffer
+			ComputeBufferResult upload_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, 1, sizeof(uint32));
+			CHECK(!upload_buffer_result.HasError());
+			Ref<ComputeBuffer> upload_buffer = upload_buffer_result.Get();
+			CHECK(upload_buffer != nullptr);
+			uint32 *upload_data = upload_buffer->Map<uint32>(ComputeBuffer::EMode::Write);
+			upload_data[0] = cUploadValue;
+			upload_buffer->Unmap();
 
-			queue->SetShader(shader);
-			queue->SetConstantBuffer("gContext", context[iter]);
-			context[iter] = nullptr; // Release the reference to ensure the queue keeps ownership
-			queue->SetBuffer("gOptionalData", optional_buffer);
-			optional_buffer = nullptr; // Release the reference so we test that the queue keeps ownership and that in the 2nd iteration we can set a null buffer
-			queue->SetBuffer("gUploadData", upload_buffer);
-			queue->SetRWBuffer("gData", buffer);
-			queue->Dispatch((cNumElements + cTestComputeGroupSize - 1) / cTestComputeGroupSize);
-		}
+			// Create a read buffer
+			UnitTestRandom rnd;
+			Array<uint32> optional_data(cNumElements);
+			for (uint32 &d : optional_data)
+				d = rnd();
+			ComputeBufferResult optional_buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, cNumElements, sizeof(uint32), optional_data.data());
+			CHECK(!optional_buffer_result.HasError());
+			Ref<ComputeBuffer> optional_buffer = optional_buffer_result.Get();
+			CHECK(optional_buffer != nullptr);
 
-		// Run all queued commands
-		queue->ScheduleReadback(readback_buffer, buffer);
-		queue->ExecuteAndWait();
+			// Create a read-write buffer
+			ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, cNumElements, sizeof(uint32));
+			CHECK(!buffer_result.HasError());
+			Ref<ComputeBuffer> buffer = buffer_result.Get();
+			CHECK(buffer != nullptr);
 
-		// Calculate the expected result
-		Array<uint32> expected_data(cNumElements);
-		for (uint32 iter = 0; iter < cNumIterations; ++iter)
-		{
-			// Copy of the shader logic
-			uint cUIntValue2 = iter;
-			if (cUIntValue2 == 0)
+			// Create a read back buffer
+			ComputeBufferResult readback_buffer_result = buffer->CreateReadBackBuffer();
+			CHECK(!readback_buffer_result.HasError());
+			Ref<ComputeBuffer> readback_buffer = readback_buffer_result.Get();
+			CHECK(readback_buffer != nullptr);
+
+			// Create the shader
+			ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("TestCompute", cTestComputeGroupSize);
+			if (shader_result.HasError())
 			{
-				// First write, uses optional data and tests that the packing of float3/uint3's works
-				for (uint32 i = 0; i < cNumElements; ++i)
-					expected_data[i] = optional_data[i] + int(cFloat3Value2.y) + cUploadValue;
+				Trace("Shader could not be created: %s", shader_result.GetError().c_str());
+				return;
 			}
-			else
+			Ref<ComputeShader> shader = shader_result.Get();
+			CHECK(shader != nullptr);
+
+			// Schedule work
+			for (uint32 iter = 0; iter < cNumIterations; ++iter)
 			{
-				// Read-modify-write gData
-				for (uint32 i = 0; i < cNumElements; ++i)
-					expected_data[i] = (expected_data[i] + cUIntValue) * cUIntValue2;
+				// Fill in the context
+				TestComputeContext *value = context[iter]->Map<TestComputeContext>(ComputeBuffer::EMode::Write);
+				value->cFloat3Value = cFloat3Value;
+				value->cUIntValue = cUIntValue;
+				value->cFloat3Value2 = cFloat3Value2;
+				value->cUIntValue2 = iter;
+				value->cNumElements = cNumElements;
+				context[iter]->Unmap();
+
+				queue->SetShader(shader);
+				queue->SetConstantBuffer("gContext", context[iter]);
+				context[iter] = nullptr; // Release the reference to ensure the queue keeps ownership
+				queue->SetBuffer("gOptionalData", optional_buffer);
+				optional_buffer = nullptr; // Release the reference so we test that the queue keeps ownership and that in the 2nd iteration we can set a null buffer
+				queue->SetBuffer("gUploadData", upload_buffer);
+				queue->SetRWBuffer("gData", buffer);
+				queue->Dispatch((cNumElements + cTestComputeGroupSize - 1) / cTestComputeGroupSize);
 			}
+
+			// Run all queued commands
+			queue->ScheduleReadback(readback_buffer, buffer);
+			queue->ExecuteAndWait();
+
+			// Calculate the expected result
+			Array<uint32> expected_data(cNumElements);
+			for (uint32 iter = 0; iter < cNumIterations; ++iter)
+			{
+				// Copy of the shader logic
+				uint cUIntValue2 = iter;
+				if (cUIntValue2 == 0)
+				{
+					// First write, uses optional data and tests that the packing of float3/uint3's works
+					for (uint32 i = 0; i < cNumElements; ++i)
+						expected_data[i] = optional_data[i] + int(cFloat3Value2.y) + cUploadValue;
+				}
+				else
+				{
+					// Read-modify-write gData
+					for (uint32 i = 0; i < cNumElements; ++i)
+						expected_data[i] = (expected_data[i] + cUIntValue) * cUIntValue2;
+				}
+			}
+
+			// Compare computed data with expected data
+			uint32 *data = readback_buffer->Map<uint32>(ComputeBuffer::EMode::Read);
+			for (uint32 i = 0; i < cNumElements; ++i)
+				CHECK(data[i] == expected_data[i]);
+			readback_buffer->Unmap();
 		}
 
-		// Compare computed data with expected data
-		uint32 *data = readback_buffer->Map<uint32>(ComputeBuffer::EMode::Read);
-		for (uint32 i = 0; i < cNumElements; ++i)
-			CHECK(data[i] == expected_data[i]);
-		readback_buffer->Unmap();
+		// Test helper functions
+		{
+			// Create the shader
+			ComputeShaderResult shader_result = inComputeSystem->CreateComputeShader("TestCompute2", cTestCompute2GroupSize);
+			if (shader_result.HasError())
+			{
+				Trace("Shader could not be created: %s", shader_result.GetError().c_str());
+				return;
+			}
+			Ref<ComputeShader> shader = shader_result.Get();
+			CHECK(shader != nullptr);
+
+			const Mat44 cMat44Value(Vec4(2, 3, 5, 0), Vec4(7, 11, 13, 0), Vec4(13, 15, 17, 0), Vec4(17, 19, 23, 0));
+			const Vec3 cMat44MulValue(29, 31, 37);
+
+			const Vec3 cDecompressedVec3(Vec3(-2, 3, -5).Normalized());
+			const uint32 cCompressedVec3 = cDecompressedVec3.CompressUnitVector();
+
+			const Quat cDecompressedQuat(Vec4(2, -3, 5, -7).Normalized());
+			const uint32 cCompressedQuat = cDecompressedQuat.CompressUnitQuat();
+
+			// Generate input data
+			TestCompute2Input input;
+			cMat44Value.StoreFloat4x4(input.mMat44Value);
+			cMat44MulValue.StoreFloat3(&input.mMat44MulValue);
+			input.mCompressedVec3 = cCompressedVec3;
+			input.mCompressedQuat = cCompressedQuat;
+
+			// Create input buffer
+			ComputeBufferResult buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, 1, sizeof(TestCompute2Input), &input);
+			CHECK(!buffer_result.HasError());
+			Ref<ComputeBuffer> input_buffer = buffer_result.Get();
+
+			// Create a read-write buffer for the output
+			buffer_result = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, 1, sizeof(TestCompute2Output));
+			CHECK(!buffer_result.HasError());
+			Ref<ComputeBuffer> output_buffer = buffer_result.Get();
+			CHECK(output_buffer != nullptr);
+
+			// Create a read back buffer
+			buffer_result = output_buffer->CreateReadBackBuffer();
+			CHECK(!buffer_result.HasError());
+			Ref<ComputeBuffer> readback_buffer = buffer_result.Get();
+			CHECK(readback_buffer != nullptr);
+
+			// Execute the shader
+			queue->SetShader(shader);
+			queue->SetBuffer("gInput", input_buffer);
+			queue->SetRWBuffer("gOutput", output_buffer);
+			queue->Dispatch(1);
+			queue->ScheduleReadback(readback_buffer, output_buffer);
+			queue->ExecuteAndWait();
+
+			// Verify the output
+			TestCompute2Output *output = readback_buffer->Map<TestCompute2Output>(ComputeBuffer::EMode::Read);
+
+			const Vec3 expected_mul3x4 = cMat44Value * cMat44MulValue;
+			CHECK(Vec3(output->mMul3x4Output) == expected_mul3x4);
+
+			const Vec3 expected_mul3x3 = cMat44Value.Multiply3x3(cMat44MulValue);
+			CHECK(Vec3(output->mMul3x3Output) == expected_mul3x3);
+
+			const Vec3 expected_decompressed_vec3 = Vec3::sDecompressUnitVector(cCompressedVec3);
+			CHECK(Vec3(output->mDecompressedVec3).IsClose(expected_decompressed_vec3));
+
+			const Quat expected_decompressed_quat = Quat::sDecompressUnitQuat(cCompressedQuat);
+			CHECK(Quat(output->mDecompressedQuat).IsClose(expected_decompressed_quat));
+
+			readback_buffer->Unmap();
+		}
 	}
 
 #ifdef JPH_USE_DX12
@@ -282,6 +356,7 @@ TEST_SUITE("ComputeTests")
 		{
 			CHECK(compute_system.Get() != nullptr);
 			JPH_REGISTER_SHADER(StaticCast<ComputeSystemCPU>(compute_system.Get()), TestCompute);
+			JPH_REGISTER_SHADER(StaticCast<ComputeSystemCPU>(compute_system.Get()), TestCompute2);
 			RunTests(compute_system.Get());
 		}
 	}