RayTracingPass.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Asset/AssetCommon.h>
  9. #include <AzCore/Asset/AssetManagerBus.h>
  10. #include <Atom/RHI/CommandList.h>
  11. #include <Atom/RHI/Factory.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/SingleDeviceDispatchRaysItem.h>
  14. #include <Atom/RHI/RHISystemInterface.h>
  15. #include <Atom/RHI/SingleDevicePipelineState.h>
  16. #include <Atom/RPI.Reflect/Pass/PassTemplate.h>
  17. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  18. #include <Atom/RPI.Public/Base.h>
  19. #include <Atom/RPI.Public/Pass/PassUtils.h>
  20. #include <Atom/RPI.Public/RPIUtils.h>
  21. #include <Atom/RPI.Public/RenderPipeline.h>
  22. #include <Atom/RPI.Public/Scene.h>
  23. #include <Atom/RPI.Public/View.h>
  24. #include <RayTracing/RayTracingPass.h>
  25. #include <RayTracing/RayTracingPassData.h>
  26. #include <RayTracing/RayTracingFeatureProcessor.h>
  27. namespace AZ
  28. {
  29. namespace Render
  30. {
  31. RPI::Ptr<RayTracingPass> RayTracingPass::Create(const RPI::PassDescriptor& descriptor)
  32. {
  33. RPI::Ptr<RayTracingPass> pass = aznew RayTracingPass(descriptor);
  34. return pass;
  35. }
  36. RayTracingPass::RayTracingPass(const RPI::PassDescriptor& descriptor)
  37. : RenderPass(descriptor)
  38. , m_passDescriptor(descriptor)
  39. {
  40. if (RHI::RHISystemInterface::Get()->GetRayTracingSupport() == RHI::MultiDevice::NoDevices)
  41. {
  42. // raytracing is not supported on this platform
  43. SetEnabled(false);
  44. return;
  45. }
  46. m_passData = RPI::PassUtils::GetPassData<RayTracingPassData>(m_passDescriptor);
  47. if (m_passData == nullptr)
  48. {
  49. AZ_Error("PassSystem", false, "RayTracingPass [%s]: Invalid RayTracingPassData", GetPathName().GetCStr());
  50. return;
  51. }
  52. CreatePipelineState();
  53. }
  54. RayTracingPass::~RayTracingPass()
  55. {
  56. RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
  57. }
  58. void RayTracingPass::CreatePipelineState()
  59. {
  60. struct RTShaderLib
  61. {
  62. AZ::Data::AssetId m_shaderAssetId;
  63. AZ::Data::Instance<AZ::RPI::Shader> m_shader;
  64. AZ::RHI::PipelineStateDescriptorForRayTracing m_pipelineStateDescriptor;
  65. AZ::Name m_rayGenerationShaderName;
  66. AZ::Name m_missShaderName;
  67. AZ::Name m_closestHitShaderName;
  68. AZ::Name m_closestHitProceduralShaderName;
  69. };
  70. AZStd::fixed_vector<RTShaderLib, 4> shaderLibs;
  71. auto loadRayTracingShader = [&](auto& assetReference) -> RTShaderLib&
  72. {
  73. auto it = std::find_if(
  74. shaderLibs.begin(),
  75. shaderLibs.end(),
  76. [&](auto& entry)
  77. {
  78. return entry.m_shaderAssetId == assetReference.m_assetId;
  79. });
  80. if (it != shaderLibs.end())
  81. {
  82. return *it;
  83. }
  84. auto shaderAsset{ AZ::RPI::FindShaderAsset(assetReference.m_assetId, assetReference.m_filePath) };
  85. AZ_Assert(shaderAsset.IsReady(), "Failed to load shader %s", assetReference.m_filePath.c_str());
  86. auto shader{ AZ::RPI::Shader::FindOrCreate(shaderAsset) };
  87. auto shaderVariant{ shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
  88. AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
  89. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
  90. auto& shaderLib = shaderLibs.emplace_back();
  91. shaderLib.m_shaderAssetId = assetReference.m_assetId;
  92. shaderLib.m_shader = shader;
  93. shaderLib.m_pipelineStateDescriptor = pipelineStateDescriptor;
  94. return shaderLib;
  95. };
  96. auto& rayGenShaderLib{ loadRayTracingShader(m_passData->m_rayGenerationShaderAssetReference) };
  97. rayGenShaderLib.m_rayGenerationShaderName = m_passData->m_rayGenerationShaderName;
  98. m_rayGenerationShader = rayGenShaderLib.m_shader;
  99. auto& closestHitShaderLib{ loadRayTracingShader(m_passData->m_closestHitShaderAssetReference) };
  100. closestHitShaderLib.m_closestHitShaderName = m_passData->m_closestHitShaderName;
  101. m_closestHitShader = closestHitShaderLib.m_shader;
  102. if (!m_passData->m_closestHitProceduralShaderName.empty())
  103. {
  104. auto& closestHitProceduralShaderLib{ loadRayTracingShader(m_passData->m_closestHitProceduralShaderAssetReference) };
  105. closestHitProceduralShaderLib.m_closestHitProceduralShaderName = m_passData->m_closestHitProceduralShaderName;
  106. m_closestHitProceduralShader = closestHitProceduralShaderLib.m_shader;
  107. }
  108. auto& missShaderLib{ loadRayTracingShader(m_passData->m_missShaderAssetReference) };
  109. missShaderLib.m_missShaderName = m_passData->m_missShaderName;
  110. m_missShader = missShaderLib.m_shader;
  111. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(shaderLibs.front().m_pipelineStateDescriptor);
  112. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  113. // create global srg
  114. const auto& globalSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingGlobalSrgBindingSlot);
  115. AZ_Error("PassSystem", globalSrgLayout != nullptr, "RayTracingPass [%s] Failed to find RayTracingGlobalSrg layout", GetPathName().GetCStr());
  116. m_shaderResourceGroup = RPI::ShaderResourceGroup::Create( m_rayGenerationShader->GetAsset(), m_rayGenerationShader->GetSupervariantIndex(), globalSrgLayout->GetName());
  117. AZ_Assert(m_shaderResourceGroup, "RayTracingPass [%s]: Failed to create RayTracingGlobalSrg", GetPathName().GetCStr());
  118. RPI::PassUtils::BindDataMappingsToSrg(m_passDescriptor, m_shaderResourceGroup.get());
  119. // check to see if the shader requires the View, Scene, or RayTracingMaterial Srgs
  120. const auto& viewSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::View);
  121. m_requiresViewSrg = (viewSrgLayout != nullptr);
  122. const auto& sceneSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Scene);
  123. m_requiresSceneSrg = (sceneSrgLayout != nullptr);
  124. const auto& rayTracingMaterialSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingMaterialSrgBindingSlot);
  125. m_requiresRayTracingMaterialSrg = (rayTracingMaterialSrgLayout != nullptr);
  126. const auto& rayTracingSceneSrgLayout = m_rayGenerationShader->FindShaderResourceGroupLayout(RayTracingSceneSrgBindingSlot);
  127. m_requiresRayTracingSceneSrg = (rayTracingSceneSrgLayout != nullptr);
  128. // build the ray tracing pipeline state descriptor
  129. RHI::MultiDeviceRayTracingPipelineStateDescriptor descriptor;
  130. descriptor.Build()
  131. ->PipelineState(m_globalPipelineState.get())
  132. ->MaxPayloadSize(m_passData->m_maxPayloadSize)
  133. ->MaxAttributeSize(m_passData->m_maxAttributeSize)
  134. ->MaxRecursionDepth(m_passData->m_maxRecursionDepth);
  135. for (auto& shaderLib : shaderLibs)
  136. {
  137. descriptor.ShaderLibrary(shaderLib.m_pipelineStateDescriptor);
  138. if (!shaderLib.m_rayGenerationShaderName.IsEmpty())
  139. {
  140. descriptor.RayGenerationShaderName(AZ::Name{ m_passData->m_rayGenerationShaderName });
  141. }
  142. if (!shaderLib.m_closestHitShaderName.IsEmpty())
  143. {
  144. descriptor.ClosestHitShaderName(AZ::Name{ m_passData->m_closestHitShaderName });
  145. }
  146. if (!shaderLib.m_closestHitProceduralShaderName.IsEmpty())
  147. {
  148. descriptor.ClosestHitShaderName(AZ::Name{ m_passData->m_closestHitProceduralShaderName });
  149. }
  150. if (!shaderLib.m_missShaderName.IsEmpty())
  151. {
  152. descriptor.MissShaderName(AZ::Name{ m_passData->m_missShaderName });
  153. }
  154. }
  155. descriptor.HitGroup(AZ::Name("HitGroup"))->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitShaderName.c_str()));
  156. RayTracingFeatureProcessor* rayTracingFeatureProcessor =
  157. GetScene() ? GetScene()->GetFeatureProcessor<RayTracingFeatureProcessor>() : nullptr;
  158. if (rayTracingFeatureProcessor && !m_passData->m_closestHitProceduralShaderName.empty())
  159. {
  160. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  161. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  162. {
  163. auto shaderVariant{ it->m_intersectionShader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
  164. AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
  165. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
  166. descriptor.ShaderLibrary(pipelineStateDescriptor);
  167. descriptor.IntersectionShaderName(it->m_intersectionShaderName);
  168. descriptor.HitGroup(it->m_name)
  169. ->ClosestHitShaderName(AZ::Name(m_passData->m_closestHitProceduralShaderName))
  170. ->IntersectionShaderName(it->m_intersectionShaderName);
  171. }
  172. }
  173. // create the ray tracing pipeline state object
  174. m_rayTracingPipelineState = aznew RHI::MultiDeviceRayTracingPipelineState;
  175. m_rayTracingPipelineState->Init(RHI::RHISystemInterface::Get()->GetRayTracingSupport(), descriptor);
  176. // make sure the shader table rebuilds if we're hotreloading
  177. m_rayTracingRevision = 0;
  178. // store the max ray length
  179. m_maxRayLength = m_passData->m_maxRayLength;
  180. RPI::ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
  181. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_rayGenerationShaderAssetReference.m_assetId);
  182. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitShaderAssetReference.m_assetId);
  183. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_closestHitProceduralShaderAssetReference.m_assetId);
  184. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_missShaderAssetReference.m_assetId);
  185. RPI::ShaderReloadNotificationBus::MultiHandler::BusConnect(m_passData->m_intersectionShaderAssetReference.m_assetId);
  186. }
  187. Data::Instance<RPI::Shader> RayTracingPass::LoadShader(const RPI::AssetReference& shaderAssetReference)
  188. {
  189. Data::Asset<RPI::ShaderAsset> shaderAsset;
  190. if (shaderAssetReference.m_assetId.IsValid())
  191. {
  192. shaderAsset = RPI::FindShaderAsset(shaderAssetReference.m_assetId, shaderAssetReference.m_filePath);
  193. }
  194. if (!shaderAsset.IsReady())
  195. {
  196. AZ_Error("PassSystem", false, "RayTracingPass [%s]: Failed to load shader asset [%s]", GetPathName().GetCStr(), shaderAssetReference.m_filePath.data());
  197. return nullptr;
  198. }
  199. return RPI::Shader::FindOrCreate(shaderAsset);
  200. }
  201. bool RayTracingPass::IsEnabled() const
  202. {
  203. if (!RenderPass::IsEnabled())
  204. {
  205. return false;
  206. }
  207. RPI::Scene* scene = m_pipeline->GetScene();
  208. if (!scene)
  209. {
  210. return false;
  211. }
  212. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  213. if (!rayTracingFeatureProcessor)
  214. {
  215. return false;
  216. }
  217. return true;
  218. }
  219. void RayTracingPass::FrameBeginInternal(FramePrepareParams params)
  220. {
  221. RPI::Scene* scene = m_pipeline->GetScene();
  222. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  223. if (!rayTracingFeatureProcessor)
  224. {
  225. return;
  226. }
  227. if (!m_rayTracingShaderTable)
  228. {
  229. RHI::MultiDeviceRayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
  230. m_rayTracingShaderTable = aznew RHI::MultiDeviceRayTracingShaderTable;
  231. m_rayTracingShaderTable->Init(RHI::RHISystemInterface::Get()->GetRayTracingSupport(), rayTracingBufferPools);
  232. }
  233. RPI::RenderPass::FrameBeginInternal(params);
  234. }
  235. void RayTracingPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  236. {
  237. RPI::Scene* scene = m_pipeline->GetScene();
  238. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  239. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  240. RPI::RenderPass::SetupFrameGraphDependencies(frameGraph);
  241. frameGraph.SetEstimatedItemCount(1);
  242. // TLAS
  243. {
  244. const RHI::Ptr<RHI::MultiDeviceBuffer>& rayTracingTlasBuffer = rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer();
  245. if (rayTracingTlasBuffer)
  246. {
  247. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  248. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  249. {
  250. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  251. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  252. }
  253. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer()->GetDescriptor().m_byteCount);
  254. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, tlasBufferByteCount);
  255. RHI::BufferScopeAttachmentDescriptor desc;
  256. desc.m_attachmentId = tlasAttachmentId;
  257. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  258. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  259. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  260. }
  261. }
  262. }
  263. void RayTracingPass::CompileResources(const RHI::FrameGraphCompileContext& context)
  264. {
  265. RPI::Scene* scene = m_pipeline->GetScene();
  266. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  267. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  268. if (m_shaderResourceGroup != nullptr)
  269. {
  270. auto constantIndex = m_shaderResourceGroup->FindShaderInputConstantIndex(Name("m_maxRayLength"));
  271. if (constantIndex.IsValid())
  272. {
  273. m_shaderResourceGroup->SetConstant(constantIndex, m_maxRayLength);
  274. }
  275. BindPassSrg(context, m_shaderResourceGroup);
  276. m_shaderResourceGroup->Compile();
  277. }
  278. uint32_t proceduralGeometryTypeRevision = rayTracingFeatureProcessor->GetProceduralGeometryTypeRevision();
  279. if (m_proceduralGeometryTypeRevision != proceduralGeometryTypeRevision)
  280. {
  281. CreatePipelineState();
  282. RPI::SceneNotificationBus::Event(
  283. GetScene()->GetId(),
  284. &RPI::SceneNotification::OnRenderPipelineChanged,
  285. GetRenderPipeline(),
  286. RPI::SceneNotification::RenderPipelineChangeType::PassChanged);
  287. m_proceduralGeometryTypeRevision = proceduralGeometryTypeRevision;
  288. }
  289. uint32_t rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
  290. if (m_rayTracingRevision != rayTracingRevision)
  291. {
  292. // scene changed, need to rebuild the shader table
  293. m_rayTracingRevision = rayTracingRevision;
  294. AZStd::shared_ptr<RHI::MultiDeviceRayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::MultiDeviceRayTracingShaderTableDescriptor>();
  295. if (rayTracingFeatureProcessor->HasGeometry())
  296. {
  297. // build the ray tracing shader table descriptor
  298. RHI::MultiDeviceRayTracingShaderTableDescriptor* descriptorBuild = descriptor->Build(AZ::Name("RayTracingShaderTable"), m_rayTracingPipelineState)
  299. ->RayGenerationRecord(AZ::Name(m_passData->m_rayGenerationShaderName.c_str()))
  300. ->MissRecord(AZ::Name(m_passData->m_missShaderName.c_str()));
  301. // add a hit group for standard meshes mesh to the shader table
  302. descriptorBuild->HitGroupRecord(AZ::Name("HitGroup"));
  303. // add a hit group for each procedural geometry type to the shader table
  304. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  305. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  306. {
  307. descriptorBuild->HitGroupRecord(it->m_name);
  308. // TODO(intersection): Set per-hitgroup SRG once RayTracingPipelineState supports local root signatures
  309. }
  310. }
  311. m_rayTracingShaderTable->Build(descriptor);
  312. }
  313. }
  314. void RayTracingPass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  315. {
  316. RPI::Scene* scene = m_pipeline->GetScene();
  317. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  318. AZ_Assert(rayTracingFeatureProcessor, "RayTracingPass requires the RayTracingFeatureProcessor");
  319. if (!rayTracingFeatureProcessor ||
  320. !rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer() ||
  321. !rayTracingFeatureProcessor->HasGeometry() ||
  322. !m_rayTracingShaderTable)
  323. {
  324. return;
  325. }
  326. RHI::SingleDeviceDispatchRaysItem dispatchRaysItem;
  327. // calculate thread counts if this is a full screen raytracing pass
  328. if (m_passData->m_makeFullscreenPass)
  329. {
  330. RPI::PassAttachment* outputAttachment = nullptr;
  331. if (GetOutputCount() > 0)
  332. {
  333. outputAttachment = GetOutputBinding(0).GetAttachment().get();
  334. }
  335. else if (GetInputOutputCount() > 0)
  336. {
  337. outputAttachment = GetInputOutputBinding(0).GetAttachment().get();
  338. }
  339. AZ_Assert(outputAttachment != nullptr, "[RayTracingPass '%s']: A fullscreen RayTracing pass must have a valid output or input/output.", GetPathName().GetCStr());
  340. AZ_Assert(outputAttachment->GetAttachmentType() == RHI::AttachmentType::Image, "[RayTracingPass '%s']: The output of a fullscreen RayTracing pass must be an image.", GetPathName().GetCStr());
  341. RHI::Size imageSize = outputAttachment->m_descriptor.m_image.m_size;
  342. dispatchRaysItem.m_arguments.m_direct.m_width = imageSize.m_width;
  343. dispatchRaysItem.m_arguments.m_direct.m_height = imageSize.m_height;
  344. dispatchRaysItem.m_arguments.m_direct.m_depth = imageSize.m_depth;
  345. }
  346. else
  347. {
  348. dispatchRaysItem.m_arguments.m_direct.m_width = m_passData->m_threadCountX;
  349. dispatchRaysItem.m_arguments.m_direct.m_height = m_passData->m_threadCountY;
  350. dispatchRaysItem.m_arguments.m_direct.m_depth = m_passData->m_threadCountZ;
  351. }
  352. // bind RayTracingGlobal, RayTracingScene, and View Srgs
  353. // [GFX TODO][ATOM-15610] Add RenderPass::SetSrgsForRayTracingDispatch
  354. AZStd::vector<RHI::SingleDeviceShaderResourceGroup*> shaderResourceGroups = { m_shaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get() };
  355. if (m_requiresRayTracingSceneSrg)
  356. {
  357. shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
  358. }
  359. if (m_requiresViewSrg)
  360. {
  361. RPI::ViewPtr view = m_pipeline->GetFirstView(GetPipelineViewTag());
  362. if (view)
  363. {
  364. shaderResourceGroups.push_back(view->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
  365. }
  366. }
  367. if (m_requiresSceneSrg)
  368. {
  369. shaderResourceGroups.push_back(scene->GetShaderResourceGroup()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
  370. }
  371. if (m_requiresRayTracingMaterialSrg)
  372. {
  373. shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
  374. }
  375. dispatchRaysItem.m_shaderResourceGroupCount = aznumeric_cast<uint32_t>(shaderResourceGroups.size());
  376. dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups.data();
  377. dispatchRaysItem.m_rayTracingPipelineState =
  378. m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
  379. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
  380. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  381. // submit the DispatchRays item
  382. context.GetCommandList()->Submit(dispatchRaysItem);
  383. }
  384. void RayTracingPass::OnShaderReinitialized([[maybe_unused]] const RPI::Shader& shader)
  385. {
  386. CreatePipelineState();
  387. }
  388. void RayTracingPass::OnShaderAssetReinitialized([[maybe_unused]] const Data::Asset<RPI::ShaderAsset>& shaderAsset)
  389. {
  390. CreatePipelineState();
  391. }
  392. void RayTracingPass::OnShaderVariantReinitialized(const RPI::ShaderVariant&)
  393. {
  394. CreatePipelineState();
  395. }
  396. } // namespace Render
  397. } // namespace AZ