3
0

DiffuseProbeGridVisualizationPreparePass.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/CommandList.h>
  9. #include <Atom/RHI/RHISystemInterface.h>
  10. #include <Atom/RPI.Public/RenderPipeline.h>
  11. #include <Atom/RPI.Public/RPIUtils.h>
  12. #include <Atom/RPI.Public/Scene.h>
  13. #include <DiffuseProbeGrid_Traits_Platform.h>
  14. #include <Render/DiffuseProbeGridFeatureProcessor.h>
  15. #include <Render/DiffuseProbeGridVisualizationPreparePass.h>
  16. #include <RayTracing/RayTracingFeatureProcessor.h>
  17. namespace AZ
  18. {
  19. namespace Render
  20. {
  21. RPI::Ptr<DiffuseProbeGridVisualizationPreparePass> DiffuseProbeGridVisualizationPreparePass::Create(const RPI::PassDescriptor& descriptor)
  22. {
  23. RPI::Ptr<DiffuseProbeGridVisualizationPreparePass> diffuseProbeGridVisualizationPreparePass = aznew DiffuseProbeGridVisualizationPreparePass(descriptor);
  24. return AZStd::move(diffuseProbeGridVisualizationPreparePass);
  25. }
  26. DiffuseProbeGridVisualizationPreparePass::DiffuseProbeGridVisualizationPreparePass(const RPI::PassDescriptor& descriptor)
  27. : RenderPass(descriptor)
  28. {
  29. // disable this pass if we're on a platform that doesn't support raytracing
  30. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  31. if (device->GetFeatures().m_rayTracing == false || !AZ_TRAIT_DIFFUSE_GI_PASSES_SUPPORTED)
  32. {
  33. SetEnabled(false);
  34. }
  35. else
  36. {
  37. LoadShader();
  38. }
  39. }
  40. void DiffuseProbeGridVisualizationPreparePass::LoadShader()
  41. {
  42. // load shaders
  43. // Note: the shader may not be available on all platforms
  44. AZStd::string shaderFilePath = "Shaders/DiffuseGlobalIllumination/DiffuseProbeGridVisualizationPrepare.azshader";
  45. m_shader = RPI::LoadCriticalShader(shaderFilePath);
  46. if (m_shader == nullptr)
  47. {
  48. return;
  49. }
  50. RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
  51. const auto& shaderVariant = m_shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  52. shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
  53. m_pipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
  54. AZ_Assert(m_pipelineState, "Failed to acquire pipeline state");
  55. m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
  56. AZ_Assert(m_srgLayout.get(), "Failed to find Srg layout");
  57. const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
  58. if (!outcome.IsSuccess())
  59. {
  60. AZ_Error("PassSystem", false, "[DiffuseProbeGridVisualizationPreparePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
  61. }
  62. }
  63. bool DiffuseProbeGridVisualizationPreparePass::ShouldUpdate(const AZStd::shared_ptr<DiffuseProbeGrid>& diffuseProbeGrid) const
  64. {
  65. return (diffuseProbeGrid->GetVisualizationEnabled() && diffuseProbeGrid->GetVisualizationTlasUpdateRequired());
  66. }
  67. bool DiffuseProbeGridVisualizationPreparePass::IsEnabled() const
  68. {
  69. if (!RenderPass::IsEnabled())
  70. {
  71. return false;
  72. }
  73. RPI::Scene* scene = m_pipeline->GetScene();
  74. if (!scene)
  75. {
  76. return false;
  77. }
  78. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  79. if (diffuseProbeGridFeatureProcessor)
  80. {
  81. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  82. {
  83. if (ShouldUpdate(diffuseProbeGrid))
  84. {
  85. return true;
  86. }
  87. }
  88. }
  89. return false;
  90. }
  91. void DiffuseProbeGridVisualizationPreparePass::FrameBeginInternal(FramePrepareParams params)
  92. {
  93. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  94. RPI::Scene* scene = m_pipeline->GetScene();
  95. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  96. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  97. {
  98. if (!ShouldUpdate(diffuseProbeGrid))
  99. {
  100. continue;
  101. }
  102. // create the TLAS descriptor by adding an instance entry for each probe in the grid
  103. RHI::RayTracingTlasDescriptor tlasDescriptor;
  104. RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build();
  105. // initialize the transform for each probe to Identity(), they will be updated by the compute shader
  106. AZ::Transform transform = AZ::Transform::Identity();
  107. uint32_t probeCount = diffuseProbeGrid->GetTotalProbeCount();
  108. for (uint32_t index = 0; index < probeCount; ++index)
  109. {
  110. tlasDescriptorBuild->Instance()
  111. ->InstanceID(index)
  112. ->HitGroupIndex(0)
  113. ->Blas(diffuseProbeGridFeatureProcessor->GetVisualizationBlas())
  114. ->Transform(transform)
  115. ;
  116. }
  117. // create the TLAS buffers from on the descriptor
  118. RHI::Ptr<RHI::RayTracingTlas>& visualizationTlas = diffuseProbeGrid->GetVisualizationTlas();
  119. visualizationTlas->CreateBuffers(*device, &tlasDescriptor, diffuseProbeGridFeatureProcessor->GetVisualizationBufferPools());
  120. }
  121. RenderPass::FrameBeginInternal(params);
  122. }
  123. void DiffuseProbeGridVisualizationPreparePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  124. {
  125. RenderPass::SetupFrameGraphDependencies(frameGraph);
  126. RPI::Scene* scene = m_pipeline->GetScene();
  127. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  128. frameGraph.SetEstimatedItemCount(aznumeric_cast<uint32_t>(diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids().size()));
  129. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  130. {
  131. if (!ShouldUpdate(diffuseProbeGrid))
  132. {
  133. continue;
  134. }
  135. // import and attach the visualization TLAS and probe data
  136. RHI::Ptr<RHI::RayTracingTlas>& visualizationTlas = diffuseProbeGrid->GetVisualizationTlas();
  137. const RHI::Ptr<RHI::Buffer>& tlasBuffer = visualizationTlas->GetTlasBuffer();
  138. const RHI::Ptr<RHI::Buffer>& tlasInstancesBuffer = visualizationTlas->GetTlasInstancesBuffer();
  139. if (tlasBuffer && tlasInstancesBuffer)
  140. {
  141. // TLAS buffer
  142. {
  143. AZ::RHI::AttachmentId attachmentId = diffuseProbeGrid->GetProbeVisualizationTlasAttachmentId();
  144. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(attachmentId) == false)
  145. {
  146. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(attachmentId, tlasBuffer);
  147. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import DiffuseProbeGrid visualization TLAS buffer with error %d", result);
  148. }
  149. uint32_t byteCount = aznumeric_cast<uint32_t>(tlasBuffer->GetDescriptor().m_byteCount);
  150. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(byteCount);
  151. RHI::BufferScopeAttachmentDescriptor desc;
  152. desc.m_attachmentId = attachmentId;
  153. desc.m_bufferViewDescriptor = bufferViewDescriptor;
  154. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  155. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write);
  156. }
  157. // TLAS Instances buffer
  158. {
  159. AZ::RHI::AttachmentId attachmentId = diffuseProbeGrid->GetProbeVisualizationTlasInstancesAttachmentId();
  160. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(attachmentId) == false)
  161. {
  162. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(attachmentId, tlasInstancesBuffer);
  163. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import DiffuseProbeGrid visualization TLAS Instances buffer with error %d", result);
  164. }
  165. uint32_t byteCount = aznumeric_cast<uint32_t>(tlasInstancesBuffer->GetDescriptor().m_byteCount);
  166. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, byteCount / RayTracingTlasInstanceElementSize, RayTracingTlasInstanceElementSize);
  167. RHI::BufferScopeAttachmentDescriptor desc;
  168. desc.m_attachmentId = attachmentId;
  169. desc.m_bufferViewDescriptor = bufferViewDescriptor;
  170. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  171. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write);
  172. }
  173. // grid data
  174. {
  175. RHI::BufferScopeAttachmentDescriptor desc;
  176. desc.m_attachmentId = diffuseProbeGrid->GetGridDataBufferAttachmentId();
  177. desc.m_bufferViewDescriptor = diffuseProbeGrid->GetRenderData()->m_gridDataBufferViewDescriptor;
  178. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  179. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
  180. }
  181. // probe data
  182. {
  183. AZ::RHI::AttachmentId attachmentId = diffuseProbeGrid->GetProbeDataImageAttachmentId();
  184. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(attachmentId) == false)
  185. {
  186. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportImage(attachmentId, diffuseProbeGrid->GetProbeDataImage());
  187. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import DiffuseProbeGrid probe data buffer with error %d", result);
  188. }
  189. RHI::ImageScopeAttachmentDescriptor desc;
  190. desc.m_attachmentId = attachmentId;
  191. desc.m_imageViewDescriptor = diffuseProbeGrid->GetRenderData()->m_probeDataImageViewDescriptor;
  192. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  193. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
  194. }
  195. }
  196. }
  197. }
  198. void DiffuseProbeGridVisualizationPreparePass::CompileResources([[maybe_unused]] const RHI::FrameGraphCompileContext& context)
  199. {
  200. RPI::Scene* scene = m_pipeline->GetScene();
  201. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  202. for (auto& diffuseProbeGrid : diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids())
  203. {
  204. if (!ShouldUpdate(diffuseProbeGrid))
  205. {
  206. continue;
  207. }
  208. // the DiffuseProbeGrid Srg must be updated in the Compile phase in order to successfully bind the ReadWrite shader inputs
  209. // (see ValidateSetImageView() in ShaderResourceGroupData.cpp)
  210. diffuseProbeGrid->UpdateVisualizationPrepareSrg(m_shader, m_srgLayout);
  211. diffuseProbeGrid->GetVisualizationPrepareSrg()->Compile();
  212. }
  213. }
  214. void DiffuseProbeGridVisualizationPreparePass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  215. {
  216. RHI::CommandList* commandList = context.GetCommandList();
  217. RPI::Scene* scene = m_pipeline->GetScene();
  218. DiffuseProbeGridFeatureProcessor* diffuseProbeGridFeatureProcessor = scene->GetFeatureProcessor<DiffuseProbeGridFeatureProcessor>();
  219. // submit the DispatchItems for each DiffuseProbeGrid in this range
  220. for (uint32_t index = context.GetSubmitRange().m_startIndex; index < context.GetSubmitRange().m_endIndex; ++index)
  221. {
  222. AZStd::shared_ptr<DiffuseProbeGrid> diffuseProbeGrid = diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids()[index];
  223. if (!ShouldUpdate(diffuseProbeGrid))
  224. {
  225. continue;
  226. }
  227. const RHI::ShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetVisualizationPrepareSrg()->GetRHIShaderResourceGroup();
  228. commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup);
  229. RHI::DispatchItem dispatchItem;
  230. dispatchItem.m_arguments = m_dispatchArgs;
  231. dispatchItem.m_pipelineState = m_pipelineState;
  232. dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = diffuseProbeGrid->GetTotalProbeCount();
  233. dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = 1;
  234. dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;
  235. commandList->Submit(dispatchItem, index);
  236. }
  237. }
  238. } // namespace RPI
  239. } // namespace AZ