3
0

RayTracingAccelerationStructurePass.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/Feature/Mesh/MeshFeatureProcessor.h>
  9. #include <Atom/RHI/BufferFrameAttachment.h>
  10. #include <Atom/RHI/BufferScopeAttachment.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/RHISystemInterface.h>
  14. #include <Atom/RPI.Public/Buffer/Buffer.h>
  15. #include <Atom/RPI.Public/Buffer/BufferSystemInterface.h>
  16. #include <Atom/RPI.Public/RenderPipeline.h>
  17. #include <Atom/RPI.Public/Scene.h>
  18. #include <RayTracing/RayTracingAccelerationStructurePass.h>
  19. #include <RayTracing/RayTracingFeatureProcessor.h>
  20. namespace AZ
  21. {
  22. namespace Render
  23. {
  24. RPI::Ptr<RayTracingAccelerationStructurePass> RayTracingAccelerationStructurePass::Create(const RPI::PassDescriptor& descriptor)
  25. {
  26. RPI::Ptr<RayTracingAccelerationStructurePass> rayTracingAccelerationStructurePass = aznew RayTracingAccelerationStructurePass(descriptor);
  27. return AZStd::move(rayTracingAccelerationStructurePass);
  28. }
  29. RayTracingAccelerationStructurePass::RayTracingAccelerationStructurePass(const RPI::PassDescriptor& descriptor)
  30. : Pass(descriptor)
  31. {
  32. // disable this pass if we're on a platform that doesn't support raytracing
  33. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  34. if (device->GetFeatures().m_rayTracing == false)
  35. {
  36. SetEnabled(false);
  37. }
  38. }
  39. void RayTracingAccelerationStructurePass::BuildInternal()
  40. {
  41. InitScope(RHI::ScopeId(GetPathName()), AZ::RHI::HardwareQueueClass::Compute);
  42. }
  43. void RayTracingAccelerationStructurePass::FrameBeginInternal(FramePrepareParams params)
  44. {
  45. m_timestampResult = RPI::TimestampResult();
  46. if(GetScopeId().IsEmpty())
  47. {
  48. InitScope(RHI::ScopeId(GetPathName()), RHI::HardwareQueueClass::Compute);
  49. }
  50. params.m_frameGraphBuilder->ImportScopeProducer(*this);
  51. ReadbackScopeQueryResults();
  52. }
  53. RHI::Ptr<RPI::Query> RayTracingAccelerationStructurePass::GetQuery(RPI::ScopeQueryType queryType)
  54. {
  55. auto typeIndex{ static_cast<uint32_t>(queryType) };
  56. if (!m_scopeQueries[typeIndex])
  57. {
  58. RHI::Ptr<RPI::Query> query;
  59. switch (queryType)
  60. {
  61. case RPI::ScopeQueryType::Timestamp:
  62. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  63. RHI::QueryType::Timestamp, RHI::QueryPoolScopeAttachmentType::Global, RHI::ScopeAttachmentAccess::Write);
  64. break;
  65. case RPI::ScopeQueryType::PipelineStatistics:
  66. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  67. RHI::QueryType::PipelineStatistics, RHI::QueryPoolScopeAttachmentType::Global,
  68. RHI::ScopeAttachmentAccess::Write);
  69. break;
  70. }
  71. m_scopeQueries[typeIndex] = query;
  72. }
  73. return m_scopeQueries[typeIndex];
  74. }
  75. template<typename Func>
  76. inline void RayTracingAccelerationStructurePass::ExecuteOnTimestampQuery(Func&& func)
  77. {
  78. if (IsTimestampQueryEnabled())
  79. {
  80. auto query{ GetQuery(RPI::ScopeQueryType::Timestamp) };
  81. if (query)
  82. {
  83. func(query);
  84. }
  85. }
  86. }
  87. template<typename Func>
  88. inline void RayTracingAccelerationStructurePass::ExecuteOnPipelineStatisticsQuery(Func&& func)
  89. {
  90. if (IsPipelineStatisticsQueryEnabled())
  91. {
  92. auto query{ GetQuery(RPI::ScopeQueryType::PipelineStatistics) };
  93. if (query)
  94. {
  95. func(query);
  96. }
  97. }
  98. }
  99. RPI::TimestampResult RayTracingAccelerationStructurePass::GetTimestampResultInternal() const
  100. {
  101. return m_timestampResult;
  102. }
  103. RPI::PipelineStatisticsResult RayTracingAccelerationStructurePass::GetPipelineStatisticsResultInternal() const
  104. {
  105. return m_statisticsResult;
  106. }
  107. void RayTracingAccelerationStructurePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  108. {
  109. RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
  110. RPI::Scene* scene = m_pipeline->GetScene();
  111. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  112. if (rayTracingFeatureProcessor)
  113. {
  114. if (rayTracingFeatureProcessor->GetRevision() != m_rayTracingRevision)
  115. {
  116. RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
  117. RayTracingFeatureProcessor::SubMeshVector& subMeshes = rayTracingFeatureProcessor->GetSubMeshes();
  118. // create the TLAS descriptor
  119. RHI::RayTracingTlasDescriptor tlasDescriptor;
  120. RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build();
  121. uint32_t instanceIndex = 0;
  122. for (auto& subMesh : subMeshes)
  123. {
  124. tlasDescriptorBuild->Instance()
  125. ->InstanceID(instanceIndex)
  126. ->InstanceMask(subMesh.m_mesh->m_instanceMask)
  127. ->HitGroupIndex(0)
  128. ->Blas(subMesh.m_blas)
  129. ->Transform(subMesh.m_mesh->m_transform)
  130. ->NonUniformScale(subMesh.m_mesh->m_nonUniformScale)
  131. ->Transparent(subMesh.m_material.m_irradianceColor.GetA() < 1.0f)
  132. ;
  133. instanceIndex++;
  134. }
  135. unsigned proceduralHitGroupIndex = 1; // Hit group 0 is used for normal meshes
  136. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  137. AZStd::unordered_map<Name, unsigned> geometryTypeMap;
  138. geometryTypeMap.reserve(proceduralGeometryTypes.size());
  139. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  140. {
  141. geometryTypeMap[it->m_name] = proceduralHitGroupIndex++;
  142. }
  143. for (const auto& proceduralGeometry : rayTracingFeatureProcessor->GetProceduralGeometries())
  144. {
  145. tlasDescriptorBuild->Instance()
  146. ->InstanceID(instanceIndex)
  147. ->InstanceMask(proceduralGeometry.m_instanceMask)
  148. ->HitGroupIndex(geometryTypeMap[proceduralGeometry.m_typeHandle->m_name])
  149. ->Blas(proceduralGeometry.m_blas)
  150. ->Transform(proceduralGeometry.m_transform)
  151. ->NonUniformScale(proceduralGeometry.m_nonUniformScale)
  152. ;
  153. instanceIndex++;
  154. }
  155. // create the TLAS buffers based on the descriptor
  156. RHI::Ptr<RHI::RayTracingTlas>& rayTracingTlas = rayTracingFeatureProcessor->GetTlas();
  157. rayTracingTlas->CreateBuffers(*device, &tlasDescriptor, rayTracingBufferPools);
  158. // import and attach the TLAS buffer
  159. const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingTlas->GetTlasBuffer();
  160. if (rayTracingTlasBuffer && rayTracingFeatureProcessor->HasGeometry())
  161. {
  162. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  163. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  164. {
  165. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  166. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  167. }
  168. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingTlasBuffer->GetDescriptor().m_byteCount);
  169. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  170. RHI::BufferScopeAttachmentDescriptor desc;
  171. desc.m_attachmentId = tlasAttachmentId;
  172. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  173. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  174. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write);
  175. }
  176. }
  177. // Attach output data from the skinning pass. This is needed to ensure that this pass is executed after
  178. // the skinning pass has finished. We assume that the pipeline has a skinning pass with this output available.
  179. if (rayTracingFeatureProcessor->GetSkinnedMeshCount() > 0)
  180. {
  181. auto skinningPassPtr = FindAdjacentPass(AZ::Name("SkinningPass"));
  182. auto skinnedMeshOutputStreamBindingPtr = skinningPassPtr->FindAttachmentBinding(AZ::Name("SkinnedMeshOutputStream"));
  183. [[maybe_unused]] auto result = frameGraph.UseShaderAttachment(skinnedMeshOutputStreamBindingPtr->m_unifiedScopeDesc.GetAsBuffer(), RHI::ScopeAttachmentAccess::Read);
  184. AZ_Assert(result == AZ::RHI::ResultCode::Success, "Failed to attach SkinnedMeshOutputStream buffer with error %d", result);
  185. }
  186. // update and compile the RayTracingSceneSrg and RayTracingMaterialSrg
  187. // Note: the timing of this update is very important, it needs to be updated after the TLAS is allocated so it can
  188. // be set on the RayTracingSceneSrg for this frame, and the ray tracing mesh data in the RayTracingSceneSrg must
  189. // exactly match the TLAS. Any mismatch in this data may result in a TDR.
  190. rayTracingFeatureProcessor->UpdateRayTracingSrgs();
  191. AddScopeQueryToFrameGraph(frameGraph);
  192. }
  193. }
  194. void RayTracingAccelerationStructurePass::BuildCommandList(const RHI::FrameGraphExecuteContext& context)
  195. {
  196. RPI::Scene* scene = m_pipeline->GetScene();
  197. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  198. if (!rayTracingFeatureProcessor)
  199. {
  200. return;
  201. }
  202. if (!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer())
  203. {
  204. return;
  205. }
  206. if (rayTracingFeatureProcessor->GetRevision() == m_rayTracingRevision && rayTracingFeatureProcessor->GetSkinnedMeshCount() == 0)
  207. {
  208. // TLAS is up to date
  209. return;
  210. }
  211. // update the stored revision, even if we don't have any meshes to process
  212. m_rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
  213. if (!rayTracingFeatureProcessor->HasGeometry())
  214. {
  215. // no ray tracing meshes in the scene
  216. return;
  217. }
  218. BeginScopeQuery(context);
  219. // build newly added or skinned BLAS objects
  220. AZStd::vector<const AZ::RHI::RayTracingBlas*> changedBlasList;
  221. RayTracingFeatureProcessor::BlasInstanceMap& blasInstances = rayTracingFeatureProcessor->GetBlasInstances();
  222. for (auto& blasInstance : blasInstances)
  223. {
  224. const bool isSkinnedMesh = blasInstance.second.m_isSkinnedMesh;
  225. if (blasInstance.second.m_blasBuilt == false || isSkinnedMesh)
  226. {
  227. for (auto submeshIndex = 0; submeshIndex < blasInstance.second.m_subMeshes.size(); ++submeshIndex)
  228. {
  229. auto& submeshBlasInstance = blasInstance.second.m_subMeshes[submeshIndex];
  230. changedBlasList.push_back(submeshBlasInstance.m_blas.get());
  231. if (blasInstance.second.m_blasBuilt == false)
  232. {
  233. // Always build the BLAS, if it has not previously been built
  234. context.GetCommandList()->BuildBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas);
  235. continue;
  236. }
  237. // Determine if a skinned mesh BLAS needs to be updated or completely rebuilt. For now, we want to rebuild a BLAS every
  238. // SKINNED_BLAS_REBUILD_FRAME_INTERVAL frames, while updating it all other frames. This is based on the assumption that
  239. // by adding together the asset ID hash, submesh index, and frame count, we get a value that allows us to uniformly
  240. // distribute rebuilding all skinned mesh BLASs over all frames.
  241. auto assetGuid = blasInstance.first.m_guid.GetHash();
  242. if (isSkinnedMesh && (assetGuid + submeshIndex + m_frameCount) % SKINNED_BLAS_REBUILD_FRAME_INTERVAL != 0)
  243. {
  244. // Skinned mesh that simply needs an update
  245. context.GetCommandList()->UpdateBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas);
  246. }
  247. else
  248. {
  249. // Fall back to building the BLAS in any case
  250. context.GetCommandList()->BuildBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas);
  251. }
  252. }
  253. blasInstance.second.m_blasBuilt = true;
  254. }
  255. }
  256. // build the TLAS object
  257. context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas(), changedBlasList);
  258. ++m_frameCount;
  259. EndScopeQuery(context);
  260. }
  261. void RayTracingAccelerationStructurePass::AddScopeQueryToFrameGraph(RHI::FrameGraphInterface frameGraph)
  262. {
  263. const auto addToFrameGraph = [&frameGraph](RHI::Ptr<RPI::Query> query)
  264. {
  265. query->AddToFrameGraph(frameGraph);
  266. };
  267. ExecuteOnTimestampQuery(addToFrameGraph);
  268. ExecuteOnPipelineStatisticsQuery(addToFrameGraph);
  269. }
  270. void RayTracingAccelerationStructurePass::BeginScopeQuery(const RHI::FrameGraphExecuteContext& context)
  271. {
  272. const auto beginQuery = [&context, this](RHI::Ptr<RPI::Query> query)
  273. {
  274. if (query->BeginQuery(context) == RPI::QueryResultCode::Fail)
  275. {
  276. AZ_UNUSED(this); // Prevent unused warning in release builds
  277. AZ_WarningOnce(
  278. "RayTracingAccelerationStructurePass", false,
  279. "BeginScopeQuery failed. Make sure AddScopeQueryToFrameGraph was called in SetupFrameGraphDependencies"
  280. " for this pass: %s",
  281. this->RTTI_GetTypeName());
  282. }
  283. };
  284. ExecuteOnTimestampQuery(beginQuery);
  285. ExecuteOnPipelineStatisticsQuery(beginQuery);
  286. }
  287. void RayTracingAccelerationStructurePass::EndScopeQuery(const RHI::FrameGraphExecuteContext& context)
  288. {
  289. const auto endQuery = [&context](const RHI::Ptr<RPI::Query>& query)
  290. {
  291. query->EndQuery(context);
  292. };
  293. // This scope query implementation should be replaced by the feature linked below on GitHub:
  294. // [GHI-16945] Feature Request - Add GPU timestamp and pipeline statistic support for scopes
  295. ExecuteOnTimestampQuery(endQuery);
  296. ExecuteOnPipelineStatisticsQuery(endQuery);
  297. }
  298. void RayTracingAccelerationStructurePass::ReadbackScopeQueryResults()
  299. {
  300. ExecuteOnTimestampQuery(
  301. [this](const RHI::Ptr<RPI::Query>& query)
  302. {
  303. const uint32_t TimestampResultQueryCount{ 2u };
  304. uint64_t timestampResult[TimestampResultQueryCount] = { 0 };
  305. query->GetLatestResult(&timestampResult, sizeof(uint64_t) * TimestampResultQueryCount);
  306. m_timestampResult = RPI::TimestampResult(timestampResult[0], timestampResult[1], RHI::HardwareQueueClass::Graphics);
  307. });
  308. ExecuteOnPipelineStatisticsQuery(
  309. [this](const RHI::Ptr<RPI::Query>& query)
  310. {
  311. query->GetLatestResult(&m_statisticsResult, sizeof(RPI::PipelineStatisticsResult));
  312. });
  313. }
  314. } // namespace RPI
  315. } // namespace AZ