Browse Source

One RayTracingAccelerationStructurePass per device (#18313)

Reverts the previous change that creates multiple scopes within one pass
and creates multiple passes instead.

Signed-off-by: Joerg H. Mueller <[email protected]>
jhmueller-huawei 11 months ago
parent
commit
d42709efb5

+ 29 - 118
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingAccelerationStructurePass.cpp

@@ -42,37 +42,22 @@ namespace AZ
 
         void RayTracingAccelerationStructurePass::BuildInternal()
         {
-            const auto deviceMask{ RHI::RHISystemInterface::Get()->GetRayTracingSupport() };
-            const auto deviceCount{ RHI::RHISystemInterface::Get()->GetDeviceCount() };
-            for (auto deviceIndex{ 0 }; deviceIndex < deviceCount; ++deviceIndex)
-            {
-                if ((AZStd::to_underlying(deviceMask) >> deviceIndex) & 1)
-                {
-                    m_scopeProducers[deviceIndex] = AZStd::make_shared<RHI::ScopeProducerFunctionNoData>(
-                        RHI::ScopeId{ AZStd::string(GetPathName().GetCStr() + AZStd::to_string(deviceIndex)) },
-                        AZStd::bind(&RayTracingAccelerationStructurePass::SetupFrameGraphDependencies, this, AZStd::placeholders::_1),
-                        [](const RHI::FrameGraphCompileContext&)
-                        {
-                        },
-                        AZStd::bind(&RayTracingAccelerationStructurePass::BuildCommandList, this, AZStd::placeholders::_1),
-                        RHI::HardwareQueueClass::Compute,
-                        deviceIndex);
-                }
-            }
+            auto deviceIndex = Pass::GetDeviceIndex();
+            InitScope(
+                RHI::ScopeId(AZStd::string(GetPathName().GetCStr() + AZStd::to_string(deviceIndex))),
+                AZ::RHI::HardwareQueueClass::Compute,
+                deviceIndex);
         }
 
         void RayTracingAccelerationStructurePass::FrameBeginInternal(FramePrepareParams params)
         {
-            const auto deviceMask{ RHI::RHISystemInterface::Get()->GetRayTracingSupport() };
-            const auto deviceCount{ RHI::RHISystemInterface::Get()->GetDeviceCount() };
-            for (auto deviceIndex{ 0 }; deviceIndex < deviceCount; ++deviceIndex)
+            if (GetScopeId().IsEmpty())
             {
-                if ((AZStd::to_underlying(deviceMask) >> deviceIndex) & 1)
-                {
-                    params.m_frameGraphBuilder->ImportScopeProducer(*m_scopeProducers[deviceIndex]);
-                }
+                InitScope(RHI::ScopeId(GetPathName()), RHI::HardwareQueueClass::Compute, Pass::GetDeviceIndex());
             }
 
+            params.m_frameGraphBuilder->ImportScopeProducer(*this);
+
             ReadbackScopeQueryResults();
 
             RPI::Scene* scene = m_pipeline->GetScene();
@@ -80,65 +65,12 @@ namespace AZ
 
             if (rayTracingFeatureProcessor)
             {
-                m_rayTracingRevisionOutDated = rayTracingFeatureProcessor->GetRevision() != m_rayTracingRevision;
+                auto revision = rayTracingFeatureProcessor->BeginFrame();
+                m_rayTracingRevisionOutDated = revision != m_rayTracingRevision;
                 if (m_rayTracingRevisionOutDated)
                 {
-                    m_rayTracingRevision = rayTracingFeatureProcessor->GetRevision();
-
-                    RHI::RayTracingBufferPools& rayTracingBufferPools = rayTracingFeatureProcessor->GetBufferPools();
-                    RayTracingFeatureProcessor::SubMeshVector& subMeshes = rayTracingFeatureProcessor->GetSubMeshes();
-
-                    // create the TLAS descriptor
-                    RHI::RayTracingTlasDescriptor tlasDescriptor;
-                    RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build();
-
-                    uint32_t instanceIndex = 0;
-                    for (auto& subMesh : subMeshes)
-                    {
-                        tlasDescriptorBuild->Instance()
-                            ->InstanceID(instanceIndex)
-                            ->InstanceMask(subMesh.m_mesh->m_instanceMask)
-                            ->HitGroupIndex(0)
-                            ->Blas(subMesh.m_blas)
-                            ->Transform(subMesh.m_mesh->m_transform)
-                            ->NonUniformScale(subMesh.m_mesh->m_nonUniformScale)
-                            ->Transparent(subMesh.m_material.m_irradianceColor.GetA() < 1.0f);
-
-                        instanceIndex++;
-                    }
-
-                    unsigned proceduralHitGroupIndex = 1; // Hit group 0 is used for normal meshes
-                    const auto& proceduralGeometryTypes = rayTracingFeatureProcessor->GetProceduralGeometryTypes();
-                    AZStd::unordered_map<Name, unsigned> geometryTypeMap;
-                    geometryTypeMap.reserve(proceduralGeometryTypes.size());
-                    for (auto it = proceduralGeometryTypes.cbegin(); it != proceduralGeometryTypes.cend(); ++it)
-                    {
-                        geometryTypeMap[it->m_name] = proceduralHitGroupIndex++;
-                    }
-
-                    for (const auto& proceduralGeometry : rayTracingFeatureProcessor->GetProceduralGeometries())
-                    {
-                        tlasDescriptorBuild->Instance()
-                            ->InstanceID(instanceIndex)
-                            ->InstanceMask(proceduralGeometry.m_instanceMask)
-                            ->HitGroupIndex(geometryTypeMap[proceduralGeometry.m_typeHandle->m_name])
-                            ->Blas(proceduralGeometry.m_blas)
-                            ->Transform(proceduralGeometry.m_transform)
-                            ->NonUniformScale(proceduralGeometry.m_nonUniformScale);
-                        instanceIndex++;
-                    }
-
-                    // create the TLAS buffers based on the descriptor
-                    RHI::Ptr<RHI::RayTracingTlas>& rayTracingTlas = rayTracingFeatureProcessor->GetTlas();
-                    rayTracingTlas->CreateBuffers(
-                        RHI::RHISystemInterface::Get()->GetRayTracingSupport(), &tlasDescriptor, rayTracingBufferPools);
+                    m_rayTracingRevision = revision;
                 }
-
-                // update and compile the RayTracingSceneSrg and RayTracingMaterialSrg
-                // Note: the timing of this update is very important, it needs to be updated after the TLAS is allocated so it can
-                // be set on the RayTracingSceneSrg for this frame, and the ray tracing mesh data in the RayTracingSceneSrg must
-                // exactly match the TLAS.  Any mismatch in this data may result in a TDR.
-                rayTracingFeatureProcessor->UpdateRayTracingSrgs();
             }
         }
 
@@ -195,26 +127,12 @@ namespace AZ
 
         RPI::TimestampResult RayTracingAccelerationStructurePass::GetTimestampResultInternal() const
         {
-            RPI::TimestampResult result{};
-
-            for (auto& [deviceIndex, timestampResult] : m_timestampResults)
-            {
-                result.Add(timestampResult);
-            }
-
-            return result;
+            return m_timestampResult;
         }
 
         RPI::PipelineStatisticsResult RayTracingAccelerationStructurePass::GetPipelineStatisticsResultInternal() const
         {
-            RPI::PipelineStatisticsResult result{};
-
-            for (auto& [deviceIndex, statisticsResult] : m_statisticsResults)
-            {
-                result += statisticsResult;
-            }
-
-            return result;
+            return m_statisticsResult;
         }
 
         void RayTracingAccelerationStructurePass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
@@ -389,33 +307,26 @@ namespace AZ
             // [GHI-16945] Feature Request - Add GPU timestamp and pipeline statistic support for scopes
             ExecuteOnTimestampQuery(endQuery);
             ExecuteOnPipelineStatisticsQuery(endQuery);
+
+            m_lastDeviceIndex = context.GetDeviceIndex();
         }
 
         void RayTracingAccelerationStructurePass::ReadbackScopeQueryResults()
         {
-            const auto deviceMask{ RHI::RHISystemInterface::Get()->GetRayTracingSupport() };
-            const auto deviceCount{ RHI::RHISystemInterface::Get()->GetDeviceCount() };
-            for (auto deviceIndex{ 0 }; deviceIndex < deviceCount; ++deviceIndex)
-            {
-                if ((AZStd::to_underlying(deviceMask) >> deviceIndex) & 1)
+            ExecuteOnTimestampQuery(
+                [this](const RHI::Ptr<RPI::Query>& query)
                 {
-                    ExecuteOnTimestampQuery(
-                        [this, deviceIndex](const RHI::Ptr<RPI::Query>& query)
-                        {
-                            const uint32_t TimestampResultQueryCount{ 2u };
-                            uint64_t timestampResult[TimestampResultQueryCount] = { 0 };
-                            query->GetLatestResult(&timestampResult, sizeof(uint64_t) * TimestampResultQueryCount, deviceIndex);
-                            m_timestampResults[deviceIndex] =
-                                RPI::TimestampResult(timestampResult[0], timestampResult[1], RHI::HardwareQueueClass::Compute);
-                        });
-
-                    ExecuteOnPipelineStatisticsQuery(
-                        [this, deviceIndex](const RHI::Ptr<RPI::Query>& query)
-                        {
-                            query->GetLatestResult(&m_statisticsResults[deviceIndex], sizeof(RPI::PipelineStatisticsResult), deviceIndex);
-                        });
-                }
-            }
+                    const uint32_t TimestampResultQueryCount{ 2u };
+                    uint64_t timestampResult[TimestampResultQueryCount] = { 0 };
+                    query->GetLatestResult(&timestampResult, sizeof(uint64_t) * TimestampResultQueryCount, m_lastDeviceIndex);
+                    m_timestampResult = RPI::TimestampResult(timestampResult[0], timestampResult[1], RHI::HardwareQueueClass::Compute);
+                });
+
+            ExecuteOnPipelineStatisticsQuery(
+                [this](const RHI::Ptr<RPI::Query>& query)
+                {
+                    query->GetLatestResult(&m_statisticsResult, sizeof(RPI::PipelineStatisticsResult), m_lastDeviceIndex);
+                });
         }
     }   // namespace RPI
 }   // namespace AZ

