瀏覽代碼

Create helper function for getting threads per (#4480)

* Create helper function for getting threads per
group from a compute shader

Added GetComputeShaderNumThreads() functions to RPIUtils.
By default the function returns 1, 1, 1 in case of errors.

Updated existing code that was looking for 'numthreads' attribute data
with the new GetComputeShaderNumThreads() API.

Signed-off-by: garrieta <[email protected]>
galibzon 3 年之前
父節點
當前提交
643bd84739

+ 3 - 20
Gems/Atom/Feature/Common/Code/Source/DiffuseGlobalIllumination/DiffuseProbeGridBlendDistancePass.cpp

@@ -54,27 +54,10 @@ namespace AZ
             m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
 
             // retrieve the number of threads per thread group from the shader
-            const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
+            if (!outcome.IsSuccess())
             {
-                const RHI::ShaderStageAttributeArguments& args = *numThreads;
-                bool validArgs = args.size() == 3;
-                if (validArgs)
-                {
-                    validArgs &= args[0].type() == azrtti_typeid<int>();
-                    validArgs &= args[1].type() == azrtti_typeid<int>();
-                    validArgs &= args[2].type() == azrtti_typeid<int>();
-                }
-
-                if (!validArgs)
-                {
-                    AZ_Error("PassSystem", false, "[DiffuseProbeGridBlendDistancePass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
-                    return;
-                }
-
-                m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
-                m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
-                m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
+                AZ_Error("PassSystem", false, "[DiffuseProbeGridBlendDistancePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
             }
         }
 

+ 3 - 20
Gems/Atom/Feature/Common/Code/Source/DiffuseGlobalIllumination/DiffuseProbeGridBlendIrradiancePass.cpp

@@ -54,27 +54,10 @@ namespace AZ
             m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
 
             // retrieve the number of threads per thread group from the shader
-            const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
+            if (!outcome.IsSuccess())
             {
-                const RHI::ShaderStageAttributeArguments& args = *numThreads;
-                bool validArgs = args.size() == 3;
-                if (validArgs)
-                {
-                    validArgs &= args[0].type() == azrtti_typeid<int>();
-                    validArgs &= args[1].type() == azrtti_typeid<int>();
-                    validArgs &= args[2].type() == azrtti_typeid<int>();
-                }
-
-                if (!validArgs)
-                {
-                    AZ_Error("PassSystem", false, "[DiffuseProbeBlendIrradiancePass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
-                    return;
-                }
-
-                m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
-                m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
-                m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
+                AZ_Error("PassSystem", false, "[DiffuseProbeBlendIrradiancePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
             }
         }
 

+ 3 - 20
Gems/Atom/Feature/Common/Code/Source/DiffuseGlobalIllumination/DiffuseProbeGridBorderUpdatePass.cpp

@@ -67,27 +67,10 @@ namespace AZ
             srgLayout = shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
 
             // retrieve the number of threads per thread group from the shader
-            const auto numThreads = shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(shader->GetAsset(), dispatchArgs);
+            if (!outcome.IsSuccess())
             {
-                const RHI::ShaderStageAttributeArguments& args = *numThreads;
-                bool validArgs = args.size() == 3;
-                if (validArgs)
-                {
-                    validArgs &= args[0].type() == azrtti_typeid<int>();
-                    validArgs &= args[1].type() == azrtti_typeid<int>();
-                    validArgs &= args[2].type() == azrtti_typeid<int>();
-                }
-
-                if (!validArgs)
-                {
-                    AZ_Error("PassSystem", false, "[DiffuseProbeGridBorderUpdatePass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
-                    return;
-                }
-
-                dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
-                dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
-                dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
+                AZ_Error("PassSystem", false, "[DiffuseProbeGridBorderUpdatePass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
             }
         }
 

+ 3 - 20
Gems/Atom/Feature/Common/Code/Source/DiffuseGlobalIllumination/DiffuseProbeGridClassificationPass.cpp

@@ -58,27 +58,10 @@ namespace AZ
             m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
 
             // retrieve the number of threads per thread group from the shader
-            const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
+            if (!outcome.IsSuccess())
             {
-                const RHI::ShaderStageAttributeArguments& args = *numThreads;
-                bool validArgs = args.size() == 3;
-                if (validArgs)
-                {
-                    validArgs &= args[0].type() == azrtti_typeid<int>();
-                    validArgs &= args[1].type() == azrtti_typeid<int>();
-                    validArgs &= args[2].type() == azrtti_typeid<int>();
-                }
-
-                if (!validArgs)
-                {
-                    AZ_Error("PassSystem", false, "[DiffuseProbeClassificationPass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
-                    return;
-                }
-
-                m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
-                m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
-                m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
+                AZ_Error("PassSystem", false, "[DiffuseProbeClassificationPass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
             }
         }
 

+ 3 - 20
Gems/Atom/Feature/Common/Code/Source/DiffuseGlobalIllumination/DiffuseProbeGridRelocationPass.cpp

@@ -58,27 +58,10 @@ namespace AZ
             m_srgLayout = m_shader->FindShaderResourceGroupLayout(RPI::SrgBindingSlot::Pass);
 
             // retrieve the number of threads per thread group from the shader
-            const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchArgs);
+            if (!outcome.IsSuccess())
             {
-                const RHI::ShaderStageAttributeArguments& args = *numThreads;
-                bool validArgs = args.size() == 3;
-                if (validArgs)
-                {
-                    validArgs &= args[0].type() == azrtti_typeid<int>();
-                    validArgs &= args[1].type() == azrtti_typeid<int>();
-                    validArgs &= args[2].type() == azrtti_typeid<int>();
-                }
-
-                if (!validArgs)
-                {
-                    AZ_Error("PassSystem", false, "[DiffuseProbeRelocationPass '%s']: Shader '%s' contains invalid numthreads arguments.", GetPathName().GetCStr(), shaderFilePath.c_str());
-                    return;
-                }
-
-                m_dispatchArgs.m_threadsPerGroupX = static_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
-                m_dispatchArgs.m_threadsPerGroupY = static_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
-                m_dispatchArgs.m_threadsPerGroupZ = static_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
+                AZ_Error("PassSystem", false, "[DiffuseProbeRelocationPass '%s']: Shader '%s' contains invalid numthreads arguments:\n%s", GetPathName().GetCStr(), shaderFilePath.c_str(), outcome.GetError().c_str());
             }
         }
 

+ 4 - 7
Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp

@@ -13,6 +13,7 @@
 #include <Atom/RPI.Public/Shader/Shader.h>
 #include <Atom/RPI.Public/Model/ModelLod.h>
 #include <Atom/RPI.Public/Buffer/Buffer.h>
+#include <Atom/RPI.Public/RPIUtils.h>
 
 #include <Atom/RHI/Factory.h>
 #include <Atom/RHI/BufferView.h>
@@ -79,15 +80,11 @@ namespace AZ
             m_dispatchItem.m_pipelineState = m_morphTargetShader->AcquirePipelineState(pipelineStateDescriptor);
 
             // Get the threads-per-group values from the compute shader [numthreads(x,y,z)]
-            const auto& numThreads = m_morphTargetShader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, AZ::Name{ "numthreads" });
             auto& arguments = m_dispatchItem.m_arguments.m_direct;
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(m_morphTargetShader->GetAsset(), arguments);
+            if (!outcome.IsSuccess())
             {
-                const auto& args = *numThreads;
-                // Check that the arguments are valid integers, and fall back to 1,1,1 if there is an error
-                arguments.m_threadsPerGroupX = static_cast<uint16_t>(args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : 1);
-                arguments.m_threadsPerGroupY = static_cast<uint16_t>(args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : 1);
-                arguments.m_threadsPerGroupZ = static_cast<uint16_t>(args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : 1);
+                AZ_Error("MorphTargetDispatchItem", false, outcome.GetError().c_str());
             }
 
             arguments.m_totalNumberOfThreadsX = m_morphTargetMetaData.m_vertexCount;

+ 6 - 8
Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshDispatchItem.cpp

@@ -14,6 +14,7 @@
 #include <Atom/RPI.Public/Shader/Shader.h>
 #include <Atom/RPI.Public/Model/ModelLod.h>
 #include <Atom/RPI.Public/Buffer/Buffer.h>
+#include <Atom/RPI.Public/RPIUtils.h>
 
 #include <Atom/RHI/Factory.h>
 #include <Atom/RHI/BufferView.h>
@@ -199,17 +200,14 @@ namespace AZ
             m_instanceSrg->Compile();
             m_dispatchItem.m_uniqueShaderResourceGroup = m_instanceSrg->GetRHIShaderResourceGroup();
             m_dispatchItem.m_pipelineState = m_skinningShader->AcquirePipelineState(pipelineStateDescriptor);
-            
-            const auto& numThreads = m_skinningShader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, AZ::Name{ "numthreads" });
+
             auto& arguments = m_dispatchItem.m_arguments.m_direct;
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(m_skinningShader->GetAsset(), arguments);
+            if (!outcome.IsSuccess())
             {
-                const auto& args = *numThreads;
-                arguments.m_threadsPerGroupX = static_cast<uint16_t>(args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : 1);
-                arguments.m_threadsPerGroupY = static_cast<uint16_t>(args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : 1);
-                arguments.m_threadsPerGroupZ = static_cast<uint16_t>(args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : 1);
+                AZ_Error("SkinnedMeshInputBuffers", false, outcome.GetError().c_str());
             }
-
+ 
             arguments.m_totalNumberOfThreadsX = xThreads;
             arguments.m_totalNumberOfThreadsY = yThreads;
             arguments.m_totalNumberOfThreadsZ = 1;

+ 18 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/RPIUtils.h

@@ -11,6 +11,7 @@
  
 #include <AtomCore/Instance/Instance.h>
 
+#include <Atom/RHI/DispatchItem.h>
 #include <Atom/RPI.Public/Base.h>
 #include <Atom/RPI.Public/Image/StreamingImage.h>
 #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
@@ -40,6 +41,23 @@ namespace AZ
 
         //! Loads a streaming image asset for the given file path
         Data::Instance<RPI::StreamingImage> LoadStreamingTexture(AZStd::string_view path);
+
+        //! Looks for a three arguments attribute named @attributeName in the given shader asset.
+        //! Assigns the value to each non-null output variables.
+        //! @param shaderAsset
+        //! @param attributeName
+        //! @param numThreadsX Can be NULL. If not NULL it takes the value of the 1st argument of the attribute. Becomes 1 on error.
+        //! @param numThreadsY Can be NULL. If not NULL it takes the value of the 2nd argument of the attribute. Becomes 1 on error.
+        //! @param numThreadsZ Can be NULL. If not NULL it takes the value of the 3rd argument of the attribute. Becomes 1 on error.
+        //! @returns An Outcome instance with error message in case of error.
+        AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ);
+
+        //! Same as above, but assumes the name of the attribute to be 'numthreads'.
+        AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ);
+
+        //! Same as above. Provided as a convenience when all arguments of the 'numthreads' attributes should be assigned to RHI::DispatchDirect::m_threadsPerGroup* variables.
+        AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, RHI::DispatchDirect& dispatchDirect);
+        
     }   // namespace RPI
 }   // namespace AZ
 

