瀏覽代碼

Improve SRG compiling performance (#18728)

Signed-off-by: Martin Sattlecker <[email protected]>
Martin Sattlecker 6 月之前
父節點
當前提交
1eb502c36c

+ 1 - 0
Code/Framework/AtomCore/AtomCore/atomcore_files.cmake

@@ -15,6 +15,7 @@ set(FILES
     Instance/InstanceDatabase.h
     Instance/InstanceDatabase.h
     std/containers/fixed_vector_set.h
     std/containers/fixed_vector_set.h
     std/containers/lru_cache.h
     std/containers/lru_cache.h
+    std/containers/small_vector.h
     std/containers/vector_set.h
     std/containers/vector_set.h
     std/containers/vector_set_base.h
     std/containers/vector_set_base.h
     std/parallel/concurrency_checker.h
     std/parallel/concurrency_checker.h

+ 245 - 0
Code/Framework/AtomCore/AtomCore/std/containers/small_vector.h

@@ -0,0 +1,245 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <AzCore/std/containers/fixed_vector.h>
+#include <AzCore/std/containers/variant.h>
+#include <AzCore/std/containers/vector.h>
+
+namespace AZStd
+{
+    // Class that servers a similar purpose as boosts small_vector
+    //  https://www.boost.org/doc/libs/1_86_0/doc/html/container/non_standard_containers.html#container.non_standard_containers.small_vector
+    // If the there are less than `FixedSize` elements in the vector, the data is stored in a fixed_vector
+    // If there are more, an AZstd::vector allocates memory on the heap
+    //
+    // This is a pretty simple implementation that does not expose all vector functions.
+    //  E.g. iterators are only accessible through the span() function
+    template<class T, size_t FixedSize>
+    class small_vector
+    {
+    public:
+        small_vector() = default;
+        small_vector(const small_vector<T, FixedSize>&) = default;
+        small_vector(small_vector<T, FixedSize>&&) = default;
+        small_vector& operator=(const small_vector<T, FixedSize>&) = default;
+        small_vector& operator=(small_vector<T, FixedSize>&&) = default;
+
+        small_vector(size_t newSize, const T& value)
+        {
+            resize(newSize, value);
+        }
+
+        void push_back(const T& value)
+        {
+            if (AZStd::holds_alternative<FixedVectorT>(m_data))
+            {
+                if (span().size() >= FixedSize)
+                {
+                    ConvertToHeapVector();
+                    AZStd::get<HeapVectorT>(m_data).push_back(value);
+                }
+                else
+                {
+                    AZStd::get<FixedVectorT>(m_data).push_back(value);
+                }
+            }
+            else if (AZStd::holds_alternative<HeapVectorT>(m_data))
+            {
+                AZStd::get<HeapVectorT>(m_data).push_back(value);
+            }
+            else
+            {
+                AZ_Assert(false, "small_vector::push_back: Empty variant");
+            }
+        }
+
+        template<typename... Args, typename = AZStd::enable_if_t<AZStd::is_constructible_v<T, Args...>>>
+        T& emplace_back(Args&&... args) noexcept
+        {
+            if (AZStd::holds_alternative<FixedVectorT>(m_data))
+            {
+                if (span().size() >= FixedSize)
+                {
+                    ConvertToHeapVector();
+                    return AZStd::get<HeapVectorT>(m_data).emplace_back(AZStd::forward<Args>(args)...);
+                }
+                else
+                {
+                    return AZStd::get<FixedVectorT>(m_data).emplace_back(AZStd::forward<Args>(args)...);
+                }
+            }
+            else if (AZStd::holds_alternative<HeapVectorT>(m_data))
+            {
+                return AZStd::get<HeapVectorT>(m_data).emplace_back(AZStd::forward<Args>(args)...);
+            }
+            else
+            {
+                AZ_Assert(false, "small_vector::emplace_back: Empty variant");
+                return span().front();
+            }
+        }
+
+        void erase(size_t position)
+        {
+            if (AZStd::holds_alternative<FixedVectorT>(m_data))
+            {
+                auto& fixed = AZStd::get<FixedVectorT>(m_data);
+                fixed.erase(fixed.begin() + position);
+            }
+            else if (AZStd::holds_alternative<HeapVectorT>(m_data))
+            {
+                auto& heap = AZStd::get<HeapVectorT>(m_data);
+                heap.erase(heap.begin() + position);
+            }
+            else
+            {
+                AZ_Assert(false, "small_vector::erase: Empty variant");
+            }
+        }
+
+        void resize(size_t newSize, const T& value)
+        {
+            if (AZStd::holds_alternative<FixedVectorT>(m_data))
+            {
+                if (newSize > FixedSize)
+                {
+                    ConvertToHeapVector();
+                    AZStd::get<HeapVectorT>(m_data).resize(newSize, value);
+                }
+                else
+                {
+                    AZStd::get<FixedVectorT>(m_data).resize(newSize, value);
+                }
+            }
+            else if (AZStd::holds_alternative<HeapVectorT>(m_data))
+            {
+                AZStd::get<HeapVectorT>(m_data).resize(newSize, value);
+            }
+            else
+            {
+                AZ_Assert(false, "small_vector::resize: Empty variant");
+            }
+        }
+
+        void resize(size_t newSize)
+        {
+            resize(newSize, {});
+        }
+
+        void reserve(size_t newCapacity)
+        {
+            if (AZStd::holds_alternative<FixedVectorT>(m_data))
+            {
+                if (newCapacity > FixedSize)
+                {
+                    ConvertToHeapVector();
+                    AZStd::get<HeapVectorT>(m_data).reserve(newCapacity);
+                }
+                else
+                {
+                    AZStd::get<FixedVectorT>(m_data).reserve(newCapacity);
+                }
+            }
+            else if (AZStd::holds_alternative<HeapVectorT>(m_data))
+            {
+                AZStd::get<HeapVectorT>(m_data).reserve(newCapacity);
+            }
+            else
+            {
+                AZ_Assert(false, "small_vector::reserve: Empty variant");
+            }
+        }
+
+        size_t size()
+        {
+            return AZStd::visit(
+                [](auto& vector)
+                {
+                    return vector.size();
+                },
+                m_data);
+        }
+
+        AZStd::span<T> span()
+        {
+            return AZStd::visit(
+                [](auto& vector) -> AZStd::span<T>
+                {
+                    return { vector };
+                },
+                m_data);
+        }
+
+        AZStd::span<const T> span() const
+        {
+            return AZStd::visit(
+                [](auto& vector) -> AZStd::span<const T>
+                {
+                    return { vector };
+                },
+                m_data);
+        }
+
+        T& operator[](size_t pos)
+        {
+            return AZStd::visit(
+                [&](auto& vector) -> T&
+                {
+                    return vector[pos];
+                },
+                m_data);
+        }
+
+        const T& operator[](size_t pos) const
+        {
+            return AZStd::visit(
+                [&](auto& vector) -> T&
+                {
+                    return vector[pos];
+                },
+                m_data);
+        }
+
+        void clear()
+        {
+            AZStd::visit(
+                [&](auto& vector)
+                {
+                    vector.clear();
+                },
+                m_data);
+        }
+
+        bool empty() const
+        {
+            return span().empty();
+        }
+
+        size_t size() const
+        {
+            return span().size();
+        }
+
+    private:
+        void ConvertToHeapVector()
+        {
+            if (AZStd::holds_alternative<FixedVectorT>(m_data))
+            {
+                auto data = span();
+                auto newData = HeapVectorT(data.begin(), data.end());
+                m_data = AZStd::move(newData);
+            }
+        }
+
+        using FixedVectorT = AZStd::fixed_vector<T, FixedSize>;
+        using HeapVectorT = AZStd::vector<T>;
+
+        AZStd::variant<FixedVectorT, HeapVectorT> m_data;
+    };
+} // namespace AZStd

+ 1 - 1
Code/Framework/AzCore/AzCore/std/containers/fixed_vector.h

@@ -548,7 +548,7 @@ namespace AZStd
 
 
         // Removes unused capacity - For fixed_vector this only asserts
         // Removes unused capacity - For fixed_vector this only asserts
         // that the supplied capacity is not longer than the fixed_vector capacity
         // that the supplied capacity is not longer than the fixed_vector capacity
-        void reserve(size_type newCapacity)
+        void reserve([[maybe_unused]] size_type newCapacity)
         {
         {
             // No-op - Implemented to provide consistent std::vector
             // No-op - Implemented to provide consistent std::vector
             AZSTD_CONTAINER_ASSERT(newCapacity <= capacity(),
             AZSTD_CONTAINER_ASSERT(newCapacity <= capacity(),

+ 15 - 12
Gems/Atom/RHI/DX12/Code/Source/RHI/DescriptorContext.cpp

@@ -5,14 +5,15 @@
  * SPDX-License-Identifier: Apache-2.0 OR MIT
  * SPDX-License-Identifier: Apache-2.0 OR MIT
  *
  *
  */
  */
-#include <RHI/DescriptorContext.h>
+#include <Atom/RHI.Reflect/DX12/PlatformLimitsDescriptor.h>
+#include <Atom/RHI/DeviceShaderResourceGroupPool.h>
+#include <AtomCore/std/containers/small_vector.h>
 #include <RHI/Buffer.h>
 #include <RHI/Buffer.h>
 #include <RHI/Conversions.h>
 #include <RHI/Conversions.h>
+#include <RHI/DescriptorContext.h>
 #include <RHI/Device.h>
 #include <RHI/Device.h>
 #include <RHI/Image.h>
 #include <RHI/Image.h>
 #include <RHI/ShaderResourceGroupPool.h>
 #include <RHI/ShaderResourceGroupPool.h>
-#include <Atom/RHI.Reflect/DX12/PlatformLimitsDescriptor.h>
-#include <Atom/RHI/DeviceShaderResourceGroupPool.h>
 
 
 namespace AZ
 namespace AZ
 {
 {
@@ -470,7 +471,7 @@ namespace AZ
         {
         {
             GetPool(table.GetType(), table.GetFlags()).ReleaseTable(table);
             GetPool(table.GetType(), table.GetFlags()).ReleaseTable(table);
         }
         }
-        
+
         void DescriptorContext::UpdateDescriptorTableRange(
         void DescriptorContext::UpdateDescriptorTableRange(
             DescriptorTable gpuDestinationTable,
             DescriptorTable gpuDestinationTable,
             const DescriptorHandle* cpuSourceDescriptors,
             const DescriptorHandle* cpuSourceDescriptors,
@@ -479,7 +480,8 @@ namespace AZ
             const uint32_t DescriptorCount = gpuDestinationTable.GetSize();
             const uint32_t DescriptorCount = gpuDestinationTable.GetSize();
 
 
             // Resolve source descriptors to platform handles.
             // Resolve source descriptors to platform handles.
-            AZStd::vector<D3D12_CPU_DESCRIPTOR_HANDLE> cpuSourceHandles;
+            constexpr size_t FixedSize = 16;
+            AZStd::small_vector<D3D12_CPU_DESCRIPTOR_HANDLE, FixedSize> cpuSourceHandles;
             cpuSourceHandles.reserve(DescriptorCount);
             cpuSourceHandles.reserve(DescriptorCount);
             for (uint32_t i = 0; i < DescriptorCount; ++i)
             for (uint32_t i = 0; i < DescriptorCount; ++i)
             {
             {
@@ -490,16 +492,17 @@ namespace AZ
             D3D12_CPU_DESCRIPTOR_HANDLE gpuDestinationHandle = GetCpuPlatformHandleForTable(gpuDestinationTable);
             D3D12_CPU_DESCRIPTOR_HANDLE gpuDestinationHandle = GetCpuPlatformHandleForTable(gpuDestinationTable);
 
 
             // An array of descriptor sizes for each range. We just want N ranges with 1 descriptor each.
             // An array of descriptor sizes for each range. We just want N ranges with 1 descriptor each.
-            AZStd::vector<uint32_t> rangeCounts(DescriptorCount, 1);
+            AZStd::small_vector<uint32_t, FixedSize> rangeCountsFixed;
+            rangeCountsFixed.resize(DescriptorCount, 1);
 
 
             //We are gathering N source descriptors into a contiguous destination table.
             //We are gathering N source descriptors into a contiguous destination table.
             m_device->CopyDescriptors(
             m_device->CopyDescriptors(
-                1,                      // Number of destination ranges.
-                &gpuDestinationHandle,  // Destination range array.
-                &DescriptorCount,       // Number of destination table elements in each range.
-                DescriptorCount,        // Number of source ranges.
-                cpuSourceHandles.data(),// Source range array
-                rangeCounts.data(),     // Number of elements in each source range.
+                1, // Number of destination ranges.
+                &gpuDestinationHandle, // Destination range array.
+                &DescriptorCount, // Number of destination table elements in each range.
+                DescriptorCount, // Number of source ranges.
+                cpuSourceHandles.span().data(), // Source range array
+                rangeCountsFixed.span().data(), // Number of elements in each source range.
                 heapType);
                 heapType);
         }
         }
 
 

+ 51 - 47
Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderResourceGroupPool.cpp

@@ -6,7 +6,7 @@
  *
  *
  */
  */
 
 
-#include <RHI/ShaderResourceGroupPool.h>
+#include <AtomCore/std/containers/small_vector.h>
 #include <RHI/Buffer.h>
 #include <RHI/Buffer.h>
 #include <RHI/BufferView.h>
 #include <RHI/BufferView.h>
 #include <RHI/Conversions.h>
 #include <RHI/Conversions.h>
@@ -14,54 +14,58 @@
 #include <RHI/Device.h>
 #include <RHI/Device.h>
 #include <RHI/Image.h>
 #include <RHI/Image.h>
 #include <RHI/ImageView.h>
 #include <RHI/ImageView.h>
+#include <RHI/ShaderResourceGroupPool.h>
 
 
 namespace AZ
 namespace AZ
 {
 {
     namespace DX12
     namespace DX12
     {
     {
         template<typename T, typename U>
         template<typename T, typename U>
-        AZStd::vector<DescriptorHandle> ShaderResourceGroupPool::GetSRVsFromImageViews(const AZStd::span<const RHI::ConstPtr<T>>& imageViews, D3D12_SRV_DIMENSION dimension)
+        void ShaderResourceGroupPool::GetSRVsFromImageViews(
+            const AZStd::span<const RHI::ConstPtr<T>>& imageViews,
+            D3D12_SRV_DIMENSION dimension,
+            AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize>& result)
         {
         {
-            AZStd::vector<DescriptorHandle> cpuSourceDescriptors(imageViews.size(), m_descriptorContext->GetNullHandleSRV(dimension));
+            result.resize(imageViews.size(), m_descriptorContext->GetNullHandleSRV(dimension));
 
 
-            for (size_t i = 0; i < cpuSourceDescriptors.size(); ++i)
+            for (size_t i = 0; i < result.size(); ++i)
             {
             {
                 if (imageViews[i])
                 if (imageViews[i])
                 {
                 {
-                    cpuSourceDescriptors[i] = AZStd::static_pointer_cast<const U>(imageViews[i])->GetReadDescriptor();
+                    result.span()[i] = AZStd::static_pointer_cast<const U>(imageViews[i])->GetReadDescriptor();
                 }
                 }
             }
             }
-
-            return cpuSourceDescriptors;
         }
         }
 
 
         template<typename T, typename U>
         template<typename T, typename U>
-        AZStd::vector<DescriptorHandle> ShaderResourceGroupPool::GetUAVsFromImageViews(const AZStd::span<const RHI::ConstPtr<T>>& imageViews, D3D12_UAV_DIMENSION dimension)
+        void ShaderResourceGroupPool::GetUAVsFromImageViews(
+            const AZStd::span<const RHI::ConstPtr<T>>& imageViews,
+            D3D12_UAV_DIMENSION dimension,
+            AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize>& result)
         {
         {
-            AZStd::vector<DescriptorHandle> cpuSourceDescriptors(imageViews.size(), m_descriptorContext->GetNullHandleUAV(dimension));
-            for (size_t i = 0; i < cpuSourceDescriptors.size(); ++i)
+            result.resize(imageViews.size(), m_descriptorContext->GetNullHandleUAV(dimension));
+            for (size_t i = 0; i < result.size(); ++i)
             {
             {
                 if (imageViews[i])
                 if (imageViews[i])
                 {
                 {
-                    cpuSourceDescriptors[i] = AZStd::static_pointer_cast<const U>(imageViews[i])->GetReadWriteDescriptor();
+                    result.span()[i] = AZStd::static_pointer_cast<const U>(imageViews[i])->GetReadWriteDescriptor();
                 }
                 }
             }
             }
-
-            return cpuSourceDescriptors;
         }
         }
 
 
-        AZStd::vector<DescriptorHandle> ShaderResourceGroupPool::GetCBVsFromBufferViews(const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufferViews)
+        void ShaderResourceGroupPool::GetCBVsFromBufferViews(
+            const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufferViews,
+            AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize>& result)
         {
         {
-            AZStd::vector<DescriptorHandle> cpuSourceDescriptors(bufferViews.size(), m_descriptorContext->GetNullHandleCBV());
+            result.resize(bufferViews.size(), m_descriptorContext->GetNullHandleCBV());
 
 
             for (size_t i = 0; i < bufferViews.size(); ++i)
             for (size_t i = 0; i < bufferViews.size(); ++i)
             {
             {
                 if (bufferViews[i])
                 if (bufferViews[i])
                 {
                 {
-                    cpuSourceDescriptors[i] = AZStd::static_pointer_cast<const BufferView>(bufferViews[i])->GetConstantDescriptor();
+                    result.span()[i] = AZStd::static_pointer_cast<const BufferView>(bufferViews[i])->GetConstantDescriptor();
                 }
                 }
             }
             }
-            return cpuSourceDescriptors;
         }
         }
 
 
         RHI::Ptr<ShaderResourceGroupPool> ShaderResourceGroupPool::Create()
         RHI::Ptr<ShaderResourceGroupPool> ShaderResourceGroupPool::Create()
@@ -282,22 +286,24 @@ namespace AZ
 
 
                     AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>> bufferViews = groupData.GetBufferViewArray(bufferInputIndex);
                     AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>> bufferViews = groupData.GetBufferViewArray(bufferInputIndex);
                     D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputBufferAccess(shaderInputBuffer.m_access);
                     D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputBufferAccess(shaderInputBuffer.m_access);
-                    AZStd::vector<DescriptorHandle> descriptorHandles;
+                    AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize> descriptorHandles;
                     switch (descriptorRangeType)
                     switch (descriptorRangeType)
                     {
                     {
                         case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
                         case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
                         {
                         {
-                            descriptorHandles = GetSRVsFromImageViews< RHI::DeviceBufferView, BufferView> (bufferViews, D3D12_SRV_DIMENSION_BUFFER);
+                            GetSRVsFromImageViews<RHI::DeviceBufferView, BufferView>(
+                                bufferViews, D3D12_SRV_DIMENSION_BUFFER, descriptorHandles);
                             break;
                             break;
                         }
                         }
                         case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
                         case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
                         {
                         {
-                            descriptorHandles = GetUAVsFromImageViews<RHI::DeviceBufferView, BufferView>(bufferViews, D3D12_UAV_DIMENSION_BUFFER);
+                            GetUAVsFromImageViews<RHI::DeviceBufferView, BufferView>(
+                                bufferViews, D3D12_UAV_DIMENSION_BUFFER, descriptorHandles);
                             break;
                             break;
                         }
                         }
                         case D3D12_DESCRIPTOR_RANGE_TYPE_CBV:
                         case D3D12_DESCRIPTOR_RANGE_TYPE_CBV:
                         {
                         {
-                            descriptorHandles = GetCBVsFromBufferViews(bufferViews);
+                            GetCBVsFromBufferViews(bufferViews, descriptorHandles);
                             break;
                             break;
                         }
                         }
                         default:
                         default:
@@ -305,7 +311,7 @@ namespace AZ
                             break;
                             break;
                     }
                     }
 
 
-                    UpdateDescriptorTableRange(descriptorTable, descriptorHandles, bufferInputIndex);
+                    UpdateDescriptorTableRange(descriptorTable, descriptorHandles.span(), bufferInputIndex);
                     ++shaderInputIndex;
                     ++shaderInputIndex;
                 }
                 }
             }
             }
@@ -320,19 +326,19 @@ namespace AZ
                     AZStd::span<const RHI::ConstPtr<RHI::DeviceImageView>> imageViews = groupData.GetImageViewArray(imageInputIndex);
                     AZStd::span<const RHI::ConstPtr<RHI::DeviceImageView>> imageViews = groupData.GetImageViewArray(imageInputIndex);
                     D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputImageAccess(shaderInputImage.m_access);
                     D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputImageAccess(shaderInputImage.m_access);
 
 
-                    AZStd::vector<DescriptorHandle> descriptorHandles;
+                    AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize> descriptorHandles;
                     switch (descriptorRangeType)
                     switch (descriptorRangeType)
                     {
                     {
                         case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
                         case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
                         {
                         {
-                            descriptorHandles =
-                                GetSRVsFromImageViews<RHI::DeviceImageView, ImageView>(imageViews, ConvertSRVDimension(shaderInputImage.m_type));
+                            GetSRVsFromImageViews<RHI::DeviceImageView, ImageView>(
+                                imageViews, ConvertSRVDimension(shaderInputImage.m_type), descriptorHandles);
                             break;
                             break;
                         }
                         }
                         case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
                         case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
                         {
                         {
-                            descriptorHandles =
-                                GetUAVsFromImageViews<RHI::DeviceImageView, ImageView>(imageViews, ConvertUAVDimension(shaderInputImage.m_type));
+                            GetUAVsFromImageViews<RHI::DeviceImageView, ImageView>(
+                                imageViews, ConvertUAVDimension(shaderInputImage.m_type), descriptorHandles);
                             break;
                             break;
                         }
                         }
                         default:
                         default:
@@ -340,7 +346,7 @@ namespace AZ
                         break;
                         break;
                     }
                     }
 
 
-                    UpdateDescriptorTableRange(descriptorTable, descriptorHandles, imageInputIndex);
+                    UpdateDescriptorTableRange(descriptorTable, descriptorHandles.span(), imageInputIndex);
                     ++shaderInputIndex;
                     ++shaderInputIndex;
                 }
                 }
             }
             }
@@ -489,17 +495,17 @@ namespace AZ
 
 
             D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputBufferAccess(bufferAccess);
             D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputBufferAccess(bufferAccess);
 
 
-            AZStd::vector<DescriptorHandle> descriptorHandles;
+            AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize> descriptorHandles;
             switch (descriptorRangeType)
             switch (descriptorRangeType)
             {
             {
             case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
             case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
                 {
                 {
-                    descriptorHandles = GetSRVsFromImageViews<RHI::DeviceBufferView, BufferView>(bufferViews, D3D12_SRV_DIMENSION_BUFFER);
+                    GetSRVsFromImageViews<RHI::DeviceBufferView, BufferView>(bufferViews, D3D12_SRV_DIMENSION_BUFFER, descriptorHandles);
                     break;
                     break;
                 }
                 }
             case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
             case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
                 {
                 {
-                    descriptorHandles = GetUAVsFromImageViews<RHI::DeviceBufferView, BufferView>(bufferViews, D3D12_UAV_DIMENSION_BUFFER);
+                    GetUAVsFromImageViews<RHI::DeviceBufferView, BufferView>(bufferViews, D3D12_UAV_DIMENSION_BUFFER, descriptorHandles);
                     break;
                     break;
                 }
                 }
             default:
             default:
@@ -508,7 +514,7 @@ namespace AZ
             }
             }
 
 
             m_descriptorContext->UpdateDescriptorTableRange(
             m_descriptorContext->UpdateDescriptorTableRange(
-                descriptorTable, descriptorHandles.data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
+                descriptorTable, descriptorHandles.span().data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
         }
         }
 
 
         void ShaderResourceGroupPool::UpdateUnboundedImagesDescTable(
         void ShaderResourceGroupPool::UpdateUnboundedImagesDescTable(
@@ -530,17 +536,17 @@ namespace AZ
 
 
             D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputImageAccess(imageAccess);
             D3D12_DESCRIPTOR_RANGE_TYPE descriptorRangeType = ConvertShaderInputImageAccess(imageAccess);
 
 
-            AZStd::vector<DescriptorHandle> descriptorHandles;
+            AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize> descriptorHandles;
             switch (descriptorRangeType)
             switch (descriptorRangeType)
             {
             {
             case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
             case D3D12_DESCRIPTOR_RANGE_TYPE_SRV:
                 {
                 {
-                    descriptorHandles = GetSRVsFromImageViews<RHI::DeviceImageView, ImageView>(imageViews, ConvertSRVDimension(imageType));
+                    GetSRVsFromImageViews<RHI::DeviceImageView, ImageView>(imageViews, ConvertSRVDimension(imageType), descriptorHandles);
                     break;
                     break;
                 }
                 }
             case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
             case D3D12_DESCRIPTOR_RANGE_TYPE_UAV:
                 {
                 {
-                    descriptorHandles = GetUAVsFromImageViews<RHI::DeviceImageView, ImageView>(imageViews, ConvertUAVDimension(imageType));
+                    GetUAVsFromImageViews<RHI::DeviceImageView, ImageView>(imageViews, ConvertUAVDimension(imageType), descriptorHandles);
                     break;
                     break;
                 }
                 }
             default:
             default:
@@ -549,7 +555,7 @@ namespace AZ
             }
             }
 
 
             m_descriptorContext->UpdateDescriptorTableRange(
             m_descriptorContext->UpdateDescriptorTableRange(
-                descriptorTable, descriptorHandles.data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
+                descriptorTable, descriptorHandles.span().data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
         }
         }
 
 
         void ShaderResourceGroupPool::OnFrameEnd()
         void ShaderResourceGroupPool::OnFrameEnd()
@@ -580,18 +586,14 @@ namespace AZ
         }
         }
 
 
         void ShaderResourceGroupPool::UpdateDescriptorTableRange(
         void ShaderResourceGroupPool::UpdateDescriptorTableRange(
-            DescriptorTable descriptorTable,
-            const AZStd::vector<DescriptorHandle>& handles,
-            RHI::ShaderInputBufferIndex bufferInputIndex)
+            DescriptorTable descriptorTable, const AZStd::span<DescriptorHandle>& handles, RHI::ShaderInputBufferIndex bufferInputIndex)
         {
         {
             const DescriptorTable gpuDestinationTable = GetBufferTable(descriptorTable, bufferInputIndex);
             const DescriptorTable gpuDestinationTable = GetBufferTable(descriptorTable, bufferInputIndex);
             m_descriptorContext->UpdateDescriptorTableRange(gpuDestinationTable, handles.data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
             m_descriptorContext->UpdateDescriptorTableRange(gpuDestinationTable, handles.data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
         }
         }
 
 
         void ShaderResourceGroupPool::UpdateDescriptorTableRange(
         void ShaderResourceGroupPool::UpdateDescriptorTableRange(
-            DescriptorTable descriptorTable,
-            const AZStd::vector<DescriptorHandle>& handles,
-            RHI::ShaderInputImageIndex imageInputIndex)
+            DescriptorTable descriptorTable, const AZStd::span<DescriptorHandle>& handles, RHI::ShaderInputImageIndex imageInputIndex)
         {
         {
             const DescriptorTable gpuDestinationTable = GetImageTable(descriptorTable, imageInputIndex);
             const DescriptorTable gpuDestinationTable = GetImageTable(descriptorTable, imageInputIndex);
             m_descriptorContext->UpdateDescriptorTableRange(gpuDestinationTable, handles.data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
             m_descriptorContext->UpdateDescriptorTableRange(gpuDestinationTable, handles.data(), D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
@@ -603,17 +605,19 @@ namespace AZ
             AZStd::span<const RHI::SamplerState> samplerStates)
             AZStd::span<const RHI::SamplerState> samplerStates)
         {
         {
             const DescriptorHandle nullHandle = m_descriptorContext->GetNullHandleSampler();
             const DescriptorHandle nullHandle = m_descriptorContext->GetNullHandleSampler();
-            AZStd::vector<DescriptorHandle> cpuSourceDescriptors(aznumeric_caster(samplerStates.size()), nullHandle);
+            AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize> cpuSourceDescriptors(
+                aznumeric_caster(samplerStates.size()), nullHandle);
             auto& device = static_cast<Device&>(GetDevice());
             auto& device = static_cast<Device&>(GetDevice());
-            AZStd::vector<RHI::ConstPtr<Sampler>> samplers(samplerStates.size(), nullptr);
+            AZStd::small_vector<RHI::ConstPtr<Sampler>, SRGViewsFixedSize> samplers(samplerStates.size(), nullptr);
             for (size_t i = 0; i < samplerStates.size(); ++i)
             for (size_t i = 0; i < samplerStates.size(); ++i)
             {
             {
-                samplers[i] = device.AcquireSampler(samplerStates[i]);
-                cpuSourceDescriptors[i] = samplers[i]->GetDescriptorHandle();
+                samplers.span()[i] = device.AcquireSampler(samplerStates[i]);
+                cpuSourceDescriptors.span()[i] = samplers.span()[i]->GetDescriptorHandle();
             }
             }
 
 
             const DescriptorTable gpuDestinationTable = GetSamplerTable(descriptorTable, samplerInputIndex);
             const DescriptorTable gpuDestinationTable = GetSamplerTable(descriptorTable, samplerInputIndex);
-            m_descriptorContext->UpdateDescriptorTableRange(gpuDestinationTable, cpuSourceDescriptors.data(), D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER);
+            m_descriptorContext->UpdateDescriptorTableRange(
+                gpuDestinationTable, cpuSourceDescriptors.span().data(), D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER);
         }
         }
     }
     }
 }
 }

+ 20 - 11
Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderResourceGroupPool.h

@@ -7,10 +7,12 @@
  */
  */
 #pragma once
 #pragma once
 
 
-#include <RHI/ShaderResourceGroup.h>
-#include <RHI/MemorySubAllocator.h>
-#include <Atom/RHI/FrameEventBus.h>
 #include <Atom/RHI/DeviceShaderResourceGroupPool.h>
 #include <Atom/RHI/DeviceShaderResourceGroupPool.h>
+#include <Atom/RHI/FrameEventBus.h>
+#include <RHI/Descriptor.h>
+#include <RHI/MemorySubAllocator.h>
+#include <RHI/ShaderResourceGroup.h>
+
 
 
 namespace AZ
 namespace AZ
 {
 {
@@ -20,6 +22,7 @@ namespace AZ
         class ImageView;
         class ImageView;
         class ShaderResourceGroupLayout;
         class ShaderResourceGroupLayout;
         class DescriptorContext;
         class DescriptorContext;
+        constexpr size_t SRGViewsFixedSize = 16;
 
 
         class ShaderResourceGroupPool final
         class ShaderResourceGroupPool final
             : public RHI::DeviceShaderResourceGroupPool
             : public RHI::DeviceShaderResourceGroupPool
@@ -71,13 +74,11 @@ namespace AZ
 
 
             void UpdateDescriptorTableRange(
             void UpdateDescriptorTableRange(
                 DescriptorTable descriptorTable,
                 DescriptorTable descriptorTable,
-                const AZStd::vector<DescriptorHandle>& descriptors,
+                const AZStd::span<DescriptorHandle>& descriptors,
                 RHI::ShaderInputBufferIndex bufferInputIndex);
                 RHI::ShaderInputBufferIndex bufferInputIndex);
 
 
             void UpdateDescriptorTableRange(
             void UpdateDescriptorTableRange(
-                DescriptorTable descriptorTable,
-                const AZStd::vector<DescriptorHandle>& descriptors,
-                RHI::ShaderInputImageIndex imageIndex);
+                DescriptorTable descriptorTable, const AZStd::span<DescriptorHandle>& descriptors, RHI::ShaderInputImageIndex imageIndex);
 
 
             void UpdateDescriptorTableRange(
             void UpdateDescriptorTableRange(
                 DescriptorTable descriptorTable,
                 DescriptorTable descriptorTable,
@@ -93,12 +94,20 @@ namespace AZ
             DescriptorTable GetSamplerTable(DescriptorTable descriptorTable, RHI::ShaderInputSamplerIndex samplerInputIndex) const;
             DescriptorTable GetSamplerTable(DescriptorTable descriptorTable, RHI::ShaderInputSamplerIndex samplerInputIndex) const;
 
 
             template<typename T, typename U>
             template<typename T, typename U>
-            AZStd::vector<DescriptorHandle> GetSRVsFromImageViews(const AZStd::span<const RHI::ConstPtr<T>>& imageViews, D3D12_SRV_DIMENSION dimension);
+            void GetSRVsFromImageViews(
+                const AZStd::span<const RHI::ConstPtr<T>>& imageViews,
+                D3D12_SRV_DIMENSION dimension,
+                AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize>& result);
 
 
             template<typename T, typename U>
             template<typename T, typename U>
-            AZStd::vector<DescriptorHandle> GetUAVsFromImageViews(const AZStd::span<const RHI::ConstPtr<T>>& bufferViews, D3D12_UAV_DIMENSION dimension);
-
-            AZStd::vector<DescriptorHandle> GetCBVsFromBufferViews(const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufferViews);
+            void GetUAVsFromImageViews(
+                const AZStd::span<const RHI::ConstPtr<T>>& bufferViews,
+                D3D12_UAV_DIMENSION dimension,
+                AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize>& result);
+
+            void GetCBVsFromBufferViews(
+                const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufferViews,
+                AZStd::small_vector<DescriptorHandle, SRGViewsFixedSize>& result);
 
 
             MemoryPoolSubAllocator m_constantAllocator;
             MemoryPoolSubAllocator m_constantAllocator;
             DescriptorContext* m_descriptorContext = nullptr;
             DescriptorContext* m_descriptorContext = nullptr;

+ 77 - 49
Gems/Atom/RHI/Vulkan/Code/Source/RHI/DescriptorSet.cpp

@@ -39,12 +39,17 @@ namespace AZ
             }            
             }            
         }
         }
 
 
+        void DescriptorSet::ReserveUpdateData(size_t numUpdates)
+        {
+            m_updateData.reserve(numUpdates);
+        }
+
         void DescriptorSet::UpdateBufferViews(uint32_t layoutIndex, const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufViews)
         void DescriptorSet::UpdateBufferViews(uint32_t layoutIndex, const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufViews)
         {
         {
             const DescriptorSetLayout& layout = *m_descriptor.m_descriptorSetLayout;
             const DescriptorSetLayout& layout = *m_descriptor.m_descriptorSetLayout;
             VkDescriptorType type = layout.GetDescriptorType(layoutIndex);
             VkDescriptorType type = layout.GetDescriptorType(layoutIndex);
-            
-            WriteDescriptorData data;
+
+            auto& data = m_updateData.emplace_back();
             data.m_layoutIndex = layoutIndex;
             data.m_layoutIndex = layoutIndex;
 
 
             if (type == VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER ||
             if (type == VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER ||
@@ -117,15 +122,13 @@ namespace AZ
                     }
                     }
                 }
                 }
             }
             }
-
-            m_updateData.push_back(AZStd::move(data));
         }
         }
 
 
         void DescriptorSet::UpdateImageViews(uint32_t layoutIndex, const AZStd::span<const RHI::ConstPtr<RHI::DeviceImageView>>& imageViews, RHI::ShaderInputImageType imageType)
         void DescriptorSet::UpdateImageViews(uint32_t layoutIndex, const AZStd::span<const RHI::ConstPtr<RHI::DeviceImageView>>& imageViews, RHI::ShaderInputImageType imageType)
         {
         {
             const DescriptorSetLayout& layout = *m_descriptor.m_descriptorSetLayout;
             const DescriptorSetLayout& layout = *m_descriptor.m_descriptorSetLayout;
 
 
-            WriteDescriptorData data;
+            auto& data = m_updateData.emplace_back();
             data.m_layoutIndex = layoutIndex;
             data.m_layoutIndex = layoutIndex;
 
 
             data.m_imageViewsInfo.resize(imageViews.size());
             data.m_imageViewsInfo.resize(imageViews.size());
@@ -165,15 +168,13 @@ namespace AZ
                 
                 
                 data.m_imageViewsInfo[i]  = imageInfo;
                 data.m_imageViewsInfo[i]  = imageInfo;
             }
             }
-
-            m_updateData.push_back(AZStd::move(data));
         }
         }
 
 
         void DescriptorSet::UpdateSamplers(uint32_t layoutIndex, const AZStd::span<const RHI::SamplerState>& samplers)
         void DescriptorSet::UpdateSamplers(uint32_t layoutIndex, const AZStd::span<const RHI::SamplerState>& samplers)
         {
         {
             auto& device = static_cast<Device&>(GetDevice());
             auto& device = static_cast<Device&>(GetDevice());
 
 
-            WriteDescriptorData data;
+            auto& data = m_updateData.emplace_back();
             data.m_layoutIndex = layoutIndex;
             data.m_layoutIndex = layoutIndex;
 
 
             VkDescriptorImageInfo imageInfo = {};
             VkDescriptorImageInfo imageInfo = {};
@@ -185,8 +186,6 @@ namespace AZ
                 imageInfo.sampler = device.AcquireSampler(samplerDesc)->GetNativeSampler();
                 imageInfo.sampler = device.AcquireSampler(samplerDesc)->GetNativeSampler();
                 data.m_imageViewsInfo.push_back(imageInfo);
                 data.m_imageViewsInfo.push_back(imageInfo);
             }
             }
