3
0

RayTracingAccelerationStructurePass.cpp 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/FrameScheduler.h>
  9. #include <Atom/RHI/CommandList.h>
  10. #include <Atom/RHI/RHISystemInterface.h>
  11. #include <Atom/RPI.Public/Buffer/BufferSystemInterface.h>
  12. #include <Atom/RPI.Public/Buffer/Buffer.h>
  13. #include <Atom/RPI.Public/RenderPipeline.h>
  14. #include <Atom/RPI.Public/Scene.h>
  15. #include <Atom/Feature/Mesh/MeshFeatureProcessor.h>
  16. #include <RayTracing/RayTracingFeatureProcessor.h>
  17. #include <RayTracing/RayTracingAccelerationStructurePass.h>
  18. namespace AZ
  19. {
  20. namespace Render
  21. {
  22. RPI::Ptr<RayTracingAccelerationStructurePass> RayTracingAccelerationStructurePass::Create(const RPI::PassDescriptor& descriptor)
  23. {
  24. RPI::Ptr<RayTracingAccelerationStructurePass> rayTracingAccelerationStructurePass = aznew RayTracingAccelerationStructurePass(descriptor);
  25. return AZStd::move(rayTracingAccelerationStructurePass);
  26. }
  27. RayTracingAccelerationStructurePass::RayTracingAccelerationStructurePass(const RPI::PassDescriptor& descriptor)
  28. : Pass(descriptor)
  29. {
  30. // disable this pass if we're on a platform that doesn't support raytracing
  31. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  32. if (device->GetFeatures().m_rayTracing == false)
  33. {
  34. SetEnabled(false);
  35. }
  36. }
  37. void RayTracingAccelerationStructurePass::BuildInternal()
  38. {
  39. InitScope(RHI::ScopeId(GetPathName()));
  40. }
  41. void RayTracingAccelerationStructurePass::FrameBeginInternal(FramePrepareParams params)
  42. {
  43. params.m_frameGraphBuilder->ImportScopeProducer(*this);
  44. }
  45. void RayTracingAccelerationStructurePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  46. {
  47. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  48. RPI::Scene* scene = m_pipeline->GetScene();
  49. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  50. if (rayTracingFeatureProcessor)
  51. {
  52. if (rayTracingFeatureProcessor->GetRevision() != m_rayTracingRevision)
  53. {
  54. RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
  55. RayTracingFeatureProcessor::SubMeshVector& subMeshes = rayTracingFeatureProcessor->GetSubMeshes();
  56. uint32_t rayTracingSubMeshCount = rayTracingFeatureProcessor->GetSubMeshCount();
  57. // create the TLAS descriptor
  58. RHI::RayTracingTlasDescriptor tlasDescriptor;
  59. RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build();
  60. uint32_t instanceIndex = 0;
  61. for (auto& subMesh : subMeshes)
  62. {
  63. tlasDescriptorBuild->Instance()
  64. ->InstanceID(instanceIndex)
  65. ->HitGroupIndex(0)
  66. ->Blas(subMesh.m_blas)
  67. ->Transform(subMesh.m_mesh->m_transform)
  68. ->NonUniformScale(subMesh.m_mesh->m_nonUniformScale)
  69. ->Transparent(subMesh.m_irradianceColor.GetA() < 1.0f)
  70. ;
  71. instanceIndex++;
  72. }
  73. // create the TLAS buffers based on the descriptor
  74. RHI::Ptr<RHI::RayTracingTlas>& rayTracingTlas = rayTracingFeatureProcessor->GetTlas();
  75. rayTracingTlas->CreateBuffers(*device, &tlasDescriptor, rayTracingBufferPools);
  76. // import and attach the TLAS buffer
  77. const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingTlas->GetTlasBuffer();
  78. if (rayTracingTlasBuffer && rayTracingSubMeshCount)
  79. {
  80. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  81. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  82. {
  83. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  84. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  85. }
  86. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingTlasBuffer->GetDescriptor().m_byteCount);
  87. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  88. RHI::BufferScopeAttachmentDescriptor desc;
  89. desc.m_attachmentId = tlasAttachmentId;
  90. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  91. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  92. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write);
  93. }
  94. }
  95. // update and compile the RayTracingSceneSrg and RayTracingMaterialSrg
  96. // Note: the timing of this update is very important, it needs to be updated after the TLAS is allocated so it can
  97. // be set on the RayTracingSceneSrg for this frame, and the ray tracing mesh data in the RayTracingSceneSrg must
  98. // exactly match the TLAS. Any mismatch in this data may result in a TDR.
  99. rayTracingFeatureProcessor->UpdateRayTracingSrgs();
  100. }
  101. }
  102. void RayTracingAccelerationStructurePass::BuildCommandList(const RHI::FrameGraphExecuteContext& context)
  103. {
  104. RPI::Scene* scene = m_pipeline->GetScene();
  105. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  106. if (!rayTracingFeatureProcessor)
  107. {
  108. return;
  109. }
  110. if (!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer())
  111. {
  112. return;
  113. }
  114. if (rayTracingFeatureProcessor->GetRevision() == m_rayTracingRevision)
  115. {
  116. // TLAS is up to date
  117. return;
  118. }
  119. // update the stored revision, even if we don't have any meshes to process
  120. m_rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
  121. if (!rayTracingFeatureProcessor->GetSubMeshCount())
  122. {
  123. // no ray tracing meshes in the scene
  124. return;
  125. }
  126. // build newly added BLAS objects
  127. RayTracingFeatureProcessor::BlasInstanceMap& blasInstances = rayTracingFeatureProcessor->GetBlasInstances();
  128. for (auto& blasInstance : blasInstances)
  129. {
  130. if (blasInstance.second.m_blasBuilt == false)
  131. {
  132. for (auto& blasInstanceSubMesh : blasInstance.second.m_subMeshes)
  133. {
  134. context.GetCommandList()->BuildBottomLevelAccelerationStructure(*blasInstanceSubMesh.m_blas);
  135. }
  136. blasInstance.second.m_blasBuilt = true;
  137. }
  138. }
  139. // build the TLAS object
  140. context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas());
  141. }
  142. } // namespace RPI
  143. } // namespace AZ