Parcourir la source

Merge pull request #59 from aws-lumberyard-dev/Atom/dmcdiar/ATOM-13575

[ATOM-13575] Move the RHI::RayTracingShaderTable build into the RHI frame
dmcdiarmid-ly il y a 4 ans
Parent
commit
b0b80302c7

+ 8 - 6
Gem/Code/Source/Passes/RayTracingAmbientOcclusionPass.cpp

@@ -128,19 +128,21 @@ namespace AZ
 
             if (!m_rayTracingShaderTable)
             {
+                RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
+                RHI::RayTracingBufferPools& rayTracingBufferPools = m_rayTracingFeatureProcessor->GetBufferPools();
+
                 // Build shader table once. Since we are not using local srg so we don't need to rebuild it even when scene changed 
                 m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
-                RHI::RayTracingShaderTableDescriptor descriptor;
-                descriptor.Build(AZ::Name("RayTracingAOShaderTable"), m_rayTracingPipelineState)
+                m_rayTracingShaderTable->Init(*device.get(), rayTracingBufferPools);
+
+                AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
+                descriptor->Build(AZ::Name("RayTracingAOShaderTable"), m_rayTracingPipelineState)
                     ->RayGenerationRecord(AZ::Name("AoRayGen"))
                     ->MissRecord(AZ::Name("AoMiss"))
                     ->HitGroupRecord(AZ::Name("ClosestHitGroup"))
                     ;
 
-                RHI::Ptr<RHI::Device> device = RHI::RHISystemInterface::Get()->GetDevice();
-                RHI::RayTracingBufferPools& rayTracingBufferPools = m_rayTracingFeatureProcessor->GetBufferPools();
-
-                m_rayTracingShaderTable->Init(*device.get(), &descriptor, rayTracingBufferPools);
+                m_rayTracingShaderTable->Build(descriptor);
             }
 
             RenderPass::FrameBeginInternal(params);

+ 51 - 76
Gem/Code/Source/RHI/RayTracingExampleComponent.cpp

@@ -42,6 +42,30 @@ namespace AtomSampleViewer
         m_supportRHISamplePipeline = true;
     }
 
+    void RayTracingExampleComponent::Activate()
+    {
+        CreateResourcePools();
+        CreateGeometry();
+        CreateFullScreenBuffer();
+        CreateOutputTexture();
+        CreateRasterShader();
+        CreateRayTracingAccelerationStructureObjects();
+        CreateRayTracingPipelineState();
+        CreateRayTracingShaderTable();
+        CreateRayTracingAccelerationTableScope();
+        CreateRayTracingDispatchScope();
+        CreateRasterScope();
+
+        RHI::RHISystemNotificationBus::Handler::BusConnect();
+    }
+
+    void RayTracingExampleComponent::Deactivate()
+    {
+        RHI::RHISystemNotificationBus::Handler::BusDisconnect();
+        m_windowContext = nullptr;
+        m_scopeProducers.clear();
+    }
+
     void RayTracingExampleComponent::CreateResourcePools()
     {
         RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
@@ -295,55 +319,11 @@ namespace AtomSampleViewer
         m_rayTracingPipelineState->Init(*device.get(), &descriptor);
     }
 
-    void RayTracingExampleComponent::CreateRayTracingShaderTableScope()
+    void RayTracingExampleComponent::CreateRayTracingShaderTable()
     {
-        struct ScopeData
-        {
-        };
-
-        const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface& scopeBuilder, [[maybe_unused]] ScopeData& scopeData)
-        {
-        };
-
-        const auto compileFunction = [this]([[maybe_unused]] const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
-        {
-        };
-
-        const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
-        {
-            RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
-
-            // build the ray tracing shader table descriptor
-            RHI::RayTracingShaderTableDescriptor descriptor;
-            descriptor.Build(AZ::Name("RayTracingExampleShaderTable"), m_rayTracingPipelineState)
-                ->RayGenerationRecord(AZ::Name("RayGenerationShader"))
-                ->MissRecord(AZ::Name("MissShader"))
-                ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle1
-                ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle2
-                ->HitGroupRecord(AZ::Name("HitGroupSolid")) // triangle3
-                ->HitGroupRecord(AZ::Name("HitGroupSolid")) // rectangle
-            ;
-
-            // initialize the ray tracing shader table object
-            m_rayTracingShaderTable->Init(*device.get(), &descriptor, *m_rayTracingBufferPools);
-        };
-
-        // create the shader table once, outside of the scope
+        RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
         m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
-
-        m_scopeProducers.emplace_back(
-            aznew RHI::ScopeProducerFunction<
-            ScopeData,
-            decltype(prepareFunction),
-            decltype(compileFunction),
-            decltype(executeFunction)>(
-                RHI::ScopeId{ "RayTracingBuildShaderTable" },
-                ScopeData{},
-                prepareFunction,
-                compileFunction,
-                executeFunction));
-
-        m_shaderTableScopeId = m_scopeProducers.back()->GetScopeId();
+        m_rayTracingShaderTable->Init(*device.get(), *m_rayTracingBufferPools);
     }
 
     void RayTracingExampleComponent::CreateRayTracingAccelerationTableScope()
