ComputeQueueDX12.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include <Jolt/Jolt.h>
  5. #ifdef JPH_USE_DX12
  6. #include <Jolt/Compute/DX12/ComputeQueueDX12.h>
  7. #include <Jolt/Compute/DX12/ComputeShaderDX12.h>
  8. #include <Jolt/Compute/DX12/ComputeBufferDX12.h>
  9. JPH_NAMESPACE_BEGIN
  10. ComputeQueueDX12::~ComputeQueueDX12()
  11. {
  12. Wait();
  13. if (mFenceEvent != INVALID_HANDLE_VALUE)
  14. CloseHandle(mFenceEvent);
  15. }
  16. bool ComputeQueueDX12::Initialize(ID3D12Device *inDevice, D3D12_COMMAND_LIST_TYPE inType)
  17. {
  18. D3D12_COMMAND_QUEUE_DESC queue_desc = {};
  19. queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
  20. queue_desc.Type = inType;
  21. queue_desc.Priority = D3D12_COMMAND_QUEUE_PRIORITY_HIGH;
  22. if (HRFailed(inDevice->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&mCommandQueue))))
  23. return false;
  24. if (HRFailed(inDevice->CreateCommandAllocator(inType, IID_PPV_ARGS(&mCommandAllocator))))
  25. return false;
  26. // Create the command list
  27. if (HRFailed(inDevice->CreateCommandList(0, inType, mCommandAllocator.Get(), nullptr, IID_PPV_ARGS(&mCommandList))))
  28. return false;
  29. // Command lists are created in the recording state, but there is nothing to record yet. The main loop expects it to be closed, so close it now
  30. if (HRFailed(mCommandList->Close()))
  31. return false;
  32. // Create synchronization object
  33. if (HRFailed(inDevice->CreateFence(mFenceValue, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&mFence))))
  34. return false;
  35. // Increment fence value so we don't skip waiting the first time a command list is executed
  36. mFenceValue++;
  37. // Create an event handle to use for frame synchronization
  38. mFenceEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
  39. if (HRFailed(HRESULT_FROM_WIN32(GetLastError())))
  40. return false;
  41. return true;
  42. }
  43. ID3D12GraphicsCommandList *ComputeQueueDX12::Start()
  44. {
  45. JPH_ASSERT(!mIsExecuting);
  46. if (!mIsStarted)
  47. {
  48. // Reset the allocator
  49. if (HRFailed(mCommandAllocator->Reset()))
  50. return nullptr;
  51. // Reset the command list
  52. if (HRFailed(mCommandList->Reset(mCommandAllocator.Get(), nullptr)))
  53. return nullptr;
  54. // Now we have started recording commands
  55. mIsStarted = true;
  56. }
  57. return mCommandList.Get();
  58. }
  59. void ComputeQueueDX12::SetShader(const ComputeShader *inShader)
  60. {
  61. ID3D12GraphicsCommandList *command_list = Start();
  62. mShader = static_cast<const ComputeShaderDX12 *>(inShader);
  63. command_list->SetPipelineState(mShader->GetPipelineState());
  64. command_list->SetComputeRootSignature(mShader->GetRootSignature());
  65. }
  66. void ComputeQueueDX12::SyncCPUToGPU(const ComputeBufferDX12 *inBuffer)
  67. {
  68. // Ensure that any CPU writes are visible to the GPU
  69. if (inBuffer->SyncCPUToGPU(mCommandList.Get()))
  70. {
  71. // After the first upload, the CPU buffer is no longer needed for Buffer and RWBuffer types
  72. if (inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer)
  73. mDelayedFreedBuffers.emplace_back(inBuffer->ReleaseResourceCPU());
  74. }
  75. }
  76. void ComputeQueueDX12::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
  77. {
  78. if (inBuffer == nullptr)
  79. return;
  80. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
  81. ID3D12GraphicsCommandList *command_list = Start();
  82. const ComputeBufferDX12 *buffer = static_cast<const ComputeBufferDX12 *>(inBuffer);
  83. command_list->SetComputeRootConstantBufferView(mShader->NameToIndex(inName), buffer->GetResourceCPU()->GetGPUVirtualAddress());
  84. mUsedBuffers.insert(buffer);
  85. }
  86. void ComputeQueueDX12::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
  87. {
  88. if (inBuffer == nullptr)
  89. return;
  90. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
  91. ID3D12GraphicsCommandList *command_list = Start();
  92. const ComputeBufferDX12 *buffer = static_cast<const ComputeBufferDX12 *>(inBuffer);
  93. uint parameter_index = mShader->NameToIndex(inName);
  94. SyncCPUToGPU(buffer);
  95. buffer->Barrier(command_list, D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE);
  96. command_list->SetComputeRootShaderResourceView(parameter_index, buffer->GetResourceGPU()->GetGPUVirtualAddress());
  97. mUsedBuffers.insert(buffer);
  98. }
  99. void ComputeQueueDX12::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
  100. {
  101. if (inBuffer == nullptr)
  102. return;
  103. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
  104. ID3D12GraphicsCommandList *command_list = Start();
  105. ComputeBufferDX12 *buffer = static_cast<ComputeBufferDX12 *>(inBuffer);
  106. uint parameter_index = mShader->NameToIndex(inName);
  107. SyncCPUToGPU(buffer);
  108. if (!buffer->Barrier(command_list, D3D12_RESOURCE_STATE_UNORDERED_ACCESS) && inBarrier == EBarrier::Yes)
  109. buffer->RWBarrier(command_list);
  110. command_list->SetComputeRootUnorderedAccessView(parameter_index, buffer->GetResourceGPU()->GetGPUVirtualAddress());
  111. mUsedBuffers.insert(buffer);
  112. }
  113. void ComputeQueueDX12::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
  114. {
  115. if (inDst == nullptr || inSrc == nullptr)
  116. return;
  117. JPH_ASSERT(inDst->GetType() == ComputeBuffer::EType::ReadbackBuffer);
  118. ID3D12GraphicsCommandList *command_list = Start();
  119. ComputeBufferDX12 *dst = static_cast<ComputeBufferDX12 *>(inDst);
  120. const ComputeBufferDX12 *src = static_cast<const ComputeBufferDX12 *>(inSrc);
  121. dst->Barrier(command_list, D3D12_RESOURCE_STATE_COPY_DEST);
  122. src->Barrier(command_list, D3D12_RESOURCE_STATE_COPY_SOURCE);
  123. command_list->CopyResource(dst->GetResourceCPU(), src->GetResourceGPU());
  124. mUsedBuffers.insert(src);
  125. mUsedBuffers.insert(dst);
  126. }
  127. void ComputeQueueDX12::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
  128. {
  129. ID3D12GraphicsCommandList *command_list = Start();
  130. command_list->Dispatch(inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
  131. }
  132. void ComputeQueueDX12::Execute()
  133. {
  134. JPH_ASSERT(mIsStarted);
  135. JPH_ASSERT(!mIsExecuting);
  136. // Close the command list
  137. if (HRFailed(mCommandList->Close()))
  138. return;
  139. // Execute the command list
  140. ID3D12CommandList *command_lists[] = { mCommandList.Get() };
  141. mCommandQueue->ExecuteCommandLists(std::size(command_lists), command_lists);
  142. // Schedule a Signal command in the queue
  143. if (HRFailed(mCommandQueue->Signal(mFence.Get(), mFenceValue)))
  144. return;
  145. // Clear the current shader
  146. mShader = nullptr;
  147. // Mark that we're executing
  148. mIsExecuting = true;
  149. }
  150. void ComputeQueueDX12::Wait()
  151. {
  152. // Check if we've been started
  153. if (mIsExecuting)
  154. {
  155. if (mFence->GetCompletedValue() < mFenceValue)
  156. {
  157. // Wait until the fence has been processed
  158. if (HRFailed(mFence->SetEventOnCompletion(mFenceValue, mFenceEvent)))
  159. return;
  160. WaitForSingleObjectEx(mFenceEvent, INFINITE, FALSE);
  161. }
  162. // Increment the fence value
  163. mFenceValue++;
  164. // Buffers can be freed now
  165. mUsedBuffers.clear();
  166. // Free buffers
  167. mDelayedFreedBuffers.clear();
  168. // Done executing
  169. mIsExecuting = false;
  170. mIsStarted = false;
  171. }
  172. }
  173. JPH_NAMESPACE_END
  174. #endif // JPH_USE_DX12