+ 10 - 8
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingAccelerationStructurePass.h

@@ -17,7 +17,9 @@ namespace AZ
     namespace Render
     {
         //! This pass builds the RayTracing acceleration structures for a scene
-        class RayTracingAccelerationStructurePass final : public RPI::Pass
+        class RayTracingAccelerationStructurePass final
+            : public RPI::Pass
+            , public RHI::ScopeProducer
         {
         public:
             AZ_RPI_PASS(RayTracingAccelerationStructurePass);
@@ -38,8 +40,8 @@ namespace AZ
             explicit RayTracingAccelerationStructurePass(const RPI::PassDescriptor& descriptor);
 
             // Scope producer functions
-            void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph);
-            void BuildCommandList(const RHI::FrameGraphExecuteContext& context);
+            void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override;
+            void BuildCommandList(const RHI::FrameGraphExecuteContext& context) override;
 
             // Pass overrides
             void BuildInternal() override;
@@ -69,9 +71,6 @@ namespace AZ
             // Used to set some build options for the TLASes
             static AZ::RHI::RayTracingAccelerationStructureBuildFlags CreateRayTracingAccelerationStructureBuildFlags(bool isSkinnedMesh);
 
-            // Scope producers for each device
-            AZStd::unordered_map<int, AZStd::shared_ptr<AZ::RHI::ScopeProducer>> m_scopeProducers;
-
             // buffer view descriptor for the TLAS
             RHI::BufferViewDescriptor m_tlasBufferViewDescriptor;
 
@@ -88,10 +87,13 @@ namespace AZ
             static constexpr uint32_t SKINNED_BLAS_REBUILD_FRAME_INTERVAL = 8;
 
             // Readback results from the Timestamp queries
-            AZStd::unordered_map<int, AZ::RPI::TimestampResult> m_timestampResults;
+            AZ::RPI::TimestampResult m_timestampResult{};
 
             // Readback results from the PipelineStatistics queries
-            AZStd::unordered_map<int, AZ::RPI::PipelineStatisticsResult> m_statisticsResults;
+            AZ::RPI::PipelineStatisticsResult m_statisticsResult{};
+
+            // The device index the pass ran on during the last frame, necessary to read the queries.
+            int m_lastDeviceIndex = RHI::MultiDevice::DefaultDeviceIndex;
 
             // For each ScopeProducer an instance of the ScopeQuery is created, which consists
             // of a Timestamp and PipelineStatistic query.

+ 111 - 6
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingFeatureProcessor.cpp

@@ -748,6 +748,65 @@ namespace AZ
             }
         }
 
