ShaderProgramImpl.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. // Copyright (C) 2009-2023, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Gr/Vulkan/ShaderProgramImpl.h>
  6. #include <AnKi/Gr/Shader.h>
  7. #include <AnKi/Gr/Vulkan/ShaderImpl.h>
  8. #include <AnKi/Gr/Vulkan/GrManagerImpl.h>
  9. #include <AnKi/Gr/Vulkan/Pipeline.h>
  10. namespace anki {
  11. ShaderProgramImpl::~ShaderProgramImpl()
  12. {
  13. if(m_graphics.m_pplineFactory)
  14. {
  15. m_graphics.m_pplineFactory->destroy();
  16. deleteInstance(getMemoryPool(), m_graphics.m_pplineFactory);
  17. }
  18. if(m_compute.m_ppline)
  19. {
  20. vkDestroyPipeline(getDevice(), m_compute.m_ppline, nullptr);
  21. }
  22. if(m_rt.m_ppline)
  23. {
  24. vkDestroyPipeline(getDevice(), m_rt.m_ppline, nullptr);
  25. }
  26. m_shaders.destroy(getMemoryPool());
  27. m_rt.m_allHandles.destroy(getMemoryPool());
  28. }
  29. Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
  30. {
  31. ANKI_ASSERT(inf.isValid());
  32. // Create the shader references
  33. //
  34. HashMapRaii<U64, U32> shaderUuidToMShadersIdx(&getMemoryPool()); // Shader UUID to m_shaders idx
  35. if(inf.m_computeShader)
  36. {
  37. m_shaders.emplaceBack(getMemoryPool(), inf.m_computeShader);
  38. }
  39. else if(inf.m_graphicsShaders[ShaderType::kVertex])
  40. {
  41. for(const ShaderPtr& s : inf.m_graphicsShaders)
  42. {
  43. if(s)
  44. {
  45. m_shaders.emplaceBack(getMemoryPool(), s);
  46. }
  47. }
  48. }
  49. else
  50. {
  51. // Ray tracing
  52. m_shaders.resizeStorage(getMemoryPool(), inf.m_rayTracingShaders.m_rayGenShaders.getSize()
  53. + inf.m_rayTracingShaders.m_missShaders.getSize()
  54. + 1); // Plus at least one hit shader
  55. for(const ShaderPtr& s : inf.m_rayTracingShaders.m_rayGenShaders)
  56. {
  57. m_shaders.emplaceBack(getMemoryPool(), s);
  58. }
  59. for(const ShaderPtr& s : inf.m_rayTracingShaders.m_missShaders)
  60. {
  61. m_shaders.emplaceBack(getMemoryPool(), s);
  62. }
  63. m_rt.m_missShaderCount = inf.m_rayTracingShaders.m_missShaders.getSize();
  64. for(const RayTracingHitGroup& group : inf.m_rayTracingShaders.m_hitGroups)
  65. {
  66. if(group.m_anyHitShader)
  67. {
  68. auto it = shaderUuidToMShadersIdx.find(group.m_anyHitShader->getUuid());
  69. if(it == shaderUuidToMShadersIdx.getEnd())
  70. {
  71. shaderUuidToMShadersIdx.emplace(group.m_anyHitShader->getUuid(), m_shaders.getSize());
  72. m_shaders.emplaceBack(getMemoryPool(), group.m_anyHitShader);
  73. }
  74. }
  75. if(group.m_closestHitShader)
  76. {
  77. auto it = shaderUuidToMShadersIdx.find(group.m_closestHitShader->getUuid());
  78. if(it == shaderUuidToMShadersIdx.getEnd())
  79. {
  80. shaderUuidToMShadersIdx.emplace(group.m_closestHitShader->getUuid(), m_shaders.getSize());
  81. m_shaders.emplaceBack(getMemoryPool(), group.m_closestHitShader);
  82. }
  83. }
  84. }
  85. }
  86. ANKI_ASSERT(m_shaders.getSize() > 0);
  87. // Merge bindings
  88. //
  89. Array2d<DescriptorBinding, kMaxDescriptorSets, kMaxBindingsPerDescriptorSet> bindings;
  90. Array<U32, kMaxDescriptorSets> counts = {};
  91. U32 descriptorSetCount = 0;
  92. for(U32 set = 0; set < kMaxDescriptorSets; ++set)
  93. {
  94. for(ShaderPtr& shader : m_shaders)
  95. {
  96. m_stages |= ShaderTypeBit(1 << shader->getShaderType());
  97. const ShaderImpl& simpl = static_cast<const ShaderImpl&>(*shader);
  98. m_refl.m_activeBindingMask[set] |= simpl.m_activeBindingMask[set];
  99. for(U32 i = 0; i < simpl.m_bindings[set].getSize(); ++i)
  100. {
  101. Bool bindingFound = false;
  102. for(U32 j = 0; j < counts[set]; ++j)
  103. {
  104. if(bindings[set][j].m_binding == simpl.m_bindings[set][i].m_binding)
  105. {
  106. // Found the binding
  107. ANKI_ASSERT(bindings[set][j].m_type == simpl.m_bindings[set][i].m_type);
  108. bindings[set][j].m_stageMask |= simpl.m_bindings[set][i].m_stageMask;
  109. bindingFound = true;
  110. break;
  111. }
  112. }
  113. if(!bindingFound)
  114. {
  115. // New binding
  116. bindings[set][counts[set]++] = simpl.m_bindings[set][i];
  117. }
  118. }
  119. if(simpl.m_pushConstantsSize > 0)
  120. {
  121. if(m_refl.m_pushConstantsSize > 0)
  122. {
  123. ANKI_ASSERT(m_refl.m_pushConstantsSize == simpl.m_pushConstantsSize);
  124. }
  125. m_refl.m_pushConstantsSize = max(m_refl.m_pushConstantsSize, simpl.m_pushConstantsSize);
  126. }
  127. }
  128. // We may end up with ppline layouts with "empty" dslayouts. That's fine, we want it.
  129. if(counts[set])
  130. {
  131. descriptorSetCount = set + 1;
  132. }
  133. }
  134. // Create the descriptor set layouts
  135. //
  136. for(U32 set = 0; set < descriptorSetCount; ++set)
  137. {
  138. DescriptorSetLayoutInitInfo dsinf;
  139. dsinf.m_bindings = WeakArray<DescriptorBinding>((counts[set]) ? &bindings[set][0] : nullptr, counts[set]);
  140. ANKI_CHECK(
  141. getGrManagerImpl().getDescriptorSetFactory().newDescriptorSetLayout(dsinf, m_descriptorSetLayouts[set]));
  142. // Even if the dslayout is empty we will have to list it because we'll have to bind a DS for it.
  143. m_refl.m_descriptorSetMask.set(set);
  144. }
  145. // Create the ppline layout
  146. //
  147. WeakArray<DescriptorSetLayout> dsetLayouts((descriptorSetCount) ? &m_descriptorSetLayouts[0] : nullptr,
  148. descriptorSetCount);
  149. ANKI_CHECK(getGrManagerImpl().getPipelineLayoutFactory().newPipelineLayout(dsetLayouts, m_refl.m_pushConstantsSize,
  150. m_pplineLayout));
  151. // Get some masks
  152. //
  153. const Bool graphicsProg = !!(m_stages & ShaderTypeBit::kAllGraphics);
  154. if(graphicsProg)
  155. {
  156. m_refl.m_attributeMask =
  157. static_cast<const ShaderImpl&>(*inf.m_graphicsShaders[ShaderType::kVertex]).m_attributeMask;
  158. m_refl.m_colorAttachmentWritemask =
  159. static_cast<const ShaderImpl&>(*inf.m_graphicsShaders[ShaderType::kFragment]).m_colorAttachmentWritemask;
  160. const U32 attachmentCount = m_refl.m_colorAttachmentWritemask.getEnabledBitCount();
  161. for(U32 i = 0; i < attachmentCount; ++i)
  162. {
  163. ANKI_ASSERT(m_refl.m_colorAttachmentWritemask.get(i) && "Should write to all attachments");
  164. }
  165. }
  166. // Init the create infos
  167. //
  168. if(graphicsProg)
  169. {
  170. for(const ShaderPtr& shader : m_shaders)
  171. {
  172. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*shader);
  173. VkPipelineShaderStageCreateInfo& createInf =
  174. m_graphics.m_shaderCreateInfos[m_graphics.m_shaderCreateInfoCount++];
  175. createInf = {};
  176. createInf.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  177. createInf.stage = VkShaderStageFlagBits(convertShaderTypeBit(ShaderTypeBit(1 << shader->getShaderType())));
  178. createInf.pName = "main";
  179. createInf.module = shaderImpl.m_handle;
  180. createInf.pSpecializationInfo = shaderImpl.getSpecConstInfo();
  181. }
  182. }
  183. // Create the factory
  184. //
  185. if(graphicsProg)
  186. {
  187. m_graphics.m_pplineFactory = anki::newInstance<PipelineFactory>(getMemoryPool());
  188. m_graphics.m_pplineFactory->init(&getMemoryPool(), getGrManagerImpl().getDevice(),
  189. getGrManagerImpl().getPipelineCache()
  190. #if ANKI_PLATFORM_MOBILE
  191. ,
  192. getGrManagerImpl().getGlobalCreatePipelineMutex()
  193. #endif
  194. );
  195. }
  196. // Create the pipeline if compute
  197. //
  198. if(!!(m_stages & ShaderTypeBit::kCompute))
  199. {
  200. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[0]);
  201. VkComputePipelineCreateInfo ci = {};
  202. if(!!(getGrManagerImpl().getExtensions() & VulkanExtensions::kKHR_pipeline_executable_properties))
  203. {
  204. ci.flags |= VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR;
  205. }
  206. ci.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
  207. ci.layout = m_pplineLayout.getHandle();
  208. ci.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  209. ci.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
  210. ci.stage.pName = "main";
  211. ci.stage.module = shaderImpl.m_handle;
  212. ci.stage.pSpecializationInfo = shaderImpl.getSpecConstInfo();
  213. ANKI_TRACE_SCOPED_EVENT(VkPipelineCreate);
  214. ANKI_VK_CHECK(vkCreateComputePipelines(getDevice(), getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr,
  215. &m_compute.m_ppline));
  216. getGrManagerImpl().printPipelineShaderInfo(m_compute.m_ppline, getName(), ShaderTypeBit::kCompute);
  217. }
  218. // Create the RT pipeline
  219. //
  220. if(!!(m_stages & ShaderTypeBit::kAllRayTracing))
  221. {
  222. // Create shaders
  223. DynamicArrayRaii<VkPipelineShaderStageCreateInfo> stages(m_shaders.getSize(), &getMemoryPool());
  224. for(U32 i = 0; i < stages.getSize(); ++i)
  225. {
  226. const ShaderImpl& impl = static_cast<const ShaderImpl&>(*m_shaders[i]);
  227. VkPipelineShaderStageCreateInfo& stage = stages[i];
  228. stage = {};
  229. stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  230. stage.stage = VkShaderStageFlagBits(convertShaderTypeBit(ShaderTypeBit(1 << impl.getShaderType())));
  231. stage.pName = "main";
  232. stage.module = impl.m_handle;
  233. stage.pSpecializationInfo = impl.getSpecConstInfo();
  234. }
  235. // Create groups
  236. VkRayTracingShaderGroupCreateInfoKHR defaultGroup = {};
  237. defaultGroup.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR;
  238. defaultGroup.generalShader = VK_SHADER_UNUSED_KHR;
  239. defaultGroup.closestHitShader = VK_SHADER_UNUSED_KHR;
  240. defaultGroup.anyHitShader = VK_SHADER_UNUSED_KHR;
  241. defaultGroup.intersectionShader = VK_SHADER_UNUSED_KHR;
  242. U32 groupCount = inf.m_rayTracingShaders.m_rayGenShaders.getSize()
  243. + inf.m_rayTracingShaders.m_missShaders.getSize()
  244. + inf.m_rayTracingShaders.m_hitGroups.getSize();
  245. DynamicArrayRaii<VkRayTracingShaderGroupCreateInfoKHR> groups(groupCount, defaultGroup, &getMemoryPool());
  246. // 1st group is the ray gen
  247. groupCount = 0;
  248. for(U32 i = 0; i < inf.m_rayTracingShaders.m_rayGenShaders.getSize(); ++i)
  249. {
  250. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
  251. groups[groupCount].generalShader = groupCount;
  252. ++groupCount;
  253. }
  254. // Miss
  255. for(U32 i = 0; i < inf.m_rayTracingShaders.m_missShaders.getSize(); ++i)
  256. {
  257. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
  258. groups[groupCount].generalShader = groupCount;
  259. ++groupCount;
  260. }
  261. // The rest of the groups are hit
  262. for(U32 i = 0; i < inf.m_rayTracingShaders.m_hitGroups.getSize(); ++i)
  263. {
  264. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
  265. if(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader)
  266. {
  267. groups[groupCount].anyHitShader =
  268. *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader->getUuid());
  269. }
  270. if(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader)
  271. {
  272. groups[groupCount].closestHitShader =
  273. *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader->getUuid());
  274. }
  275. ++groupCount;
  276. }
  277. ANKI_ASSERT(groupCount == groups.getSize());
  278. VkRayTracingPipelineCreateInfoKHR ci = {};
  279. ci.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR;
  280. ci.stageCount = stages.getSize();
  281. ci.pStages = &stages[0];
  282. ci.groupCount = groups.getSize();
  283. ci.pGroups = &groups[0];
  284. ci.maxPipelineRayRecursionDepth = inf.m_rayTracingShaders.m_maxRecursionDepth;
  285. ci.layout = m_pplineLayout.getHandle();
  286. {
  287. ANKI_TRACE_SCOPED_EVENT(VkPipelineCreate);
  288. ANKI_VK_CHECK(vkCreateRayTracingPipelinesKHR(
  289. getDevice(), VK_NULL_HANDLE, getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr, &m_rt.m_ppline));
  290. }
  291. // Get RT handles
  292. const U32 handleArraySize =
  293. getGrManagerImpl().getPhysicalDeviceRayTracingProperties().shaderGroupHandleSize * groupCount;
  294. m_rt.m_allHandles.create(getMemoryPool(), handleArraySize, 0);
  295. ANKI_VK_CHECK(vkGetRayTracingShaderGroupHandlesKHR(getDevice(), m_rt.m_ppline, 0, groupCount, handleArraySize,
  296. &m_rt.m_allHandles[0]));
  297. }
  298. return Error::kNone;
  299. }
  300. } // end namespace anki