+ 5 - 22
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp

@@ -107,30 +107,13 @@ namespace AZ
             dispatchArgs.m_totalNumberOfThreadsY = passData->m_totalNumberOfThreadsY;
             dispatchArgs.m_totalNumberOfThreadsZ = passData->m_totalNumberOfThreadsZ;
 
-            const auto numThreads = m_shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name{ "numthreads" });
-            if (numThreads)
+            const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), dispatchArgs);
+            if (!outcome.IsSuccess())
             {
-                const RHI::ShaderStageAttributeArguments& args = *numThreads;
-                bool validArgs = args.size() == 3;
-                if (validArgs)
-                {
-                    validArgs &= args[0].type() == azrtti_typeid<int>();
-                    validArgs &= args[1].type() == azrtti_typeid<int>();
-                    validArgs &= args[2].type() == azrtti_typeid<int>();
-                }
-
-                if (!validArgs)
-                {
-                    AZ_Error("PassSystem", false, "[ComputePass '%s']: Shader '%s' contains invalid numthreads arguments.",
-                        GetPathName().GetCStr(),
-                        passData->m_shaderReference.m_filePath.data());
-                    return;
-                }
-
-                dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(AZStd::any_cast<int>(args[0]));
-                dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(AZStd::any_cast<int>(args[1]));
-                dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(AZStd::any_cast<int>(args[2]));
+                AZ_Error("PassSystem", false, "[ComputePass '%s']: Shader '%.*s' contains invalid numthreads arguments:\n%s",
+                        GetPathName().GetCStr(), passData->m_shaderReference.m_filePath.size(), passData->m_shaderReference.m_filePath.data(), outcome.GetError().c_str());
             }
