Selaa lähdekoodia

Fix validation warnings due to incorrect buffer bind flags for the AABB staging buffer when building a BLAS with procedural geometry (#18012)

Signed-off-by: Markus Prettner <[email protected]>
Markus Prettner 1 vuosi sitten
vanhempi
commit
1f2621241b

+ 3 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/RayTracingBufferPools.h

@@ -31,6 +31,7 @@ namespace AZ::RHI
         // accessors
         const RHI::Ptr<RHI::BufferPool>& GetShaderTableBufferPool() const;
         const RHI::Ptr<RHI::BufferPool>& GetScratchBufferPool() const;
+        const RHI::Ptr<RHI::BufferPool>& GetAabbStagingBufferPool() const;
         const RHI::Ptr<RHI::BufferPool>& GetBlasBufferPool() const;
         const RHI::Ptr<RHI::BufferPool>& GetTlasInstancesBufferPool() const;
         const RHI::Ptr<RHI::BufferPool>& GetTlasBufferPool() const;
@@ -43,6 +44,7 @@ namespace AZ::RHI
 
         virtual RHI::BufferBindFlags GetShaderTableBufferBindFlags() const { return RHI::BufferBindFlags::ShaderRead | RHI::BufferBindFlags::CopyRead | RHI::BufferBindFlags::RayTracingShaderTable; }
         virtual RHI::BufferBindFlags GetScratchBufferBindFlags() const { return RHI::BufferBindFlags::ShaderReadWrite | RHI::BufferBindFlags::RayTracingScratchBuffer; }
+        virtual RHI::BufferBindFlags GetAabbStagingBufferBindFlags() const { return RHI::BufferBindFlags::CopyRead; }
         virtual RHI::BufferBindFlags GetBlasBufferBindFlags() const { return RHI::BufferBindFlags::ShaderReadWrite | RHI::BufferBindFlags::RayTracingAccelerationStructure; }
         virtual RHI::BufferBindFlags GetTlasInstancesBufferBindFlags() const { return RHI::BufferBindFlags::ShaderReadWrite; }
         virtual RHI::BufferBindFlags GetTlasBufferBindFlags() const { return RHI::BufferBindFlags::RayTracingAccelerationStructure; }
@@ -51,6 +53,7 @@ namespace AZ::RHI
         bool m_initialized = false;
         RHI::Ptr<RHI::BufferPool> m_shaderTableBufferPool;
         RHI::Ptr<RHI::BufferPool> m_scratchBufferPool;
+        RHI::Ptr<RHI::BufferPool> m_aabbStagingBufferPool;
         RHI::Ptr<RHI::BufferPool> m_blasBufferPool;
         RHI::Ptr<RHI::BufferPool> m_tlasInstancesBufferPool;
         RHI::Ptr<RHI::BufferPool> m_tlasBufferPool;

+ 18 - 0
Gems/Atom/RHI/Code/Source/RHI/RayTracingBufferPools.cpp

@@ -31,6 +31,12 @@ namespace AZ::RHI
         return m_scratchBufferPool;
     }
 
+    const RHI::Ptr<RHI::BufferPool>& RayTracingBufferPools::GetAabbStagingBufferPool() const
+    {
+        AZ_Assert(m_initialized, "RayTracingBufferPools was not initialized");
+        return m_aabbStagingBufferPool;
+    }
+
     const RHI::Ptr<RHI::BufferPool>& RayTracingBufferPools::GetBlasBufferPool() const
     {
         AZ_Assert(m_initialized, "RayTracingBufferPools was not initialized");
@@ -80,6 +86,18 @@ namespace AZ::RHI
             AZ_Assert(resultCode == RHI::ResultCode::Success, "Failed to initialize ray tracing scratch buffer pool");
         }
 
