Bladeren bron

Implemented a CPU compute fallback, mainly for debugging purposes (#1859)

Jorrit Rouwe 3 weken geleden
bovenliggende
commit
65c7e8bba1

+ 31 - 0
Jolt/Compute/CPU/ComputeBufferCPU.cpp

@@ -0,0 +1,31 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+#include <Jolt/Compute/CPU/ComputeBufferCPU.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeBufferCPU::ComputeBufferCPU(EType inType, uint64 inSize, uint inStride, const void *inData) :
+	ComputeBuffer(inType, inSize, inStride)
+{
+	size_t buffer_size = size_t(mSize) * mStride;
+	mData = Allocate(buffer_size);
+	if (inData != nullptr)
+		memcpy(mData, inData, buffer_size);
+}
+
+ComputeBufferCPU::~ComputeBufferCPU()
+{
+	Free(mData);
+}
+
+ComputeBufferResult ComputeBufferCPU::CreateReadBackBuffer() const
+{
+	ComputeBufferResult result;
+	result.Set(const_cast<ComputeBufferCPU *>(this));
+	return result;
+}
+
+JPH_NAMESPACE_END

+ 32 - 0
Jolt/Compute/CPU/ComputeBufferCPU.h

@@ -0,0 +1,32 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeBuffer.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Buffer that can be used with the CPU compute system
+class JPH_EXPORT ComputeBufferCPU final : public ComputeBuffer
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor / destructor
+									ComputeBufferCPU(EType inType, uint64 inSize, uint inStride, const void *inData);
+	virtual							~ComputeBufferCPU() override;
+
+	ComputeBufferResult				CreateReadBackBuffer() const override;
+
+	void *							GetData() const										{ return mData; }
+
+private:
+	virtual void *					MapInternal(EMode inMode) override					{ return mData; }
+	virtual void					UnmapInternal() override							{ /* Nothing to do */ }
+
+	void *							mData;
+};
+
+JPH_NAMESPACE_END

+ 96 - 0
Jolt/Compute/CPU/ComputeQueueCPU.cpp

@@ -0,0 +1,96 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+#include <Jolt/Compute/CPU/ComputeQueueCPU.h>
+#include <Jolt/Compute/CPU/ComputeShaderCPU.h>
+#include <Jolt/Compute/CPU/ComputeBufferCPU.h>
+#include <Jolt/Compute/CPU/ShaderWrapper.h>
+#include <Jolt/Shaders/HLSLToCPP.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeQueueCPU::~ComputeQueueCPU()
+{
+	JPH_ASSERT(mShader == nullptr && mWrapper == nullptr);
+}
+
+void ComputeQueueCPU::SetShader(const ComputeShader *inShader)
+{
+	JPH_ASSERT(mShader == nullptr && mWrapper == nullptr);
+
+	mShader = static_cast<const ComputeShaderCPU *>(inShader);
+	mWrapper = mShader->CreateWrapper();
+}
+
+void ComputeQueueCPU::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
+	const ComputeBufferCPU *buffer = static_cast<const ComputeBufferCPU *>(inBuffer);
+	mWrapper->Bind(inName, buffer->GetData(), buffer->GetSize() * buffer->GetStride());
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueCPU::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+	const ComputeBufferCPU *buffer = static_cast<const ComputeBufferCPU *>(inBuffer);
+	mWrapper->Bind(inName, buffer->GetData(), buffer->GetSize() * buffer->GetStride());
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueCPU::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+	const ComputeBufferCPU *buffer = static_cast<const ComputeBufferCPU *>(inBuffer);
+	mWrapper->Bind(inName, buffer->GetData(), buffer->GetSize() * buffer->GetStride());
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueCPU::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
+{
+	/* Nothing to read back */
+}
+
+void ComputeQueueCPU::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
+{
+	uint nx = inThreadGroupsX * mShader->GetGroupSizeX();
+	uint ny = inThreadGroupsY * mShader->GetGroupSizeY();
+	uint nz = inThreadGroupsZ * mShader->GetGroupSizeZ();
+
+	for (uint z = 0; z < nz; ++z)
+		for (uint y = 0; y < ny; ++y)
+			for (uint x = 0; x < nx; ++x)
+			{
+				HLSLToCPP::uint3 tid { x, y, z };
+				mWrapper->Main(tid);
+			}
+
+	delete mWrapper;
+	mWrapper = nullptr;
+
+	mUsedBuffers.clear();
+	mShader = nullptr;
+}
+
+void ComputeQueueCPU::Execute()
+{
+	/* Nothing to do */
+}
+
+void ComputeQueueCPU::Wait()
+{
+	/* Nothing to do */
+}
+
+JPH_NAMESPACE_END

+ 38 - 0
Jolt/Compute/CPU/ComputeQueueCPU.h

@@ -0,0 +1,38 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeQueue.h>
+#include <Jolt/Compute/CPU/ComputeShaderCPU.h>
+#include <Jolt/Core/UnorderedSet.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// A command queue for the CPU compute system
+class JPH_EXPORT ComputeQueueCPU final : public ComputeQueue
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Destructor
+	virtual								~ComputeQueueCPU() override;
+
+	// See: ComputeQueue
+	virtual void						SetShader(const ComputeShader *inShader) override;
+	virtual void						SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void						SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void 						SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
+	virtual void						ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
+	virtual void						Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
+	virtual void						Execute() override;
+	virtual void						Wait() override;
+
+private:
+	RefConst<ComputeShaderCPU>			mShader = nullptr;							///< Current active shader
+	ShaderWrapper *						mWrapper = nullptr;							///< The active shader wrapper
+	UnorderedSet<RefConst<ComputeBuffer>> mUsedBuffers;								///< Buffers that are in use by the current execution, these will be retained until execution is finished so that we don't free buffers that are in use
+};
+
+JPH_NAMESPACE_END

+ 38 - 0
Jolt/Compute/CPU/ComputeShaderCPU.h

@@ -0,0 +1,38 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeShader.h>
+
+JPH_NAMESPACE_BEGIN
+
+class ShaderWrapper;
+
+/// Compute shader handle for CPU compute
+class JPH_EXPORT ComputeShaderCPU : public ComputeShader
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	using CreateShader = ShaderWrapper *(*)();
+
+	/// Constructor
+									ComputeShaderCPU(CreateShader inCreateShader, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
+		ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ),
+		mCreateShader(inCreateShader)
+	{
+	}
+
+	/// Create an instance of the shader wrapper
+	ShaderWrapper *					CreateWrapper() const
+	{
+		return mCreateShader();
+	}
+
+private:
+	CreateShader					mCreateShader;
+};
+
+JPH_NAMESPACE_END

+ 46 - 0
Jolt/Compute/CPU/ComputeSystemCPU.cpp

@@ -0,0 +1,46 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+#include <Jolt/Compute/CPU/ComputeSystemCPU.h>
+#include <Jolt/Compute/CPU/ComputeQueueCPU.h>
+#include <Jolt/Compute/CPU/ComputeBufferCPU.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeShaderResult ComputeSystemCPU::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
+{
+	ComputeShaderResult result;
+	const ShaderRegistry::const_iterator it = mShaderRegistry.find(inName);
+	if (it == mShaderRegistry.end())
+	{
+		result.SetError("Compute shader not found");
+		return result;
+	}
+	result.Set(new ComputeShaderCPU(it->second, inGroupSizeX, inGroupSizeY, inGroupSizeZ));
+	return result;
+}
+
+ComputeBufferResult ComputeSystemCPU::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
+{
+	ComputeBufferResult result;
+	result.Set(new ComputeBufferCPU(inType, inSize, inStride, inData));
+	return result;
+}
+
+ComputeQueueResult ComputeSystemCPU::CreateComputeQueue()
+{
+	ComputeQueueResult result;
+	result.Set(new ComputeQueueCPU());
+	return result;
+}
+
+ComputeSystemResult CreateComputeSystemCPU()
+{
+	ComputeSystemResult result;
+	result.Set(new ComputeSystemCPU());
+	return result;
+}
+
+JPH_NAMESPACE_END

+ 47 - 0
Jolt/Compute/CPU/ComputeSystemCPU.h

