瀏覽代碼

support srg constants with different values on different GPUs (#18816)

Signed-off-by: Karl Haubenwallner <[email protected]>
Karl Haubenwallner 5 月之前
父節點
當前提交
cd1b6e4c8d

+ 32 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/ShaderResourceGroupData.h

@@ -92,6 +92,10 @@ namespace AZ::RHI
         template<typename T>
         template<typename T>
         bool SetConstant(ShaderInputConstantIndex inputIndex, const T& value);
         bool SetConstant(ShaderInputConstantIndex inputIndex, const T& value);
 
 
+        //! Assigns a device-specific value of type T to the constant shader input.
+        template<typename T>
+        bool SetConstant(ShaderInputConstantIndex inputIndex, const AZStd::unordered_map<int, T>& values);
+
         //! Assigns a specified number of rows from a Matrix
         //! Assigns a specified number of rows from a Matrix
         template<typename T>
         template<typename T>
         bool SetConstantMatrixRows(ShaderInputConstantIndex inputIndex, const T& value, uint32_t rowCount);
         bool SetConstantMatrixRows(ShaderInputConstantIndex inputIndex, const T& value, uint32_t rowCount);
@@ -270,6 +274,34 @@ namespace AZ::RHI
         return isValidAll;
         return isValidAll;
     }
     }
 
 
+    template<typename T>
+    bool ShaderResourceGroupData::SetConstant(ShaderInputConstantIndex inputIndex, const AZStd::unordered_map<int, T>& values)
+    {
+        EnableResourceTypeCompilation(ResourceTypeMask::ConstantDataMask);
+
+        bool isValidAll = true;
+        bool foundValidDevice = false;
+
+        for (auto& [deviceIndex, deviceShaderResourceGroupData] : m_deviceShaderResourceGroupDatas)
+        {
+            auto deviceValueIterator = values.find(deviceIndex);
+            if (deviceValueIterator != values.end())
+            {
+                // Use the data for the first valid device for the getters
+                if (!foundValidDevice)
+                {
+                    foundValidDevice = true;
+                    m_constantsData.SetConstant(inputIndex, deviceValueIterator->second);
+                }
+                isValidAll &= deviceShaderResourceGroupData.SetConstant(inputIndex, deviceValueIterator->second);
+            }
+        }
+        // We need at least one valid device
+        isValidAll &= foundValidDevice;
+
+        return isValidAll;
+    }
+
     template<typename T>
     template<typename T>
     bool ShaderResourceGroupData::SetConstant(ShaderInputConstantIndex inputIndex, const T& value, uint32_t arrayIndex)
     bool ShaderResourceGroupData::SetConstant(ShaderInputConstantIndex inputIndex, const T& value, uint32_t arrayIndex)
     {
     {

+ 24 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderResourceGroup.h

@@ -238,6 +238,14 @@ namespace AZ
             template <typename T>
             template <typename T>
             bool SetConstant(RHI::ShaderInputConstantIndex inputIndex, const T& value);
             bool SetConstant(RHI::ShaderInputConstantIndex inputIndex, const T& value);
 
 
+            /// Assign a device-specific value of type T to the constant shader input. Note that the corresponding GetConstant() - function
+            /// returns only the value of one device.
+            template<typename T>
+            bool SetConstant(RHI::ShaderInputNameIndex& inputIndex, const AZStd::unordered_map<int, T>& values);
+
+            template<typename T>
+            bool SetConstant(RHI::ShaderInputConstantIndex& inputIndex, const AZStd::unordered_map<int, T>& values);
+
             /// Assigns the specified number of rows from a Matrix
             /// Assigns the specified number of rows from a Matrix
             template <typename T>
             template <typename T>
             bool SetConstantMatrixRows(RHI::ShaderInputNameIndex& inputIndex, const T& value, uint32_t rowCount);
             bool SetConstantMatrixRows(RHI::ShaderInputNameIndex& inputIndex, const T& value, uint32_t rowCount);
@@ -426,6 +434,22 @@ namespace AZ
             return false;
             return false;
         }
         }
 
 
+        template<typename T>
+        bool ShaderResourceGroup::SetConstant(RHI::ShaderInputConstantIndex& inputIndex, const AZStd::unordered_map<int, T>& values)
+        {
+            return m_data.SetConstant(inputIndex, values);
+        }
+
+        template<typename T>
+        bool ShaderResourceGroup::SetConstant(RHI::ShaderInputNameIndex& inputIndex, const AZStd::unordered_map<int, T>& values)
+        {
+            if (inputIndex.ValidateOrFindConstantIndex(GetLayout()))
+            {
+                return SetConstant(inputIndex.GetConstantIndex(), values);
+            }
+            return false;
+        }
+
         template <typename T>
         template <typename T>
         bool ShaderResourceGroup::SetConstantMatrixRows(RHI::ShaderInputConstantIndex inputIndex, const T& value, uint32_t rowCount)
         bool ShaderResourceGroup::SetConstantMatrixRows(RHI::ShaderInputConstantIndex inputIndex, const T& value, uint32_t rowCount)
         {
         {