ComputeQueueVK.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include <Jolt/Jolt.h>
  5. #ifdef JPH_USE_VK
  6. #include <Jolt/Compute/VK/ComputeQueueVK.h>
  7. #include <Jolt/Compute/VK/ComputeBufferVK.h>
  8. #include <Jolt/Compute/VK/ComputeSystemVK.h>
  9. JPH_NAMESPACE_BEGIN
  10. ComputeQueueVK::~ComputeQueueVK()
  11. {
  12. Wait();
  13. VkDevice device = mComputeSystem->GetDevice();
  14. if (mCommandBuffer != VK_NULL_HANDLE)
  15. vkFreeCommandBuffers(device, mCommandPool, 1, &mCommandBuffer);
  16. if (mCommandPool != VK_NULL_HANDLE)
  17. vkDestroyCommandPool(device, mCommandPool, nullptr);
  18. if (mDescriptorPool != VK_NULL_HANDLE)
  19. vkDestroyDescriptorPool(device, mDescriptorPool, nullptr);
  20. if (mFence != VK_NULL_HANDLE)
  21. vkDestroyFence(device, mFence, nullptr);
  22. }
  23. bool ComputeQueueVK::Initialize(uint32 inComputeQueueIndex)
  24. {
  25. // Get the queue
  26. VkDevice device = mComputeSystem->GetDevice();
  27. vkGetDeviceQueue(device, inComputeQueueIndex, 0, &mQueue);
  28. // Create a command pool
  29. VkCommandPoolCreateInfo pool_info = {};
  30. pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
  31. pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
  32. pool_info.queueFamilyIndex = inComputeQueueIndex;
  33. if (VKFailed(vkCreateCommandPool(device, &pool_info, nullptr, &mCommandPool)))
  34. return false;
  35. // Create descriptor pool
  36. VkDescriptorPoolSize descriptor_pool_sizes[] = {
  37. { VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1024 },
  38. { VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 16 * 1024 },
  39. };
  40. VkDescriptorPoolCreateInfo descriptor_info = {};
  41. descriptor_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
  42. descriptor_info.poolSizeCount = std::size(descriptor_pool_sizes);
  43. descriptor_info.pPoolSizes = descriptor_pool_sizes;
  44. descriptor_info.maxSets = 256;
  45. if (VKFailed(vkCreateDescriptorPool(device, &descriptor_info, nullptr, &mDescriptorPool)))
  46. return false;
  47. // Create a command buffer
  48. VkCommandBufferAllocateInfo alloc_info = {};
  49. alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
  50. alloc_info.commandPool = mCommandPool;
  51. alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
  52. alloc_info.commandBufferCount = 1;
  53. if (VKFailed(vkAllocateCommandBuffers(device, &alloc_info, &mCommandBuffer)))
  54. return false;
  55. // Create a fence
  56. VkFenceCreateInfo fence_info = {};
  57. fence_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
  58. if (VKFailed(vkCreateFence(device, &fence_info, nullptr, &mFence)))
  59. return false;
  60. return true;
  61. }
  62. bool ComputeQueueVK::BeginCommandBuffer()
  63. {
  64. if (!mCommandBufferRecording)
  65. {
  66. VkCommandBufferBeginInfo begin_info = {};
  67. begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
  68. begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
  69. if (VKFailed(vkBeginCommandBuffer(mCommandBuffer, &begin_info)))
  70. return false;
  71. mCommandBufferRecording = true;
  72. }
  73. return true;
  74. }
  75. void ComputeQueueVK::SetShader(const ComputeShader *inShader)
  76. {
  77. mShader = static_cast<const ComputeShaderVK *>(inShader);
  78. mBufferInfos = mShader->GetBufferInfos();
  79. }
  80. void ComputeQueueVK::SetConstantBuffer(const char *inName, const ComputeBuffer *inBuffer)
  81. {
  82. if (inBuffer == nullptr)
  83. return;
  84. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::ConstantBuffer);
  85. if (!BeginCommandBuffer())
  86. return;
  87. const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
  88. buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_ACCESS_UNIFORM_READ_BIT, false);
  89. uint index = mShader->NameToBufferInfoIndex(inName);
  90. JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
  91. mBufferInfos[index].buffer = buffer->GetBufferCPU();
  92. mUsedBuffers.insert(buffer);
  93. }
  94. void ComputeQueueVK::SyncCPUToGPU(const ComputeBufferVK *inBuffer)
  95. {
  96. // Ensure that any CPU writes are visible to the GPU
  97. if (inBuffer->SyncCPUToGPU(mCommandBuffer))
  98. {
  99. // After the first upload, the CPU buffer is no longer needed for Buffer and RWBuffer types
  100. if (inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer)
  101. mDelayedFreedBuffers.push_back(inBuffer->ReleaseBufferCPU());
  102. }
  103. }
  104. void ComputeQueueVK::SetBuffer(const char *inName, const ComputeBuffer *inBuffer)
  105. {
  106. if (inBuffer == nullptr)
  107. return;
  108. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::UploadBuffer || inBuffer->GetType() == ComputeBuffer::EType::Buffer || inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
  109. if (!BeginCommandBuffer())
  110. return;
  111. const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
  112. SyncCPUToGPU(buffer);
  113. buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_ACCESS_SHADER_READ_BIT, false);
  114. uint index = mShader->NameToBufferInfoIndex(inName);
  115. JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
  116. mBufferInfos[index].buffer = buffer->GetBufferGPU();
  117. mUsedBuffers.insert(buffer);
  118. }
  119. void ComputeQueueVK::SetRWBuffer(const char *inName, ComputeBuffer *inBuffer, EBarrier inBarrier)
  120. {
  121. if (inBuffer == nullptr)
  122. return;
  123. JPH_ASSERT(inBuffer->GetType() == ComputeBuffer::EType::RWBuffer);
  124. if (!BeginCommandBuffer())
  125. return;
  126. const ComputeBufferVK *buffer = static_cast<const ComputeBufferVK *>(inBuffer);
  127. SyncCPUToGPU(buffer);
  128. buffer->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VkAccessFlagBits(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT), inBarrier == EBarrier::Yes);
  129. uint index = mShader->NameToBufferInfoIndex(inName);
  130. JPH_ASSERT(mShader->GetLayoutBindings()[index].descriptorType == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
  131. mBufferInfos[index].buffer = buffer->GetBufferGPU();
  132. mUsedBuffers.insert(buffer);
  133. }
  134. void ComputeQueueVK::ScheduleReadback(ComputeBuffer *inDst, const ComputeBuffer *inSrc)
  135. {
  136. if (inDst == nullptr || inSrc == nullptr)
  137. return;
  138. JPH_ASSERT(inDst->GetType() == ComputeBuffer::EType::ReadbackBuffer);
  139. if (!BeginCommandBuffer())
  140. return;
  141. const ComputeBufferVK *src_vk = static_cast<const ComputeBufferVK *>(inSrc);
  142. ComputeBufferVK *dst_vk = static_cast<ComputeBufferVK *>(inDst);
  143. // Barrier to start reading from GPU buffer and writing to CPU buffer
  144. src_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_READ_BIT, false);
  145. dst_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_ACCESS_TRANSFER_WRITE_BIT, false);
  146. // Copy
  147. VkBufferCopy copy = {};
  148. copy.srcOffset = 0;
  149. copy.dstOffset = 0;
  150. copy.size = src_vk->GetSize() * src_vk->GetStride();
  151. vkCmdCopyBuffer(mCommandBuffer, src_vk->GetBufferGPU(), dst_vk->GetBufferCPU(), 1, &copy);
  152. // Barrier to indicate that CPU can read from the buffer
  153. dst_vk->Barrier(mCommandBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_ACCESS_HOST_READ_BIT, false);
  154. mUsedBuffers.insert(src_vk);
  155. mUsedBuffers.insert(dst_vk);
  156. }
  157. void ComputeQueueVK::Dispatch(uint inThreadGroupsX, uint inThreadGroupsY, uint inThreadGroupsZ)
  158. {
  159. if (!BeginCommandBuffer())
  160. return;
  161. vkCmdBindPipeline(mCommandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, mShader->GetPipeline());
  162. VkDevice device = mComputeSystem->GetDevice();
  163. const Array<VkDescriptorSetLayoutBinding> &ds_bindings = mShader->GetLayoutBindings();
  164. if (!ds_bindings.empty())
  165. {
  166. // Create a descriptor set
  167. VkDescriptorSetAllocateInfo alloc_info = {};
  168. alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
  169. alloc_info.descriptorPool = mDescriptorPool;
  170. alloc_info.descriptorSetCount = 1;
  171. VkDescriptorSetLayout ds_layout = mShader->GetDescriptorSetLayout();
  172. alloc_info.pSetLayouts = &ds_layout;
  173. VkDescriptorSet descriptor_set;
  174. if (VKFailed(vkAllocateDescriptorSets(device, &alloc_info, &descriptor_set)))
  175. return;
  176. // Write the values to the descriptor set
  177. Array<VkWriteDescriptorSet> writes;
  178. writes.reserve(ds_bindings.size());
  179. for (uint32 i = 0; i < (uint32)ds_bindings.size(); ++i)
  180. {
  181. VkWriteDescriptorSet w = {};
  182. w.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
  183. w.dstSet = descriptor_set;
  184. w.dstBinding = ds_bindings[i].binding;
  185. w.dstArrayElement = 0;
  186. w.descriptorCount = ds_bindings[i].descriptorCount;
  187. w.descriptorType = ds_bindings[i].descriptorType;
  188. w.pBufferInfo = &mBufferInfos[i];
  189. writes.push_back(w);
  190. }
  191. vkUpdateDescriptorSets(device, (uint32)writes.size(), writes.data(), 0, nullptr);
  192. // Bind the descriptor set
  193. vkCmdBindDescriptorSets(mCommandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, mShader->GetPipelineLayout(), 0, 1, &descriptor_set, 0, nullptr);
  194. }
  195. vkCmdDispatch(mCommandBuffer, inThreadGroupsX, inThreadGroupsY, inThreadGroupsZ);
  196. }
  197. void ComputeQueueVK::Execute()
  198. {
  199. // End command buffer
  200. if (!mCommandBufferRecording)
  201. return;
  202. if (VKFailed(vkEndCommandBuffer(mCommandBuffer)))
  203. return;
  204. mCommandBufferRecording = false;
  205. // Reset fence
  206. VkDevice device = mComputeSystem->GetDevice();
  207. if (VKFailed(vkResetFences(device, 1, &mFence)))
  208. return;
  209. // Submit
  210. VkSubmitInfo submit = {};
  211. submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
  212. submit.commandBufferCount = 1;
  213. submit.pCommandBuffers = &mCommandBuffer;
  214. if (VKFailed(vkQueueSubmit(mQueue, 1, &submit, mFence)))
  215. return;
  216. // Clear the current shader
  217. mShader = nullptr;
  218. // Mark that we're executing
  219. mIsExecuting = true;
  220. }
  221. void ComputeQueueVK::Wait()
  222. {
  223. if (!mIsExecuting)
  224. return;
  225. // Wait for the work to complete
  226. VkDevice device = mComputeSystem->GetDevice();
  227. if (VKFailed(vkWaitForFences(device, 1, &mFence, VK_TRUE, UINT64_MAX)))
  228. return;
  229. // Reset command buffer so it can be reused
  230. if (mCommandBuffer != VK_NULL_HANDLE)
  231. vkResetCommandBuffer(mCommandBuffer, 0);
  232. // Allow reusing the descriptors for next run
  233. vkResetDescriptorPool(device, mDescriptorPool, 0);
  234. // Buffers can be freed now
  235. mUsedBuffers.clear();
  236. // Free delayed buffers
  237. for (BufferVK &buffer : mDelayedFreedBuffers)
  238. mComputeSystem->FreeBuffer(buffer);
  239. mDelayedFreedBuffers.clear();
  240. mIsExecuting = false;
  241. }
  242. JPH_NAMESPACE_END
  243. #endif // JPH_USE_VK