@@ -0,0 +1,47 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeSystem.h>
+#include <Jolt/Core/UnorderedMap.h>
+#include <Jolt/Compute/CPU/ComputeShaderCPU.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Interface to run a workload on the CPU
+/// This is intended mainly for debugging purposes and is not optimized for performance
+class JPH_EXPORT ComputeSystemCPU : public ComputeSystem
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	// See: ComputeSystem
+	virtual ComputeShaderResult  	CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
+	virtual ComputeBufferResult  	CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
+	virtual ComputeQueueResult  	CreateComputeQueue() override;
+
+	using CreateShader = ComputeShaderCPU::CreateShader;
+
+	void							RegisterShader(const char *inName, CreateShader inCreateShader)
+	{
+		mShaderRegistry[inName] = inCreateShader;
+	}
+
+private:
+	using ShaderRegistry = UnorderedMap<string_view, CreateShader>;
+	ShaderRegistry					mShaderRegistry;
+};
+
+// Internal helpers
+#define JPH_SHADER_WRAPPER_FUNCTION_NAME(name)		RegisterShader##name
+#define JPH_SHADER_WRAPPER_FUNCTION(sys, name)		void JPH_EXPORT JPH_SHADER_WRAPPER_FUNCTION_NAME(name)(ComputeSystemCPU *sys)
+
+/// Macro to declare a shader register function
+#define JPH_DECLARE_REGISTER_SHADER(name)			namespace JPH { class ComputeSystemCPU; JPH_SHADER_WRAPPER_FUNCTION(, name); }
+
+/// Macro to register a shader
+#define JPH_REGISTER_SHADER(sys, name)				JPH::JPH_SHADER_WRAPPER_FUNCTION_NAME(name)(sys)
+
+JPH_NAMESPACE_END

+ 25 - 0
Jolt/Compute/CPU/ShaderWrapper.h

@@ -0,0 +1,25 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+JPH_NAMESPACE_BEGIN
+
+namespace HLSLToCPP { struct uint3; }
+
+/// Wraps a compute shader to allow calling it from C++
+class ShaderWrapper
+{
+public:
+	/// Destructor
+	virtual				~ShaderWrapper() = default;
+
+	/// Bind buffer to shader
+	virtual void		Bind(const char *inName, void *inData, uint64 inSize) = 0;
+
+	/// Execute a single shader thread
+	virtual void		Main(const HLSLToCPP::uint3 &inThreadID) = 0;
+};
+
+JPH_NAMESPACE_END

+ 4 - 0
Jolt/Compute/ComputeSystem.h

@@ -40,6 +40,10 @@ using ComputeSystemResult = Result<Ref<ComputeSystem>>;
 extern JPH_EXPORT ComputeSystemResult	CreateComputeSystemVK();
 #endif
 
+/// Factory function to create a compute system that falls back to CPU.
+/// This is intended mainly for debugging purposes and is not optimized for performance
+extern JPH_EXPORT ComputeSystemResult	CreateComputeSystemCPU();
+
 #ifdef JPH_USE_DX12
 
 /// Factory function to create a compute system using DirectX 12

+ 11 - 1
Jolt/Jolt.cmake

