3
0

RayTracingAccelerationStructurePass.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/Feature/Mesh/MeshFeatureProcessor.h>
  9. #include <Atom/RHI/BufferFrameAttachment.h>
  10. #include <Atom/RHI/BufferScopeAttachment.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/RHISystemInterface.h>
  14. #include <Atom/RPI.Public/Buffer/Buffer.h>
  15. #include <Atom/RPI.Public/Buffer/BufferSystemInterface.h>
  16. #include <Atom/RPI.Public/RenderPipeline.h>
  17. #include <Atom/RPI.Public/Scene.h>
  18. #include <RayTracing/RayTracingAccelerationStructurePass.h>
  19. #include <RayTracing/RayTracingFeatureProcessor.h>
  20. namespace AZ
  21. {
  22. namespace Render
  23. {
  24. RPI::Ptr<RayTracingAccelerationStructurePass> RayTracingAccelerationStructurePass::Create(const RPI::PassDescriptor& descriptor)
  25. {
  26. RPI::Ptr<RayTracingAccelerationStructurePass> rayTracingAccelerationStructurePass = aznew RayTracingAccelerationStructurePass(descriptor);
  27. return AZStd::move(rayTracingAccelerationStructurePass);
  28. }
  29. RayTracingAccelerationStructurePass::RayTracingAccelerationStructurePass(const RPI::PassDescriptor& descriptor)
  30. : Pass(descriptor)
  31. {
  32. // disable this pass if we're on a platform that doesn't support raytracing
  33. if (RHI::RHISystemInterface::Get()->GetRayTracingSupport() == RHI::MultiDevice::NoDevices)
  34. {
  35. SetEnabled(false);
  36. }
  37. }
  38. void RayTracingAccelerationStructurePass::BuildInternal()
  39. {
  40. InitScope(RHI::ScopeId(GetPathName()), AZ::RHI::HardwareQueueClass::Compute);
  41. }
  42. void RayTracingAccelerationStructurePass::FrameBeginInternal(FramePrepareParams params)
  43. {
  44. m_timestampResult = RPI::TimestampResult();
  45. if(GetScopeId().IsEmpty())
  46. {
  47. InitScope(RHI::ScopeId(GetPathName()), RHI::HardwareQueueClass::Compute);
  48. }
  49. params.m_frameGraphBuilder->ImportScopeProducer(*this);
  50. ReadbackScopeQueryResults();
  51. }
  52. RHI::Ptr<RPI::Query> RayTracingAccelerationStructurePass::GetQuery(RPI::ScopeQueryType queryType)
  53. {
  54. auto typeIndex{ static_cast<uint32_t>(queryType) };
  55. if (!m_scopeQueries[typeIndex])
  56. {
  57. RHI::Ptr<RPI::Query> query;
  58. switch (queryType)
  59. {
  60. case RPI::ScopeQueryType::Timestamp:
  61. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  62. RHI::QueryType::Timestamp, RHI::QueryPoolScopeAttachmentType::Global, RHI::ScopeAttachmentAccess::Write);
  63. break;
  64. case RPI::ScopeQueryType::PipelineStatistics:
  65. query = RPI::GpuQuerySystemInterface::Get()->CreateQuery(
  66. RHI::QueryType::PipelineStatistics, RHI::QueryPoolScopeAttachmentType::Global,
  67. RHI::ScopeAttachmentAccess::Write);
  68. break;
  69. }
  70. m_scopeQueries[typeIndex] = query;
  71. }
  72. return m_scopeQueries[typeIndex];
  73. }
  74. template<typename Func>
  75. inline void RayTracingAccelerationStructurePass::ExecuteOnTimestampQuery(Func&& func)
  76. {
  77. if (IsTimestampQueryEnabled())
  78. {
  79. auto query{ GetQuery(RPI::ScopeQueryType::Timestamp) };
  80. if (query)
  81. {
  82. func(query);
  83. }
  84. }
  85. }
  86. template<typename Func>
  87. inline void RayTracingAccelerationStructurePass::ExecuteOnPipelineStatisticsQuery(Func&& func)
  88. {
  89. if (IsPipelineStatisticsQueryEnabled())
  90. {
  91. auto query{ GetQuery(RPI::ScopeQueryType::PipelineStatistics) };
  92. if (query)
  93. {
  94. func(query);
  95. }
  96. }
  97. }
  98. RPI::TimestampResult RayTracingAccelerationStructurePass::GetTimestampResultInternal() const
  99. {
  100. return m_timestampResult;
  101. }
  102. RPI::PipelineStatisticsResult RayTracingAccelerationStructurePass::GetPipelineStatisticsResultInternal() const
  103. {
  104. return m_statisticsResult;
  105. }
  106. void RayTracingAccelerationStructurePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  107. {
  108. RPI::Scene* scene = m_pipeline->GetScene();
  109. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  110. if (rayTracingFeatureProcessor)
  111. {
  112. if (rayTracingFeatureProcessor->GetRevision() != m_rayTracingRevision)
  113. {
  114. RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
  115. RayTracingFeatureProcessor::SubMeshVector& subMeshes = rayTracingFeatureProcessor->GetSubMeshes();
  116. // create the TLAS descriptor
  117. RHI::RayTracingTlasDescriptor tlasDescriptor;
  118. RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build();
  119. uint32_t instanceIndex = 0;
  120. for (auto& subMesh : subMeshes)
  121. {
  122. tlasDescriptorBuild->Instance()
  123. ->InstanceID(instanceIndex)
  124. ->InstanceMask(subMesh.m_mesh->m_instanceMask)
  125. ->HitGroupIndex(0)
  126. ->Blas(subMesh.m_blas)
  127. ->Transform(subMesh.m_mesh->m_transform)
  128. ->NonUniformScale(subMesh.m_mesh->m_nonUniformScale)
  129. ->Transparent(subMesh.m_material.m_irradianceColor.GetA() < 1.0f)
  130. ;
  131. instanceIndex++;
  132. }
  133. unsigned proceduralHitGroupIndex = 1; // Hit group 0 is used for normal meshes
  134. const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
  135. AZStd::unordered_map<Name, unsigned> geometryTypeMap;
  136. geometryTypeMap.reserve(proceduralGeometryTypes.size());
  137. for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
  138. {
  139. geometryTypeMap[it->m_name] = proceduralHitGroupIndex++;
  140. }
  141. for (const auto& proceduralGeometry : rayTracingFeatureProcessor->GetProceduralGeometries())
  142. {
  143. tlasDescriptorBuild->Instance()
  144. ->InstanceID(instanceIndex)
  145. ->InstanceMask(proceduralGeometry.m_instanceMask)
  146. ->HitGroupIndex(geometryTypeMap[proceduralGeometry.m_typeHandle->m_name])
  147. ->Blas(proceduralGeometry.m_blas)
  148. ->Transform(proceduralGeometry.m_transform)
  149. ->NonUniformScale(proceduralGeometry.m_nonUniformScale)
  150. ;
  151. instanceIndex++;
  152. }
  153. // create the TLAS buffers based on the descriptor
  154. RHI::Ptr<RHI::RayTracingTlas>& rayTracingTlas = rayTracingFeatureProcessor->GetTlas();
  155. rayTracingTlas->CreateBuffers(RHI::RHISystemInterface::Get()->GetRayTracingSupport(), &tlasDescriptor, rayTracingBufferPools);
  156. // import and attach the TLAS buffer
  157. const RHI::Ptr<RHI::Buffer>& rayTracingTlasBuffer = rayTracingTlas->GetTlasBuffer();
  158. if (rayTracingTlasBuffer && rayTracingFeatureProcessor->HasGeometry())
  159. {
  160. AZ::RHI::AttachmentId tlasAttachmentId = rayTracingFeatureProcessor->GetTlasAttachmentId();
  161. if (frameGraph.GetAttachmentDatabase().IsAttachmentValid(tlasAttachmentId) == false)
  162. {
  163. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(tlasAttachmentId, rayTracingTlasBuffer);
  164. AZ_Assert(result == RHI::ResultCode::Success, "Failed to import ray tracing TLAS buffer with error %d", result);
  165. }
  166. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(rayTracingTlasBuffer->GetDescriptor().m_byteCount);
  167. RHI::BufferViewDescriptor tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  168. RHI::BufferScopeAttachmentDescriptor desc;
  169. desc.m_attachmentId = tlasAttachmentId;
  170. desc.m_bufferViewDescriptor = tlasBufferViewDescriptor;
  171. desc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::DontCare;
  172. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Write, RHI::ScopeAttachmentStage::RayTracingShader);
  173. }
  174. }
  175. // Attach output data from the skinning pass. This is needed to ensure that this pass is executed after
  176. // the skinning pass has finished. We assume that the pipeline has a skinning pass with this output available.
  177. if (rayTracingFeatureProcessor->GetSkinnedMeshCount() > 0)
  178. {
  179. auto skinningPassPtr = FindAdjacentPass(AZ::Name("SkinningPass"));
  180. auto skinnedMeshOutputStreamBindingPtr = skinningPassPtr->FindAttachmentBinding(AZ::Name("SkinnedMeshOutputStream"));
  181. [[maybe_unused]] auto result = frameGraph.UseShaderAttachment(skinnedMeshOutputStreamBindingPtr->m_unifiedScopeDesc.GetAsBuffer(), RHI::ScopeAttachmentAccess::Read, RHI::ScopeAttachmentStage::RayTracingShader);
  182. AZ_Assert(result == AZ::RHI::ResultCode::Success, "Failed to attach SkinnedMeshOutputStream buffer with error %d", result);
  183. }
  184. // update and compile the RayTracingSceneSrg and RayTracingMaterialSrg
  185. // Note: the timing of this update is very important, it needs to be updated after the TLAS is allocated so it can
  186. // be set on the RayTracingSceneSrg for this frame, and the ray tracing mesh data in the RayTracingSceneSrg must
  187. // exactly match the TLAS. Any mismatch in this data may result in a TDR.
  188. rayTracingFeatureProcessor->UpdateRayTracingSrgs();
  189. AddScopeQueryToFrameGraph(frameGraph);
  190. }
  191. }
  192. void RayTracingAccelerationStructurePass::BuildCommandList(const RHI::FrameGraphExecuteContext& context)
  193. {
  194. RPI::Scene* scene = m_pipeline->GetScene();
  195. RayTracingFeatureProcessor* rayTracingFeatureProcessor = scene->GetFeatureProcessor<RayTracingFeatureProcessor>();
  196. if (!rayTracingFeatureProcessor)
  197. {
  198. return;
  199. }
  200. if (!rayTracingFeatureProcessor->GetTlas()->GetTlasBuffer())
  201. {
  202. return;
  203. }
  204. if (rayTracingFeatureProcessor->GetRevision() == m_rayTracingRevision && rayTracingFeatureProcessor->GetSkinnedMeshCount() == 0)
  205. {
  206. // TLAS is up to date
  207. return;
  208. }
  209. // update the stored revision, even if we don't have any meshes to process
  210. m_rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
  211. if (!rayTracingFeatureProcessor->HasGeometry())
  212. {
  213. // no ray tracing meshes in the scene
  214. return;
  215. }
  216. BeginScopeQuery(context);
  217. // build newly added or skinned BLAS objects
  218. AZStd::vector<const AZ::RHI::DeviceRayTracingBlas*> changedBlasList;
  219. RayTracingFeatureProcessor::BlasInstanceMap& blasInstances = rayTracingFeatureProcessor->GetBlasInstances();
  220. for (auto& blasInstance : blasInstances)
  221. {
  222. const bool isSkinnedMesh = blasInstance.second.m_isSkinnedMesh;
  223. if (blasInstance.second.m_blasBuilt == false || isSkinnedMesh)
  224. {
  225. for (auto submeshIndex = 0; submeshIndex < blasInstance.second.m_subMeshes.size(); ++submeshIndex)
  226. {
  227. auto& submeshBlasInstance = blasInstance.second.m_subMeshes[submeshIndex];
  228. changedBlasList.push_back(submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get());
  229. if (blasInstance.second.m_blasBuilt == false)
  230. {
  231. // Always build the BLAS, if it has not previously been built
  232. context.GetCommandList()->BuildBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  233. continue;
  234. }
  235. // Determine if a skinned mesh BLAS needs to be updated or completely rebuilt. For now, we want to rebuild a BLAS every
  236. // SKINNED_BLAS_REBUILD_FRAME_INTERVAL frames, while updating it all other frames. This is based on the assumption that
  237. // by adding together the asset ID hash, submesh index, and frame count, we get a value that allows us to uniformly
  238. // distribute rebuilding all skinned mesh BLASs over all frames.
  239. auto assetGuid = blasInstance.first.m_guid.GetHash();
  240. if (isSkinnedMesh && (assetGuid + submeshIndex + m_frameCount) % SKINNED_BLAS_REBUILD_FRAME_INTERVAL != 0)
  241. {
  242. // Skinned mesh that simply needs an update
  243. context.GetCommandList()->UpdateBottomLevelAccelerationStructure(
  244. *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  245. }
  246. else
  247. {
  248. // Fall back to building the BLAS in any case
  249. context.GetCommandList()->BuildBottomLevelAccelerationStructure(
  250. *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  251. }
  252. }
  253. blasInstance.second.m_blasBuilt = true;
  254. }
  255. }
  256. // build the TLAS object
  257. context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas()->GetDeviceRayTracingTlas(context.GetDeviceIndex()), changedBlasList);
  258. ++m_frameCount;
  259. EndScopeQuery(context);
  260. }
  261. void RayTracingAccelerationStructurePass::AddScopeQueryToFrameGraph(RHI::FrameGraphInterface frameGraph)
  262. {
  263. const auto addToFrameGraph = [&frameGraph](RHI::Ptr<RPI::Query> query)
  264. {
  265. query->AddToFrameGraph(frameGraph);
  266. };
  267. ExecuteOnTimestampQuery(addToFrameGraph);
  268. ExecuteOnPipelineStatisticsQuery(addToFrameGraph);
  269. }
  270. void RayTracingAccelerationStructurePass::BeginScopeQuery(const RHI::FrameGraphExecuteContext& context)
  271. {
  272. const auto beginQuery = [&context, this](RHI::Ptr<RPI::Query> query)
  273. {
  274. if (query->BeginQuery(context) == RPI::QueryResultCode::Fail)
  275. {
  276. AZ_UNUSED(this); // Prevent unused warning in release builds
  277. AZ_WarningOnce(
  278. "RayTracingAccelerationStructurePass", false,
  279. "BeginScopeQuery failed. Make sure AddScopeQueryToFrameGraph was called in SetupFrameGraphDependencies"
  280. " for this pass: %s",
  281. this->RTTI_GetTypeName());
  282. }
  283. };
  284. ExecuteOnTimestampQuery(beginQuery);
  285. ExecuteOnPipelineStatisticsQuery(beginQuery);
  286. }
  287. void RayTracingAccelerationStructurePass::EndScopeQuery(const RHI::FrameGraphExecuteContext& context)
  288. {
  289. const auto endQuery = [&context](const RHI::Ptr<RPI::Query>& query)
  290. {
  291. query->EndQuery(context);
  292. };
  293. // This scope query implementation should be replaced by the feature linked below on GitHub:
  294. // [GHI-16945] Feature Request - Add GPU timestamp and pipeline statistic support for scopes
  295. ExecuteOnTimestampQuery(endQuery);
  296. ExecuteOnPipelineStatisticsQuery(endQuery);
  297. m_lastDeviceIndex = context.GetDeviceIndex();
  298. }
  299. void RayTracingAccelerationStructurePass::ReadbackScopeQueryResults()
  300. {
  301. ExecuteOnTimestampQuery(
  302. [this](const RHI::Ptr<RPI::Query>& query)
  303. {
  304. const uint32_t TimestampResultQueryCount{ 2u };
  305. uint64_t timestampResult[TimestampResultQueryCount] = { 0 };
  306. query->GetLatestResult(&timestampResult, sizeof(uint64_t) * TimestampResultQueryCount, m_lastDeviceIndex);
  307. m_timestampResult = RPI::TimestampResult(timestampResult[0], timestampResult[1], RHI::HardwareQueueClass::Graphics);
  308. });
  309. ExecuteOnPipelineStatisticsQuery(
  310. [this](const RHI::Ptr<RPI::Query>& query)
  311. {
  312. query->GetLatestResult(&m_statisticsResult, sizeof(RPI::PipelineStatisticsResult), m_lastDeviceIndex);
  313. });
  314. }
  315. } // namespace RPI
  316. } // namespace AZ