2
0

ShaderProgramImpl.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. // Copyright (C) 2009-2021, Panagiotis Christopoulos Charitos and contributors.
  2. // All rights reserved.
  3. // Code licensed under the BSD License.
  4. // http://www.anki3d.org/LICENSE
  5. #include <AnKi/Gr/Vulkan/ShaderProgramImpl.h>
  6. #include <AnKi/Gr/Shader.h>
  7. #include <AnKi/Gr/Vulkan/ShaderImpl.h>
  8. #include <AnKi/Gr/Vulkan/GrManagerImpl.h>
  9. #include <AnKi/Gr/Vulkan/Pipeline.h>
  10. namespace anki
  11. {
  12. ShaderProgramImpl::~ShaderProgramImpl()
  13. {
  14. if(m_graphics.m_pplineFactory)
  15. {
  16. m_graphics.m_pplineFactory->destroy();
  17. getAllocator().deleteInstance(m_graphics.m_pplineFactory);
  18. }
  19. if(m_compute.m_ppline)
  20. {
  21. vkDestroyPipeline(getDevice(), m_compute.m_ppline, nullptr);
  22. }
  23. if(m_rt.m_ppline)
  24. {
  25. vkDestroyPipeline(getDevice(), m_rt.m_ppline, nullptr);
  26. }
  27. m_shaders.destroy(getAllocator());
  28. m_rt.m_allHandles.destroy(getAllocator());
  29. }
  30. Error ShaderProgramImpl::init(const ShaderProgramInitInfo& inf)
  31. {
  32. ANKI_ASSERT(inf.isValid());
  33. // Create the shader references
  34. //
  35. HashMapAuto<U64, U32> shaderUuidToMShadersIdx(getAllocator()); // Shader UUID to m_shaders idx
  36. if(inf.m_computeShader)
  37. {
  38. m_shaders.emplaceBack(getAllocator(), inf.m_computeShader);
  39. }
  40. else if(inf.m_graphicsShaders[ShaderType::VERTEX])
  41. {
  42. for(const ShaderPtr& s : inf.m_graphicsShaders)
  43. {
  44. if(s)
  45. {
  46. m_shaders.emplaceBack(getAllocator(), s);
  47. }
  48. }
  49. }
  50. else
  51. {
  52. // Ray tracing
  53. m_shaders.resizeStorage(getAllocator(), inf.m_rayTracingShaders.m_rayGenShaders.getSize()
  54. + inf.m_rayTracingShaders.m_missShaders.getSize()
  55. + 1); // Plus at least one hit shader
  56. for(const ShaderPtr& s : inf.m_rayTracingShaders.m_rayGenShaders)
  57. {
  58. m_shaders.emplaceBack(getAllocator(), s);
  59. }
  60. for(const ShaderPtr& s : inf.m_rayTracingShaders.m_missShaders)
  61. {
  62. m_shaders.emplaceBack(getAllocator(), s);
  63. }
  64. m_rt.m_missShaderCount = inf.m_rayTracingShaders.m_missShaders.getSize();
  65. for(const RayTracingHitGroup& group : inf.m_rayTracingShaders.m_hitGroups)
  66. {
  67. if(group.m_anyHitShader)
  68. {
  69. auto it = shaderUuidToMShadersIdx.find(group.m_anyHitShader->getUuid());
  70. if(it == shaderUuidToMShadersIdx.getEnd())
  71. {
  72. shaderUuidToMShadersIdx.emplace(group.m_anyHitShader->getUuid(), m_shaders.getSize());
  73. m_shaders.emplaceBack(getAllocator(), group.m_anyHitShader);
  74. }
  75. }
  76. if(group.m_closestHitShader)
  77. {
  78. auto it = shaderUuidToMShadersIdx.find(group.m_closestHitShader->getUuid());
  79. if(it == shaderUuidToMShadersIdx.getEnd())
  80. {
  81. shaderUuidToMShadersIdx.emplace(group.m_closestHitShader->getUuid(), m_shaders.getSize());
  82. m_shaders.emplaceBack(getAllocator(), group.m_closestHitShader);
  83. }
  84. }
  85. }
  86. }
  87. ANKI_ASSERT(m_shaders.getSize() > 0);
  88. // Merge bindings
  89. //
  90. Array2d<DescriptorBinding, MAX_DESCRIPTOR_SETS, MAX_BINDINGS_PER_DESCRIPTOR_SET> bindings;
  91. Array<U32, MAX_DESCRIPTOR_SETS> counts = {};
  92. U32 descriptorSetCount = 0;
  93. for(U32 set = 0; set < MAX_DESCRIPTOR_SETS; ++set)
  94. {
  95. for(ShaderPtr& shader : m_shaders)
  96. {
  97. m_stages |= ShaderTypeBit(1 << shader->getShaderType());
  98. const ShaderImpl& simpl = static_cast<const ShaderImpl&>(*shader);
  99. m_refl.m_activeBindingMask[set] |= simpl.m_activeBindingMask[set];
  100. for(U32 i = 0; i < simpl.m_bindings[set].getSize(); ++i)
  101. {
  102. Bool bindingFound = false;
  103. for(U32 j = 0; j < counts[set]; ++j)
  104. {
  105. if(bindings[set][j].m_binding == simpl.m_bindings[set][i].m_binding)
  106. {
  107. // Found the binding
  108. ANKI_ASSERT(bindings[set][j].m_type == simpl.m_bindings[set][i].m_type);
  109. bindings[set][j].m_stageMask |= simpl.m_bindings[set][i].m_stageMask;
  110. bindingFound = true;
  111. break;
  112. }
  113. }
  114. if(!bindingFound)
  115. {
  116. // New binding
  117. bindings[set][counts[set]++] = simpl.m_bindings[set][i];
  118. }
  119. }
  120. if(simpl.m_pushConstantsSize > 0)
  121. {
  122. if(m_refl.m_pushConstantsSize > 0)
  123. {
  124. ANKI_ASSERT(m_refl.m_pushConstantsSize == simpl.m_pushConstantsSize);
  125. }
  126. m_refl.m_pushConstantsSize = max(m_refl.m_pushConstantsSize, simpl.m_pushConstantsSize);
  127. }
  128. }
  129. // We may end up with ppline layouts with "empty" dslayouts. That's fine, we want it.
  130. if(counts[set])
  131. {
  132. descriptorSetCount = set + 1;
  133. }
  134. }
  135. // Create the descriptor set layouts
  136. //
  137. for(U32 set = 0; set < descriptorSetCount; ++set)
  138. {
  139. DescriptorSetLayoutInitInfo inf;
  140. inf.m_bindings = WeakArray<DescriptorBinding>((counts[set]) ? &bindings[set][0] : nullptr, counts[set]);
  141. ANKI_CHECK(
  142. getGrManagerImpl().getDescriptorSetFactory().newDescriptorSetLayout(inf, m_descriptorSetLayouts[set]));
  143. // Even if the dslayout is empty we will have to list it because we'll have to bind a DS for it.
  144. m_refl.m_descriptorSetMask.set(set);
  145. }
  146. // Create the ppline layout
  147. //
  148. WeakArray<DescriptorSetLayout> dsetLayouts((descriptorSetCount) ? &m_descriptorSetLayouts[0] : nullptr,
  149. descriptorSetCount);
  150. ANKI_CHECK(getGrManagerImpl().getPipelineLayoutFactory().newPipelineLayout(dsetLayouts, m_refl.m_pushConstantsSize,
  151. m_pplineLayout));
  152. // Get some masks
  153. //
  154. const Bool graphicsProg = !!(m_stages & ShaderTypeBit::ALL_GRAPHICS);
  155. if(graphicsProg)
  156. {
  157. m_refl.m_attributeMask =
  158. static_cast<const ShaderImpl&>(*inf.m_graphicsShaders[ShaderType::VERTEX]).m_attributeMask;
  159. m_refl.m_colorAttachmentWritemask =
  160. static_cast<const ShaderImpl&>(*inf.m_graphicsShaders[ShaderType::FRAGMENT]).m_colorAttachmentWritemask;
  161. const U32 attachmentCount = m_refl.m_colorAttachmentWritemask.getEnabledBitCount();
  162. for(U32 i = 0; i < attachmentCount; ++i)
  163. {
  164. ANKI_ASSERT(m_refl.m_colorAttachmentWritemask.get(i) && "Should write to all attachments");
  165. }
  166. }
  167. // Init the create infos
  168. //
  169. if(graphicsProg)
  170. {
  171. for(const ShaderPtr& shader : m_shaders)
  172. {
  173. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*shader);
  174. VkPipelineShaderStageCreateInfo& inf = m_graphics.m_shaderCreateInfos[m_graphics.m_shaderCreateInfoCount++];
  175. inf = {};
  176. inf.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  177. inf.stage = VkShaderStageFlagBits(convertShaderTypeBit(ShaderTypeBit(1 << shader->getShaderType())));
  178. inf.pName = "main";
  179. inf.module = shaderImpl.m_handle;
  180. inf.pSpecializationInfo = shaderImpl.getSpecConstInfo();
  181. }
  182. }
  183. // Create the factory
  184. //
  185. if(graphicsProg)
  186. {
  187. m_graphics.m_pplineFactory = getAllocator().newInstance<PipelineFactory>();
  188. m_graphics.m_pplineFactory->init(getGrManagerImpl().getAllocator(), getGrManagerImpl().getDevice(),
  189. getGrManagerImpl().getPipelineCache());
  190. }
  191. // Create the pipeline if compute
  192. //
  193. if(!!(m_stages & ShaderTypeBit::COMPUTE))
  194. {
  195. const ShaderImpl& shaderImpl = static_cast<const ShaderImpl&>(*m_shaders[0]);
  196. VkComputePipelineCreateInfo ci = {};
  197. ci.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
  198. ci.layout = m_pplineLayout.getHandle();
  199. ci.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  200. ci.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
  201. ci.stage.pName = "main";
  202. ci.stage.module = shaderImpl.m_handle;
  203. ci.stage.pSpecializationInfo = shaderImpl.getSpecConstInfo();
  204. ANKI_TRACE_SCOPED_EVENT(VK_PIPELINE_CREATE);
  205. ANKI_VK_CHECK(vkCreateComputePipelines(getDevice(), getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr,
  206. &m_compute.m_ppline));
  207. getGrManagerImpl().printPipelineShaderInfo(m_compute.m_ppline, getName(), ShaderTypeBit::COMPUTE);
  208. }
  209. // Create the RT pipeline
  210. //
  211. if(!!(m_stages & ShaderTypeBit::ALL_RAY_TRACING))
  212. {
  213. // Create shaders
  214. DynamicArrayAuto<VkPipelineShaderStageCreateInfo> stages(getAllocator(), m_shaders.getSize());
  215. for(U32 i = 0; i < stages.getSize(); ++i)
  216. {
  217. const ShaderImpl& impl = static_cast<const ShaderImpl&>(*m_shaders[i]);
  218. VkPipelineShaderStageCreateInfo& stage = stages[i];
  219. stage = {};
  220. stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  221. stage.stage = VkShaderStageFlagBits(convertShaderTypeBit(ShaderTypeBit(1 << impl.getShaderType())));
  222. stage.pName = "main";
  223. stage.module = impl.m_handle;
  224. stage.pSpecializationInfo = impl.getSpecConstInfo();
  225. }
  226. // Create groups
  227. VkRayTracingShaderGroupCreateInfoKHR defaultGroup = {};
  228. defaultGroup.sType = VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR;
  229. defaultGroup.generalShader = VK_SHADER_UNUSED_KHR;
  230. defaultGroup.closestHitShader = VK_SHADER_UNUSED_KHR;
  231. defaultGroup.anyHitShader = VK_SHADER_UNUSED_KHR;
  232. defaultGroup.intersectionShader = VK_SHADER_UNUSED_KHR;
  233. U32 groupCount = inf.m_rayTracingShaders.m_rayGenShaders.getSize()
  234. + inf.m_rayTracingShaders.m_missShaders.getSize()
  235. + inf.m_rayTracingShaders.m_hitGroups.getSize();
  236. DynamicArrayAuto<VkRayTracingShaderGroupCreateInfoKHR> groups(getAllocator(), groupCount, defaultGroup);
  237. // 1st group is the ray gen
  238. groupCount = 0;
  239. for(U32 i = 0; i < inf.m_rayTracingShaders.m_rayGenShaders.getSize(); ++i)
  240. {
  241. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
  242. groups[groupCount].generalShader = groupCount;
  243. ++groupCount;
  244. }
  245. // Miss
  246. for(U32 i = 0; i < inf.m_rayTracingShaders.m_missShaders.getSize(); ++i)
  247. {
  248. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
  249. groups[groupCount].generalShader = groupCount;
  250. ++groupCount;
  251. }
  252. // The rest of the groups are hit
  253. for(U32 i = 0; i < inf.m_rayTracingShaders.m_hitGroups.getSize(); ++i)
  254. {
  255. groups[groupCount].type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
  256. if(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader)
  257. {
  258. groups[groupCount].anyHitShader =
  259. *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_anyHitShader->getUuid());
  260. }
  261. if(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader)
  262. {
  263. groups[groupCount].closestHitShader =
  264. *shaderUuidToMShadersIdx.find(inf.m_rayTracingShaders.m_hitGroups[i].m_closestHitShader->getUuid());
  265. }
  266. ++groupCount;
  267. }
  268. ANKI_ASSERT(groupCount == groups.getSize());
  269. VkRayTracingPipelineCreateInfoKHR ci = {};
  270. ci.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR;
  271. ci.stageCount = stages.getSize();
  272. ci.pStages = &stages[0];
  273. ci.groupCount = groups.getSize();
  274. ci.pGroups = &groups[0];
  275. ci.maxPipelineRayRecursionDepth = inf.m_rayTracingShaders.m_maxRecursionDepth;
  276. ci.layout = m_pplineLayout.getHandle();
  277. {
  278. ANKI_TRACE_SCOPED_EVENT(VK_PIPELINE_CREATE);
  279. ANKI_VK_CHECK(vkCreateRayTracingPipelinesKHR(
  280. getDevice(), VK_NULL_HANDLE, getGrManagerImpl().getPipelineCache(), 1, &ci, nullptr, &m_rt.m_ppline));
  281. }
  282. // Get RT handles
  283. const U32 handleArraySize =
  284. getGrManagerImpl().getPhysicalDeviceRayTracingProperties().shaderGroupHandleSize * groupCount;
  285. m_rt.m_allHandles.create(getAllocator(), handleArraySize, 0);
  286. ANKI_VK_CHECK(vkGetRayTracingShaderGroupHandlesKHR(getDevice(), m_rt.m_ppline, 0, groupCount, handleArraySize,
  287. &m_rt.m_allHandles[0]));
  288. }
  289. return Error::NONE;
  290. }
  291. } // end namespace anki