-
-            m_updateData.push_back(AZStd::move(data));
         }
         }
 
 
         void DescriptorSet::UpdateConstantData(AZStd::span<const uint8_t> rawData)
         void DescriptorSet::UpdateConstantData(AZStd::span<const uint8_t> rawData)
@@ -199,7 +198,7 @@ namespace AZ
             memcpy(mappedData, rawData.data(), rawData.size());
             memcpy(mappedData, rawData.data(), rawData.size());
             memoryView->Unmap(RHI::HostMemoryAccess::Write);
             memoryView->Unmap(RHI::HostMemoryAccess::Write);
 
 
-            WriteDescriptorData data;
+            WriteDescriptorData& data = m_updateData.emplace_back();
             data.m_layoutIndex = layout.GetLayoutIndexFromGroupIndex(0, DescriptorSetLayout::ResourceType::ConstantData);
             data.m_layoutIndex = layout.GetLayoutIndexFromGroupIndex(0, DescriptorSetLayout::ResourceType::ConstantData);
 
 
             VkDescriptorBufferInfo bufferInfo;
             VkDescriptorBufferInfo bufferInfo;
@@ -207,7 +206,6 @@ namespace AZ
             bufferInfo.offset = memoryView->GetOffset();
             bufferInfo.offset = memoryView->GetOffset();
             bufferInfo.range = rawData.size();
             bufferInfo.range = rawData.size();
             data.m_bufferViewsInfo.push_back(bufferInfo);
             data.m_bufferViewsInfo.push_back(bufferInfo);