+        uint32_t RayTracingFeatureProcessor::BeginFrame()
+        {
+            if (m_tlasRevision != m_revision)
+            {
+                m_tlasRevision = m_revision;
+
+                // create the TLAS descriptor
+                RHI::RayTracingTlasDescriptor tlasDescriptor;
+                RHI::RayTracingTlasDescriptor* tlasDescriptorBuild = tlasDescriptor.Build();
+
+                uint32_t instanceIndex = 0;
+                for (auto& subMesh : m_subMeshes)
+                {
+                    tlasDescriptorBuild->Instance()
+                        ->InstanceID(instanceIndex)
+                        ->InstanceMask(subMesh.m_mesh->m_instanceMask)
+                        ->HitGroupIndex(0)
+                        ->Blas(subMesh.m_blas)
+                        ->Transform(subMesh.m_mesh->m_transform)
+                        ->NonUniformScale(subMesh.m_mesh->m_nonUniformScale)
+                        ->Transparent(subMesh.m_material.m_irradianceColor.GetA() < 1.0f);
+
+                    instanceIndex++;
+                }
+
+                unsigned proceduralHitGroupIndex = 1; // Hit group 0 is used for normal meshes
+                AZStd::unordered_map<Name, unsigned> geometryTypeMap;
+                geometryTypeMap.reserve(m_proceduralGeometryTypes.size());
+                for (auto it = m_proceduralGeometryTypes.cbegin(); it != m_proceduralGeometryTypes.cend(); ++it)
+                {
+                    geometryTypeMap[it->m_name] = proceduralHitGroupIndex++;
+                }
+
+                for (const auto& proceduralGeometry : m_proceduralGeometry)
+                {
+                    tlasDescriptorBuild->Instance()
+                        ->InstanceID(instanceIndex)
+                        ->InstanceMask(proceduralGeometry.m_instanceMask)
+                        ->HitGroupIndex(geometryTypeMap[proceduralGeometry.m_typeHandle->m_name])
+                        ->Blas(proceduralGeometry.m_blas)
+                        ->Transform(proceduralGeometry.m_transform)
+                        ->NonUniformScale(proceduralGeometry.m_nonUniformScale);
+                    instanceIndex++;
+                }
+
+                // create the TLAS buffers based on the descriptor
+                RHI::Ptr<RHI::RayTracingTlas>& rayTracingTlas = m_tlas;
+                rayTracingTlas->CreateBuffers(RHI::RHISystemInterface::Get()->GetRayTracingSupport(), &tlasDescriptor, *m_bufferPools);
+            }
+
+            // update and compile the RayTracingSceneSrg and RayTracingMaterialSrg
+            // Note: the timing of this update is very important, it needs to be updated after the TLAS is allocated so it can
+            // be set on the RayTracingSceneSrg for this frame, and the ray tracing mesh data in the RayTracingSceneSrg must
+            // exactly match the TLAS.  Any mismatch in this data may result in a TDR.
+            UpdateRayTracingSrgs();
+
+            return m_revision;
+        }
+
         void RayTracingFeatureProcessor::UpdateRayTracingSrgs()
         {
             AZ_PROFILE_SCOPE(AzRender, "RayTracingFeatureProcessor::UpdateRayTracingSrgs");
@@ -1067,19 +1126,65 @@ namespace AZ
                 return;
             }
 