+        // create AABB buffer pool
+        {
+            RHI::BufferPoolDescriptor bufferPoolDesc;
+            bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
+            bufferPoolDesc.m_bindFlags = GetAabbStagingBufferBindFlags();
+
+            m_aabbStagingBufferPool = RHI::Factory::Get().CreateBufferPool();
+            m_aabbStagingBufferPool->SetName(Name("RayTracingAabbStagingBufferPool"));
+            [[maybe_unused]] RHI::ResultCode resultCode = m_aabbStagingBufferPool->Init(*device, bufferPoolDesc);
+            AZ_Assert(resultCode == RHI::ResultCode::Success, "Failed to initialize ray tracing AABB staging buffer pool");
+        }
+
         // create BLAS buffer pool
         {
             RHI::BufferPoolDescriptor bufferPoolDesc;

+ 2 - 2
Gems/Atom/RHI/DX12/Code/Source/RHI/RayTracingBlas.cpp

@@ -40,7 +40,7 @@ namespace AZ
                 const AZ::Aabb& aabb = descriptor->GetAABB();
                 buffers.m_aabbBuffer = RHI::Factory::Get().CreateBuffer();
                 AZ::RHI::BufferDescriptor blasBufferDescriptor;
-                blasBufferDescriptor.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite | RHI::BufferBindFlags::RayTracingAccelerationStructure;
+                blasBufferDescriptor.m_bindFlags = RHI::BufferBindFlags::CopyRead;
                 blasBufferDescriptor.m_byteCount = sizeof(D3D12_RAYTRACING_AABB);
                 blasBufferDescriptor.m_alignment = D3D12_RAYTRACING_AABB_BYTE_ALIGNMENT;
 
@@ -56,7 +56,7 @@ namespace AZ
                 blasBufferRequest.m_buffer = buffers.m_aabbBuffer.get();
                 blasBufferRequest.m_initialData = &rtAabb;
                 blasBufferRequest.m_descriptor = blasBufferDescriptor;
-                auto resultCode = bufferPools.GetBlasBufferPool()->InitBuffer(blasBufferRequest);
+                auto resultCode = bufferPools.GetAabbStagingBufferPool()->InitBuffer(blasBufferRequest);
                 if (resultCode != AZ::RHI::ResultCode::Success)
                 {
                     AZ_Error("RayTracing", false, "Failed to initialize BLAS buffer index buffer with error code: %d", resultCode);

+ 2 - 2
Gems/Atom/RHI/Vulkan/Code/Source/RHI/RayTracingBlas.cpp

@@ -51,7 +51,7 @@ namespace AZ
                 const AZ::Aabb& aabb = descriptor->GetAABB();
                 buffers.m_aabbBuffer = RHI::Factory::Get().CreateBuffer();
                 AZ::RHI::BufferDescriptor blasBufferDescriptor;
-                blasBufferDescriptor.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite | RHI::BufferBindFlags::RayTracingAccelerationStructure;
+                blasBufferDescriptor.m_bindFlags = RHI::BufferBindFlags::CopyRead | RHI::BufferBindFlags::RayTracingAccelerationStructure;
                 blasBufferDescriptor.m_byteCount = sizeof(VkAabbPositionsKHR);
                 blasBufferDescriptor.m_alignment = RHI::AlignUp(sizeof(VkAabbPositionsKHR), 8);
 
@@ -67,7 +67,7 @@ namespace AZ
                 blasBufferRequest.m_buffer = buffers.m_aabbBuffer.get();
                 blasBufferRequest.m_initialData = &rtAabb;
                 blasBufferRequest.m_descriptor = blasBufferDescriptor;
-                auto resultCode = bufferPools.GetBlasBufferPool()->InitBuffer(blasBufferRequest);
+                auto resultCode = bufferPools.GetAabbStagingBufferPool()->InitBuffer(blasBufferRequest);
                 if (resultCode != AZ::RHI::ResultCode::Success)
                 {
                     AZ_Error("RayTracing", false, "Failed to initialize BLAS buffer index buffer with error code: %d", resultCode);

+ 1 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/RayTracingBufferPools.h

@@ -24,6 +24,7 @@ namespace AZ
             static RHI::Ptr<RayTracingBufferPools> Create() { return aznew RayTracingBufferPools; }
 
         protected:
+            RHI::BufferBindFlags GetAabbStagingBufferBindFlags() const override { return RHI::BufferBindFlags::CopyRead | RHI::BufferBindFlags::RayTracingAccelerationStructure; }
             RHI::BufferBindFlags GetShaderTableBufferBindFlags() const override { return RHI::BufferBindFlags::CopyRead | RHI::BufferBindFlags::RayTracingShaderTable; }
             RHI::BufferBindFlags GetTlasInstancesBufferBindFlags() const override { return RHI::BufferBindFlags::ShaderReadWrite | RHI::BufferBindFlags::RayTracingAccelerationStructure; }