|
@@ -143,5 +143,79 @@ namespace AZ
|
|
|
|
|
|
return RPI::StreamingImage::FindOrCreate(streamingImageAsset);
|
|
|
}
|
|
|
+
|
|
|
+ //! A helper function for GetComputeShaderNumThreads(), to consolidate error messages, etc.
|
|
|
+ static bool GetAttributeArgumentByIndex(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, const RHI::ShaderStageAttributeArguments& args, const size_t argIndex, uint16_t* value, AZStd::string& errorMsg)
|
|
|
+ {
|
|
|
+ if (value)
|
|
|
+ {
|
|
|
+ const auto numArguments = args.size();
|
|
|
+ if (numArguments > argIndex)
|
|
|
+ {
|
|
|
+ if (args[argIndex].type() == azrtti_typeid<int>())
|
|
|
+ {
|
|
|
+ *value = aznumeric_caster(AZStd::any_cast<int>(args[argIndex]));
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ errorMsg = AZStd::string::format("Was expecting argument '%zu' in attribute '%s' to be of type 'int' from shader asset '%s'", argIndex, attributeName.GetCStr(), shaderAsset.GetHint().c_str());
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ errorMsg = AZStd::string::format("Was expecting at least '%zu' arguments in attribute '%s' from shader asset '%s'", argIndex + 1, attributeName.GetCStr(), shaderAsset.GetHint().c_str());
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, const AZ::Name& attributeName, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ)
|
|
|
+ {
|
|
|
+ // Set default 1, 1, 1 now. In case of errors later this is what the caller will get.
|
|
|
+ if (numThreadsX)
|
|
|
+ {
|
|
|
+ *numThreadsX = 1;
|
|
|
+ }
|
|
|
+ if (numThreadsY)
|
|
|
+ {
|
|
|
+ *numThreadsY = 1;
|
|
|
+ }
|
|
|
+ if (numThreadsZ)
|
|
|
+ {
|
|
|
+ *numThreadsZ = 1;
|
|
|
+ }
|
|
|
+ const auto numThreads = shaderAsset->GetAttribute(RHI::ShaderStage::Compute, attributeName);
|
|
|
+ if (!numThreads)
|
|
|
+ {
|
|
|
+ return AZ::Failure(AZStd::string::format("Couldn't find attribute '%s' in shader asset '%s'", attributeName.GetCStr(), shaderAsset.GetHint().c_str()));
|
|
|
+ }
|
|
|
+ const RHI::ShaderStageAttributeArguments& args = *numThreads;
|
|
|
+ AZStd::string errorMsg;
|
|
|
+ if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 0, numThreadsX, errorMsg))
|
|
|
+ {
|
|
|
+ return AZ::Failure(errorMsg);
|
|
|
+ }
|
|
|
+ if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 1, numThreadsY, errorMsg))
|
|
|
+ {
|
|
|
+ return AZ::Failure(errorMsg);
|
|
|
+ }
|
|
|
+ if (!GetAttributeArgumentByIndex(shaderAsset, attributeName, args, 2, numThreadsZ, errorMsg))
|
|
|
+ {
|
|
|
+ return AZ::Failure(errorMsg);
|
|
|
+ }
|
|
|
+ return AZ::Success();
|
|
|
+ }
|
|
|
+
|
|
|
+ AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, uint16_t* numThreadsX, uint16_t* numThreadsY, uint16_t* numThreadsZ)
|
|
|
+ {
|
|
|
+ return GetComputeShaderNumThreads(shaderAsset, Name{ "numthreads" }, numThreadsX, numThreadsY, numThreadsZ);
|
|
|
+ }
|
|
|
+
|
|
|
+ AZ::Outcome<void, AZStd::string> GetComputeShaderNumThreads(const Data::Asset<ShaderAsset>& shaderAsset, RHI::DispatchDirect& dispatchDirect)
|
|
|
+ {
|
|
|
+ return GetComputeShaderNumThreads(shaderAsset, &dispatchDirect.m_threadsPerGroupX, &dispatchDirect.m_threadsPerGroupY, &dispatchDirect.m_threadsPerGroupZ);
|
|
|
+ }
|
|
|
}
|
|
|
}
|