@@ -18,6 +18,14 @@ set(JOLT_PHYSICS_SRC_FILES
 	${JOLT_PHYSICS_ROOT}/Compute/ComputeQueue.h
 	${JOLT_PHYSICS_ROOT}/Compute/ComputeSystem.h
 	${JOLT_PHYSICS_ROOT}/Compute/ComputeShader.h
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ComputeQueueCPU.cpp
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ComputeQueueCPU.h
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ComputeBufferCPU.cpp
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ComputeBufferCPU.h
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ComputeSystemCPU.cpp
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ComputeSystemCPU.h
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ComputeShaderCPU.h
+	${JOLT_PHYSICS_ROOT}/Compute/CPU/ShaderWrapper.h
 	${JOLT_PHYSICS_ROOT}/Core/ARMNeon.h
 	${JOLT_PHYSICS_ROOT}/Core/Array.h
 	${JOLT_PHYSICS_ROOT}/Core/Atomics.h
@@ -424,6 +432,9 @@ set(JOLT_PHYSICS_SRC_FILES
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererRecorder.h
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererSimple.cpp
 	${JOLT_PHYSICS_ROOT}/Renderer/DebugRendererSimple.h
+	${JOLT_PHYSICS_ROOT}/Shaders/HLSLToCPP.h
+	${JOLT_PHYSICS_ROOT}/Shaders/ShaderWrapperCreator.h
+	${JOLT_PHYSICS_ROOT}/Shaders/TestComputeWrapper.cpp
 	${JOLT_PHYSICS_ROOT}/Skeleton/SkeletalAnimation.cpp
 	${JOLT_PHYSICS_ROOT}/Skeleton/SkeletalAnimation.h
 	${JOLT_PHYSICS_ROOT}/Skeleton/Skeleton.cpp
@@ -475,7 +486,6 @@ if (JPH_USE_DX12 OR JPH_USE_VK OR JPH_USE_MTL)
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderPlane.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderQuat.h
 		${JOLT_PHYSICS_ROOT}/Shaders/ShaderVec3.h
-		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute.h
 		${JOLT_PHYSICS_ROOT}/Shaders/TestComputeBindings.h
 	)
 endif()

+ 521 - 0
Jolt/Shaders/HLSLToCPP.h

@@ -0,0 +1,521 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+JPH_NAMESPACE_BEGIN
+
+/// Emulates HLSL vector types and operations in C++
+/// Note doesn't emulate things like barriers and group shared memory
+namespace HLSLToCPP {
+
+using std::sqrt;
+using std::min;
+using std::max;
+using std::round;
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// float2
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct float2
+{
+	// Constructors
+	inline				float2() = default;
+	constexpr			float2(float inX, float inY)						: x(inX), y(inY) { }
+	explicit constexpr	float2(float inS)									: x(inS), y(inS) { }
+
+	// Operators
+	constexpr float2 &	operator += (const float2 &inRHS)					{ x += inRHS.x; y += inRHS.y; return *this; }
+	constexpr float2 &	operator -= (const float2 &inRHS)					{ x -= inRHS.x; y -= inRHS.y; return *this; }
+	constexpr float2 &	operator *= (float inRHS)							{ x *= inRHS; y *= inRHS; return *this; }
+	constexpr float2 &	operator /= (float inRHS)							{ x /= inRHS; y /= inRHS; return *this; }
+	constexpr float2 &	operator *= (const float2 &inRHS)					{ x *= inRHS.x; y *= inRHS.y; return *this; }
+	constexpr float2 &	operator /= (const float2 &inRHS)					{ x /= inRHS.x; y /= inRHS.y; return *this; }
+
+	// Equality
+	constexpr bool		operator == (const float2 &inRHS) const				{ return x == inRHS.x && y == inRHS.y; }
+	constexpr bool		operator != (const float2 &inRHS) const				{ return !(*this == inRHS); }
+
+	// Component access
+	const float &		operator [] (uint inIndex) const					{ return (&x)[inIndex]; }
+	float &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
+
+	// Swizzling (note return value is const to prevent assignment to swizzled results)
+	const float2		xy() const											{ return float2(x, y); }
+	const float2		yx() const											{ return float2(y, x); }
+
+	float				x, y;
+};
+
+// Operators
+constexpr float2		operator - (const float2 &inA)						{ return float2(-inA.x, -inA.y); }
+constexpr float2		operator + (const float2 &inA, const float2 &inB)	{ return float2(inA.x + inB.x, inA.y + inB.y); }
+constexpr float2		operator - (const float2 &inA, const float2 &inB)	{ return float2(inA.x - inB.x, inA.y - inB.y); }
+constexpr float2		operator * (const float2 &inA, const float2 &inB)	{ return float2(inA.x * inB.x, inA.y * inB.y); }
+constexpr float2		operator / (const float2 &inA, const float2 &inB)	{ return float2(inA.x / inB.x, inA.y / inB.y); }
+constexpr float2		operator * (const float2 &inA, float inS)			{ return float2(inA.x * inS, inA.y * inS); }
+constexpr float2		operator * (float inS, const float2 &inA)			{ return inA * inS; }
+constexpr float2		operator / (const float2 &inA, float inS)			{ return float2(inA.x / inS, inA.y / inS); }
+
+// Dot product
+constexpr float			dot(const float2 &inA, const float2 &inB)			{ return inA.x * inB.x + inA.y * inB.y; }
+
+// Min value
+constexpr float2		min(const float2 &inA, const float2 &inB)			{ return float2(min(inA.x, inB.x), min(inA.y, inB.y)); }
+
+// Max value
+constexpr float2		max(const float2 &inA, const float2 &inB)			{ return float2(max(inA.x, inB.x), max(inA.y, inB.y)); }
+
+// Length
+inline float			length(const float2 &inV)							{ return sqrt(dot(inV, inV)); }
+
+// Normalization
+inline float2			normalize(const float2 &inV)						{ return inV / length(inV); }
+
+// Rounding to int
+inline float2			round(const float2 &inV)							{ return float2(round(inV.x), round(inV.y)); }
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// float3
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct float3
+{
+	// Constructors
+	inline				float3() = default;
+	constexpr			float3(const float2 &inV, float inZ)				: x(inV.x), y(inV.y), z(inZ) { }
+	constexpr			float3(float inX, float inY, float inZ)				: x(inX), y(inY), z(inZ) { }
+	explicit constexpr	float3(float inS)									: x(inS), y(inS), z(inS) { }
+	explicit constexpr	float3(const struct uint3 &inV);
+
+	// Operators
+	constexpr float3 &	operator += (const float3 &inRHS)					{ x += inRHS.x; y += inRHS.y; z += inRHS.z; return *this; }
+	constexpr float3 &	operator -= (const float3 &inRHS)					{ x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; return *this; }
+	constexpr float3 &	operator *= (float inRHS)							{ x *= inRHS; y *= inRHS; z *= inRHS; return *this; }
+	constexpr float3 &	operator /= (float inRHS)							{ x /= inRHS; y /= inRHS; z /= inRHS; return *this; }
+	constexpr float3 &	operator *= (const float3 &inRHS)					{ x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; return *this; }
+	constexpr float3 &	operator /= (const float3 &inRHS)					{ x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; return *this; }
+
+	// Equality
+	constexpr bool		operator == (const float3 &inRHS) const				{ return x == inRHS.x && y == inRHS.y && z == inRHS.z; }
+	constexpr bool		operator != (const float3 &inRHS) const				{ return !(*this == inRHS); }
+
+	// Component access
+	const float &		operator [] (uint inIndex) const					{ return (&x)[inIndex]; }
+	float &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
+
+	// Swizzling (note return value is const to prevent assignment to swizzled results)
+	const float2		xy() const											{ return float2(x, y); }
+	const float2		yx() const											{ return float2(y, x); }
+	const float3		xyz() const											{ return float3(x, y, z); }
+	const float3		xzy() const											{ return float3(x, z, y); }
+	const float3		yxz() const											{ return float3(y, x, z); }
+	const float3		yzx() const											{ return float3(y, z, x); }
+	const float3		zxy() const											{ return float3(z, x, y); }
+	const float3		zyx() const											{ return float3(z, y, x); }
+
+	float				x, y, z;
+};
+
+// Operators
+constexpr float3		operator - (const float3 &inA)						{ return float3(-inA.x, -inA.y, -inA.z); }
+constexpr float3		operator + (const float3 &inA, const float3 &inB)	{ return float3(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z); }
+constexpr float3		operator - (const float3 &inA, const float3 &inB)	{ return float3(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z); }
+constexpr float3		operator * (const float3 &inA, const float3 &inB)	{ return float3(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z); }
+constexpr float3		operator / (const float3 &inA, const float3 &inB)	{ return float3(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z); }
+constexpr float3		operator * (const float3 &inA, float inS)			{ return float3(inA.x * inS, inA.y * inS, inA.z * inS); }
+constexpr float3		operator * (float inS, const float3 &inA)			{ return inA * inS; }
+constexpr float3		operator / (const float3 &inA, float inS)			{ return float3(inA.x / inS, inA.y / inS, inA.z / inS); }
+
+// Dot product
+constexpr float			dot(const float3 &inA, const float3 &inB)			{ return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z; }
+
+// Min value
+constexpr float3		min(const float3 &inA, const float3 &inB)			{ return float3(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z)); }
+
+// Max value
+constexpr float3		max(const float3 &inA, const float3 &inB)			{ return float3(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z)); }
+
+// Length
+inline float			length(const float3 &inV)							{ return sqrt(dot(inV, inV)); }
+
+// Normalization
+inline float3			normalize(const float3 &inV)						{ return inV / length(inV); }
+
+// Rounding to int
+inline float3			round(const float3 &inV)							{ return float3(round(inV.x), round(inV.y), round(inV.z)); }
+
+// Cross product
+constexpr float3		cross(const float3 &inA, const float3 &inB)			{ return float3(inA.y * inB.z - inA.z * inB.y, inA.z * inB.x - inA.x * inB.z, inA.x * inB.y - inA.y * inB.x); }
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// float4
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct float4
+{
+	// Constructors
+	inline				float4() = default;
+	constexpr			float4(const float3 &inV, float inW)				: x(inV.x), y(inV.y), z(inV.z), w(inW) { }
+	constexpr			float4(float inX, float inY, float inZ, float inW)	: x(inX), y(inY), z(inZ), w(inW) { }
+	explicit constexpr	float4(float inS)									: x(inS), y(inS), z(inS), w(inS) { }
+	explicit constexpr	float4(const struct int4 &inV);
+
+	// Operators
+	constexpr float4 &	operator += (const float4 &inRHS)					{ x += inRHS.x; y += inRHS.y; z += inRHS.z; w += inRHS.w; return *this; }
+	constexpr float4 &	operator -= (const float4 &inRHS)					{ x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; w -= inRHS.w; return *this; }
+	constexpr float4 &	operator *= (float inRHS)							{ x *= inRHS; y *= inRHS; z *= inRHS; w *= inRHS; return *this; }
+	constexpr float4 &	operator /= (float inRHS)							{ x /= inRHS; y /= inRHS; z /= inRHS; w /= inRHS; return *this; }
+	constexpr float4 &	operator *= (const float4 &inRHS)					{ x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; w *= inRHS.w; return *this; }
+	constexpr float4 &	operator /= (const float4 &inRHS)					{ x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; w /= inRHS.w; return *this; }
+
+	// Equality
+	constexpr bool		operator == (const float4 &inRHS) const				{ return x == inRHS.x && y == inRHS.y && z == inRHS.z && w == inRHS.w; }
+	constexpr bool		operator != (const float4 &inRHS) const				{ return !(*this == inRHS); }
+
+	// Component access
+	const float &		operator [] (uint inIndex) const					{ return (&x)[inIndex]; }
+	float &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
+
+	// Swizzling (note return value is const to prevent assignment to swizzled results)
+	const float2		xy() const											{ return float2(x, y); }
+	const float2		yx() const											{ return float2(y, x); }
+	const float3		xyz() const											{ return float3(x, y, z); }
+	const float3		xzy() const											{ return float3(x, z, y); }
+	const float3		yxz() const											{ return float3(y, x, z); }
+	const float3		yzx() const											{ return float3(y, z, x); }
+	const float3		zxy() const											{ return float3(z, x, y); }
+	const float3		zyx() const											{ return float3(z, y, x); }
+	const float4		xywz() const										{ return float4(x, y, w, z); }
+	const float4		xwyz() const										{ return float4(x, w, y, z); }
+	const float4		wxyz() const										{ return float4(w, x, y, z); }
+
+	float				x, y, z, w;
+};
+
+// Operators
+constexpr float4		operator - (const float4 &inA)						{ return float4(-inA.x, -inA.y, -inA.z, -inA.w); }
+constexpr float4		operator + (const float4 &inA, const float4 &inB)	{ return float4(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z, inA.w + inB.w); }
+constexpr float4		operator - (const float4 &inA, const float4 &inB)	{ return float4(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z, inA.w - inB.w); }
+constexpr float4		operator * (const float4 &inA, const float4 &inB)	{ return float4(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z, inA.w * inB.w); }
+constexpr float4		operator / (const float4 &inA, const float4 &inB)	{ return float4(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z, inA.w / inB.w); }
+constexpr float4		operator * (const float4 &inA, float inS)			{ return float4(inA.x * inS, inA.y * inS, inA.z * inS, inA.w * inS); }
+constexpr float4		operator * (float inS, const float4 &inA)			{ return inA * inS; }
+constexpr float4		operator / (const float4 &inA, float inS)			{ return float4(inA.x / inS, inA.y / inS, inA.z / inS, inA.w / inS); }
+
+// Dot product
+constexpr float			dot(const float4 &inA, const float4 &inB)			{ return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z + inA.w * inB.w; }
+
+// Min value
+constexpr float4		min(const float4 &inA, const float4 &inB)			{ return float4(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z), min(inA.w, inB.w)); }
+
+// Max value
+constexpr float4		max(const float4 &inA, const float4 &inB)			{ return float4(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z), max(inA.w, inB.w)); }
+
+// Length
+inline float			length(const float4 &inV)							{ return sqrt(dot(inV, inV)); }
+
+// Normalization
+inline float4			normalize(const float4 &inV)						{ return inV / length(inV); }
+
+// Rounding to int
+inline float4			round(const float4 &inV)							{ return float4(round(inV.x), round(inV.y), round(inV.z), round(inV.w)); }
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// uint3
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct uint3
+{
+	inline				uint3() = default;
+	constexpr			uint3(uint32 inX, uint32 inY, uint32 inZ)			: x(inX), y(inY), z(inZ) { }
+	explicit constexpr	uint3(const float3 &inV)							: x(uint32(inV.x)), y(uint32(inV.y)), z(uint32(inV.z)) { }
+
+	// Operators
+	constexpr uint3 &	operator += (const uint3 &inRHS)					{ x += inRHS.x; y += inRHS.y; z += inRHS.z; return *this; }
+	constexpr uint3 &	operator -= (const uint3 &inRHS)					{ x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; return *this; }
+	constexpr uint3 &	operator *= (uint32 inRHS)							{ x *= inRHS; y *= inRHS; z *= inRHS; return *this; }
+	constexpr uint3 &	operator /= (uint32 inRHS)							{ x /= inRHS; y /= inRHS; z /= inRHS; return *this; }
+	constexpr uint3 &	operator *= (const uint3 &inRHS)					{ x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; return *this; }
+	constexpr uint3 &	operator /= (const uint3 &inRHS)					{ x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; return *this; }
+
+	// Equality
+	constexpr bool		operator == (const uint3 &inRHS) const				{ return x == inRHS.x && y == inRHS.y && z == inRHS.z; }
+	constexpr bool		operator != (const uint3 &inRHS) const				{ return !(*this == inRHS); }
+
+	// Component access
+	const uint32 &		operator [] (uint inIndex) const					{ return (&x)[inIndex]; }
+	uint32 &			operator [] (uint inIndex)							{ return (&x)[inIndex]; }
+
+	// Swizzling (note return value is const to prevent assignment to swizzled results)
+	const uint3			xyz() const											{ return uint3(x, y, z); }
+	const uint3			xzy() const											{ return uint3(x, z, y); }
+	const uint3			yxz() const											{ return uint3(y, x, z); }
+	const uint3			yzx() const											{ return uint3(y, z, x); }
+	const uint3			zxy() const											{ return uint3(z, x, y); }
+	const uint3			zyx() const											{ return uint3(z, y, x); }
+
+	uint32				x, y, z;
+};
+
+// Operators
+constexpr uint3			operator + (const uint3 &inA, const uint3 &inB)		{ return uint3(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z); }
+constexpr uint3			operator - (const uint3 &inA, const uint3 &inB)		{ return uint3(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z); }
+constexpr uint3			operator * (const uint3 &inA, const uint3 &inB)		{ return uint3(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z); }
+constexpr uint3			operator / (const uint3 &inA, const uint3 &inB)		{ return uint3(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z); }
+constexpr uint3			operator * (const uint3 &inA, uint32 inS)			{ return uint3(inA.x * inS, inA.y * inS, inA.z * inS); }
+constexpr uint3			operator * (uint32 inS, const uint3 &inA)			{ return inA * inS; }
+constexpr uint3			operator / (const uint3 &inA, uint32 inS)			{ return uint3(inA.x / inS, inA.y / inS, inA.z / inS); }
+
+// Dot product
+constexpr uint32		dot(const uint3 &inA, const uint3 &inB)				{ return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z; }
+
+// Min value
+constexpr uint3			min(const uint3 &inA, const uint3 &inB)				{ return uint3(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z)); }
+
+// Max value
+constexpr uint3			max(const uint3 &inA, const uint3 &inB)				{ return uint3(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z)); }
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// uint4
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct uint4
+{
+	// Constructors
+	inline				uint4() = default;
+	constexpr			uint4(const uint3 &inV, uint32 inW)					: x(inV.x), y(inV.y), z(inV.z), w(inW) { }
+	constexpr			uint4(uint32 inX, uint32 inY, uint32 inZ, uint32 inW) : x(inX), y(inY), z(inZ), w(inW) { }
+	explicit constexpr	uint4(uint32 inS)									: x(inS), y(inS), z(inS), w(inS) { }
+
+	// Operators
+	constexpr uint4 &	operator += (const uint4 &inRHS)					{ x += inRHS.x; y += inRHS.y; z += inRHS.z; w += inRHS.w; return *this; }
+	constexpr uint4 &	operator -= (const uint4 &inRHS)					{ x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; w -= inRHS.w; return *this; }
+	constexpr uint4 &	operator *= (uint32 inRHS)							{ x *= inRHS; y *= inRHS; z *= inRHS; w *= inRHS; return *this; }
+	constexpr uint4 &	operator /= (uint32 inRHS)							{ x /= inRHS; y /= inRHS; z /= inRHS; w /= inRHS; return *this; }
+	constexpr uint4 &	operator *= (const uint4 &inRHS)					{ x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; w *= inRHS.w; return *this; }
+	constexpr uint4 &	operator /= (const uint4 &inRHS)					{ x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; w /= inRHS.w; return *this; }
+
+	// Equality
+	constexpr bool		operator == (const uint4 &inRHS) const				{ return x == inRHS.x && y == inRHS.y && z == inRHS.z && w == inRHS.w; }
+	constexpr bool		operator != (const uint4 &inRHS) const				{ return !(*this == inRHS); }
+
+	// Component access
+	const uint32 &		operator [] (uint inIndex) const					{ return (&x)[inIndex]; }
+	uint32 &			operator [] (uint inIndex)							{ return (&x)[inIndex]; }
+
+	// Swizzling (note return value is const to prevent assignment to swizzled results)
+	const uint3			xyz() const											{ return uint3(x, y, z); }
+	const uint3			xzy() const											{ return uint3(x, z, y); }
+	const uint3			yxz() const											{ return uint3(y, x, z); }
+	const uint3			yzx() const											{ return uint3(y, z, x); }
+	const uint3			zxy() const											{ return uint3(z, x, y); }
+	const uint3			zyx() const											{ return uint3(z, y, x); }
+	const uint4			xywz() const										{ return uint4(x, y, w, z); }
+	const uint4			xwyz() const										{ return uint4(x, w, y, z); }
+	const uint4			wxyz() const										{ return uint4(w, x, y, z); }
+
+	uint32				x, y, z, w;
+};
+
+// Operators
+constexpr uint4			operator + (const uint4 &inA, const uint4 &inB)		{ return uint4(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z, inA.w + inB.w); }
+constexpr uint4			operator - (const uint4 &inA, const uint4 &inB)		{ return uint4(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z, inA.w - inB.w); }
+constexpr uint4			operator * (const uint4 &inA, const uint4 &inB)		{ return uint4(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z, inA.w * inB.w); }
+constexpr uint4			operator / (const uint4 &inA, const uint4 &inB)		{ return uint4(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z, inA.w / inB.w); }
+constexpr uint4			operator * (const uint4 &inA, uint32 inS)			{ return uint4(inA.x * inS, inA.y * inS, inA.z * inS, inA.w * inS); }
+constexpr uint4			operator * (uint32 inS, const uint4 &inA)			{ return inA * inS; }
+constexpr uint4			operator / (const uint4 &inA, uint32 inS)			{ return uint4(inA.x / inS, inA.y / inS, inA.z / inS, inA.w / inS); }
+
+// Dot product
+constexpr uint32		dot(const uint4 &inA, const uint4 &inB)				{ return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z + inA.w * inB.w; }
+
+// Min value
+constexpr uint4			min(const uint4 &inA, const uint4 &inB)				{ return uint4(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z), min(inA.w, inB.w)); }
+
+// Max value
+constexpr uint4			max(const uint4 &inA, const uint4 &inB)				{ return uint4(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z), max(inA.w, inB.w)); }
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// int3
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct int3
+{
+	inline				int3() = default;
+	constexpr			int3(int inX, int inY, int inZ)						: x(inX), y(inY), z(inZ) { }
+	explicit constexpr	int3(const float3 &inV)								: x(int(inV.x)), y(int(inV.y)), z(int(inV.z)) { }
+
+	// Operators
+	constexpr int3 &	operator += (const int3 &inRHS)						{ x += inRHS.x; y += inRHS.y; z += inRHS.z; return *this; }
+	constexpr int3 &	operator -= (const int3 &inRHS)						{ x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; return *this; }
+	constexpr int3 &	operator *= (int inRHS)								{ x *= inRHS; y *= inRHS; z *= inRHS; return *this; }
+	constexpr int3 &	operator /= (int inRHS)								{ x /= inRHS; y /= inRHS; z /= inRHS; return *this; }
+	constexpr int3 &	operator *= (const int3 &inRHS)						{ x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; return *this; }
+	constexpr int3 &	operator /= (const int3 &inRHS)						{ x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; return *this; }
+
+	// Equality
+	constexpr bool		operator == (const int3 &inRHS) const				{ return x == inRHS.x && y == inRHS.y && z == inRHS.z; }
+	constexpr bool		operator != (const int3 &inRHS) const				{ return !(*this == inRHS); }
+
+	// Component access
+	const int &			operator [] (uint inIndex) const					{ return (&x)[inIndex]; }
+	int &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
+
+	// Swizzling (note return value is const to prevent assignment to swizzled results)
+	const int3			xyz() const											{ return int3(x, y, z); }
+	const int3			xzy() const											{ return int3(x, z, y); }
+	const int3			yxz() const											{ return int3(y, x, z); }
+	const int3			yzx() const											{ return int3(y, z, x); }
+	const int3			zxy() const											{ return int3(z, x, y); }
+	const int3			zyx() const											{ return int3(z, y, x); }
+
+	int					x, y, z;
+};
+
+// Operators
+constexpr int3			operator - (const int3 &inA)						{ return int3(-inA.x, -inA.y, -inA.z); }
+constexpr int3			operator + (const int3 &inA, const int3 &inB)		{ return int3(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z); }
+constexpr int3			operator - (const int3 &inA, const int3 &inB)		{ return int3(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z); }
+constexpr int3			operator * (const int3 &inA, const int3 &inB)		{ return int3(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z); }
+constexpr int3			operator / (const int3 &inA, const int3 &inB)		{ return int3(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z); }
+constexpr int3			operator * (const int3 &inA, int inS)				{ return int3(inA.x * inS, inA.y * inS, inA.z * inS); }
+constexpr int3			operator * (int inS, const int3 &inA)				{ return inA * inS; }
+constexpr int3			operator / (const int3 &inA, int inS)				{ return int3(inA.x / inS, inA.y / inS, inA.z / inS); }
+
+// Dot product
+constexpr int			dot(const int3 &inA, const int3 &inB)				{ return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z; }
+
+// Min value
+constexpr int3			min(const int3 &inA, const int3 &inB)				{ return int3(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z)); }
+
+// Max value
+constexpr int3			max(const int3 &inA, const int3 &inB)				{ return int3(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z)); }
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// int4
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct int4
+{
+	// Constructors
+	inline				int4() = default;
+	constexpr			int4(const int3 &inV, int inW)						: x(inV.x), y(inV.y), z(inV.z), w(inW) { }
+	constexpr			int4(int inX, int inY, int inZ, int inW)			: x(inX), y(inY), z(inZ), w(inW) { }
+	explicit constexpr	int4(int inS)										: x(inS), y(inS), z(inS), w(inS) { }
+	explicit constexpr	int4(const float4 &inV)								: x(int(inV.x)), y(int(inV.y)), z(int(inV.z)), w(int(inV.w)) { }
+
+	// Operators
+	constexpr int4 &	operator += (const int4 &inRHS)						{ x += inRHS.x; y += inRHS.y; z += inRHS.z; w += inRHS.w; return *this; }
+	constexpr int4 &	operator -= (const int4 &inRHS)						{ x -= inRHS.x; y -= inRHS.y; z -= inRHS.z; w -= inRHS.w; return *this; }
+	constexpr int4 &	operator *= (int inRHS)								{ x *= inRHS; y *= inRHS; z *= inRHS; w *= inRHS; return *this; }
+	constexpr int4 &	operator /= (int inRHS)								{ x /= inRHS; y /= inRHS; z /= inRHS; w /= inRHS; return *this; }
+	constexpr int4 &	operator *= (const int4 &inRHS)						{ x *= inRHS.x; y *= inRHS.y; z *= inRHS.z; w *= inRHS.w; return *this; }
+	constexpr int4 &	operator /= (const int4 &inRHS)						{ x /= inRHS.x; y /= inRHS.y; z /= inRHS.z; w /= inRHS.w; return *this; }
+
+	// Equality
+	constexpr bool		operator == (const int4 &inRHS) const				{ return x == inRHS.x && y == inRHS.y && z == inRHS.z && w == inRHS.w; }
+	constexpr bool		operator != (const int4 &inRHS) const				{ return !(*this == inRHS); }
+
+	// Component access
+	const int &			operator [] (uint inIndex) const					{ return (&x)[inIndex]; }
+	int &				operator [] (uint inIndex)							{ return (&x)[inIndex]; }
+
+	// Swizzling (note return value is const to prevent assignment to swizzled results)
+	const int3			xyz() const											{ return int3(x, y, z); }
+	const int3			xzy() const											{ return int3(x, z, y); }
+	const int3			yxz() const											{ return int3(y, x, z); }
+	const int3			yzx() const											{ return int3(y, z, x); }
+	const int3			zxy() const											{ return int3(z, x, y); }
+	const int3			zyx() const											{ return int3(z, y, x); }
+	const int4			xywz() const										{ return int4(x, y, w, z); }
+	const int4			xwyz() const										{ return int4(x, w, y, z); }
+	const int4			wxyz() const										{ return int4(w, x, y, z); }
+
+	int					x, y, z, w;
+};
+
+// Operators
+constexpr int4			operator - (const int4 &inA)						{ return int4(-inA.x, -inA.y, -inA.z, -inA.w); }
+constexpr int4			operator + (const int4 &inA, const int4 &inB)		{ return int4(inA.x + inB.x, inA.y + inB.y, inA.z + inB.z, inA.w + inB.w); }
+constexpr int4			operator - (const int4 &inA, const int4 &inB)		{ return int4(inA.x - inB.x, inA.y - inB.y, inA.z - inB.z, inA.w - inB.w); }
+constexpr int4			operator * (const int4 &inA, const int4 &inB)		{ return int4(inA.x * inB.x, inA.y * inB.y, inA.z * inB.z, inA.w * inB.w); }
+constexpr int4			operator / (const int4 &inA, const int4 &inB)		{ return int4(inA.x / inB.x, inA.y / inB.y, inA.z / inB.z, inA.w / inB.w); }
+constexpr int4			operator * (const int4 &inA, int inS)				{ return int4(inA.x * inS, inA.y * inS, inA.z * inS, inA.w * inS); }
+constexpr int4			operator * (int inS, const int4 &inA)				{ return inA * inS; }
+constexpr int4			operator / (const int4 &inA, int inS)				{ return int4(inA.x / inS, inA.y / inS, inA.z / inS, inA.w / inS); }
+
+// Dot product
+constexpr int			dot(const int4 &inA, const int4 &inB)				{ return inA.x * inB.x + inA.y * inB.y + inA.z * inB.z + inA.w * inB.w; }
+
+// Min value
+constexpr int4			min(const int4 &inA, const int4 &inB)				{ return int4(min(inA.x, inB.x), min(inA.y, inB.y), min(inA.z, inB.z), min(inA.w, inB.w)); }
+
+// Max value
+constexpr int4			max(const int4 &inA, const int4 &inB)				{ return int4(max(inA.x, inB.x), max(inA.y, inB.y), max(inA.z, inB.z), max(inA.w, inB.w)); }
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// Mat44
+//////////////////////////////////////////////////////////////////////////////////////////
+
+struct Mat44
+{
+	// Constructors
+	inline				Mat44() = default;
+	constexpr 			Mat44(const float4 &inC0, const float4 &inC1, const float4 &inC2, const float4 &inC3) : c { inC0, inC1, inC2, inC3 } { }
+
+	// Columns
+	float4 &			operator [] (uint inIndex)							{ return c[inIndex]; }
+	const float4 &		operator [] (uint inIndex) const					{ return c[inIndex]; }
+
+private:
+	float4				c[4];
+};
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// Other types
+//////////////////////////////////////////////////////////////////////////////////////////
+
+using Quat = float4;
+using Plane = float4;
+
+// Clamp value
+template <class T>
+constexpr T				clamp(const T &inValue, const T &inMinValue, const T &inMaxValue)
+{
+	return min(max(inValue, inMinValue), inMaxValue);
+}
+
+// Atomic add
+template <class T>
+T						JPH_AtomicAdd(T &ioT, const T &inValue)
+{
+	std::atomic<T> *value = reinterpret_cast<std::atomic<T> *>(&ioT);
+	return value->fetch_add(inValue) + inValue;
+}
+
+// Bitcast float4 to int4
+inline int4				asint(const float4 &inV)							{ return int4(BitCast<int>(inV.x), BitCast<int>(inV.y), BitCast<int>(inV.z), BitCast<int>(inV.w)); }
+
+// Functions that couldn't be declared earlier
+constexpr				float3::float3(const uint3 &inV)					: x(float(inV.x)), y(float(inV.y)), z(float(inV.z)) { }
+constexpr				float4::float4(const int4 &inV)						: x(float(inV.x)), y(float(inV.y)), z(float(inV.z)), w(float(inV.w)) { }
+
+// Swizzle operators
+#define xy				xy()
+#define yx				yx()
+#define xyz				xyz()
+#define xzy				xzy()
+#define yxz				yxz()
+#define yzx				yzx()
+#define zxy				zxy()
+#define zyx				zyx()
+#define xywz			xywz()
+#define xwyz			xwyz()
+#define wxyz			wxyz()
+
+} // HLSLToCPP
+
+JPH_NAMESPACE_END

