ShaderProgramImpl.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. // Copyright (C) 2009-2023, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Gr/Vulkan/ShaderProgramImpl.h>
  6. #include <AnKi/Gr/Shader.h>
  7. #include <AnKi/Gr/Vulkan/ShaderImpl.h>
  8. #include <AnKi/Gr/Vulkan/GrManagerImpl.h>
  9. #include <AnKi/Gr/Vulkan/Pipeline.h>
  10. namespace anki {
  11. ShaderProgramImpl::~ShaderProgramImpl()
  12. {
  13. if(m_graphics.m_pplineFactory)
  14. {
  15. m_graphics.m_pplineFactory->destroy();
  16. deleteInstance(GrMemoryPool::getSingleton(), m_graphics.m_pplineFactory);
  17. }
  18. if(m_compute.m_ppline)
  19. {
  20. vkDestroyPipeline(getVkDevice(), m_compute.m_ppline, nullptr);
  21. }
  22. if(m_rt.m_ppline)
  23. {
  24. vkDestroyPipeline(getVkDevice(), m_rt.m_ppline, nullptr);
  25. }
  26. }
  27. Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
  28. {
  29. ANKI_ASSERT(inf.isValid());
  30. // Create the shader references
  31. //
  32. GrHashMap<U64, U32> shaderUuidToMShadersIdx; // Shader UUID to m_shaders idx
  33. if(inf.m_computeShader)
  34. {
  35. m_shaders.emplaceBack(inf.m_computeShader);
  36. }
  37. else if(inf.m_graphicsShaders[ShaderType::kVertex])
  38. {
  39. for(const ShaderPtr& s : inf.m_graphicsShaders)
  40. {
  41. if(s)
  42. {
  43. m_shaders.emplaceBack(s);
  44. }
  45. }
  46. }
  47. else
  48. {
  49. // Ray tracing
  50. m_shaders.resizeStorage(inf.m_rayTracingShaders.m_rayGenShaders.getSize() + inf.m_rayTracingShaders.m_missShaders.getSize()
  51. + 1); // Plus at least one hit shader
  52. for(const ShaderPtr& s : inf.m_rayTracingShaders.m_rayGenShaders)
  53. {
  54. m_shaders.emplaceBack(s);
  55. }
  56. for(const ShaderPtr& s : inf.m_rayTracingShaders.m_missShaders)
  57. {
  58. m_shaders.emplaceBack(s);
  59. }
  60. m_rt.m_missShaderCount = inf.m_rayTracingShaders.m_missShaders.getSize();
  61. for(const RayTracingHitGroup& group : inf.m_rayTracingShaders.m_hitGroups)
  62. {
  63. if(group.m_anyHitShader)
  64. {
  65. auto it = shaderUuidToMShadersIdx.find(group.m_anyHitShader->getUuid());
  66. if(it == shaderUuidToMShadersIdx.getEnd())
  67. {
  68. shaderUuidToMShadersIdx.emplace(group.m_anyHitShader->getUuid(), m_shaders.getSize());
  69. m_shaders.emplaceBack(group.m_anyHitShader);
  70. }
  71. }
  72. if(group.m_closestHitShader)
  73. {
  74. auto it = shaderUuidToMShadersIdx.find(group.m_closestHitShader->getUuid());
  75. if(it == shaderUuidToMShadersIdx.getEnd())
  76. {
  77. shaderUuidToMShadersIdx.emplace(group.m_closestHitShader->getUuid(), m_shaders.getSize());
  78. m_shaders.emplaceBack(group.m_closestHitShader);
  79. }
  80. }
  81. }
  82. }
  83. ANKI_ASSERT(m_shaders.getSize() > 0);
  84. // Merge bindings
  85. //
  86. Array2d<DescriptorBinding, kMaxDescriptorSets, kMaxBindingsPerDescriptorSet> bindings;
  87. Array<U32, kMaxDescriptorSets> counts = {};
  88. U32 descriptorSetCount = 0;
  89. for(U32 set = 0; set < kMaxDescriptorSets; ++set)
  90. {
  91. for(ShaderPtr& shader : m_shaders)
  92. {
  93. m_stages |= ShaderTypeBit(1 << shader->getShaderType());
  94. const ShaderImpl& simpl = static_cast<const ShaderImpl&>(*shader);
  95. m_refl.m_activeBindingMask[set] |= simpl.m_activeBindingMask[set];
  96. for(U32 i = 0; i < simpl.m_bindings[set].getSize(); ++i)
  97. {
  98. Bool bindingFound = false;
  99. for(U32 j = 0; j < counts[set]; ++j)
  100. {
  101. if(bindings[set][j].m_binding == simpl.m_bindings[set][i].m_binding)
  102. {
  103. // Found the binding
  104. ANKI_ASSERT(bindings[set][j].m_type == simpl.m_bindings[set][i].m_type);
  105. bindings[set][j].m_stageMask |= simpl.m_bindings[set][i].m_stageMask;
  106. bindingFound = true;
  107. break;
  108. }
  109. }
  110. if(!bindingFound)
  111. {
  112. // New binding
  113. bindings[set][counts[set]++] = simpl.m_bindings[set][i];
  114. }
  115. }
  116. if(simpl.m_pushConstantsSize > 0)
  117. {
  118. if(m_refl.m_pushConstantsSize > 0)
  119. {
  120. ANKI_ASSERT(m_refl.m_pushConstantsSize == simpl.m_pushConstantsSize);
  121. }
  122. m_refl.m_pushConstantsSize = max(m_refl.m_pushConstantsSize, simpl.m_pushConstantsSize);
  123. }
  124. }
  125. // We may end up with ppline layouts with "empty" dslayouts. That's fine, we want it.
  126. if(counts[set])
  127. {
  128. descriptorSetCount = set + 1;
  129. }
  130. }
  131. // Create the descriptor set layouts
  132. //
  133. for(U32 set = 0; set < descriptorSetCount; ++set)
  134. {
  135. DescriptorSetLayoutInitInfo dsinf;
  136. dsinf.m_bindings = WeakArray<DescriptorBinding>((counts[set]) ? &bindings[set][0] : nullptr, counts[set]);
  137. ANKI_CHECK(getGrManagerImpl().getDescriptorSetFactory().newDescriptorSetLayout(dsinf, m_descriptorSetLayouts[set]));
  138. // Even if the dslayout is empty we will have to list it because we'll have to bind a DS for it.
  139. m_refl.m_descriptorSetMask.set(set);
  140. }
  141. // Create the ppline layout
  142. //
  143. WeakArray<DescriptorSetLayout> dsetLayouts((descriptorSetCount) ? &m_descriptorSetLayouts[0] : nullptr, descriptorSetCount);
  144. ANKI_CHECK(getGrManagerImpl().getPipelineLayoutFactory().newPipelineLayout(dsetLayouts, m_refl.m_pushConstantsSize, m_pplineLayout));
  145. // Get some masks
  146. //
  147. const Bool graphicsProg = !!(m_stages & ShaderTypeBit::kAllGraphics);
  148. if(graphicsProg)
  149. {
  150. m_refl.m_attributeMask = static_cast<const ShaderImpl&>(*inf.m_graphicsShaders[ShaderType::kVertex]).m_attributeMask;
  151. m_refl.m_colorAttachmentWritemask = static_cast<const ShaderImpl&>(*inf.m_graphicsShaders[ShaderType::kFragment]).m_colorAttachmentWritemask;
  152. const U32 attachmentCount = m_refl.m_colorAttachmentWritemask.getEnabledBitCount();
  153. for(U32 i = 0; i < attachmentCount; ++i)
  154. {
  155. ANKI_ASSERT(m_refl.m_colorAttachmentWritemask.get(i) && "Should write to all attachments");
  156. }
  157. }
  158. // Init the create infos
  159. //
  160. if(graphicsProg)
  161. {
  162. for(const ShaderPtr& shader : m_shaders)
  163. {
  164. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*shader);
  165. VkPipelineShaderStageCreateInfo& createInf = m_graphics.m_shaderCreateInfos[m_graphics.m_shaderCreateInfoCount++];
  166. createInf = {};
  167. createInf.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  168. createInf.stage = VkShaderStageFlagBits(convertShaderTypeBit(ShaderTypeBit(1 << shader->getShaderType())));
  169. createInf.pName = "main";
  170. createInf.module = shaderImpl.m_handle;
  171. createInf.pSpecializationInfo = shaderImpl.getSpecConstInfo();
  172. }
  173. }
  174. // Create the factory
  175. //
  176. if(graphicsProg)
  177. {
  178. m_graphics.m_pplineFactory = anki::newInstance<PipelineFactory>(GrMemoryPool::getSingleton());
  179. m_graphics.m_pplineFactory->init(getGrManagerImpl().getPipelineCache()
  180. #if ANKI_PLATFORM_MOBILE
  181. ,
  182. getGrManagerImpl().getGlobalCreatePipelineMutex()
  183. #endif
  184. );
  185. }
  186. // Create the pipeline if compute
  187. //
  188. if(!!(m_stages & ShaderTypeBit::kCompute))
  189. {
  190. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[0]);
  191. VkComputePipelineCreateInfo ci = {};
  192. if(!!(getGrManagerImpl().getExtensions() & VulkanExtensions::kKHR_pipeline_executable_properties))
  193. {
  194. ci.flags |= VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR;
  195. }
  196. ci.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
  197. ci.layout = m_pplineLayout.getHandle();
  198. ci.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  199. ci.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
  200. ci.stage.pName = "main";
  201. ci.stage.module = shaderImpl.m_handle;
  202. ci.stage.pSpecializationInfo = shaderImpl.getSpecConstInfo();
  203. ANKI_TRACE_SCOPED_EVENT(VkPipelineCreate);
  204. ANKI_VK_CHECK(vkCreateComputePipelines(getVkDevice(), getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr, &m_compute.m_ppline));
  205. getGrManagerImpl().printPipelineShaderInfo(m_compute.m_ppline, getName(), ShaderTypeBit::kCompute);
  206. }
  207. // Create the RT pipeline
  208. //
  209. if(!!(m_stages & ShaderTypeBit::kAllRayTracing))
  210. {
  211. // Create shaders
  212. GrDynamicArray<VkPipelineShaderStageCreateInfo> stages;
  213. stages.resize(m_shaders.getSize());
  214. for(U32 i = 0; i < stages.getSize(); ++i)
  215. {
  216. const ShaderImpl& impl = static_cast<const ShaderImpl&>(*m_shaders[i]);
  217. VkPipelineShaderStageCreateInfo& stage = stages[i];
  218. stage = {};
  219. stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  220. stage.stage = VkShaderStageFlagBits(convertShaderTypeBit(ShaderTypeBit(1 << impl.getShaderType())));
  221. stage.pName = "main";
  222. stage.module = impl.m_handle;
  223. stage.pSpecializationInfo = impl.getSpecConstInfo();
  224. }
  225. // Create groups
  226. VkRayTracingShaderGroupCreateInfoKHR defaultGroup = {};
  227. defaultGroup.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR;
  228. defaultGroup.generalShader = VK_SHADER_UNUSED_KHR;
  229. defaultGroup.closestHitShader = VK_SHADER_UNUSED_KHR;
  230. defaultGroup.anyHitShader = VK_SHADER_UNUSED_KHR;
  231. defaultGroup.intersectionShader = VK_SHADER_UNUSED_KHR;
  232. U32 groupCount = inf.m_rayTracingShaders.m_rayGenShaders.getSize() + inf.m_rayTracingShaders.m_missShaders.getSize()
  233. + inf.m_rayTracingShaders.m_hitGroups.getSize();
  234. GrDynamicArray<VkRayTracingShaderGroupCreateInfoKHR> groups;
  235. groups.resize(groupCount, defaultGroup);
  236. // 1st group is the ray gen
  237. groupCount = 0;
  238. for(U32 i = 0; i < inf.m_rayTracingShaders.m_rayGenShaders.getSize(); ++i)
  239. {
  240. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
  241. groups[groupCount].generalShader = groupCount;
  242. ++groupCount;
  243. }
  244. // Miss
  245. for(U32 i = 0; i < inf.m_rayTracingShaders.m_missShaders.getSize(); ++i)
  246. {
  247. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
  248. groups[groupCount].generalShader = groupCount;
  249. ++groupCount;
  250. }
  251. // The rest of the groups are hit
  252. for(U32 i = 0; i < inf.m_rayTracingShaders.m_hitGroups.getSize(); ++i)
  253. {
  254. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
  255. if(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader)
  256. {
  257. groups[groupCount].anyHitShader = *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader->getUuid());
  258. }
  259. if(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader)
  260. {
  261. groups[groupCount].closestHitShader =
  262. *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader->getUuid());
  263. }
  264. ++groupCount;
  265. }
  266. ANKI_ASSERT(groupCount == groups.getSize());
  267. VkRayTracingPipelineCreateInfoKHR ci = {};
  268. ci.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR;
  269. ci.stageCount = stages.getSize();
  270. ci.pStages = &stages[0];
  271. ci.groupCount = groups.getSize();
  272. ci.pGroups = &groups[0];
  273. ci.maxPipelineRayRecursionDepth = inf.m_rayTracingShaders.m_maxRecursionDepth;
  274. ci.layout = m_pplineLayout.getHandle();
  275. {
  276. ANKI_TRACE_SCOPED_EVENT(VkPipelineCreate);
  277. ANKI_VK_CHECK(vkCreateRayTracingPipelinesKHR(getVkDevice(), VK_NULL_HANDLE, getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr,
  278. &m_rt.m_ppline));
  279. }
  280. // Get RT handles
  281. const U32 handleArraySize = getGrManagerImpl().getPhysicalDeviceRayTracingProperties().shaderGroupHandleSize * groupCount;
  282. m_rt.m_allHandles.resize(handleArraySize, 0_U8);
  283. ANKI_VK_CHECK(vkGetRayTracingShaderGroupHandlesKHR(getVkDevice(), m_rt.m_ppline, 0, groupCount, handleArraySize, &m_rt.m_allHandles[0]));
  284. }
  285. return Error::kNone;
  286. }
  287. } // end namespace anki