RayTracingAccelerationStructurePass.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/Feature/Mesh/MeshFeatureProcessor.h>
  9. #include <Atom/RHI/BufferFrameAttachment.h>
  10. #include <Atom/RHI/BufferScopeAttachment.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/RHISystemInterface.h>
  14. #include <Atom/RHI/ScopeProducerFunction.h>
  15. #include <Atom/RPI.Public/Buffer/Buffer.h>
  16. #include <Atom/RPI.Public/Buffer/BufferSystemInterface.h>
  17. #include <Atom/RPI.Public/RenderPipeline.h>
  18. #include <Atom/RPI.Public/Scene.h>
  19. #include <RayTracing/RayTracingAccelerationStructurePass.h>
  20. #include <RayTracing/RayTracingFeatureProcessor.h>
  21. namespace AZ
  22. {
  23. namespace Render
  24. {
  25. RPI::Ptr<RayTracingAccelerationStructurePass> RayTracingAccelerationStructurePass::Create(const RPI::PassDescriptor& descriptor)
  26. {
  27. RPI::Ptr<RayTracingAccelerationStructurePass> rayTracingAccelerationStructurePass = aznew RayTracingAccelerationStructurePass(descriptor);
  28. return AZStd::move(rayTracingAccelerationStructurePass);
  29. }
  30. RayTracingAccelerationStructurePass::RayTracingAccelerationStructurePass(const RPI::PassDescriptor& descriptor)
  31. : Pass(descriptor)
  32. {
  33. // disable this pass if we're on a platform that doesn't support raytracing
  34. if (RHI::RHISystemInterface::Get()->GetRayTracingSupport() == RHI::MultiDevice::NoDevices)
  35. {
  36. SetEnabled(false);
  37. }
  38. }
  39. void RayTracingAccelerationStructurePass::BuildInternal()
  40. {
  41. auto deviceIndex = Pass::GetDeviceIndex();
  42. InitScope(
  43. RHI::ScopeId(AZStd::string(GetPathName().GetCStr() + AZStd::to_string(deviceIndex))),
  44. AZ::RHI::HardwareQueueClass::Compute,
  45. deviceIndex);
  46. }
  47. void RayTracingAccelerationStructurePass::FrameBeginInternal(FramePrepareParams params)
  48. {
  49. if (GetScopeId().IsEmpty())
  50. {
  51. InitScope(RHI::ScopeId(GetPathName()), RHI::HardwareQueueClass::Compute, Pass::GetDeviceIndex());
  52. }
  53. params.m_frameGraphBuilder->ImportScopeProducer(*this);
  54. ReadbackScopeQueryResults();
  55. RPI::Scene* scene = m_pipeline->GetScene();
  56. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  57. if (rayTracingFeatureProcessor)
  58. {
  59. auto revision = rayTracingFeatureProcessor->BeginFrame();
  60. m_rayTracingRevisionOutDated = revision != m_rayTracingRevision;
  61. if (m_rayTracingRevisionOutDated)
  62. {
  63. m_rayTracingRevision = revision;
  64. }
  65. }
  66. }
  67. RHI::Ptr<RPI::Query> RayTracingAccelerationStructurePass::GetQuery(RPI::ScopeQueryType queryType)
  68. {
  69. auto typeIndex{ static_cast<uint32_t>(queryType) };
  70. if (!m_scopeQueries[typeIndex])
  71. {
  72. RHI::Ptr<RPI::Query> query;
  73. switch (queryType)
  74. {
  75. case RPI::ScopeQueryType::Timestamp:
  76. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  77. RHI::QueryType::Timestamp, RHI::QueryPoolScopeAttachmentType::Global, RHI::ScopeAttachmentAccess::Write);
  78. break;
  79. case RPI::ScopeQueryType::PipelineStatistics:
  80. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  81. RHI::QueryType::PipelineStatistics, RHI::QueryPoolScopeAttachmentType::Global,
  82. RHI::ScopeAttachmentAccess::Write);
  83. break;
  84. }
  85. m_scopeQueries[typeIndex] = query;
  86. }
  87. return m_scopeQueries[typeIndex];
  88. }
  89. template<typename Func>
  90. inline void RayTracingAccelerationStructurePass::ExecuteOnTimestampQuery(Func&& func)
  91. {
  92. if (IsTimestampQueryEnabled())
  93. {
  94. auto query{ GetQuery(RPI::ScopeQueryType::Timestamp) };
  95. if (query)
  96. {
  97. func(query);
  98. }
  99. }
  100. }
  101. template<typename Func>
  102. inline void RayTracingAccelerationStructurePass::ExecuteOnPipelineStatisticsQuery(Func&& func)
  103. {
  104. if (IsPipelineStatisticsQueryEnabled())
  105. {
  106. auto query{ GetQuery(RPI::ScopeQueryType::PipelineStatistics) };
  107. if (query)
  108. {
  109. func(query);
  110. }
  111. }
  112. }
  113. RPI::TimestampResult RayTracingAccelerationStructurePass::GetTimestampResultInternal() const
  114. {
  115. return m_timestampResult;
  116. }
  117. RPI::PipelineStatisticsResult RayTracingAccelerationStructurePass::GetPipelineStatisticsResultInternal() const
  118. {
  119. return m_statisticsResult;
  120. }
  121. void RayTracingAccelerationStructurePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  122. {
  123. RPI::Scene* scene = m_pipeline->GetScene();
  124. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  125. if (rayTracingFeatureProcessor)
  126. {
  127. if (m_rayTracingRevisionOutDated)
  128. {
  129. // create the TLAS buffers based on the descriptor
  130. RHI::Ptr<RHI::RayTracingTlas>& rayTracingTlas = rayTracingFeatureProcessor->GetTlas();
  131. // import and attach the TLAS buffer
  132. const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingTlas->GetTlasBuffer();
  133. if (rayTracingTlasBuffer && rayTracingFeatureProcessor->HasGeometry())
  134. {
  135. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  136. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  137. {
  138. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  139. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  140. }
  141. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingTlasBuffer->GetDescriptor().m_byteCount);
  142. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  143. RHI::BufferScopeAttachmentDescriptor desc;
  144. desc.m_attachmentId = tlasAttachmentId;
  145. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  146. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  147. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write, RHI::ScopeAttachmentStage::RayTracingShader);
  148. }
  149. }
  150. // Attach output data from the skinning pass. This is needed to ensure that this pass is executed after
  151. // the skinning pass has finished. We assume that the pipeline has a skinning pass with this output available.
  152. if (rayTracingFeatureProcessor->GetSkinnedMeshCount() > 0)
  153. {
  154. auto skinningPassPtr = FindAdjacentPass(AZ::Name("SkinningPass"));
  155. auto skinnedMeshOutputStreamBindingPtr = skinningPassPtr->FindAttachmentBinding(AZ::Name("SkinnedMeshOutputStream"));
  156. [[maybe_unused]] auto result = frameGraph.UseShaderAttachment(skinnedMeshOutputStreamBindingPtr->m_unifiedScopeDesc.GetAsBuffer(), RHI::ScopeAttachmentAccess::Read, RHI::ScopeAttachmentStage::RayTracingShader);
  157. AZ_Assert(result == AZ::RHI::ResultCode::Success, "Failed to attach SkinnedMeshOutputStream buffer with error %d", result);
  158. }
  159. AddScopeQueryToFrameGraph(frameGraph);
  160. }
  161. }
  162. void RayTracingAccelerationStructurePass::BuildCommandList(const RHI::FrameGraphExecuteContext& context)
  163. {
  164. RPI::Scene* scene = m_pipeline->GetScene();
  165. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  166. if (!rayTracingFeatureProcessor)
  167. {
  168. return;
  169. }
  170. if (!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer())
  171. {
  172. return;
  173. }
  174. if (!m_rayTracingRevisionOutDated && rayTracingFeatureProcessor->GetSkinnedMeshCount() == 0)
  175. {
  176. // TLAS is up to date
  177. return;
  178. }
  179. if (!rayTracingFeatureProcessor->HasGeometry())
  180. {
  181. // no ray tracing meshes in the scene
  182. return;
  183. }
  184. BeginScopeQuery(context);
  185. // build newly added or skinned BLAS objects
  186. AZStd::vector<const AZ::RHI::DeviceRayTracingBlas*> changedBlasList;
  187. RayTracingFeatureProcessor::BlasInstanceMap& blasInstances = rayTracingFeatureProcessor->GetBlasInstances();
  188. for (auto& blasInstance : blasInstances)
  189. {
  190. const bool isSkinnedMesh = blasInstance.second.m_isSkinnedMesh;
  191. const bool buildBlas = (blasInstance.second.m_blasBuilt & RHI::MultiDevice::DeviceMask(1 << context.GetDeviceIndex())) ==
  192. RHI::MultiDevice::NoDevices;
  193. if (buildBlas || isSkinnedMesh)
  194. {
  195. for (auto submeshIndex = 0; submeshIndex < blasInstance.second.m_subMeshes.size(); ++submeshIndex)
  196. {
  197. auto& submeshBlasInstance = blasInstance.second.m_subMeshes[submeshIndex];
  198. changedBlasList.push_back(submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get());
  199. if (buildBlas)
  200. {
  201. // Always build the BLAS, if it has not previously been built
  202. context.GetCommandList()->BuildBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  203. continue;
  204. }
  205. // Determine if a skinned mesh BLAS needs to be updated or completely rebuilt. For now, we want to rebuild a BLAS every
  206. // SKINNED_BLAS_REBUILD_FRAME_INTERVAL frames, while updating it all other frames. This is based on the assumption that
  207. // by adding together the asset ID hash, submesh index, and frame count, we get a value that allows us to uniformly
  208. // distribute rebuilding all skinned mesh BLASs over all frames.
  209. auto assetGuid = blasInstance.first.m_guid.GetHash();
  210. if (isSkinnedMesh && (assetGuid + submeshIndex + m_frameCount) % SKINNED_BLAS_REBUILD_FRAME_INTERVAL != 0)
  211. {
  212. // Skinned mesh that simply needs an update
  213. context.GetCommandList()->UpdateBottomLevelAccelerationStructure(
  214. *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  215. }
  216. else
  217. {
  218. // Fall back to building the BLAS in any case
  219. context.GetCommandList()->BuildBottomLevelAccelerationStructure(
  220. *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  221. }
  222. }
  223. AZStd::lock_guard lock(rayTracingFeatureProcessor->GetBlasBuiltMutex());
  224. blasInstance.second.m_blasBuilt |= RHI::MultiDevice::DeviceMask(1 << context.GetDeviceIndex());
  225. }
  226. }
  227. // build the TLAS object
  228. context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas()->GetDeviceRayTracingTlas(context.GetDeviceIndex()), changedBlasList);
  229. ++m_frameCount;
  230. EndScopeQuery(context);
  231. }
  232. void RayTracingAccelerationStructurePass::AddScopeQueryToFrameGraph(RHI::FrameGraphInterface frameGraph)
  233. {
  234. const auto addToFrameGraph = [&frameGraph](RHI::Ptr<RPI::Query> query)
  235. {
  236. query->AddToFrameGraph(frameGraph);
  237. };
  238. ExecuteOnTimestampQuery(addToFrameGraph);
  239. ExecuteOnPipelineStatisticsQuery(addToFrameGraph);
  240. }
  241. void RayTracingAccelerationStructurePass::BeginScopeQuery(const RHI::FrameGraphExecuteContext& context)
  242. {
  243. const auto beginQuery = [&context, this](RHI::Ptr<RPI::Query> query)
  244. {
  245. if (query->BeginQuery(context) == RPI::QueryResultCode::Fail)
  246. {
  247. AZ_UNUSED(this); // Prevent unused warning in release builds
  248. AZ_WarningOnce(
  249. "RayTracingAccelerationStructurePass", false,
  250. "BeginScopeQuery failed. Make sure AddScopeQueryToFrameGraph was called in SetupFrameGraphDependencies"
  251. " for this pass: %s",
  252. this->RTTI_GetTypeName());
  253. }
  254. };
  255. ExecuteOnTimestampQuery(beginQuery);
  256. ExecuteOnPipelineStatisticsQuery(beginQuery);
  257. }
  258. void RayTracingAccelerationStructurePass::EndScopeQuery(const RHI::FrameGraphExecuteContext& context)
  259. {
  260. const auto endQuery = [&context](const RHI::Ptr<RPI::Query>& query)
  261. {
  262. query->EndQuery(context);
  263. };
  264. // This scope query implementation should be replaced by the feature linked below on GitHub:
  265. // [GHI-16945] Feature Request - Add GPU timestamp and pipeline statistic support for scopes
  266. ExecuteOnTimestampQuery(endQuery);
  267. ExecuteOnPipelineStatisticsQuery(endQuery);
  268. m_lastDeviceIndex = context.GetDeviceIndex();
  269. }
  270. void RayTracingAccelerationStructurePass::ReadbackScopeQueryResults()
  271. {
  272. ExecuteOnTimestampQuery(
  273. [this](const RHI::Ptr<RPI::Query>& query)
  274. {
  275. const uint32_t TimestampResultQueryCount{ 2u };
  276. uint64_t timestampResult[TimestampResultQueryCount] = { 0 };
  277. query->GetLatestResult(&timestampResult, sizeof(uint64_t) * TimestampResultQueryCount, m_lastDeviceIndex);
  278. m_timestampResult = RPI::TimestampResult(timestampResult[0], timestampResult[1], RHI::HardwareQueueClass::Compute);
  279. });
  280. ExecuteOnPipelineStatisticsQuery(
  281. [this](const RHI::Ptr<RPI::Query>& query)
  282. {
  283. query->GetLatestResult(&m_statisticsResult, sizeof(RPI::PipelineStatisticsResult), m_lastDeviceIndex);
  284. });
  285. }
  286. } // namespace RPI
  287. } // namespace AZ