+ 8 - 0
Jolt/Shaders/ShaderCore.h

@@ -21,6 +21,8 @@
 	using JPH_Plane = JPH::Float4;
 	using JPH_Mat44 = JPH::Float4[4]; // matrix, column major
 
+	#define JPH_SHADER_CONSTANT(type, name, value)	constexpr type name = value;
+
 	#define JPH_SHADER_CONSTANTS_BEGIN(type, name)	struct type {
 	#define JPH_SHADER_CONSTANTS_MEMBER(type, name)	type c##name;
 	#define JPH_SHADER_CONSTANTS_END				};
@@ -47,6 +49,8 @@
 	typedef float4 JPH_Plane; // xyz = normal, w = constant
 	typedef float4x4 JPH_Mat44; // matrix, column major
 
+	#define JPH_SHADER_CONSTANT(type, name, value)	static const type name = value;
+
 	#define JPH_SHADER_CONSTANTS_BEGIN(type, name)	cbuffer name {
 	#define JPH_SHADER_CONSTANTS_MEMBER(type, name)	type c##name;
 	#define JPH_SHADER_CONSTANTS_END				};
@@ -72,4 +76,8 @@
 #define JPH_SHADER_STRUCT_MEMBER(type, name)		type m##name;
 #define JPH_SHADER_STRUCT_END						};
 