+
             m_dispatchItem.m_arguments = dispatchArgs;
 
             m_isFullscreenPass = passData->m_makeFullscreenPass;

+ 74 - 0
Gems/Atom/RPI/Code/Source/RPI.Public/RPIUtils.cpp

@@ -143,5 +143,79 @@ namespace AZ
 
             return RPI::StreamingImage::FindOrCreate(streamingImageAsset);
         }
+
+        //! A helper function for GetComputeShaderNumThreads(), to consolidate error messages, etc.
+        static bool GetAttributeArgumentByIndex(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, const RHI::ShaderStageAttributeArguments& args, const size_t argIndex, uint16_t* value, AZStd::string& errorMsg)
+        {
+            if (value)
+            {
+                const auto numArguments = args.size();
+                if (numArguments > argIndex)
+                {
+                    if (args[argIndex].type() == azrtti_typeid<int>())
+                    {
+                        *value = aznumeric_caster(AZStd::any_cast<int>(args[argIndex]));
+                    }
+                    else
+                    {
+                        errorMsg = AZStd::string::format("Was expecting argument '%zu' in attribute '%s' to be of type 'int' from shader asset '%s'", argIndex, attributeName.GetCStr(), shaderAsset.GetHint().c_str());
+                        return false;
+                    }
+                }
+                else
+                {
+                     errorMsg = AZStd::string::format("Was expecting at least '%zu' arguments in attribute '%s' from shader asset '%s'", argIndex + 1, attributeName.GetCStr(), shaderAsset.GetHint().c_str());
+                     return false;
+                }
+            }
+            return true;
+        }
+
+        AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ)
+        {
+            // Set default 1, 1, 1 now. In case of errors later this is what the caller will get.
+            if (numThreadsX)
+            {
+                *numThreadsX = 1;
+            }
+            if (numThreadsY)
+            {
+                *numThreadsY = 1;
+            }
+            if (numThreadsZ)
+            {
+                *numThreadsZ = 1;
+            }
+            const auto numThreads = shaderAsset->GetAttribute(RHI::ShaderStage::Compute, attributeName);
+            if (!numThreads)
+            {
+                return AZ::Failure(AZStd::string::format("Couldn't find attribute '%s' in shader asset '%s'", attributeName.GetCStr(), shaderAsset.GetHint().c_str()));
+            }
+            const RHI::ShaderStageAttributeArguments& args = *numThreads;
+            AZStd::string errorMsg;
+            if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 0, numThreadsX, errorMsg))
+            {
+                return AZ::Failure(errorMsg);
+            }
+            if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 1, numThreadsY, errorMsg))
+            {
+                return AZ::Failure(errorMsg);
+            }
+            if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 2, numThreadsZ, errorMsg))
+            {
+                return AZ::Failure(errorMsg);
+            }
+            return AZ::Success();
+        }
+
+        AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ)
+        {
+            return GetComputeShaderNumThreads(shaderAsset, Name{ "numthreads" }, numThreadsX, numThreadsY, numThreadsZ);
+        }
+
+        AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, RHI::DispatchDirect& dispatchDirect)
+        {
+            return GetComputeShaderNumThreads(shaderAsset, &dispatchDirect.m_threadsPerGroupX, &dispatchDirect.m_threadsPerGroupY, &dispatchDirect.m_threadsPerGroupZ);
+        }
     }
 }