@@ -463,6 +443,16 @@ namespace AtomSampleViewer
             m_rayTracingTlas->CreateBuffers(*device, &tlasDescriptor, *m_rayTracingBufferPools);
 
             m_tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, (uint32_t)m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
+
+            [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_tlasBufferAttachmentId, m_rayTracingTlas->GetTlasBuffer());
+            AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import TLAS buffer with error %d", result);
+
+            RHI::BufferScopeAttachmentDescriptor desc;
+            desc.m_attachmentId = m_tlasBufferAttachmentId;
+            desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
+            desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
+
+            frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
         };
 
         RHI::EmptyCompileFunction<ScopeData> compileFunction;
@@ -513,18 +503,14 @@ namespace AtomSampleViewer
             // attach TLAS buffer
             if (m_rayTracingTlas->GetTlasBuffer())
             {
-                [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_tlasBufferAttachmentId, m_rayTracingTlas->GetTlasBuffer());
-                AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import TLAS buffer with error %d", result);
-
                 RHI::BufferScopeAttachmentDescriptor desc;
                 desc.m_attachmentId = m_tlasBufferAttachmentId;
                 desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
                 desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
 
-                frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
+                frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
             }
 
-            frameGraph.ExecuteAfter(m_shaderTableScopeId);
             frameGraph.SetEstimatedItemCount(1);
         };
 
@@ -586,6 +572,19 @@ namespace AtomSampleViewer
 
                 m_globalSrg->SetConstantArray(hitSolidDataConstantIndex, hitSolidData);
                 m_globalSrg->Compile();
+
+                // update the ray tracing shader table
+                AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
+                descriptor->Build(AZ::Name("RayTracingExampleShaderTable"), m_rayTracingPipelineState)
+                    ->RayGenerationRecord(AZ::Name("RayGenerationShader"))
+                    ->MissRecord(AZ::Name("MissShader"))
+                    ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle1
+                    ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle2
+                    ->HitGroupRecord(AZ::Name("HitGroupSolid")) // triangle3
+                    ->HitGroupRecord(AZ::Name("HitGroupSolid")) // rectangle
+                    ;
+
+                m_rayTracingShaderTable->Build(descriptor);
             }
         };
 
@@ -704,28 +703,4 @@ namespace AtomSampleViewer
                 compileFunction,
                 executeFunction));
     }
-
-    void RayTracingExampleComponent::Activate()
-    {
-        CreateResourcePools();
-        CreateGeometry();
-        CreateFullScreenBuffer();
-        CreateOutputTexture();
-        CreateRasterShader();
-        CreateRayTracingAccelerationStructureObjects();
-        CreateRayTracingPipelineState();
-        CreateRayTracingShaderTableScope();
-        CreateRayTracingAccelerationTableScope();
-        CreateRayTracingDispatchScope();
-        CreateRasterScope();
-
-        RHI::RHISystemNotificationBus::Handler::BusConnect();
-    }
-
-    void RayTracingExampleComponent::Deactivate()
-    {
-        RHI::RHISystemNotificationBus::Handler::BusDisconnect();
-        m_windowContext = nullptr;
-        m_scopeProducers.clear();
-    }
 } // namespace AtomSampleViewer

+ 1 - 2
Gem/Code/Source/RHI/RayTracingExampleComponent.h

@@ -61,7 +61,7 @@ namespace AtomSampleViewer
         void CreateRasterShader();
         void CreateRayTracingAccelerationStructureObjects();
         void CreateRayTracingPipelineState();
-        void CreateRayTracingShaderTableScope();
+        void CreateRayTracingShaderTable();
         void CreateRayTracingAccelerationTableScope();
         void CreateRayTracingDispatchScope();
         void CreateRasterScope();
@@ -106,7 +106,6 @@ namespace AtomSampleViewer
 
         // ray tracing shader table
         RHI::Ptr<RHI::RayTracingShaderTable> m_rayTracingShaderTable;
-        RHI::ScopeId m_shaderTableScopeId;
 
         // ray tracing global shader resource group and pipeline state
         Data::Instance<RPI::ShaderResourceGroup> m_globalSrg;