+#define JPH_IN(type)								in type
+#define JPH_OUT(type)								out type
+#define JPH_IN_OUT(type)							in out type
+
 #endif // JPH_OVERRIDE_SHADER_MACROS

+ 1 - 1
Jolt/Shaders/ShaderMath.h

@@ -9,7 +9,7 @@ inline float JPH_Square(float inValue)
 }
 
 // Get the closest point on a line segment defined by inA + x * inAB for x e [0, 1] to the origin
-inline float3 JPH_GetClosestPointOnLine(in float3 inA, in float3 inAB)
+inline float3 JPH_GetClosestPointOnLine(float3 inA, float3 inAB)
 {
 	float v = clamp(-dot(inA, inAB) / dot(inAB, inAB), 0.0f, 1.0f);
 	return inA + v * inAB;

+ 5 - 5
Jolt/Shaders/ShaderQuat.h

@@ -4,10 +4,10 @@
 
 inline float3 JPH_QuatMulVec3(JPH_Quat inLHS, float3 inRHS)
 {
-	float3 xyz = inLHS.xyz;
-	float3 yzx = inLHS.yzx;
-	float3 q_cross_p = (inRHS.yzx * xyz - yzx * inRHS).yzx;
-	float3 q_cross_q_cross_p = (q_cross_p.yzx * xyz - yzx * q_cross_p).yzx;
+	float3 v_xyz = inLHS.xyz;
+	float3 v_yzx = inLHS.yzx;
+	float3 q_cross_p = (inRHS.yzx * v_xyz - v_yzx * inRHS).yzx;
+	float3 q_cross_q_cross_p = (q_cross_p.yzx * v_xyz - v_yzx * q_cross_p).yzx;
 	float3 v = inLHS.w * q_cross_p + q_cross_q_cross_p;
 	return inRHS + (v + v);
 }
@@ -49,7 +49,7 @@ inline JPH_Quat JPH_QuatDecompress(uint inValue)
 	const float cScale = 2.0f * cOneOverSqrt2 / float(cMaxValue);
 
 	// Restore two components
-	float3 v3 = float3(inValue & cMask, (inValue >> cNumBits) & cMask, (inValue >> (2 * cNumBits)) & cMask) * cScale - float3(cOneOverSqrt2, cOneOverSqrt2, cOneOverSqrt2);
+	float3 v3 = float3(float(inValue & cMask), float((inValue >> cNumBits) & cMask), float((inValue >> (2 * cNumBits)) & cMask)) * cScale - float3(cOneOverSqrt2, cOneOverSqrt2, cOneOverSqrt2);
 
 	// Restore the highest component
 	float4 v = float4(v3, sqrt(max(1.0f - dot(v3, v3), 0.0f)));

+ 1 - 1
Jolt/Shaders/ShaderVec3.h

@@ -11,7 +11,7 @@ inline float3 JPH_Vec3DecompressUnit(uint inValue)
 	const float cScale = 2.0f * cOneOverSqrt2 / float(cMaxValue);
 
 	// Restore two components
-	float2 v2 = float2(inValue & cMask, (inValue >> cNumBits) & cMask) * cScale - float2(cOneOverSqrt2, cOneOverSqrt2);
+	float2 v2 = float2(float(inValue & cMask), float((inValue >> cNumBits) & cMask)) * cScale - float2(cOneOverSqrt2, cOneOverSqrt2);
 
 	// Restore the highest component
 	float3 v = float3(v2, sqrt(max(1.0f - dot(v2, v2), 0.0f)));

+ 156 - 0
Jolt/Shaders/ShaderWrapperCreator.h

@@ -0,0 +1,156 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Core/HashCombine.h>
+#include <Jolt/Compute/CPU/ComputeSystemCPU.h>
+#include <Jolt/Compute/CPU/ShaderWrapper.h>
+#include <Jolt/Shaders/HLSLToCPP.h>
+
+JPH_NAMESPACE_BEGIN
+
+#define JPH_SHADER_OVERRIDE_MACROS
+
+using namespace HLSLToCPP;
+
+#define JPH_SHADER_CONSTANT(type, name, value)	inline static constexpr type name = value;
+
+#define JPH_SHADER_CONSTANTS_BEGIN(type, name)	struct type { alignas(16) int dummy; } name; // Ensure that the first constant is 16 byte aligned
+#define JPH_SHADER_CONSTANTS_MEMBER(type, name)	type c##name;
+#define JPH_SHADER_CONSTANTS_END
+
+#define JPH_SHADER_BUFFER(type)					const type *
+#define JPH_SHADER_RW_BUFFER(type)				type *
+
+#define JPH_SHADER_BIND_BEGIN(name)
+#define JPH_SHADER_BIND_END
+#define JPH_SHADER_BIND_BUFFER(type, name)		const type *name = nullptr;
+#define JPH_SHADER_BIND_RW_BUFFER(type, name)	type *name = nullptr;
+
+#define JPH_SHADER_FUNCTION_BEGIN(return_type, name, group_size_x, group_size_y, group_size_z) \
+		virtual void Main(
+#define JPH_SHADER_PARAM_THREAD_ID(name)		const HLSLToCPP::uint3 &name
+#define JPH_SHADER_FUNCTION_END					) override
+
+#define JPH_SHADER_STRUCT_BEGIN(name)			struct name {
+#define JPH_SHADER_STRUCT_MEMBER(type, name)	type m##name;
+#define JPH_SHADER_STRUCT_END					};
+
+#define JPH_TO_STRING(name)						JPH_TO_STRING2(name)
+#define JPH_TO_STRING2(name)					#name
+
+#define JPH_SHADER_CLASS_NAME(name)				JPH_SHADER_CLASS_NAME2(name)
+#define JPH_SHADER_CLASS_NAME2(name)			name##ShaderWrapper
+
+#define JPH_SHADER_HEADER_NAME(name)			JPH_TO_STRING(name.hlsl)
+
+#define JPH_BINDINGS_HEADER_NAME(name)			JPH_BINDINGS_HEADER_NAME2(name)
+#define JPH_BINDINGS_HEADER_NAME2(name)			JPH_TO_STRING(name##Bindings.h)
+
+#define JPH_IN(type)							const type &
+#define JPH_OUT(type)							type &
+#define JPH_IN_OUT(type)						type &
+
+/// @cond INTERNAL
+class JPH_SHADER_CLASS_NAME(JPH_SHADER_NAME) : public ShaderWrapper
+{
+public:
+	// Define types
+	using JPH_float = float;
+	using JPH_float3 = HLSLToCPP::float3;
+	using JPH_float4 = HLSLToCPP::float4;
+	using JPH_uint = uint;
+	using JPH_uint3 = HLSLToCPP::uint3;
+	using JPH_uint4 = HLSLToCPP::uint4;
+	using JPH_int = int;
+	using JPH_int3 = HLSLToCPP::int3;
+	using JPH_int4 = HLSLToCPP::int4;
+	using JPH_Quat = HLSLToCPP::Quat;
+	using JPH_Plane = HLSLToCPP::Plane;
+	using JPH_Mat44 = HLSLToCPP::Mat44;
+
+	// Include the actual shader
+	#include JPH_SHADER_HEADER_NAME(JPH_SHADER_NAME)
+
+	/// Bind a buffer to the shader
+	virtual void			Bind(const char *inName, void *inData, uint64 inSize) override
+	{
+		// Don't redefine constants
+		#undef JPH_SHADER_CONSTANT
+		#define JPH_SHADER_CONSTANT(type, name, value)
+
+		// Don't redefine structs
+		#undef JPH_SHADER_STRUCT_BEGIN
+		#undef JPH_SHADER_STRUCT_MEMBER
+		#undef JPH_SHADER_STRUCT_END
+		#define JPH_SHADER_STRUCT_BEGIN(name)
+		#define JPH_SHADER_STRUCT_MEMBER(type, name)
+		#define JPH_SHADER_STRUCT_END
+
+		// When a constant buffer is bound, copy the data into the members
+		#undef JPH_SHADER_CONSTANTS_BEGIN
+		#undef JPH_SHADER_CONSTANTS_MEMBER
+		#define JPH_SHADER_CONSTANTS_BEGIN(type, name)	case HashString(#name): memcpy(&name + 1, inData, size_t(inSize));	break; // Very hacky way to get the address of the first constant and to copy the entire block of constants
+		#define JPH_SHADER_CONSTANTS_MEMBER(type, name)
+
+		// When a buffer is bound, set the pointer
+		#undef JPH_SHADER_BIND_BUFFER
+		#undef JPH_SHADER_BIND_RW_BUFFER
+		#define JPH_SHADER_BIND_BUFFER(type, name)		case HashString(#name): name = (const type *)inData;		break;
+		#define JPH_SHADER_BIND_RW_BUFFER(type, name)	case HashString(#name): name = (type *)inData;				break;
+
+		switch (HashString(inName))
+		{
+		// Include the bindings header only
+		#include JPH_BINDINGS_HEADER_NAME(JPH_SHADER_NAME)
+
+		default:
+			JPH_ASSERT(false, "Buffer cannot be bound to this shader");
+			break;
+		}
+	}
+
+	/// Factory function to create a shader wrapper for this shader
+	static ShaderWrapper *	sCreate()
+	{
+		return new JPH_SHADER_CLASS_NAME(JPH_SHADER_NAME)();
+	}
+};
+/// @endcond
+
+// Stop clang from complaining that the register function is missing a prototype
+JPH_SHADER_WRAPPER_FUNCTION(, JPH_SHADER_NAME);
+
+/// Register this wrapper
+JPH_SHADER_WRAPPER_FUNCTION(inComputeSystem, JPH_SHADER_NAME)
+{
+	inComputeSystem->RegisterShader(JPH_TO_STRING(JPH_SHADER_NAME), JPH_SHADER_CLASS_NAME(JPH_SHADER_NAME)::sCreate);
+}
+
+#undef JPH_SHADER_CONSTANT
+#undef JPH_SHADER_CONSTANTS_BEGIN
+#undef JPH_SHADER_CONSTANTS_MEMBER
+#undef JPH_SHADER_CONSTANTS_END
+#undef JPH_SHADER_BUFFER
+#undef JPH_SHADER_RW_BUFFER
+#undef JPH_SHADER_BIND_BEGIN
+#undef JPH_SHADER_BIND_END
+#undef JPH_SHADER_BIND_BUFFER
+#undef JPH_SHADER_BIND_RW_BUFFER
+#undef JPH_SHADER_FUNCTION_BEGIN
+#undef JPH_SHADER_PARAM_THREAD_ID
+#undef JPH_SHADER_FUNCTION_END
+#undef JPH_SHADER_STRUCT_BEGIN
+#undef JPH_SHADER_STRUCT_MEMBER
+#undef JPH_SHADER_STRUCT_END
+#undef JPH_TO_STRING
+#undef JPH_TO_STRING2
+#undef JPH_SHADER_CLASS_NAME
+#undef JPH_SHADER_CLASS_NAME2
+#undef JPH_SHADER_HEADER_NAME
+#undef JPH_BINDINGS_HEADER_NAME
+#undef JPH_BINDINGS_HEADER_NAME2
+#undef JPH_OUT
+#undef JPH_IN_OUT
+
+JPH_NAMESPACE_END

+ 0 - 19
Jolt/Shaders/TestCompute.h

@@ -1,19 +0,0 @@
-// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
-// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
-// SPDX-License-Identifier: MIT
-
-#ifdef __cplusplus
-	#pragma once
-#endif
-
-#include "ShaderCore.h"
-
-static const int cTestComputeGroupSize = 64;
-
-JPH_SHADER_CONSTANTS_BEGIN(TestComputeContext, gContext)
-	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		Float3Value)
-	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		UIntValue)		// Test that this value packs correctly with the float3 preceding it
-	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		Float3Value2)
-	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		UIntValue2)
-	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumElements)
-JPH_SHADER_CONSTANTS_END