-            // only enable the RayTracingAccelerationStructurePass on the first pipeline in this scene, this will avoid multiple updates to the same AS
-            bool enabled = true;
+            // only enable the RayTracingAccelerationStructurePass for each device on the first pipeline in this scene, this will avoid
+            // multiple updates to the same AS
             if (changeType == RPI::SceneNotification::RenderPipelineChangeType::Added
                 || changeType == RPI::SceneNotification::RenderPipelineChangeType::Removed)
             {
-                AZ::RPI::PassFilter passFilter = AZ::RPI::PassFilter::CreateWithPassName(AZ::Name("RayTracingAccelerationStructurePass"), GetParentScene());
-                AZ::RPI::PassSystemInterface::Get()->ForEachPass(passFilter, [&enabled](AZ::RPI::Pass* pass) -> AZ::RPI::PassFilterExecutionFlow
+                AZ::RPI::Pass* firstRayTracingAccelerationStructurePass{ nullptr };
+                auto rayTracingDeviceMask{ RHI::RHISystemInterface::Get()->GetRayTracingSupport() };
+                AZ::RHI::MultiDevice::DeviceMask devicesToAdd{ rayTracingDeviceMask };
+
+                AZ::RPI::PassFilter passFilter =
+                    AZ::RPI::PassFilter::CreateWithTemplateName(AZ::Name("RayTracingAccelerationStructurePassTemplate"), GetParentScene());
+                AZ::RPI::PassSystemInterface::Get()->ForEachPass(
+                    passFilter,
+                    [&devicesToAdd, &firstRayTracingAccelerationStructurePass, &rayTracingDeviceMask](
+                        AZ::RPI::Pass* pass) -> AZ::RPI::PassFilterExecutionFlow
                     {
-                        pass->SetEnabled(enabled);
-                        enabled = false;
+                        if (!firstRayTracingAccelerationStructurePass)
+                        {
+                            firstRayTracingAccelerationStructurePass = pass;
+                        }
+
+                        // we always set an invalid device index to the first available device
+                        if (pass->GetDeviceIndex() == RHI::MultiDevice::InvalidDeviceIndex)
+                        {
+                            pass->SetDeviceIndex(az_ctz_u32(AZStd::to_underlying(rayTracingDeviceMask)));
+                        }
+
+                        auto mask = RHI::MultiDevice::DeviceMask(AZ_BIT(pass->GetDeviceIndex()));
+
+                        // only have one RayTracingAccelerationStructurePass per device
+                        pass->SetEnabled((mask & devicesToAdd) != RHI::MultiDevice::NoDevices);
+                        devicesToAdd &= ~mask;
 
                         return AZ::RPI::PassFilterExecutionFlow::ContinueVisitingPasses;
                     });
+
+                // we only add the passes on the other devices if the pipeline contains one in the first place
+                if (firstRayTracingAccelerationStructurePass)
+                {
+                    // add passes for the remaining devices
+                    while (devicesToAdd != RHI::MultiDevice::NoDevices)
+                    {
+                        auto deviceIndex{ az_ctz_u32(AZStd::to_underlying(devicesToAdd)) };
+
+                        AZStd::shared_ptr<RPI::PassRequest> passRequest = AZStd::make_shared<RPI::PassRequest>();
+                        passRequest->m_templateName = Name("RayTracingAccelerationStructurePassTemplate");
+                        passRequest->m_passName = Name("RayTracingAccelerationStructurePass" + AZStd::to_string(deviceIndex));
+
+                        AZStd::shared_ptr<RPI::PassData> passData = AZStd::make_shared<RPI::PassData>();
+                        passData->m_deviceIndex = deviceIndex;
+                        passRequest->m_passData = passData;
+
+                        auto pass = RPI::PassSystemInterface::Get()->CreatePassFromRequest(passRequest.get());
+
+                        renderPipeline->AddPassAfter(pass, firstRayTracingAccelerationStructurePass->GetName());
+
+                        devicesToAdd &= RHI::MultiDevice::DeviceMask(~AZ_BIT(deviceIndex));
+                    }
+                }
             }
         }
 

