Sfoglia il codice sorgente

Added interface to run compute shaders on the GPU with implementations for DX12, Vulkan and Metal. (#1847)

Currently it is only used by a unit test, further users of this system will follow later.
Jorrit Rouwe 1 mese fa
parent
commit
5ac132df68
96 ha cambiato i file con 4930 aggiunte e 1191 eliminazioni
  1. 49 14
      .github/workflows/build.yml
  2. 24 6
      .github/workflows/determinism_check.yml
  3. 2 2
      .gitignore
  4. 45 19
      Build/CMakeLists.txt
  5. 6 1
      Build/README.md
  6. 1 1
      Build/cmake_linux_mingw.sh
  7. 1 1
      Build/cmake_vs2022_cl_32bit.bat
  8. 1 1
      Build/cmake_windows_mingw.sh
  9. 1 1
      Build/iOS/UnitTestsInfo.plist
  10. 1 1
      Build/ubuntu24_install_vulkan_sdk.sh
  11. 61 0
      Jolt/Compute/ComputeBuffer.h
  12. 67 0
      Jolt/Compute/ComputeQueue.h
  13. 38 0
      Jolt/Compute/ComputeShader.h
  14. 69 0
      Jolt/Compute/ComputeSystem.h
  15. 149 0
      Jolt/Compute/DX12/ComputeBufferDX12.cpp
  16. 50 0
      Jolt/Compute/DX12/ComputeBufferDX12.h
  17. 221 0
      Jolt/Compute/DX12/ComputeQueueDX12.cpp
  18. 61 0
      Jolt/Compute/DX12/ComputeQueueDX12.h
  19. 54 0
      Jolt/Compute/DX12/ComputeShaderDX12.h
  20. 412 0
      Jolt/Compute/DX12/ComputeSystemDX12.cpp
  21. 52 0
      Jolt/Compute/DX12/ComputeSystemDX12.h
  22. 147 0
      Jolt/Compute/DX12/ComputeSystemDX12Impl.cpp
  23. 33 0
      Jolt/Compute/DX12/ComputeSystemDX12Impl.h
  24. 40 0
      Jolt/Compute/DX12/IncludeDX12.h
  25. 37 0
      Jolt/Compute/MTL/ComputeBufferMTL.h
  26. 44 0
      Jolt/Compute/MTL/ComputeBufferMTL.mm
  27. 49 0
      Jolt/Compute/MTL/ComputeQueueMTL.h
  28. 123 0
      Jolt/Compute/MTL/ComputeQueueMTL.mm
  29. 39 0
      Jolt/Compute/MTL/ComputeShaderMTL.h
  30. 34 0
      Jolt/Compute/MTL/ComputeShaderMTL.mm
  31. 40 0
      Jolt/Compute/MTL/ComputeSystemMTL.h
  32. 89 0
      Jolt/Compute/MTL/ComputeSystemMTL.mm
  33. 28 0
      Jolt/Compute/MTL/ComputeSystemMTLImpl.h
  34. 39 0
      Jolt/Compute/MTL/ComputeSystemMTLImpl.mm
  35. 41 0
      Jolt/Compute/VK/BufferVK.h
  36. 130 0
      Jolt/Compute/VK/ComputeBufferVK.cpp
  37. 51 0
      Jolt/Compute/VK/ComputeBufferVK.h
  38. 304 0
      Jolt/Compute/VK/ComputeQueueVK.cpp
  39. 66 0
      Jolt/Compute/VK/ComputeQueueVK.h
  40. 232 0
      Jolt/Compute/VK/ComputeShaderVK.cpp
  41. 53 0
      Jolt/Compute/VK/ComputeShaderVK.h
  42. 90 0
      Jolt/Compute/VK/ComputeSystemVK.cpp
  43. 57 0
      Jolt/Compute/VK/ComputeSystemVK.h
  44. 308 0
      Jolt/Compute/VK/ComputeSystemVKImpl.cpp
  45. 57 0
      Jolt/Compute/VK/ComputeSystemVKImpl.h
  46. 168 0
      Jolt/Compute/VK/ComputeSystemVKWithAllocator.cpp
  47. 70 0
      Jolt/Compute/VK/ComputeSystemVKWithAllocator.h
  48. 30 0
      Jolt/Compute/VK/IncludeVK.h
  49. 10 0
      Jolt/Core/Core.h
  50. 36 0
      Jolt/Core/IncludeWindows.h
  51. 1 14
      Jolt/Core/JobSystemThreadPool.cpp
  52. 1 15
      Jolt/Core/Semaphore.cpp
  53. 1 15
      Jolt/Core/TickCounter.cpp
  54. 193 3
      Jolt/Jolt.cmake
  55. 0 0
      Jolt/Physics/Collision/Shape/TaperedCapsuleShape.gliffy
  56. 75 0
      Jolt/Shaders/ShaderCore.h
  57. 13 0
      Jolt/Shaders/ShaderMat44.h
  58. 16 0
      Jolt/Shaders/ShaderMath.h
  59. 18 0
      Jolt/Shaders/ShaderPlane.h
  60. 114 0
      Jolt/Shaders/ShaderQuat.h
  61. 28 0
      Jolt/Shaders/ShaderVec3.h
  62. 19 0
      Jolt/Shaders/TestCompute.h
  63. 26 0
      Jolt/Shaders/TestCompute.hlsl
  64. 9 0
      Jolt/Shaders/TestComputeBindings.h
  65. 1 1
      JoltViewer/JoltViewer.cmake
  66. 6 1
      JoltViewer/JoltViewer.cpp
  67. 1 1
      Samples/Samples.cmake
  68. 28 3
      Samples/SamplesApp.cpp
  69. 1 0
      Samples/SamplesApp.h
  70. 7 0
      Samples/Tests/Test.h
  71. 14 1
      TestFramework/Application/Application.cpp
  72. 0 117
      TestFramework/Renderer/DX12/CommandQueueDX12.h
  73. 0 40
      TestFramework/Renderer/DX12/ConstantBufferDX12.cpp
  74. 0 32
      TestFramework/Renderer/DX12/ConstantBufferDX12.h
  75. 2 0
      TestFramework/Renderer/DX12/DescriptorHeapDX12.h
  76. 40 142
      TestFramework/Renderer/DX12/RendererDX12.cpp
  77. 11 19
      TestFramework/Renderer/DX12/RendererDX12.h
  78. 5 4
      TestFramework/Renderer/MTL/RendererMTL.h
  79. 11 3
      TestFramework/Renderer/MTL/RendererMTL.mm
  80. 4 0
      TestFramework/Renderer/Renderer.h
  81. 0 21
      TestFramework/Renderer/VK/BufferVK.h
  82. 0 32
      TestFramework/Renderer/VK/ConstantBufferVK.cpp
  83. 0 30
      TestFramework/Renderer/VK/ConstantBufferVK.h
  84. 2 3
      TestFramework/Renderer/VK/PixelShaderVK.h
  85. 3 5
      TestFramework/Renderer/VK/RenderInstancesVK.cpp
  86. 6 10
      TestFramework/Renderer/VK/RenderPrimitiveVK.cpp
  87. 1 1
      TestFramework/Renderer/VK/RenderPrimitiveVK.h
  88. 96 422
      TestFramework/Renderer/VK/RendererVK.cpp
  89. 24 74
      TestFramework/Renderer/VK/RendererVK.h
  90. 2 4
      TestFramework/Renderer/VK/TextureVK.cpp
  91. 2 3
      TestFramework/Renderer/VK/TextureVK.h
  92. 2 3
      TestFramework/Renderer/VK/VertexShaderVK.h
  93. 108 104
      TestFramework/TestFramework.cmake
  94. 4 21
      TestFramework/TestFramework.h
  95. 218 0
      UnitTests/Compute/ComputeTests.cpp
  96. 7 0
      UnitTests/UnitTests.cmake

+ 49 - 14
.github/workflows/build.yml

@@ -42,7 +42,7 @@ jobs:
       run: ctest --output-on-failure --verbose
 
   linux_clang_tsan:
-    runs-on: ubuntu-24.04
+    runs-on: ubuntu-latest
     name: Linux Clang Sanitizers
     strategy:
         fail-fast: false
@@ -52,6 +52,8 @@ jobs:
     steps:
     - name: Checkout Code
       uses: actions/checkout@v6
+    - name: Install Vulkan
+      run: ${{github.workspace}}/Build/ubuntu24_install_vulkan_sdk.sh
     - name: Configure CMake
       working-directory: ${{github.workspace}}/Build
       run: ./cmake_linux_clang_gcc.sh ${{matrix.build_type}} ${{env.UBUNTU_CLANG_VERSION}} -DTARGET_VIEWER=OFF -DTARGET_SAMPLES=OFF -DTARGET_HELLO_WORLD=OFF -DTARGET_UNIT_TESTS=ON -DTARGET_PERFORMANCE_TEST=ON
@@ -68,7 +70,7 @@ jobs:
       run: ./PerformanceTest -q=LinearCast -t=max -s=Ragdoll
 
   linux-clang-so:
-    runs-on: ubuntu-24.04
+    runs-on: ubuntu-latest
     name: Linux Clang Shared Library
     strategy:
         fail-fast: false
@@ -78,6 +80,8 @@ jobs:
     steps:
     - name: Checkout Code
       uses: actions/checkout@v6
+    - name: Install Vulkan
+      run: ${{github.workspace}}/Build/ubuntu24_install_vulkan_sdk.sh
     - name: Configure CMake
       working-directory: ${{github.workspace}}/Build
       run: ./cmake_linux_clang_gcc.sh ${{matrix.build_type}} ${{env.UBUNTU_CLANG_VERSION}} -DBUILD_SHARED_LIBS=YES
@@ -88,7 +92,7 @@ jobs:
       run: ctest --output-on-failure --verbose
 
   linux-clang-32-bit:
-    runs-on: ubuntu-24.04
+    runs-on: ubuntu-latest
     name: Linux Clang 32-bit
     strategy:
         fail-fast: false
@@ -112,7 +116,7 @@ jobs:
       run: ctest --output-on-failure --verbose
 
   linux-clang-use-std-vector:
-    runs-on: ubuntu-24.04
+    runs-on: ubuntu-latest
     name: Linux Clang using std::vector
     strategy:
         fail-fast: false
@@ -123,6 +127,8 @@ jobs:
     steps:
     - name: Checkout Code
       uses: actions/checkout@v6
+    - name: Install Vulkan
+      run: ${{github.workspace}}/Build/ubuntu24_install_vulkan_sdk.sh
     - name: Configure CMake
       working-directory: ${{github.workspace}}/Build
       run: ./cmake_linux_clang_gcc.sh ${{matrix.build_type}} ${{env.UBUNTU_CLANG_VERSION}} -DDOUBLE_PRECISION=${{matrix.double_precision}} -DUSE_STD_VECTOR=ON
@@ -133,7 +139,7 @@ jobs:
       run: ctest --output-on-failure --verbose
 
   linux-gcc:
-    runs-on: ubuntu-24.04
+    runs-on: ubuntu-latest
     name: Linux GCC
     strategy:
         fail-fast: false
@@ -155,7 +161,7 @@ jobs:
       run: ctest --output-on-failure --verbose
 
   linux-gcc-so:
-    runs-on: ubuntu-24.04
+    runs-on: ubuntu-latest
     name: Linux GCC Shared Library
     strategy:
         fail-fast: false
@@ -165,6 +171,8 @@ jobs:
     steps:
     - name: Checkout Code
       uses: actions/checkout@v6
+    - name: Install Vulkan
+      run: ${{github.workspace}}/Build/ubuntu24_install_vulkan_sdk.sh
     - name: Configure CMake
       working-directory: ${{github.workspace}}/Build
       run: ./cmake_linux_clang_gcc.sh ${{matrix.build_type}} ${{env.UBUNTU_GCC_VERSION}} -DBUILD_SHARED_LIBS=Yes
@@ -203,7 +211,7 @@ jobs:
     - name: Test
       working-directory: Build/MinGW_${{matrix.build_type}}
       run: ctest --output-on-failure --verbose
-      
+
   msvc_cl:
     runs-on: windows-latest
     name: Visual Studio CL
@@ -224,7 +232,12 @@ jobs:
       run: msbuild Build\VS2022_CL\JoltPhysics.sln /property:Configuration=${{matrix.build_type}} -m
     - name: Test
       working-directory: ${{github.workspace}}/Build/VS2022_CL/${{matrix.build_type}}
-      run: ./UnitTests.exe
+      shell: cmd
+      # We need to run vcvarsall to set the search path to include the Windows SDK, if not dxcompiler.dll will not be found and the UnitTests application will crash
+      run: |
+        for /f "delims=" %%i in ('vswhere -latest -property installationPath') do set VS_PATH=%%i
+        call "%VS_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x64
+        UnitTests.exe
 
   msvc_cl_no_object_stream:
     runs-on: windows-latest
@@ -245,7 +258,12 @@ jobs:
       run: msbuild Build\VS2022_CL\JoltPhysics.sln /property:Configuration=${{matrix.build_type}} -m
     - name: Test
       working-directory: ${{github.workspace}}/Build/VS2022_CL/${{matrix.build_type}}
-      run: ./UnitTests.exe
+      shell: cmd
+      # We need to run vcvarsall to set the search path to include the Windows SDK, if not dxcompiler.dll will not be found and the UnitTests application will crash
+      run: |
+        for /f "delims=" %%i in ('vswhere -latest -property installationPath') do set VS_PATH=%%i
+        call "%VS_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x64
+        UnitTests.exe
 
   msvc_cl_dll:
     runs-on: windows-latest
@@ -266,7 +284,12 @@ jobs:
       run: msbuild Build\VS2022_CL\JoltPhysics.sln /property:Configuration=${{matrix.build_type}} -m
     - name: Test
       working-directory: ${{github.workspace}}/Build/VS2022_CL/${{matrix.build_type}}
-      run: ./UnitTests.exe
+      shell: cmd
+      # We need to run vcvarsall to set the search path to include the Windows SDK, if not dxcompiler.dll will not be found and the UnitTests application will crash
+      run: |
+        for /f "delims=" %%i in ('vswhere -latest -property installationPath') do set VS_PATH=%%i
+        call "%VS_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x64
+        UnitTests.exe
 
   msvc_cl_32_bit:
     runs-on: windows-latest
@@ -287,7 +310,12 @@ jobs:
       run: msbuild Build\VS2022_CL_32_BIT\JoltPhysics.sln /property:Configuration=${{matrix.build_type}} -m
     - name: Test
       working-directory: ${{github.workspace}}/Build/VS2022_CL_32_BIT/${{matrix.build_type}}
-      run: ./UnitTests.exe
+      shell: cmd
+      # We need to run vcvarsall to set the search path to include the Windows SDK, if not dxcompiler.dll will not be found and the UnitTests application will crash
+      run: |
+        for /f "delims=" %%i in ('vswhere -latest -property installationPath') do set VS_PATH=%%i
+        call "%VS_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x86
+        UnitTests.exe
 
   msvc_cl_arm:
     runs-on: windows-latest
@@ -349,7 +377,12 @@ jobs:
       run: msbuild Build\VS2022_Clang\JoltPhysics.sln /property:Configuration=${{matrix.build_type}} -m
     - name: Test
       working-directory: ${{github.workspace}}/Build/VS2022_Clang/${{matrix.build_type}}
-      run: ./UnitTests.exe
+      shell: cmd
+      # We need to run vcvarsall to set the search path to include the Windows SDK, if not dxcompiler.dll will not be found and the UnitTests application will crash
+      run: |
+        for /f "delims=" %%i in ('vswhere -latest -property installationPath') do set VS_PATH=%%i
+        call "%VS_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x64
+        UnitTests.exe
 
   macos:
     runs-on: macos-latest
@@ -374,7 +407,9 @@ jobs:
       run: cmake --build ${{github.workspace}}/Build/MacOS_${{matrix.build_type}} -j $(nproc)
     - name: Test
       working-directory: ${{github.workspace}}/Build/MacOS_${{matrix.build_type}}
-      run: ctest --output-on-failure --verbose
+      run: |
+        source ${VULKAN_SDK_INSTALL}/setup-env.sh
+        ctest --output-on-failure --verbose
 
   android:
     runs-on: ubuntu-latest
@@ -429,7 +464,7 @@ jobs:
       working-directory: ${{github.workspace}}/Build
       run: ./cmake_linux_emscripten.sh Distribution -DTARGET_HELLO_WORLD=OFF -DTARGET_PERFORMANCE_TEST=OFF
     - name: Build
-      run: cmake --build ${{github.workspace}}/Build/WASM_Distribution -j $(nproc)      
+      run: cmake --build ${{github.workspace}}/Build/WASM_Distribution -j $(nproc)
     - name: Test
       working-directory: ${{github.workspace}}/Build/WASM_Distribution
       run: node UnitTests.js

+ 24 - 6
.github/workflows/determinism_check.yml

@@ -99,7 +99,12 @@ jobs:
       run: msbuild Build\VS2022_CL\JoltPhysics.sln /property:Configuration=Distribution
     - name: Unit Tests
       working-directory: ${{github.workspace}}/Build/VS2022_CL/Distribution
-      run: ./UnitTests.exe
+      shell: cmd
+      # We need to run vcvarsall to set the search path to include the Windows SDK, if not dxcompiler.dll will not be found and the UnitTests application will crash
+      run: |
+        for /f "delims=" %%i in ('vswhere -latest -property installationPath') do set VS_PATH=%%i
+        call "%VS_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x64
+        UnitTests.exe
     - name: Test ConvexVsMesh
       working-directory: ${{github.workspace}}/Build/VS2022_CL/Distribution
       run: ./PerformanceTest -q=LinearCast -t=max -s=ConvexVsMesh "-validate_hash=$env:CONVEX_VS_MESH_HASH"
@@ -129,7 +134,12 @@ jobs:
       run: msbuild Build\VS2022_CL_32BIT\JoltPhysics.sln /property:Configuration=Distribution
     - name: Unit Tests
       working-directory: ${{github.workspace}}/Build/VS2022_CL_32BIT/Distribution
-      run: ./UnitTests.exe
+      shell: cmd
+      # We need to run vcvarsall to set the search path to include the Windows SDK, if not dxcompiler.dll will not be found and the UnitTests application will crash
+      run: |
+        for /f "delims=" %%i in ('vswhere -latest -property installationPath') do set VS_PATH=%%i
+        call "%VS_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x86
+        UnitTests.exe
     - name: Test ConvexVsMesh
       working-directory: ${{github.workspace}}/Build/VS2022_CL_32BIT/Distribution
       run: ./PerformanceTest -q=LinearCast -t=max -s=ConvexVsMesh "-validate_hash=$env:CONVEX_VS_MESH_HASH"
@@ -146,18 +156,26 @@ jobs:
   macos:
     runs-on: macos-latest
     name: macOS Determinism Check
+    env:
+        VULKAN_SDK_INSTALL: ${{github.workspace}}/vulkan_sdk
 
     steps:
     - name: Checkout Code
       uses: actions/checkout@v6
+    - name: Install Vulkan SDK
+      run: ${{github.workspace}}/Build/macos_install_vulkan_sdk.sh ${VULKAN_SDK_INSTALL}
     - name: Configure CMake
       working-directory: ${{github.workspace}}/Build
-      run: ./cmake_linux_clang_gcc.sh Distribution clang++ -DCROSS_PLATFORM_DETERMINISTIC=ON -DTARGET_VIEWER=OFF -DTARGET_SAMPLES=OFF -DTARGET_HELLO_WORLD=OFF -DTARGET_UNIT_TESTS=ON -DTARGET_PERFORMANCE_TEST=ON
+      run: |
+        source ${VULKAN_SDK_INSTALL}/setup-env.sh
+        ./cmake_linux_clang_gcc.sh Distribution clang++ -DCROSS_PLATFORM_DETERMINISTIC=ON -DTARGET_VIEWER=OFF -DTARGET_SAMPLES=OFF -DTARGET_HELLO_WORLD=OFF -DTARGET_UNIT_TESTS=ON -DTARGET_PERFORMANCE_TEST=ON
     - name: Build
       run: cmake --build ${{github.workspace}}/Build/Linux_Distribution -j $(nproc)
     - name: Unit Tests
       working-directory: ${{github.workspace}}/Build/Linux_Distribution
-      run: ctest --output-on-failure --verbose
+      run: |
+        source ${VULKAN_SDK_INSTALL}/setup-env.sh
+        ctest --output-on-failure --verbose
     - name: Test ConvexVsMesh
       working-directory: ${{github.workspace}}/Build/Linux_Distribution
       run: ./PerformanceTest -q=LinearCast -t=max -s=ConvexVsMesh -validate_hash=${CONVEX_VS_MESH_HASH}
@@ -374,7 +392,7 @@ jobs:
       working-directory: ${{github.workspace}}/Build
       run: ./cmake_linux_emscripten.sh Distribution -DCROSS_PLATFORM_DETERMINISTIC=ON -DTARGET_VIEWER=OFF -DTARGET_SAMPLES=OFF -DTARGET_HELLO_WORLD=OFF -DTARGET_UNIT_TESTS=ON -DTARGET_PERFORMANCE_TEST=ON
     - name: Build
-      run: cmake --build ${{github.workspace}}/Build/WASM_Distribution -j $(nproc)      
+      run: cmake --build ${{github.workspace}}/Build/WASM_Distribution -j $(nproc)
     - name: Unit Tests
       working-directory: ${{github.workspace}}/Build/WASM_Distribution
       run: node UnitTests.js
@@ -412,7 +430,7 @@ jobs:
       working-directory: ${{github.workspace}}/Build
       run: ./cmake_linux_emscripten.sh Distribution -DCROSS_PLATFORM_DETERMINISTIC=ON -DTARGET_VIEWER=OFF -DTARGET_SAMPLES=OFF -DTARGET_HELLO_WORLD=OFF -DTARGET_UNIT_TESTS=ON -DTARGET_PERFORMANCE_TEST=ON -DJPH_USE_WASM64=ON
     - name: Build
-      run: cmake --build ${{github.workspace}}/Build/WASM_Distribution -j $(nproc)      
+      run: cmake --build ${{github.workspace}}/Build/WASM_Distribution -j $(nproc)
     - name: Unit Tests
       working-directory: ${{github.workspace}}/Build/WASM_Distribution
       run: node --experimental-wasm-memory64 UnitTests.js

+ 2 - 2
.gitignore

@@ -10,5 +10,5 @@
 /snapshot.bin
 /*.jor
 /detlog.txt
-/Assets/Shaders/VK/*.spv
-/Assets/Shaders/MTL/*.metallib
+*.spv
+*.metallib

+ 45 - 19
Build/CMakeLists.txt

@@ -110,7 +110,7 @@ option(USE_STD_VECTOR "Use std::vector instead of own Array class" OFF)
 option(ENABLE_OBJECT_STREAM "Compile the ObjectStream class and RTTI attribute information" ON)
 
 # Enable installation
-option(ENABLE_INSTALL "Generate installation target"  ON)
+option(ENABLE_INSTALL "Generate installation target" ON)
 
 include(CMakeDependentOption)
 
@@ -118,8 +118,17 @@ include(CMakeDependentOption)
 # Windows Store only supports the DLL version
 cmake_dependent_option(USE_STATIC_MSVC_RUNTIME_LIBRARY "Use the static MSVC runtime library" ON "MSVC;NOT WINDOWS_STORE" OFF)
 
-# Enable Vulkan instead of DirectX
-cmake_dependent_option(JPH_ENABLE_VULKAN "Enable Vulkan" ON "LINUX" OFF)
+# Option to compile with DirectX 12 compute
+option(JPH_USE_DX12 "Use DirectX" ON)
+
+# Use DXC compiler to compile shaders, when off falls back to FXC
+option(JPH_USE_DXC "Use DXC shader compiler" ON)
+
+# Option to compile with Vulkan compute
+option(JPH_USE_VK "Use Vulkan" ON)
+
+# Option to compile with Metal compute
+option(JPH_USE_MTL "Use Metal" ON)
 
 # Determine which configurations exist
 if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) # Only do this when we're at the top level, see: https://gitlab.kitware.com/cmake/cmake/-/issues/24181
@@ -353,7 +362,36 @@ if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
 	if (TARGET_UNIT_TESTS)
 		# Create UnitTests executable
 		include(${PHYSICS_REPO_ROOT}/UnitTests/UnitTests.cmake)
-		add_executable(UnitTests ${UNIT_TESTS_SRC_FILES})
+
+		if (APPLE)
+			# Icon
+			set(JPH_ICON "${CMAKE_CURRENT_SOURCE_DIR}/macOS/icon.icns")
+			set_source_files_properties(${JPH_ICON} PROPERTIES MACOSX_PACKAGE_LOCATION "Resources")
+
+			# macOS configuration
+			add_executable(UnitTests MACOSX_BUNDLE ${UNIT_TESTS_SRC_FILES} ${UNIT_TESTS_ASSETS} ${JPH_ICON})
+
+			# Make sure that all unit test assets move to the Resources folder in the package
+			foreach(ASSET_FILE ${UNIT_TESTS_ASSETS})
+				string(REPLACE ${PHYSICS_REPO_ROOT} "Resources" ASSET_DST ${ASSET_FILE})
+				get_filename_component(ASSET_DST ${ASSET_DST} DIRECTORY)
+				set_source_files_properties(${ASSET_FILE} PROPERTIES MACOSX_PACKAGE_LOCATION ${ASSET_DST})
+			endforeach()
+
+			set_property(TARGET UnitTests PROPERTY MACOSX_BUNDLE_INFO_PLIST "${CMAKE_CURRENT_SOURCE_DIR}/iOS/UnitTestsInfo.plist")
+			set_property(TARGET UnitTests PROPERTY XCODE_ATTRIBUTE_PRODUCT_BUNDLE_IDENTIFIER "com.joltphysics.unittests")
+			set_property(TARGET UnitTests PROPERTY BUILD_RPATH "/usr/local/lib" INSTALL_RPATH "/usr/local/lib") # to find the Vulkan shared lib
+
+			# Ensure that we enable SSE4.2 for the x86_64 build, XCode builds multiple architectures
+			set_property(TARGET UnitTests PROPERTY XCODE_ATTRIBUTE_OTHER_CPLUSPLUSFLAGS[arch=x86_64] "$(inherited) -msse4.2 -mpopcnt")
+
+			# Unit tests are in the app bundle on macOS
+			set(UNIT_TEST_COMMAND UnitTests.app/Contents/MacOS/UnitTests)
+		else()
+			add_executable(UnitTests ${UNIT_TESTS_SRC_FILES})
+
+			set(UNIT_TEST_COMMAND UnitTests)
+		endif()
 		target_include_directories(UnitTests PUBLIC ${UNIT_TESTS_ROOT})
 		target_link_libraries(UnitTests LINK_PUBLIC Jolt)
 
@@ -374,21 +412,10 @@ if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
 			target_link_options(UnitTests PUBLIC "/SUBSYSTEM:CONSOLE")
 		endif()
 
-		if (IOS)
-			# Set the bundle information
-			set_property(TARGET UnitTests PROPERTY MACOSX_BUNDLE_INFO_PLIST "${CMAKE_CURRENT_SOURCE_DIR}/iOS/UnitTestsInfo.plist")
-			set_property(TARGET UnitTests PROPERTY XCODE_ATTRIBUTE_PRODUCT_BUNDLE_IDENTIFIER "com.joltphysics.unittests")
-		endif()
-
-		if (XCODE)
-			# Ensure that we enable SSE4.2 for the x86_64 build, XCode builds multiple architectures
-			set_property(TARGET UnitTests PROPERTY XCODE_ATTRIBUTE_OTHER_CPLUSPLUSFLAGS[arch=x86_64] "$(inherited) -msse4.2 -mpopcnt")
-		endif()
-
 		# Register unit tests as a test so that it can be run with:
 		# ctest --output-on-failure
 		enable_testing()
-		add_test(UnitTests UnitTests)
+		add_test(UnitTests ${UNIT_TEST_COMMAND})
 	endif()
 
 	if (NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "WindowsStore")
@@ -427,7 +454,6 @@ if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
 	endif()
 
 	if ((WIN32 OR LINUX OR ("${CMAKE_SYSTEM_NAME}" MATCHES "Darwin")) AND NOT ("${CMAKE_VS_PLATFORM_NAME}" STREQUAL "ARM")) # ARM 32-bit is missing dinput8.lib
-		# Windows only targets
 		if (TARGET_SAMPLES OR TARGET_VIEWER)
 			include(${PHYSICS_REPO_ROOT}/TestFramework/TestFramework.cmake)
 		endif()
@@ -435,14 +461,14 @@ if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
 			if (TEST_FRAMEWORK_AVAILABLE)
 				include(${PHYSICS_REPO_ROOT}/Samples/Samples.cmake)
 			else()
-				message("Cannot build Samples because Vulkan/DirectX SDK is not available!")
+				message("Cannot build Samples because Vulkan/DirectX/Metal SDK is not available!")
 			endif()
 		endif()
 		if (TARGET_VIEWER)
 			if (TEST_FRAMEWORK_AVAILABLE)
 				include(${PHYSICS_REPO_ROOT}/JoltViewer/JoltViewer.cmake)
 			else()
-				message("Cannot build JoltViewer because Vulkan/DirectX SDK is not available!")
+				message("Cannot build JoltViewer because Vulkan/DirectX/Metal SDK is not available!")
 			endif()
 		endif()
 	endif()

+ 6 - 1
Build/README.md

@@ -39,6 +39,9 @@ There are a number of user configurable defines that turn on/off certain feature
 		<li>JPH_NO_FORCE_INLINE - Don't use force inlining but fall back to a regular 'inline'.</li>
 		<li>JPH_USE_STD_VECTOR - Use std::vector instead of Jolt's own Array class.</li>
 		<li>CPP_RTTI_ENABLED - Enable C++ RTTI for the library. Disabled by default.</li>
+		<li>JPH_USE_DX12 - Implement the DX12 version of ComputeSystem.</li>
+		<li>JPH_USE_VK - Implement the Vulkan version of ComputeSystem.</li>
+		<li>JPH_USE_MTL - Implement the Metal version of ComputeSystem.</li>
 	</ul>
 </details>
 
@@ -130,7 +133,7 @@ To implement your custom memory allocator override Allocate, Free, Reallocate, A
 	<ul>
 		<li>Install clang (apt-get install clang)</li>
 		<li>Install cmake (apt-get install cmake)</li>
-		<li>If you want to build the Samples or JoltViewer, install the <a href="https://vulkan.lunarg.com/doc/view/latest/linux/getting_started_ubuntu.html">Vulkan SDK</a></li>
+		<li>If you want to build the Samples, JoltViewer or use the ComputeSystem, install the <a href="https://vulkan.lunarg.com/doc/view/latest/linux/getting_started_ubuntu.html">Vulkan SDK</a></li>
 		<li>Run: ./cmake_linux_clang_gcc.sh</li>
 		<li>Go to the Linux_Debug folder</li>
 		<li>Run: make -j$(nproc) && ./UnitTests</li>
@@ -152,6 +155,7 @@ To implement your custom memory allocator override Allocate, Free, Reallocate, A
 	<summary>macOS</summary>
 	<ul>
 		<li>Install XCode</li>
+		<li>Install the Vulkan SDK or the dxc and spirv-cross tools (required to cross compile hlsl shaders to Metal)</li>
 		<li>Download CMake 3.23+ (https://cmake.org/download/)</li>
 		<li>Run: ./cmake_xcode_macos.sh</li>
 		<li>This will open XCode with a newly generated project</li>
@@ -164,6 +168,7 @@ To implement your custom memory allocator override Allocate, Free, Reallocate, A
 	<summary>iOS</summary>
 	<ul>
 		<li>Install XCode</li>
+		<li>Install the Vulkan SDK or the dxc and spirv-cross tools (required to cross compile hlsl shaders to Metal)</li>
 		<li>Download CMake 3.23+ (https://cmake.org/download/)</li>
 		<li>Run: ./cmake_xcode.ios.sh</li>
 		<li>This will open XCode with a newly generated project</li>

+ 1 - 1
Build/cmake_linux_mingw.sh

@@ -14,6 +14,6 @@ echo Usage: ./cmake_linux_mingw.sh [Configuration]
 echo "Possible configurations: Debug, Release (default), Distribution"
 echo Generating Makefile for build type \"$BUILD_TYPE\" in folder \"$BUILD_DIR\"
 
-cmake -S . -B $BUILD_DIR -DCMAKE_TOOLCHAIN_FILE=mingw-w64-x86_64.cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE "${@}"
+cmake -S . -B $BUILD_DIR -DCMAKE_TOOLCHAIN_FILE=mingw-w64-x86_64.cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DJPH_USE_DXC=OFF "${@}"
 
 echo Compile by running \"cmake --build $BUILD_DIR -j $(nproc)\"

+ 1 - 1
Build/cmake_vs2022_cl_32bit.bat

@@ -1,3 +1,3 @@
 @echo off
-cmake -S . -B VS2022_CL_32BIT -G "Visual Studio 17 2022" -A Win32 -DUSE_SSE4_1=OFF -DUSE_SSE4_2=OFF -DUSE_AVX=OFF -DUSE_AVX2=OFF -DUSE_AVX512=OFF -DUSE_LZCNT=OFF -DUSE_TZCNT=OFF -DUSE_F16C=OFF -DUSE_FMADD=OFF %*
+cmake -S . -B VS2022_CL_32BIT -G "Visual Studio 17 2022" -A Win32 -DUSE_SSE4_1=OFF -DUSE_SSE4_2=OFF -DUSE_AVX=OFF -DUSE_AVX2=OFF -DUSE_AVX512=OFF -DUSE_LZCNT=OFF -DUSE_TZCNT=OFF -DUSE_F16C=OFF -DUSE_FMADD=OFF -DJPH_USE_VK=OFF %*
 echo Open VS2022_CL_32BIT\JoltPhysics.sln to build the project.

+ 1 - 1
Build/cmake_windows_mingw.sh

@@ -14,6 +14,6 @@ echo Usage: ./cmake_windows_mingw.sh [Configuration]
 echo "Possible configurations: Debug (default), Release, Distribution"
 echo Generating Makefile for build type \"$BUILD_TYPE\" in folder \"$BUILD_DIR\"
 
-cmake -S . -B $BUILD_DIR -G "MinGW Makefiles" -DCMAKE_BUILD_TYPE=$BUILD_TYPE "${@}"
+cmake -S . -B $BUILD_DIR -G "MinGW Makefiles" -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DJPH_USE_DXC=OFF "${@}"
 
 echo Compile by running \"cmake --build $BUILD_DIR -j $(nproc)\"

+ 1 - 1
Build/iOS/UnitTestsInfo.plist

@@ -9,7 +9,7 @@
 	<key>CFBundleGetInfoString</key>
 	<string></string>
 	<key>CFBundleIconFile</key>
-	<string></string>
+	<string>icon.icns</string>
 	<key>CFBundleIdentifier</key>
 	<string>com.joltphysics.unittests</string>
 	<key>CFBundleInfoDictionaryVersion</key>

+ 1 - 1
Build/ubuntu24_install_vulkan_sdk.sh

@@ -1,4 +1,4 @@
 wget -qO- https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo tee /etc/apt/trusted.gpg.d/lunarg.asc
 sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list http://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list
 sudo apt update
-sudo apt install vulkan-sdk
+sudo apt install vulkan-sdk mesa-vulkan-drivers

+ 61 - 0
Jolt/Compute/ComputeBuffer.h

@@ -0,0 +1,61 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Core/Reference.h>
+#include <Jolt/Core/NonCopyable.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Buffer that can be read from / written to by a compute shader
+class JPH_EXPORT ComputeBuffer : public RefTarget<ComputeBuffer>, public NonCopyable
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Type of buffer
+	enum class EType
+	{
+		UploadBuffer,			///< Buffer that can be written on the CPU and then uploaded to the GPU.
+		ReadbackBuffer,			///< Buffer to be sent from the GPU to the CPU, used to read back data.
+		ConstantBuffer,			///< A smallish buffer that is used to pass constants to a shader.
+		Buffer,					///< Buffer that can be read from by a shader. Must be initialized with data at construction time and is read only thereafter.
+		RWBuffer,				///< Buffer that can be read from and written to by a shader.
+	};
+
+	/// Constructor / Destructor
+								ComputeBuffer(EType inType, uint64 inSize, uint inStride) : mType(inType), mSize(inSize), mStride(inStride) { }
+	virtual						~ComputeBuffer() = default;
+
+	/// Properties
+	EType						GetType() const									{ return mType; }
+	uint64						GetSize() const									{ return mSize; }
+	uint						GetStride() const								{ return mStride; }
+
+	/// Mode in which the buffer is accessed
+	enum class EMode
+	{
+		Read,					///< Read only access to the buffer
+		Write,					///< Write only access to the buffer (this will discard all previous data in the buffer)
+	};
+
+	/// Map / unmap buffer (get pointer to data).
+	void *						Map(EMode inMode)								{ return MapInternal(inMode); }
+	template <typename T> T *	Map(EMode inMode)								{ JPH_ASSERT(sizeof(T) == mStride); return reinterpret_cast<T *>(MapInternal(inMode)); }
+	virtual void				Unmap() = 0;
+
+	/// Create a readback buffer of the same size and stride that can be used to read the data stored in this buffer on CPU.
+	/// Note that this could also be implemented as 'return this' in case the underlying implementation allows locking GPU data on CPU directly.
+	virtual Ref<ComputeBuffer>	CreateReadBackBuffer() const = 0;
+
+protected:
+	EType						mType;
+	uint64						mSize;
+	uint						mStride;
+
+	virtual void *				MapInternal(EMode inMode) = 0;
+};
+
+JPH_NAMESPACE_END

+ 67 - 0
Jolt/Compute/ComputeQueue.h

@@ -0,0 +1,67 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Core/Reference.h>
+#include <Jolt/Core/NonCopyable.h>
+
+JPH_NAMESPACE_BEGIN
+
+class ComputeShader;
+class ComputeBuffer;
+
+/// A command queue for executing compute workloads on the GPU.
+class JPH_EXPORT ComputeQueue : public RefTarget<ComputeQueue>, public NonCopyable
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Destructor
+	virtual					~ComputeQueue() = default;
+
+	/// Activate a shader. Shader must be set first before buffers can be bound.
+	/// After every Dispatch call, the shader must be set again and all buffers must be bound again.
+	virtual void			SetShader(const ComputeShader *inShader) = 0;
+
+	/// If a barrier should be placed before accessing the buffer
+	enum class EBarrier
+	{
+		Yes,
+		No
+	};
+
+	/// Bind a constant buffer to the shader. Note that the contents of the buffer cannot be modified until execution finishes.
+	virtual void			SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) = 0;
+
+	/// Bind a read only buffer to the shader. Note that the contents of the buffer cannot be modified on CPU until execution finishes (only relevant for buffers of type UploadBuffer).
+	virtual void			SetBuffer(const char *inName, const ComputeBuffer *inBuffer) = 0;
+
+	/// Bind a read/write buffer to the shader.
+	/// @param inName Name of the buffer as specified in the shader.
+	/// @param inBuffer The buffer to bind.
+	/// @param inBarrier If set to Yes, a barrier will be placed before accessing the buffer to ensure all previous writes to the buffer are visible.
+	virtual void 			SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) = 0;
+
+	/// Dispatch a compute shader with the specified number of thread groups
+	virtual void			Dispatch(uint inThreadGroupsX, uint inThreadGroupsY = 1, uint inThreadGroupsZ = 1) = 0;
+
+	/// Schedule buffer to be copied from GPU to CPU
+	virtual void			ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) = 0;
+
+	/// Execute accumulated command list
+	virtual void			Execute() = 0;
+
+	/// After executing, this waits until execution is done
+	virtual void			Wait() = 0;
+
+	/// Execute and wait for the command list to finish
+	void					ExecuteAndWait()
+	{
+		Execute();
+		Wait();
+	}
+};
+
+JPH_NAMESPACE_END

+ 38 - 0
Jolt/Compute/ComputeShader.h

@@ -0,0 +1,38 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Core/Reference.h>
+#include <Jolt/Core/NonCopyable.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Compute shader handle
+class JPH_EXPORT ComputeShader : public RefTarget<ComputeShader>, public NonCopyable
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor / destructor
+							ComputeShader(uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
+		mGroupSizeX(inGroupSizeX),
+		mGroupSizeY(inGroupSizeY),
+		mGroupSizeZ(inGroupSizeZ)
+	{
+	}
+	virtual					~ComputeShader() = default;
+
+	/// Get group sizes
+	uint32					GetGroupSizeX() const						{ return mGroupSizeX; }
+	uint32					GetGroupSizeY() const						{ return mGroupSizeY; }
+	uint32					GetGroupSizeZ() const						{ return mGroupSizeZ; }
+
+private:
+	uint32					mGroupSizeX;
+	uint32					mGroupSizeY;
+	uint32					mGroupSizeZ;
+};
+
+JPH_NAMESPACE_END

+ 69 - 0
Jolt/Compute/ComputeSystem.h

@@ -0,0 +1,69 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeShader.h>
+#include <Jolt/Compute/ComputeBuffer.h>
+#include <Jolt/Compute/ComputeQueue.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Interface to run a workload on the GPU
+class JPH_EXPORT ComputeSystem : public RefTarget<ComputeSystem>, public NonCopyable
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Destructor
+	virtual							~ComputeSystem() = default;
+
+	/// Compile a compute shader
+	virtual Ref<ComputeShader>		CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY = 1, uint32 inGroupSizeZ = 1) = 0;
+
+	/// Create a buffer for use with a compute shader
+	virtual Ref<ComputeBuffer>		CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) = 0;
+
+	/// Create a queue for executing compute shaders
+	virtual Ref<ComputeQueue>		CreateComputeQueue() = 0;
+
+	/// Callback used when loading shaders
+	using ShaderLoader = std::function<bool(const char *inName, Array<uint8> &outData)>;
+	ShaderLoader					mShaderLoader = [](const char *, Array<uint8> &) { JPH_ASSERT(false, "Override this function"); return false; };
+};
+
+#ifdef JPH_USE_VK
+/// Factory function to create a compute system using Vulkan
+extern JPH_EXPORT ComputeSystem *	CreateComputeSystemVK();
+#endif
+
+#ifdef JPH_USE_DX12
+
+/// Factory function to create a compute system using DirectX 12
+extern JPH_EXPORT ComputeSystem *	CreateComputeSystemDX12();
+
+/// Factory function to create the default compute system for this platform
+inline ComputeSystem *				CreateComputeSystem()		{ return CreateComputeSystemDX12(); }
+
+#elif defined(JPH_USE_MTL)
+
+/// Factory function to create a compute system using Metal
+extern JPH_EXPORT ComputeSystem *	CreateComputeSystemMTL();
+
+/// Factory function to create the default compute system for this platform
+inline ComputeSystem *				CreateComputeSystem()		{ return CreateComputeSystemMTL(); }
+
+#elif defined(JPH_USE_VK)
+
+/// Factory function to create the default compute system for this platform
+inline ComputeSystem *				CreateComputeSystem()		{ return CreateComputeSystemVK(); }
+
+#else
+
+/// Fallback implementation when no compute system is available
+inline ComputeSystem *				CreateComputeSystem()		{ return nullptr; }
+
+#endif
+
+JPH_NAMESPACE_END

+ 149 - 0
Jolt/Compute/DX12/ComputeBufferDX12.cpp

@@ -0,0 +1,149 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/DX12/ComputeBufferDX12.h>
+#include <Jolt/Compute/DX12/ComputeSystemDX12.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeBufferDX12::ComputeBufferDX12(ComputeSystemDX12 *inComputeSystem, EType inType, uint64 inSize, uint inStride, const void *inData) :
+	ComputeBuffer(inType, inSize, inStride),
+	mComputeSystem(inComputeSystem)
+{
+	uint64 buffer_size = inSize * inStride;
+
+	switch (inType)
+	{
+	case EType::UploadBuffer:
+		mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
+		mBufferGPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_NONE, buffer_size);
+		break;
+
+	case EType::ConstantBuffer:
+		mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
+		break;
+
+	case EType::ReadbackBuffer:
+		JPH_ASSERT(inData == nullptr, "Can't upload data to a readback buffer");
+		mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_READBACK, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_FLAG_NONE, buffer_size);
+		break;
+
+	case EType::Buffer:
+		JPH_ASSERT(inData != nullptr);
+		mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
+		mBufferGPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_NONE, buffer_size);
+		mNeedsSync = true;
+		break;
+
+	case EType::RWBuffer:
+		if (inData != nullptr)
+		{
+			mBufferCPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, buffer_size);
+			mNeedsSync = true;
+		}
+		mBufferGPU = mComputeSystem->CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, buffer_size);
+		break;
+	}
+
+	// Copy data to upload buffer
+	if (inData != nullptr)
+	{
+		void *data = nullptr;
+		D3D12_RANGE range = { 0, 0 }; // We're not going to read
+		mBufferCPU->Map(0, &range, &data);
+		memcpy(data, inData, size_t(buffer_size));
+		mBufferCPU->Unmap(0, nullptr);
+	}
+}
+
+bool ComputeBufferDX12::Barrier(ID3D12GraphicsCommandList *inCommandList, D3D12_RESOURCE_STATES inTo) const
+{
+	// Check if state changed
+	if (mCurrentState == inTo)
+		return false;
+
+	// Only buffers in GPU memory can change state
+	if (mType != ComputeBuffer::EType::Buffer && mType != ComputeBuffer::EType::RWBuffer)
+		return true;
+
+	D3D12_RESOURCE_BARRIER barrier;
+	barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_TRANSITION;
+	barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
+	barrier.Transition.pResource = GetResourceGPU();
+	barrier.Transition.StateBefore = mCurrentState;
+	barrier.Transition.StateAfter = inTo;
+	barrier.Transition.Subresource = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES;
+	inCommandList->ResourceBarrier(1, &barrier);
+
+	mCurrentState = inTo;
+	return true;
+}
+
+void ComputeBufferDX12::RWBarrier(ID3D12GraphicsCommandList *inCommandList)
+{
+	JPH_ASSERT(mCurrentState == D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
+
+	D3D12_RESOURCE_BARRIER barrier;
+	barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV;
+	barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
+	barrier.Transition.pResource = GetResourceGPU();
+	inCommandList->ResourceBarrier(1, &barrier);
+}
+
+bool ComputeBufferDX12::SyncCPUToGPU(ID3D12GraphicsCommandList *inCommandList) const
+{
+	if (!mNeedsSync)
+		return false;
+
+	Barrier(inCommandList, D3D12_RESOURCE_STATE_COPY_DEST);
+
+	inCommandList->CopyResource(GetResourceGPU(), GetResourceCPU());
+
+	mNeedsSync = false;
+	return true;
+}
+
+void *ComputeBufferDX12::MapInternal(EMode inMode)
+{
+	void *mapped_resource = nullptr;
+
+	switch (inMode)
+	{
+	case EMode::Read:
+		JPH_ASSERT(mType == EType::ReadbackBuffer);
+		if (HRFailed(mBufferCPU->Map(0, nullptr, &mapped_resource)))
+			return nullptr;
+		break;
+
+	case EMode::Write:
+		{
+			JPH_ASSERT(mType == EType::UploadBuffer || mType == EType::ConstantBuffer);
+			D3D12_RANGE range = { 0, 0 }; // We're not going to read
+			if (HRFailed(mBufferCPU->Map(0, &range, &mapped_resource)))
+				return nullptr;
+			mNeedsSync = true;
+		}
+		break;
+	}
+
+	return mapped_resource;
+}
+
+void ComputeBufferDX12::Unmap()
+{
+	mBufferCPU->Unmap(0, nullptr);
+}
+
+Ref<ComputeBuffer> ComputeBufferDX12::CreateReadBackBuffer() const
+{
+	return mComputeSystem->CreateComputeBuffer(EType::ReadbackBuffer, mSize, mStride);
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 50 - 0
Jolt/Compute/DX12/ComputeBufferDX12.h

@@ -0,0 +1,50 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeBuffer.h>
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/DX12/IncludeDX12.h>
+
+JPH_NAMESPACE_BEGIN
+
+class ComputeSystemDX12;
+
+/// Buffer that can be read from / written to by a compute shader
+class JPH_EXPORT ComputeBufferDX12 final : public ComputeBuffer
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor
+									ComputeBufferDX12(ComputeSystemDX12 *inComputeSystem, EType inType, uint64 inSize, uint inStride, const void *inData);
+
+	ID3D12Resource *				GetResourceCPU() const									{ return mBufferCPU.Get(); }
+	ID3D12Resource *				GetResourceGPU() const									{ return mBufferGPU.Get(); }
+	ComPtr<ID3D12Resource>			ReleaseResourceCPU() const								{ return std::move(mBufferCPU); }
+
+	bool							Barrier(ID3D12GraphicsCommandList *inCommandList, D3D12_RESOURCE_STATES inTo) const;
+	void							RWBarrier(ID3D12GraphicsCommandList *inCommandList);
+	bool							SyncCPUToGPU(ID3D12GraphicsCommandList *inCommandList) const;
+
+	virtual void					Unmap() override;
+
+	Ref<ComputeBuffer>				CreateReadBackBuffer() const override;
+
+private:
+	virtual void *					MapInternal(EMode inMode) override;
+
+	ComputeSystemDX12 *				mComputeSystem;
+	mutable ComPtr<ID3D12Resource>	mBufferCPU;
+	ComPtr<ID3D12Resource>			mBufferGPU;
+	mutable bool					mNeedsSync = false;										///< If this buffer needs to be synced from CPU to GPU
+	mutable D3D12_RESOURCE_STATES	mCurrentState = D3D12_RESOURCE_STATE_COPY_DEST;			///< State of the GPU buffer so we can do proper barriers
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 221 - 0
Jolt/Compute/DX12/ComputeQueueDX12.cpp

@@ -0,0 +1,221 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/DX12/ComputeQueueDX12.h>
+#include <Jolt/Compute/DX12/ComputeShaderDX12.h>
+#include <Jolt/Compute/DX12/ComputeBufferDX12.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeQueueDX12::~ComputeQueueDX12()
+{
+	Wait();
+
+	if (mFenceEvent != INVALID_HANDLE_VALUE)
+		CloseHandle(mFenceEvent);
+}
+
+bool ComputeQueueDX12::Initialize(ID3D12Device *inDevice, D3D12_COMMAND_LIST_TYPE inType)
+{
+	D3D12_COMMAND_QUEUE_DESC queue_desc = {};
+	queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
+	queue_desc.Type = inType;
+	queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_HIGH;
+	if (HRFailed(inDevice->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&mCommandQueue))))
+		return false;
+
+	if (HRFailed(inDevice->CreateCommandAllocator(inType, IID_PPV_ARGS(&mCommandAllocator))))
+		return false;
+
+	// Create the command list
+	if (HRFailed(inDevice->CreateCommandList(0, inType, mCommandAllocator.Get(), nullptr, IID_PPV_ARGS(&mCommandList))))
+		return false;
+
+	// Command lists are created in the recording state, but there is nothing to record yet. The main loop expects it to be closed, so close it now
+	if (HRFailed(mCommandList->Close()))
+		return false;
+
+	// Create synchronization object
+	if (HRFailed(inDevice->CreateFence(mFenceValue, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&mFence))))
+		return false;
+
+	// Increment fence value so we don't skip waiting the first time a command list is executed
+	mFenceValue++;
+
+	// Create an event handle to use for frame synchronization
+	mFenceEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
+	if (HRFailed(HRESULT_FROM_WIN32(GetLastError())))
+		return false;
+
+	return true;
+}
+
+ID3D12GraphicsCommandList *ComputeQueueDX12::Start()
+{
+	JPH_ASSERT(!mIsExecuting);
+
+	if (!mIsStarted)
+	{
+		// Reset the allocator
+		if (HRFailed(mCommandAllocator->Reset()))
+			return nullptr;
+
+		// Reset the command list
+		if (HRFailed(mCommandList->Reset(mCommandAllocator.Get(), nullptr)))
+			return nullptr;
+
+		// Now we have started recording commands
+		mIsStarted = true;
+	}
+
+	return mCommandList.Get();
+}
+
+void ComputeQueueDX12::SetShader(const ComputeShader *inShader)
+{
+	ID3D12GraphicsCommandList *command_list = Start();
+	mShader = static_cast<const ComputeShaderDX12 *>(inShader);
+	command_list->SetPipelineState(mShader->GetPipelineState());
+	command_list->SetComputeRootSignature(mShader->GetRootSignature());
+}
+
+void ComputeQueueDX12::SyncCPUToGPU(const ComputeBufferDX12 *inBuffer)
+{
+	// Ensure that any CPU writes are visible to the GPU
+	if (inBuffer->SyncCPUToGPU(mCommandList.Get()))
+	{
+		// After the first upload, the CPU buffer is no longer needed for Buffer and RWBuffer types
+		if (inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer)
+			mDelayedFreedBuffers.emplace_back(inBuffer->ReleaseResourceCPU());
+	}
+}
+
+void ComputeQueueDX12::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
+
+	ID3D12GraphicsCommandList *command_list = Start();
+	const ComputeBufferDX12 *buffer = static_cast<const ComputeBufferDX12 *>(inBuffer);
+	command_list->SetComputeRootConstantBufferView(mShader->NameToIndex(inName), buffer->GetResourceCPU()->GetGPUVirtualAddress());
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueDX12::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+
+	ID3D12GraphicsCommandList *command_list = Start();
+	const ComputeBufferDX12 *buffer = static_cast<const ComputeBufferDX12 *>(inBuffer);
+	uint parameter_index = mShader->NameToIndex(inName);
+	SyncCPUToGPU(buffer);
+	buffer->Barrier(command_list, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE);
+	command_list->SetComputeRootShaderResourceView(parameter_index, buffer->GetResourceGPU()->GetGPUVirtualAddress());
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueDX12::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+
+	ID3D12GraphicsCommandList *command_list = Start();
+	ComputeBufferDX12 *buffer = static_cast<ComputeBufferDX12 *>(inBuffer);
+	uint parameter_index = mShader->NameToIndex(inName);
+	SyncCPUToGPU(buffer);
+	if (!buffer->Barrier(command_list, D3D12_RESOURCE_STATE_UNORDERED_ACCESS) && inBarrier == EBarrier::Yes)
+		buffer->RWBarrier(command_list);
+	command_list->SetComputeRootUnorderedAccessView(parameter_index, buffer->GetResourceGPU()->GetGPUVirtualAddress());
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueDX12::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
+{
+	if (inDst == nullptr || inSrc == nullptr)
+		return;
+	JPH_ASSERT(inDst->GetType() == ComputeBuffer::EType::ReadbackBuffer);
+
+	ID3D12GraphicsCommandList *command_list = Start();
+	ComputeBufferDX12 *dst = static_cast<ComputeBufferDX12 *>(inDst);
+	const ComputeBufferDX12 *src = static_cast<const ComputeBufferDX12 *>(inSrc);
+	dst->Barrier(command_list, D3D12_RESOURCE_STATE_COPY_DEST);
+	src->Barrier(command_list, D3D12_RESOURCE_STATE_COPY_SOURCE);
+	command_list->CopyResource(dst->GetResourceCPU(), src->GetResourceGPU());
+
+	mUsedBuffers.insert(src);
+	mUsedBuffers.insert(dst);
+}
+
+void ComputeQueueDX12::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
+{
+	ID3D12GraphicsCommandList *command_list = Start();
+	command_list->Dispatch(inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
+}
+
+void ComputeQueueDX12::Execute()
+{
+	JPH_ASSERT(mIsStarted);
+	JPH_ASSERT(!mIsExecuting);
+
+	// Close the command list
+	if (HRFailed(mCommandList->Close()))
+		return;
+
+	// Execute the command list
+	ID3D12CommandList *command_lists[] = { mCommandList.Get() };
+	mCommandQueue->ExecuteCommandLists(std::size(command_lists), command_lists);
+
+	// Schedule a Signal command in the queue
+	if (HRFailed(mCommandQueue->Signal(mFence.Get(), mFenceValue)))
+		return;
+
+	// Clear the current shader
+	mShader = nullptr;
+
+	// Mark that we're executing
+	mIsExecuting = true;
+}
+
+void ComputeQueueDX12::Wait()
+{
+	// Check if we've been started
+	if (mIsExecuting)
+	{
+		if (mFence->GetCompletedValue() < mFenceValue)
+		{
+			// Wait until the fence has been processed
+			if (HRFailed(mFence->SetEventOnCompletion(mFenceValue, mFenceEvent)))
+				return;
+			WaitForSingleObjectEx(mFenceEvent, INFINITE, FALSE);
+		}
+
+		// Increment the fence value
+		mFenceValue++;
+
+		// Buffers can be freed now
+		mUsedBuffers.clear();
+
+		// Free buffers
+		mDelayedFreedBuffers.clear();
+
+		// Done executing
+		mIsExecuting = false;
+		mIsStarted = false;
+	}
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 61 - 0
Jolt/Compute/DX12/ComputeQueueDX12.h

@@ -0,0 +1,61 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/ComputeQueue.h>
+#include <Jolt/Compute/DX12/ComputeShaderDX12.h>
+#include <Jolt/Core/UnorderedSet.h>
+
+JPH_NAMESPACE_BEGIN
+
+class ComputeBufferDX12;
+
+/// A command queue for DirectX for executing compute workloads on the GPU.
+class JPH_EXPORT ComputeQueueDX12 final : public ComputeQueue
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Destructor
+	virtual								~ComputeQueueDX12() override;
+
+	/// Initialize the queue
+	bool								Initialize(ID3D12Device *inDevice, D3D12_COMMAND_LIST_TYPE inType);
+
+	/// Start the command list (requires waiting until the previous one is finished)
+	ID3D12GraphicsCommandList *			Start();
+
+	// See: ComputeQueue
+	virtual void						SetShader(const ComputeShader *inShader) override;
+	virtual void						SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void						SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void 						SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
+	virtual void						ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
+	virtual void						Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
+	virtual void						Execute() override;
+	virtual void						Wait() override;
+
+private:
+	/// Copy the CPU buffer to the GPU buffer if needed
+	void								SyncCPUToGPU(const ComputeBufferDX12 *inBuffer);
+
+	ComPtr<ID3D12CommandQueue>			mCommandQueue;								///< The command queue that will hold command lists
+	ComPtr<ID3D12CommandAllocator>		mCommandAllocator;							///< Allocator that holds the memory for the commands
+	ComPtr<ID3D12GraphicsCommandList>	mCommandList;								///< The command list that will hold the render commands / state changes
+	HANDLE								mFenceEvent = INVALID_HANDLE_VALUE;			///< Fence event, used to wait for rendering to complete
+	ComPtr<ID3D12Fence>					mFence;										///< Fence object, used to signal the fence event
+	UINT64								mFenceValue = 0;							///< Current fence value, each time we need to wait we will signal the fence with this value, wait for it and then increase the value
+	RefConst<ComputeShaderDX12>			mShader = nullptr;							///< Current active shader
+	bool								mIsStarted = false;							///< If the command list has been started (reset) and is ready to record commands
+	bool								mIsExecuting = false;						///< If a command list is currently executing on the queue
+	UnorderedSet<RefConst<ComputeBuffer>> mUsedBuffers;								///< Buffers that are in use by the current execution, these will be retained until execution is finished so that we don't free buffers that are in use
+	Array<ComPtr<ID3D12Resource>>		mDelayedFreedBuffers;						///< Buffers freed during the execution
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 54 - 0
Jolt/Compute/DX12/ComputeShaderDX12.h

@@ -0,0 +1,54 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/ComputeShader.h>
+#include <Jolt/Compute/DX12/IncludeDX12.h>
+#include <Jolt/Core/UnorderedMap.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Compute shader handle for DirectX
+class JPH_EXPORT ComputeShaderDX12 : public ComputeShader
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor
+									ComputeShaderDX12(ComPtr<ID3DBlob> inShader, ComPtr<ID3D12RootSignature> inRootSignature, ComPtr<ID3D12PipelineState> inPipelineState, Array<String> &&inBindingNames, UnorderedMap<string_view, uint> &&inNameToIndex, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
+		ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ),
+		mShader(inShader),
+		mRootSignature(inRootSignature),
+		mPipelineState(inPipelineState),
+		mBindingNames(std::move(inBindingNames)),
+		mNameToIndex(std::move(inNameToIndex))
+	{
+	}
+
+	/// Get index of shader parameter
+	uint							NameToIndex(const char *inName) const
+	{
+		UnorderedMap<string_view, uint>::const_iterator it = mNameToIndex.find(inName);
+		JPH_ASSERT(it != mNameToIndex.end());
+		return it->second;
+	}
+
+	/// Getters
+	ID3D12PipelineState *			GetPipelineState() const				{ return mPipelineState.Get(); }
+	ID3D12RootSignature *			GetRootSignature() const				{ return mRootSignature.Get(); }
+
+private:
+	ComPtr<ID3DBlob>				mShader;								///< The compiled shader
+	ComPtr<ID3D12RootSignature>		mRootSignature;							///< The root signature for this shader
+	ComPtr<ID3D12PipelineState>		mPipelineState;							///< The pipeline state object for this shader
+	Array<String>					mBindingNames;							///< A list of binding names, mNameToIndex points to these strings
+	UnorderedMap<string_view, uint>	mNameToIndex;							///< Maps names to indices for the shader parameters, using a string_view so we can do find() without an allocation
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 412 - 0
Jolt/Compute/DX12/ComputeSystemDX12.cpp

@@ -0,0 +1,412 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/DX12/ComputeSystemDX12.h>
+#include <Jolt/Compute/DX12/ComputeQueueDX12.h>
+#include <Jolt/Compute/DX12/ComputeShaderDX12.h>
+#include <Jolt/Compute/DX12/ComputeBufferDX12.h>
+#include <Jolt/Core/StringTools.h>
+#include <Jolt/Core/UnorderedMap.h>
+
+JPH_SUPPRESS_WARNINGS_STD_BEGIN
+JPH_MSVC_SUPPRESS_WARNING(5204) // 'X': class has virtual functions, but its trivial destructor is not virtual; instances of objects derived from this class may not be destructed correctly
+JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
+#include <fstream>
+#include <d3dcompiler.h>
+#include <dxcapi.h>
+#ifdef JPH_DEBUG
+	#include <d3d12sdklayers.h>
+#endif
+JPH_SUPPRESS_WARNINGS_STD_END
+
+JPH_NAMESPACE_BEGIN
+
+void ComputeSystemDX12::Initialize(ID3D12Device *inDevice, EDebug inDebug)
+{
+	mDevice = inDevice;
+	mDebug = inDebug;
+}
+
+void ComputeSystemDX12::Shutdown()
+{
+	mDevice.Reset();
+}
+
+ComPtr<ID3D12Resource> ComputeSystemDX12::CreateD3DResource(D3D12_HEAP_TYPE inHeapType, D3D12_RESOURCE_STATES inResourceState, D3D12_RESOURCE_FLAGS inFlags, uint64 inSize)
+{
+	// Create a new resource
+	D3D12_RESOURCE_DESC desc;
+	desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
+	desc.Alignment = 0;
+	desc.Width = inSize;
+	desc.Height = 1;
+	desc.DepthOrArraySize = 1;
+	desc.MipLevels = 1;
+	desc.Format = DXGI_FORMAT_UNKNOWN;
+	desc.SampleDesc.Count = 1;
+	desc.SampleDesc.Quality = 0;
+	desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
+	desc.Flags = inFlags;
+
+	D3D12_HEAP_PROPERTIES heap_properties = {};
+	heap_properties.Type = inHeapType;
+	heap_properties.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
+	heap_properties.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
+	heap_properties.CreationNodeMask = 1;
+	heap_properties.VisibleNodeMask = 1;
+
+	ComPtr<ID3D12Resource> resource;
+	if (HRFailed(mDevice->CreateCommittedResource(&heap_properties, D3D12_HEAP_FLAG_NONE, &desc, inResourceState, nullptr, IID_PPV_ARGS(&resource))))
+		return nullptr;
+	return resource;
+}
+
+Ref<ComputeShader> ComputeSystemDX12::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
+{
+	// Read shader source file
+	Array<uint8> data;
+	String file_name = String(inName) + ".hlsl";
+	if (!mShaderLoader(file_name.c_str(), data))
+		return nullptr;
+
+#ifndef JPH_USE_DXC // Use FXC, the old shader compiler?
+
+	UINT flags = D3DCOMPILE_ENABLE_STRICTNESS | D3DCOMPILE_WARNINGS_ARE_ERRORS | D3DCOMPILE_ALL_RESOURCES_BOUND;
+#ifdef JPH_DEBUG
+	flags |= D3DCOMPILE_SKIP_OPTIMIZATION;
+#else
+	flags |= D3DCOMPILE_OPTIMIZATION_LEVEL3;
+#endif
+	if (mDebug == EDebug::DebugSymbols)
+		flags |= D3DCOMPILE_DEBUG;
+
+	const D3D_SHADER_MACRO defines[] =
+	{
+		{ nullptr, nullptr }
+	};
+
+	// Handles loading include files through the shader loader
+	struct IncludeHandler : public ID3DInclude
+	{
+								IncludeHandler(const ShaderLoader &inShaderLoader) : mShaderLoader(inShaderLoader) { }
+		virtual					~IncludeHandler() = default;
+
+		STDMETHOD				(Open)(D3D_INCLUDE_TYPE, LPCSTR inFileName, LPCVOID, LPCVOID *outData, UINT *outNumBytes) override
+		{
+			// Read the header file
+			Array<uint8> file_data;
+			if (!mShaderLoader(inFileName, file_data))
+				return E_FAIL;
+			if (file_data.empty())
+			{
+				*outData = nullptr;
+				*outNumBytes = 0;
+				return S_OK;
+			}
+
+			// Copy to a new memory block
+			void *mem = CoTaskMemAlloc(file_data.size());
+			if (mem == nullptr)
+				return E_OUTOFMEMORY;
+			memcpy(mem, file_data.data(), file_data.size());
+			*outData = mem;
+			*outNumBytes = (UINT)file_data.size();
+			return S_OK;
+		}
+
+		STDMETHOD				(Close)(LPCVOID inData) override
+		{
+			if (inData != nullptr)
+				CoTaskMemFree(const_cast<void *>(inData));
+			return S_OK;
+		}
+
+	private:
+		const ShaderLoader &	mShaderLoader;
+	};
+	IncludeHandler include_handler(mShaderLoader);
+
+	// Compile source
+	ComPtr<ID3DBlob> shader_blob, error_blob;
+	if (FAILED(D3DCompile(&data[0],
+				(uint)data.size(),
+				file_name.c_str(),
+				defines,
+				&include_handler,
+				"main",
+				"cs_5_0",
+				flags,
+				0,
+				shader_blob.GetAddressOf(),
+				error_blob.GetAddressOf())))
+	{
+		if (error_blob)
+			Trace("Shader compile error: %s", (const char *)error_blob->GetBufferPointer());
+		return nullptr;
+	}
+
+	// Get shader description
+	ComPtr<ID3D12ShaderReflection> reflector;
+	if (FAILED(D3DReflect(shader_blob->GetBufferPointer(), shader_blob->GetBufferSize(), IID_PPV_ARGS(&reflector))))
+		return nullptr;
+
+#else
+
+	ComPtr<IDxcUtils> utils;
+	DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(utils.GetAddressOf()));
+
+	// Custom include handler that forwards include loads to mShaderLoader
+	struct DxcIncludeHandler : public IDxcIncludeHandler
+	{
+								DxcIncludeHandler(IDxcUtils *inUtils, const ShaderLoader &inLoader) : mUtils(inUtils), mShaderLoader(inLoader) { }
+		virtual					~DxcIncludeHandler() = default;
+
+		STDMETHODIMP			QueryInterface(REFIID riid, void **ppvObject) override
+		{
+			JPH_ASSERT(false);
+			return E_NOINTERFACE;
+		}
+
+		STDMETHODIMP_(ULONG)	AddRef(void) override
+		{
+			// Allocated on the stack, we don't do ref counting
+			return 1;
+		}
+
+		STDMETHODIMP_(ULONG)	Release(void) override
+		{
+			// Allocated on the stack, we don't do ref counting
+			return 1;
+		}
+
+		// IDxcIncludeHandler::LoadSource uses IDxcBlob**
+		STDMETHODIMP			LoadSource(LPCWSTR inFilename, IDxcBlob **outIncludeSource) override
+		{
+			*outIncludeSource = nullptr;
+
+			// Convert to UTF-8
+			char file_name[MAX_PATH];
+			WideCharToMultiByte(CP_UTF8, 0, inFilename, -1, file_name, sizeof(file_name), nullptr, nullptr);
+
+			// Load the header
+			Array<uint8> file_data;
+			if (!mShaderLoader(file_name, file_data))
+				return E_FAIL;
+
+			// Create a blob from the loaded data
+			ComPtr<IDxcBlobEncoding> blob_encoder;
+			HRESULT hr = mUtils->CreateBlob(file_data.empty()? nullptr : file_data.data(), (uint)file_data.size(), CP_UTF8, blob_encoder.GetAddressOf());
+			if (FAILED(hr))
+				return hr;
+
+			// Return as IDxcBlob
+			*outIncludeSource = blob_encoder.Detach();
+			return S_OK;
+		}
+
+		IDxcUtils *				mUtils;
+		const ShaderLoader &	mShaderLoader;
+	};
+	DxcIncludeHandler include_handler(utils.Get(), mShaderLoader);
+
+	ComPtr<IDxcBlobEncoding> source;
+	if (HRFailed(utils->CreateBlob(data.data(), (uint)data.size(), CP_UTF8, source.GetAddressOf())))
+		return nullptr;
+
+	ComPtr<IDxcCompiler3> compiler;
+	DxcCreateInstance(CLSID_DxcCompiler, IID_PPV_ARGS(compiler.GetAddressOf()));
+
+	Array<LPCWSTR> arguments;
+	arguments.push_back(L"-E");
+	arguments.push_back(L"main");
+	arguments.push_back(L"-T");
+	arguments.push_back(L"cs_6_0");
+	arguments.push_back(DXC_ARG_WARNINGS_ARE_ERRORS);
+	arguments.push_back(DXC_ARG_OPTIMIZATION_LEVEL3);
+	arguments.push_back(DXC_ARG_ALL_RESOURCES_BOUND);
+	if (mDebug == EDebug::DebugSymbols)
+		arguments.push_back(DXC_ARG_DEBUG);
+
+	// Compile the shader
+	DxcBuffer source_buffer;
+	source_buffer.Ptr = source->GetBufferPointer();
+	source_buffer.Size = source->GetBufferSize();
+	source_buffer.Encoding = 0;
+	ComPtr<IDxcResult> result;
+	if (FAILED(compiler->Compile(&source_buffer, arguments.data(), (uint32)arguments.size(), &include_handler, IID_PPV_ARGS(result.GetAddressOf()))))
+		return nullptr;
+
+	// Check for compilation errors
+	ComPtr<IDxcBlobUtf8> errors;
+	result->GetOutput(DXC_OUT_ERRORS, IID_PPV_ARGS(errors.GetAddressOf()), nullptr);
+	if (errors != nullptr && errors->GetStringLength() > 0)
+	{
+		Trace((char *)errors->GetBufferPointer());
+		return nullptr;
+	}
+
+	// Get the compiled shader code
+	ComPtr<ID3DBlob> shader_blob;
+	if (HRFailed(result->GetOutput(DXC_OUT_OBJECT, IID_PPV_ARGS(shader_blob.GetAddressOf()), nullptr)))
+		return nullptr;
+
+	if (mDebug == EDebug::DebugSymbols)
+	{
+		// Get shader hash and create PDB file name
+		ComPtr<IDxcBlob> hash;
+		if (HRFailed(result->GetOutput(DXC_OUT_SHADER_HASH, IID_PPV_ARGS(hash.GetAddressOf()), nullptr)))
+			return nullptr;
+		DxcShaderHash *hash_buf = (DxcShaderHash *)hash->GetBufferPointer();
+		String hash_str;
+		for (BYTE b : hash_buf->HashDigest)
+			hash_str += StringFormat("%02x", b);
+		hash_str += ".pdb";
+
+		// Get PDB file from the compiler
+		ComPtr<IDxcBlob> pdb;
+		if (HRFailed(result->GetOutput(DXC_OUT_PDB, IID_PPV_ARGS(pdb.GetAddressOf()), nullptr)))
+			return nullptr;
+
+		// Write PDB file to the temp folder
+		char temp_path[MAX_PATH];
+		GetTempPathA(MAX_PATH, temp_path);
+		std::ofstream pdb_stream((temp_path + hash_str).c_str(), std::ios::out | std::ios::binary);
+		pdb_stream.write((const char *)pdb->GetBufferPointer(), pdb->GetBufferSize());
+	}
+
+	// Get reflection data
+	ComPtr<IDxcBlob> reflection_data;
+	if (HRFailed(result->GetOutput(DXC_OUT_REFLECTION, IID_PPV_ARGS(reflection_data.GetAddressOf()), nullptr)))
+		return nullptr;
+	DxcBuffer reflection_buffer;
+	reflection_buffer.Ptr = reflection_data->GetBufferPointer();
+	reflection_buffer.Size = reflection_data->GetBufferSize();
+	reflection_buffer.Encoding = 0;
+	ComPtr<ID3D12ShaderReflection> reflector;
+	if (HRFailed(utils->CreateReflection(&reflection_buffer, IID_PPV_ARGS(reflector.GetAddressOf()))))
+		return nullptr;
+
+#endif // JPH_USE_DXC
+
+	// Get the shader description
+	D3D12_SHADER_DESC shader_desc;
+	if (HRFailed(reflector->GetDesc(&shader_desc)))
+		return nullptr;
+
+	// Verify that the group sizes match the shader's thread group size
+	UINT thread_group_size_x, thread_group_size_y, thread_group_size_z;
+	if (HRFailed(reflector->GetThreadGroupSize(&thread_group_size_x, &thread_group_size_y, &thread_group_size_z)))
+		return nullptr;
+	JPH_ASSERT(inGroupSizeX == thread_group_size_x, "Group size X mismatch");
+	JPH_ASSERT(inGroupSizeY == thread_group_size_y, "Group size Y mismatch");
+	JPH_ASSERT(inGroupSizeZ == thread_group_size_z, "Group size Z mismatch");
+
+	// Convert parameters to root signature description
+	Array<String> binding_names;
+	binding_names.reserve(shader_desc.BoundResources);
+	UnorderedMap<string_view, uint> name_to_index;
+	Array<D3D12_ROOT_PARAMETER1> root_params;
+	for (UINT i = 0; i < shader_desc.BoundResources; ++i)
+	{
+		D3D12_SHADER_INPUT_BIND_DESC bind_desc;
+		reflector->GetResourceBindingDesc(i, &bind_desc);
+
+		D3D12_ROOT_PARAMETER1 param = {};
+		param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
+
+		switch (bind_desc.Type)
+		{
+		case D3D_SIT_CBUFFER:
+			param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_CBV;
+			break;
+
+		case D3D_SIT_STRUCTURED:
+		case D3D_SIT_BYTEADDRESS:
+			param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_SRV;
+			break;
+
+		case D3D_SIT_UAV_RWTYPED:
+		case D3D_SIT_UAV_RWSTRUCTURED:
+		case D3D_SIT_UAV_RWBYTEADDRESS:
+        case D3D_SIT_UAV_APPEND_STRUCTURED:
+        case D3D_SIT_UAV_CONSUME_STRUCTURED:
+		case D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER:
+			param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_UAV;
+			break;
+
+        case D3D_SIT_TBUFFER:
+        case D3D_SIT_TEXTURE:
+        case D3D_SIT_SAMPLER:
+        case D3D_SIT_RTACCELERATIONSTRUCTURE:
+		case D3D_SIT_UAV_FEEDBACKTEXTURE:
+			JPH_ASSERT(false, "Unsupported shader input type");
+			continue;
+		}
+
+		param.Descriptor.RegisterSpace = bind_desc.Space;
+		param.Descriptor.ShaderRegister = bind_desc.BindPoint;
+		param.Descriptor.Flags = D3D12_ROOT_DESCRIPTOR_FLAG_DATA_VOLATILE;
+
+		binding_names.push_back(bind_desc.Name); // Add all strings to a pool to keep them alive
+		name_to_index[string_view(binding_names.back())] = (uint)root_params.size();
+		root_params.push_back(param);
+	}
+
+	// Create the root signature
+	D3D12_VERSIONED_ROOT_SIGNATURE_DESC root_sig_desc = {};
+	root_sig_desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
+	root_sig_desc.Desc_1_1.NumParameters = (UINT)root_params.size();
+	root_sig_desc.Desc_1_1.pParameters = root_params.data();
+	root_sig_desc.Desc_1_1.NumStaticSamplers = 0;
+	root_sig_desc.Desc_1_1.pStaticSamplers = nullptr;
+	root_sig_desc.Desc_1_1.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
+	ComPtr<ID3DBlob> serialized_sig;
+	ComPtr<ID3DBlob> root_sig_error_blob;
+	if (FAILED(D3D12SerializeVersionedRootSignature(&root_sig_desc, &serialized_sig, &root_sig_error_blob)))
+	{
+		if (root_sig_error_blob)
+			Trace("Failed to create root signature: %s", (const char *)root_sig_error_blob->GetBufferPointer());
+		return nullptr;
+	}
+	ComPtr<ID3D12RootSignature> root_sig;
+	if (FAILED(mDevice->CreateRootSignature(0, serialized_sig->GetBufferPointer(), serialized_sig->GetBufferSize(), IID_PPV_ARGS(&root_sig))))
+		return nullptr;
+
+	// Create a pipeline state object from the root signature and the shader
+	ComPtr<ID3D12PipelineState> pipeline_state;
+	D3D12_COMPUTE_PIPELINE_STATE_DESC compute_state_desc = {};
+	compute_state_desc.pRootSignature = root_sig.Get();
+	compute_state_desc.CS = { shader_blob->GetBufferPointer(), shader_blob->GetBufferSize() };
+	if (FAILED(mDevice->CreateComputePipelineState(&compute_state_desc, IID_PPV_ARGS(&pipeline_state))))
+		return nullptr;
+
+	// Set name on DX12 objects for easier debugging
+	wchar_t w_name[1024];
+	size_t converted_chars = 0;
+	mbstowcs_s(&converted_chars, w_name, 1024, inName, _TRUNCATE);
+	pipeline_state->SetName(w_name);
+
+	return new ComputeShaderDX12(shader_blob, root_sig, pipeline_state, std::move(binding_names), std::move(name_to_index), inGroupSizeX, inGroupSizeY, inGroupSizeZ);
+}
+
+Ref<ComputeBuffer> ComputeSystemDX12::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
+{
+	return new ComputeBufferDX12(this, inType, inSize, inStride, inData);
+}
+
+Ref<ComputeQueue> ComputeSystemDX12::CreateComputeQueue()
+{
+	Ref<ComputeQueueDX12> queue = new ComputeQueueDX12();
+	if (!queue->Initialize(mDevice.Get(), D3D12_COMMAND_LIST_TYPE_COMPUTE))
+		return nullptr;
+	return queue.GetPtr();
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 52 - 0
Jolt/Compute/DX12/ComputeSystemDX12.h

@@ -0,0 +1,52 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Core/UnorderedMap.h>
+#include <Jolt/Compute/ComputeSystem.h>
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/DX12/IncludeDX12.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Interface to run a workload on the GPU using DirectX 12.
+/// Minimal implementation that can integrate with your own DirectX 12 setup.
+class JPH_EXPORT ComputeSystemDX12 : public ComputeSystem
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// How we want to compile our shaders
+	enum class EDebug
+	{
+		NoDebugSymbols,
+		DebugSymbols
+	};
+
+	/// Initialize / shutdown
+	void							Initialize(ID3D12Device *inDevice, EDebug inDebug);
+	void							Shutdown();
+
+	// See: ComputeSystem
+	virtual Ref<ComputeShader>		CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
+	virtual Ref<ComputeBuffer>		CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
+	virtual Ref<ComputeQueue>		CreateComputeQueue() override;
+
+	/// Access to the DX12 device
+	ID3D12Device *					GetDevice() const								{ return mDevice.Get(); }
+
+	// Function to create a ID3D12Resource on specified heap with specified state
+	ComPtr<ID3D12Resource>			CreateD3DResource(D3D12_HEAP_TYPE inHeapType, D3D12_RESOURCE_STATES inResourceState, D3D12_RESOURCE_FLAGS inFlags, uint64 inSize);
+
+private:
+	ComPtr<ID3D12Device>			mDevice;
+	EDebug							mDebug = EDebug::NoDebugSymbols;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 147 - 0
Jolt/Compute/DX12/ComputeSystemDX12Impl.cpp

@@ -0,0 +1,147 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/DX12/ComputeSystemDX12Impl.h>
+
+#ifdef JPH_DEBUG
+	#include <d3d12sdklayers.h>
+#endif
+
+JPH_NAMESPACE_BEGIN
+
+ComputeSystemDX12Impl::~ComputeSystemDX12Impl()
+{
+	Shutdown();
+	mDXGIFactory.Reset();
+
+#ifdef JPH_DEBUG
+	// Test for leaks
+	ComPtr<IDXGIDebug1> dxgi_debug;
+	if (SUCCEEDED(DXGIGetDebugInterface1(0, IID_PPV_ARGS(&dxgi_debug))))
+		dxgi_debug->ReportLiveObjects(DXGI_DEBUG_ALL, DXGI_DEBUG_RLO_ALL);
+#endif
+}
+
+bool ComputeSystemDX12Impl::Initialize()
+{
+#if defined(JPH_DEBUG)
+	// Enable the D3D12 debug layer
+	ComPtr<ID3D12Debug> debug_controller;
+	if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug_controller))))
+		debug_controller->EnableDebugLayer();
+#endif
+
+	// Create DXGI factory
+	if (HRFailed(CreateDXGIFactory1(IID_PPV_ARGS(&mDXGIFactory))))
+		return false;
+
+	// Find adapter
+	ComPtr<IDXGIAdapter1> adapter;
+	ComPtr<ID3D12Device> device;
+
+	HRESULT result = E_FAIL;
+
+	// First check if we have the Windows 1803 IDXGIFactory6 interface
+	ComPtr<IDXGIFactory6> factory6;
+	if (SUCCEEDED(mDXGIFactory->QueryInterface(IID_PPV_ARGS(&factory6))))
+	{
+		for (int search_software = 0; search_software < 2 && device == nullptr; ++search_software)
+			for (UINT index = 0; factory6->EnumAdapterByGpuPreference(index, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(&adapter)) != DXGI_ERROR_NOT_FOUND; ++index)
+			{
+				DXGI_ADAPTER_DESC1 desc;
+				adapter->GetDesc1(&desc);
+
+				// We don't want software renderers in the first pass
+				int is_software = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) != 0? 1 : 0;
+				if (search_software != is_software)
+					continue;
+
+				// Check to see whether the adapter supports Direct3D 12
+			#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
+				int prev_state = _CrtSetDbgFlag(0); // Temporarily disable leak detection as this call reports false positives
+			#endif
+				result = D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&device));
+			#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
+				_CrtSetDbgFlag(prev_state);
+			#endif
+				if (SUCCEEDED(result))
+					break;
+			}
+	}
+	else
+	{
+		// Fall back to the older method that may not get the fastest GPU
+		for (int search_software = 0; search_software < 2 && device == nullptr; ++search_software)
+			for (UINT index = 0; mDXGIFactory->EnumAdapters1(index, &adapter) != DXGI_ERROR_NOT_FOUND; ++index)
+			{
+				DXGI_ADAPTER_DESC1 desc;
+				adapter->GetDesc1(&desc);
+
+				// We don't want software renderers in the first pass
+				int is_software = (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) != 0? 1 : 0;
+				if (search_software != is_software)
+					continue;
+
+				// Check to see whether the adapter supports Direct3D 12
+			#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
+				int prev_state = _CrtSetDbgFlag(0); // Temporarily disable leak detection as this call reports false positives
+			#endif
+				result = D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&device));
+			#if defined(JPH_PLATFORM_WINDOWS) && defined(_DEBUG)
+				_CrtSetDbgFlag(prev_state);
+			#endif
+				if (SUCCEEDED(result))
+					break;
+			}
+	}
+
+	// Check if we managed to obtain a device
+	if (HRFailed(result))
+		return false;
+
+	// Initialize the compute interface
+	ComputeSystemDX12::Initialize(device.Get(), EDebug::DebugSymbols);
+
+#ifdef JPH_DEBUG
+	// Enable breaking on errors
+	ComPtr<ID3D12InfoQueue> info_queue;
+	if (SUCCEEDED(device.As(&info_queue)))
+	{
+		info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_CORRUPTION, TRUE);
+		info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_ERROR, TRUE);
+		info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_WARNING, TRUE);
+
+		// Disable an error that triggers on Windows 11 with a hybrid graphic system
+		// See: https://stackoverflow.com/questions/69805245/directx-12-application-is-crashing-in-windows-11
+		D3D12_MESSAGE_ID hide[] =
+		{
+			D3D12_MESSAGE_ID_RESOURCE_BARRIER_MISMATCHING_COMMAND_LIST_TYPE,
+		};
+		D3D12_INFO_QUEUE_FILTER filter = { };
+		filter.DenyList.NumIDs = static_cast<UINT>( std::size( hide ) );
+		filter.DenyList.pIDList = hide;
+		info_queue->AddStorageFilterEntries( &filter );
+	}
+#endif // JPH_DEBUG
+
+	return true;
+}
+
+ComputeSystem *CreateComputeSystemDX12()
+{
+	ComputeSystemDX12Impl *compute = new ComputeSystemDX12Impl();
+	if (compute->Initialize())
+		return compute;
+
+	delete compute;
+	return nullptr;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 33 - 0
Jolt/Compute/DX12/ComputeSystemDX12Impl.h

@@ -0,0 +1,33 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Compute/DX12/ComputeSystemDX12.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Implementation of ComputeSystemDX12 that fully initializes DirectX 12
+class JPH_EXPORT ComputeSystemDX12Impl : public ComputeSystemDX12
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Destructor
+	virtual 						~ComputeSystemDX12Impl() override;
+
+	/// Initialize the compute system
+	bool							Initialize();
+
+	IDXGIFactory4 *					GetDXGIFactory() const						{ return mDXGIFactory.Get(); }
+
+private:
+	ComPtr<IDXGIFactory4>			mDXGIFactory;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 40 - 0
Jolt/Compute/DX12/IncludeDX12.h

@@ -0,0 +1,40 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_DX12
+
+#include <Jolt/Core/IncludeWindows.h>
+
+JPH_SUPPRESS_WARNINGS_STD_BEGIN
+JPH_MSVC_SUPPRESS_WARNING(4265) // 'X': class has virtual functions, but its non-trivial destructor is not virtual; instances of this class may not be destructed correctly
+JPH_MSVC_SUPPRESS_WARNING(4625) // 'X': copy constructor was implicitly defined as deleted
+JPH_MSVC_SUPPRESS_WARNING(4626) // 'X': assignment operator was implicitly defined as deleted
+JPH_MSVC_SUPPRESS_WARNING(5204) // 'X': class has virtual functions, but its trivial destructor is not virtual; instances of objects derived from this class may not be destructed correctly
+JPH_MSVC_SUPPRESS_WARNING(5220) // 'X': a non-static data member with a volatile qualified type no longer implies
+JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
+#include <d3d12.h>
+#include <dxgi1_6.h>
+#include <dxgidebug.h>
+#include <wrl.h>
+JPH_SUPPRESS_WARNINGS_STD_END
+
+JPH_NAMESPACE_BEGIN
+
+using Microsoft::WRL::ComPtr;
+
+inline bool HRFailed(HRESULT inHR)
+{
+	if (SUCCEEDED(inHR))
+		return false;
+
+	Trace("Call failed with error code: %08X", inHR);
+	JPH_ASSERT(false);
+	return true;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_DX12

+ 37 - 0
Jolt/Compute/MTL/ComputeBufferMTL.h

@@ -0,0 +1,37 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_MTL
+
+#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Buffer that can be read from / written to by a compute shader
+class JPH_EXPORT ComputeBufferMTL final : public ComputeBuffer
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor
+									ComputeBufferMTL(ComputeSystemMTL *inComputeSystem, EType inType, uint64 inSize, uint inStride, const void *inData);
+	virtual							~ComputeBufferMTL() override;
+
+	virtual void					Unmap() override;
+
+	virtual Ref<ComputeBuffer>		CreateReadBackBuffer() const override;
+
+	id<MTLBuffer>					GetBuffer() const							{ return mBuffer; }
+
+private:
+	virtual void *					MapInternal(EMode inMode) override;
+
+	id<MTLBuffer>					mBuffer;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 44 - 0
Jolt/Compute/MTL/ComputeBufferMTL.mm

@@ -0,0 +1,44 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_MTL
+
+#include <Jolt/Compute/MTL/ComputeBufferMTL.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeBufferMTL::ComputeBufferMTL(ComputeSystemMTL *inComputeSystem, EType inType, uint64 inSize, uint inStride, const void *inData) :
+	ComputeBuffer(inType, inSize, inStride)
+{
+	NSUInteger size = NSUInteger(inSize) * inStride;
+	if (inData != nullptr)
+		mBuffer = [inComputeSystem->GetDevice() newBufferWithBytes: inData length: size options: MTLResourceCPUCacheModeDefaultCache | MTLResourceStorageModeShared | MTLResourceHazardTrackingModeTracked];
+	else
+		mBuffer = [inComputeSystem->GetDevice() newBufferWithLength: size options: MTLResourceCPUCacheModeDefaultCache | MTLResourceStorageModeShared | MTLResourceHazardTrackingModeTracked];
+}
+
+ComputeBufferMTL::~ComputeBufferMTL()
+{
+	[mBuffer release];
+}
+
+void *ComputeBufferMTL::MapInternal(EMode inMode)
+{
+	return mBuffer.contents;
+}
+
+void ComputeBufferMTL::Unmap()
+{
+}
+
+Ref<ComputeBuffer> ComputeBufferMTL::CreateReadBackBuffer() const
+{
+	return const_cast<ComputeBufferMTL *>(this);
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 49 - 0
Jolt/Compute/MTL/ComputeQueueMTL.h

@@ -0,0 +1,49 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_MTL
+
+#include <MetalKit/MetalKit.h>
+
+#include <Jolt/Compute/ComputeQueue.h>
+
+JPH_NAMESPACE_BEGIN
+
+class ComputeShaderMTL;
+
+/// A command queue for Metal for executing compute workloads on the GPU.
+class JPH_EXPORT ComputeQueueMTL final : public ComputeQueue
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor / destructor
+										ComputeQueueMTL(id<MTLDevice> inDevice);
+	virtual								~ComputeQueueMTL() override;
+
+	// See: ComputeQueue
+	virtual void						SetShader(const ComputeShader *inShader) override;
+	virtual void						SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void						SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void 						SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
+	virtual void						ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
+	virtual void						Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
+	virtual void						Execute() override;
+	virtual void						Wait() override;
+
+private:
+	void								BeginCommandBuffer();
+
+	id<MTLCommandQueue>					mCommandQueue;
+	id<MTLCommandBuffer> 				mCommandBuffer;
+	id<MTLComputeCommandEncoder>		mComputeEncoder;
+	RefConst<ComputeShaderMTL>			mShader;
+	bool								mIsExecuting = false;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 123 - 0
Jolt/Compute/MTL/ComputeQueueMTL.mm

@@ -0,0 +1,123 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_MTL
+
+#include <Jolt/Compute/MTL/ComputeQueueMTL.h>
+#include <Jolt/Compute/MTL/ComputeShaderMTL.h>
+#include <Jolt/Compute/MTL/ComputeBufferMTL.h>
+#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeQueueMTL::~ComputeQueueMTL()
+{
+	Wait();
+
+	[mCommandQueue release];
+}
+
+ComputeQueueMTL::ComputeQueueMTL(id<MTLDevice> inDevice)
+{
+	// Create the command queue
+	mCommandQueue = [inDevice newCommandQueue];
+}
+
+void ComputeQueueMTL::BeginCommandBuffer()
+{
+	if (mCommandBuffer == nil)
+	{
+		// Start a new command buffer
+		mCommandBuffer = [mCommandQueue commandBuffer];
+		mComputeEncoder = [mCommandBuffer computeCommandEncoder];
+	}
+}
+
+void ComputeQueueMTL::SetShader(const ComputeShader *inShader)
+{
+	BeginCommandBuffer();
+
+	mShader = static_cast<const ComputeShaderMTL *>(inShader);
+
+	[mComputeEncoder setComputePipelineState: mShader->GetPipelineState()];
+}
+
+void ComputeQueueMTL::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
+
+	BeginCommandBuffer();
+
+	const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
+	[mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
+}
+
+void ComputeQueueMTL::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+
+	BeginCommandBuffer();
+
+	const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
+	[mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
+}
+
+void ComputeQueueMTL::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+
+	BeginCommandBuffer();
+
+	const ComputeBufferMTL *buffer = static_cast<const ComputeBufferMTL *>(inBuffer);
+	[mComputeEncoder setBuffer: buffer->GetBuffer() offset: 0 atIndex: mShader->NameToBindingIndex(inName)];
+}
+
+void ComputeQueueMTL::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
+{
+	JPH_ASSERT(inDst == inSrc); // Since ComputeBuffer::CreateReadBackBuffer returns the same buffer, we don't need to copy
+}
+
+void ComputeQueueMTL::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
+{
+	BeginCommandBuffer();
+
+	MTLSize thread_groups = MTLSizeMake(inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
+	MTLSize group_size = MTLSizeMake(mShader->GetGroupSizeX(), mShader->GetGroupSizeY(), mShader->GetGroupSizeZ());
+	[mComputeEncoder dispatchThreadgroups: thread_groups threadsPerThreadgroup: group_size];
+}
+
+void ComputeQueueMTL::Execute()
+{
+	// End command buffer
+	if (mCommandBuffer == nil)
+		return;
+
+	[mComputeEncoder endEncoding];
+	[mCommandBuffer commit];
+	mShader = nullptr;
+	mIsExecuting = true;
+}
+
+void ComputeQueueMTL::Wait()
+{
+	if (!mIsExecuting)
+		return;
+
+	[mCommandBuffer waitUntilCompleted];
+	mComputeEncoder = nil;
+	mCommandBuffer = nil;
+	mIsExecuting = false;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 39 - 0
Jolt/Compute/MTL/ComputeShaderMTL.h

@@ -0,0 +1,39 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_MTL
+
+#include <MetalKit/MetalKit.h>
+
+#include <Jolt/Compute/ComputeShader.h>
+#include <Jolt/Core/UnorderedMap.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Compute shader handle for Metal
+class JPH_EXPORT ComputeShaderMTL : public ComputeShader
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor
+								ComputeShaderMTL(id<MTLComputePipelineState> inPipelineState, MTLComputePipelineReflection *inReflection, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ);
+	virtual						~ComputeShaderMTL() override 					{ [mPipelineState release]; }
+
+	/// Access to the function
+	id<MTLComputePipelineState>	GetPipelineState() const						{ return mPipelineState; }
+
+	/// Get index of buffer name
+	uint						NameToBindingIndex(const char *inName) const;
+
+private:
+	id<MTLComputePipelineState>	mPipelineState;
+	UnorderedMap<String, uint>	mNameToBindingIndex;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 34 - 0
Jolt/Compute/MTL/ComputeShaderMTL.mm

@@ -0,0 +1,34 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_MTL
+
+#include <Jolt/Compute/MTL/ComputeShaderMTL.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeShaderMTL::ComputeShaderMTL(id<MTLComputePipelineState> inPipelineState, MTLComputePipelineReflection *inReflection, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) :
+	ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ),
+	mPipelineState(inPipelineState)
+{
+	for (id<MTLBinding> binding in inReflection.bindings)
+	{
+		const char *name = [binding.name UTF8String];
+		uint index = uint(binding.index);
+		mNameToBindingIndex[name] = index;
+	}
+}
+
+uint ComputeShaderMTL::NameToBindingIndex(const char *inName) const
+{
+	UnorderedMap<String, uint>::const_iterator it = mNameToBindingIndex.find(inName);
+	JPH_ASSERT(it != mNameToBindingIndex.end());
+	return it->second;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 40 - 0
Jolt/Compute/MTL/ComputeSystemMTL.h

@@ -0,0 +1,40 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeSystem.h>
+
+#ifdef JPH_USE_MTL
+
+#include <MetalKit/MetalKit.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Interface to run a workload on the GPU
+class JPH_EXPORT ComputeSystemMTL : public ComputeSystem
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	// Initialize / shutdown the compute system
+	bool							Initialize(id<MTLDevice> inDevice);
+	void							Shutdown();
+
+	// See: ComputeSystem
+	virtual Ref<ComputeShader>		CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
+	virtual Ref<ComputeBuffer>		CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
+	virtual Ref<ComputeQueue>		CreateComputeQueue() override;
+
+	/// Get the metal device
+	id<MTLDevice>					GetDevice() const						{ return mDevice; }
+
+private:
+	id<MTLDevice>					mDevice;
+	id<MTLLibrary>					mShaderLibrary;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 89 - 0
Jolt/Compute/MTL/ComputeSystemMTL.mm

@@ -0,0 +1,89 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_MTL
+
+#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
+#include <Jolt/Compute/MTL/ComputeBufferMTL.h>
+#include <Jolt/Compute/MTL/ComputeShaderMTL.h>
+#include <Jolt/Compute/MTL/ComputeQueueMTL.h>
+
+JPH_NAMESPACE_BEGIN
+
+bool ComputeSystemMTL::Initialize(id<MTLDevice> inDevice)
+{
+	mDevice = [inDevice retain];
+
+	return true;
+}
+
+void ComputeSystemMTL::Shutdown()
+{
+	[mShaderLibrary release];
+	[mDevice release];
+}
+
+Ref<ComputeShader> ComputeSystemMTL::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
+{
+	if (mShaderLibrary == nil)
+	{
+		// Load the shader library containing all shaders
+		Array<uint8> *data = new Array<uint8>();
+		if (!mShaderLoader("Jolt.metallib", *data))
+		{
+			JPH_ASSERT(false, "Failed to load shader library");
+			delete data;
+			return nullptr;
+		}
+
+		// Convert to dispatch data
+		dispatch_data_t data_dispatch = dispatch_data_create(data->data(), data->size(), nullptr, ^{ delete data; });
+
+		// Create the library
+		NSError *error = nullptr;
+		mShaderLibrary = [mDevice newLibraryWithData: data_dispatch error: &error];
+		if (error != nil)
+		{
+			JPH_ASSERT(false, "Failed to load shader library");
+			return nullptr;
+		}
+	}
+
+	// Get the shader function
+	id<MTLFunction> function = [mShaderLibrary newFunctionWithName: [NSString stringWithCString: inName encoding: NSUTF8StringEncoding]];
+	if (function == nil)
+	{
+		Trace("Failed to create compute shader: %s", inName);
+		return nullptr;
+	}
+
+	// Create the pipeline
+	NSError *error = nil;
+	MTLComputePipelineReflection *reflection = nil;
+	id<MTLComputePipelineState> pipeline_state = [mDevice newComputePipelineStateWithFunction: function options: MTLPipelineOptionBindingInfo | MTLPipelineOptionBufferTypeInfo reflection: &reflection error: &error];
+	if (error != nil || pipeline_state == nil)
+	{
+		JPH_ASSERT(false, "Failed to create compute pipeline");
+		[function release];
+		return nullptr;
+	}
+
+	return new ComputeShaderMTL(pipeline_state, reflection, inGroupSizeX, inGroupSizeY, inGroupSizeZ);
+}
+
+Ref<ComputeBuffer> ComputeSystemMTL::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
+{
+	return new ComputeBufferMTL(this, inType, inSize, inStride, inData);
+}
+
+Ref<ComputeQueue> ComputeSystemMTL::CreateComputeQueue()
+{
+	return new ComputeQueueMTL(mDevice);
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 28 - 0
Jolt/Compute/MTL/ComputeSystemMTLImpl.h

@@ -0,0 +1,28 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_MTL
+
+#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Interface to run a workload on the GPU that fully initializes Metal.
+class JPH_EXPORT ComputeSystemMTLImpl : public ComputeSystemMTL
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Destructor
+	virtual							~ComputeSystemMTLImpl() override;
+
+	/// Initialize / shutdown the compute system
+	bool							Initialize();
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 39 - 0
Jolt/Compute/MTL/ComputeSystemMTLImpl.mm

@@ -0,0 +1,39 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_MTL
+
+#include <Jolt/Compute/MTL/ComputeSystemMTLImpl.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeSystemMTLImpl::~ComputeSystemMTLImpl()
+{
+	Shutdown();
+
+	[GetDevice() release];
+}
+
+bool ComputeSystemMTLImpl::Initialize()
+{
+	id<MTLDevice> device = MTLCreateSystemDefaultDevice();
+
+	return ComputeSystemMTL::Initialize(device);
+}
+
+ComputeSystem *CreateComputeSystemMTL()
+{
+	ComputeSystemMTLImpl *compute = new ComputeSystemMTLImpl;
+	if (compute->Initialize())
+		return compute;
+
+	delete compute;
+	return nullptr;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_MTL

+ 41 - 0
Jolt/Compute/VK/BufferVK.h

@@ -0,0 +1,41 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2024 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/VK/IncludeVK.h>
+#include <Jolt/Core/Reference.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Simple wrapper class to manage a Vulkan memory block
+class MemoryVK : public RefTarget<MemoryVK>
+{
+public:
+								~MemoryVK()
+	{
+		// We should have unmapped and freed the block before destruction
+		JPH_ASSERT(mMappedCount == 0);
+		JPH_ASSERT(mMemory == VK_NULL_HANDLE);
+	}
+
+	VkDeviceMemory				mMemory = VK_NULL_HANDLE;		///< The Vulkan memory handle
+	VkDeviceSize				mSize = 0;						///< Size of the memory block
+	VkDeviceSize				mBufferSize = 0;				///< Size of each of the buffers that this memory block has been divided into
+	VkMemoryPropertyFlags		mProperties = 0;				///< Vulkan memory properties used to allocate this block
+	int							mMappedCount = 0;				///< How often buffers using this memory block were mapped
+	void *						mMappedPtr = nullptr;			///< The CPU address of the memory block when mapped
+};
+
+/// Simple wrapper class to manage a Vulkan buffer
+class BufferVK
+{
+public:
+	Ref<MemoryVK>				mMemory;						///< The memory block that contains the buffer (note that filling this in is optional if you do your own buffer allocation)
+	VkBuffer					mBuffer = VK_NULL_HANDLE;		///< The Vulkan buffer handle
+	VkDeviceSize				mOffset = 0;					///< Offset in the memory block where the buffer starts
+	VkDeviceSize				mSize = 0;						///< Real size of the buffer
+};
+
+JPH_NAMESPACE_END

+ 130 - 0
Jolt/Compute/VK/ComputeBufferVK.cpp

@@ -0,0 +1,130 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeBufferVK.h>
+#include <Jolt/Compute/VK/ComputeSystemVK.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeBufferVK::ComputeBufferVK(ComputeSystemVK *inComputeSystem, EType inType, uint64 inSize, uint inStride, const void *inData) :
+	ComputeBuffer(inType, inSize, inStride),
+	mComputeSystem(inComputeSystem)
+{
+	VkDeviceSize buffer_size = VkDeviceSize(inSize * inStride);
+
+	switch (inType)
+	{
+	case EType::Buffer:
+		JPH_ASSERT(inData != nullptr);
+		[[fallthrough]];
+
+	case EType::UploadBuffer:
+	case EType::RWBuffer:
+		mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_TRANSFER_SRC_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT, mBufferCPU);
+		mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, mBufferGPU);
+		if (inData != nullptr)
+		{
+			void *data = mComputeSystem->MapBuffer(mBufferCPU);
+			memcpy(data, inData, size_t(buffer_size));
+			mComputeSystem->UnmapBuffer(mBufferCPU);
+			mNeedsSync = true;
+		}
+		break;
+
+	case EType::ConstantBuffer:
+		mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT, mBufferCPU);
+		if (inData != nullptr)
+		{
+			void* data = mComputeSystem->MapBuffer(mBufferCPU);
+			memcpy(data, inData, size_t(buffer_size));
+			mComputeSystem->UnmapBuffer(mBufferCPU);
+		}
+		break;
+
+	case EType::ReadbackBuffer:
+		JPH_ASSERT(inData == nullptr, "Can't upload data to a readback buffer");
+		mComputeSystem->CreateBuffer(buffer_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT, mBufferCPU);
+		break;
+	}
+}
+
+ComputeBufferVK::~ComputeBufferVK()
+{
+	mComputeSystem->FreeBuffer(mBufferGPU);
+	mComputeSystem->FreeBuffer(mBufferCPU);
+}
+
+void ComputeBufferVK::Barrier(VkCommandBuffer inCommandBuffer, VkPipelineStageFlags inToStage, VkAccessFlagBits inToFlags, bool inForce) const
+{
+	if (mAccessStage == inToStage && mAccessFlagBits == inToFlags && !inForce)
+		return;
+
+	VkBufferMemoryBarrier b = {};
+	b.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
+	b.srcAccessMask = mAccessFlagBits;
+	b.dstAccessMask = inToFlags;
+	b.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
+	b.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
+	b.buffer = mBufferGPU.mBuffer != VK_NULL_HANDLE? mBufferGPU.mBuffer : mBufferCPU.mBuffer;
+	b.offset = 0;
+	b.size = VK_WHOLE_SIZE;
+	vkCmdPipelineBarrier(inCommandBuffer, mAccessStage, inToStage, 0, 0, nullptr, 1, &b, 0, nullptr);
+
+	mAccessStage = inToStage;
+	mAccessFlagBits = inToFlags;
+}
+
+bool ComputeBufferVK::SyncCPUToGPU(VkCommandBuffer inCommandBuffer) const
+{
+	if (!mNeedsSync)
+		return false;
+
+	// Barrier before write
+	Barrier(inCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_WRITE_BIT, false);
+
+	// Copy from CPU to GPU
+	VkBufferCopy copy = {};
+	copy.srcOffset = 0;
+	copy.dstOffset = 0;
+	copy.size = GetSize() * GetStride();
+	vkCmdCopyBuffer(inCommandBuffer, mBufferCPU.mBuffer, mBufferGPU.mBuffer, 1, &copy);
+
+	mNeedsSync = false;
+	return true;
+}
+
+void *ComputeBufferVK::MapInternal(EMode inMode)
+{
+	switch (inMode)
+	{
+	case EMode::Read:
+		JPH_ASSERT(mType == EType::ReadbackBuffer);
+		break;
+
+	case EMode::Write:
+		JPH_ASSERT(mType == EType::UploadBuffer || mType == EType::ConstantBuffer);
+		mNeedsSync = true;
+		break;
+	}
+
+	return mComputeSystem->MapBuffer(mBufferCPU);
+}
+
+void ComputeBufferVK::Unmap()
+{
+	mComputeSystem->UnmapBuffer(mBufferCPU);
+}
+
+Ref<ComputeBuffer> ComputeBufferVK::CreateReadBackBuffer() const
+{
+	return mComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ReadbackBuffer, mSize, mStride);
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 51 - 0
Jolt/Compute/VK/ComputeBufferVK.h

@@ -0,0 +1,51 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeBuffer.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/BufferVK.h>
+
+JPH_NAMESPACE_BEGIN
+
+class ComputeSystemVK;
+
+/// Buffer that can be read from / written to by a compute shader
+class JPH_EXPORT ComputeBufferVK final : public ComputeBuffer
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor
+									ComputeBufferVK(ComputeSystemVK *inComputeSystem, EType inType, uint64 inSize, uint inStride, const void *inData);
+	virtual							~ComputeBufferVK() override;
+
+	virtual void					Unmap() override;
+
+	virtual Ref<ComputeBuffer>		CreateReadBackBuffer() const override;
+
+	VkBuffer						GetBufferCPU() const									{ return mBufferCPU.mBuffer; }
+	VkBuffer						GetBufferGPU() const									{ return mBufferGPU.mBuffer; }
+	BufferVK						ReleaseBufferCPU() const								{ BufferVK tmp = mBufferCPU; mBufferCPU = BufferVK(); return tmp; }
+
+	void							Barrier(VkCommandBuffer inCommandBuffer, VkPipelineStageFlags inToStage, VkAccessFlagBits inToFlags, bool inForce) const;
+	bool							SyncCPUToGPU(VkCommandBuffer inCommandBuffer) const;
+
+private:
+	virtual void *					MapInternal(EMode inMode) override;
+
+	ComputeSystemVK *				mComputeSystem;
+	mutable BufferVK				mBufferCPU;
+	BufferVK						mBufferGPU;
+	mutable bool					mNeedsSync = false;										///< If this buffer needs to be synced from CPU to GPU
+	mutable VkAccessFlagBits		mAccessFlagBits = VK_ACCESS_SHADER_READ_BIT;			///< Access flags of the last usage, used for barriers
+	mutable VkPipelineStageFlags	mAccessStage = VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;	///< Pipeline stage of the last usage, used for barriers
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 304 - 0
Jolt/Compute/VK/ComputeQueueVK.cpp

@@ -0,0 +1,304 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeQueueVK.h>
+#include <Jolt/Compute/VK/ComputeBufferVK.h>
+#include <Jolt/Compute/VK/ComputeSystemVK.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeQueueVK::~ComputeQueueVK()
+{
+	Wait();
+
+	VkDevice device = mComputeSystem->GetDevice();
+
+	if (mCommandBuffer != VK_NULL_HANDLE)
+		vkFreeCommandBuffers(device, mCommandPool, 1, &mCommandBuffer);
+
+	if (mCommandPool != VK_NULL_HANDLE)
+		vkDestroyCommandPool(device, mCommandPool, nullptr);
+
+	if (mDescriptorPool != VK_NULL_HANDLE)
+		vkDestroyDescriptorPool(device, mDescriptorPool, nullptr);
+
+	if (mFence != VK_NULL_HANDLE)
+		vkDestroyFence(device, mFence, nullptr);
+}
+
+bool ComputeQueueVK::Initialize(uint32 inComputeQueueIndex)
+{
+	// Get the queue
+	VkDevice device = mComputeSystem->GetDevice();
+	vkGetDeviceQueue(device, inComputeQueueIndex, 0, &mQueue);
+
+	// Create a command pool
+	VkCommandPoolCreateInfo pool_info = {};
+	pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
+	pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
+	pool_info.queueFamilyIndex = inComputeQueueIndex;
+	if (VKFailed(vkCreateCommandPool(device, &pool_info, nullptr, &mCommandPool)))
+		return false;
+
+	// Create descriptor pool
+	VkDescriptorPoolSize descriptor_pool_sizes[] = {
+		{ VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1024 },
+		{ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 16 * 1024 },
+	};
+	VkDescriptorPoolCreateInfo descriptor_info = {};
+	descriptor_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
+	descriptor_info.poolSizeCount = std::size(descriptor_pool_sizes);
+	descriptor_info.pPoolSizes = descriptor_pool_sizes;
+	descriptor_info.maxSets = 256;
+	if (VKFailed(vkCreateDescriptorPool(device, &descriptor_info, nullptr, &mDescriptorPool)))
+		return false;
+
+	// Create a command buffer
+	VkCommandBufferAllocateInfo alloc_info = {};
+	alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+	alloc_info.commandPool = mCommandPool;
+	alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+	alloc_info.commandBufferCount = 1;
+	if (VKFailed(vkAllocateCommandBuffers(device, &alloc_info, &mCommandBuffer)))
+		return false;
+
+	// Create a fence
+	VkFenceCreateInfo fence_info = {};
+	fence_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
+	if (VKFailed(vkCreateFence(device, &fence_info, nullptr, &mFence)))
+		return false;
+
+	return true;
+}
+
+bool ComputeQueueVK::BeginCommandBuffer()
+{
+	if (!mCommandBufferRecording)
+	{
+		VkCommandBufferBeginInfo begin_info = {};
+		begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+		begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+		if (VKFailed(vkBeginCommandBuffer(mCommandBuffer, &begin_info)))
+			return false;
+		mCommandBufferRecording = true;
+	}
+	return true;
+}
+
+void ComputeQueueVK::SetShader(const ComputeShader *inShader)
+{
+	mShader = static_cast<const ComputeShaderVK *>(inShader);
+	mBufferInfos = mShader->GetBufferInfos();
+}
+
+void ComputeQueueVK::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
+
+	if (!BeginCommandBuffer())
+		return;
+
+	const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
+	buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_ACCESS_UNIFORM_READ_BIT, false);
+
+	uint index = mShader->NameToBufferInfoIndex(inName);
+	JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
+	mBufferInfos[index].buffer = buffer->GetBufferCPU();
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueVK::SyncCPUToGPU(const ComputeBufferVK *inBuffer)
+{
+	// Ensure that any CPU writes are visible to the GPU
+	if (inBuffer->SyncCPUToGPU(mCommandBuffer))
+	{
+		// After the first upload, the CPU buffer is no longer needed for Buffer and RWBuffer types
+		if (inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType()  == ComputeBuffer::EType::RWBuffer)
+			mDelayedFreedBuffers.push_back(inBuffer->ReleaseBufferCPU());
+	}
+}
+
+void ComputeQueueVK::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+
+	if (!BeginCommandBuffer())
+		return;
+
+	const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
+	SyncCPUToGPU(buffer);
+	buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_ACCESS_SHADER_READ_BIT, false);
+
+	uint index = mShader->NameToBufferInfoIndex(inName);
+	JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
+	mBufferInfos[index].buffer = buffer->GetBufferGPU();
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueVK::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
+{
+	if (inBuffer == nullptr)
+		return;
+	JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
+
+	if (!BeginCommandBuffer())
+		return;
+
+	const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
+	SyncCPUToGPU(buffer);
+	buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VkAccessFlagBits(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT), inBarrier == EBarrier::Yes);
+
+	uint index = mShader->NameToBufferInfoIndex(inName);
+	JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
+	mBufferInfos[index].buffer = buffer->GetBufferGPU();
+
+	mUsedBuffers.insert(buffer);
+}
+
+void ComputeQueueVK::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
+{
+	if (inDst == nullptr || inSrc == nullptr)
+		return;
+	JPH_ASSERT(inDst->GetType() == ComputeBuffer::EType::ReadbackBuffer);
+
+	if (!BeginCommandBuffer())
+		return;
+
+	const ComputeBufferVK *src_vk = static_cast<const ComputeBufferVK *>(inSrc);
+	ComputeBufferVK *dst_vk = static_cast<ComputeBufferVK *>(inDst);
+
+	// Barrier to start reading from GPU buffer and writing to CPU buffer
+	src_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_READ_BIT, false);
+	dst_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_WRITE_BIT, false);
+
+	// Copy
+	VkBufferCopy copy = {};
+	copy.srcOffset = 0;
+	copy.dstOffset = 0;
+	copy.size = src_vk->GetSize() * src_vk->GetStride();
+	vkCmdCopyBuffer(mCommandBuffer, src_vk->GetBufferGPU(), dst_vk->GetBufferCPU(), 1, &copy);
+
+	// Barrier to indicate that CPU can read from the buffer
+	dst_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_ACCESS_HOST_READ_BIT, false);
+
+	mUsedBuffers.insert(src_vk);
+	mUsedBuffers.insert(dst_vk);
+}
+
+void ComputeQueueVK::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
+{
+	if (!BeginCommandBuffer())
+		return;
+
+	vkCmdBindPipeline(mCommandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, mShader->GetPipeline());
+
+	VkDevice device = mComputeSystem->GetDevice();
+	const Array<VkDescriptorSetLayoutBinding> &ds_bindings = mShader->GetLayoutBindings();
+	if (!ds_bindings.empty())
+	{
+		// Create a descriptor set
+		VkDescriptorSetAllocateInfo alloc_info = {};
+		alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
+		alloc_info.descriptorPool = mDescriptorPool;
+		alloc_info.descriptorSetCount = 1;
+		VkDescriptorSetLayout ds_layout = mShader->GetDescriptorSetLayout();
+		alloc_info.pSetLayouts = &ds_layout;
+		VkDescriptorSet descriptor_set;
+		if (VKFailed(vkAllocateDescriptorSets(device, &alloc_info, &descriptor_set)))
+			return;
+
+		// Write the values to the descriptor set
+		Array<VkWriteDescriptorSet> writes;
+		writes.reserve(ds_bindings.size());
+		for (uint32 i = 0; i < (uint32)ds_bindings.size(); ++i)
+		{
+			VkWriteDescriptorSet w = {};
+			w.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+			w.dstSet = descriptor_set;
+			w.dstBinding = ds_bindings[i].binding;
+			w.dstArrayElement = 0;
+			w.descriptorCount = ds_bindings[i].descriptorCount;
+			w.descriptorType = ds_bindings[i].descriptorType;
+			w.pBufferInfo = &mBufferInfos[i];
+			writes.push_back(w);
+		}
+		vkUpdateDescriptorSets(device, (uint32)writes.size(), writes.data(), 0, nullptr);
+
+		// Bind the descriptor set
+		vkCmdBindDescriptorSets(mCommandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, mShader->GetPipelineLayout(), 0, 1, &descriptor_set, 0, nullptr);
+	}
+
+	vkCmdDispatch(mCommandBuffer, inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
+}
+
+void ComputeQueueVK::Execute()
+{
+	// End command buffer
+	if (!mCommandBufferRecording)
+		return;
+	if (VKFailed(vkEndCommandBuffer(mCommandBuffer)))
+		return;
+	mCommandBufferRecording = false;
+
+	// Reset fence
+	VkDevice device = mComputeSystem->GetDevice();
+	if (VKFailed(vkResetFences(device, 1, &mFence)))
+		return;
+
+	// Submit
+	VkSubmitInfo submit = {};
+	submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+	submit.commandBufferCount = 1;
+	submit.pCommandBuffers = &mCommandBuffer;
+	if (VKFailed(vkQueueSubmit(mQueue, 1, &submit, mFence)))
+		return;
+
+	// Clear the current shader
+	mShader = nullptr;
+
+	// Mark that we're executing
+	mIsExecuting = true;
+}
+
+void ComputeQueueVK::Wait()
+{
+	if (!mIsExecuting)
+		return;
+
+	// Wait for the work to complete
+	VkDevice device = mComputeSystem->GetDevice();
+	if (VKFailed(vkWaitForFences(device, 1, &mFence, VK_TRUE, UINT64_MAX)))
+		return;
+
+	// Reset command buffer so it can be reused
+	if (mCommandBuffer != VK_NULL_HANDLE)
+		vkResetCommandBuffer(mCommandBuffer, 0);
+
+	// Allow reusing the descriptors for next run
+	vkResetDescriptorPool(device, mDescriptorPool, 0);
+
+	// Buffers can be freed now
+	mUsedBuffers.clear();
+
+	// Free delayed buffers
+	for (BufferVK &buffer : mDelayedFreedBuffers)
+		mComputeSystem->FreeBuffer(buffer);
+	mDelayedFreedBuffers.clear();
+
+	mIsExecuting = false;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 66 - 0
Jolt/Compute/VK/ComputeQueueVK.h

@@ -0,0 +1,66 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeQueue.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeShaderVK.h>
+#include <Jolt/Compute/VK/BufferVK.h>
+#include <Jolt/Core/UnorderedMap.h>
+#include <Jolt/Core/UnorderedSet.h>
+
+JPH_NAMESPACE_BEGIN
+
+class ComputeSystemVK;
+class ComputeBufferVK;
+
+/// A command queue for Vulkan for executing compute workloads on the GPU.
+class JPH_EXPORT ComputeQueueVK final : public ComputeQueue
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor / Destructor
+										ComputeQueueVK(ComputeSystemVK *inComputeSystem) : mComputeSystem(inComputeSystem) { }
+	virtual								~ComputeQueueVK() override;
+
+	/// Initialize the queue
+	bool								Initialize(uint32 inComputeQueueIndex);
+
+	// See: ComputeQueue
+	virtual void						SetShader(const ComputeShader *inShader) override;
+	virtual void						SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void						SetBuffer(const char *inName, const ComputeBuffer *inBuffer) override;
+	virtual void 						SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier = EBarrier::Yes) override;
+	virtual void						ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc) override;
+	virtual void						Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ) override;
+	virtual void						Execute() override;
+	virtual void						Wait() override;
+
+private:
+	bool								BeginCommandBuffer();
+
+	// Copy the CPU buffer to the GPU buffer if needed
+	void								SyncCPUToGPU(const ComputeBufferVK *inBuffer);
+
+	ComputeSystemVK *					mComputeSystem;
+	VkQueue								mQueue = VK_NULL_HANDLE;
+	VkCommandPool						mCommandPool = VK_NULL_HANDLE;
+	VkDescriptorPool					mDescriptorPool = VK_NULL_HANDLE;
+	VkCommandBuffer						mCommandBuffer = VK_NULL_HANDLE;
+	bool								mCommandBufferRecording = false;				///< If we are currently recording commands into the command buffer
+	VkFence								mFence = VK_NULL_HANDLE;
+	bool								mIsExecuting = false;							///< If Execute has been called and we are waiting for it to finish
+	RefConst<ComputeShaderVK>			mShader;										///< Shader that has been activated
+	Array<VkDescriptorBufferInfo>		mBufferInfos;									///< List of parameters that will be sent to the current shader
+	UnorderedSet<RefConst<ComputeBuffer>> mUsedBuffers;									///< Buffers that are in use by the current execution, these will be retained until execution is finished so that we don't free buffers that are in use
+	Array<BufferVK>						mDelayedFreedBuffers;							///< Hardware buffers that need to be freed after execution is done
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 232 - 0
Jolt/Compute/VK/ComputeShaderVK.cpp

@@ -0,0 +1,232 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeShaderVK.h>
+
+JPH_NAMESPACE_BEGIN
+
+ComputeShaderVK::~ComputeShaderVK()
+{
+	if (mShaderModule != VK_NULL_HANDLE)
+		vkDestroyShaderModule(mDevice, mShaderModule, nullptr);
+
+	if (mDescriptorSetLayout != VK_NULL_HANDLE)
+		vkDestroyDescriptorSetLayout(mDevice, mDescriptorSetLayout, nullptr);
+
+	if (mPipelineLayout != VK_NULL_HANDLE)
+		vkDestroyPipelineLayout(mDevice, mPipelineLayout, nullptr);
+
+	if (mPipeline != VK_NULL_HANDLE)
+		vkDestroyPipeline(mDevice, mPipeline, nullptr);
+}
+
+bool ComputeShaderVK::Initialize(const Array<uint8> &inSPVCode, VkBuffer inDummyBuffer)
+{
+	const uint32 *spv_words = reinterpret_cast<const uint32 *>(inSPVCode.data());
+	size_t spv_word_count = inSPVCode.size() / sizeof(uint32);
+
+	// Minimal SPIR-V parser to extract name to binding info
+	UnorderedMap<uint32, String> id_to_name;
+	UnorderedMap<uint32, uint32> id_to_binding;
+	UnorderedMap<uint32, VkDescriptorType> id_to_descriptor_type;
+	UnorderedMap<uint32, uint32> pointer_to_pointee;
+	UnorderedMap<uint32, uint32> var_to_ptr_type;
+	size_t i = 5; // Skip 5 word header
+	while (i < spv_word_count)
+	{
+		// Parse next word
+		uint32 word = spv_words[i];
+		uint16 opcode = uint16(word & 0xffff);
+		uint16 word_count = uint16(word >> 16);
+		if (word_count == 0 || i + word_count > spv_word_count)
+			break;
+
+		switch (opcode)
+		{
+		case 5: // OpName
+			if (word_count >= 2)
+			{
+				uint32 target_id = spv_words[i + 1];
+				const char* name = reinterpret_cast<const char*>(&spv_words[i + 2]);
+				if (*name != 0)
+					id_to_name.insert({ target_id, name });
+			}
+			break;
+
+		case 16: // OpExecutionMode
+			if (word_count >= 6)
+			{
+				uint32 execution_mode = spv_words[i + 2];
+				if (execution_mode == 17) // LocalSize
+				{
+					// Assert that the group size provided matches the one in the shader
+					JPH_ASSERT(GetGroupSizeX() == spv_words[i + 3], "Group size X mismatch");
+					JPH_ASSERT(GetGroupSizeY() == spv_words[i + 4], "Group size Y mismatch");
+					JPH_ASSERT(GetGroupSizeZ() == spv_words[i + 5], "Group size Z mismatch");
+				}
+			}
+			break;
+
+		case 32: // OpTypePointer
+			if (word_count >= 4)
+			{
+				uint32 result_id = spv_words[i + 1];
+				uint32 type_id = spv_words[i + 3];
+				pointer_to_pointee.insert({ result_id, type_id });
+			}
+			break;
+
+		case 59: // OpVariable
+			if (word_count >= 3)
+			{
+				uint32 ptr_type_id = spv_words[i + 1];
+				uint32 result_id = spv_words[i + 2];
+				var_to_ptr_type.insert({ result_id, ptr_type_id });
+			}
+			break;
+
+		case 71: // OpDecorate
+			if (word_count >= 3)
+			{
+				uint32 target_id = spv_words[i + 1];
+				uint32 decoration = spv_words[i + 2];
+				if (decoration == 2) // Block
+				{
+					id_to_descriptor_type.insert({ target_id, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER });
+				}
+				else if (decoration == 3) // BufferBlock
+				{
+					id_to_descriptor_type.insert({ target_id, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER });
+				}
+				else if (decoration == 33 && word_count >= 4) // Binding
+				{
+					uint32 binding = spv_words[i + 3];
+					id_to_binding.insert({ target_id, binding });
+				}
+			}
+			break;
+
+		default:
+			break;
+		}
+
+		i += word_count;
+	}
+
+	// Build name to binding map
+	UnorderedMap<String, std::pair<uint32, VkDescriptorType>> name_to_binding;
+	for (const UnorderedMap<uint32, uint32>::value_type& entry : id_to_binding)
+	{
+		uint32 target_id = entry.first;
+		uint32 binding = entry.second;
+
+		// Get the name of the variable
+		UnorderedMap<uint32, String>::const_iterator it_name = id_to_name.find(target_id);
+		if (it_name != id_to_name.end())
+		{
+			// Find variable that links to the target
+			UnorderedMap<uint32, uint32>::const_iterator it_var_ptr = var_to_ptr_type.find(target_id);
+			if (it_var_ptr != var_to_ptr_type.end())
+			{
+				// Find type pointed at
+				uint32 ptr_type = it_var_ptr->second;
+				UnorderedMap<uint32, uint32>::const_iterator it_pointee = pointer_to_pointee.find(ptr_type);
+				if (it_pointee != pointer_to_pointee.end())
+				{
+					uint32 pointee_type = it_pointee->second;
+
+					// Find descriptor type
+					UnorderedMap<uint32, VkDescriptorType>::iterator it_descriptor_type = id_to_descriptor_type.find(pointee_type);
+					VkDescriptorType descriptor_type = it_descriptor_type != id_to_descriptor_type.end() ? it_descriptor_type->second : VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+
+					name_to_binding.insert({ it_name->second, { binding, descriptor_type } });
+					continue;
+				}
+			}
+		}
+	}
+
+	// Create layout bindings and buffer infos
+	if (!name_to_binding.empty())
+	{
+		mLayoutBindings.reserve(name_to_binding.size());
+		mBufferInfos.reserve(name_to_binding.size());
+
+		mBindingNames.reserve(name_to_binding.size());
+		for (const UnorderedMap<String, std::pair<uint32, VkDescriptorType>>::value_type &b : name_to_binding)
+		{
+			const String &name = b.first;
+			uint binding = b.second.first;
+			VkDescriptorType descriptor_type = b.second.second;
+
+			VkDescriptorSetLayoutBinding l = {};
+			l.binding = binding;
+			l.descriptorCount = 1;
+			l.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+			l.descriptorType = descriptor_type;
+			mLayoutBindings.push_back(l);
+
+			mBindingNames.push_back(name); // Add all strings to a pool to keep them alive
+			mNameToBufferInfoIndex[string_view(mBindingNames.back())] = (uint32)mBufferInfos.size();
+
+			VkDescriptorBufferInfo bi = {};
+			bi.offset = 0;
+			bi.range = VK_WHOLE_SIZE;
+			bi.buffer = inDummyBuffer; // Avoid: The Vulkan spec states: If the nullDescriptor feature is not enabled, buffer must not be VK_NULL_HANDLE
+			mBufferInfos.push_back(bi);
+		}
+
+		// Create descriptor set layout
+		VkDescriptorSetLayoutCreateInfo layout_info = {};
+		layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
+		layout_info.bindingCount = (uint32)mLayoutBindings.size();
+		layout_info.pBindings = mLayoutBindings.data();
+		if (VKFailed(vkCreateDescriptorSetLayout(mDevice, &layout_info, nullptr, &mDescriptorSetLayout)))
+			return false;
+	}
+
+	// Create pipeline layout
+	VkPipelineLayoutCreateInfo pl_info = {};
+	pl_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
+	pl_info.setLayoutCount = mDescriptorSetLayout != VK_NULL_HANDLE ? 1 : 0;
+	pl_info.pSetLayouts = mDescriptorSetLayout != VK_NULL_HANDLE ? &mDescriptorSetLayout : nullptr;
+	if (VKFailed(vkCreatePipelineLayout(mDevice, &pl_info, nullptr, &mPipelineLayout)))
+		return false;
+
+	// Create shader module
+	VkShaderModuleCreateInfo create_info = {};
+	create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
+	create_info.codeSize = inSPVCode.size();
+	create_info.pCode = spv_words;
+	if (VKFailed(vkCreateShaderModule(mDevice, &create_info, nullptr, &mShaderModule)))
+		return false;
+
+	// Create compute pipeline
+	VkComputePipelineCreateInfo pipe_info = {};
+	pipe_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
+	pipe_info.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
+	pipe_info.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
+	pipe_info.stage.module = mShaderModule;
+	pipe_info.stage.pName = "main";
+	pipe_info.layout = mPipelineLayout;
+	if (VKFailed(vkCreateComputePipelines(mDevice, VK_NULL_HANDLE, 1, &pipe_info, nullptr, &mPipeline)))
+		return false;
+
+	return true;
+}
+
+uint32 ComputeShaderVK::NameToBufferInfoIndex(const char *inName) const
+{
+	UnorderedMap<string_view, uint>::const_iterator it = mNameToBufferInfoIndex.find(inName);
+	JPH_ASSERT(it != mNameToBufferInfoIndex.end());
+	return it->second;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 53 - 0
Jolt/Compute/VK/ComputeShaderVK.h

@@ -0,0 +1,53 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeShader.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/IncludeVK.h>
+#include <Jolt/Core/UnorderedMap.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Compute shader handle for Vulkan
+class JPH_EXPORT ComputeShaderVK : public ComputeShader
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Constructor / destructor
+										ComputeShaderVK(VkDevice inDevice, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) : ComputeShader(inGroupSizeX, inGroupSizeY, inGroupSizeZ), mDevice(inDevice) { }
+	virtual								~ComputeShaderVK() override;
+
+	/// Initialize from SPIR-V code
+	bool								Initialize(const Array<uint8> &inSPVCode, VkBuffer inDummyBuffer);
+
+	/// Get index of parameter in buffer infos
+	uint32								NameToBufferInfoIndex(const char *inName) const;
+
+	/// Getters
+	VkPipeline							GetPipeline() const							{ return mPipeline; }
+	VkPipelineLayout					GetPipelineLayout() const					{ return mPipelineLayout; }
+	VkDescriptorSetLayout				GetDescriptorSetLayout() const				{ return mDescriptorSetLayout; }
+	const Array<VkDescriptorSetLayoutBinding> &GetLayoutBindings() const			{ return mLayoutBindings; }
+	const Array<VkDescriptorBufferInfo> &GetBufferInfos() const						{ return mBufferInfos; }
+
+private:
+	VkDevice							mDevice;
+	VkShaderModule						mShaderModule = VK_NULL_HANDLE;
+	VkPipelineLayout					mPipelineLayout = VK_NULL_HANDLE;
+	VkPipeline							mPipeline = VK_NULL_HANDLE;
+	VkDescriptorSetLayout				mDescriptorSetLayout = VK_NULL_HANDLE;
+	Array<String>						mBindingNames;								///< A list of binding names, mNameToBufferInfoIndex points to these strings
+	UnorderedMap<string_view, uint32>	mNameToBufferInfoIndex;						///< Binding name to buffer index, using a string_view so we can do find() without an allocation
+	Array<VkDescriptorSetLayoutBinding>	mLayoutBindings;
+	Array<VkDescriptorBufferInfo>		mBufferInfos;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 90 - 0
Jolt/Compute/VK/ComputeSystemVK.cpp

@@ -0,0 +1,90 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeSystemVK.h>
+#include <Jolt/Compute/VK/ComputeShaderVK.h>
+#include <Jolt/Compute/VK/ComputeBufferVK.h>
+#include <Jolt/Compute/VK/ComputeQueueVK.h>
+
+JPH_NAMESPACE_BEGIN
+
+bool ComputeSystemVK::Initialize(VkPhysicalDevice inPhysicalDevice, VkDevice inDevice, uint32 inComputeQueueIndex)
+{
+	mPhysicalDevice = inPhysicalDevice;
+	mDevice = inDevice;
+	mComputeQueueIndex = inComputeQueueIndex;
+
+	// Get function to set a debug name
+	mVkSetDebugUtilsObjectNameEXT = reinterpret_cast<PFN_vkSetDebugUtilsObjectNameEXT>(reinterpret_cast<void *>(vkGetDeviceProcAddr(mDevice, "vkSetDebugUtilsObjectNameEXT")));
+
+	if (!InitializeMemory())
+		return false;
+
+	// Create the dummy buffer. This is used to bind to shaders for which we have no buffer. We can't rely on VK_EXT_robustness2 being available to set nullDescriptor = VK_TRUE (it is unavailable on macOS).
+	CreateBuffer(1024, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, mDummyBuffer);
+
+	return true;
+}
+
+void ComputeSystemVK::Shutdown()
+{
+	if (mDevice != VK_NULL_HANDLE)
+		vkDeviceWaitIdle(mDevice);
+
+	// Free the dummy buffer
+	FreeBuffer(mDummyBuffer);
+
+	ShutdownMemory();
+}
+
+Ref<ComputeShader> ComputeSystemVK::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
+{
+	// Read shader source file
+	Array<uint8> data;
+	String file_name = String(inName) + ".spv";
+	if (!mShaderLoader(file_name.c_str(), data))
+		return nullptr;
+
+	Ref<ComputeShaderVK> shader = new ComputeShaderVK(mDevice, inGroupSizeX, inGroupSizeY, inGroupSizeZ);
+	if (!shader->Initialize(data, mDummyBuffer.mBuffer))
+	{
+		Trace("Failed to create compute shader: %s", file_name.c_str());
+		return nullptr;
+	}
+
+	// Name the pipeline so we can easily find it in a profile
+	if (mVkSetDebugUtilsObjectNameEXT != nullptr)
+	{
+		VkDebugUtilsObjectNameInfoEXT info = {};
+		info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT;
+		info.pNext = nullptr;
+		info.objectType = VK_OBJECT_TYPE_PIPELINE;
+		info.objectHandle = (uint64)shader->GetPipeline();
+		info.pObjectName = inName;
+		mVkSetDebugUtilsObjectNameEXT(mDevice, &info);
+	}
+
+	return shader.GetPtr();
+}
+
+Ref<ComputeBuffer> ComputeSystemVK::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
+{
+	return new ComputeBufferVK(this, inType, inSize, inStride, inData);
+}
+
+Ref<ComputeQueue> ComputeSystemVK::CreateComputeQueue()
+{
+	Ref<ComputeQueueVK> q = new ComputeQueueVK(this);
+	if (!q->Initialize(mComputeQueueIndex))
+		return nullptr;
+	return q.GetPtr();
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 57 - 0
Jolt/Compute/VK/ComputeSystemVK.h

@@ -0,0 +1,57 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <Jolt/Compute/ComputeSystem.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeQueueVK.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Interface to run a workload on the GPU using Vulkan.
+/// Minimal implementation that can integrate with your own Vulkan setup.
+class JPH_EXPORT ComputeSystemVK : public ComputeSystem
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	// Initialize / shutdown the compute system
+	bool							Initialize(VkPhysicalDevice inPhysicalDevice, VkDevice inDevice, uint32 inComputeQueueIndex);
+	void							Shutdown();
+
+	// See: ComputeSystem
+	virtual Ref<ComputeShader>		CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ) override;
+	virtual Ref<ComputeBuffer>		CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData = nullptr) override;
+	virtual Ref<ComputeQueue>		CreateComputeQueue() override;
+
+	/// Access to the Vulkan device
+	VkDevice						GetDevice() const												{ return mDevice; }
+
+	/// Allow the application to override buffer creation and memory mapping in case it uses its own allocator
+	virtual void					CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer) = 0;
+	virtual void					FreeBuffer(BufferVK &ioBuffer) = 0;
+	virtual void *					MapBuffer(BufferVK &ioBuffer) = 0;
+	virtual void					UnmapBuffer(BufferVK &ioBuffer) = 0;
+
+protected:
+	/// Initialize / shutdown the memory subsystem
+	virtual bool					InitializeMemory() = 0;
+	virtual void					ShutdownMemory() = 0;
+
+	VkPhysicalDevice				mPhysicalDevice = VK_NULL_HANDLE;
+	VkDevice						mDevice = VK_NULL_HANDLE;
+	uint32							mComputeQueueIndex = 0;
+	PFN_vkSetDebugUtilsObjectNameEXT mVkSetDebugUtilsObjectNameEXT = nullptr;
+
+private:
+	// Buffer that can be bound when we have no buffer
+	BufferVK						mDummyBuffer;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 308 - 0
Jolt/Compute/VK/ComputeSystemVKImpl.cpp

@@ -0,0 +1,308 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeSystemVKImpl.h>
+#include <Jolt/Core/QuickSort.h>
+
+JPH_NAMESPACE_BEGIN
+
+#ifdef JPH_DEBUG
+
+static VKAPI_ATTR VkBool32 VKAPI_CALL sVulkanDebugCallback(VkDebugUtilsMessageSeverityFlagBitsEXT inSeverity, [[maybe_unused]] VkDebugUtilsMessageTypeFlagsEXT inType, const VkDebugUtilsMessengerCallbackDataEXT *inCallbackData, [[maybe_unused]] void *inUserData)
+{
+	if (inSeverity & (VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT))
+		Trace("VK: %s", inCallbackData->pMessage);
+	JPH_ASSERT((inSeverity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT) == 0);
+	return VK_FALSE;
+}
+
+#endif // JPH_DEBUG
+
+ComputeSystemVKImpl::~ComputeSystemVKImpl()
+{
+	ComputeSystemVK::Shutdown();
+
+	if (mDevice != VK_NULL_HANDLE)
+		vkDestroyDevice(mDevice, nullptr);
+
+#ifdef JPH_DEBUG
+	PFN_vkDestroyDebugUtilsMessengerEXT vkDestroyDebugUtilsMessengerEXT = (PFN_vkDestroyDebugUtilsMessengerEXT)(void *)vkGetInstanceProcAddr(mInstance, "vkDestroyDebugUtilsMessengerEXT");
+	if (mInstance != VK_NULL_HANDLE && mDebugMessenger != VK_NULL_HANDLE && vkDestroyDebugUtilsMessengerEXT != nullptr)
+		vkDestroyDebugUtilsMessengerEXT(mInstance, mDebugMessenger, nullptr);
+#endif
+
+	if (mInstance != VK_NULL_HANDLE)
+		vkDestroyInstance(mInstance, nullptr);
+}
+
+bool ComputeSystemVKImpl::Initialize()
+{
+	// Required instance extensions
+	Array<const char *> required_instance_extensions;
+	required_instance_extensions.push_back(VK_KHR_SURFACE_EXTENSION_NAME);
+	required_instance_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
+#ifdef JPH_PLATFORM_MACOS
+	required_instance_extensions.push_back("VK_KHR_portability_enumeration");
+	required_instance_extensions.push_back("VK_KHR_get_physical_device_properties2");
+#endif
+	GetInstanceExtensions(required_instance_extensions);
+
+	// Required device extensions
+	Array<const char *> required_device_extensions;
+	required_device_extensions.push_back(VK_EXT_SCALAR_BLOCK_LAYOUT_EXTENSION_NAME);
+#ifdef JPH_PLATFORM_MACOS
+	required_device_extensions.push_back("VK_KHR_portability_subset"); // VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME
+#endif
+	GetDeviceExtensions(required_device_extensions);
+
+	// Query supported instance extensions
+	uint32 instance_extension_count = 0;
+	if (VKFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, nullptr)))
+		return false;
+	Array<VkExtensionProperties> instance_extensions;
+	instance_extensions.resize(instance_extension_count);
+	if (VKFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, instance_extensions.data())))
+		return false;
+
+	// Query supported validation layers
+	uint32 validation_layer_count;
+	vkEnumerateInstanceLayerProperties(&validation_layer_count, nullptr);
+	Array<VkLayerProperties> validation_layers(validation_layer_count);
+	vkEnumerateInstanceLayerProperties(&validation_layer_count, validation_layers.data());
+
+	VkApplicationInfo app_info = {};
+	app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
+	app_info.apiVersion = VK_API_VERSION_1_1;
+
+	// Create Vulkan instance
+	VkInstanceCreateInfo instance_create_info = {};
+	instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
+#ifdef JPH_PLATFORM_MACOS
+	instance_create_info.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
+#endif
+	instance_create_info.pApplicationInfo = &app_info;
+
+#ifdef JPH_DEBUG
+	// Enable validation layer if supported
+	const char *desired_validation_layers[] = { "VK_LAYER_KHRONOS_validation" };
+	for (const VkLayerProperties &p : validation_layers)
+		if (strcmp(desired_validation_layers[0], p.layerName) == 0)
+		{
+			instance_create_info.enabledLayerCount = 1;
+			instance_create_info.ppEnabledLayerNames = desired_validation_layers;
+			break;
+		}
+
+	// Setup debug messenger callback if the extension is supported
+	VkDebugUtilsMessengerCreateInfoEXT messenger_create_info = {};
+	for (const VkExtensionProperties &ext : instance_extensions)
+		if (strcmp(VK_EXT_DEBUG_UTILS_EXTENSION_NAME, ext.extensionName) == 0)
+		{
+			messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
+			messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
+			messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
+			messenger_create_info.pfnUserCallback = sVulkanDebugCallback;
+			instance_create_info.pNext = &messenger_create_info;
+			required_instance_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
+			break;
+		}
+#endif
+
+	instance_create_info.enabledExtensionCount = (uint32)required_instance_extensions.size();
+	instance_create_info.ppEnabledExtensionNames = required_instance_extensions.data();
+	if (VKFailed(vkCreateInstance(&instance_create_info, nullptr, &mInstance)))
+		return false;
+
+#ifdef JPH_DEBUG
+	// Finalize debug messenger callback
+	PFN_vkCreateDebugUtilsMessengerEXT vkCreateDebugUtilsMessengerEXT = (PFN_vkCreateDebugUtilsMessengerEXT)(std::uintptr_t)vkGetInstanceProcAddr(mInstance, "vkCreateDebugUtilsMessengerEXT");
+	if (vkCreateDebugUtilsMessengerEXT != nullptr)
+		if (VKFailed(vkCreateDebugUtilsMessengerEXT(mInstance, &messenger_create_info, nullptr, &mDebugMessenger)))
+			return false;
+#endif
+
+	// Notify that instance has been created
+	OnInstanceCreated();
+
+	// Select device
+	uint32 device_count = 0;
+	if (VKFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, nullptr)))
+		return false;
+	Array<VkPhysicalDevice> devices;
+	devices.resize(device_count);
+	if (VKFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, devices.data())))
+		return false;
+	struct Device
+	{
+		VkPhysicalDevice		mPhysicalDevice;
+		String					mName;
+		VkSurfaceFormatKHR		mFormat;
+		uint32					mGraphicsQueueIndex;
+		uint32					mPresentQueueIndex;
+		uint32					mComputeQueueIndex;
+		int						mScore;
+	};
+	Array<Device> available_devices;
+	for (VkPhysicalDevice device : devices)
+	{
+		// Get device properties
+		VkPhysicalDeviceProperties properties;
+		vkGetPhysicalDeviceProperties(device, &properties);
+
+		// Test if it is an appropriate type
+		int score = 0;
+		switch (properties.deviceType)
+		{
+		case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
+			score = 30;
+			break;
+		case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
+			score = 20;
+			break;
+		case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
+			score = 10;
+			break;
+		case VK_PHYSICAL_DEVICE_TYPE_CPU:
+			score = 5;
+			break;
+		case VK_PHYSICAL_DEVICE_TYPE_OTHER:
+		case VK_PHYSICAL_DEVICE_TYPE_MAX_ENUM:
+			continue;
+		}
+
+		// Check if the device supports all our required extensions
+		uint32 device_extension_count;
+		vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, nullptr);
+		Array<VkExtensionProperties> available_extensions;
+		available_extensions.resize(device_extension_count);
+		vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, available_extensions.data());
+		int found_extensions = 0;
+		for (const char *required_device_extension : required_device_extensions)
+			for (const VkExtensionProperties &ext : available_extensions)
+				if (strcmp(required_device_extension, ext.extensionName) == 0)
+				{
+					found_extensions++;
+					break;
+				}
+		if (found_extensions != int(required_device_extensions.size()))
+			continue;
+
+		// Find the right queues
+		uint32 queue_family_count = 0;
+		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
+		Array<VkQueueFamilyProperties> queue_families;
+		queue_families.resize(queue_family_count);
+		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
+		uint32 graphics_queue = ~uint32(0);
+		uint32 present_queue = ~uint32(0);
+		uint32 compute_queue = ~uint32(0);
+		for (uint32 i = 0; i < uint32(queue_families.size()); ++i)
+		{
+			if (queue_families[i].queueFlags & VK_QUEUE_GRAPHICS_BIT)
+			{
+				graphics_queue = i;
+
+				if (queue_families[i].queueFlags & VK_QUEUE_COMPUTE_BIT)
+					compute_queue = i;
+			}
+
+			if (HasPresentSupport(device, i))
+				present_queue = i;
+
+			if (graphics_queue != ~uint32(0) && present_queue != ~uint32(0) && compute_queue != ~uint32(0))
+				break;
+		}
+		if (graphics_queue == ~uint32(0) || present_queue == ~uint32(0) || compute_queue == ~uint32(0))
+			continue;
+
+		// Select surface format
+		VkSurfaceFormatKHR selected_format = SelectFormat(device);
+		if (selected_format.format == VK_FORMAT_UNDEFINED)
+			continue;
+
+		// Add the device
+		available_devices.push_back({ device, properties.deviceName, selected_format, graphics_queue, present_queue, compute_queue, score });
+	}
+	if (available_devices.empty())
+		return false;
+
+	// Sort the devices by score
+	QuickSort(available_devices.begin(), available_devices.end(), [](const Device &inLHS, const Device &inRHS) {
+		return inLHS.mScore > inRHS.mScore;
+	});
+	const Device &selected_device = available_devices[0];
+
+	// Create device
+	float queue_priority = 1.0f;
+	VkDeviceQueueCreateInfo queue_create_info[3] = {};
+	for (size_t i = 0; i < std::size(queue_create_info); ++i)
+	{
+		queue_create_info[i].sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
+		queue_create_info[i].queueCount = 1;
+		queue_create_info[i].pQueuePriorities = &queue_priority;
+	}
+	uint32 num_queues = 0;
+	queue_create_info[num_queues++].queueFamilyIndex = selected_device.mGraphicsQueueIndex;
+	for (uint32 i = 0; i < num_queues; ++i)
+		if (queue_create_info[i].queueFamilyIndex != selected_device.mPresentQueueIndex)
+			queue_create_info[num_queues++].queueFamilyIndex = selected_device.mPresentQueueIndex;
+	for (uint32 i = 0; i < num_queues; ++i)
+		if (queue_create_info[i].queueFamilyIndex != selected_device.mComputeQueueIndex)
+			queue_create_info[num_queues++].queueFamilyIndex = selected_device.mComputeQueueIndex;
+
+	VkPhysicalDeviceScalarBlockLayoutFeatures enable_scalar_block = {};
+	enable_scalar_block.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SCALAR_BLOCK_LAYOUT_FEATURES;
+	enable_scalar_block.scalarBlockLayout = VK_TRUE;
+
+	VkPhysicalDeviceFeatures2 enabled_features2 = {};
+	enabled_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+	GetEnabledFeatures(enabled_features2);
+	enable_scalar_block.pNext = enabled_features2.pNext;
+	enabled_features2.pNext = &enable_scalar_block;
+
+	VkDeviceCreateInfo device_create_info = {};
+	device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
+	device_create_info.queueCreateInfoCount = num_queues;
+	device_create_info.pQueueCreateInfos = queue_create_info;
+	device_create_info.enabledLayerCount = instance_create_info.enabledLayerCount;
+	device_create_info.ppEnabledLayerNames = instance_create_info.ppEnabledLayerNames;
+	device_create_info.enabledExtensionCount = uint32(required_device_extensions.size());
+	device_create_info.ppEnabledExtensionNames = required_device_extensions.data();
+	device_create_info.pNext = &enabled_features2;
+	device_create_info.pEnabledFeatures = nullptr;
+
+	VkDevice device = VK_NULL_HANDLE;
+	if (VKFailed(vkCreateDevice(selected_device.mPhysicalDevice, &device_create_info, nullptr, &device)))
+		return false;
+
+	// Get the queues
+	mGraphicsQueueIndex = selected_device.mGraphicsQueueIndex;
+	mPresentQueueIndex = selected_device.mPresentQueueIndex;
+	vkGetDeviceQueue(device, mGraphicsQueueIndex, 0, &mGraphicsQueue);
+	vkGetDeviceQueue(device, mPresentQueueIndex, 0, &mPresentQueue);
+
+	// Store selected format
+	mSelectedFormat = selected_device.mFormat;
+
+	// Initialize the compute system
+	return ComputeSystemVK::Initialize(selected_device.mPhysicalDevice, device, selected_device.mComputeQueueIndex);
+}
+
+ComputeSystem *CreateComputeSystemVK()
+{
+	ComputeSystemVKImpl *compute = new ComputeSystemVKImpl;
+	if (compute->Initialize())
+		return compute;
+
+	delete compute;
+	return nullptr;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 57 - 0
Jolt/Compute/VK/ComputeSystemVKImpl.h

@@ -0,0 +1,57 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeSystemVKWithAllocator.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// Implementation of ComputeSystemVK that fully initializes Vulkan
+class JPH_EXPORT ComputeSystemVKImpl : public ComputeSystemVKWithAllocator
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Destructor
+	virtual							~ComputeSystemVKImpl() override;
+
+	/// Initialize the compute system
+	bool							Initialize();
+
+protected:
+	/// Override to perform actions once the instance has been created
+	virtual void					OnInstanceCreated()																{ /* Do nothing */ }
+
+	/// Override to add platform specific instance extensions
+	virtual void					GetInstanceExtensions(Array<const char *> &outExtensions)						{ /* Add nothing */ }
+
+	/// Override to add platform specific device extensions
+	virtual void					GetDeviceExtensions(Array<const char *> &outExtensions)							{ /* Add nothing */ }
+
+	/// Override  to enable specific features
+	virtual void					GetEnabledFeatures(VkPhysicalDeviceFeatures2 &ioFeatures)						{ /* Add nothing */ }
+
+	/// Override to check for present support on a given device and queue family
+	virtual bool					HasPresentSupport(VkPhysicalDevice inDevice, uint32 inQueueFamilyIndex)			{ return true; }
+
+	/// Override to select the surface format
+	virtual VkSurfaceFormatKHR		SelectFormat(VkPhysicalDevice inDevice)											{ return { VK_FORMAT_B8G8R8A8_UNORM, VK_COLOR_SPACE_SRGB_NONLINEAR_KHR }; }
+
+	VkInstance						mInstance = VK_NULL_HANDLE;
+#ifdef JPH_DEBUG
+	VkDebugUtilsMessengerEXT		mDebugMessenger = VK_NULL_HANDLE;
+#endif
+	uint32							mGraphicsQueueIndex = 0;
+	uint32							mPresentQueueIndex = 0;
+	VkQueue							mGraphicsQueue = VK_NULL_HANDLE;
+	VkQueue							mPresentQueue = VK_NULL_HANDLE;
+	VkSurfaceFormatKHR				mSelectedFormat;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 168 - 0
Jolt/Compute/VK/ComputeSystemVKWithAllocator.cpp

@@ -0,0 +1,168 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include <Jolt/Jolt.h>
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeSystemVKWithAllocator.h>
+#include <Jolt/Compute/VK/ComputeShaderVK.h>
+#include <Jolt/Compute/VK/ComputeBufferVK.h>
+#include <Jolt/Compute/VK/ComputeQueueVK.h>
+
+JPH_NAMESPACE_BEGIN
+
+bool ComputeSystemVKWithAllocator::InitializeMemory()
+{
+	// Get memory properties
+	vkGetPhysicalDeviceMemoryProperties(mPhysicalDevice, &mMemoryProperties);
+
+	return true;
+}
+
+void ComputeSystemVKWithAllocator::ShutdownMemory()
+{
+	// Free all memory
+	for (MemoryCache::value_type &mc : mMemoryCache)
+		for (Memory &m : mc.second)
+			if (m.mOffset == 0)
+				FreeMemory(*m.mMemory);
+	mMemoryCache.clear();
+}
+
+uint32 ComputeSystemVKWithAllocator::FindMemoryType(uint32 inTypeFilter, VkMemoryPropertyFlags inProperties)
+{
+	for (uint32 i = 0; i < mMemoryProperties.memoryTypeCount; i++)
+		if ((inTypeFilter & (1 << i))
+			&& (mMemoryProperties.memoryTypes[i].propertyFlags & inProperties) == inProperties)
+			return i;
+
+	JPH_ASSERT(false, "Failed to find memory type!");
+	return 0;
+}
+
+void ComputeSystemVKWithAllocator::AllocateMemory(VkDeviceSize inSize, uint32 inMemoryTypeBits, VkMemoryPropertyFlags inProperties, MemoryVK &ioMemory)
+{
+	JPH_ASSERT(ioMemory.mMemory == VK_NULL_HANDLE);
+
+	ioMemory.mSize = inSize;
+	ioMemory.mProperties = inProperties;
+
+	VkMemoryAllocateInfo alloc_info = {};
+	alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+	alloc_info.allocationSize = inSize;
+	alloc_info.memoryTypeIndex = FindMemoryType(inMemoryTypeBits, inProperties);
+	vkAllocateMemory(mDevice, &alloc_info, nullptr, &ioMemory.mMemory);
+}
+
+void ComputeSystemVKWithAllocator::FreeMemory(MemoryVK &ioMemory)
+{
+	vkFreeMemory(mDevice, ioMemory.mMemory, nullptr);
+	ioMemory.mMemory = VK_NULL_HANDLE;
+}
+
+void ComputeSystemVKWithAllocator::CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer)
+{
+	// Create a new buffer
+	outBuffer.mSize = inSize;
+
+	VkBufferCreateInfo create_info = {};
+	create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+	create_info.size = inSize;
+	create_info.usage = inUsage;
+	create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+	if (VKFailed(vkCreateBuffer(mDevice, &create_info, nullptr, &outBuffer.mBuffer)))
+	{
+		outBuffer.mBuffer = VK_NULL_HANDLE;
+		return;
+	}
+
+	VkMemoryRequirements mem_requirements;
+	vkGetBufferMemoryRequirements(mDevice, outBuffer.mBuffer, &mem_requirements);
+
+	if (mem_requirements.size > cMaxAllocSize)
+	{
+		// Allocate block directly
+		Ref<MemoryVK> memory_vk = new MemoryVK();
+		memory_vk->mBufferSize = mem_requirements.size;
+		AllocateMemory(mem_requirements.size, mem_requirements.memoryTypeBits, inProperties, *memory_vk);
+		outBuffer.mMemory = memory_vk;
+		outBuffer.mOffset = 0;
+	}
+	else
+	{
+		// Round allocation to the next power of 2 so that we can use a simple block based allocator
+		VkDeviceSize buffer_size = max(VkDeviceSize(GetNextPowerOf2(uint32(mem_requirements.size))), cMinAllocSize);
+
+		// Ensure that we have memory available from the right pool
+		Array<Memory> &mem_array = mMemoryCache[{ buffer_size, inProperties }];
+		if (mem_array.empty())
+		{
+			// Allocate a bigger block
+			Ref<MemoryVK> memory_vk = new MemoryVK();
+			memory_vk->mBufferSize = buffer_size;
+			AllocateMemory(cBlockSize, mem_requirements.memoryTypeBits, inProperties, *memory_vk);
+
+			// Divide into sub blocks
+			for (VkDeviceSize offset = 0; offset < cBlockSize; offset += buffer_size)
+				mem_array.push_back({ memory_vk, offset });
+		}
+
+		// Claim memory from the pool
+		Memory &memory = mem_array.back();
+		outBuffer.mMemory = memory.mMemory;
+		outBuffer.mOffset = memory.mOffset;
+		mem_array.pop_back();
+	}
+
+	// Bind the memory to the buffer
+	vkBindBufferMemory(mDevice, outBuffer.mBuffer, outBuffer.mMemory->mMemory, outBuffer.mOffset);
+}
+
+void ComputeSystemVKWithAllocator::FreeBuffer(BufferVK &ioBuffer)
+{
+	if (ioBuffer.mBuffer != VK_NULL_HANDLE)
+	{
+		// Destroy the buffer
+		vkDestroyBuffer(mDevice, ioBuffer.mBuffer, nullptr);
+		ioBuffer.mBuffer = VK_NULL_HANDLE;
+
+		// Hand the memory back to the cache
+		VkDeviceSize buffer_size = ioBuffer.mMemory->mBufferSize;
+		if (buffer_size > cMaxAllocSize)
+			FreeMemory(*ioBuffer.mMemory);
+		else
+			mMemoryCache[{ buffer_size, ioBuffer.mMemory->mProperties }].push_back({ ioBuffer.mMemory, ioBuffer.mOffset });
+
+		ioBuffer = BufferVK();
+	}
+}
+
+void *ComputeSystemVKWithAllocator::MapBuffer(BufferVK& ioBuffer)
+{
+	if (++ioBuffer.mMemory->mMappedCount == 1)
+	{
+		if (VKFailed(vkMapMemory(mDevice, ioBuffer.mMemory->mMemory, 0, VK_WHOLE_SIZE, 0, &ioBuffer.mMemory->mMappedPtr)))
+		{
+			ioBuffer.mMemory->mMappedCount = 0;
+			return nullptr;
+		}
+	}
+
+	return static_cast<uint8 *>(ioBuffer.mMemory->mMappedPtr) + ioBuffer.mOffset;
+}
+
+void ComputeSystemVKWithAllocator::UnmapBuffer(BufferVK& ioBuffer)
+{
+	JPH_ASSERT(ioBuffer.mMemory->mMappedCount > 0);
+	if (--ioBuffer.mMemory->mMappedCount == 0)
+	{
+		vkUnmapMemory(mDevice, ioBuffer.mMemory->mMemory);
+		ioBuffer.mMemory->mMappedPtr = nullptr;
+	}
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 70 - 0
Jolt/Compute/VK/ComputeSystemVKWithAllocator.h

@@ -0,0 +1,70 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_VK
+
+#include <Jolt/Compute/VK/ComputeSystemVK.h>
+#include <Jolt/Core/UnorderedMap.h>
+
+JPH_NAMESPACE_BEGIN
+
+/// This extends ComputeSystemVK to provide a default implementation for memory allocation and mapping.
+/// It uses a simple block based allocator to reduce the number of allocations done to Vulkan.
+class JPH_EXPORT ComputeSystemVKWithAllocator : public ComputeSystemVK
+{
+public:
+	JPH_OVERRIDE_NEW_DELETE
+
+	/// Allow the application to override buffer creation and memory mapping in case it uses its own allocator
+	virtual void					CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer) override;
+	virtual void					FreeBuffer(BufferVK &ioBuffer) override;
+	virtual void *					MapBuffer(BufferVK &ioBuffer) override;
+	virtual void					UnmapBuffer(BufferVK &ioBuffer) override;
+
+protected:
+	virtual bool					InitializeMemory() override;
+	virtual void					ShutdownMemory() override;
+
+	uint32							FindMemoryType(uint32 inTypeFilter, VkMemoryPropertyFlags inProperties);
+	void							AllocateMemory(VkDeviceSize inSize, uint32 inMemoryTypeBits, VkMemoryPropertyFlags inProperties, MemoryVK &ioMemory);
+	void							FreeMemory(MemoryVK &ioMemory);
+
+	VkPhysicalDeviceMemoryProperties mMemoryProperties;
+
+private:
+	// Smaller allocations (from cMinAllocSize to cMaxAllocSize) will be done in blocks of cBlockSize bytes.
+	// We do this because there is a limit to the number of allocations that we can make in Vulkan.
+	static constexpr VkDeviceSize	cMinAllocSize = 512;
+	static constexpr VkDeviceSize	cMaxAllocSize = 65536;
+	static constexpr VkDeviceSize	cBlockSize = 524288;
+
+	struct MemoryKey
+	{
+		bool						operator == (const MemoryKey &inRHS) const
+		{
+			return mSize == inRHS.mSize && mProperties == inRHS.mProperties;
+		}
+
+		VkDeviceSize				mSize;
+		VkMemoryPropertyFlags		mProperties;
+	};
+
+	JPH_MAKE_HASH_STRUCT(MemoryKey, MemoryKeyHasher, t.mProperties, t.mSize)
+
+	struct Memory
+	{
+		Ref<MemoryVK>				mMemory;
+		VkDeviceSize				mOffset;
+	};
+
+	using MemoryCache = UnorderedMap<MemoryKey, Array<Memory>, MemoryKeyHasher>;
+
+	MemoryCache						mMemoryCache;
+};
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 30 - 0
Jolt/Compute/VK/IncludeVK.h

@@ -0,0 +1,30 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#ifdef JPH_USE_VK
+
+JPH_SUPPRESS_WARNINGS_STD_BEGIN
+JPH_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic")
+
+#include <vulkan/vulkan.h>
+
+JPH_SUPPRESS_WARNINGS_STD_END
+
+JPH_NAMESPACE_BEGIN
+
+inline bool VKFailed(VkResult inResult)
+{
+	if (inResult == VK_SUCCESS)
+		return false;
+
+	Trace("Vulkan call failed with error code: %d", (int)inResult);
+	JPH_ASSERT(false);
+	return true;
+}
+
+JPH_NAMESPACE_END
+
+#endif // JPH_USE_VK

+ 10 - 0
Jolt/Core/Core.h

@@ -649,4 +649,14 @@ static_assert(sizeof(uint64) == 8, "Invalid size of uint64");
 	#define JPH_TSAN_NO_SANITIZE
 #endif
 
+// DirectX 12 is only supported on Windows
+#if defined(JPH_USE_DX12) && !defined(JPH_PLATFORM_WINDOWS)
+	#undef JPH_USE_DX12
+#endif // JPH_PLATFORM_WINDOWS
+
+// Metal is only supported on Apple platforms
+#if defined(JPH_USE_METAL) && !defined(JPH_PLATFORM_MACOS) && !defined(JPH_PLATFORM_IOS)
+	#undef JPH_USE_METAL
+#endif // !JPH_PLATFORM_MACOS && !JPH_PLATFORM_IOS
+
 JPH_NAMESPACE_END

+ 36 - 0
Jolt/Core/IncludeWindows.h

@@ -0,0 +1,36 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#ifdef JPH_PLATFORM_WINDOWS
+
+JPH_SUPPRESS_WARNING_PUSH
+JPH_MSVC_SUPPRESS_WARNING(5039) // winbase.h(13179): warning C5039: 'TpSetCallbackCleanupGroup': pointer or reference to potentially throwing function passed to 'extern "C"' function under -EHc. Undefined behavior may occur if this function throws an exception.
+JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
+JPH_CLANG_SUPPRESS_WARNING("-Wreserved-macro-identifier") // Complains about _WIN32_WINNT being reserved
+
+#ifndef WINVER
+	#define WINVER 0x0A00 // Targeting Windows 10 and above
+#endif
+
+#ifndef _WIN32_WINNT
+	#define _WIN32_WINNT 0x0A00
+#endif
+
+#ifndef WIN32_LEAN_AND_MEAN
+	#define WIN32_LEAN_AND_MEAN
+#endif
+
+#ifndef NOMINMAX
+	#define NOMINMAX
+#endif
+
+#ifndef JPH_COMPILER_MINGW
+	#include <Windows.h>
+#else
+	#include <windows.h>
+#endif
+
+JPH_SUPPRESS_WARNING_POP
+
+#endif

+ 1 - 14
Jolt/Core/JobSystemThreadPool.cpp

@@ -7,21 +7,8 @@
 #include <Jolt/Core/JobSystemThreadPool.h>
 #include <Jolt/Core/Profiler.h>
 #include <Jolt/Core/FPException.h>
+#include <Jolt/Core/IncludeWindows.h>
 
-#ifdef JPH_PLATFORM_WINDOWS
-	JPH_SUPPRESS_WARNING_PUSH
-	JPH_MSVC_SUPPRESS_WARNING(5039) // winbase.h(13179): warning C5039: 'TpSetCallbackCleanupGroup': pointer or reference to potentially throwing function passed to 'extern "C"' function under -EHc. Undefined behavior may occur if this function throws an exception.
-	JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
-	#ifndef WIN32_LEAN_AND_MEAN
-		#define WIN32_LEAN_AND_MEAN
-	#endif
-#ifndef JPH_COMPILER_MINGW
-	#include <Windows.h>
-#else
-	#include <windows.h>
-#endif
-	JPH_SUPPRESS_WARNING_POP
-#endif
 #ifdef JPH_PLATFORM_LINUX
 	#include <sys/prctl.h>
 #endif

+ 1 - 15
Jolt/Core/Semaphore.cpp

@@ -5,21 +5,7 @@
 #include <Jolt/Jolt.h>
 
 #include <Jolt/Core/Semaphore.h>
-
-#ifdef JPH_PLATFORM_WINDOWS
-	JPH_SUPPRESS_WARNING_PUSH
-	JPH_MSVC_SUPPRESS_WARNING(5039) // winbase.h(13179): warning C5039: 'TpSetCallbackCleanupGroup': pointer or reference to potentially throwing function passed to 'extern "C"' function under -EHc. Undefined behavior may occur if this function throws an exception.
-	JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
-#ifndef WIN32_LEAN_AND_MEAN
-		#define WIN32_LEAN_AND_MEAN
-	#endif
-#ifndef JPH_COMPILER_MINGW
-	#include <Windows.h>
-#else
-	#include <windows.h>
-#endif
-	JPH_SUPPRESS_WARNING_POP
-#endif
+#include <Jolt/Core/IncludeWindows.h>
 
 JPH_NAMESPACE_BEGIN
 

+ 1 - 15
Jolt/Core/TickCounter.cpp

@@ -5,21 +5,7 @@
 #include <Jolt/Jolt.h>
 
 #include <Jolt/Core/TickCounter.h>
-
-#if defined(JPH_PLATFORM_WINDOWS)
-	JPH_SUPPRESS_WARNING_PUSH
-	JPH_MSVC_SUPPRESS_WARNING(5039) // winbase.h(13179): warning C5039: 'TpSetCallbackCleanupGroup': pointer or reference to potentially throwing function passed to 'extern "C"' function under -EHc. Undefined behavior may occur if this function throws an exception.
-	JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
-#ifndef WIN32_LEAN_AND_MEAN
-		#define WIN32_LEAN_AND_MEAN
-	#endif
-#ifndef JPH_COMPILER_MINGW
-	#include <Windows.h>
-#else
-	#include <windows.h>
-#endif
-	JPH_SUPPRESS_WARNING_POP
-#endif
+#include <Jolt/Core/IncludeWindows.h>
 
 JPH_NAMESPACE_BEGIN
 

+ 193 - 3
Jolt/Jolt.cmake

@@ -14,6 +14,10 @@ set(JOLT_PHYSICS_SRC_FILES
 	${JOLT_PHYSICS_ROOT}/AABBTree/NodeCodec/NodeCodecQuadTreeHalfFloat.h
 	${JOLT_PHYSICS_ROOT}/AABBTree/TriangleCodec/TriangleCodecIndexed8BitPackSOA4Flags.h
 	${JOLT_PHYSICS_ROOT}/ConfigurationString.h
+	${JOLT_PHYSICS_ROOT}/Compute/ComputeBuffer.h
+	${JOLT_PHYSICS_ROOT}/Compute/ComputeQueue.h
+	${JOLT_PHYSICS_ROOT}/Compute/ComputeSystem.h
+	${JOLT_PHYSICS_ROOT}/Compute/ComputeShader.h
 	${JOLT_PHYSICS_ROOT}/Core/ARMNeon.h
 	${JOLT_PHYSICS_ROOT}/Core/Array.h
 	${JOLT_PHYSICS_ROOT}/Core/Atomics.h
@@ -31,6 +35,7 @@ set(JOLT_PHYSICS_SRC_FILES
 	${JOLT_PHYSICS_ROOT}/Core/FPFlushDenormals.h
 	${JOLT_PHYSICS_ROOT}/Core/HashCombine.h
 	${JOLT_PHYSICS_ROOT}/Core/HashTable.h
+	${JOLT_PHYSICS_ROOT}/Core/IncludeWindows.h
 	${JOLT_PHYSICS_ROOT}/Core/InsertionSort.h
 	${JOLT_PHYSICS_ROOT}/Core/IssueReporting.cpp
 	${JOLT_PHYSICS_ROOT}/Core/IssueReporting.h
@@ -457,16 +462,174 @@ if (ENABLE_OBJECT_STREAM)
 	)
 endif()
 
-if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Windows")
+if (JPH_USE_DX12 OR JPH_USE_VK OR JPH_USE_MTL)
+	# Compute shaders
+	set(JOLT_PHYSICS_SHADERS
+		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute.hlsl
+	)
+
+	set(JOLT_PHYSICS_SHADER_HEADERS
+		${JOLT_PHYSICS_ROOT}/Shaders/ShaderCore.h
+		${JOLT_PHYSICS_ROOT}/Shaders/ShaderMat44.h
+		${JOLT_PHYSICS_ROOT}/Shaders/ShaderMath.h
+		${JOLT_PHYSICS_ROOT}/Shaders/ShaderPlane.h
+		${JOLT_PHYSICS_ROOT}/Shaders/ShaderQuat.h
+		${JOLT_PHYSICS_ROOT}/Shaders/ShaderVec3.h
+		${JOLT_PHYSICS_ROOT}/Shaders/TestCompute.h
+		${JOLT_PHYSICS_ROOT}/Shaders/TestComputeBindings.h
+	)
+endif()
+
+if (WIN32)
 	# Add natvis file
 	set(JOLT_PHYSICS_SRC_FILES ${JOLT_PHYSICS_SRC_FILES} ${JOLT_PHYSICS_ROOT}/Jolt.natvis)
+
+	# Set properties to compile shaders as compute shaders
+	set_source_files_properties(${JOLT_PHYSICS_SHADERS} PROPERTIES VS_SHADER_FLAGS "/WX /T cs_5_0")
+
+	# DirectX support
+	if (JPH_USE_DX12)
+		# DirectX source files
+		set(JOLT_PHYSICS_SRC_FILES
+			${JOLT_PHYSICS_SRC_FILES}
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeQueueDX12.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeQueueDX12.h
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeBufferDX12.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeBufferDX12.h
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeSystemDX12.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeSystemDX12.h
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeSystemDX12Impl.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeSystemDX12Impl.h
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/ComputeShaderDX12.h
+			${JOLT_PHYSICS_ROOT}/Compute/DX12/IncludeDX12.h
+		)
+	endif()
+else()
+	set(JPH_USE_DX12 OFF)
+endif()
+
+if (APPLE)
+	# Metal support
+	if (JPH_USE_MTL)
+		# Metal source files
+		set(JOLT_PHYSICS_SRC_FILES
+			${JOLT_PHYSICS_SRC_FILES}
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeBufferMTL.mm
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeBufferMTL.h
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeQueueMTL.mm
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeQueueMTL.h
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeShaderMTL.mm
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeShaderMTL.h
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeSystemMTL.mm
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeSystemMTL.h
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeSystemMTLImpl.mm
+			${JOLT_PHYSICS_ROOT}/Compute/MTL/ComputeSystemMTLImpl.h
+		)
+
+		find_program(DXC_COMPILER NAMES dxc)
+		find_program(SPIRV_CROSS_COMPILER NAMES spirv-cross)
+		if (NOT DXC_COMPILER)
+			MESSAGE("Application 'dxc' not found. Can't compile compute shaders. Some functionality will be unavailable. You can install it by e.g. installing the Vulkan SDK.")
+		elseif (NOT SPIRV_CROSS_COMPILER)
+			MESSAGE("Application 'spirv-cross' not found. Can't compile compute shaders. Some functionality will be unavailable. You can install it by e.g. installing the Vulkan SDK.")
+		else()
+			# Determine target for shader compiler
+			if (IOS)
+				set(METAL_SDK_TARGET "iphonesimulator")
+			else()
+				set(METAL_SDK_TARGET "macosx")
+			endif()
+
+			# Compile Metal shaders
+			foreach(SHADER ${JOLT_PHYSICS_SHADERS})
+				cmake_path(GET SHADER STEM SHADER_STEM) # Filename without extension
+				set(SPV_SHADER "${CMAKE_CURRENT_BINARY_DIR}/${SHADER_STEM}.spv")
+				set(MTL_SHADER "${CMAKE_CURRENT_BINARY_DIR}/${SHADER_STEM}.metal")
+				set(AIR_SHADER "${CMAKE_CURRENT_BINARY_DIR}/${SHADER_STEM}.air")
+				add_custom_command(OUTPUT ${AIR_SHADER}
+					COMMAND ${DXC_COMPILER} -E main -T cs_6_0 -I Jolt/Shaders -WX -O3 -all_resources_bound ${SHADER} -spirv -fvk-use-dx-layout -fspv-entrypoint-name=${SHADER_STEM} -Fo ${SPV_SHADER}
+					COMMAND ${SPIRV_CROSS_COMPILER} ${SPV_SHADER} --msl --output ${MTL_SHADER}
+					COMMAND xcrun -sdk ${METAL_SDK_TARGET} metal -c ${MTL_SHADER} -o ${AIR_SHADER}
+					DEPENDS ${SHADER} ${JOLT_PHYSICS_SHADER_HEADERS} # Currently don't have a way to detect header dependencies, so making dependent on all
+					COMMENT "Compiling Metal ${SHADER}")
+				list(APPEND JOLT_PHYSICS_MTL_SHADERS ${AIR_SHADER})
+			endforeach()
+
+			# Link Metal shaders
+			set(JOLT_PHYSICS_METAL_LIB ${JOLT_PHYSICS_ROOT}/Shaders/Jolt.metallib)
+			add_custom_command(OUTPUT ${JOLT_PHYSICS_METAL_LIB}
+				COMMAND xcrun -sdk ${METAL_SDK_TARGET} metallib -o ${JOLT_PHYSICS_METAL_LIB} ${JOLT_PHYSICS_MTL_SHADERS}
+				DEPENDS ${JOLT_PHYSICS_MTL_SHADERS}
+				COMMENT "Linking shaders")
+
+			# Group intermediate files
+			source_group(Intermediate FILES ${JOLT_PHYSICS_MTL_SHADERS} ${JOLT_PHYSICS_METAL_LIB})
+		endif()
+	endif()
+
+	# Ignore PCH files for .mm files
+	foreach(SRC_FILE ${JOLT_PHYSICS_SRC_FILES})
+		if (SRC_FILE MATCHES "\.mm")
+			set_source_files_properties(${SRC_FILE} PROPERTIES SKIP_PRECOMPILE_HEADERS ON)
+		endif()
+	endforeach()
+else()
+	set(JPH_USE_MTL OFF)
+endif()
+
+# Vulkan support
+if (JPH_USE_VK)
+	find_package(Vulkan)
+	if (Vulkan_FOUND)
+		# Vulkan source files
+		set(JOLT_PHYSICS_SRC_FILES
+			${JOLT_PHYSICS_SRC_FILES}
+			${JOLT_PHYSICS_ROOT}/Compute/VK/BufferVK.h
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeBufferVK.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeBufferVK.h
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeQueueVK.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeQueueVK.h
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeShaderVK.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeShaderVK.h
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeSystemVK.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeSystemVK.h
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeSystemVKImpl.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeSystemVKImpl.h
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeSystemVKWithAllocator.cpp
+			${JOLT_PHYSICS_ROOT}/Compute/VK/ComputeSystemVKWithAllocator.h
+			${JOLT_PHYSICS_ROOT}/Compute/VK/IncludeVK.h
+		)
+
+		# TODO: For some reason it errors on finding dxc when we specify the dxc component to find_vulkan (and update cmake version)
+		# For now, just set it manually
+		string(REPLACE "glslc" "dxc" Vulkan_dxc_EXECUTABLE ${Vulkan_GLSLC_EXECUTABLE})
+
+		# Compile Vulkan shaders
+		foreach(SHADER ${JOLT_PHYSICS_SHADERS})
+			string(REPLACE ".hlsl" ".spv" SPV_SHADER ${SHADER})
+			add_custom_command(OUTPUT ${SPV_SHADER}
+				# We use dxc instead of: ${Vulkan_GLSLC_EXECUTABLE} -fshader-stage=compute ${SHADER} -o ${SPV_SHADER}
+				# The glslc compiler has the following issues:
+				# - All buffers bind to slot 0. We don't want to manually specify registers so this requires going into the SPIRV code and patching it.
+				# - It automatically aligns float3 to 16 byte boundaries which wastes a lot of memory in structs. We only seem to be able to override this alignment when compiling a GLSL shader and not with HLSL.
+				COMMAND ${Vulkan_dxc_EXECUTABLE} -E main -T cs_6_0 -I Jolt/Shaders -WX -O3 -all_resources_bound ${SHADER} -spirv -fvk-use-dx-layout -Fo ${SPV_SHADER}
+				DEPENDS ${SHADER} ${JOLT_PHYSICS_SHADER_HEADERS} # Currently don't have a way to detect header dependencies, so making dependent on all
+				COMMENT "Compiling Vulkan ${SHADER}")
+			list(APPEND JOLT_PHYSICS_SPV_SHADERS ${SPV_SHADER})
+		endforeach()
+
+		# Group intermediate files
+		source_group(Intermediate FILES ${JOLT_PHYSICS_SPV_SHADERS})
+	else()
+		set(JPH_USE_VK OFF)
+	endif()
 endif()
 
 # Group source files
-source_group(TREE ${JOLT_PHYSICS_ROOT} FILES ${JOLT_PHYSICS_SRC_FILES})
+source_group(TREE ${JOLT_PHYSICS_ROOT} FILES ${JOLT_PHYSICS_SRC_FILES} ${JOLT_PHYSICS_SHADERS} ${JOLT_PHYSICS_SHADER_HEADERS})
 
 # Create Jolt lib
-add_library(Jolt ${JOLT_PHYSICS_SRC_FILES})
+add_library(Jolt ${JOLT_PHYSICS_SRC_FILES} ${JOLT_PHYSICS_SHADERS} ${JOLT_PHYSICS_SHADER_HEADERS} ${JOLT_PHYSICS_SPV_SHADERS} ${JOLT_PHYSICS_METAL_LIB})
 add_library(Jolt::Jolt ALIAS Jolt)
 
 if (BUILD_SHARED_LIBS)
@@ -566,6 +729,33 @@ if (JPH_TRACK_SIMULATION_STATS)
 	target_compile_definitions(Jolt PUBLIC JPH_TRACK_SIMULATION_STATS)
 endif()
 
+# Compile against DirectX 12
+if (JPH_USE_DX12)
+	target_compile_definitions(Jolt PUBLIC JPH_USE_DX12)
+	target_link_libraries(Jolt LINK_PUBLIC dxgi.lib d3d12.lib d3dcompiler.lib dxguid.lib)
+
+	# Use DXC compiler to compile shaders, when off falls back to FXC
+	if (JPH_USE_DXC)
+		target_compile_definitions(Jolt PUBLIC JPH_USE_DXC)
+		target_link_libraries(Jolt LINK_PUBLIC dxcompiler.lib)
+	endif()
+endif()
+
+# Compile against Vulkan
+if (JPH_USE_VK)
+	target_compile_definitions(Jolt PUBLIC JPH_USE_VK)
+
+	target_include_directories(Jolt PUBLIC ${Vulkan_INCLUDE_DIRS})
+	target_link_libraries(Jolt LINK_PUBLIC ${Vulkan_LIBRARIES})
+endif()
+
+# Compile against Metal
+if (JPH_USE_MTL)
+	target_compile_definitions(Jolt PUBLIC JPH_USE_MTL)
+
+	target_link_libraries(Jolt LINK_PUBLIC "-framework Foundation -framework Metal -framework MetalKit")
+endif()
+
 # Enable the debug renderer
 if (DEBUG_RENDERER_IN_DISTRIBUTION)
 	target_compile_definitions(Jolt PUBLIC "JPH_DEBUG_RENDERER")

File diff suppressed because it is too large
+ 0 - 0
Jolt/Physics/Collision/Shape/TaperedCapsuleShape.gliffy


+ 75 - 0
Jolt/Shaders/ShaderCore.h

@@ -0,0 +1,75 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#ifndef JPH_SHADER_OVERRIDE_MACROS
+
+#ifdef __cplusplus
+	JPH_SUPPRESS_WARNING_PUSH
+	JPH_SUPPRESS_WARNINGS
+
+	using JPH_float = float;
+	using JPH_float3 = JPH::Float3;
+	using JPH_float4 = JPH::Float4;
+	using JPH_uint = JPH::uint32;
+	using JPH_uint3 = JPH::uint32[3];
+	using JPH_uint4 = JPH::uint32[4];
+	using JPH_int = int;
+	using JPH_int3 = int[3];
+	using JPH_int4 = int[4];
+	using JPH_Quat = JPH::Float4;
+	using JPH_Plane = JPH::Float4;
+	using JPH_Mat44 = JPH::Float4[4]; // matrix, column major
+
+	#define JPH_SHADER_CONSTANTS_BEGIN(type, name)	struct type {
+	#define JPH_SHADER_CONSTANTS_MEMBER(type, name)	type c##name;
+	#define JPH_SHADER_CONSTANTS_END				};
+
+	#define JPH_SHADER_BIND_BEGIN(name)
+	#define JPH_SHADER_BIND_END
+	#define JPH_SHADER_BIND_BUFFER(type, name)
+	#define JPH_SHADER_BIND_RW_BUFFER(type, name)
+
+	JPH_SUPPRESS_WARNING_POP
+#else
+	#pragma pack_matrix(column_major)
+
+	typedef float JPH_float;
+	typedef float3 JPH_float3;
+	typedef float4 JPH_float4;
+	typedef uint JPH_uint;
+	typedef uint3 JPH_uint3;
+	typedef uint4 JPH_uint4;
+	typedef int JPH_int;
+	typedef int3 JPH_int3;
+	typedef int4 JPH_int4;
+	typedef float4 JPH_Quat; // xyz = imaginary part, w = real part
+	typedef float4 JPH_Plane; // xyz = normal, w = constant
+	typedef float4x4 JPH_Mat44; // matrix, column major
+
+	#define JPH_SHADER_CONSTANTS_BEGIN(type, name)	cbuffer name {
+	#define JPH_SHADER_CONSTANTS_MEMBER(type, name)	type c##name;
+	#define JPH_SHADER_CONSTANTS_END				};
+
+	#define JPH_SHADER_FUNCTION_BEGIN(return_type, name, group_size_x, group_size_y, group_size_z) \
+		[numthreads(group_size_x, group_size_y, group_size_z)] \
+		return_type name(
+	#define JPH_SHADER_PARAM_THREAD_ID(name)		uint3 name : SV_DispatchThreadID
+	#define JPH_SHADER_FUNCTION_END					)
+
+	#define JPH_SHADER_BUFFER(type)					StructuredBuffer<type>
+	#define JPH_SHADER_RW_BUFFER(type)				RWStructuredBuffer<type>
+
+	#define JPH_SHADER_BIND_BEGIN(name)
+	#define JPH_SHADER_BIND_END
+	#define JPH_SHADER_BIND_BUFFER(type, name)		JPH_SHADER_BUFFER(type) name;
+	#define JPH_SHADER_BIND_RW_BUFFER(type, name)	JPH_SHADER_RW_BUFFER(type) name;
+
+	#define JPH_AtomicAdd							InterlockedAdd
+#endif
+
+#define JPH_SHADER_STRUCT_BEGIN(name)				struct name {
+#define JPH_SHADER_STRUCT_MEMBER(type, name)		type m##name;
+#define JPH_SHADER_STRUCT_END						};
+
+#endif // JPH_OVERRIDE_SHADER_MACROS

+ 13 - 0
Jolt/Shaders/ShaderMat44.h

@@ -0,0 +1,13 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+inline float3 JPH_Mat44MulVec3(JPH_Mat44 inLHS, float3 inRHS)
+{
+	return inLHS[0].xyz * inRHS.x + inLHS[1].xyz * inRHS.y + inLHS[2].xyz * inRHS.z + inLHS[3].xyz;
+}
+
+inline float3 JPH_Mat44Mul3x3Vec3(JPH_Mat44 inLHS, float3 inRHS)
+{
+	return inLHS[0].xyz * inRHS.x + inLHS[1].xyz * inRHS.y + inLHS[2].xyz * inRHS.z;
+}

+ 16 - 0
Jolt/Shaders/ShaderMath.h

@@ -0,0 +1,16 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+// Calculate inValue^2
+inline float JPH_Square(float inValue)
+{
+	return inValue * inValue;
+}
+
+// Get the closest point on a line segment defined by inA + x * inAB for x e [0, 1] to the origin
+inline float3 JPH_GetClosestPointOnLine(in float3 inA, in float3 inAB)
+{
+	float v = clamp(-dot(inA, inAB) / dot(inAB, inAB), 0.0f, 1.0f);
+	return inA + v * inAB;
+}

+ 18 - 0
Jolt/Shaders/ShaderPlane.h

@@ -0,0 +1,18 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+JPH_Plane JPH_PlaneFromPointAndNormal(float3 inPoint, float3 inNormal)
+{
+	return JPH_Plane(inNormal, -dot(inNormal, inPoint));
+}
+
+float3 JPH_PlaneGetNormal(JPH_Plane inPlane)
+{
+	return inPlane.xyz;
+}
+
+float JPH_PlaneSignedDistance(JPH_Plane inPlane, float3 inPoint)
+{
+	return dot(inPoint, inPlane.xyz) + inPlane.w;
+}

+ 114 - 0
Jolt/Shaders/ShaderQuat.h

@@ -0,0 +1,114 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+inline float3 JPH_QuatMulVec3(JPH_Quat inLHS, float3 inRHS)
+{
+	float3 xyz = inLHS.xyz;
+	float3 yzx = inLHS.yzx;
+	float3 q_cross_p = (inRHS.yzx * xyz - yzx * inRHS).yzx;
+	float3 q_cross_q_cross_p = (q_cross_p.yzx * xyz - yzx * q_cross_p).yzx;
+	float3 v = inLHS.w * q_cross_p + q_cross_q_cross_p;
+	return inRHS + (v + v);
+}
+
+inline JPH_Quat JPH_QuatMulQuat(JPH_Quat inLHS, JPH_Quat inRHS)
+{
+	float x = inLHS.w * inRHS.x + inLHS.x * inRHS.w + inLHS.y * inRHS.z - inLHS.z * inRHS.y;
+	float y = inLHS.w * inRHS.y - inLHS.x * inRHS.z + inLHS.y * inRHS.w + inLHS.z * inRHS.x;
+	float z = inLHS.w * inRHS.z + inLHS.x * inRHS.y - inLHS.y * inRHS.x + inLHS.z * inRHS.w;
+	float w = inLHS.w * inRHS.w - inLHS.x * inRHS.x - inLHS.y * inRHS.y - inLHS.z * inRHS.z;
+	return JPH_Quat(x, y, z, w);
+}
+
+inline JPH_Quat JPH_QuatImaginaryMulQuat(float3 inLHS, JPH_Quat inRHS)
+{
+	float x = +inLHS.x * inRHS.w + inLHS.y * inRHS.z - inLHS.z * inRHS.y;
+	float y = -inLHS.x * inRHS.z + inLHS.y * inRHS.w + inLHS.z * inRHS.x;
+	float z = +inLHS.x * inRHS.y - inLHS.y * inRHS.x + inLHS.z * inRHS.w;
+	float w = -inLHS.x * inRHS.x - inLHS.y * inRHS.y - inLHS.z * inRHS.z;
+	return JPH_Quat(x, y, z, w);
+}
+
+inline float3 JPH_QuatRotateAxisZ(JPH_Quat inRotation)
+{
+	return (inRotation.z + inRotation.z) * inRotation.xyz + (inRotation.w + inRotation.w) * float3(inRotation.y, -inRotation.x, inRotation.w) - float3(0, 0, 1);
+}
+
+inline JPH_Quat JPH_QuatConjugate(JPH_Quat inRotation)
+{
+	return JPH_Quat(-inRotation.x, -inRotation.y, -inRotation.z, inRotation.w);
+}
+
+inline JPH_Quat JPH_QuatDecompress(uint inValue)
+{
+	const float cOneOverSqrt2 = 0.70710678f;
+	const uint cNumBits = 9;
+	const uint cMask = (1u << cNumBits) - 1;
+	const uint cMaxValue = cMask - 1; // Need odd number of buckets to quantize to or else we can't encode 0
+	const float cScale = 2.0f * cOneOverSqrt2 / float(cMaxValue);
+
+	// Restore two components
+	float3 v3 = float3(inValue & cMask, (inValue >> cNumBits) & cMask, (inValue >> (2 * cNumBits)) & cMask) * cScale - float3(cOneOverSqrt2, cOneOverSqrt2, cOneOverSqrt2);
+
+	// Restore the highest component
+	float4 v = float4(v3, sqrt(max(1.0f - dot(v3, v3), 0.0f)));
+
+	// Extract sign
+	if ((inValue & 0x80000000u) != 0)
+		v = -v;
+
+	// Swizzle the components in place
+	uint max_element = (inValue >> 29) & 3;
+	v = max_element == 0? v.wxyz : (max_element == 1? v.xwyz : (max_element == 2? v.xywz : v));
+
+	return v;
+}
+
+inline JPH_Quat JPH_QuatFromMat33(float3 inCol0, float3 inCol1, float3 inCol2)
+{
+	float tr = inCol0.x + inCol1.y + inCol2.z;
+	if (tr >= 0.0f)
+	{
+		float s = sqrt(tr + 1.0f);
+		float is = 0.5f / s;
+		return JPH_Quat(
+			(inCol1.z - inCol2.y) * is,
+			(inCol2.x - inCol0.z) * is,
+			(inCol0.y - inCol1.x) * is,
+			0.5f * s);
+	}
+	else
+	{
+		if (inCol0.x > inCol1.y && inCol0.x > inCol2.z)
+		{
+			float s = sqrt(inCol0.x - (inCol1.y + inCol2.z) + 1);
+			float is = 0.5f / s;
+			return JPH_Quat(
+				0.5f * s,
+				(inCol1.x + inCol0.y) * is,
+				(inCol0.z + inCol2.x) * is,
+				(inCol1.z - inCol2.y) * is);
+		}
+		else if (inCol1.y > inCol2.z)
+		{
+			float s = sqrt(inCol1.y - (inCol2.z + inCol0.x) + 1);
+			float is = 0.5f / s;
+			return JPH_Quat(
+				(inCol1.x + inCol0.y) * is,
+				0.5f * s,
+				(inCol2.y + inCol1.z) * is,
+				(inCol2.x - inCol0.z) * is);
+		}
+		else
+		{
+			float s = sqrt(inCol2.z - (inCol0.x + inCol1.y) + 1);
+			float is = 0.5f / s;
+			return JPH_Quat(
+				(inCol0.z + inCol2.x) * is,
+				(inCol2.y + inCol1.z) * is,
+				0.5f * s,
+				(inCol0.y - inCol1.x) * is);
+		}
+	}
+}

+ 28 - 0
Jolt/Shaders/ShaderVec3.h

@@ -0,0 +1,28 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+inline float3 JPH_Vec3DecompressUnit(uint inValue)
+{
+	const float cOneOverSqrt2 = 0.70710678f;
+	const uint cNumBits = 14;
+	const uint cMask = (1u << cNumBits) - 1;
+	const uint cMaxValue = cMask - 1; // Need odd number of buckets to quantize to or else we can't encode 0
+	const float cScale = 2.0f * cOneOverSqrt2 / float(cMaxValue);
+
+	// Restore two components
+	float2 v2 = float2(inValue & cMask, (inValue >> cNumBits) & cMask) * cScale - float2(cOneOverSqrt2, cOneOverSqrt2);
+
+	// Restore the highest component
+	float3 v = float3(v2, sqrt(max(1.0f - dot(v2, v2), 0.0f)));
+
+	// Extract sign
+	if ((inValue & 0x80000000u) != 0)
+		v = -v;
+
+	// Swizzle the components in place
+	uint max_element = (inValue >> 29) & 3;
+	v = max_element == 0? v.zxy : (max_element == 1? v.xzy : v);
+
+	return v;
+}

+ 19 - 0
Jolt/Shaders/TestCompute.h

@@ -0,0 +1,19 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#ifdef __cplusplus
+	#pragma once
+#endif
+
+#include "ShaderCore.h"
+
+static const int cTestComputeGroupSize = 64;
+
+JPH_SHADER_CONSTANTS_BEGIN(TestComputeContext, gContext)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		Float3Value)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		UIntValue)		// Test that this value packs correctly with the float3 preceding it
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_float3,		Float3Value2)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		UIntValue2)
+	JPH_SHADER_CONSTANTS_MEMBER(JPH_uint,		NumElements)
+JPH_SHADER_CONSTANTS_END

+ 26 - 0
Jolt/Shaders/TestCompute.hlsl

@@ -0,0 +1,26 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "TestCompute.h"
+#include "TestComputeBindings.h"
+
+JPH_SHADER_FUNCTION_BEGIN(void, main, cTestComputeGroupSize, 1, 1)
+	JPH_SHADER_PARAM_THREAD_ID(tid)
+JPH_SHADER_FUNCTION_END
+{
+	// Ensure that we do not write out of bounds
+	if (tid.x >= cNumElements)
+		return;
+
+	if (cUIntValue2 == 0)
+	{
+		// First write, uses optional data and tests that the packing of float3/uint3's works
+		gData[tid.x] = gOptionalData[tid.x] + int(cFloat3Value2.y) + gUploadData[0];
+	}
+	else
+	{
+		// Read-modify-write gData
+		gData[tid.x] = (gData[tid.x] + cUIntValue) * cUIntValue2;
+	}
+}

+ 9 - 0
Jolt/Shaders/TestComputeBindings.h

@@ -0,0 +1,9 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+JPH_SHADER_BIND_BEGIN(JPH_TestCompute)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gUploadData)
+	JPH_SHADER_BIND_BUFFER(JPH_uint, gOptionalData)
+	JPH_SHADER_BIND_RW_BUFFER(JPH_uint, gData)
+JPH_SHADER_BIND_END

+ 1 - 1
JoltViewer/JoltViewer.cmake

@@ -12,7 +12,7 @@ set(JOLT_VIEWER_SRC_FILES
 source_group(TREE ${JOLT_VIEWER_ROOT} FILES ${JOLT_VIEWER_SRC_FILES})
 
 # Create JoltViewer executable
-if ("${CMAKE_SYSTEM_NAME}" MATCHES "Darwin")
+if (APPLE)
 	# Icon
 	set(JPH_ICON "${CMAKE_CURRENT_SOURCE_DIR}/macOS/icon.icns")
 	set_source_files_properties(${JPH_ICON} PROPERTIES MACOSX_PACKAGE_LOCATION "Resources")

+ 6 - 1
JoltViewer/JoltViewer.cpp

@@ -29,10 +29,15 @@ JPH_GCC_SUPPRESS_WARNING("-Wswitch")
 JoltViewer::JoltViewer(const String &inCommandLine) :
 	Application("Jolt Viewer", inCommandLine)
 {
-	// Get file name from command line
+	// Explode command line into separate arguments
 	Array<String> args;
 	StringToVector(inCommandLine, args, " ");
 
+	// Remove entries starting with `-`
+	for (int i = (int)args.size() - 1; i >= 0; --i)
+		if (!args[i].empty() && args[i].at(0) == '-')
+			args.erase(args.begin() + i);
+
 	// Check arguments
 	if (args.size() != 2 || args[1].empty())
 		FatalError("Usage: JoltViewer <recording filename>");

+ 1 - 1
Samples/Samples.cmake

@@ -338,7 +338,7 @@ set(SAMPLES_ASSETS
 source_group(TREE ${SAMPLES_ROOT} FILES ${SAMPLES_SRC_FILES})
 
 # Create Samples executable
-if ("${CMAKE_SYSTEM_NAME}" MATCHES "Darwin")
+if (APPLE)
 	# Icon
 	set(JPH_ICON "${CMAKE_CURRENT_SOURCE_DIR}/macOS/icon.icns")
 	set_source_files_properties(${JPH_ICON} PROPERTIES MACOSX_PACKAGE_LOCATION "Resources")

+ 28 - 3
Samples/SamplesApp.cpp

@@ -49,6 +49,7 @@
 #include <Utils/ShapeCreator.h>
 #include <Utils/CustomMemoryHook.h>
 #include <Utils/SoftBodyCreator.h>
+#include <Utils/ReadData.h>
 #include <Renderer/DebugRendererImp.h>
 
 JPH_SUPPRESS_WARNINGS_STD_BEGIN
@@ -470,6 +471,22 @@ SamplesApp::SamplesApp(const String &inCommandLine) :
 	// Create single threaded job system for validating
 	mJobSystemValidating = new JobSystemSingleThreaded(cMaxPhysicsJobs);
 
+	// Set shader loader
+	mRenderer->GetComputeSystem().mShaderLoader = [](const char *inName, Array<uint8> &outData) {
+	#ifdef JPH_PLATFORM_MACOS
+		// In macOS the shaders are copied to the bundle
+		String base_path = "Jolt/Shaders/";
+	#else
+		// On other platforms they are in the Jolt source folder
+		String base_path = "../Jolt/Shaders/";
+	#endif
+		outData = ReadData((base_path + inName).c_str());
+		return true;
+	};
+
+	// Create compute queue
+	mComputeQueue = mRenderer->GetComputeSystem().CreateComputeQueue();
+
 	{
 		// Disable allocation checking
 		DisableCustomMemoryHook dcmh;
@@ -642,10 +659,16 @@ SamplesApp::SamplesApp(const String &inCommandLine) :
 		mDebugUI->ShowMenu(main_menu);
 	}
 
-	// Get test name from command line
-	String cmd_line = ToLower(inCommandLine);
+	// Explode command line into separate arguments
 	Array<String> args;
-	StringToVector(cmd_line, args, " ");
+	StringToVector(ToLower(inCommandLine), args, " ");
+
+	// Remove entries starting with `-`
+	for (int i = (int)args.size() - 1; i >= 0; --i)
+		if (!args[i].empty() && args[i].at(0) == '-')
+			args.erase(args.begin() + i);
+
+	// Get test name from command line
 	if (args.size() == 2)
 	{
 		String cmd = args[1];
@@ -689,6 +712,7 @@ SamplesApp::~SamplesApp()
 	delete mTest;
 	delete mContactListener;
 	delete mPhysicsSystem;
+	mComputeQueue = nullptr;
 	delete mJobSystemValidating;
 	delete mJobSystem;
 	delete mTempAllocator;
@@ -736,6 +760,7 @@ void SamplesApp::StartTest(const RTTI *inRTTI)
 	mTest = static_cast<Test *>(inRTTI->CreateObject());
 	mTest->SetPhysicsSystem(mPhysicsSystem);
 	mTest->SetJobSystem(mJobSystem);
+	mTest->SetComputeSystem(&mRenderer->GetComputeSystem(), mComputeQueue);
 	mTest->SetDebugRenderer(mDebugRenderer);
 	mTest->SetTempAllocator(mTempAllocator);
 	if (mInstallContactListener)

+ 1 - 0
Samples/SamplesApp.h

@@ -92,6 +92,7 @@ private:
 	TempAllocator *			mTempAllocator = nullptr;									// Allocator for temporary allocations
 	JobSystem *				mJobSystem = nullptr;										// The job system that runs physics jobs
 	JobSystem *				mJobSystemValidating = nullptr;								// The job system to use when validating determinism
+	Ref<ComputeQueue>		mComputeQueue = nullptr;									// The compute queue to use for compute jobs
 	BPLayerInterfaceImpl	mBroadPhaseLayerInterface;									// The broadphase layer interface that maps object layers to broadphase layers
 	ObjectVsBroadPhaseLayerFilterImpl mObjectVsBroadPhaseLayerFilter;					// Class that filters object vs broadphase layers
 	ObjectLayerPairFilterImpl mObjectVsObjectLayerFilter;								// Class that filters object vs object layers

+ 7 - 0
Samples/Tests/Test.h

@@ -15,6 +15,8 @@ class UIElement;
 namespace JPH {
 	class StateRecorder;
 	class JobSystem;
+	class ComputeSystem;
+	class ComputeQueue;
 	class ContactListener;
 	class DebugRenderer;
 }
@@ -33,6 +35,9 @@ public:
 	// Set the job system
 	void			SetJobSystem(JobSystem *inJobSystem)						{ mJobSystem = inJobSystem; }
 
+	// Set compute system and queue
+	void			SetComputeSystem(ComputeSystem *inComputeSystem, ComputeQueue *inComputeQueue) { mComputeSystem = inComputeSystem; mComputeQueue = inComputeQueue; }
+
 	// Set the debug renderer
 	void			SetDebugRenderer(DebugRenderer *inDebugRenderer)			{ mDebugRenderer = inDebugRenderer; }
 
@@ -128,6 +133,8 @@ protected:
 	void			SetBodyLabel(const BodyID &inBodyID, const String &inLabel)	{ mBodyLabels[inBodyID] = inLabel; }
 
 	JobSystem *		mJobSystem = nullptr;
+	ComputeSystem *	mComputeSystem = nullptr;
+	ComputeQueue *	mComputeQueue = nullptr;
 	PhysicsSystem *	mPhysicsSystem = nullptr;
 	BodyInterface *	mBodyInterface = nullptr;
 	DebugRenderer *	mDebugRenderer = nullptr;

+ 14 - 1
TestFramework/Application/Application.cpp

@@ -27,6 +27,10 @@
 	#include <Window/ApplicationWindowMacOS.h>
 #endif
 
+#ifdef JPH_USE_VK
+extern Renderer *CreateRendererVK();
+#endif
+
 JPH_GCC_SUPPRESS_WARNING("-Wswitch")
 
 // Constructor
@@ -61,6 +65,10 @@ Application::Application(const char *inApplicationName, [[maybe_unused]] const S
 	// Register physics types with the factory
 	RegisterTypes();
 
+	// Explode command line into separate arguments
+	Array<String> args;
+	StringToVector(ToLower(inCommandLine), args, " ");
+
 	{
 		// Disable allocation checking
 		DisableCustomMemoryHook dcmh;
@@ -78,7 +86,12 @@ Application::Application(const char *inApplicationName, [[maybe_unused]] const S
 		mWindow->Initialize(inApplicationName);
 
 		// Create renderer
-		mRenderer = Renderer::sCreate();
+	#ifdef JPH_USE_VK
+		if (std::find(args.begin(), args.end(), "-vulkan") != args.end())
+			mRenderer = CreateRendererVK();
+		else
+	#endif
+			mRenderer = Renderer::sCreate();
 		mRenderer->Initialize(mWindow);
 
 		// Create font

+ 0 - 117
TestFramework/Renderer/DX12/CommandQueueDX12.h

@@ -1,117 +0,0 @@
-// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
-// SPDX-FileCopyrightText: 2021 Jorrit Rouwe
-// SPDX-License-Identifier: MIT
-
-#pragma once
-
-#include <Renderer/DX12/FatalErrorIfFailedDX12.h>
-
-/// Holds a number of DirectX operations with logic to wait for completion
-class CommandQueueDX12
-{
-public:
-	/// Destructor
-										~CommandQueueDX12()
-	{
-		WaitUntilFinished();
-
-		if (mFenceEvent != INVALID_HANDLE_VALUE)
-			CloseHandle(mFenceEvent);
-	}
-
-	/// Initialize the queue
-	void								Initialize(ID3D12Device *inDevice)
-	{
-		D3D12_COMMAND_QUEUE_DESC queue_desc = {};
-		queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
-		queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
-		FatalErrorIfFailed(inDevice->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&mCommandQueue)));
-
-		FatalErrorIfFailed(inDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&mCommandAllocator)));
-
-		// Create the command list
-		FatalErrorIfFailed(inDevice->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, mCommandAllocator.Get(), nullptr, IID_PPV_ARGS(&mCommandList)));
-
-		// Command lists are created in the recording state, but there is nothing to record yet. The main loop expects it to be closed, so close it now
-		FatalErrorIfFailed(mCommandList->Close());
-
-		// Create synchronization object
-		FatalErrorIfFailed(inDevice->CreateFence(mFenceValue, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&mFence)));
-
-		// Increment fence value so we don't skip waiting the first time a command list is executed
-		mFenceValue++;
-
-		// Create an event handle to use for frame synchronization
-		mFenceEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
-		if (mFenceEvent == nullptr)
-			FatalErrorIfFailed(HRESULT_FROM_WIN32(GetLastError()));
-	}
-
-	/// Start the command list (requires waiting until the previous one is finished)
-	ID3D12GraphicsCommandList *			Start()
-	{
-		// Reset the allocator
-		FatalErrorIfFailed(mCommandAllocator->Reset());
-
-		// Reset the command list
-		FatalErrorIfFailed(mCommandList->Reset(mCommandAllocator.Get(), nullptr));
-
-		return mCommandList.Get();
-	}
-
-	/// Execute accumulated command list
-	void								Execute()
-	{
-		JPH_ASSERT(!mIsExecuting);
-
-		// Close the command list
-		FatalErrorIfFailed(mCommandList->Close());
-
-		// Execute the command list
-		ID3D12CommandList* ppCommandLists[] = { mCommandList.Get() };
-		mCommandQueue->ExecuteCommandLists(_countof(ppCommandLists), ppCommandLists);
-
-		// Schedule a Signal command in the queue
-		FatalErrorIfFailed(mCommandQueue->Signal(mFence.Get(), mFenceValue));
-
-		// Mark that we're executing
-		mIsExecuting = true;
-	}
-
-	/// After executing, this waits until execution is done
-	void								WaitUntilFinished()
-	{
-		// Check if we've been started
-		if (mIsExecuting)
-		{
-			if (mFence->GetCompletedValue() < mFenceValue)
-			{
-				// Wait until the fence has been processed
-				FatalErrorIfFailed(mFence->SetEventOnCompletion(mFenceValue, mFenceEvent));
-				WaitForSingleObjectEx(mFenceEvent, INFINITE, FALSE);
-			}
-
-			// Increment the fence value
-			mFenceValue++;
-
-			// Done executing
-			mIsExecuting = false;
-		}
-	}
-
-	/// Execute and wait for the command list to finish
-	void								ExecuteAndWait()
-	{
-		Execute();
-		WaitUntilFinished();
-	}
-
-private:
-	ComPtr<ID3D12CommandQueue>			mCommandQueue;								///< The command queue that will hold command lists
-	ComPtr<ID3D12CommandAllocator>		mCommandAllocator;							///< Allocator that holds the memory for the commands
-	ComPtr<ID3D12GraphicsCommandList>	mCommandList;								///< The command list that will hold the render commands / state changes
-	HANDLE								mFenceEvent = INVALID_HANDLE_VALUE;			///< Fence event, used to wait for rendering to complete
-	ComPtr<ID3D12Fence>					mFence;										///< Fence object, used to signal the fence event
-	UINT64								mFenceValue = 0;							///< Current fence value, each time we need to wait we will signal the fence with this value, wait for it and then increase the value
-	bool								mIsExecuting = false;						///< If a commandlist is currently executing on the queue
-};

+ 0 - 40
TestFramework/Renderer/DX12/ConstantBufferDX12.cpp

@@ -1,40 +0,0 @@
-// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
-// SPDX-FileCopyrightText: 2021 Jorrit Rouwe
-// SPDX-License-Identifier: MIT
-
-#include <TestFramework.h>
-
-#include <Renderer/DX12/ConstantBufferDX12.h>
-#include <Renderer/DX12/RendererDX12.h>
-#include <Renderer/DX12/FatalErrorIfFailedDX12.h>
-
-ConstantBufferDX12::ConstantBufferDX12(RendererDX12 *inRenderer, uint64 inBufferSize) :
-	mRenderer(inRenderer)
-{
-	mBuffer = mRenderer->CreateD3DResourceOnUploadHeap(inBufferSize);
-	mBufferSize = inBufferSize;
-}
-
-ConstantBufferDX12::~ConstantBufferDX12()
-{
-	if (mBuffer != nullptr)
-		mRenderer->RecycleD3DResourceOnUploadHeap(mBuffer.Get(), mBufferSize);
-}
-
-void *ConstantBufferDX12::MapInternal()
-{
-	void *mapped_resource;
-	D3D12_RANGE range = { 0, 0 }; // We're not going to read
-	FatalErrorIfFailed(mBuffer->Map(0, &range, &mapped_resource));
-	return mapped_resource;
-}
-
-void ConstantBufferDX12::Unmap()
-{
-	mBuffer->Unmap(0, nullptr);
-}
-
-void ConstantBufferDX12::Bind(int inSlot)
-{
-	mRenderer->GetCommandList()->SetGraphicsRootConstantBufferView(inSlot, mBuffer->GetGPUVirtualAddress());
-}

+ 0 - 32
TestFramework/Renderer/DX12/ConstantBufferDX12.h

@@ -1,32 +0,0 @@
-// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
-// SPDX-FileCopyrightText: 2021 Jorrit Rouwe
-// SPDX-License-Identifier: MIT
-
-#pragma once
-
-class RendererDX12;
-
-/// A binary blob that can be used to pass constants to a shader
-class ConstantBufferDX12
-{
-public:
-	/// Constructor
-										ConstantBufferDX12(RendererDX12 *inRenderer, uint64 inBufferSize);
-										~ConstantBufferDX12();
-
-	/// Map / unmap buffer (get pointer to data). This will discard all data in the buffer.
-	template <typename T> T *			Map()											{ return reinterpret_cast<T *>(MapInternal()); }
-	void								Unmap();
-
-	// Bind the constant buffer to a slot
-	void								Bind(int inSlot);
-
-private:
-	friend class RendererDX12;
-
-	void *								MapInternal();
-
-	RendererDX12 *						mRenderer;
-	ComPtr<ID3D12Resource>				mBuffer;
-	uint64								mBufferSize;
-};

+ 2 - 0
TestFramework/Renderer/DX12/DescriptorHeapDX12.h

@@ -4,6 +4,8 @@
 
 #pragma once
 
+#include <Renderer/DX12/FatalErrorIfFailedDX12.h>
+
 /// DirectX descriptor heap, used to allocate handles for resources to bind them to shaders
 class DescriptorHeapDX12
 {

+ 40 - 142
TestFramework/Renderer/DX12/RendererDX12.cpp

@@ -13,15 +13,19 @@
 #include <Renderer/DX12/RenderInstancesDX12.h>
 #include <Renderer/DX12/FatalErrorIfFailedDX12.h>
 #include <Window/ApplicationWindowWin.h>
+#include <Jolt/Compute/DX12/ComputeBufferDX12.h>
 #include <Jolt/Core/Profiler.h>
 #include <Utils/ReadData.h>
 #include <Utils/Log.h>
 #include <Utils/AssetStream.h>
 
 #include <d3dcompiler.h>
-#ifdef JPH_DEBUG
-	#include <d3d12sdklayers.h>
-#endif
+
+RendererDX12::RendererDX12()
+{
+	// Ensure ComputeSystem doesn't get destructed
+	ComputeSystem::SetEmbedded();
+}
 
 RendererDX12::~RendererDX12()
 {
@@ -65,7 +69,7 @@ void RendererDX12::CreateRenderTargets()
 		mRenderTargetViews[n] = mRTVHeap.Allocate();
 
 		FatalErrorIfFailed(mSwapChain->GetBuffer(n, IID_PPV_ARGS(&mRenderTargets[n])));
-		mDevice->CreateRenderTargetView(mRenderTargets[n].Get(), nullptr, mRenderTargetViews[n]);
+		GetDevice()->CreateRenderTargetView(mRenderTargets[n].Get(), nullptr, mRenderTargetViews[n]);
 	}
 }
 
@@ -104,7 +108,8 @@ void RendererDX12::CreateDepthBuffer()
 	depth_stencil_desc.Layout = D3D12_TEXTURE_LAYOUT_UNKNOWN;
 	depth_stencil_desc.Flags = D3D12_RESOURCE_FLAG_ALLOW_DEPTH_STENCIL;
 
-	FatalErrorIfFailed(mDevice->CreateCommittedResource(&heap_properties, D3D12_HEAP_FLAG_NONE, &depth_stencil_desc, D3D12_RESOURCE_STATE_DEPTH_WRITE, &clear_value, IID_PPV_ARGS(&mDepthStencilBuffer)));
+	ID3D12Device *device = GetDevice();
+	FatalErrorIfFailed(device->CreateCommittedResource(&heap_properties, D3D12_HEAP_FLAG_NONE, &depth_stencil_desc, D3D12_RESOURCE_STATE_DEPTH_WRITE, &clear_value, IID_PPV_ARGS(&mDepthStencilBuffer)));
 
 	// Allocate depth stencil view
 	D3D12_DEPTH_STENCIL_VIEW_DESC depth_stencil_view_desc = {};
@@ -113,108 +118,36 @@ void RendererDX12::CreateDepthBuffer()
 	depth_stencil_view_desc.Flags = D3D12_DSV_FLAG_NONE;
 
 	mDepthStencilView = mDSVHeap.Allocate();
-	mDevice->CreateDepthStencilView(mDepthStencilBuffer.Get(), &depth_stencil_view_desc, mDepthStencilView);
+	device->CreateDepthStencilView(mDepthStencilBuffer.Get(), &depth_stencil_view_desc, mDepthStencilView);
 }
 
 void RendererDX12::Initialize(ApplicationWindow *inWindow)
 {
 	Renderer::Initialize(inWindow);
 
-#if defined(JPH_DEBUG)
-	// Enable the D3D12 debug layer
-	ComPtr<ID3D12Debug> debug_controller;
-	if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug_controller))))
-		debug_controller->EnableDebugLayer();
-#endif
-
-	// Create DXGI factory
-	FatalErrorIfFailed(CreateDXGIFactory1(IID_PPV_ARGS(&mDXGIFactory)));
-
-	// Find adapter
-	ComPtr<IDXGIAdapter1> adapter;
-
-	HRESULT result = E_FAIL;
-
-	// First check if we have the Windows 1803 IDXGIFactory6 interface
-	ComPtr<IDXGIFactory6> factory6;
-	if (SUCCEEDED(mDXGIFactory->QueryInterface(IID_PPV_ARGS(&factory6))))
-	{
-		for (UINT index = 0; DXGI_ERROR_NOT_FOUND != factory6->EnumAdapterByGpuPreference(index, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(&adapter)); ++index)
-		{
-			DXGI_ADAPTER_DESC1 desc;
-			adapter->GetDesc1(&desc);
-
-			// We don't want software renderers
-			if (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE)
-				continue;
-
-			// Check to see whether the adapter supports Direct3D 12
-			result = D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&mDevice));
-			if (SUCCEEDED(result))
-				break;
-		}
-	}
-	else
-	{
-		// Fall back to the older method that may not get the fastest GPU
-		for (UINT index = 0; DXGI_ERROR_NOT_FOUND != mDXGIFactory->EnumAdapters1(index, &adapter); ++index)
-		{
-			DXGI_ADAPTER_DESC1 desc;
-			adapter->GetDesc1(&desc);
-
-			// We don't want software renderers
-			if (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE)
-				continue;
-
-			// Check to see whether the adapter supports Direct3D 12
-			result = D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&mDevice));
-			if (SUCCEEDED(result))
-				break;
-		}
-	}
-
-	// Check if we managed to obtain a device
-	FatalErrorIfFailed(result);
-
-#ifdef JPH_DEBUG
-	// Enable breaking on errors
-	ComPtr<ID3D12InfoQueue> info_queue;
-	if (SUCCEEDED(mDevice.As(&info_queue)))
-	{
-		info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_CORRUPTION, TRUE);
-		info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_ERROR, TRUE);
-		info_queue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_WARNING, TRUE);
-
-		// Disable an error that triggers on Windows 11 with a hybrid graphic system
-		// See: https://stackoverflow.com/questions/69805245/directx-12-application-is-crashing-in-windows-11
-		D3D12_MESSAGE_ID hide[] =
-		{
-			D3D12_MESSAGE_ID_RESOURCE_BARRIER_MISMATCHING_COMMAND_LIST_TYPE,
-		};
-		D3D12_INFO_QUEUE_FILTER filter = { };
-		filter.DenyList.NumIDs = static_cast<UINT>( std::size( hide ) );
-		filter.DenyList.pIDList = hide;
-		info_queue->AddStorageFilterEntries( &filter );
-	}
-#endif // JPH_DEBUG
+	if (!ComputeSystemDX12Impl::Initialize())
+		FatalError("Failed to initialize DirectX");
 
 	// Disable full screen transitions
-	FatalErrorIfFailed(mDXGIFactory->MakeWindowAssociation(static_cast<ApplicationWindowWin *>(mWindow)->GetWindowHandle(), DXGI_MWA_NO_ALT_ENTER));
+	IDXGIFactory4 *factory = GetDXGIFactory();
+	FatalErrorIfFailed(factory->MakeWindowAssociation(static_cast<ApplicationWindowWin *>(mWindow)->GetWindowHandle(), DXGI_MWA_NO_ALT_ENTER));
 
 	// Create heaps
-	mRTVHeap.Init(mDevice.Get(), D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE, 2);
-	mDSVHeap.Init(mDevice.Get(), D3D12_DESCRIPTOR_HEAP_TYPE_DSV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE, 4);
-	mSRVHeap.Init(mDevice.Get(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE, 128);
+	ID3D12Device *device = GetDevice();
+	mRTVHeap.Init(device, D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE, 2);
+	mDSVHeap.Init(device, D3D12_DESCRIPTOR_HEAP_TYPE_DSV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE, 4);
+	mSRVHeap.Init(device, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE, 128);
 
 	// Create a command queue
 	D3D12_COMMAND_QUEUE_DESC queue_desc = {};
 	queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
 	queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
-	FatalErrorIfFailed(mDevice->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&mCommandQueue)));
+	queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL;
+	FatalErrorIfFailed(device->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&mCommandQueue)));
 
 	// Create a command allocator for each frame
 	for (uint n = 0; n < cFrameCount; n++)
-		FatalErrorIfFailed(mDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&mCommandAllocators[n])));
+		FatalErrorIfFailed(device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&mCommandAllocators[n])));
 
 	// Describe and create the swap chain
 	DXGI_SWAP_CHAIN_DESC swap_chain_desc = {};
@@ -229,7 +162,7 @@ void RendererDX12::Initialize(ApplicationWindow *inWindow)
 	swap_chain_desc.Windowed = TRUE;
 
 	ComPtr<IDXGISwapChain> swap_chain;
-	FatalErrorIfFailed(mDXGIFactory->CreateSwapChain(mCommandQueue.Get(), &swap_chain_desc, &swap_chain));
+	FatalErrorIfFailed(factory->CreateSwapChain(mCommandQueue.Get(), &swap_chain_desc, &swap_chain));
 	FatalErrorIfFailed(swap_chain.As(&mSwapChain));
 	mFrameIndex = mSwapChain->GetCurrentBackBufferIndex();
 
@@ -299,16 +232,16 @@ void RendererDX12::Initialize(ApplicationWindow *inWindow)
 	ComPtr<ID3DBlob> signature;
 	ComPtr<ID3DBlob> error;
 	FatalErrorIfFailed(D3D12SerializeRootSignature(&root_signature_desc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error));
-	FatalErrorIfFailed(mDevice->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&mRootSignature)));
+	FatalErrorIfFailed(device->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&mRootSignature)));
 
 	// Create the command list
-	FatalErrorIfFailed(mDevice->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, mCommandAllocators[mFrameIndex].Get(), nullptr, IID_PPV_ARGS(&mCommandList)));
+	FatalErrorIfFailed(device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, mCommandAllocators[mFrameIndex].Get(), nullptr, IID_PPV_ARGS(&mCommandList)));
 
 	// Command lists are created in the recording state, but there is nothing to record yet. The main loop expects it to be closed, so close it now
 	FatalErrorIfFailed(mCommandList->Close());
 
 	// Create synchronization object
-	FatalErrorIfFailed(mDevice->CreateFence(mFenceValues[mFrameIndex], D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&mFence)));
+	FatalErrorIfFailed(device->CreateFence(mFenceValues[mFrameIndex], D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&mFence)));
 
 	// Increment fence value so we don't skip waiting the first time a command list is executed
 	mFenceValues[mFrameIndex]++;
@@ -319,14 +252,14 @@ void RendererDX12::Initialize(ApplicationWindow *inWindow)
 		FatalErrorIfFailed(HRESULT_FROM_WIN32(GetLastError()));
 
 	// Initialize the queue used to upload resources to the GPU
-	mUploadQueue.Initialize(mDevice.Get());
+	mUploadQueue.Initialize(device, D3D12_COMMAND_LIST_TYPE_DIRECT);
 
 	// Create constant buffer. One per frame to avoid overwriting the constant buffer while the GPU is still using it.
 	for (uint n = 0; n < cFrameCount; ++n)
 	{
-		mVertexShaderConstantBufferProjection[n] = CreateConstantBuffer(sizeof(VertexShaderConstantBuffer));
-		mVertexShaderConstantBufferOrtho[n] = CreateConstantBuffer(sizeof(VertexShaderConstantBuffer));
-		mPixelShaderConstantBuffer[n] = CreateConstantBuffer(sizeof(PixelShaderConstantBuffer));
+		mVertexShaderConstantBufferProjection[n] = CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(VertexShaderConstantBuffer));
+		mVertexShaderConstantBufferOrtho[n] = CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(VertexShaderConstantBuffer));
+		mPixelShaderConstantBuffer[n] = CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(PixelShaderConstantBuffer));
 	}
 
 	// Create depth only texture (no color buffer, as seen from light)
@@ -380,7 +313,7 @@ bool RendererDX12::BeginFrame(const CameraState &inCamera, float inWorldScale)
 
 	// Set SRV heap
 	ID3D12DescriptorHeap *heaps[] = { mSRVHeap.Get() };
-	mCommandList->SetDescriptorHeaps(_countof(heaps), heaps);
+	mCommandList->SetDescriptorHeaps(std::size(heaps), heaps);
 
 	// Indicate that the back buffer will be used as a render target.
 	D3D12_RESOURCE_BARRIER barrier;
@@ -398,12 +331,12 @@ bool RendererDX12::BeginFrame(const CameraState &inCamera, float inWorldScale)
 	mCommandList->ClearDepthStencilView(mDepthStencilView, D3D12_CLEAR_FLAG_DEPTH, 0.0f, 0, 0, nullptr);
 
 	// Set constants for vertex shader in projection mode
-	VertexShaderConstantBuffer *vs = mVertexShaderConstantBufferProjection[mFrameIndex]->Map<VertexShaderConstantBuffer>();
+	VertexShaderConstantBuffer *vs = mVertexShaderConstantBufferProjection[mFrameIndex]->Map<VertexShaderConstantBuffer>(ComputeBuffer::EMode::Write);
 	*vs = mVSBuffer;
 	mVertexShaderConstantBufferProjection[mFrameIndex]->Unmap();
 
 	// Set constants for vertex shader in ortho mode
-	vs = mVertexShaderConstantBufferOrtho[mFrameIndex]->Map<VertexShaderConstantBuffer>();
+	vs = mVertexShaderConstantBufferOrtho[mFrameIndex]->Map<VertexShaderConstantBuffer>(ComputeBuffer::EMode::Write);
 	*vs = mVSBufferOrtho;
 	mVertexShaderConstantBufferOrtho[mFrameIndex]->Unmap();
 
@@ -411,12 +344,12 @@ bool RendererDX12::BeginFrame(const CameraState &inCamera, float inWorldScale)
 	SetProjectionMode();
 
 	// Set constants for pixel shader
-	PixelShaderConstantBuffer *ps = mPixelShaderConstantBuffer[mFrameIndex]->Map<PixelShaderConstantBuffer>();
+	PixelShaderConstantBuffer *ps = mPixelShaderConstantBuffer[mFrameIndex]->Map<PixelShaderConstantBuffer>(ComputeBuffer::EMode::Write);
 	*ps = mPSBuffer;
 	mPixelShaderConstantBuffer[mFrameIndex]->Unmap();
 
 	// Set the pixel shader constant buffer data.
-	mPixelShaderConstantBuffer[mFrameIndex]->Bind(1);
+	mCommandList->SetGraphicsRootConstantBufferView(1, static_cast<ComputeBufferDX12 *>(mPixelShaderConstantBuffer[mFrameIndex].GetPtr())->GetResourceCPU()->GetGPUVirtualAddress());
 
 	// Start drawing the shadow pass
 	mShadowMap->SetAsRenderTarget(true);
@@ -464,7 +397,7 @@ void RendererDX12::EndFrame()
 
 	// Execute the command list
 	ID3D12CommandList* command_lists[] = { mCommandList.Get() };
-	mCommandQueue->ExecuteCommandLists(_countof(command_lists), command_lists);
+	mCommandQueue->ExecuteCommandLists(std::size(command_lists), command_lists);
 
 	// Present the frame
 	FatalErrorIfFailed(mSwapChain->Present(1, 0));
@@ -500,14 +433,14 @@ void RendererDX12::SetProjectionMode()
 {
 	JPH_ASSERT(mInFrame);
 
-	mVertexShaderConstantBufferProjection[mFrameIndex]->Bind(0);
+	mCommandList->SetGraphicsRootConstantBufferView(0, static_cast<ComputeBufferDX12 *>(mVertexShaderConstantBufferProjection[mFrameIndex].GetPtr())->GetResourceCPU()->GetGPUVirtualAddress());
 }
 
 void RendererDX12::SetOrthoMode()
 {
 	JPH_ASSERT(mInFrame);
 
-	mVertexShaderConstantBufferOrtho[mFrameIndex]->Bind(0);
+	mCommandList->SetGraphicsRootConstantBufferView(0, static_cast<ComputeBufferDX12 *>(mVertexShaderConstantBufferOrtho[mFrameIndex].GetPtr())->GetResourceCPU()->GetGPUVirtualAddress());
 }
 
 Ref<Texture> RendererDX12::CreateTexture(const Surface *inSurface)
@@ -595,11 +528,6 @@ Ref<PixelShader> RendererDX12::CreatePixelShader(const char *inName)
 	return new PixelShaderDX12(shader_blob);
 }
 
-unique_ptr<ConstantBufferDX12> RendererDX12::CreateConstantBuffer(uint inBufferSize)
-{
-	return make_unique<ConstantBufferDX12>(this, inBufferSize);
-}
-
 unique_ptr<PipelineState> RendererDX12::CreatePipelineState(const VertexShader *inVertexShader, const PipelineState::EInputDescription *inInputDescription, uint inInputDescriptionCount, const PixelShader *inPixelShader, PipelineState::EDrawPass inDrawPass, PipelineState::EFillMode inFillMode, PipelineState::ETopology inTopology, PipelineState::EDepthTest inDepthTest, PipelineState::EBlendMode inBlendMode, PipelineState::ECullMode inCullMode)
 {
 	return make_unique<PipelineStateDX12>(this, static_cast<const VertexShaderDX12 *>(inVertexShader), inInputDescription, inInputDescriptionCount, static_cast<const PixelShaderDX12 *>(inPixelShader), inDrawPass, inFillMode, inTopology, inDepthTest, inBlendMode, inCullMode);
@@ -615,34 +543,6 @@ RenderInstances *RendererDX12::CreateRenderInstances()
 	return new RenderInstancesDX12(this);
 }
 
-ComPtr<ID3D12Resource> RendererDX12::CreateD3DResource(D3D12_HEAP_TYPE inHeapType, D3D12_RESOURCE_STATES inResourceState, uint64 inSize)
-{
-	// Create a new resource
-	D3D12_RESOURCE_DESC desc;
-	desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
-	desc.Alignment = 0;
-	desc.Width = inSize;
-	desc.Height = 1;
-	desc.DepthOrArraySize = 1;
-	desc.MipLevels = 1;
-	desc.Format = DXGI_FORMAT_UNKNOWN;
-	desc.SampleDesc.Count = 1;
-	desc.SampleDesc.Quality = 0;
-	desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
-	desc.Flags = D3D12_RESOURCE_FLAG_NONE;
-
-	D3D12_HEAP_PROPERTIES heap_properties = {};
-	heap_properties.Type = inHeapType;
-	heap_properties.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
-	heap_properties.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
-	heap_properties.CreationNodeMask = 1;
-	heap_properties.VisibleNodeMask = 1;
-
-	ComPtr<ID3D12Resource> resource;
-	FatalErrorIfFailed(mDevice->CreateCommittedResource(&heap_properties, D3D12_HEAP_FLAG_NONE, &desc, inResourceState, nullptr, IID_PPV_ARGS(&resource)));
-	return resource;
-}
-
 void RendererDX12::CopyD3DResource(ID3D12Resource *inDest, const void *inSrc, uint64 inSize)
 {
 	// Copy data to destination buffer
@@ -678,7 +578,7 @@ void RendererDX12::CopyD3DResource(ID3D12Resource *inDest, ID3D12Resource *inSrc
 ComPtr<ID3D12Resource> RendererDX12::CreateD3DResourceOnDefaultHeap(const void *inData, uint64 inSize)
 {
 	ComPtr<ID3D12Resource> upload = CreateD3DResourceOnUploadHeap(inSize);
-	ComPtr<ID3D12Resource> resource = CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, inSize);
+	ComPtr<ID3D12Resource> resource = CreateD3DResource(D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_FLAG_NONE, inSize);
 	CopyD3DResource(upload.Get(), inData, inSize);
 	CopyD3DResource(resource.Get(), upload.Get(), inSize);
 	RecycleD3DResourceOnUploadHeap(upload.Get(), inSize);
@@ -696,7 +596,7 @@ ComPtr<ID3D12Resource> RendererDX12::CreateD3DResourceOnUploadHeap(uint64 inSize
 		return resource;
 	}
 
-	return CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, inSize);
+	return CreateD3DResource(D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_FLAG_NONE, inSize);
 }
 
 void RendererDX12::RecycleD3DResourceOnUploadHeap(ID3D12Resource *inResource, uint64 inSize)
@@ -711,9 +611,7 @@ void RendererDX12::RecycleD3DObject(ID3D12Object *inResource)
 		mDelayReleased[mFrameIndex].push_back(inResource);
 }
 
-#ifndef JPH_ENABLE_VULKAN
 Renderer *Renderer::sCreate()
 {
 	return new RendererDX12;
 }
-#endif

+ 11 - 19
TestFramework/Renderer/DX12/RendererDX12.h

@@ -4,22 +4,23 @@
 
 #pragma once
 
-#include <Jolt/Core/UnorderedMap.h>
 #include <Renderer/Renderer.h>
-#include <Renderer/DX12/CommandQueueDX12.h>
+#include <Jolt/Compute/DX12/ComputeSystemDX12Impl.h>
+#include <Jolt/Compute/DX12/ComputeQueueDX12.h>
 #include <Renderer/DX12/DescriptorHeapDX12.h>
-#include <Renderer/DX12/ConstantBufferDX12.h>
 #include <Renderer/DX12/TextureDX12.h>
 
 /// DirectX 12 renderer
-class RendererDX12 : public Renderer
+class RendererDX12 : public Renderer, public ComputeSystemDX12Impl
 {
 public:
-	/// Destructor
+	/// Constructor / destructor
+									RendererDX12();
 	virtual							~RendererDX12() override;
 
 	// See: Renderer
 	virtual void					Initialize(ApplicationWindow *inWindow) override;
+	virtual ComputeSystem &			GetComputeSystem() override			{ return *this; }
 	virtual bool					BeginFrame(const CameraState &inCamera, float inWorldScale) override;
 	virtual void					EndShadowPass() override;
 	virtual void					EndFrame() override;
@@ -34,9 +35,6 @@ public:
 	virtual Texture *				GetShadowMap() const override		{ return mShadowMap.GetPtr(); }
 	virtual void					OnWindowResize() override;
 
-	/// Create a constant buffer
-	unique_ptr<ConstantBufferDX12>	CreateConstantBuffer(uint inBufferSize);
-
 	/// Create a buffer on the default heap (usable for permanent buffers)
 	ComPtr<ID3D12Resource>			CreateD3DResourceOnDefaultHeap(const void *inData, uint64 inSize);
 
@@ -50,10 +48,9 @@ public:
 	void							RecycleD3DObject(ID3D12Object *inResource);
 
 	/// Access to the most important DirectX structures
-	ID3D12Device *					GetDevice()							{ return mDevice.Get(); }
 	ID3D12RootSignature *			GetRootSignature()					{ return mRootSignature.Get(); }
 	ID3D12GraphicsCommandList *		GetCommandList()					{ JPH_ASSERT(mInFrame); return mCommandList.Get(); }
-	CommandQueueDX12 &				GetUploadQueue()					{ return mUploadQueue; }
+	ComputeQueueDX12 &				GetUploadQueue()					{ return mUploadQueue; }
 	DescriptorHeapDX12 &			GetDSVHeap()						{ return mDSVHeap; }
 	DescriptorHeapDX12 &			GetSRVHeap()						{ return mSRVHeap; }
 
@@ -67,9 +64,6 @@ private:
 	// Create a depth buffer for the back buffer
 	void							CreateDepthBuffer();
 
-	// Function to create a ID3D12Resource on specified heap with specified state
-	ComPtr<ID3D12Resource>			CreateD3DResource(D3D12_HEAP_TYPE inHeapType, D3D12_RESOURCE_STATES inResourceState, uint64 inSize);
-
 	// Copy CPU memory into a ID3D12Resource
 	void							CopyD3DResource(ID3D12Resource *inDest, const void *inSrc, uint64 inSize);
 
@@ -77,8 +71,6 @@ private:
 	void							CopyD3DResource(ID3D12Resource *inDest, ID3D12Resource *inSrc, uint64 inSize);
 
 	// DirectX interfaces
-	ComPtr<IDXGIFactory4>			mDXGIFactory;
-	ComPtr<ID3D12Device>			mDevice;
 	DescriptorHeapDX12				mRTVHeap;							///< Render target view heap
 	DescriptorHeapDX12				mDSVHeap;							///< Depth stencil view heap
 	DescriptorHeapDX12				mSRVHeap;							///< Shader resource view heap
@@ -92,10 +84,10 @@ private:
 	ComPtr<ID3D12GraphicsCommandList> mCommandList;						///< The command list
 	ComPtr<ID3D12RootSignature>		mRootSignature;						///< The root signature, we have a simple application so we only need 1, which is suitable for all our shaders
 	Ref<TextureDX12>				mShadowMap;							///< Used to render shadow maps
-	CommandQueueDX12				mUploadQueue;						///< Queue used to upload resources to GPU memory
-	unique_ptr<ConstantBufferDX12>	mVertexShaderConstantBufferProjection[cFrameCount];
-	unique_ptr<ConstantBufferDX12>	mVertexShaderConstantBufferOrtho[cFrameCount];
-	unique_ptr<ConstantBufferDX12>	mPixelShaderConstantBuffer[cFrameCount];
+	ComputeQueueDX12				mUploadQueue;						///< Queue used to upload resources to GPU memory
+	Ref<ComputeBuffer>				mVertexShaderConstantBufferProjection[cFrameCount];
+	Ref<ComputeBuffer>				mVertexShaderConstantBufferOrtho[cFrameCount];
+	Ref<ComputeBuffer>				mPixelShaderConstantBuffer[cFrameCount];
 
 	// Synchronization objects used to finish rendering and swapping before reusing a command queue
 	HANDLE							mFenceEvent;						///< Fence event to wait for the previous frame rendering to complete (in order to free 1 of the buffers)

+ 5 - 4
TestFramework/Renderer/MTL/RendererMTL.h

@@ -6,17 +6,19 @@
 
 #include <Renderer/Renderer.h>
 #include <Renderer/MTL/TextureMTL.h>
-
-#include <MetalKit/MetalKit.h>
+#include <Jolt/Compute/MTL/ComputeSystemMTL.h>
 
 /// Metal renderer
-class RendererMTL : public Renderer
+class RendererMTL : public Renderer, public ComputeSystemMTL
 {
 public:
+	/// Constructor / destructor
+									RendererMTL();
 	virtual 						~RendererMTL() override;
 	
 	// See: Renderer
 	virtual void					Initialize(ApplicationWindow *inWindow) override;
+	virtual ComputeSystem &			GetComputeSystem() override										{ return *this; }
 	virtual bool					BeginFrame(const CameraState &inCamera, float inWorldScale) override;
 	virtual void					EndShadowPass() override;
 	virtual void					EndFrame() override;
@@ -32,7 +34,6 @@ public:
 	virtual void					OnWindowResize() override										{ }
 
 	MTKView *						GetView() const													{ return mView; }
-	id<MTLDevice>					GetDevice() const												{ return mView.device; }
 	id<MTLRenderCommandEncoder>		GetRenderEncoder() const										{ return mRenderEncoder; }
 
 private:

+ 11 - 3
TestFramework/Renderer/MTL/RendererMTL.mm

@@ -17,11 +17,19 @@
 #include <Utils/AssetStream.h>
 #include <Jolt/Core/Profiler.h>
 
+RendererMTL::RendererMTL()
+{
+	// Ensure ComputeSystem doesn't get destructed
+	ComputeSystem::SetEmbedded();
+}
+
 RendererMTL::~RendererMTL()
 {
 	[mCommandQueue release];
 	[mShadowRenderPass release];
 	[mShaderLibrary release];
+
+	ComputeSystemMTL::Shutdown();
 }
 
 void RendererMTL::Initialize(ApplicationWindow *inWindow)
@@ -30,7 +38,9 @@ void RendererMTL::Initialize(ApplicationWindow *inWindow)
 
 	mView = static_cast<ApplicationWindowMacOS *>(inWindow)->GetMetalView();
 
-	id<MTLDevice> device = GetDevice();
+	id<MTLDevice> device = mView.device;
+
+	ComputeSystemMTL::Initialize(device);
 
 	// Load the shader library containing all shaders for the test framework
 	NSError *error = nullptr;
@@ -176,9 +186,7 @@ RenderInstances *RendererMTL::CreateRenderInstances()
 	return new RenderInstancesMTL(this);
 }
 
-#ifndef JPH_ENABLE_VULKAN
 Renderer *Renderer::sCreate()
 {
 	return new RendererMTL;
 }
-#endif

+ 4 - 0
TestFramework/Renderer/Renderer.h

@@ -12,6 +12,7 @@
 #include <Renderer/PixelShader.h>
 #include <Renderer/RenderPrimitive.h>
 #include <Renderer/RenderInstances.h>
+#include <Jolt/Compute/ComputeSystem.h>
 #include <memory>
 
 // Forward declares
@@ -38,6 +39,9 @@ public:
 	/// Initialize renderer
 	virtual void					Initialize(ApplicationWindow *inWindow);
 
+	/// Access to the compute interface
+	virtual ComputeSystem &			GetComputeSystem() = 0;
+
 	/// Start / end drawing a frame
 	virtual bool					BeginFrame(const CameraState &inCamera, float inWorldScale);
 	virtual void					EndShadowPass() = 0;

+ 0 - 21
TestFramework/Renderer/VK/BufferVK.h

@@ -1,21 +0,0 @@
-// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
-// SPDX-FileCopyrightText: 2024 Jorrit Rouwe
-// SPDX-License-Identifier: MIT
-
-#pragma once
-
-#include <vulkan/vulkan.h>
-
-/// Simple wrapper class to manage a Vulkan buffer
-class BufferVK
-{
-public:
-	VkBuffer					mBuffer = VK_NULL_HANDLE;
-	VkDeviceMemory				mMemory = VK_NULL_HANDLE;
-	VkDeviceSize				mOffset = 0;
-	VkDeviceSize				mSize = 0;
-
-	VkBufferUsageFlags			mUsage;
-	VkMemoryPropertyFlags		mProperties;
-	VkDeviceSize				mAllocatedSize;
-};

+ 0 - 32
TestFramework/Renderer/VK/ConstantBufferVK.cpp

@@ -1,32 +0,0 @@
-// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
-// SPDX-FileCopyrightText: 2024 Jorrit Rouwe
-// SPDX-License-Identifier: MIT
-
-#include <TestFramework.h>
-
-#include <Renderer/VK/ConstantBufferVK.h>
-#include <Renderer/VK/RendererVK.h>
-#include <Renderer/VK/FatalErrorIfFailedVK.h>
-
-ConstantBufferVK::ConstantBufferVK(RendererVK *inRenderer, VkDeviceSize inBufferSize) :
-	mRenderer(inRenderer)
-{
-	mRenderer->CreateBuffer(inBufferSize, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, mBuffer);
-}
-
-ConstantBufferVK::~ConstantBufferVK()
-{
-	mRenderer->FreeBuffer(mBuffer);
-}
-
-void *ConstantBufferVK::MapInternal()
-{
-	void *data = nullptr;
-	FatalErrorIfFailed(vkMapMemory(mRenderer->GetDevice(), mBuffer.mMemory, mBuffer.mOffset, mBuffer.mSize, 0, &data));
-	return data;
-}
-
-void ConstantBufferVK::Unmap()
-{
-	vkUnmapMemory(mRenderer->GetDevice(), mBuffer.mMemory);
-}

+ 0 - 30
TestFramework/Renderer/VK/ConstantBufferVK.h

@@ -1,30 +0,0 @@
-// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
-// SPDX-FileCopyrightText: 2024 Jorrit Rouwe
-// SPDX-License-Identifier: MIT
-
-#pragma once
-
-#include <Renderer/VK/BufferVK.h>
-
-class RendererVK;
-
-/// A binary blob that can be used to pass constants to a shader
-class ConstantBufferVK
-{
-public:
-	/// Constructor
-										ConstantBufferVK(RendererVK *inRenderer, VkDeviceSize inBufferSize);
-										~ConstantBufferVK();
-
-	/// Map / unmap buffer (get pointer to data). This will discard all data in the buffer.
-	template <typename T> T *			Map()											{ return reinterpret_cast<T *>(MapInternal()); }
-	void								Unmap();
-
-	VkBuffer							GetBuffer() const								{ return mBuffer.mBuffer; }
-
-private:
-	void *								MapInternal();
-
-	RendererVK *						mRenderer;
-	BufferVK							mBuffer;
-};

+ 2 - 3
TestFramework/Renderer/VK/PixelShaderVK.h

@@ -14,8 +14,7 @@ class PixelShaderVK : public PixelShader
 public:
 	/// Constructor
 							PixelShaderVK(VkDevice inDevice, VkShaderModule inShaderModule) :
-		mDevice(inDevice),
-		mStageInfo()
+		mDevice(inDevice)
 	{
 		mStageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
 		mStageInfo.stage = VK_SHADER_STAGE_FRAGMENT_BIT;
@@ -30,5 +29,5 @@ public:
 	}
 
 	VkDevice				mDevice;
-	VkPipelineShaderStageCreateInfo mStageInfo;
+	VkPipelineShaderStageCreateInfo mStageInfo = {};
 };

+ 3 - 5
TestFramework/Renderer/VK/RenderInstancesVK.cpp

@@ -10,7 +10,7 @@
 
 void RenderInstancesVK::Clear()
 {
-	mRenderer->FreeBuffer(mInstancesBuffer);
+	mRenderer->FreeBufferDelayed(mInstancesBuffer);
 }
 
 void RenderInstancesVK::CreateBuffer(int inNumInstances, int inInstanceSize)
@@ -22,14 +22,12 @@ void RenderInstancesVK::CreateBuffer(int inNumInstances, int inInstanceSize)
 
 void *RenderInstancesVK::Lock()
 {
-	void *data;
-	FatalErrorIfFailed(vkMapMemory(mRenderer->GetDevice(), mInstancesBuffer.mMemory, mInstancesBuffer.mOffset, mInstancesBuffer.mSize, 0, &data));
-	return data;
+	return mRenderer->MapBuffer(mInstancesBuffer);
 }
 
 void RenderInstancesVK::Unlock()
 {
-	vkUnmapMemory(mRenderer->GetDevice(), mInstancesBuffer.mMemory);
+	mRenderer->UnmapBuffer(mInstancesBuffer);
 }
 
 void RenderInstancesVK::Draw(RenderPrimitive *inPrimitive, int inStartInstance, int inNumInstances) const

+ 6 - 10
TestFramework/Renderer/VK/RenderPrimitiveVK.cpp

@@ -9,7 +9,7 @@
 
 void RenderPrimitiveVK::ReleaseVertexBuffer()
 {
-	mRenderer->FreeBuffer(mVertexBuffer);
+	mRenderer->FreeBufferDelayed(mVertexBuffer);
 	mVertexBufferDeviceLocal = false;
 
 	RenderPrimitive::ReleaseVertexBuffer();
@@ -17,7 +17,7 @@ void RenderPrimitiveVK::ReleaseVertexBuffer()
 
 void RenderPrimitiveVK::ReleaseIndexBuffer()
 {
-	mRenderer->FreeBuffer(mIndexBuffer);
+	mRenderer->FreeBufferDelayed(mIndexBuffer);
 	mIndexBufferDeviceLocal = false;
 
 	RenderPrimitive::ReleaseIndexBuffer();
@@ -41,14 +41,12 @@ void *RenderPrimitiveVK::LockVertexBuffer()
 {
 	JPH_ASSERT(!mVertexBufferDeviceLocal);
 
-	void *data;
-	FatalErrorIfFailed(vkMapMemory(mRenderer->GetDevice(), mVertexBuffer.mMemory, mVertexBuffer.mOffset, VkDeviceSize(mNumVtx) * mVtxSize, 0, &data));
-	return data;
+	return mRenderer->MapBuffer(mVertexBuffer);
 }
 
 void RenderPrimitiveVK::UnlockVertexBuffer()
 {
-	vkUnmapMemory(mRenderer->GetDevice(), mVertexBuffer.mMemory);
+	mRenderer->UnmapBuffer(mVertexBuffer);
 }
 
 void RenderPrimitiveVK::CreateIndexBuffer(int inNumIdx, const uint32 *inData)
@@ -69,14 +67,12 @@ uint32 *RenderPrimitiveVK::LockIndexBuffer()
 {
 	JPH_ASSERT(!mIndexBufferDeviceLocal);
 
-	void *data;
-	vkMapMemory(mRenderer->GetDevice(), mIndexBuffer.mMemory, mIndexBuffer.mOffset, VkDeviceSize(mNumIdx) * sizeof(uint32), 0, &data);
-	return reinterpret_cast<uint32 *>(data);
+	return reinterpret_cast<uint32 *>(mRenderer->MapBuffer(mIndexBuffer));
 }
 
 void RenderPrimitiveVK::UnlockIndexBuffer()
 {
-	vkUnmapMemory(mRenderer->GetDevice(), mIndexBuffer.mMemory);
+	mRenderer->UnmapBuffer(mIndexBuffer);
 }
 
 void RenderPrimitiveVK::Draw() const

+ 1 - 1
TestFramework/Renderer/VK/RenderPrimitiveVK.h

@@ -6,7 +6,7 @@
 
 #include <Renderer/RenderPrimitive.h>
 #include <Renderer/VK/RendererVK.h>
-#include <Renderer/VK/BufferVK.h>
+#include <Jolt/Compute/VK/BufferVK.h>
 
 /// Vulkan implementation of a render primitive
 class RenderPrimitiveVK : public RenderPrimitive

+ 96 - 422
TestFramework/Renderer/VK/RendererVK.cpp

@@ -15,8 +15,8 @@
 #include <Utils/Log.h>
 #include <Utils/ReadData.h>
 #include <Jolt/Core/Profiler.h>
-#include <Jolt/Core/QuickSort.h>
 #include <Jolt/Core/RTTI.h>
+#include <Jolt/Compute/VK/ComputeBufferVK.h>
 
 JPH_SUPPRESS_WARNINGS_STD_BEGIN
 #ifdef JPH_PLATFORM_WINDOWS
@@ -31,54 +31,42 @@ JPH_SUPPRESS_WARNINGS_STD_BEGIN
 #endif
 JPH_SUPPRESS_WARNINGS_STD_END
 
-#ifdef JPH_DEBUG
-
-static VKAPI_ATTR VkBool32 VKAPI_CALL sVulkanDebugCallback(VkDebugUtilsMessageSeverityFlagBitsEXT inSeverity, [[maybe_unused]] VkDebugUtilsMessageTypeFlagsEXT inType, const VkDebugUtilsMessengerCallbackDataEXT *inCallbackData, [[maybe_unused]] void *inUserData)
+RendererVK::RendererVK()
 {
-	Trace("VK: %s", inCallbackData->pMessage);
-	JPH_ASSERT((inSeverity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT) == 0);
-	return VK_FALSE;
+	// Ensure ComputeSystem doesn't get destructed
+	ComputeSystem::SetEmbedded();
 }
 
-#endif // JPH_DEBUG
-
 RendererVK::~RendererVK()
 {
 	vkDeviceWaitIdle(mDevice);
 
-	// Trace allocation stats
-	Trace("VK: Max allocations: %u, max size: %u MB", mMaxNumAllocations, uint32(mMaxTotalAllocated >> 20));
-
 	// Destroy the shadow map
 	mShadowMap = nullptr;
 	vkDestroyFramebuffer(mDevice, mShadowFrameBuffer, nullptr);
 
 	// Release constant buffers
-	for (unique_ptr<ConstantBufferVK> &cb : mVertexShaderConstantBufferProjection)
+	for (Ref<ComputeBuffer> &cb : mVertexShaderConstantBufferProjection)
 		cb = nullptr;
-	for (unique_ptr<ConstantBufferVK> &cb : mVertexShaderConstantBufferOrtho)
+	for (Ref<ComputeBuffer> &cb : mVertexShaderConstantBufferOrtho)
 		cb = nullptr;
-	for (unique_ptr<ConstantBufferVK> &cb : mPixelShaderConstantBuffer)
+	for (Ref<ComputeBuffer> &cb : mPixelShaderConstantBuffer)
 		cb = nullptr;
 	
 	// Free all buffers
-	for (BufferCache &bc : mFreedBuffers)
-		for (BufferCache::value_type &vt : bc)
-			for (BufferVK &bvk : vt.second)
-				FreeBufferInternal(bvk);
-	for (BufferCache::value_type &vt : mBufferCache)
-		for (BufferVK &bvk : vt.second)
-			FreeBufferInternal(bvk);
-
-	// Free all blocks in the memory cache
-	for (MemoryCache::value_type &mc : mMemoryCache)
-		for (Memory &m : mc.second)
-			if (m.mOffset == 0)
-				vkFreeMemory(mDevice, m.mMemory, nullptr); // Don't care about memory tracking anymore
-	
+	for (Array<BufferVK> &buffers : mPerFrameFreedBuffers)
+	{
+		for (BufferVK& buffer : buffers)
+			FreeBuffer(buffer);
+		buffers.clear();
+	}
+
 	for (VkFence fence : mInFlightFences)
 		vkDestroyFence(mDevice, fence, nullptr);
 
+	for (uint32 i = 0; i < cFrameCount; ++i)
+		vkFreeCommandBuffers(mDevice, mCommandPool, 1, &mCommandBuffers[i]);
+
 	vkDestroyCommandPool(mDevice, mCommandPool, nullptr);
 
 	vkDestroyPipelineLayout(mDevice, mPipelineLayout, nullptr);
@@ -97,102 +85,10 @@ RendererVK::~RendererVK()
 	DestroySwapChain();
 
 	vkDestroySurfaceKHR(mInstance, mSurface, nullptr);
-
-	vkDestroyDevice(mDevice, nullptr);
-
-#ifdef JPH_DEBUG
-	PFN_vkDestroyDebugUtilsMessengerEXT vkDestroyDebugUtilsMessengerEXT = (PFN_vkDestroyDebugUtilsMessengerEXT)(void *)vkGetInstanceProcAddr(mInstance, "vkDestroyDebugUtilsMessengerEXT");
-	if (vkDestroyDebugUtilsMessengerEXT != nullptr)
-		vkDestroyDebugUtilsMessengerEXT(mInstance, mDebugMessenger, nullptr);
-#endif
-
-	 vkDestroyInstance(mInstance, nullptr);
 }
 
-void RendererVK::Initialize(ApplicationWindow *inWindow)
+void RendererVK::OnInstanceCreated()
 {
-	Renderer::Initialize(inWindow);
-
-	// Flip the sign of the projection matrix
-	mPerspectiveYSign = -1.0f;
-
-	// Required instance extensions
-	Array<const char *> required_instance_extensions;
-	required_instance_extensions.push_back(VK_KHR_SURFACE_EXTENSION_NAME);
-#ifdef JPH_PLATFORM_WINDOWS
-	required_instance_extensions.push_back(VK_KHR_WIN32_SURFACE_EXTENSION_NAME);
-#elif defined(JPH_PLATFORM_LINUX)
-	required_instance_extensions.push_back(VK_KHR_XLIB_SURFACE_EXTENSION_NAME);
-#elif defined(JPH_PLATFORM_MACOS)
-	required_instance_extensions.push_back(VK_EXT_METAL_SURFACE_EXTENSION_NAME);
-	required_instance_extensions.push_back("VK_KHR_portability_enumeration");
-	required_instance_extensions.push_back("VK_KHR_get_physical_device_properties2");
-#endif
-
-	// Required device extensions
-	Array<const char *> required_device_extensions;
-	required_device_extensions.push_back(VK_KHR_SWAPCHAIN_EXTENSION_NAME);
-#ifdef JPH_PLATFORM_MACOS
-	required_device_extensions.push_back("VK_KHR_portability_subset"); // VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME
-#endif
-
-	// Query supported instance extensions
-	uint32 instance_extension_count = 0;
-	FatalErrorIfFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, nullptr));
-	Array<VkExtensionProperties> instance_extensions;
-	instance_extensions.resize(instance_extension_count);
-	FatalErrorIfFailed(vkEnumerateInstanceExtensionProperties(nullptr, &instance_extension_count, instance_extensions.data()));
-
-	// Query supported validation layers
-	uint32 validation_layer_count;
-	vkEnumerateInstanceLayerProperties(&validation_layer_count, nullptr);
-	Array<VkLayerProperties> validation_layers(validation_layer_count);
-	vkEnumerateInstanceLayerProperties(&validation_layer_count, validation_layers.data());
-
-	// Create Vulkan instance
-	VkInstanceCreateInfo instance_create_info = {};
-	instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
-#ifdef JPH_PLATFORM_MACOS
-	instance_create_info.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
-#endif
-
-#ifdef JPH_DEBUG
-	// Enable validation layer if supported
-	const char *desired_validation_layers[] = { "VK_LAYER_KHRONOS_validation" };
-	for (const VkLayerProperties &p : validation_layers)
-		if (strcmp(desired_validation_layers[0], p.layerName) == 0)
-		{
-			instance_create_info.enabledLayerCount = 1;
-			instance_create_info.ppEnabledLayerNames = desired_validation_layers;
-			break;
-		}
-
-	// Setup debug messenger callback if the extension is supported
-	VkDebugUtilsMessengerCreateInfoEXT messenger_create_info = {};
-	for (const VkExtensionProperties &ext : instance_extensions)
-		if (strcmp(VK_EXT_DEBUG_UTILS_EXTENSION_NAME, ext.extensionName) == 0)
-		{
-			messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
-			messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
-			messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
-			messenger_create_info.pfnUserCallback = sVulkanDebugCallback;
-			instance_create_info.pNext = &messenger_create_info;
-			required_instance_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
-			break;
-		}
-#endif
-
-	instance_create_info.enabledExtensionCount = (uint32)required_instance_extensions.size();
-	instance_create_info.ppEnabledExtensionNames = required_instance_extensions.data();
-	FatalErrorIfFailed(vkCreateInstance(&instance_create_info, nullptr, &mInstance));
-
-#ifdef JPH_DEBUG
-	// Finalize debug messenger callback
-	PFN_vkCreateDebugUtilsMessengerEXT vkCreateDebugUtilsMessengerEXT = (PFN_vkCreateDebugUtilsMessengerEXT)(std::uintptr_t)vkGetInstanceProcAddr(mInstance, "vkCreateDebugUtilsMessengerEXT");
-	if (vkCreateDebugUtilsMessengerEXT != nullptr)
-		FatalErrorIfFailed(vkCreateDebugUtilsMessengerEXT(mInstance, &messenger_create_info, nullptr, &mDebugMessenger));
-#endif
-
 	// Create surface
 #ifdef JPH_PLATFORM_WINDOWS
 	VkWin32SurfaceCreateInfoKHR surface_create_info = {};
@@ -213,154 +109,29 @@ void RendererVK::Initialize(ApplicationWindow *inWindow)
 	surface_create_info.pLayer = static_cast<ApplicationWindowMacOS *>(mWindow)->GetMetalLayer();
 	FatalErrorIfFailed(vkCreateMetalSurfaceEXT(mInstance, &surface_create_info, nullptr, &mSurface));
 #endif
+}
 
-	// Select device
-	uint32 device_count = 0;
-	FatalErrorIfFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, nullptr));
-	Array<VkPhysicalDevice> devices;
-	devices.resize(device_count);
-	FatalErrorIfFailed(vkEnumeratePhysicalDevices(mInstance, &device_count, devices.data()));
-	struct Device
-	{
-		VkPhysicalDevice		mPhysicalDevice;
-		String					mName;
-		VkSurfaceFormatKHR		mFormat;
-		uint32					mGraphicsQueueIndex;
-		uint32					mPresentQueueIndex;
-		int						mScore;
-	};
-	Array<Device> available_devices;
-	for (VkPhysicalDevice device : devices)
-	{
-		// Get device properties
-		VkPhysicalDeviceProperties properties;
-		vkGetPhysicalDeviceProperties(device, &properties);
-
-		// Test if it is an appropriate type
-		int score = 0;
-		switch (properties.deviceType)
-		{
-		case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
-			score = 30;
-			break;
-		case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
-			score = 20;
-			break;
-		case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
-			score = 10;
-			break;
-		case VK_PHYSICAL_DEVICE_TYPE_CPU:
-			score = 5;
-			break;
-		case VK_PHYSICAL_DEVICE_TYPE_OTHER:
-		case VK_PHYSICAL_DEVICE_TYPE_MAX_ENUM:
-			continue;
-		}
-
-		// Check if the device supports all our required extensions
-		uint32 device_extension_count;
-		vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, nullptr);
-		Array<VkExtensionProperties> available_extensions;
-		available_extensions.resize(device_extension_count);
-		vkEnumerateDeviceExtensionProperties(device, nullptr, &device_extension_count, available_extensions.data());
-		int found_extensions = 0;
-		for (const char *required_device_extension : required_device_extensions)
-			for (const VkExtensionProperties &ext : available_extensions)
-				if (strcmp(required_device_extension, ext.extensionName) == 0)
-				{
-					found_extensions++;
-					break;
-				}
-		if (found_extensions != int(required_device_extensions.size()))
-			continue;
-
-		// Find the right queues
-		uint32 queue_family_count = 0;
-		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
-		Array<VkQueueFamilyProperties> queue_families;
-		queue_families.resize(queue_family_count);
-		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
-		uint32 graphics_queue = ~uint32(0);
-		uint32 present_queue = ~uint32(0);
-		for (uint32 i = 0; i < uint32(queue_families.size()); ++i)
-		{
-			if (queue_families[i].queueFlags & VK_QUEUE_GRAPHICS_BIT)
-				graphics_queue = i;
-
-			VkBool32 present_support = false;
-			vkGetPhysicalDeviceSurfaceSupportKHR(device, i, mSurface, &present_support);
-			if (present_support)
-				present_queue = i;
+void RendererVK::Initialize(ApplicationWindow *inWindow)
+{
+	Renderer::Initialize(inWindow);
 
-			if (graphics_queue != ~uint32(0) && present_queue != ~uint32(0))
-				break;
-		}
-		if (graphics_queue == ~uint32(0) || present_queue == ~uint32(0))
-			continue;
+	// Flip the sign of the projection matrix
+	mPerspectiveYSign = -1.0f;
 
-		// Select surface format
-		VkSurfaceFormatKHR selected_format = SelectFormat(device);
-		if (selected_format.format == VK_FORMAT_UNDEFINED)
-			continue;
+	if (!ComputeSystemVKImpl::Initialize())
+		FatalError("Unable to initialize Vulkan");
 
-		// Add the device
-		available_devices.push_back({ device, properties.deviceName, selected_format, graphics_queue, present_queue, score });
-	}
-	if (available_devices.empty())
-		FatalError("No Vulkan device found!");
-	QuickSort(available_devices.begin(), available_devices.end(), [](const Device &inLHS, const Device &inRHS) {
-		return inLHS.mScore > inRHS.mScore;
-	});
-	const Device &selected_device = available_devices[0];
-	Trace("Selected device: %s", selected_device.mName.c_str());
-	mPhysicalDevice = selected_device.mPhysicalDevice;
-
-	// Get memory properties
-	vkGetPhysicalDeviceMemoryProperties(mPhysicalDevice, &mMemoryProperties);
-
-	// Get features
-	VkPhysicalDeviceFeatures physical_device_features = {};
-	vkGetPhysicalDeviceFeatures(mPhysicalDevice, &physical_device_features);
-
-	// Create device
-	float queue_priority = 1.0f;
-	VkDeviceQueueCreateInfo queue_create_info[2] = {};
-	for (size_t i = 0; i < std::size(queue_create_info); ++i)
-	{
-		queue_create_info[i].sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
-		queue_create_info[i].queueCount = 1;
-		queue_create_info[i].pQueuePriorities = &queue_priority;
-	}
-	queue_create_info[0].queueFamilyIndex = selected_device.mGraphicsQueueIndex;
-	queue_create_info[1].queueFamilyIndex = selected_device.mPresentQueueIndex;
-
-	VkPhysicalDeviceFeatures device_features = {};
-
-	if (!physical_device_features.fillModeNonSolid)
-		FatalError("fillModeNonSolid not supported!");
-	device_features.fillModeNonSolid = VK_TRUE;
-
-	VkDeviceCreateInfo device_create_info = {};
-	device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
-	device_create_info.queueCreateInfoCount = selected_device.mGraphicsQueueIndex != selected_device.mPresentQueueIndex? 2 : 1;
-	device_create_info.pQueueCreateInfos = queue_create_info;
-	device_create_info.enabledLayerCount = instance_create_info.enabledLayerCount;
-	device_create_info.ppEnabledLayerNames = instance_create_info.ppEnabledLayerNames;
-	device_create_info.enabledExtensionCount = uint32(required_device_extensions.size());
-	device_create_info.ppEnabledExtensionNames = required_device_extensions.data();
-	device_create_info.pEnabledFeatures = &device_features;
-	FatalErrorIfFailed(vkCreateDevice(selected_device.mPhysicalDevice, &device_create_info, nullptr, &mDevice));
-
-	// Get the queues
-	mGraphicsQueueIndex = selected_device.mGraphicsQueueIndex;
-	mPresentQueueIndex = selected_device.mPresentQueueIndex;
-	vkGetDeviceQueue(mDevice, mGraphicsQueueIndex, 0, &mGraphicsQueue);
-	vkGetDeviceQueue(mDevice, mPresentQueueIndex, 0, &mPresentQueue);
+	// Check if fill mode non solid is supported
+	VkPhysicalDeviceFeatures2 supported_features2 = {};
+	supported_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+	vkGetPhysicalDeviceFeatures2(mPhysicalDevice, &supported_features2);
+	if (!supported_features2.features.fillModeNonSolid)
+		FatalError("This Vulkan implementation does not support fill mode non solid");
 
 	VkCommandPoolCreateInfo pool_info = {};
 	pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
 	pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
-	pool_info.queueFamilyIndex = selected_device.mGraphicsQueueIndex;
+	pool_info.queueFamilyIndex = mGraphicsQueueIndex;
 	FatalErrorIfFailed(vkCreateCommandPool(mDevice, &pool_info, nullptr, &mCommandPool));
 
 	VkCommandBufferAllocateInfo command_buffer_info = {};
@@ -380,9 +151,9 @@ void RendererVK::Initialize(ApplicationWindow *inWindow)
 	// Create constant buffer. One per frame to avoid overwriting the constant buffer while the GPU is still using it.
 	for (uint n = 0; n < cFrameCount; ++n)
 	{
-		mVertexShaderConstantBufferProjection[n] = CreateConstantBuffer(sizeof(VertexShaderConstantBuffer));
-		mVertexShaderConstantBufferOrtho[n] = CreateConstantBuffer(sizeof(VertexShaderConstantBuffer));
-		mPixelShaderConstantBuffer[n] = CreateConstantBuffer(sizeof(PixelShaderConstantBuffer));
+		mVertexShaderConstantBufferProjection[n] = CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(VertexShaderConstantBuffer));
+		mVertexShaderConstantBufferOrtho[n] = CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(VertexShaderConstantBuffer));
+		mPixelShaderConstantBuffer[n] = CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(PixelShaderConstantBuffer));
 	}
 
 	// Create descriptor set layout for the uniform buffers
@@ -445,11 +216,11 @@ void RendererVK::Initialize(ApplicationWindow *inWindow)
 	for (uint i = 0; i < cFrameCount; i++)
 	{
 		VkDescriptorBufferInfo vs_buffer_info = {};
-		vs_buffer_info.buffer = mVertexShaderConstantBufferProjection[i]->GetBuffer();
+		vs_buffer_info.buffer = StaticCast<ComputeBufferVK>(mVertexShaderConstantBufferProjection[i])->GetBufferCPU();
 		vs_buffer_info.range = sizeof(VertexShaderConstantBuffer);
 
 		VkDescriptorBufferInfo ps_buffer_info = {};
-		ps_buffer_info.buffer = mPixelShaderConstantBuffer[i]->GetBuffer();
+		ps_buffer_info.buffer = StaticCast<ComputeBufferVK>(mPixelShaderConstantBuffer[i])->GetBufferCPU();
 		ps_buffer_info.range = sizeof(PixelShaderConstantBuffer);
 
 		VkWriteDescriptorSet descriptor_write[2] = {};
@@ -473,7 +244,7 @@ void RendererVK::Initialize(ApplicationWindow *inWindow)
 	for (uint i = 0; i < cFrameCount; i++)
 	{
 		VkDescriptorBufferInfo vs_buffer_info = {};
-		vs_buffer_info.buffer = mVertexShaderConstantBufferOrtho[i]->GetBuffer();
+		vs_buffer_info.buffer = StaticCast<ComputeBufferVK>(mVertexShaderConstantBufferOrtho[i])->GetBufferCPU();
 		vs_buffer_info.range = sizeof(VertexShaderConstantBuffer);
 
 		VkWriteDescriptorSet descriptor_write = {};
@@ -561,7 +332,7 @@ void RendererVK::Initialize(ApplicationWindow *inWindow)
 		// Create normal render pass
 		VkAttachmentDescription attachments_normal[2] = {};
 		VkAttachmentDescription &color_attachment = attachments_normal[0];
-		color_attachment.format = selected_device.mFormat.format;
+		color_attachment.format = mSelectedFormat.format;
 		color_attachment.samples = VK_SAMPLE_COUNT_1_BIT;
 		color_attachment.loadOp = VK_ATTACHMENT_LOAD_OP_CLEAR;
 		color_attachment.storeOp = VK_ATTACHMENT_STORE_OP_STORE;
@@ -611,6 +382,34 @@ void RendererVK::Initialize(ApplicationWindow *inWindow)
 	CreateSwapChain(mPhysicalDevice);
 }
 
+void RendererVK::GetInstanceExtensions(Array<const char *> &outExtensions)
+{
+#ifdef JPH_PLATFORM_WINDOWS
+	outExtensions.push_back(VK_KHR_WIN32_SURFACE_EXTENSION_NAME);
+#elif defined(JPH_PLATFORM_LINUX)
+	outExtensions.push_back(VK_KHR_XLIB_SURFACE_EXTENSION_NAME);
+#elif defined(JPH_PLATFORM_MACOS)
+	outExtensions.push_back(VK_EXT_METAL_SURFACE_EXTENSION_NAME);
+#endif
+}
+
+void RendererVK::GetDeviceExtensions(Array<const char *> &outExtensions)
+{
+	outExtensions.push_back(VK_KHR_SWAPCHAIN_EXTENSION_NAME);
+}
+
+void RendererVK::GetEnabledFeatures(VkPhysicalDeviceFeatures2 &ioFeatures)
+{
+	ioFeatures.features.fillModeNonSolid = VK_TRUE;
+}
+
+bool RendererVK::HasPresentSupport(VkPhysicalDevice inDevice, uint32 inQueueFamilyIndex)
+{
+	VkBool32 present_support = false;
+	vkGetPhysicalDeviceSurfaceSupportKHR(inDevice, inQueueFamilyIndex, mSurface, &present_support);
+	return present_support == VK_TRUE;
+}
+
 VkSurfaceFormatKHR RendererVK::SelectFormat(VkPhysicalDevice inDevice)
 {
 	uint32 format_count;
@@ -777,7 +576,6 @@ void RendererVK::DestroySwapChain()
 
 		DestroyImage(mDepthImage, mDepthImageMemory);
 		mDepthImage = VK_NULL_HANDLE;
-		mDepthImageMemory = VK_NULL_HANDLE;
 	}
 
 	for (VkFramebuffer frame_buffer : mSwapChainFramebuffers)
@@ -876,17 +674,14 @@ bool RendererVK::BeginFrame(const CameraState &inCamera, float inWorldScale)
 	FreeSemaphore(mImageAvailableSemaphores[mImageIndex]);
 	mImageAvailableSemaphores[mImageIndex] = semaphore;
 
-	// Free buffers that weren't used this frame
-	for (BufferCache::value_type &vt : mBufferCache)
-		for (BufferVK &bvk : vt.second)
-			FreeBufferInternal(bvk);
-	mBufferCache.clear();
-
-	// Recycle the buffers that were freed
-	mBufferCache.swap(mFreedBuffers[mFrameIndex]);
-	
 	vkResetFences(mDevice, 1, &mInFlightFences[mFrameIndex]);
 
+	// Release the buffers that belonged to the previous frame with this index. Nothing should be using them anymore.
+	Array<BufferVK> &buffers = mPerFrameFreedBuffers[mFrameIndex];
+	for (BufferVK &buffer : buffers)
+		FreeBuffer(buffer);
+	buffers.clear();
+
 	VkCommandBuffer command_buffer = GetCommandBuffer();
 	FatalErrorIfFailed(vkResetCommandBuffer(command_buffer, 0));
 
@@ -908,17 +703,17 @@ bool RendererVK::BeginFrame(const CameraState &inCamera, float inWorldScale)
 	vkCmdBeginRenderPass(command_buffer, &render_pass_begin_info, VK_SUBPASS_CONTENTS_INLINE);
 
 	// Set constants for vertex shader in projection mode
-	VertexShaderConstantBuffer *vs = mVertexShaderConstantBufferProjection[mFrameIndex]->Map<VertexShaderConstantBuffer>();
+	VertexShaderConstantBuffer *vs = mVertexShaderConstantBufferProjection[mFrameIndex]->Map<VertexShaderConstantBuffer>(ComputeBuffer::EMode::Write);
 	*vs = mVSBuffer;
 	mVertexShaderConstantBufferProjection[mFrameIndex]->Unmap();
 
 	// Set constants for vertex shader in ortho mode
-	vs = mVertexShaderConstantBufferOrtho[mFrameIndex]->Map<VertexShaderConstantBuffer>();
+	vs = mVertexShaderConstantBufferOrtho[mFrameIndex]->Map<VertexShaderConstantBuffer>(ComputeBuffer::EMode::Write);
 	*vs = mVSBufferOrtho;
 	mVertexShaderConstantBufferOrtho[mFrameIndex]->Unmap();
 
 	// Set constants for pixel shader
-	PixelShaderConstantBuffer *ps = mPixelShaderConstantBuffer[mFrameIndex]->Map<PixelShaderConstantBuffer>();
+	PixelShaderConstantBuffer *ps = mPixelShaderConstantBuffer[mFrameIndex]->Map<PixelShaderConstantBuffer>(ComputeBuffer::EMode::Write);
 	*ps = mPSBuffer;
 	mPixelShaderConstantBuffer[mFrameIndex]->Unmap();
 
@@ -1015,7 +810,7 @@ Ref<Texture> RendererVK::CreateTexture(const Surface *inSurface)
 
 Ref<VertexShader> RendererVK::CreateVertexShader(const char *inName)
 {
-	Array<uint8> data = ReadData((String("Shaders/VK/") + inName + ".vert.spv").c_str());
+	Array<uint8> data = ReadData((String("Shaders/VK/") + inName + ".spv").c_str());
 
 	VkShaderModuleCreateInfo create_info = {};
 	create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
@@ -1029,7 +824,7 @@ Ref<VertexShader> RendererVK::CreateVertexShader(const char *inName)
 
 Ref<PixelShader> RendererVK::CreatePixelShader(const char *inName)
 {
-	Array<uint8> data = ReadData((String("Shaders/VK/") + inName + ".frag.spv").c_str());
+	Array<uint8> data = ReadData((String("Shaders/VK/") + inName + ".spv").c_str());
 
 	VkShaderModuleCreateInfo create_info = {};
 	create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
@@ -1056,104 +851,6 @@ RenderInstances *RendererVK::CreateRenderInstances()
 	return new RenderInstancesVK(this);
 }
 
-uint32 RendererVK::FindMemoryType(uint32 inTypeFilter, VkMemoryPropertyFlags inProperties)
-{
-	for (uint32 i = 0; i < mMemoryProperties.memoryTypeCount; i++)
-		if ((inTypeFilter & (1 << i))
-			&& (mMemoryProperties.memoryTypes[i].propertyFlags & inProperties) == inProperties)
-			return i;
-
-	FatalError("Failed to find memory type!");
-}
-
-void RendererVK::AllocateMemory(VkDeviceSize inSize, uint32 inMemoryTypeBits, VkMemoryPropertyFlags inProperties, VkDeviceMemory &outMemory)
-{
-	VkMemoryAllocateInfo alloc_info = {};
-	alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
-	alloc_info.allocationSize = inSize;
-	alloc_info.memoryTypeIndex = FindMemoryType(inMemoryTypeBits, inProperties);
-	FatalErrorIfFailed(vkAllocateMemory(mDevice, &alloc_info, nullptr, &outMemory));
-
-	// Track allocation
-	++mNumAllocations;
-	mTotalAllocated += inSize;
-
-	// Track max usage
-	mMaxTotalAllocated = max(mMaxTotalAllocated, mTotalAllocated);
-	mMaxNumAllocations = max(mMaxNumAllocations, mNumAllocations);
-}
-
-void RendererVK::FreeMemory(VkDeviceMemory inMemory, VkDeviceSize inSize)
-{
-	vkFreeMemory(mDevice, inMemory, nullptr);
-
-	// Track free
-	--mNumAllocations;
-	mTotalAllocated -= inSize;
-}
-
-void RendererVK::CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer)
-{
-	// Check the cache
-	BufferCache::iterator i = mBufferCache.find({ inSize, inUsage, inProperties });
-	if (i != mBufferCache.end() && !i->second.empty())
-	{
-		outBuffer = i->second.back();
-		i->second.pop_back();
-		return;
-	}
-
-	// Create a new buffer
-	outBuffer.mSize = inSize;
-	outBuffer.mUsage = inUsage;
-	outBuffer.mProperties = inProperties;
-
-	VkBufferCreateInfo create_info = {};
-	create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
-	create_info.size = inSize;
-	create_info.usage = inUsage;
-	create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
-	FatalErrorIfFailed(vkCreateBuffer(mDevice, &create_info, nullptr, &outBuffer.mBuffer));
-
-	VkMemoryRequirements mem_requirements;
-	vkGetBufferMemoryRequirements(mDevice, outBuffer.mBuffer, &mem_requirements);
-
-	if (mem_requirements.size > cMaxAllocSize)
-	{
-		// Allocate block directly
-		AllocateMemory(mem_requirements.size, mem_requirements.memoryTypeBits, inProperties, outBuffer.mMemory);
-		outBuffer.mAllocatedSize = mem_requirements.size;
-		outBuffer.mOffset = 0;
-	}
-	else
-	{
-		// Round allocation to the next power of 2 so that we can use a simple block based allocator
-		outBuffer.mAllocatedSize = max(VkDeviceSize(GetNextPowerOf2(uint32(mem_requirements.size))), cMinAllocSize);
-
-		// Ensure that we have memory available from the right pool
-		Array<Memory> &mem_array = mMemoryCache[{ outBuffer.mAllocatedSize, outBuffer.mUsage, outBuffer.mProperties }];
-		if (mem_array.empty())
-		{
-			// Allocate a bigger block
-			VkDeviceMemory device_memory;
-			AllocateMemory(cBlockSize, mem_requirements.memoryTypeBits, inProperties, device_memory);
-
-			// Divide into sub blocks
-			for (VkDeviceSize offset = 0; offset < cBlockSize; offset += outBuffer.mAllocatedSize)
-				mem_array.push_back({ device_memory, offset });
-		}
-
-		// Claim memory from the pool
-		Memory &memory = mem_array.back();
-		outBuffer.mMemory = memory.mMemory;
-		outBuffer.mOffset = memory.mOffset;
-		mem_array.pop_back();
-	}
-
-	// Bind the memory to the buffer
-	vkBindBufferMemory(mDevice, outBuffer.mBuffer, outBuffer.mMemory, outBuffer.mOffset);
-}
-
 VkCommandBuffer RendererVK::StartTempCommandBuffer()
 {
 	VkCommandBufferAllocateInfo alloc_info = {};
@@ -1204,10 +901,9 @@ void RendererVK::CreateDeviceLocalBuffer(const void *inData, VkDeviceSize inSize
 	BufferVK staging_buffer;
 	CreateBuffer(inSize, VK_BUFFER_USAGE_TRANSFER_SRC_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, staging_buffer);
 
-	void *data;
-	vkMapMemory(mDevice, staging_buffer.mMemory, staging_buffer.mOffset, inSize, 0, &data);
+	void *data = MapBuffer(staging_buffer);
 	memcpy(data, inData, (size_t)inSize);
-	vkUnmapMemory(mDevice, staging_buffer.mMemory);
+	UnmapBuffer(staging_buffer);
 
 	CreateBuffer(inSize, inUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, outBuffer);
 
@@ -1216,33 +912,6 @@ void RendererVK::CreateDeviceLocalBuffer(const void *inData, VkDeviceSize inSize
 	FreeBuffer(staging_buffer);
 }
 
-void RendererVK::FreeBuffer(BufferVK &ioBuffer)
-{
-	if (ioBuffer.mBuffer != VK_NULL_HANDLE)
-	{
-		JPH_ASSERT(mFrameIndex < cFrameCount);
-		mFreedBuffers[mFrameIndex][{ ioBuffer.mSize, ioBuffer.mUsage, ioBuffer.mProperties }].push_back(ioBuffer);
-	}
-}
-
-void RendererVK::FreeBufferInternal(BufferVK &ioBuffer)
-{
-	// Destroy the buffer
-	vkDestroyBuffer(mDevice, ioBuffer.mBuffer, nullptr);
-	ioBuffer.mBuffer = VK_NULL_HANDLE;
-
-	if (ioBuffer.mAllocatedSize > cMaxAllocSize)
-		FreeMemory(ioBuffer.mMemory, ioBuffer.mAllocatedSize);
-	else
-		mMemoryCache[{ ioBuffer.mAllocatedSize, ioBuffer.mUsage, ioBuffer.mProperties }].push_back({ ioBuffer.mMemory, ioBuffer.mOffset });
-	ioBuffer.mMemory = VK_NULL_HANDLE;
-}
-
-unique_ptr<ConstantBufferVK> RendererVK::CreateConstantBuffer(VkDeviceSize inBufferSize)
-{
-	return make_unique<ConstantBufferVK>(this, inBufferSize);
-}
-
 VkImageView RendererVK::CreateImageView(VkImage inImage, VkFormat inFormat, VkImageAspectFlags inAspectFlags)
 {
 	VkImageViewCreateInfo view_info = {};
@@ -1260,7 +929,7 @@ VkImageView RendererVK::CreateImageView(VkImage inImage, VkFormat inFormat, VkIm
 	return image_view;
 }
 
-void RendererVK::CreateImage(uint32 inWidth, uint32 inHeight, VkFormat inFormat, VkImageTiling inTiling, VkImageUsageFlags inUsage, VkMemoryPropertyFlags inProperties, VkImage &outImage, VkDeviceMemory &outMemory)
+void RendererVK::CreateImage(uint32 inWidth, uint32 inHeight, VkFormat inFormat, VkImageTiling inTiling, VkImageUsageFlags inUsage, VkMemoryPropertyFlags inProperties, VkImage &outImage, MemoryVK &ioMemory)
 {
 	VkImageCreateInfo image_info = {};
 	image_info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
@@ -1281,19 +950,16 @@ void RendererVK::CreateImage(uint32 inWidth, uint32 inHeight, VkFormat inFormat,
 	VkMemoryRequirements mem_requirements;
 	vkGetImageMemoryRequirements(mDevice, outImage, &mem_requirements);
 
-	AllocateMemory(mem_requirements.size, mem_requirements.memoryTypeBits, inProperties, outMemory);
+	AllocateMemory(mem_requirements.size, mem_requirements.memoryTypeBits, inProperties, ioMemory);
 
-	vkBindImageMemory(mDevice, outImage, outMemory, 0);
+	vkBindImageMemory(mDevice, outImage, ioMemory.mMemory, 0);
 }
 
-void RendererVK::DestroyImage(VkImage inImage, VkDeviceMemory inMemory)
+void RendererVK::DestroyImage(VkImage inImage, MemoryVK &ioMemory)
 {
-	VkMemoryRequirements mem_requirements;
-	vkGetImageMemoryRequirements(mDevice, inImage, &mem_requirements);
-
 	vkDestroyImage(mDevice, inImage, nullptr);
 
-	FreeMemory(inMemory, mem_requirements.size);
+	FreeMemory(ioMemory);
 }
 
 void RendererVK::UpdateViewPortAndScissorRect(uint32 inWidth, uint32 inHeight)
@@ -1316,9 +982,17 @@ void RendererVK::UpdateViewPortAndScissorRect(uint32 inWidth, uint32 inHeight)
 	vkCmdSetScissor(command_buffer, 0, 1, &scissor);
 }
 
-#ifdef JPH_ENABLE_VULKAN
+// The default renderer is Vulkan if we don't have DirectX 12 or Metal
+#if !defined(JPH_USE_DX12) && !defined(JPH_USE_MTL)
+
 Renderer *Renderer::sCreate()
 {
 	return new RendererVK;
 }
-#endif
+
+#endif // !defined(JPH_USE_DX12) && !defined(JPH_USE_MTL)
+
+Renderer *CreateRendererVK()
+{
+	return new RendererVK;
+}

+ 24 - 74
TestFramework/Renderer/VK/RendererVK.h

@@ -5,21 +5,20 @@
 #pragma once
 
 #include <Renderer/Renderer.h>
-#include <Renderer/VK/ConstantBufferVK.h>
 #include <Renderer/VK/TextureVK.h>
-#include <Jolt/Core/UnorderedMap.h>
-
-#include <vulkan/vulkan.h>
+#include <Jolt/Compute/VK/ComputeSystemVKImpl.h>
 
 /// Vulkan renderer
-class RendererVK : public Renderer
+class RendererVK : public Renderer, public ComputeSystemVKImpl
 {
 public:
-	/// Destructor
+	/// Constructor / destructor
+									RendererVK();
 	virtual							~RendererVK() override;
 
 	// See: Renderer
 	virtual void					Initialize(ApplicationWindow *inWindow) override;
+	virtual ComputeSystem &			GetComputeSystem() override										{ return *this; }
 	virtual bool					BeginFrame(const CameraState &inCamera, float inWorldScale) override;
 	virtual void					EndShadowPass() override;
 	virtual void					EndFrame() override;
@@ -34,7 +33,6 @@ public:
 	virtual Texture *				GetShadowMap() const override									{ return mShadowMap.GetPtr(); }
 	virtual void					OnWindowResize() override;
 
-	VkDevice						GetDevice() const												{ return mDevice; }
 	VkDescriptorPool				GetDescriptorPool() const										{ return mDescriptorPool; }
 	VkDescriptorSetLayout			GetDescriptorSetLayoutTexture() const							{ return mDescriptorSetLayoutTexture; }
 	VkSampler						GetTextureSamplerRepeat() const									{ return mTextureSamplerRepeat; }
@@ -45,39 +43,30 @@ public:
 	VkCommandBuffer					GetCommandBuffer()												{ JPH_ASSERT(mInFrame); return mCommandBuffers[mFrameIndex]; }
 	VkCommandBuffer					StartTempCommandBuffer();
 	void							EndTempCommandBuffer(VkCommandBuffer inCommandBuffer);
-	void							AllocateMemory(VkDeviceSize inSize, uint32 inMemoryTypeBits, VkMemoryPropertyFlags inProperties, VkDeviceMemory &outMemory);
-	void							FreeMemory(VkDeviceMemory inMemory, VkDeviceSize inSize);
-	void							CreateBuffer(VkDeviceSize inSize, VkBufferUsageFlags inUsage, VkMemoryPropertyFlags inProperties, BufferVK &outBuffer);
 	void							CopyBuffer(VkBuffer inSrc, VkBuffer inDst, VkDeviceSize inSize);
 	void							CreateDeviceLocalBuffer(const void *inData, VkDeviceSize inSize, VkBufferUsageFlags inUsage, BufferVK &outBuffer);
-	void							FreeBuffer(BufferVK &ioBuffer);
-	unique_ptr<ConstantBufferVK>	CreateConstantBuffer(VkDeviceSize inBufferSize);
-	void							CreateImage(uint32 inWidth, uint32 inHeight, VkFormat inFormat, VkImageTiling inTiling, VkImageUsageFlags inUsage, VkMemoryPropertyFlags inProperties, VkImage &outImage, VkDeviceMemory &outMemory);
-	void							DestroyImage(VkImage inImage, VkDeviceMemory inMemory);
+	void							CreateImage(uint32 inWidth, uint32 inHeight, VkFormat inFormat, VkImageTiling inTiling, VkImageUsageFlags inUsage, VkMemoryPropertyFlags inProperties, VkImage &outImage, MemoryVK &ioMemory);
+	void							DestroyImage(VkImage inImage, MemoryVK &ioMemory);
 	VkImageView						CreateImageView(VkImage inImage, VkFormat inFormat, VkImageAspectFlags inAspectFlags);
 	VkFormat						FindDepthFormat();
+	void							FreeBufferDelayed(const BufferVK &inBuffer)						{ mPerFrameFreedBuffers[mFrameIndex].push_back(inBuffer); }
+
+protected:
+	// Callbacks from ComputeSystemVKImpl
+	virtual void					OnInstanceCreated() override;
+	virtual void					GetInstanceExtensions(Array<const char *> &outExtensions) override;
+	virtual void					GetDeviceExtensions(Array<const char *> &outExtensions) override;
+	virtual void					GetEnabledFeatures(VkPhysicalDeviceFeatures2 &ioFeatures) override;
+	virtual bool					HasPresentSupport(VkPhysicalDevice inDevice, uint32 inQueueFamilyIndex) override;
+	virtual VkSurfaceFormatKHR		SelectFormat(VkPhysicalDevice inDevice) override;
 
 private:
-	uint32							FindMemoryType(uint32 inTypeFilter, VkMemoryPropertyFlags inProperties);
-	void							FreeBufferInternal(BufferVK &ioBuffer);
-	VkSurfaceFormatKHR				SelectFormat(VkPhysicalDevice inDevice);
 	void							CreateSwapChain(VkPhysicalDevice inDevice);
 	void							DestroySwapChain();
 	void							UpdateViewPortAndScissorRect(uint32 inWidth, uint32 inHeight);
 	VkSemaphore						AllocateSemaphore();
 	void							FreeSemaphore(VkSemaphore inSemaphore);
 
-	VkInstance						mInstance = VK_NULL_HANDLE;
-#ifdef JPH_DEBUG
-	VkDebugUtilsMessengerEXT		mDebugMessenger = VK_NULL_HANDLE;
-#endif
-	VkPhysicalDevice				mPhysicalDevice = VK_NULL_HANDLE;
-	VkPhysicalDeviceMemoryProperties mMemoryProperties;
-	VkDevice						mDevice = VK_NULL_HANDLE;
-	uint32							mGraphicsQueueIndex = 0;
-	uint32							mPresentQueueIndex = 0;
-	VkQueue							mGraphicsQueue = VK_NULL_HANDLE;
-	VkQueue							mPresentQueue = VK_NULL_HANDLE;
 	VkSurfaceKHR					mSurface = VK_NULL_HANDLE;
 	VkSwapchainKHR					mSwapChain = VK_NULL_HANDLE;
 	bool							mSubOptimalSwapChain = false;
@@ -86,7 +75,7 @@ private:
 	VkExtent2D						mSwapChainExtent;
 	Array<VkImageView>				mSwapChainImageViews;
 	VkImage							mDepthImage = VK_NULL_HANDLE;
-	VkDeviceMemory					mDepthImageMemory = VK_NULL_HANDLE;
+	MemoryVK						mDepthImageMemory;
 	VkImageView						mDepthImageView = VK_NULL_HANDLE;
 	VkDescriptorSetLayout			mDescriptorSetLayoutUBO = VK_NULL_HANDLE;
 	VkDescriptorSetLayout			mDescriptorSetLayoutTexture = VK_NULL_HANDLE;
@@ -108,49 +97,10 @@ private:
 	Array<VkSemaphore>				mRenderFinishedSemaphores;
 	VkFence							mInFlightFences[cFrameCount];
 	Ref<TextureVK>					mShadowMap;
-	unique_ptr<ConstantBufferVK>	mVertexShaderConstantBufferProjection[cFrameCount];
-	unique_ptr<ConstantBufferVK>	mVertexShaderConstantBufferOrtho[cFrameCount];
-	unique_ptr<ConstantBufferVK>	mPixelShaderConstantBuffer[cFrameCount];
-
-	struct Key
-	{
-		bool						operator == (const Key &inRHS) const
-		{
-			return mSize == inRHS.mSize && mUsage == inRHS.mUsage && mProperties == inRHS.mProperties;
-		}
-
-		VkDeviceSize				mSize;
-		VkBufferUsageFlags			mUsage;
-		VkMemoryPropertyFlags		mProperties;
-	};
-
-	JPH_MAKE_HASH_STRUCT(Key, KeyHasher, t.mSize, t.mUsage, t.mProperties)
-
-	// We try to recycle buffers from frame to frame
-	using BufferCache = UnorderedMap<Key, Array<BufferVK>, KeyHasher>;
-
-	BufferCache						mFreedBuffers[cFrameCount];
-	BufferCache						mBufferCache;
-
-	// Smaller allocations (from cMinAllocSize to cMaxAllocSize) will be done in blocks of cBlockSize bytes.
-	// We do this because there is a limit to the number of allocations that we can make in Vulkan.
-	static constexpr VkDeviceSize	cMinAllocSize = 512;
-	static constexpr VkDeviceSize	cMaxAllocSize = 65536;
-	static constexpr VkDeviceSize	cBlockSize = 524288;
-
-	JPH_MAKE_HASH_STRUCT(Key, MemKeyHasher, t.mUsage, t.mProperties, t.mSize)
-
-	struct Memory
-	{
-		VkDeviceMemory				mMemory;
-		VkDeviceSize				mOffset;
-	};
-
-	using MemoryCache = UnorderedMap<Key, Array<Memory>, KeyHasher>;
-
-	MemoryCache						mMemoryCache;
-	uint32							mNumAllocations = 0;
-	uint32							mMaxNumAllocations = 0;
-	VkDeviceSize					mTotalAllocated = 0;
-	VkDeviceSize					mMaxTotalAllocated = 0;
+	Ref<ComputeBuffer>				mVertexShaderConstantBufferProjection[cFrameCount];
+	Ref<ComputeBuffer>				mVertexShaderConstantBufferOrtho[cFrameCount];
+	Ref<ComputeBuffer>				mPixelShaderConstantBuffer[cFrameCount];
+	Array<BufferVK>					mPerFrameFreedBuffers[cFrameCount];
 };
+
+extern Renderer *CreateRendererVK();

+ 2 - 4
TestFramework/Renderer/VK/TextureVK.cpp

@@ -48,18 +48,16 @@ TextureVK::TextureVK(RendererVK *inRenderer, const Surface *inSurface) :
 
 	int bpp = surface->GetBytesPerPixel();
 	VkDeviceSize image_size = VkDeviceSize(mWidth) * mHeight * bpp;
-	VkDevice device = mRenderer->GetDevice();
 
 	BufferVK staging_buffer;
 	mRenderer->CreateBuffer(image_size, VK_BUFFER_USAGE_TRANSFER_SRC_BIT, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, staging_buffer);
 
 	// Copy data to upload texture
 	surface->Lock(ESurfaceLockMode::Read);
-	void *data;
-	vkMapMemory(device, staging_buffer.mMemory, staging_buffer.mOffset, image_size, 0, &data);
+	void *data = mRenderer->MapBuffer(staging_buffer);
 	for (int y = 0; y < mHeight; ++y)
 		memcpy(reinterpret_cast<uint8 *>(data) + y * mWidth * bpp, surface->GetData() + y * surface->GetStride(), mWidth * bpp);
-	vkUnmapMemory(device, staging_buffer.mMemory);
+	mRenderer->UnmapBuffer(staging_buffer);
 	surface->UnLock();
 
 	// Create destination image

+ 2 - 3
TestFramework/Renderer/VK/TextureVK.h

@@ -5,8 +5,7 @@
 #pragma once
 
 #include <Renderer/Texture.h>
-
-#include <vulkan/vulkan.h>
+#include <Jolt/Compute/VK/BufferVK.h>
 
 class RendererVK;
 
@@ -29,7 +28,7 @@ private:
 
 	RendererVK *						mRenderer;
 	VkImage								mImage = VK_NULL_HANDLE;
-	VkDeviceMemory						mImageMemory = VK_NULL_HANDLE;
+	MemoryVK							mImageMemory;
 	VkImageView							mImageView = VK_NULL_HANDLE;
 	VkDescriptorSet						mDescriptorSet = VK_NULL_HANDLE;
 };

+ 2 - 3
TestFramework/Renderer/VK/VertexShaderVK.h

@@ -14,8 +14,7 @@ class VertexShaderVK : public VertexShader
 public:
 	/// Constructor
 							VertexShaderVK(VkDevice inDevice, VkShaderModule inShaderModule) :
-		mDevice(inDevice),
-		mStageInfo()
+		mDevice(inDevice)
 	{
 		mStageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
 		mStageInfo.stage = VK_SHADER_STAGE_VERTEX_BIT;
@@ -30,5 +29,5 @@ public:
 	}
 
 	VkDevice				mDevice;
-	VkPipelineShaderStageCreateInfo mStageInfo;
+	VkPipelineShaderStageCreateInfo mStageInfo = {};
 };

+ 108 - 104
TestFramework/TestFramework.cmake

@@ -1,7 +1,5 @@
-# Find Vulkan
-find_package(Vulkan)
-if (NOT CROSS_COMPILE_ARM AND (Vulkan_FOUND OR WIN32 OR ("${CMAKE_SYSTEM_NAME}" MATCHES "Darwin")))
-	# We have Vulkan/DirectX so we can compile TestFramework
+if (NOT CROSS_COMPILE_ARM AND (JPH_USE_VK OR JPH_USE_DX12 OR JPH_USE_MTL))
+	# We have Vulkan/DirectX/Metal so we can compile TestFramework
 	set(TEST_FRAMEWORK_AVAILABLE TRUE)
 
 	# Root
@@ -90,55 +88,60 @@ if (NOT CROSS_COMPILE_ARM AND (Vulkan_FOUND OR WIN32 OR ("${CMAKE_SYSTEM_NAME}"
 			${TEST_FRAMEWORK_ROOT}/Input/Win/KeyboardWin.h
 			${TEST_FRAMEWORK_ROOT}/Input/Win/MouseWin.cpp
 			${TEST_FRAMEWORK_ROOT}/Input/Win/MouseWin.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/ConstantBufferDX12.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/ConstantBufferDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/CommandQueueDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/DescriptorHeapDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/FatalErrorIfFailedDX12.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/FatalErrorIfFailedDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/PipelineStateDX12.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/PipelineStateDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/PixelShaderDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RendererDX12.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RendererDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderInstancesDX12.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderInstancesDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderPrimitiveDX12.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderPrimitiveDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/TextureDX12.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/TextureDX12.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/DX12/VertexShaderDX12.h
 			${TEST_FRAMEWORK_ROOT}/Utils/AssetStream.cpp
 			${TEST_FRAMEWORK_ROOT}/Utils/Log.cpp
 			${TEST_FRAMEWORK_ROOT}/Window/ApplicationWindowWin.cpp
 			${TEST_FRAMEWORK_ROOT}/Window/ApplicationWindowWin.h
 		)
 
-		# HLSL vertex shaders
-		set(TEST_FRAMEWORK_SRC_FILES_SHADERS
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/VertexConstants.h
-		)
-		set(TEST_FRAMEWORK_HLSL_VERTEX_SHADERS
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/FontVertexShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/LineVertexShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TriangleDepthVertexShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TriangleVertexShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/UIVertexShader.hlsl
-		)
-		set(TEST_FRAMEWORK_SRC_FILES_SHADERS ${TEST_FRAMEWORK_SRC_FILES_SHADERS} ${TEST_FRAMEWORK_HLSL_VERTEX_SHADERS})
-		set_source_files_properties(${TEST_FRAMEWORK_HLSL_VERTEX_SHADERS} PROPERTIES VS_SHADER_FLAGS "/WX /T vs_5_0")
+		# Include the DirectX renderer
+		if (JPH_USE_DX12)
+			# DirectX source files
+			set(TEST_FRAMEWORK_SRC_FILES
+				${TEST_FRAMEWORK_SRC_FILES}
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/DescriptorHeapDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/FatalErrorIfFailedDX12.cpp
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/FatalErrorIfFailedDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/PipelineStateDX12.cpp
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/PipelineStateDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/PixelShaderDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RendererDX12.cpp
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RendererDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderInstancesDX12.cpp
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderInstancesDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderPrimitiveDX12.cpp
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/RenderPrimitiveDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/TextureDX12.cpp
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/TextureDX12.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/DX12/VertexShaderDX12.h
+			)
 
-		# HLSL pixel shaders
-		set(TEST_FRAMEWORK_HLSL_PIXEL_SHADERS
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/FontPixelShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/LinePixelShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TriangleDepthPixelShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TrianglePixelShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/UIPixelShader.hlsl
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/UIPixelShaderUntextured.hlsl
-		)
-		set(TEST_FRAMEWORK_SRC_FILES_SHADERS ${TEST_FRAMEWORK_SRC_FILES_SHADERS} ${TEST_FRAMEWORK_HLSL_PIXEL_SHADERS})
-		set_source_files_properties(${TEST_FRAMEWORK_HLSL_PIXEL_SHADERS} PROPERTIES VS_SHADER_FLAGS "/WX /T ps_5_0")
+			# HLSL vertex shaders
+			set(TEST_FRAMEWORK_SRC_FILES_SHADERS
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/VertexConstants.h
+			)
+			set(TEST_FRAMEWORK_HLSL_VERTEX_SHADERS
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/FontVertexShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/LineVertexShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TriangleDepthVertexShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TriangleVertexShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/UIVertexShader.hlsl
+			)
+			set(TEST_FRAMEWORK_SRC_FILES_SHADERS ${TEST_FRAMEWORK_SRC_FILES_SHADERS} ${TEST_FRAMEWORK_HLSL_VERTEX_SHADERS})
+			set_source_files_properties(${TEST_FRAMEWORK_HLSL_VERTEX_SHADERS} PROPERTIES VS_SHADER_FLAGS "/WX /T vs_5_0")
+
+			# HLSL pixel shaders
+			set(TEST_FRAMEWORK_HLSL_PIXEL_SHADERS
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/FontPixelShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/LinePixelShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TriangleDepthPixelShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/TrianglePixelShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/UIPixelShader.hlsl
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/DX/UIPixelShaderUntextured.hlsl
+			)
+			set(TEST_FRAMEWORK_SRC_FILES_SHADERS ${TEST_FRAMEWORK_SRC_FILES_SHADERS} ${TEST_FRAMEWORK_HLSL_PIXEL_SHADERS})
+			set_source_files_properties(${TEST_FRAMEWORK_HLSL_PIXEL_SHADERS} PROPERTIES VS_SHADER_FLAGS "/WX /T ps_5_0")
+		endif()
 	endif()
 
 	if (LINUX)
@@ -156,24 +159,10 @@ if (NOT CROSS_COMPILE_ARM AND (Vulkan_FOUND OR WIN32 OR ("${CMAKE_SYSTEM_NAME}"
 		)
 	endif()
 		
-	if ("${CMAKE_SYSTEM_NAME}" MATCHES "Darwin")
+	if (APPLE)
 		# macOS source files
 		set(TEST_FRAMEWORK_SRC_FILES
 			${TEST_FRAMEWORK_SRC_FILES}
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/FatalErrorIfFailedMTL.mm
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/FatalErrorIfFailedMTL.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/PipelineStateMTL.mm
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/PipelineStateMTL.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/PixelShaderMTL.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RendererMTL.mm
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RendererMTL.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderInstancesMTL.mm
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderInstancesMTL.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderPrimitiveMTL.mm
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderPrimitiveMTL.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/TextureMTL.mm
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/TextureMTL.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/MTL/VertexShaderMTL.h
 			${TEST_FRAMEWORK_ROOT}/Input/MacOS/KeyboardMacOS.mm
 			${TEST_FRAMEWORK_ROOT}/Input/MacOS/KeyboardMacOS.h
 			${TEST_FRAMEWORK_ROOT}/Input/MacOS/MouseMacOS.mm
@@ -184,44 +173,63 @@ if (NOT CROSS_COMPILE_ARM AND (Vulkan_FOUND OR WIN32 OR ("${CMAKE_SYSTEM_NAME}"
 			${TEST_FRAMEWORK_ROOT}/Window/ApplicationWindowMacOS.h
 		)
 
-		# Metal shaders
-		set(TEST_FRAMEWORK_SRC_FILES_SHADERS
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/VertexConstants.h
-		)
-		set(TEST_FRAMEWORK_METAL_SHADERS
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/FontShader.metal
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/LineShader.metal
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/TriangleShader.metal
-			${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/UIShader.metal
-		)
+		# Include the Metal renderer
+		if (JPH_USE_MTL)
+			# Metal source files
+			set(TEST_FRAMEWORK_SRC_FILES
+				${TEST_FRAMEWORK_SRC_FILES}
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/FatalErrorIfFailedMTL.mm
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/FatalErrorIfFailedMTL.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/PipelineStateMTL.mm
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/PipelineStateMTL.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/PixelShaderMTL.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RendererMTL.mm
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RendererMTL.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderInstancesMTL.mm
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderInstancesMTL.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderPrimitiveMTL.mm
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/RenderPrimitiveMTL.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/TextureMTL.mm
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/TextureMTL.h
+				${TEST_FRAMEWORK_ROOT}/Renderer/MTL/VertexShaderMTL.h
+			)
 
-		# Compile Metal shaders
-		foreach(SHADER ${TEST_FRAMEWORK_METAL_SHADERS})
-			cmake_path(GET SHADER FILENAME AIR_SHADER)
-			set(AIR_SHADER "${CMAKE_CURRENT_BINARY_DIR}/${AIR_SHADER}.air")
-			add_custom_command(OUTPUT ${AIR_SHADER}
-				COMMAND xcrun -sdk macosx metal -c ${SHADER} -o ${AIR_SHADER}
-				DEPENDS ${SHADER}
-				COMMENT "Compiling ${SHADER}")
-			list(APPEND TEST_FRAMEWORK_AIR_SHADERS ${AIR_SHADER})
-		endforeach()
+			# Metal shaders
+			set(TEST_FRAMEWORK_SRC_FILES_SHADERS
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/VertexConstants.h
+			)
+			set(TEST_FRAMEWORK_METAL_SHADERS
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/FontShader.metal
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/LineShader.metal
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/TriangleShader.metal
+				${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/UIShader.metal
+			)
 
-		# Link Metal shaders
-		set(TEST_FRAMEWORK_METAL_LIB ${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/Shaders.metallib)
-		add_custom_command(OUTPUT ${TEST_FRAMEWORK_METAL_LIB}
-			COMMAND xcrun -sdk macosx metallib -o ${TEST_FRAMEWORK_METAL_LIB} ${TEST_FRAMEWORK_AIR_SHADERS}
-			DEPENDS ${TEST_FRAMEWORK_AIR_SHADERS}
-			COMMENT "Linking shaders")
+			# Compile Metal shaders
+			foreach(SHADER ${TEST_FRAMEWORK_METAL_SHADERS})
+				cmake_path(GET SHADER FILENAME AIR_SHADER)
+				set(AIR_SHADER "${CMAKE_CURRENT_BINARY_DIR}/${AIR_SHADER}.air")
+				add_custom_command(OUTPUT ${AIR_SHADER}
+					COMMAND xcrun -sdk macosx metal -c ${SHADER} -o ${AIR_SHADER}
+					DEPENDS ${SHADER}
+					COMMENT "Compiling ${SHADER}")
+				list(APPEND TEST_FRAMEWORK_MTL_SHADERS ${AIR_SHADER})
+			endforeach()
+
+			# Link Metal shaders
+			set(TEST_FRAMEWORK_METAL_LIB ${PHYSICS_REPO_ROOT}/Assets/Shaders/MTL/Shaders.metallib)
+			add_custom_command(OUTPUT ${TEST_FRAMEWORK_METAL_LIB}
+				COMMAND xcrun -sdk macosx metallib -o ${TEST_FRAMEWORK_METAL_LIB} ${TEST_FRAMEWORK_MTL_SHADERS}
+				DEPENDS ${TEST_FRAMEWORK_MTL_SHADERS}
+				COMMENT "Linking shaders")
+		endif()
 	endif()
 
-	# Include the Vulkan library
-	if (Vulkan_FOUND)
+	# Include the Vulkan renderer
+	if (JPH_USE_VK)
 		# Vulkan source files
 		set(TEST_FRAMEWORK_SRC_FILES
 			${TEST_FRAMEWORK_SRC_FILES}
-			${TEST_FRAMEWORK_ROOT}/Renderer/VK/BufferVK.h
-			${TEST_FRAMEWORK_ROOT}/Renderer/VK/ConstantBufferVK.cpp
-			${TEST_FRAMEWORK_ROOT}/Renderer/VK/ConstantBufferVK.h
 			${TEST_FRAMEWORK_ROOT}/Renderer/VK/FatalErrorIfFailedVK.cpp
 			${TEST_FRAMEWORK_ROOT}/Renderer/VK/FatalErrorIfFailedVK.h
 			${TEST_FRAMEWORK_ROOT}/Renderer/VK/PipelineStateVK.cpp
@@ -259,7 +267,8 @@ if (NOT CROSS_COMPILE_ARM AND (Vulkan_FOUND OR WIN32 OR ("${CMAKE_SYSTEM_NAME}"
 
 		# Compile GLSL shaders
 		foreach(SHADER ${TEST_FRAMEWORK_GLSL_SHADERS})
-			set(SPV_SHADER ${SHADER}.spv)
+			string(REPLACE ".vert" ".spv" SPV_SHADER ${SHADER})
+			string(REPLACE ".frag" ".spv" SPV_SHADER ${SPV_SHADER})
 			add_custom_command(OUTPUT ${SPV_SHADER}
 				COMMAND ${Vulkan_GLSLC_EXECUTABLE} ${SHADER} -o ${SPV_SHADER}
 				DEPENDS ${SHADER}
@@ -272,6 +281,8 @@ if (NOT CROSS_COMPILE_ARM AND (Vulkan_FOUND OR WIN32 OR ("${CMAKE_SYSTEM_NAME}"
 	set(TEST_FRAMEWORK_ASSETS
 		${PHYSICS_REPO_ROOT}/Assets/Fonts/Roboto-Regular.ttf
 		${PHYSICS_REPO_ROOT}/Assets/UI.tga
+		${JOLT_PHYSICS_SPV_SHADERS}
+		${JOLT_PHYSICS_METAL_LIB}
 		${TEST_FRAMEWORK_SRC_FILES_SHADERS}
 		${TEST_FRAMEWORK_HLSL_VERTEX_SHADERS}
 		${TEST_FRAMEWORK_HLSL_PIXEL_SHADERS}
@@ -293,29 +304,22 @@ if (NOT CROSS_COMPILE_ARM AND (Vulkan_FOUND OR WIN32 OR ("${CMAKE_SYSTEM_NAME}"
 	target_include_directories(TestFramework PUBLIC ${TEST_FRAMEWORK_ROOT})
 	target_precompile_headers(TestFramework PUBLIC ${TEST_FRAMEWORK_ROOT}/TestFramework.h)
 
-	if (Vulkan_FOUND)
-		# Vulkan configuration
-		target_include_directories(TestFramework PUBLIC ${Vulkan_INCLUDE_DIRS})
-		target_link_libraries(TestFramework LINK_PUBLIC Jolt ${Vulkan_LIBRARIES})
-		if (JPH_ENABLE_VULKAN)
-			target_compile_definitions(TestFramework PRIVATE JPH_ENABLE_VULKAN)
-		endif()
-	endif()
 	if (WIN32)
 		# Windows configuration
-		target_link_libraries(TestFramework LINK_PUBLIC Jolt dxguid.lib dinput8.lib dxgi.lib d3d12.lib d3dcompiler.lib shcore.lib)
+		target_link_libraries(TestFramework LINK_PUBLIC Jolt dinput8.lib shcore.lib)
 	endif()
 	if (LINUX)
 		# Linux configuration
 		target_link_libraries(TestFramework LINK_PUBLIC Jolt X11)
 	endif()
-	if ("${CMAKE_SYSTEM_NAME}" MATCHES "Darwin")
+	if (APPLE)
 		# macOS configuration
-		target_link_libraries(TestFramework LINK_PUBLIC Jolt "-framework Cocoa -framework Metal -framework MetalKit -framework GameController")
+		target_link_libraries(TestFramework LINK_PUBLIC Jolt "-framework Cocoa -framework GameController")
 
 		# Make sure that all test framework assets move to the Resources folder in the package
 		foreach(ASSET_FILE ${TEST_FRAMEWORK_ASSETS})
 			string(REPLACE ${PHYSICS_REPO_ROOT}/Assets "Resources" ASSET_DST ${ASSET_FILE})
+			string(REPLACE ${PHYSICS_REPO_ROOT}/Jolt "Resources/Jolt" ASSET_DST ${ASSET_DST})
 			get_filename_component(ASSET_DST ${ASSET_DST} DIRECTORY)
 			set_source_files_properties(${ASSET_FILE} PROPERTIES MACOSX_PACKAGE_LOCATION ${ASSET_DST})
 		endforeach()

+ 4 - 21
TestFramework/TestFramework.h

@@ -19,30 +19,13 @@ JPH_MSVC_SUPPRESS_WARNING(4062) // enumerator 'X' in switch of enum 'X' is not h
 
 #ifdef JPH_PLATFORM_WINDOWS
 
-// Targeting Windows 10 and above
-#define WINVER 0x0A00
-#define _WIN32_WINNT 0x0A00
-
-JPH_SUPPRESS_WARNING_PUSH
-JPH_MSVC_SUPPRESS_WARNING(5039) // winbase.h(13179): warning C5039: 'TpSetCallbackCleanupGroup': pointer or reference to potentially throwing function passed to 'extern "C"' function under -EHc. Undefined behavior may occur if this function throws an exception.
-JPH_MSVC_SUPPRESS_WARNING(5204) // implements.h(65): warning C5204: 'Microsoft::WRL::CloakedIid<IMarshal>': class has virtual functions, but its trivial destructor is not virtual; instances of objects derived from this class may not be destructed correctly
-JPH_MSVC_SUPPRESS_WARNING(4265) // implements.h(1449): warning C4265: 'Microsoft::WRL::FtmBase': class has virtual functions, but its non-trivial destructor is not virtual; instances of this class may not be destructed correctly
-JPH_MSVC_SUPPRESS_WARNING(5220) // implements.h(1648): warning C5220: 'Microsoft::WRL::Details::RuntimeClassImpl<Microsoft::WRL::RuntimeClassFlags<2>,true,false,true,IWeakReference>::refcount_': a non-static data member with a volatile qualified type no longer implies
-JPH_MSVC_SUPPRESS_WARNING(4986) // implements.h(2343): warning C4986: 'Microsoft::WRL::Details::RuntimeClassImpl<RuntimeClassFlagsT,true,true,false,I0,TInterfaces...>::GetWeakReference': exception specification does not match previous declaration
-JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
-#define WIN32_LEAN_AND_MEAN
+#ifdef _WIN32_WINNT
+	#undef _WIN32_WINNT
+#endif
 #define Ellipse DrawEllipse // Windows.h defines a name that we would like to use
-#include <windows.h>
+#include <Jolt/Compute/DX12/IncludeDX12.h>
 #undef Ellipse
-#undef min // We'd like to use std::min and max instead of the ones defined in windows.h
-#undef max
 #undef DrawText // We don't want this to map to DrawTextW
-#include <d3d12.h>
-#include <dxgi1_6.h>
-#include <wrl.h> // for ComPtr
-JPH_SUPPRESS_WARNING_POP
-
-using Microsoft::WRL::ComPtr;
 
 #elif defined(JPH_PLATFORM_LINUX)
 

+ 218 - 0
UnitTests/Compute/ComputeTests.cpp

@@ -0,0 +1,218 @@
+// Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
+// SPDX-FileCopyrightText: 2025 Jorrit Rouwe
+// SPDX-License-Identifier: MIT
+
+#include "UnitTestFramework.h"
+
+#if defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)
+
+#include <Jolt/Compute/ComputeSystem.h>
+#include <Jolt/Shaders/TestCompute.h>
+#include <Jolt/Core/IncludeWindows.h>
+
+JPH_SUPPRESS_WARNINGS_STD_BEGIN
+#include <fstream>
+#include <filesystem>
+#ifdef JPH_PLATFORM_LINUX
+#include <unistd.h>
+#endif
+JPH_SUPPRESS_WARNINGS_STD_END
+
+#if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
+#include <CoreFoundation/CoreFoundation.h>
+#endif
+
+TEST_SUITE("ComputeTests")
+{
+	static void RunTests(ComputeSystem *inComputeSystem)
+	{
+		inComputeSystem->mShaderLoader = [](const char *inName, Array<uint8> &outData) {
+		#if defined(JPH_PLATFORM_MACOS) || defined(JPH_PLATFORM_IOS)
+			// In macOS the shaders are copied to the bundle
+			CFBundleRef bundle = CFBundleGetMainBundle();
+			CFURLRef resources = CFBundleCopyResourcesDirectoryURL(bundle);
+			CFURLRef absolute = CFURLCopyAbsoluteURL(resources);
+			CFRelease(resources);
+			CFStringRef path_string = CFURLCopyFileSystemPath(absolute, kCFURLPOSIXPathStyle);
+			CFRelease(absolute);
+			char path[PATH_MAX];
+			CFStringGetCString(path_string, path, PATH_MAX, kCFStringEncodingUTF8);
+			CFRelease(path_string);
+    		String base_path = String(path) + "/Jolt/Shaders/";
+		#else
+			// On other platforms, start searching up from the application path
+			#ifdef JPH_PLATFORM_WINDOWS
+				char application_path[MAX_PATH] = { 0 };
+				GetModuleFileName(nullptr, application_path, MAX_PATH);
+			#elif defined(JPH_PLATFORM_LINUX)
+				char application_path[PATH_MAX] = { 0 };
+				int count = readlink("/proc/self/exe", application_path, PATH_MAX);
+				if (count > 0)
+					application_path[count] = 0;
+			#else
+				#error Unsupported platform
+			#endif
+			String base_path;
+			filesystem::path shader_path(application_path);
+			while (!shader_path.empty())
+			{
+				filesystem::path parent_path = shader_path.parent_path();
+				if (parent_path == shader_path)
+					break;
+				shader_path = parent_path;
+				filesystem::path full_path = shader_path / "Jolt" / "Shaders" / "";
+				if (filesystem::exists(full_path))
+				{
+					base_path = String(full_path.string());
+					break;
+				}
+			}
+		#endif
+
+			// Open file
+			std::ifstream input((base_path + inName).c_str(), std::ios::in | std::ios::binary);
+			if (!input.is_open())
+				return false;
+
+			// Read contents of file
+			input.seekg(0, ios_base::end);
+			ifstream::pos_type length = input.tellg();
+			input.seekg(0, ios_base::beg);
+			outData.resize(size_t(length));
+			input.read((char *)&outData[0], length);
+			return true;
+		};
+
+		constexpr uint32 cNumElements = 1234; // Not a multiple of cTestComputeGroupSize
+		constexpr uint32 cNumIterations = 10;
+		constexpr JPH_float3 cFloat3Value = JPH_float3(0, 0, 0);
+		constexpr JPH_float3 cFloat3Value2 = JPH_float3(0, 13, 0);
+		constexpr uint32 cUIntValue = 7;
+		constexpr uint32 cUploadValue = 42;
+
+		// Can't change context buffer while commands are queued, so create multiple constant buffers
+		Ref<ComputeBuffer> context[cNumIterations];
+		for (uint32 iter = 0; iter < cNumIterations; ++iter)
+			context[iter] = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::ConstantBuffer, 1, sizeof(TestComputeContext));
+		CHECK(context != nullptr);
+
+		// Create an upload buffer
+		Ref<ComputeBuffer> upload_buffer = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::UploadBuffer, 1, sizeof(uint32));
+		CHECK(upload_buffer != nullptr);
+		uint32 *upload_data = upload_buffer->Map<uint32>(ComputeBuffer::EMode::Write);
+		upload_data[0] = cUploadValue;
+		upload_buffer->Unmap();
+
+		// Create a read buffer
+		UnitTestRandom rnd;
+		Array<uint32> optional_data(cNumElements);
+		for (uint32 &d : optional_data)
+			d = rnd();
+		Ref<ComputeBuffer> optional_buffer = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::Buffer, cNumElements, sizeof(uint32), optional_data.data());
+		CHECK(optional_buffer != nullptr);
+
+		// Create a read-write buffer
+		Ref<ComputeBuffer> buffer = inComputeSystem->CreateComputeBuffer(ComputeBuffer::EType::RWBuffer, cNumElements, sizeof(uint32));
+		CHECK(buffer != nullptr);
+
+		// Create a read back buffer
+		Ref<ComputeBuffer> readback_buffer = buffer->CreateReadBackBuffer();
+		CHECK(readback_buffer != nullptr);
+
+		// Create the shader
+		Ref<ComputeShader> shader = inComputeSystem->CreateComputeShader("TestCompute", cTestComputeGroupSize);
+		CHECK(shader != nullptr);
+		if (shader == nullptr)
+		{
+			Trace("Shader could not be loaded. This can fail on macOS when dxc or spirv-cross could not be found so the shaders could not be compiled.");
+			return;
+		}
+
+		// Create the queue
+		Ref<ComputeQueue> queue = inComputeSystem->CreateComputeQueue();
+
+		// Schedule work
+		for (uint32 iter = 0; iter < cNumIterations; ++iter)
+		{
+			// Fill in the context
+			TestComputeContext *value = context[iter]->Map<TestComputeContext>(ComputeBuffer::EMode::Write);
+			value->cFloat3Value = cFloat3Value;
+			value->cUIntValue = cUIntValue;
+			value->cFloat3Value2 = cFloat3Value2;
+			value->cUIntValue2 = iter;
+			value->cNumElements = cNumElements;
+			context[iter]->Unmap();
+
+			queue->SetShader(shader);
+			queue->SetConstantBuffer("gContext", context[iter]);
+			context[iter] = nullptr; // Release the reference to ensure the queue keeps ownership
+			queue->SetBuffer("gOptionalData", optional_buffer);
+			optional_buffer = nullptr; // Release the reference so we test that the queue keeps ownership and that in the 2nd iteration we can set a null buffer
+			queue->SetBuffer("gUploadData", upload_buffer);
+			queue->SetRWBuffer("gData", buffer);
+			queue->Dispatch((cNumElements + cTestComputeGroupSize - 1) / cTestComputeGroupSize);
+		}
+
+		// Run all queued commands
+		queue->ScheduleReadback(readback_buffer, buffer);
+		queue->ExecuteAndWait();
+
+		// Calculate the expected result
+		Array<uint32> expected_data(cNumElements);
+		for (uint32 iter = 0; iter < cNumIterations; ++iter)
+		{
+			// Copy of the shader logic
+			uint cUIntValue2 = iter;
+			if (cUIntValue2 == 0)
+			{
+				// First write, uses optional data and tests that the packing of float3/uint3's works
+				for (uint32 i = 0; i < cNumElements; ++i)
+					expected_data[i] = optional_data[i] + int(cFloat3Value2.y) + cUploadValue;
+			}
+			else
+			{
+				// Read-modify-write gData
+				for (uint32 i = 0; i < cNumElements; ++i)
+					expected_data[i] = (expected_data[i] + cUIntValue) * cUIntValue2;
+			}
+		}
+
+		// Compare computed data with expected data
+		uint32 *data = readback_buffer->Map<uint32>(ComputeBuffer::EMode::Read);
+		for (uint32 i = 0; i < cNumElements; ++i)
+			CHECK(data[i] == expected_data[i]);
+		readback_buffer->Unmap();
+	}
+
+#ifdef JPH_USE_DX12
+	TEST_CASE("TestComputeDX12")
+	{
+		Ref<ComputeSystem> compute_system = CreateComputeSystemDX12();
+		CHECK(compute_system != nullptr);
+		if (compute_system != nullptr)
+			RunTests(compute_system);
+	}
+#endif // JPH_USE_DX12
+
+#ifdef JPH_USE_MTL
+	TEST_CASE("TestComputeMTL")
+	{
+		Ref<ComputeSystem> compute_system = CreateComputeSystemMTL();
+		CHECK(compute_system != nullptr);
+		if (compute_system != nullptr)
+			RunTests(compute_system);
+	}
+#endif // JPH_USE_MTL
+
+#ifdef JPH_USE_VK
+	TEST_CASE("TestComputeVK")
+	{
+		Ref<ComputeSystem> compute_system = CreateComputeSystemVK();
+		CHECK(compute_system != nullptr);
+		if (compute_system != nullptr)
+			RunTests(compute_system);
+	}
+#endif // JPH_USE_VK
+}
+
+#endif // defined(JPH_USE_DX12) || defined(JPH_USE_MTL) || defined(JPH_USE_VK)

+ 7 - 0
UnitTests/UnitTests.cmake

@@ -3,6 +3,7 @@ set(UNIT_TESTS_ROOT ${PHYSICS_REPO_ROOT}/UnitTests)
 
 # Source files
 set(UNIT_TESTS_SRC_FILES
+	${UNIT_TESTS_ROOT}/Compute/ComputeTests.cpp
 	${UNIT_TESTS_ROOT}/Core/ArrayTest.cpp
 	${UNIT_TESTS_ROOT}/Core/BinaryHeapTest.cpp
 	${UNIT_TESTS_ROOT}/Core/FPFlushDenormalsTest.cpp
@@ -89,5 +90,11 @@ if (ENABLE_OBJECT_STREAM)
 	)
 endif()
 
+# Assets used by the unit tests
+set(UNIT_TESTS_ASSETS
+	${JOLT_PHYSICS_SPV_SHADERS}
+	${JOLT_PHYSICS_METAL_LIB}
+)
+
 # Group source files
 source_group(TREE ${UNIT_TESTS_ROOT} FILES ${UNIT_TESTS_SRC_FILES})

Some files were not shown because too many files changed in this diff