+ 1 - 1
Jolt/Shaders/TestCompute.hlsl

@@ -2,7 +2,7 @@
 // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
 // SPDX-License-Identifier: MIT
 
-#include "TestCompute.h"
+#include "ShaderCore.h"
 #include "TestComputeBindings.h"
 
 JPH_SHADER_FUNCTION_BEGIN(void, main, cTestComputeGroupSize, 1, 1)

+ 10 - 0
Jolt/Shaders/TestComputeBindings.h

@@ -2,6 +2,16 @@
 // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
 // SPDX-License-Identifier: MIT
 
+JPH_SHADER_CONSTANT(int, cTestComputeGroupSize, 64)
+
+JPH_SHADER_CONSTANTS_BEGIN(TestComputeContext, gContext)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		Float3Value)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		UIntValue)		// Test that this value packs correctly with the float3 preceding it
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		Float3Value2)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		UIntValue2)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumElements)
+JPH_SHADER_CONSTANTS_END
+
 JPH_SHADER_BIND_BEGIN(JPH_TestCompute)
 	JPH_SHADER_BIND_BUFFER(JPH_uint, gUploadData)
 	JPH_SHADER_BIND_BUFFER(JPH_uint, gOptionalData)

+ 9 - 0
Jolt/Shaders/TestComputeWrapper.cpp

