DiffuseProbeGridVisualizationRayTracingPass.cpp 16 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/CommandList.h>
  9. #include <Atom/RHI/DispatchRaysItem.h>
  10. #include <Atom/RHI/Factory.h>
  11. #include <Atom/RHI/RHISystemInterface.h>
  12. #include <Atom/RPI.Public/RenderPipeline.h>
  13. #include <Atom/RPI.Public/Scene.h>
  14. #include <Atom/RPI.Public/RPIUtils.h>
  15. #include <Atom/RPI.Public/View.h>
  16. #include <DiffuseProbeGrid_Traits_Platform.h>
  17. #include <Render/DiffuseProbeGridFeatureProcessor.h>
  18. #include <Render/DiffuseProbeGridVisualizationRayTracingPass.h>
  19. #include <RayTracing/RayTracingFeatureProcessor.h>
  20. namespace AZ
  21. {
  22. namespace Render
  23. {
  24. RPI::Ptr<DiffuseProbeGridVisualizationRayTracingPass> DiffuseProbeGridVisualizationRayTracingPass::Create(const RPI::PassDescriptor& descriptor)
  25. {
  26. RPI::Ptr<DiffuseProbeGridVisualizationRayTracingPass> pass = aznew DiffuseProbeGridVisualizationRayTracingPass(descriptor);
  27. return AZStd::move(pass);
  28. }
  29. DiffuseProbeGridVisualizationRayTracingPass::DiffuseProbeGridVisualizationRayTracingPass(const RPI::PassDescriptor& descriptor)
  30. : RPI::RenderPass(descriptor)
  31. {
  32. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  33. if (device->GetFeatures().m_rayTracing == false || !AZ_TRAIT_DIFFUSE_GI_PASSES_SUPPORTED)
  34. {
  35. // raytracing or GI is not supported on this platform
  36. SetEnabled(false);
  37. }
  38. }
  39. void DiffuseProbeGridVisualizationRayTracingPass::CreateRayTracingPipelineState()
  40. {
  41. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  42. // load the ray tracing shader
  43. // Note: the shader may not be available on all platforms
  44. AZStd::string shaderFilePath = "Shaders/DiffuseGlobalIllumination/DiffuseProbeGridVisualizationRayTracing.azshader";
  45. m_rayTracingShader = RPI::LoadCriticalShader(shaderFilePath);
  46. if (m_rayTracingShader == nullptr)
  47. {
  48. return;
  49. }
  50. auto shaderVariant = m_rayTracingShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  51. RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
  52. shaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
  53. // closest hit shader
  54. AZStd::string closestHitShaderFilePath = "Shaders/DiffuseGlobalIllumination/DiffuseProbeGridVisualizationRayTracingClosestHit.azshader";
  55. m_closestHitShader = RPI::LoadCriticalShader(closestHitShaderFilePath);
  56. auto closestHitShaderVariant = m_closestHitShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  57. RHI::PipelineStateDescriptorForRayTracing closestHitShaderDescriptor;
  58. closestHitShaderVariant.ConfigurePipelineState(closestHitShaderDescriptor);
  59. // miss shader
  60. AZStd::string missShaderFilePath = "Shaders/DiffuseGlobalIllumination/DiffuseProbeGridVisualizationRayTracingMiss.azshader";
  61. m_missShader = RPI::LoadCriticalShader(missShaderFilePath);
  62. auto missShaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  63. RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
  64. missShaderVariant.ConfigurePipelineState(missShaderDescriptor);
  65. // global pipeline state and Srg
  66. m_globalPipelineState = m_rayTracingShader->AcquirePipelineState(rayGenerationShaderDescriptor);
  67. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  68. m_globalSrgLayout = m_rayTracingShader->FindShaderResourceGroupLayout(Name{ "RayTracingGlobalSrg" });
  69. AZ_Assert(m_globalSrgLayout != nullptr, "Failed to find RayTracingGlobalSrg layout for shader [%s]", shaderFilePath.c_str());
  70. // build the ray tracing pipeline state descriptor
  71. RHI::RayTracingPipelineStateDescriptor descriptor;
  72. descriptor.Build()
  73. ->PipelineState(m_globalPipelineState.get())
  74. ->MaxPayloadSize(64)
  75. ->MaxAttributeSize(32)
  76. ->MaxRecursionDepth(2)
  77. ->ShaderLibrary(rayGenerationShaderDescriptor)
  78. ->RayGenerationShaderName(AZ::Name("RayGen"))
  79. ->ShaderLibrary(missShaderDescriptor)
  80. ->MissShaderName(AZ::Name("Miss"))
  81. ->ShaderLibrary(closestHitShaderDescriptor)
  82. ->ClosestHitShaderName(AZ::Name("ClosestHit"))
  83. ->HitGroup(AZ::Name("HitGroup"))
  84. ->ClosestHitShaderName(AZ::Name("ClosestHit"));
  85. // create the ray tracing pipeline state object
  86. m_rayTracingPipelineState = RHI::Factory::Get().CreateRayTracingPipelineState();
  87. m_rayTracingPipelineState->Init(*device.get(), &descriptor);
  88. }
  89. bool DiffuseProbeGridVisualizationRayTracingPass::IsEnabled() const
  90. {
  91. if (!RenderPass::IsEnabled())
  92. {
  93. return false;
  94. }
  95. RPI::Scene* scene = m_pipeline->GetScene();
  96. if (!scene)
  97. {
  98. return false;
  99. }
  100. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  101. if (diffuseProbeGridFeatureProcessor)
  102. {
  103. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  104. {
  105. if (diffuseProbeGrid->GetVisualizationEnabled())
  106. {
  107. return true;
  108. }
  109. }
  110. }
  111. return false;
  112. }
  113. void DiffuseProbeGridVisualizationRayTracingPass::FrameBeginInternal(FramePrepareParams params)
  114. {
  115. RPI::Scene* scene = m_pipeline->GetScene();
  116. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  117. if (!m_initialized)
  118. {
  119. CreateRayTracingPipelineState();
  120. m_initialized = true;
  121. }
  122. if (!m_rayTracingShaderTable)
  123. {
  124. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  125. RHI::RayTracingBufferPools& rayTracingBufferPools = diffuseProbeGridFeatureProcessor->GetVisualizationBufferPools();
  126. m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
  127. m_rayTracingShaderTable->Init(*device.get(), rayTracingBufferPools);
  128. AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
  129. // build the ray tracing shader table descriptor
  130. descriptor->Build(AZ::Name("RayTracingShaderTable"), m_rayTracingPipelineState)
  131. ->RayGenerationRecord(AZ::Name("RayGen"))
  132. ->MissRecord(AZ::Name("Miss"))
  133. ->HitGroupRecord(AZ::Name("HitGroup"))
  134. ;
  135. m_rayTracingShaderTable->Build(descriptor);
  136. }
  137. RenderPass::FrameBeginInternal(params);
  138. }
  139. void DiffuseProbeGridVisualizationRayTracingPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  140. {
  141. RenderPass::SetupFrameGraphDependencies(frameGraph);
  142. RPI::Scene* scene = m_pipeline->GetScene();
  143. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  144. frameGraph.SetEstimatedItemCount(aznumeric_cast<uint32_t>(diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids().size()));
  145. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  146. {
  147. if (!diffuseProbeGrid->GetVisualizationEnabled())
  148. {
  149. continue;
  150. }
  151. // TLAS
  152. {
  153. AZ::RHI::AttachmentId tlasAttachmentId = diffuseProbeGrid->GetProbeVisualizationTlasAttachmentId();
  154. const RHI::Ptr<RHI::Buffer>& visualizationTlasBuffer = diffuseProbeGrid->GetVisualizationTlas()->GetTlasBuffer();
  155. if (visualizationTlasBuffer)
  156. {
  157. if (!frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId))
  158. {
  159. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, visualizationTlasBuffer);
  160. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  161. }
  162. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(visualizationTlasBuffer->GetDescriptor().m_byteCount);
  163. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, tlasBufferByteCount);
  164. RHI::BufferScopeAttachmentDescriptor desc;
  165. desc.m_attachmentId = tlasAttachmentId;
  166. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  167. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  168. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  169. }
  170. }
  171. // grid data
  172. {
  173. RHI::BufferScopeAttachmentDescriptor desc;
  174. desc.m_attachmentId = diffuseProbeGrid->GetGridDataBufferAttachmentId();
  175. desc.m_bufferViewDescriptor = diffuseProbeGrid->GetRenderData()->m_gridDataBufferViewDescriptor;
  176. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  177. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
  178. }
  179. // probe irradiance
  180. {
  181. RHI::ImageScopeAttachmentDescriptor desc;
  182. desc.m_attachmentId = diffuseProbeGrid->GetIrradianceImageAttachmentId();
  183. desc.m_imageViewDescriptor = diffuseProbeGrid->GetRenderData()->m_probeIrradianceImageViewDescriptor;
  184. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  185. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
  186. }
  187. // probe distance
  188. {
  189. RHI::ImageScopeAttachmentDescriptor desc;
  190. desc.m_attachmentId = diffuseProbeGrid->GetDistanceImageAttachmentId();
  191. desc.m_imageViewDescriptor = diffuseProbeGrid->GetRenderData()->m_probeDistanceImageViewDescriptor;
  192. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  193. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
  194. }
  195. // probe data
  196. {
  197. RHI::ImageScopeAttachmentDescriptor desc;
  198. desc.m_attachmentId = diffuseProbeGrid->GetProbeDataImageAttachmentId();
  199. desc.m_imageViewDescriptor = diffuseProbeGrid->GetRenderData()->m_probeDataImageViewDescriptor;
  200. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  201. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
  202. }
  203. }
  204. // retrieve the visualization image size, this will determine the number of rays to cast
  205. RPI::Ptr<RPI::PassAttachment> visualizationImageAttachment = m_ownedAttachments[0];
  206. AZ_Assert(visualizationImageAttachment.get(), "Invalid DiffuseProbeGrid Visualization image");
  207. m_outputAttachmentSize = visualizationImageAttachment->GetTransientImageDescriptor().m_imageDescriptor.m_size;
  208. }
  209. void DiffuseProbeGridVisualizationRayTracingPass::CompileResources([[maybe_unused]] const RHI::FrameGraphCompileContext& context)
  210. {
  211. const RHI::ImageView* outputImageView = context.GetImageView(GetOutputBinding(0).GetAttachment()->GetAttachmentId());
  212. AZ_Assert(outputImageView, "Failed to retrieve output ImageView");
  213. RPI::Scene* scene = m_pipeline->GetScene();
  214. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  215. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  216. {
  217. if (!diffuseProbeGrid->GetVisualizationEnabled())
  218. {
  219. continue;
  220. }
  221. // the DiffuseProbeGridVisualization Srg must be updated in the Compile phase in order to successfully bind the ReadWrite shader
  222. // inputs (see line ValidateSetImageView() in ShaderResourceGroupData.cpp)
  223. diffuseProbeGrid->UpdateVisualizationRayTraceSrg(m_rayTracingShader, m_globalSrgLayout, outputImageView);
  224. diffuseProbeGrid->GetVisualizationRayTraceSrg()->Compile();
  225. }
  226. }
  227. void DiffuseProbeGridVisualizationRayTracingPass::BuildCommandListInternal([[maybe_unused]] const RHI::FrameGraphExecuteContext& context)
  228. {
  229. RPI::Scene* scene = m_pipeline->GetScene();
  230. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  231. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  232. AZ_Assert(rayTracingFeatureProcessor, "DiffuseProbeGridVisualizationRayTracingPass requires the RayTracingFeatureProcessor");
  233. const AZStd::vector<RPI::ViewPtr>& views = m_pipeline->GetViews(RPI::PipelineViewTag{ "MainCamera" });
  234. if (views.empty())
  235. {
  236. return;
  237. }
  238. // submit the DispatchRaysItems for each DiffuseProbeGrid in this range
  239. for (uint32_t index = context.GetSubmitRange().m_startIndex; index < context.GetSubmitRange().m_endIndex; ++index)
  240. {
  241. AZStd::shared_ptr<DiffuseProbeGrid> diffuseProbeGrid = diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids()[index];
  242. if (!diffuseProbeGrid->GetVisualizationEnabled())
  243. {
  244. continue;
  245. }
  246. const RHI::ShaderResourceGroup* shaderResourceGroups[] = {
  247. diffuseProbeGrid->GetVisualizationRayTraceSrg()->GetRHIShaderResourceGroup(),
  248. rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup(),
  249. views[0]->GetRHIShaderResourceGroup(),
  250. };
  251. RHI::DispatchRaysItem dispatchRaysItem;
  252. dispatchRaysItem.m_arguments.m_direct.m_width = m_outputAttachmentSize.m_width;
  253. dispatchRaysItem.m_arguments.m_direct.m_height = m_outputAttachmentSize.m_height;
  254. dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
  255. dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState.get();
  256. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable.get();
  257. dispatchRaysItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
  258. dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
  259. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState.get();
  260. // submit the DispatchRays item
  261. context.GetCommandList()->Submit(dispatchRaysItem, index);
  262. }
  263. }
  264. } // namespace RPI
  265. } // namespace AZ