2
0

ComputeShaderVK.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2025 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #include <Jolt/Jolt.h>
  5. #ifdef JPH_USE_VK
  6. #include <Jolt/Compute/VK/ComputeShaderVK.h>
  7. JPH_NAMESPACE_BEGIN
  8. ComputeShaderVK::~ComputeShaderVK()
  9. {
  10. if (mShaderModule != VK_NULL_HANDLE)
  11. vkDestroyShaderModule(mDevice, mShaderModule, nullptr);
  12. if (mDescriptorSetLayout != VK_NULL_HANDLE)
  13. vkDestroyDescriptorSetLayout(mDevice, mDescriptorSetLayout, nullptr);
  14. if (mPipelineLayout != VK_NULL_HANDLE)
  15. vkDestroyPipelineLayout(mDevice, mPipelineLayout, nullptr);
  16. if (mPipeline != VK_NULL_HANDLE)
  17. vkDestroyPipeline(mDevice, mPipeline, nullptr);
  18. }
  19. bool ComputeShaderVK::Initialize(const Array<uint8> &inSPVCode, VkBuffer inDummyBuffer, ComputeShaderResult &outResult)
  20. {
  21. const uint32 *spv_words = reinterpret_cast<const uint32 *>(inSPVCode.data());
  22. size_t spv_word_count = inSPVCode.size() / sizeof(uint32);
  23. // Minimal SPIR-V parser to extract name to binding info
  24. UnorderedMap<uint32, String> id_to_name;
  25. UnorderedMap<uint32, uint32> id_to_binding;
  26. UnorderedMap<uint32, VkDescriptorType> id_to_descriptor_type;
  27. UnorderedMap<uint32, uint32> pointer_to_pointee;
  28. UnorderedMap<uint32, uint32> var_to_ptr_type;
  29. size_t i = 5; // Skip 5 word header
  30. while (i < spv_word_count)
  31. {
  32. // Parse next word
  33. uint32 word = spv_words[i];
  34. uint16 opcode = uint16(word & 0xffff);
  35. uint16 word_count = uint16(word >> 16);
  36. if (word_count == 0 || i + word_count > spv_word_count)
  37. break;
  38. switch (opcode)
  39. {
  40. case 5: // OpName
  41. if (word_count >= 2)
  42. {
  43. uint32 target_id = spv_words[i + 1];
  44. const char* name = reinterpret_cast<const char*>(&spv_words[i + 2]);
  45. if (*name != 0)
  46. id_to_name.insert({ target_id, name });
  47. }
  48. break;
  49. case 16: // OpExecutionMode
  50. if (word_count >= 6)
  51. {
  52. uint32 execution_mode = spv_words[i + 2];
  53. if (execution_mode == 17) // LocalSize
  54. {
  55. // Assert that the group size provided matches the one in the shader
  56. JPH_ASSERT(GetGroupSizeX() == spv_words[i + 3], "Group size X mismatch");
  57. JPH_ASSERT(GetGroupSizeY() == spv_words[i + 4], "Group size Y mismatch");
  58. JPH_ASSERT(GetGroupSizeZ() == spv_words[i + 5], "Group size Z mismatch");
  59. }
  60. }
  61. break;
  62. case 32: // OpTypePointer
  63. if (word_count >= 4)
  64. {
  65. uint32 result_id = spv_words[i + 1];
  66. uint32 type_id = spv_words[i + 3];
  67. pointer_to_pointee.insert({ result_id, type_id });
  68. }
  69. break;
  70. case 59: // OpVariable
  71. if (word_count >= 3)
  72. {
  73. uint32 ptr_type_id = spv_words[i + 1];
  74. uint32 result_id = spv_words[i + 2];
  75. var_to_ptr_type.insert({ result_id, ptr_type_id });
  76. }
  77. break;
  78. case 71: // OpDecorate
  79. if (word_count >= 3)
  80. {
  81. uint32 target_id = spv_words[i + 1];
  82. uint32 decoration = spv_words[i + 2];
  83. if (decoration == 2) // Block
  84. {
  85. id_to_descriptor_type.insert({ target_id, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER });
  86. }
  87. else if (decoration == 3) // BufferBlock
  88. {
  89. id_to_descriptor_type.insert({ target_id, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER });
  90. }
  91. else if (decoration == 33 && word_count >= 4) // Binding
  92. {
  93. uint32 binding = spv_words[i + 3];
  94. id_to_binding.insert({ target_id, binding });
  95. }
  96. }
  97. break;
  98. default:
  99. break;
  100. }
  101. i += word_count;
  102. }
  103. // Build name to binding map
  104. UnorderedMap<String, std::pair<uint32, VkDescriptorType>> name_to_binding;
  105. for (const UnorderedMap<uint32, uint32>::value_type &entry : id_to_binding)
  106. {
  107. uint32 target_id = entry.first;
  108. uint32 binding = entry.second;
  109. // Get the name of the variable
  110. UnorderedMap<uint32, String>::const_iterator it_name = id_to_name.find(target_id);
  111. if (it_name != id_to_name.end())
  112. {
  113. // Find variable that links to the target
  114. UnorderedMap<uint32, uint32>::const_iterator it_var_ptr = var_to_ptr_type.find(target_id);
  115. if (it_var_ptr != var_to_ptr_type.end())
  116. {
  117. // Find type pointed at
  118. uint32 ptr_type = it_var_ptr->second;
  119. UnorderedMap<uint32, uint32>::const_iterator it_pointee = pointer_to_pointee.find(ptr_type);
  120. if (it_pointee != pointer_to_pointee.end())
  121. {
  122. uint32 pointee_type = it_pointee->second;
  123. // Find descriptor type
  124. UnorderedMap<uint32, VkDescriptorType>::iterator it_descriptor_type = id_to_descriptor_type.find(pointee_type);
  125. VkDescriptorType descriptor_type = it_descriptor_type != id_to_descriptor_type.end() ? it_descriptor_type->second : VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
  126. name_to_binding.insert({ it_name->second, { binding, descriptor_type } });
  127. continue;
  128. }
  129. }
  130. }
  131. }
  132. // Create layout bindings and buffer infos
  133. if (!name_to_binding.empty())
  134. {
  135. mLayoutBindings.reserve(name_to_binding.size());
  136. mBufferInfos.reserve(name_to_binding.size());
  137. mBindingNames.reserve(name_to_binding.size());
  138. for (const UnorderedMap<String, std::pair<uint32, VkDescriptorType>>::value_type &b : name_to_binding)
  139. {
  140. const String &name = b.first;
  141. uint binding = b.second.first;
  142. VkDescriptorType descriptor_type = b.second.second;
  143. VkDescriptorSetLayoutBinding l = {};
  144. l.binding = binding;
  145. l.descriptorCount = 1;
  146. l.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
  147. l.descriptorType = descriptor_type;
  148. mLayoutBindings.push_back(l);
  149. mBindingNames.push_back(name); // Add all strings to a pool to keep them alive
  150. mNameToBufferInfoIndex[string_view(mBindingNames.back())] = (uint32)mBufferInfos.size();
  151. VkDescriptorBufferInfo bi = {};
  152. bi.offset = 0;
  153. bi.range = VK_WHOLE_SIZE;
  154. bi.buffer = inDummyBuffer; // Avoid: The Vulkan spec states: If the nullDescriptor feature is not enabled, buffer must not be VK_NULL_HANDLE
  155. mBufferInfos.push_back(bi);
  156. }
  157. // Create descriptor set layout
  158. VkDescriptorSetLayoutCreateInfo layout_info = {};
  159. layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
  160. layout_info.bindingCount = (uint32)mLayoutBindings.size();
  161. layout_info.pBindings = mLayoutBindings.data();
  162. if (VKFailed(vkCreateDescriptorSetLayout(mDevice, &layout_info, nullptr, &mDescriptorSetLayout), outResult))
  163. return false;
  164. }
  165. // Create pipeline layout
  166. VkPipelineLayoutCreateInfo pl_info = {};
  167. pl_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
  168. pl_info.setLayoutCount = mDescriptorSetLayout != VK_NULL_HANDLE ? 1 : 0;
  169. pl_info.pSetLayouts = mDescriptorSetLayout != VK_NULL_HANDLE ? &mDescriptorSetLayout : nullptr;
  170. if (VKFailed(vkCreatePipelineLayout(mDevice, &pl_info, nullptr, &mPipelineLayout), outResult))
  171. return false;
  172. // Create shader module
  173. VkShaderModuleCreateInfo create_info = {};
  174. create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
  175. create_info.codeSize = inSPVCode.size();
  176. create_info.pCode = spv_words;
  177. if (VKFailed(vkCreateShaderModule(mDevice, &create_info, nullptr, &mShaderModule), outResult))
  178. return false;
  179. // Create compute pipeline
  180. VkComputePipelineCreateInfo pipe_info = {};
  181. pipe_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
  182. pipe_info.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  183. pipe_info.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
  184. pipe_info.stage.module = mShaderModule;
  185. pipe_info.stage.pName = "main";
  186. pipe_info.layout = mPipelineLayout;
  187. if (VKFailed(vkCreateComputePipelines(mDevice, VK_NULL_HANDLE, 1, &pipe_info, nullptr, &mPipeline), outResult))
  188. return false;
  189. return true;
  190. }
  191. uint32 ComputeShaderVK::NameToBufferInfoIndex(const char *inName) const
  192. {
  193. UnorderedMap<string_view, uint>::const_iterator it = mNameToBufferInfoIndex.find(inName);
  194. JPH_ASSERT(it != mNameToBufferInfoIndex.end());
  195. return it->second;
  196. }
  197. JPH_NAMESPACE_END
  198. #endif // JPH_USE_VK