@@ -0,0 +1,9 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2026 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#define JPH_SHADER_NAME TestCompute
+#include "ShaderWrapperCreator.h"
+#undef JPH_SHADER_NAME

+ 14 - 3
Samples/SamplesApp.cpp

@@ -12,6 +12,7 @@
 #include <Jolt/Core/StreamWrapper.h>
 #include <Jolt/Core/StringTools.h>
 #include <Jolt/Geometry/OrientedBox.h>
+#include <Jolt/Compute/CPU/ComputeSystemCPU.h>
 #include <Jolt/Physics/PhysicsSystem.h>
 #include <Jolt/Physics/StateRecorderImpl.h>
 #include <Jolt/Physics/Body/BodyCreationSettings.h>
@@ -472,7 +473,8 @@ SamplesApp::SamplesApp(const String &inCommandLine) :
 	mJobSystemValidating = new JobSystemSingleThreaded(cMaxPhysicsJobs);
 
 	// Set shader loader
-	mRenderer->GetComputeSystem().mShaderLoader = [](const char *inName, Array<uint8> &outData, String &outError) {
+	mComputeSystem = &mRenderer->GetComputeSystem();
+	mComputeSystem->mShaderLoader = [](const char *inName, Array<uint8> &outData, String &outError) {
 	#ifdef JPH_PLATFORM_MACOS
 		// In macOS the shaders are copied to the bundle
 		String base_path = "Jolt/Shaders/";
@@ -485,11 +487,15 @@ SamplesApp::SamplesApp(const String &inCommandLine) :
 	};
 
 	// Create compute queue
-	ComputeQueueResult queue_result = mRenderer->GetComputeSystem().CreateComputeQueue();
+	ComputeQueueResult queue_result = mComputeSystem->CreateComputeQueue();
 	if (queue_result.HasError())
 		FatalError(queue_result.GetError().c_str());
 	mComputeQueue = queue_result.Get();
 
+	// Create compute system CPU
+	mComputeSystemCPU = StaticCast<ComputeSystemCPU>(CreateComputeSystemCPU().Get());
+	mComputeQueueCPU = mComputeSystemCPU->CreateComputeQueue().Get();
+
 	{
 		// Disable allocation checking
 		DisableCustomMemoryHook dcmh;
@@ -547,6 +553,7 @@ SamplesApp::SamplesApp(const String &inCommandLine) :
 			mDebugUI->CreateCheckBox(phys_settings, "Record State For Playback", mRecordState, [this](UICheckBox::EState inState) { mRecordState = inState == UICheckBox::STATE_CHECKED; });
 			mDebugUI->CreateCheckBox(phys_settings, "Check Determinism", mCheckDeterminism, [this](UICheckBox::EState inState) { mCheckDeterminism = inState == UICheckBox::STATE_CHECKED; });
 			mDebugUI->CreateCheckBox(phys_settings, "Install Contact Listener", mInstallContactListener, [this](UICheckBox::EState inState) { mInstallContactListener = inState == UICheckBox::STATE_CHECKED; StartTest(mTestClass); });
+			mDebugUI->CreateCheckBox(phys_settings, "Use GPU Compute System", mUseGPUCompute, [this](UICheckBox::EState inState) { mUseGPUCompute = inState == UICheckBox::STATE_CHECKED; StartTest(mTestClass); });
 			mDebugUI->ShowMenu(phys_settings);
 		});
 	#ifdef JPH_DEBUG_RENDERER