+ 9 - 1
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingFeatureProcessor.h

@@ -375,7 +375,12 @@ namespace AZ
             //! Retrieves the GPU buffer containing information for all ray tracing materials.
             const Data::Instance<RPI::Buffer> GetMaterialInfoGpuBuffer() const { return m_materialInfoGpuBuffer.GetCurrentBuffer(); }
 
-            //! Updates the RayTracingSceneSrg and RayTracingMaterialSrg, called after the TLAS is allocated in the RayTracingAccelerationStructurePass
+            //! If necessary recreates TLAS buffers and updates the ray tracing SRGs. Should only be called by the
+            //! RayTracingAccelerationStructurePass. Returns the current revision.
+            uint32_t BeginFrame();
+
+            //! Updates the RayTracingSceneSrg and RayTracingMaterialSrg, called after the TLAS is allocated in the
+            //! RayTracingAccelerationStructurePass
             void UpdateRayTracingSrgs();
 
             struct SubMeshBlasInstance
@@ -435,6 +440,9 @@ namespace AZ
             // current revision number of ray tracing data
             uint32_t m_revision = 0;
 
+            // latest tlas revision number
+            uint32_t m_tlasRevision = 0;
+
             uint32_t m_proceduralGeometryTypeRevision = 0;
 
             // total number of ray tracing sub-meshes