3
0

RayTracingPass.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Asset/AssetCommon.h>
  9. #include <AzCore/Asset/AssetManagerBus.h>
  10. #include <Atom/RHI/CommandList.h>
  11. #include <Atom/RHI/Factory.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/DispatchRaysItem.h>
  14. #include <Atom/RHI/RHISystemInterface.h>
  15. #include <Atom/RHI/PipelineState.h>
  16. #include <Atom/RPI.Reflect/Pass/PassTemplate.h>
  17. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  18. #include <Atom/RPI.Public/Base.h>
  19. #include <Atom/RPI.Public/Pass/PassUtils.h>
  20. #include <Atom/RPI.Public/RPIUtils.h>
  21. #include <Atom/RPI.Public/RenderPipeline.h>
  22. #include <Atom/RPI.Public/Scene.h>
  23. #include <Atom/RPI.Public/View.h>
  24. #include <RayTracing/RayTracingPass.h>
  25. #include <RayTracing/RayTracingPassData.h>
  26. #include <RayTracing/RayTracingFeatureProcessor.h>
  27. namespace AZ
  28. {
  29. namespace Render
  30. {
  31. RPI::Ptr<RayTracingPass> RayTracingPass::Create(const RPI::PassDescriptor& descriptor)
  32. {
  33. RPI::Ptr<RayTracingPass> pass = aznew RayTracingPass(descriptor);
  34. return pass;
  35. }
  36. RayTracingPass::RayTracingPass(const RPI::PassDescriptor& descriptor)
  37. : RenderPass(descriptor)
  38. , m_passDescriptor(descriptor)
  39. {
  40. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  41. if (device->GetFeatures().m_rayTracing == false)
  42. {
  43. // raytracing is not supported on this platform
  44. SetEnabled(false);
  45. return;
  46. }
  47. m_passData = RPI::PassUtils::GetPassData<RayTracingPassData>(m_passDescriptor);
  48. if (m_passData == nullptr)
  49. {
  50. AZ_Error("PassSystem", false, "RayTracingPass [%s]: Invalid RayTracingPassData", GetPathName().GetCStr());
  51. return;
  52. }
  53. CreatePipelineState();
  54. }
  55. RayTracingPass::~RayTracingPass()
  56. {
  57. RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
  58. }
  59. void RayTracingPass::CreatePipelineState()
  60. {
  61. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  62. struct RTShaderLib
  63. {
  64. AZ::Data::AssetId m_shaderAssetId;
  65. AZ::Data::Instance<AZ::RPI::Shader> m_shader;
  66. AZ::RHI::PipelineStateDescriptorForRayTracing m_pipelineStateDescriptor;
  67. AZ::Name m_rayGenerationShaderName;
  68. AZ::Name m_missShaderName;
  69. AZ::Name m_closestHitShaderName;
  70. AZ::Name m_closestHitProceduralShaderName;
  71. };
  72. AZStd::fixed_vector<RTShaderLib, 4> shaderLibs;
  73. auto loadRayTracingShader = [&](auto& assetReference) -> RTShaderLib&
  74. {
  75. auto it = std::find_if(
  76. shaderLibs.begin(),
  77. shaderLibs.end(),
  78. [&](auto& entry)
  79. {
  80. return entry.m_shaderAssetId == assetReference.m_assetId;
  81. });
  82. if (it != shaderLibs.end())
  83. {
  84. return *it;
  85. }
  86. auto shaderAsset{ AZ::RPI::FindShaderAsset(assetReference.m_assetId, assetReference.m_filePath) };
  87. AZ_Assert(shaderAsset.IsReady(), "Failed to load shader %s", assetReference.m_filePath.c_str());
  88. auto shader{ AZ::RPI::Shader::FindOrCreate(shaderAsset) };
  89. auto shaderVariant{ shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
  90. AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
  91. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
  92. auto& shaderLib = shaderLibs.emplace_back();
  93. shaderLib.m_shaderAssetId = assetReference.m_assetId;
  94. shaderLib.m_shader = shader;
  95. shaderLib.m_pipelineStateDescriptor = pipelineStateDescriptor;
  96. return shaderLib;
  97. };
  98. auto& rayGenShaderLib{ loadRayTracingShader(m_passData->m_rayGenerationShaderAssetReference) };
  99. rayGenShaderLib.m_rayGenerationShaderName = m_passData->m_rayGenerationShaderName;
  100. m_rayGenerationShader = rayGenShaderLib.m_shader;
  101. auto& closestHitShaderLib{ loadRayTracingShader(m_passData->m_closestHitShaderAssetReference) };
  102. closestHitShaderLib.m_closestHitShaderName = m_passData->m_closestHitShaderName;
  103. m_closestHitShader = closestHitShaderLib.m_shader;
  104. if (!m_passData->m_closestHitProceduralShaderName.empty())
  105. {
  106. auto& closestHitProceduralShaderLib{ loadRayTracingShader(m_passData->m_closestHitProceduralShaderAssetReference) };
  107. closestHitProceduralShaderLib.m_closestHitProceduralShaderName = m_passData->m_closestHitProceduralShaderName;
  108. m_closestHitProceduralShader = closestHitProceduralShaderLib.m_shader;
  109. }
  110. auto& missShaderLib{ loadRayTracingShader(m_passData->m_missShaderAssetReference) };
  111. missShaderLib.m_missShaderName = m_passData->m_missShaderName;
  112. m_missShader = missShaderLib.m_shader;
  113. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(shaderLibs.front().m_pipelineStateDescriptor);
  114. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  115. // create global srg
  116. const auto& globalSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingGlobalSrgBindingSlot);
  117. AZ_Error("PassSystem", globalSrgLayout != nullptr, "RayTracingPass [%s] Failed to find RayTracingGlobalSrg layout", GetPathName().GetCStr());
  118. m_shaderResourceGroup = RPI::ShaderResourceGroup::Create( m_rayGenerationShader->GetAsset(), m_rayGenerationShader->GetSupervariantIndex(), globalSrgLayout->GetName());
  119. AZ_Assert(m_shaderResourceGroup, "RayTracingPass [%s]: Failed to create RayTracingGlobalSrg", GetPathName().GetCStr());
  120. RPI::PassUtils::BindDataMappingsToSrg(m_passDescriptor, m_shaderResourceGroup.get());
  121. // check to see if the shader requires the View, Scene, or RayTracingMaterial Srgs
  122. const auto& viewSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::View);
  123. m_requiresViewSrg = (viewSrgLayout != nullptr);
  124. const auto& sceneSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Scene);
  125. m_requiresSceneSrg = (sceneSrgLayout != nullptr);
  126. const auto& rayTracingMaterialSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingMaterialSrgBindingSlot);
  127. m_requiresRayTracingMaterialSrg = (rayTracingMaterialSrgLayout != nullptr);
  128. const auto& rayTracingSceneSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingSceneSrgBindingSlot);
  129. m_requiresRayTracingSceneSrg = (rayTracingSceneSrgLayout != nullptr);
  130. // build the ray tracing pipeline state descriptor
  131. RHI::RayTracingPipelineStateDescriptor descriptor;
  132. descriptor.Build()
  133. ->PipelineState(m_globalPipelineState.get())
  134. ->MaxPayloadSize(m_passData->m_maxPayloadSize)
  135. ->MaxAttributeSize(m_passData->m_maxAttributeSize)
  136. ->MaxRecursionDepth(m_passData->m_maxRecursionDepth);
  137. for (auto& shaderLib : shaderLibs)
  138. {
  139. descriptor.ShaderLibrary(shaderLib.m_pipelineStateDescriptor);
  140. if (!shaderLib.m_rayGenerationShaderName.IsEmpty())
  141. {
  142. descriptor.RayGenerationShaderName(AZ::Name{ m_passData->m_rayGenerationShaderName });
  143. }
  144. if (!shaderLib.m_closestHitShaderName.IsEmpty())
  145. {
  146. descriptor.ClosestHitShaderName(AZ::Name{ m_passData->m_closestHitShaderName });
  147. }
  148. if (!shaderLib.m_closestHitProceduralShaderName.IsEmpty())
  149. {
  150. descriptor.ClosestHitShaderName(AZ::Name{ m_passData->m_closestHitProceduralShaderName });
  151. }
  152. if (!shaderLib.m_missShaderName.IsEmpty())
  153. {
  154. descriptor.MissShaderName(AZ::Name{ m_passData->m_missShaderName });
  155. }
  156. }
  157. descriptor.HitGroup(AZ::Name("HitGroup"))->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitShaderName.c_str()));
  158. RayTracingFeatureProcessor* rayTracingFeatureProcessor =
  159. GetScene() ? GetScene()->GetFeatureProcessor<RayTracingFeatureProcessor>() : nullptr;
  160. if (rayTracingFeatureProcessor && !m_passData->m_closestHitProceduralShaderName.empty())
  161. {
  162. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  163. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  164. {
  165. auto shaderVariant{ it->m_intersectionShader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
  166. AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
  167. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
  168. descriptor.ShaderLibrary(pipelineStateDescriptor);
  169. descriptor.IntersectionShaderName(it->m_intersectionShaderName);
  170. descriptor.HitGroup(it->m_name)
  171. ->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitProceduralShaderName))
  172. ->IntersectionShaderName(it->m_intersectionShaderName);
  173. }
  174. }
  175. // create the ray tracing pipeline state object
  176. m_rayTracingPipelineState = RHI::Factory::Get().CreateRayTracingPipelineState();
  177. m_rayTracingPipelineState->Init(*device.get(), &descriptor);
  178. // make sure the shader table rebuilds if we're hotreloading
  179. m_rayTracingRevision = 0;
  180. // store the max ray length
  181. m_maxRayLength = m_passData->m_maxRayLength;
  182. RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
  183. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_rayGenerationShaderAssetReference.m_assetId);
  184. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitShaderAssetReference.m_assetId);
  185. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitProceduralShaderAssetReference.m_assetId);
  186. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_missShaderAssetReference.m_assetId);
  187. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_intersectionShaderAssetReference.m_assetId);
  188. }
  189. Data::Instance<RPI::Shader> RayTracingPass::LoadShader(const RPI::AssetReference& shaderAssetReference)
  190. {
  191. Data::Asset<RPI::ShaderAsset> shaderAsset;
  192. if (shaderAssetReference.m_assetId.IsValid())
  193. {
  194. shaderAsset = RPI::FindShaderAsset(shaderAssetReference.m_assetId, shaderAssetReference.m_filePath);
  195. }
  196. if (!shaderAsset.IsReady())
  197. {
  198. AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load shader asset [%s]", GetPathName().GetCStr(), shaderAssetReference.m_filePath.data());
  199. return nullptr;
  200. }
  201. return RPI::Shader::FindOrCreate(shaderAsset);
  202. }
  203. bool RayTracingPass::IsEnabled() const
  204. {
  205. if (!RenderPass::IsEnabled())
  206. {
  207. return false;
  208. }
  209. RPI::Scene* scene = m_pipeline->GetScene();
  210. if (!scene)
  211. {
  212. return false;
  213. }
  214. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  215. if (!rayTracingFeatureProcessor)
  216. {
  217. return false;
  218. }
  219. return true;
  220. }
  221. void RayTracingPass::FrameBeginInternal(FramePrepareParams params)
  222. {
  223. RPI::Scene* scene = m_pipeline->GetScene();
  224. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  225. if (!rayTracingFeatureProcessor)
  226. {
  227. return;
  228. }
  229. if (!m_rayTracingShaderTable)
  230. {
  231. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  232. RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
  233. m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
  234. m_rayTracingShaderTable->Init(*device.get(), rayTracingBufferPools);
  235. }
  236. RPI::RenderPass::FrameBeginInternal(params);
  237. }
  238. void RayTracingPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  239. {
  240. RPI::Scene* scene = m_pipeline->GetScene();
  241. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  242. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  243. RPI::RenderPass::SetupFrameGraphDependencies(frameGraph);
  244. frameGraph.SetEstimatedItemCount(1);
  245. // TLAS
  246. {
  247. const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer();
  248. if (rayTracingTlasBuffer)
  249. {
  250. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  251. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  252. {
  253. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  254. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  255. }
  256. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer()->GetDescriptor().m_byteCount);
  257. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, tlasBufferByteCount);
  258. RHI::BufferScopeAttachmentDescriptor desc;
  259. desc.m_attachmentId = tlasAttachmentId;
  260. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  261. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  262. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  263. }
  264. }
  265. }
  266. void RayTracingPass::CompileResources(const RHI::FrameGraphCompileContext& context)
  267. {
  268. RPI::Scene* scene = m_pipeline->GetScene();
  269. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  270. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  271. if (m_shaderResourceGroup != nullptr)
  272. {
  273. auto constantIndex = m_shaderResourceGroup->FindShaderInputConstantIndex(Name("m_maxRayLength"));
  274. if (constantIndex.IsValid())
  275. {
  276. m_shaderResourceGroup->SetConstant(constantIndex, m_maxRayLength);
  277. }
  278. BindPassSrg(context, m_shaderResourceGroup);
  279. m_shaderResourceGroup->Compile();
  280. }
  281. uint32_t proceduralGeometryTypeRevision = rayTracingFeatureProcessor->GetProceduralGeometryTypeRevision();
  282. if (m_proceduralGeometryTypeRevision != proceduralGeometryTypeRevision)
  283. {
  284. CreatePipelineState();
  285. RPI::SceneNotificationBus::Event(
  286. GetScene()->GetId(),
  287. &RPI::SceneNotification::OnRenderPipelineChanged,
  288. GetRenderPipeline(),
  289. RPI::SceneNotification::RenderPipelineChangeType::PassChanged);
  290. m_proceduralGeometryTypeRevision = proceduralGeometryTypeRevision;
  291. }
  292. uint32_t rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
  293. if (m_rayTracingRevision != rayTracingRevision)
  294. {
  295. // scene changed, need to rebuild the shader table
  296. m_rayTracingRevision = rayTracingRevision;
  297. AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
  298. if (rayTracingFeatureProcessor->HasGeometry())
  299. {
  300. // build the ray tracing shader table descriptor
  301. RHI::RayTracingShaderTableDescriptor* descriptorBuild = descriptor->Build(AZ::Name("RayTracingShaderTable"), m_rayTracingPipelineState)
  302. ->RayGenerationRecord(AZ::Name(m_passData->m_rayGenerationShaderName.c_str()))
  303. ->MissRecord(AZ::Name(m_passData->m_missShaderName.c_str()));
  304. // add a hit group for standard meshes mesh to the shader table
  305. descriptorBuild->HitGroupRecord(AZ::Name("HitGroup"));
  306. // add a hit group for each procedural geometry type to the shader table
  307. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  308. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  309. {
  310. descriptorBuild->HitGroupRecord(it->m_name);
  311. // TODO(intersection): Set per-hitgroup SRG once RayTracingPipelineState supports local root signatures
  312. }
  313. }
  314. m_rayTracingShaderTable->Build(descriptor);
  315. }
  316. }
  317. void RayTracingPass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  318. {
  319. RPI::Scene* scene = m_pipeline->GetScene();
  320. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  321. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  322. if (!rayTracingFeatureProcessor ||
  323. !rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer() ||
  324. !rayTracingFeatureProcessor->HasGeometry() ||
  325. !m_rayTracingShaderTable)
  326. {
  327. return;
  328. }
  329. RHI::DispatchRaysItem dispatchRaysItem;
  330. // calculate thread counts if this is a full screen raytracing pass
  331. if (m_passData->m_makeFullscreenPass)
  332. {
  333. RPI::PassAttachment* outputAttachment = nullptr;
  334. if (GetOutputCount() > 0)
  335. {
  336. outputAttachment = GetOutputBinding(0).GetAttachment().get();
  337. }
  338. else if (GetInputOutputCount() > 0)
  339. {
  340. outputAttachment = GetInputOutputBinding(0).GetAttachment().get();
  341. }
  342. AZ_Assert(outputAttachment != nullptr, "[RayTracingPass '%s']: A fullscreen RayTracing pass must have a valid output or input/output.", GetPathName().GetCStr());
  343. AZ_Assert(outputAttachment->GetAttachmentType() == RHI::AttachmentType::Image, "[RayTracingPass '%s']: The output of a fullscreen RayTracing pass must be an image.", GetPathName().GetCStr());
  344. RHI::Size imageSize = outputAttachment->m_descriptor.m_image.m_size;
  345. dispatchRaysItem.m_arguments.m_direct.m_width = imageSize.m_width;
  346. dispatchRaysItem.m_arguments.m_direct.m_height = imageSize.m_height;
  347. dispatchRaysItem.m_arguments.m_direct.m_depth = imageSize.m_depth;
  348. }
  349. else
  350. {
  351. dispatchRaysItem.m_arguments.m_direct.m_width = m_passData->m_threadCountX;
  352. dispatchRaysItem.m_arguments.m_direct.m_height = m_passData->m_threadCountY;
  353. dispatchRaysItem.m_arguments.m_direct.m_depth = m_passData->m_threadCountZ;
  354. }
  355. // bind RayTracingGlobal, RayTracingScene, and View Srgs
  356. // [GFX TODO][ATOM-15610] Add RenderPass::SetSrgsForRayTracingDispatch
  357. AZStd::vector<RHI::ShaderResourceGroup*> shaderResourceGroups = { m_shaderResourceGroup->GetRHIShaderResourceGroup() };
  358. if (m_requiresRayTracingSceneSrg)
  359. {
  360. shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup());
  361. }
  362. if (m_requiresViewSrg)
  363. {
  364. RPI::ViewPtr view = m_pipeline->GetFirstView(GetPipelineViewTag());
  365. if (view)
  366. {
  367. shaderResourceGroups.push_back(view->GetRHIShaderResourceGroup());
  368. }
  369. }
  370. if (m_requiresSceneSrg)
  371. {
  372. shaderResourceGroups.push_back(scene->GetShaderResourceGroup()->GetRHIShaderResourceGroup());
  373. }
  374. if (m_requiresRayTracingMaterialSrg)
  375. {
  376. shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup());
  377. }
  378. dispatchRaysItem.m_shaderResourceGroupCount = aznumeric_cast<uint32_t>(shaderResourceGroups.size());
  379. dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups.data();
  380. dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState.get();
  381. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable.get();
  382. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState.get();
  383. // submit the DispatchRays item
  384. context.GetCommandList()->Submit(dispatchRaysItem);
  385. }
  386. void RayTracingPass::OnShaderReinitialized([[maybe_unused]] const RPI::Shader& shader)
  387. {
  388. CreatePipelineState();
  389. }
  390. void RayTracingPass::OnShaderAssetReinitialized([[maybe_unused]] const Data::Asset<RPI::ShaderAsset>& shaderAsset)
  391. {
  392. CreatePipelineState();
  393. }
  394. void RayTracingPass::OnShaderVariantReinitialized(const RPI::ShaderVariant&)
  395. {
  396. CreatePipelineState();
  397. }
  398. } // namespace Render
  399. } // namespace AZ