ComputeSystemDX12.cpp 14 KB


  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include <Jolt/Jolt.h>
  5. #ifdef JPH_USE_DX12
  6. #include <Jolt/Compute/DX12/ComputeSystemDX12.h>
  7. #include <Jolt/Compute/DX12/ComputeQueueDX12.h>
  8. #include <Jolt/Compute/DX12/ComputeShaderDX12.h>
  9. #include <Jolt/Compute/DX12/ComputeBufferDX12.h>
  10. #include <Jolt/Core/StringTools.h>
  11. #include <Jolt/Core/UnorderedMap.h>
  12. JPH_SUPPRESS_WARNINGS_STD_BEGIN
  13. JPH_MSVC_SUPPRESS_WARNING(5204) // 'X': class has virtual functions, but its trivial destructor is not virtual; instances of objects derived from this class may not be destructed correctly
  14. JPH_MSVC2026_PLUS_SUPPRESS_WARNING(4865) // wingdi.h(2806,1): '<unnamed-enum-DISPLAYCONFIG_OUTPUT_TECHNOLOGY_OTHER>': the underlying type will change from 'int' to '__int64' when '/Zc:enumTypes' is specified on the command line
  15. #include <fstream>
  16. #include <d3dcompiler.h>
  17. #include <dxcapi.h>
  18. #ifdef JPH_DEBUG
  19. #include <d3d12sdklayers.h>
  20. #endif
  21. JPH_SUPPRESS_WARNINGS_STD_END
  22. JPH_NAMESPACE_BEGIN
  23. JPH_IMPLEMENT_RTTI_VIRTUAL(ComputeSystemDX12)
  24. {
  25. JPH_ADD_BASE_CLASS(ComputeSystemDX12, ComputeSystem)
  26. }
  27. void ComputeSystemDX12::Initialize(ID3D12Device *inDevice, EDebug inDebug)
  28. {
  29. mDevice = inDevice;
  30. mDebug = inDebug;
  31. }
  32. void ComputeSystemDX12::Shutdown()
  33. {
  34. mDevice.Reset();
  35. }
  36. ComPtr<ID3D12Resource> ComputeSystemDX12::CreateD3DResource(D3D12_HEAP_TYPE inHeapType, D3D12_RESOURCE_STATES inResourceState, D3D12_RESOURCE_FLAGS inFlags, uint64 inSize)
  37. {
  38. // Create a new resource
  39. D3D12_RESOURCE_DESC desc;
  40. desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
  41. desc.Alignment = 0;
  42. desc.Width = inSize;
  43. desc.Height = 1;
  44. desc.DepthOrArraySize = 1;
  45. desc.MipLevels = 1;
  46. desc.Format = DXGI_FORMAT_UNKNOWN;
  47. desc.SampleDesc.Count = 1;
  48. desc.SampleDesc.Quality = 0;
  49. desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
  50. desc.Flags = inFlags;
  51. D3D12_HEAP_PROPERTIES heap_properties = {};
  52. heap_properties.Type = inHeapType;
  53. heap_properties.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
  54. heap_properties.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
  55. heap_properties.CreationNodeMask = 1;
  56. heap_properties.VisibleNodeMask = 1;
  57. ComPtr<ID3D12Resource> resource;
  58. if (HRFailed(mDevice->CreateCommittedResource(&heap_properties, D3D12_HEAP_FLAG_NONE, &desc, inResourceState, nullptr, IID_PPV_ARGS(&resource))))
  59. return nullptr;
  60. return resource;
  61. }
  62. ComputeShaderResult ComputeSystemDX12::CreateComputeShader(const char *inName, uint32 inGroupSizeX, uint32 inGroupSizeY, uint32 inGroupSizeZ)
  63. {
  64. ComputeShaderResult result;
  65. // Read shader source file
  66. Array<uint8> data;
  67. String error;
  68. String file_name = String(inName) + ".hlsl";
  69. if (!mShaderLoader(file_name.c_str(), data, error))
  70. {
  71. result.SetError(error);
  72. return result;
  73. }
  74. #ifndef JPH_USE_DXC // Use FXC, the old shader compiler?
  75. UINT flags = D3DCOMPILE_ENABLE_STRICTNESS | D3DCOMPILE_WARNINGS_ARE_ERRORS | D3DCOMPILE_ALL_RESOURCES_BOUND;
  76. #ifdef JPH_DEBUG
  77. flags |= D3DCOMPILE_SKIP_OPTIMIZATION;
  78. #else
  79. flags |= D3DCOMPILE_OPTIMIZATION_LEVEL3;
  80. #endif
  81. if (mDebug == EDebug::DebugSymbols)
  82. flags |= D3DCOMPILE_DEBUG;
  83. const D3D_SHADER_MACRO defines[] =
  84. {
  85. { nullptr, nullptr }
  86. };
  87. // Handles loading include files through the shader loader
  88. struct IncludeHandler : public ID3DInclude
  89. {
  90. IncludeHandler(const ShaderLoader &inShaderLoader) : mShaderLoader(inShaderLoader) { }
  91. virtual ~IncludeHandler() = default;
  92. STDMETHOD (Open)(D3D_INCLUDE_TYPE, LPCSTR inFileName, LPCVOID, LPCVOID *outData, UINT *outNumBytes) override
  93. {
  94. // Read the header file
  95. Array<uint8> file_data;
  96. String error;
  97. if (!mShaderLoader(inFileName, file_data, error))
  98. return E_FAIL;
  99. if (file_data.empty())
  100. {
  101. *outData = nullptr;
  102. *outNumBytes = 0;
  103. return S_OK;
  104. }
  105. // Copy to a new memory block
  106. void *mem = CoTaskMemAlloc(file_data.size());
  107. if (mem == nullptr)
  108. return E_OUTOFMEMORY;
  109. memcpy(mem, file_data.data(), file_data.size());
  110. *outData = mem;
  111. *outNumBytes = (UINT)file_data.size();
  112. return S_OK;
  113. }
  114. STDMETHOD (Close)(LPCVOID inData) override
  115. {
  116. if (inData != nullptr)
  117. CoTaskMemFree(const_cast<void *>(inData));
  118. return S_OK;
  119. }
  120. private:
  121. const ShaderLoader & mShaderLoader;
  122. };
  123. IncludeHandler include_handler(mShaderLoader);
  124. // Compile source
  125. ComPtr<ID3DBlob> shader_blob, error_blob;
  126. if (FAILED(D3DCompile(&data[0],
  127. (uint)data.size(),
  128. file_name.c_str(),
  129. defines,
  130. &include_handler,
  131. "main",
  132. "cs_5_0",
  133. flags,
  134. 0,
  135. shader_blob.GetAddressOf(),
  136. error_blob.GetAddressOf())))
  137. {
  138. if (error_blob)
  139. result.SetError((const char *)error_blob->GetBufferPointer());
  140. else
  141. result.SetError("Shader compile error");
  142. return result;
  143. }
  144. // Get shader description
  145. ComPtr<ID3D12ShaderReflection> reflector;
  146. if (FAILED(D3DReflect(shader_blob->GetBufferPointer(), shader_blob->GetBufferSize(), IID_PPV_ARGS(&reflector))))
  147. {
  148. result.SetError("Failed to reflect shader");
  149. return result;
  150. }
  151. #else
  152. ComPtr<IDxcUtils> utils;
  153. DxcCreateInstance(CLSID_DxcUtils, IID_PPV_ARGS(utils.GetAddressOf()));
  154. // Custom include handler that forwards include loads to mShaderLoader
  155. struct DxcIncludeHandler : public IDxcIncludeHandler
  156. {
  157. DxcIncludeHandler(IDxcUtils *inUtils, const ShaderLoader &inLoader) : mUtils(inUtils), mShaderLoader(inLoader) { }
  158. virtual ~DxcIncludeHandler() = default;
  159. STDMETHODIMP QueryInterface(REFIID riid, void **ppvObject) override
  160. {
  161. JPH_ASSERT(false);
  162. return E_NOINTERFACE;
  163. }
  164. STDMETHODIMP_(ULONG) AddRef(void) override
  165. {
  166. // Allocated on the stack, we don't do ref counting
  167. return 1;
  168. }
  169. STDMETHODIMP_(ULONG) Release(void) override
  170. {
  171. // Allocated on the stack, we don't do ref counting
  172. return 1;
  173. }
  174. // IDxcIncludeHandler::LoadSource uses IDxcBlob**
  175. STDMETHODIMP LoadSource(LPCWSTR inFilename, IDxcBlob **outIncludeSource) override
  176. {
  177. *outIncludeSource = nullptr;
  178. // Convert to UTF-8
  179. char file_name[MAX_PATH];
  180. WideCharToMultiByte(CP_UTF8, 0, inFilename, -1, file_name, sizeof(file_name), nullptr, nullptr);
  181. // Load the header
  182. Array<uint8> file_data;
  183. String error;
  184. if (!mShaderLoader(file_name, file_data, error))
  185. return E_FAIL;
  186. // Create a blob from the loaded data
  187. ComPtr<IDxcBlobEncoding> blob_encoder;
  188. HRESULT hr = mUtils->CreateBlob(file_data.empty()? nullptr : file_data.data(), (uint)file_data.size(), CP_UTF8, blob_encoder.GetAddressOf());
  189. if (FAILED(hr))
  190. return hr;
  191. // Return as IDxcBlob
  192. *outIncludeSource = blob_encoder.Detach();
  193. return S_OK;
  194. }
  195. IDxcUtils * mUtils;
  196. const ShaderLoader & mShaderLoader;
  197. };
  198. DxcIncludeHandler include_handler(utils.Get(), mShaderLoader);
  199. ComPtr<IDxcBlobEncoding> source;
  200. if (HRFailed(utils->CreateBlob(data.data(), (uint)data.size(), CP_UTF8, source.GetAddressOf()), result))
  201. return result;
  202. ComPtr<IDxcCompiler3> compiler;
  203. DxcCreateInstance(CLSID_DxcCompiler, IID_PPV_ARGS(compiler.GetAddressOf()));
  204. Array<LPCWSTR> arguments;
  205. arguments.push_back(L"-E");
  206. arguments.push_back(L"main");
  207. arguments.push_back(L"-T");
  208. arguments.push_back(L"cs_6_0");
  209. arguments.push_back(DXC_ARG_WARNINGS_ARE_ERRORS);
  210. arguments.push_back(DXC_ARG_OPTIMIZATION_LEVEL3);
  211. arguments.push_back(DXC_ARG_ALL_RESOURCES_BOUND);
  212. if (mDebug == EDebug::DebugSymbols)
  213. {
  214. arguments.push_back(DXC_ARG_DEBUG);
  215. arguments.push_back(L"-Qembed_debug");
  216. }
  217. // Provide file name so tools know what the original shader was called (the actual source comes from the blob)
  218. wchar_t w_file_name[MAX_PATH];
  219. MultiByteToWideChar(CP_UTF8, 0, file_name.c_str(), -1, w_file_name, MAX_PATH);
  220. arguments.push_back(w_file_name);
  221. // Compile the shader
  222. DxcBuffer source_buffer;
  223. source_buffer.Ptr = source->GetBufferPointer();
  224. source_buffer.Size = source->GetBufferSize();
  225. source_buffer.Encoding = 0;
  226. ComPtr<IDxcResult> compile_result;
  227. if (FAILED(compiler->Compile(&source_buffer, arguments.data(), (uint32)arguments.size(), &include_handler, IID_PPV_ARGS(compile_result.GetAddressOf()))))
  228. {
  229. result.SetError("Failed to compile shader");
  230. return result;
  231. }
  232. // Check for compilation errors
  233. ComPtr<IDxcBlobUtf8> errors;
  234. compile_result->GetOutput(DXC_OUT_ERRORS, IID_PPV_ARGS(errors.GetAddressOf()), nullptr);
  235. if (errors != nullptr && errors->GetStringLength() > 0)
  236. {
  237. result.SetError((const char *)errors->GetBufferPointer());
  238. return result;
  239. }
  240. // Get the compiled shader code
  241. ComPtr<ID3DBlob> shader_blob;
  242. if (HRFailed(compile_result->GetOutput(DXC_OUT_OBJECT, IID_PPV_ARGS(shader_blob.GetAddressOf()), nullptr), result))
  243. return result;
  244. // Get reflection data
  245. ComPtr<IDxcBlob> reflection_data;
  246. if (HRFailed(compile_result->GetOutput(DXC_OUT_REFLECTION, IID_PPV_ARGS(reflection_data.GetAddressOf()), nullptr), result))
  247. return result;
  248. DxcBuffer reflection_buffer;
  249. reflection_buffer.Ptr = reflection_data->GetBufferPointer();
  250. reflection_buffer.Size = reflection_data->GetBufferSize();
  251. reflection_buffer.Encoding = 0;
  252. ComPtr<ID3D12ShaderReflection> reflector;
  253. if (HRFailed(utils->CreateReflection(&reflection_buffer, IID_PPV_ARGS(reflector.GetAddressOf())), result))
  254. return result;
  255. #endif // JPH_USE_DXC
  256. // Get the shader description
  257. D3D12_SHADER_DESC shader_desc;
  258. if (HRFailed(reflector->GetDesc(&shader_desc), result))
  259. return result;
  260. // Verify that the group sizes match the shader's thread group size
  261. UINT thread_group_size_x, thread_group_size_y, thread_group_size_z;
  262. if (HRFailed(reflector->GetThreadGroupSize(&thread_group_size_x, &thread_group_size_y, &thread_group_size_z), result))
  263. return result;
  264. JPH_ASSERT(inGroupSizeX == thread_group_size_x, "Group size X mismatch");
  265. JPH_ASSERT(inGroupSizeY == thread_group_size_y, "Group size Y mismatch");
  266. JPH_ASSERT(inGroupSizeZ == thread_group_size_z, "Group size Z mismatch");
  267. // Convert parameters to root signature description
  268. Array<String> binding_names;
  269. binding_names.reserve(shader_desc.BoundResources);
  270. UnorderedMap<string_view, uint> name_to_index;
  271. Array<D3D12_ROOT_PARAMETER1> root_params;
  272. for (UINT i = 0; i < shader_desc.BoundResources; ++i)
  273. {
  274. D3D12_SHADER_INPUT_BIND_DESC bind_desc;
  275. reflector->GetResourceBindingDesc(i, &bind_desc);
  276. D3D12_ROOT_PARAMETER1 param = {};
  277. param.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
  278. switch (bind_desc.Type)
  279. {
  280. case D3D_SIT_CBUFFER:
  281. param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_CBV;
  282. break;
  283. case D3D_SIT_STRUCTURED:
  284. case D3D_SIT_BYTEADDRESS:
  285. param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_SRV;
  286. break;
  287. case D3D_SIT_UAV_RWTYPED:
  288. case D3D_SIT_UAV_RWSTRUCTURED:
  289. case D3D_SIT_UAV_RWBYTEADDRESS:
  290. case D3D_SIT_UAV_APPEND_STRUCTURED:
  291. case D3D_SIT_UAV_CONSUME_STRUCTURED:
  292. case D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER:
  293. param.ParameterType = D3D12_ROOT_PARAMETER_TYPE_UAV;
  294. break;
  295. case D3D_SIT_TBUFFER:
  296. case D3D_SIT_TEXTURE:
  297. case D3D_SIT_SAMPLER:
  298. case D3D_SIT_RTACCELERATIONSTRUCTURE:
  299. case D3D_SIT_UAV_FEEDBACKTEXTURE:
  300. JPH_ASSERT(false, "Unsupported shader input type");
  301. continue;
  302. }
  303. param.Descriptor.RegisterSpace = bind_desc.Space;
  304. param.Descriptor.ShaderRegister = bind_desc.BindPoint;
  305. param.Descriptor.Flags = D3D12_ROOT_DESCRIPTOR_FLAG_DATA_VOLATILE;
  306. binding_names.push_back(bind_desc.Name); // Add all strings to a pool to keep them alive
  307. name_to_index[string_view(binding_names.back())] = (uint)root_params.size();
  308. root_params.push_back(param);
  309. }
  310. // Create the root signature
  311. D3D12_VERSIONED_ROOT_SIGNATURE_DESC root_sig_desc = {};
  312. root_sig_desc.Version = D3D_ROOT_SIGNATURE_VERSION_1_1;
  313. root_sig_desc.Desc_1_1.NumParameters = (UINT)root_params.size();
  314. root_sig_desc.Desc_1_1.pParameters = root_params.data();
  315. root_sig_desc.Desc_1_1.NumStaticSamplers = 0;
  316. root_sig_desc.Desc_1_1.pStaticSamplers = nullptr;
  317. root_sig_desc.Desc_1_1.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
  318. ComPtr<ID3DBlob> serialized_sig;
  319. ComPtr<ID3DBlob> root_sig_error_blob;
  320. if (FAILED(D3D12SerializeVersionedRootSignature(&root_sig_desc, &serialized_sig, &root_sig_error_blob)))
  321. {
  322. if (root_sig_error_blob)
  323. {
  324. error = StringFormat("Failed to create root signature: %s", (const char *)root_sig_error_blob->GetBufferPointer());
  325. result.SetError(error);
  326. }
  327. else
  328. result.SetError("Failed to create root signature");
  329. return result;
  330. }
  331. ComPtr<ID3D12RootSignature> root_sig;
  332. if (FAILED(mDevice->CreateRootSignature(0, serialized_sig->GetBufferPointer(), serialized_sig->GetBufferSize(), IID_PPV_ARGS(&root_sig))))
  333. {
  334. result.SetError("Failed to create root signature");
  335. return result;
  336. }
  337. // Create a pipeline state object from the root signature and the shader
  338. ComPtr<ID3D12PipelineState> pipeline_state;
  339. D3D12_COMPUTE_PIPELINE_STATE_DESC compute_state_desc = {};
  340. compute_state_desc.pRootSignature = root_sig.Get();
  341. compute_state_desc.CS = { shader_blob->GetBufferPointer(), shader_blob->GetBufferSize() };
  342. if (FAILED(mDevice->CreateComputePipelineState(&compute_state_desc, IID_PPV_ARGS(&pipeline_state))))
  343. {
  344. result.SetError("Failed to create compute pipeline state");
  345. return result;
  346. }
  347. // Set name on DX12 objects for easier debugging
  348. wchar_t w_name[1024];
  349. size_t converted_chars = 0;
  350. mbstowcs_s(&converted_chars, w_name, 1024, inName, _TRUNCATE);
  351. pipeline_state->SetName(w_name);
  352. result.Set(new ComputeShaderDX12(shader_blob, root_sig, pipeline_state, std::move(binding_names), std::move(name_to_index), inGroupSizeX, inGroupSizeY, inGroupSizeZ));
  353. return result;
  354. }
  355. ComputeBufferResult ComputeSystemDX12::CreateComputeBuffer(ComputeBuffer::EType inType, uint64 inSize, uint inStride, const void *inData)
  356. {
  357. ComputeBufferResult result;
  358. Ref<ComputeBufferDX12> buffer = new ComputeBufferDX12(this, inType, inSize, inStride);
  359. if (!buffer->Initialize(inData))
  360. {
  361. result.SetError("Failed to create compute buffer");
  362. return result;
  363. }
  364. result.Set(buffer.GetPtr());
  365. return result;
  366. }
  367. ComputeQueueResult ComputeSystemDX12::CreateComputeQueue()
  368. {
  369. ComputeQueueResult result;
  370. Ref<ComputeQueueDX12> queue = new ComputeQueueDX12();
  371. if (!queue->Initialize(mDevice.Get(), D3D12_COMMAND_LIST_TYPE_COMPUTE, result))
  372. return result;
  373. result.Set(queue.GetPtr());
  374. return result;
  375. }
  376. JPH_NAMESPACE_END
  377. #endif // JPH_USE_DX12