@@ -716,6 +723,7 @@ SamplesApp::~SamplesApp()
 	delete mContactListener;
 	delete mPhysicsSystem;
 	mComputeQueue = nullptr;
+	mComputeSystem = nullptr;
 	delete mJobSystemValidating;
 	delete mJobSystem;
 	delete mTempAllocator;
@@ -763,7 +771,10 @@ void SamplesApp::StartTest(const RTTI *inRTTI)
 	mTest = static_cast<Test *>(inRTTI->CreateObject());
 	mTest->SetPhysicsSystem(mPhysicsSystem);
 	mTest->SetJobSystem(mJobSystem);
-	mTest->SetComputeSystem(&mRenderer->GetComputeSystem(), mComputeQueue);
+	if (mUseGPUCompute)
+		mTest->SetComputeSystem(mComputeSystem, mComputeQueue);
+	else
+		mTest->SetComputeSystem(mComputeSystemCPU, mComputeQueueCPU);
 	mTest->SetDebugRenderer(mDebugRenderer);
 	mTest->SetTempAllocator(mTempAllocator);
 	if (mInstallContactListener)

+ 5 - 0
Samples/SamplesApp.h

@@ -18,6 +18,7 @@
 namespace JPH {
 	class JobSystem;
 	class TempAllocator;
+	class ComputeSystemCPU;
 };
 
 // Application class that runs the samples
@@ -92,7 +93,10 @@ private:
 	TempAllocator *			mTempAllocator = nullptr;									// Allocator for temporary allocations
 	JobSystem *				mJobSystem = nullptr;										// The job system that runs physics jobs
 	JobSystem *				mJobSystemValidating = nullptr;								// The job system to use when validating determinism
+	Ref<ComputeSystem>		mComputeSystem = nullptr;									// The compute system to use for compute jobs
 	Ref<ComputeQueue>		mComputeQueue = nullptr;									// The compute queue to use for compute jobs
+	Ref<ComputeSystemCPU>	mComputeSystemCPU = nullptr;								// The compute system to use for CPU compute jobs
+	Ref<ComputeQueue>		mComputeQueueCPU = nullptr;									// The compute queue to use for CPU compute jobs
 	BPLayerInterfaceImpl	mBroadPhaseLayerInterface;									// The broadphase layer interface that maps object layers to broadphase layers
 	ObjectVsBroadPhaseLayerFilterImpl mObjectVsBroadPhaseLayerFilter;					// Class that filters object vs broadphase layers
 	ObjectLayerPairFilterImpl mObjectVsObjectLayerFilter;								// Class that filters object vs object layers
@@ -128,6 +132,7 @@ private:
 
 	// Test settings
 	bool					mInstallContactListener = false;							// When true, the contact listener is installed the next time the test is reset
+	bool					mUseGPUCompute = true;										// When true, uses the GPU compute system for compute jobs
 
 	// State recording and determinism checks
 	bool					mRecordState = false;										// When true, the state of the physics system is recorded in mPlaybackFrames every physics update

+ 20 - 6
UnitTests/Compute/ComputeTests.cpp

@@ -4,11 +4,12 @@
 
 #include "UnitTestFramework.h"
 
-#if defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)
-
 #include <Jolt/Compute/ComputeSystem.h>
-#include <Jolt/Shaders/TestCompute.h>
+#include <Jolt/Compute/CPU/ComputeSystemCPU.h>
+#include <Jolt/Shaders/ShaderCore.h>
+#include <Jolt/Shaders/TestComputeBindings.h>
 #include <Jolt/Core/IncludeWindows.h>
+#include <Jolt/Core/RTTI.h>
 
 JPH_SUPPRESS_WARNINGS_STD_BEGIN
 #include <fstream>
@@ -22,6 +23,8 @@ JPH_SUPPRESS_WARNINGS_STD_END
 #include <CoreFoundation/CoreFoundation.h>
 #endif
 
+JPH_DECLARE_REGISTER_SHADER(TestCompute)
+
 TEST_SUITE("ComputeTests")
 {
 	static const char *cInvalidShaderName = "InvalidShader";
@@ -60,7 +63,8 @@ TEST_SUITE("ComputeTests")
 				if (count > 0)
 					application_path[count] = 0;
 			#else
-				#error Unsupported platform
+				// Not implemented
+				const char *application_path = "";
 			#endif
 			String base_path;
 			filesystem::path shader_path(application_path);
@@ -269,6 +273,16 @@ TEST_SUITE("ComputeTests")
 		}
 	}
 #endif // JPH_USE_VK
-}
 
-#endif // defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)
+	TEST_CASE("TestComputeCPU")
+	{
+		ComputeSystemResult compute_system = CreateComputeSystemCPU();
+		CHECK(!compute_system.HasError());
+		if (!compute_system.HasError())
+		{
+			CHECK(compute_system.Get() != nullptr);
+			JPH_REGISTER_SHADER(StaticCast<ComputeSystemCPU>(compute_system.Get()), TestCompute);
+			RunTests(compute_system.Get());
+		}
+	}
+}