RayTracingAmbientOcclusionPass.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/Feature/TransformService/TransformServiceFeatureProcessorInterface.h>
  9. #include <Atom/RHI/CommandList.h>
  10. #include <Atom/RHI/DeviceDispatchRaysItem.h>
  11. #include <Atom/RHI/Factory.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/RHISystemInterface.h>
  14. #include <Atom/RHI/ScopeProducerFunction.h>
  15. #include <Atom/RPI.Public/Buffer/Buffer.h>
  16. #include <Atom/RPI.Public/Buffer/BufferSystemInterface.h>
  17. #include <Atom/RPI.Public/Pass/PassUtils.h>
  18. #include <Atom/RPI.Public/RPIUtils.h>
  19. #include <Atom/RPI.Public/RenderPipeline.h>
  20. #include <Atom/RPI.Public/Scene.h>
  21. #include <Atom/RPI.Public/View.h>
  22. #include <Passes/RayTracingAmbientOcclusionPass.h>
  23. namespace AZ
  24. {
  25. namespace Render
  26. {
  27. RPI::Ptr<RayTracingAmbientOcclusionPass> RayTracingAmbientOcclusionPass::Create(const RPI::PassDescriptor& descriptor)
  28. {
  29. RPI::Ptr<RayTracingAmbientOcclusionPass> pass = aznew RayTracingAmbientOcclusionPass(descriptor);
  30. return AZStd::move(pass);
  31. }
  32. RayTracingAmbientOcclusionPass::RayTracingAmbientOcclusionPass(const RPI::PassDescriptor& descriptor)
  33. : RPI::RenderPass(descriptor)
  34. {
  35. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  36. if (device->GetFeatures().m_rayTracing == false)
  37. {
  38. // ray tracing is not supported on this platform
  39. SetEnabled(false);
  40. }
  41. }
  42. RayTracingAmbientOcclusionPass::~RayTracingAmbientOcclusionPass()
  43. {
  44. }
  45. void RayTracingAmbientOcclusionPass::CreateRayTracingPipelineState()
  46. {
  47. // load ray generation shader
  48. const char* rayGenerationShaderFilePath = "Shaders/RayTracing/RTAOGeneration.azshader";
  49. m_rayGenerationShader = RPI::LoadShader(rayGenerationShaderFilePath);
  50. AZ_Assert(m_rayGenerationShader, "Failed to load ray generation shader");
  51. auto rayGenerationShaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  52. RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
  53. rayGenerationShaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
  54. // load miss shader
  55. const char* missShaderFilePath = "Shaders/RayTracing/RTAOMiss.azshader";
  56. m_missShader = RPI::LoadShader(missShaderFilePath);
  57. AZ_Assert(m_missShader, "Failed to load miss shader");
  58. auto missShaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  59. RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
  60. missShaderVariant.ConfigurePipelineState(missShaderDescriptor);
  61. // Load closest hit shader
  62. // This can be removed when the following issue is fixed.
  63. // [ATOM-15087] RayTracingShaderTable shouldn't report an error if there is no hit group in the descriptor
  64. const char* hitShaderFilePath = "Shaders/RayTracing/RTAOClosestHit.azshader";
  65. m_hitShader = RPI::LoadShader(hitShaderFilePath);
  66. AZ_Assert(m_hitShader, "Failed to load closest hit shader");
  67. auto hitShaderVariant = m_hitShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  68. RHI::PipelineStateDescriptorForRayTracing hitShaderDescriptor;
  69. hitShaderVariant.ConfigurePipelineState(hitShaderDescriptor);
  70. // global pipeline state
  71. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
  72. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  73. //Get pass srg
  74. m_shaderResourceGroup = RPI::ShaderResourceGroup::Create(m_rayGenerationShader->GetAsset(), Name { "RayTracingGlobalSrg" });
  75. AZ_Assert(m_shaderResourceGroup, "[RayTracingAmbientOcclusionPass '%s']: Failed to create SRG from shader asset '%s'",
  76. GetPathName().GetCStr(), rayGenerationShaderFilePath);
  77. RHI::RayTracingPipelineStateDescriptor descriptor;
  78. descriptor.Build()
  79. ->PipelineState(m_globalPipelineState.get())
  80. ->ShaderLibrary(rayGenerationShaderDescriptor)
  81. ->RayGenerationShaderName(AZ::Name("AoRayGen"))
  82. ->ShaderLibrary(missShaderDescriptor)
  83. ->MissShaderName(AZ::Name("AoMiss"))
  84. ->ShaderLibrary(hitShaderDescriptor)
  85. ->ClosestHitShaderName(AZ::Name("AoClosestHit"))
  86. ->HitGroup(AZ::Name("ClosestHitGroup"))
  87. ->ClosestHitShaderName(AZ::Name("AoClosestHit"))
  88. ;
  89. // create the ray tracing pipeline state object
  90. m_rayTracingPipelineState = aznew RHI::RayTracingPipelineState;
  91. m_rayTracingPipelineState->Init(RHI::MultiDevice::AllDevices, descriptor);
  92. }
  93. void RayTracingAmbientOcclusionPass::FrameBeginInternal(FramePrepareParams params)
  94. {
  95. if (m_createRayTracingPipelineState)
  96. {
  97. RPI::Scene* scene = m_pipeline->GetScene();
  98. m_rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessorInterface>();
  99. CreateRayTracingPipelineState();
  100. m_createRayTracingPipelineState = false;
  101. }
  102. if (!m_rayTracingShaderTable)
  103. {
  104. RHI::RayTracingBufferPools& rayTracingBufferPools = m_rayTracingFeatureProcessor->GetBufferPools();
  105. // Build shader table once. Since we are not using local srg so we don't need to rebuild it even when scene changed
  106. m_rayTracingShaderTable = aznew RHI::RayTracingShaderTable;
  107. m_rayTracingShaderTable->Init(RHI::MultiDevice::AllDevices, rayTracingBufferPools);
  108. AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
  109. descriptor->Build(AZ::Name("RayTracingAOShaderTable"), m_rayTracingPipelineState)
  110. ->RayGenerationRecord(AZ::Name("AoRayGen"))
  111. ->MissRecord(AZ::Name("AoMiss"))
  112. ->HitGroupRecord(AZ::Name("ClosestHitGroup"))
  113. ;
  114. m_rayTracingShaderTable->Build(descriptor);
  115. }
  116. RenderPass::FrameBeginInternal(params);
  117. }
  118. void RayTracingAmbientOcclusionPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  119. {
  120. RenderPass::SetupFrameGraphDependencies(frameGraph);
  121. frameGraph.SetEstimatedItemCount(1);
  122. }
  123. void RayTracingAmbientOcclusionPass::CompileResources(const RHI::FrameGraphCompileContext& context)
  124. {
  125. if (!m_shaderResourceGroup)
  126. {
  127. return;
  128. }
  129. // Bind pass attachments to global srg
  130. BindPassSrg(context, m_shaderResourceGroup);
  131. // Bind others for global srg
  132. const RHI::ShaderResourceGroupLayout* srgLayout = m_shaderResourceGroup->GetLayout();
  133. RHI::ShaderInputBufferIndex bufferIndex;
  134. RHI::ShaderInputConstantIndex constantIndex;
  135. // Bind scene TLAS buffer
  136. auto tlasBuffer = m_rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer();
  137. if (tlasBuffer)
  138. {
  139. // TLAS
  140. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(tlasBuffer->GetDescriptor().m_byteCount);
  141. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  142. bufferIndex = srgLayout->FindShaderInputBufferIndex(AZ::Name("m_scene"));
  143. m_shaderResourceGroup->SetBufferView(bufferIndex, tlasBuffer->BuildBufferView(bufferViewDescriptor).get());
  144. }
  145. // Bind constants
  146. // float m_aoRadius
  147. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_aoRadius"));
  148. m_shaderResourceGroup->SetConstant(constantIndex, m_rayMaxT);
  149. // uint m_frameCount
  150. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_frameCount"));
  151. m_shaderResourceGroup->SetConstant(constantIndex, m_frameCount++);
  152. // float m_rayMinT
  153. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_rayMinT"));
  154. m_shaderResourceGroup->SetConstant(constantIndex, m_rayMinT);
  155. // uint m_numRays
  156. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_numRays"));
  157. m_shaderResourceGroup->SetConstant(constantIndex, m_rayNumber);
  158. // Matrix4x4 m_viewProjectionInverseMatrix. This is the copy of same constant from ViewSrg.
  159. // Although we don't have access to ViewSrg in ray tracing shader at this moment
  160. constantIndex = srgLayout->FindShaderInputConstantIndex(AZ::Name("m_viewProjectionInverseMatrix"));
  161. const AZStd::vector<RPI::ViewPtr>& views = m_pipeline->GetViews(RPI::PipelineViewTag{"MainCamera"});
  162. Matrix4x4 clipToWorld = views[0]->GetWorldToClipMatrix();
  163. clipToWorld.InvertFull();
  164. m_shaderResourceGroup->SetConstant(constantIndex, clipToWorld);
  165. m_shaderResourceGroup->Compile();
  166. }
  167. void RayTracingAmbientOcclusionPass::BuildCommandListInternal([[maybe_unused]] const RHI::FrameGraphExecuteContext& context)
  168. {
  169. RPI::Scene* scene = m_pipeline->GetScene();
  170. RayTracingFeatureProcessorInterface* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessorInterface>();
  171. AZ_Assert(rayTracingFeatureProcessor, "RayTracingAmbientOcclusionPass requires the RayTracingFeatureProcessor");
  172. if (!rayTracingFeatureProcessor->GetSubMeshCount())
  173. {
  174. return;
  175. }
  176. if (!m_rayTracingShaderTable)
  177. {
  178. return;
  179. }
  180. RPI::PassAttachment* outputAttachment = GetOutputBinding(0).GetAttachment().get();
  181. RHI::Size targetImageSize = outputAttachment->m_descriptor.m_image.m_size;
  182. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  183. m_shaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  184. };
  185. RHI::DeviceDispatchRaysItem dispatchRaysItem;
  186. dispatchRaysItem.m_arguments.m_direct.m_width = targetImageSize.m_width;
  187. dispatchRaysItem.m_arguments.m_direct.m_height = targetImageSize.m_height;
  188. dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
  189. dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
  190. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
  191. dispatchRaysItem.m_shaderResourceGroupCount = 1;
  192. dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
  193. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  194. // submit the DispatchRays item
  195. context.GetCommandList()->Submit(dispatchRaysItem);
  196. }
  197. uint32_t RayTracingAmbientOcclusionPass::GetRayNumberPerPixel()
  198. {
  199. return m_rayNumber;
  200. }
  201. void RayTracingAmbientOcclusionPass::SetRayNumberPerPixel(uint32_t rayNumber)
  202. {
  203. m_rayNumber = rayNumber;
  204. }
  205. float RayTracingAmbientOcclusionPass::GetRayExtentMin()
  206. {
  207. return m_rayMinT;
  208. }
  209. void RayTracingAmbientOcclusionPass::SetRayExtentMin(float minT)
  210. {
  211. m_rayMinT = minT;
  212. }
  213. float RayTracingAmbientOcclusionPass::GetRayExtentMax()
  214. {
  215. return m_rayMaxT;
  216. }
  217. void RayTracingAmbientOcclusionPass::SetRayExtentMax(float maxT)
  218. {
  219. m_rayMaxT = maxT;
  220. }
  221. } // namespace RPI
  222. } // namespace AZ