-            m_updateData.push_back(AZStd::move(data));
         }
         }
 
 
         RHI::Ptr<DescriptorSet> DescriptorSet::Create()
         RHI::Ptr<DescriptorSet> DescriptorSet::Create()
@@ -312,10 +310,27 @@ namespace AZ
                 AllocateDescriptorSetWithUnboundedArray();
                 AllocateDescriptorSetWithUnboundedArray();
             }
             }
 
 
-            AZStd::vector<VkWriteDescriptorSet> writeDescSetDescs;
-            AZStd::vector<VkWriteDescriptorSetAccelerationStructureKHR> writeAccelerationStructureDescs;
+            AZStd::small_vector<VkWriteDescriptorSet, ViewsFixedsize> writeDescSetDescs;
+            writeDescSetDescs.reserve(m_updateData.size());
+            AZStd::small_vector<VkWriteDescriptorSetAccelerationStructureKHR, ViewsFixedsize> writeAccelerationStructureDescs;
+
             const DescriptorSetLayout& layout = *m_descriptor.m_descriptorSetLayout;
             const DescriptorSetLayout& layout = *m_descriptor.m_descriptorSetLayout;
-            for (const WriteDescriptorData& updateData : m_updateData)
+            {
+                // Reserve memory for the acceleration structures descriptors
+                // We need t a pointer to the entries which may be invalidated in push_back without reserving memory
+                size_t numAccelerationStructureEntries = 0;
+                for (const WriteDescriptorData& updateData : m_updateData.span())
+                {
+                    const VkDescriptorType descType = layout.GetDescriptorType(updateData.m_layoutIndex);
+                    if (descType == VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
+                    {
+                        numAccelerationStructureEntries++;
+                    }
+                }
+                writeAccelerationStructureDescs.reserve(numAccelerationStructureEntries);
+            }
+
+            for (const WriteDescriptorData& updateData : m_updateData.span())
             {
             {
                 const VkDescriptorType descType = layout.GetDescriptorType(updateData.m_layoutIndex);
                 const VkDescriptorType descType = layout.GetDescriptorType(updateData.m_layoutIndex);
 
 
@@ -330,13 +345,16 @@ namespace AZ
                 {
                 {
                 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
                 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
                 case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
                 case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
-                    AZ_Assert(!updateData.m_bufferViewsInfo.empty(), "BufferInfo is empty.");
-                    for (const RHI::Interval& interval : GetValidDescriptorsIntervals(updateData.m_bufferViewsInfo))
                     {
                     {
-                        writeDescSet.pBufferInfo = updateData.m_bufferViewsInfo.data() + interval.m_min;
-                        writeDescSet.dstArrayElement = interval.m_min;
-                        writeDescSet.descriptorCount = interval.m_max - interval.m_min;
-                        writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        AZ_Assert(!updateData.m_bufferViewsInfo.empty(), "BufferInfo is empty.");
+                        auto intervals = GetValidDescriptorsIntervals(updateData.m_bufferViewsInfo.span());
+                        for (const RHI::Interval& interval : intervals.span())
+                        {
+                            writeDescSet.pBufferInfo = updateData.m_bufferViewsInfo.span().data() + interval.m_min;
+                            writeDescSet.dstArrayElement = interval.m_min;
+                            writeDescSet.descriptorCount = interval.m_max - interval.m_min;
+                            writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        }
                     }
                     }
                     break;
                     break;
                 case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
                 case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
@@ -344,42 +362,52 @@ namespace AZ
                 case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
                 case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
                 case VK_DESCRIPTOR_TYPE_SAMPLER:
                 case VK_DESCRIPTOR_TYPE_SAMPLER:
                 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
                 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
-                    AZ_Assert(!updateData.m_imageViewsInfo.empty(), "ImageInfo is empty.");
-                    for (const RHI::Interval& interval : GetValidDescriptorsIntervals(updateData.m_imageViewsInfo))
                     {
                     {
-                        writeDescSet.pImageInfo = updateData.m_imageViewsInfo.data() + interval.m_min;
-                        writeDescSet.dstArrayElement = interval.m_min;
-                        writeDescSet.descriptorCount = interval.m_max - interval.m_min;
-                        writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        AZ_Assert(!updateData.m_imageViewsInfo.empty(), "ImageInfo is empty.");
+                        auto intervals = GetValidDescriptorsIntervals(updateData.m_imageViewsInfo.span());
+                        for (const RHI::Interval& interval : intervals.span())
+                        {
+                            writeDescSet.pImageInfo = updateData.m_imageViewsInfo.span().data() + interval.m_min;
+                            writeDescSet.dstArrayElement = interval.m_min;
+                            writeDescSet.descriptorCount = interval.m_max - interval.m_min;
+                            writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        }
                     }
                     }
                     break;
                     break;
                 case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
                 case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
                 case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
                 case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
-                    AZ_Assert(!updateData.m_texelBufferViews.empty(), "TexelInfo list is empty.");
-                    for (const RHI::Interval& interval : GetValidDescriptorsIntervals(updateData.m_texelBufferViews))
                     {
                     {
-                        writeDescSet.pTexelBufferView = updateData.m_texelBufferViews.data() + interval.m_min;
-                        writeDescSet.dstArrayElement = interval.m_min;
-                        writeDescSet.descriptorCount = interval.m_max - interval.m_min;
-                        writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        AZ_Assert(!updateData.m_texelBufferViews.empty(), "TexelInfo list is empty.");
+                        auto intervals = GetValidDescriptorsIntervals(updateData.m_texelBufferViews.span());
+                        for (const RHI::Interval& interval : intervals.span())
+                        {
+                            writeDescSet.pTexelBufferView = updateData.m_texelBufferViews.span().data() + interval.m_min;
+                            writeDescSet.dstArrayElement = interval.m_min;
+                            writeDescSet.descriptorCount = interval.m_max - interval.m_min;
+                            writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        }
                     }
                     }
                     break;
                     break;
                 case VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR:
                 case VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR:
-                    AZ_Assert(!updateData.m_bufferViewsInfo.empty(), "BufferInfo is empty.");
-                    AZ_Assert(!updateData.m_accelerationStructures.empty(), "AccelerationStructures is empty.");
-                    for (const RHI::Interval& interval : GetValidDescriptorsIntervals(updateData.m_bufferViewsInfo))
                     {
                     {
-                        // acceleration structure descriptor is added as the pNext in the VkWriteDescriptorSet
-                        VkWriteDescriptorSetAccelerationStructureKHR writeAccelerationStructure = {};
-                        writeAccelerationStructure.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR;
-                        writeAccelerationStructure.accelerationStructureCount = interval.m_max - interval.m_min;
-                        writeAccelerationStructure.pAccelerationStructures = updateData.m_accelerationStructures.data() + interval.m_min;
-                        writeAccelerationStructureDescs.push_back(AZStd::move(writeAccelerationStructure));
-
-                        writeDescSet.dstArrayElement = interval.m_min;
-                        writeDescSet.descriptorCount = interval.m_max - interval.m_min;
-                        writeDescSet.pNext = &writeAccelerationStructureDescs.back();
-                        writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        AZ_Assert(!updateData.m_bufferViewsInfo.empty(), "BufferInfo is empty.");
+                        AZ_Assert(!updateData.m_accelerationStructures.empty(), "AccelerationStructures is empty.");
+                        auto intervals = GetValidDescriptorsIntervals(updateData.m_bufferViewsInfo.span());
+                        for (const RHI::Interval& interval : intervals.span())
+                        {
+                            // acceleration structure descriptor is added as the pNext in the VkWriteDescriptorSet
+                            VkWriteDescriptorSetAccelerationStructureKHR writeAccelerationStructure = {};
+                            writeAccelerationStructure.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR;
+                            writeAccelerationStructure.accelerationStructureCount = interval.m_max - interval.m_min;
+                            writeAccelerationStructure.pAccelerationStructures =
+                                updateData.m_accelerationStructures.span().data() + interval.m_min;
+                            writeAccelerationStructureDescs.push_back(AZStd::move(writeAccelerationStructure));
+
+                            writeDescSet.dstArrayElement = interval.m_min;
+                            writeDescSet.descriptorCount = interval.m_max - interval.m_min;
+                            writeDescSet.pNext = &writeAccelerationStructureDescs.span().back();
+                            writeDescSetDescs.push_back(AZStd::move(writeDescSet));
+                        }
                     }
                     }
                     break;
                     break;
                 default:
                 default:
@@ -392,7 +420,7 @@ namespace AZ
             {
             {
                 auto& device = static_cast<Device&>(GetDevice());
                 auto& device = static_cast<Device&>(GetDevice());
                 device.GetContext().UpdateDescriptorSets(
                 device.GetContext().UpdateDescriptorSets(
-                    device.GetNativeDevice(), static_cast<uint32_t>(writeDescSetDescs.size()), writeDescSetDescs.data(), 0, nullptr);
+                    device.GetNativeDevice(), static_cast<uint32_t>(writeDescSetDescs.size()), writeDescSetDescs.span().data(), 0, nullptr);
             }
             }
 
 
             m_updateData.clear();
             m_updateData.clear();
@@ -404,7 +432,7 @@ namespace AZ
             uint32_t unboundedArraySize = 0;
             uint32_t unboundedArraySize = 0;
 
 
             // find the unbounded array in the update data
             // find the unbounded array in the update data
-            for (const WriteDescriptorData& updateData : m_updateData)
+            for (const WriteDescriptorData& updateData : m_updateData.span())
             {
             {
                 if ((layout.GetNativeBindingFlags()[updateData.m_layoutIndex] & VK_DESCRIPTOR_BINDING_VARIABLE_DESCRIPTOR_COUNT_BIT) != 0)
                 if ((layout.GetNativeBindingFlags()[updateData.m_layoutIndex] & VK_DESCRIPTOR_BINDING_VARIABLE_DESCRIPTOR_COUNT_BIT) != 0)
                 {
                 {

+ 21 - 14
Gems/Atom/RHI/Vulkan/Code/Source/RHI/DescriptorSet.h

@@ -7,14 +7,15 @@
  */
  */
 #pragma once
 #pragma once
 
 
-#include <Atom/RHI/DeviceObject.h>
+#include <Atom/RHI.Reflect/SamplerState.h>
 #include <Atom/RHI/DeviceBuffer.h>
 #include <Atom/RHI/DeviceBuffer.h>
 #include <Atom/RHI/DeviceBufferView.h>
 #include <Atom/RHI/DeviceBufferView.h>
 #include <Atom/RHI/DeviceImage.h>
 #include <Atom/RHI/DeviceImage.h>
 #include <Atom/RHI/DeviceImageView.h>
 #include <Atom/RHI/DeviceImageView.h>
-#include <Atom/RHI.Reflect/SamplerState.h>
-#include <AzCore/std/containers/span.h>
+#include <Atom/RHI/DeviceObject.h>
+#include <AtomCore/std/containers/small_vector.h>
 #include <AzCore/Memory/PoolAllocator.h>
 #include <AzCore/Memory/PoolAllocator.h>
+#include <AzCore/std/containers/span.h>
 #include <RHI/Buffer.h>
 #include <RHI/Buffer.h>
 
 
 namespace AZ
 namespace AZ
@@ -38,6 +39,8 @@ namespace AZ
             using Base = RHI::DeviceObject;
             using Base = RHI::DeviceObject;
             friend class DescriptorPool;
             friend class DescriptorPool;
 
 
+            static constexpr size_t ViewsFixedsize = 16;
+
         public:
         public:
             
             
             //Using SystemAllocator here instead of ThreadPoolAllocator as it gets slower when
             //Using SystemAllocator here instead of ThreadPoolAllocator as it gets slower when
@@ -60,6 +63,8 @@ namespace AZ
 
 
             void CommitUpdates();
             void CommitUpdates();
 
 
+            void ReserveUpdateData(size_t numUpdates);
+
             void UpdateBufferViews(uint32_t index, const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufViews);
             void UpdateBufferViews(uint32_t index, const AZStd::span<const RHI::ConstPtr<RHI::DeviceBufferView>>& bufViews);
             void UpdateImageViews(uint32_t index, const AZStd::span<const RHI::ConstPtr<RHI::DeviceImageView>>& imageViews, RHI::ShaderInputImageType imageType);
             void UpdateImageViews(uint32_t index, const AZStd::span<const RHI::ConstPtr<RHI::DeviceImageView>>& imageViews, RHI::ShaderInputImageType imageType);
             void UpdateSamplers(uint32_t index, const AZStd::span<const RHI::SamplerState>& samplers);
             void UpdateSamplers(uint32_t index, const AZStd::span<const RHI::SamplerState>& samplers);
@@ -71,10 +76,10 @@ namespace AZ
             struct WriteDescriptorData
             struct WriteDescriptorData
             {
             {
                 uint32_t m_layoutIndex = 0;
                 uint32_t m_layoutIndex = 0;
-                AZStd::vector<VkDescriptorBufferInfo> m_bufferViewsInfo;
-                AZStd::vector<VkDescriptorImageInfo> m_imageViewsInfo;
-                AZStd::vector<VkBufferView> m_texelBufferViews;
-                AZStd::vector<VkAccelerationStructureKHR> m_accelerationStructures;
+                AZStd::small_vector<VkDescriptorBufferInfo, ViewsFixedsize> m_bufferViewsInfo;
+                AZStd::small_vector<VkDescriptorImageInfo, ViewsFixedsize> m_imageViewsInfo;
+                AZStd::small_vector<VkBufferView, ViewsFixedsize> m_texelBufferViews;
+                AZStd::small_vector<VkAccelerationStructureKHR, ViewsFixedsize> m_accelerationStructures;
             };
             };
 
 
             DescriptorSet() = default;
             DescriptorSet() = default;
@@ -93,7 +98,7 @@ namespace AZ
             void AllocateDescriptorSetWithUnboundedArray();
             void AllocateDescriptorSetWithUnboundedArray();
 
 
             template<typename T>
             template<typename T>
-            AZStd::vector<RHI::Interval> GetValidDescriptorsIntervals(const AZStd::vector<T>& descriptorsInfo) const;
+            AZStd::small_vector<RHI::Interval, ViewsFixedsize> GetValidDescriptorsIntervals(const AZStd::span<T>& descriptorsInfo) const;
 
 
             static bool IsNullDescriptorInfo(const VkDescriptorBufferInfo& descriptorInfo);
             static bool IsNullDescriptorInfo(const VkDescriptorBufferInfo& descriptorInfo);
             static bool IsNullDescriptorInfo(const VkDescriptorImageInfo& descriptorInfo);
             static bool IsNullDescriptorInfo(const VkDescriptorImageInfo& descriptorInfo);
@@ -102,7 +107,7 @@ namespace AZ
             Descriptor m_descriptor;
             Descriptor m_descriptor;
 
 
             VkDescriptorSet m_nativeDescriptorSet = VK_NULL_HANDLE;
             VkDescriptorSet m_nativeDescriptorSet = VK_NULL_HANDLE;
-            AZStd::vector<WriteDescriptorData> m_updateData;
+            AZStd::small_vector<WriteDescriptorData, ViewsFixedsize> m_updateData;
             RHI::Ptr<Buffer> m_constantDataBuffer;
             RHI::Ptr<Buffer> m_constantDataBuffer;
             RHI::Ptr<BufferView> m_constantDataBufferView;
             RHI::Ptr<BufferView> m_constantDataBufferView;
             bool m_nullDescriptorSupported = false;
             bool m_nullDescriptorSupported = false;
@@ -110,15 +115,17 @@ namespace AZ
         };
         };
 
 
         template<typename T>
         template<typename T>
-        AZStd::vector<RHI::Interval> DescriptorSet::GetValidDescriptorsIntervals(const AZStd::vector<T>& descriptorsInfo) const
+        AZStd::small_vector<RHI::Interval, DescriptorSet::ViewsFixedsize> DescriptorSet::GetValidDescriptorsIntervals(
+            const AZStd::span<T>& descriptorsInfo) const
         {
         {
+            AZStd::small_vector<RHI::Interval, DescriptorSet::ViewsFixedsize> intervals;
             // if Null descriptors are supported, then we just return one interval that covers the whole range.
             // if Null descriptors are supported, then we just return one interval that covers the whole range.
             if (m_nullDescriptorSupported)
             if (m_nullDescriptorSupported)
             {
             {
-                return { RHI::Interval(0, aznumeric_caster(descriptorsInfo.size())) };
+                intervals.push_back(RHI::Interval(0, aznumeric_caster(descriptorsInfo.size())));
+                return intervals;
             }
             }
 
 
-            AZStd::vector<RHI::Interval> intervals;
             auto beginInterval = descriptorsInfo.begin();
             auto beginInterval = descriptorsInfo.begin();
             auto endInterval = beginInterval;
             auto endInterval = beginInterval;
             bool (*IsNullFuntion)(const T&) = &DescriptorSet::IsNullDescriptorInfo;
             bool (*IsNullFuntion)(const T&) = &DescriptorSet::IsNullDescriptorInfo;
@@ -129,8 +136,8 @@ namespace AZ
                 {
                 {
                     endInterval = AZStd::find_if(beginInterval, descriptorsInfo.end(), IsNullFuntion);
                     endInterval = AZStd::find_if(beginInterval, descriptorsInfo.end(), IsNullFuntion);
 
 
-                    intervals.emplace_back();
-                    RHI::Interval& interval = intervals.back();
+                    intervals.push_back({});
+                    RHI::Interval& interval = intervals.span().back();
                     interval.m_min = aznumeric_caster(AZStd::distance(descriptorsInfo.begin(), beginInterval));
                     interval.m_min = aznumeric_caster(AZStd::distance(descriptorsInfo.begin(), beginInterval));
                     interval.m_max = endInterval == descriptorsInfo.end() ? static_cast<uint32_t>(descriptorsInfo.size()) : static_cast<uint32_t>(AZStd::distance(descriptorsInfo.begin(), endInterval));
                     interval.m_max = endInterval == descriptorsInfo.end() ? static_cast<uint32_t>(descriptorsInfo.size()) : static_cast<uint32_t>(AZStd::distance(descriptorsInfo.begin(), endInterval));
                 }
                 }

+ 14 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/ShaderResourceGroupPool.cpp

@@ -111,6 +111,20 @@ namespace AZ
 
 
             const RHI::ShaderResourceGroupLayout* layout = groupData.GetLayout();
             const RHI::ShaderResourceGroupLayout* layout = groupData.GetLayout();
 
 
+            {
+                size_t numUpdates = 0;
+                numUpdates += layout->GetShaderInputListForBuffers().size();
+                numUpdates += layout->GetShaderInputListForImages().size();
+                numUpdates += layout->GetShaderInputListForBufferUnboundedArrays().size();
+                numUpdates += layout->GetShaderInputListForImageUnboundedArrays().size();
+                numUpdates += layout->GetShaderInputListForSamplers().size();
+                if (!groupData.GetConstantData().empty())
+                {
+                    numUpdates++;
+                }
+                descriptorSet.ReserveUpdateData(numUpdates);
+            }
+
             for (uint32_t groupIndex = 0; groupIndex < static_cast<uint32_t>(layout->GetShaderInputListForBuffers().size()); ++groupIndex)
             for (uint32_t groupIndex = 0; groupIndex < static_cast<uint32_t>(layout->GetShaderInputListForBuffers().size()); ++groupIndex)
             {
             {
                 const RHI::ShaderInputBufferIndex index(groupIndex);
                 const RHI::ShaderInputBufferIndex index(groupIndex);