Browse Source

Add support for specialization constants for shader options (#18019)

- Add support of Specialization Constants for dx12 and vulkan
- Add configure function to specialize a pipeline descriptor
- Refactor ShaderVariantAsset job to batch variants
- Add signing of DX12 shader bytecode at runtime
- Add support for function constants on Metal

Signed-off-by: Akio Gaule <[email protected]>
Akio Gaule 1 year ago
parent
commit
fdabdc28e1
100 changed files with 1947 additions and 283 deletions
  1. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Android/shader_build_options.settings
  2. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Linux/shader_build_options.settings
  3. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Mac/shader_build_options.settings
  4. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/dx12/shader_build_options.settings
  5. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/shader_build_options.settings
  6. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/vulkan/shader_build_options.settings
  7. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/iOS/shader_build_options.settings
  8. 1 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/README.md
  9. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/shader_build_options.settings
  10. 18 2
      Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.cpp
  11. 1 1
      Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.h
  12. 3 3
      Gems/Atom/Asset/Shader/Code/Source/Editor/AzslShaderBuilderSystemComponent.cpp
  13. 18 2
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderAssetBuilder.cpp
  14. 13 7
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuildArgumentsManager.h
  15. 4 2
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.cpp
  16. 2 1
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.h
  17. 124 45
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.cpp
  18. 4 0
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.h
  19. 1 1
      Gems/Atom/Asset/Shader/Registry/atom_shaders.setreg
  20. 3 3
      Gems/Atom/Feature/Common/Code/Source/CoreLights/DepthExponentiationPass.cpp
  21. 9 8
      Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp
  22. 2 2
      Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h
  23. 2 2
      Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp
  24. 5 6
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/BlendColorGradingLutsPass.cpp
  25. 2 2
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/PostProcessingShaderOptionBase.cpp
  26. 2 7
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.cpp
  27. 0 1
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.h
  28. 1 1
      Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp
  29. 14 27
      Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.cpp
  30. 0 4
      Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.h
  31. 2 2
      Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshDispatchItem.cpp
  32. 11 2
      Gems/Atom/Feature/Common/Code/Source/SkyAtmosphere/SkyAtmospherePass.cpp
  33. 3 1
      Gems/Atom/RHI/Code/Include/Atom/RHI.Edit/ShaderPlatformInterface.h
  34. 10 4
      Gems/Atom/RHI/Code/Include/Atom/RHI/PipelineStateDescriptor.h
  35. 45 0
      Gems/Atom/RHI/Code/Include/Atom/RHI/SpecializationConstant.h
  36. 23 16
      Gems/Atom/RHI/Code/Source/RHI/PipelineStateDescriptor.cpp
  37. 31 0
      Gems/Atom/RHI/Code/Source/RHI/SpecializationConstant.cpp
  38. 2 0
      Gems/Atom/RHI/Code/atom_rhi_public_files.cmake
  39. 12 0
      Gems/Atom/RHI/DX12/Code/CMakeLists.txt
  40. 13 0
      Gems/Atom/RHI/DX12/Code/Include/Atom/RHI.Reflect/DX12/ShaderStageFunction.h
  41. 73 4
      Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  42. 5 2
      Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  43. 18 2
      Gems/Atom/RHI/DX12/Code/Source/RHI.Reflect/ShaderStageFunction.cpp
  44. 2 0
      Gems/Atom/RHI/DX12/Code/Source/RHI/DX12.h
  45. 14 6
      Gems/Atom/RHI/DX12/Code/Source/RHI/PipelineState.cpp
  46. 7 2
      Gems/Atom/RHI/DX12/Code/Source/RHI/RayTracingPipelineState.cpp
  47. 237 0
      Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.cpp
  48. 36 0
      Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.h
  49. 2 0
      Gems/Atom/RHI/DX12/Code/atom_rhi_dx12_private_common_files.cmake
  50. 12 0
      Gems/Atom/RHI/DX12/Code/openssl_md5_files.cmake
  51. 38 0
      Gems/Atom/RHI/DX12/External/md5/README.md
  52. 289 0
      Gems/Atom/RHI/DX12/External/md5/openssl/md5.c
  53. 54 0
      Gems/Atom/RHI/DX12/External/md5/openssl/md5.h
  54. 2 1
      Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  55. 2 1
      Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  56. 41 7
      Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.cpp
  57. 3 2
      Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.h
  58. 2 1
      Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  59. 2 1
      Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  60. 2 1
      Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  61. 2 1
      Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  62. 3 1
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.cpp
  63. 3 0
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.h
  64. 11 3
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/RayTracingPipelineState.cpp
  65. 66 0
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.cpp
  66. 37 0
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.h
  67. 2 0
      Gems/Atom/RHI/Vulkan/Code/atom_rhi_vulkan_private_common_files.cmake
  68. 5 1
      Gems/Atom/RPI/Code/Include/Atom/RPI.Edit/Shader/ShaderVariantAssetCreator.h
  69. 3 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h
  70. 4 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h
  71. 4 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h
  72. 22 2
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h
  73. 18 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAsset.h
  74. 3 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAssetCreator.h
  75. 18 1
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderOptionGroupLayout.h
  76. 2 2
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderVariantAsset.h
  77. 5 1
      Gems/Atom/RPI/Code/Source/RPI.Edit/Shader/ShaderVariantAssetCreator.cpp
  78. 1 1
      Gems/Atom/RPI/Code/Source/RPI.Public/MeshDrawPacket.cpp
  79. 19 1
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp
  80. 17 3
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp
  81. 1 1
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/Specific/ImageAttachmentPreviewPass.cpp
  82. 27 16
      Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp
  83. 3 2
      Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp
  84. 78 2
      Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp
  85. 24 3
      Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAsset.cpp
  86. 17 0
      Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAssetCreator.cpp
  87. 37 3
      Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderOptionGroupLayout.cpp
  88. 212 5
      Gems/Atom/RPI/Code/Tests/Shader/ShaderTests.cpp
  89. 2 2
      Gems/AtomLyIntegration/EditorModeFeedback/Code/Source/Draw/EditorStateMeshDrawPacket.cpp
  90. 0 14
      Gems/AtomTressFX/Assets/Passes/HairParentShortCutPass.pass
  91. 10 1
      Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryDepthAlpha.pass
  92. 8 1
      Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryShading.pass
  93. 37 21
      Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.cpp
  94. 4 0
      Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.h
  95. 8 4
      Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.cpp
  96. 1 0
      Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.h
  97. 5 7
      Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.cpp
  98. 0 1
      Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.h
  99. 5 1
      Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.cpp
  100. 1 0
      Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.h

+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Android/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Android/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Linux/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Linux/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Mac/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Mac/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Windows/dx12/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/dx12/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Windows/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Windows/vulkan/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/vulkan/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/iOS/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/iOS/shader_build_options.settings


+ 1 - 0
Gems/Atom/Asset/Shader/Assets/Config/Shader/README.md

@@ -0,0 +1 @@
+The files in these folders are config files for building shaders, not real assets that will be used at runtime. They are in the "Assets" folder so we can create source dependencies with the azsl shader files to be able to detect changes in order to rebuild the shaders.

+ 0 - 0
Gems/Atom/Asset/Shader/Config/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/shader_build_options.settings


+ 18 - 2
Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.cpp

@@ -900,7 +900,10 @@ namespace AZ
             return CompileToFileAndPrepareJsonDocument(output, "--options", "options.json") == BuildResult::Success;
         }
 
-        bool AzslCompiler::ParseOptionsPopulateOptionGroupLayout(const rapidjson::Document& input, RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout) const
+        bool AzslCompiler::ParseOptionsPopulateOptionGroupLayout(
+            const rapidjson::Document& input,
+            RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout,
+            bool& outUseSpecializationConstants) const
         {
             auto totalBitOffset = (uint32_t) 0u;
 
@@ -933,6 +936,12 @@ namespace AZ
                     return false;
                 };
 
+            outUseSpecializationConstants = false;
+            if (input.HasMember("specializationConstants"))
+            {
+                outUseSpecializationConstants = input["specializationConstants"].GetBool();
+            }
+
             const rapidjson::Value& shaderOptions = input["ShaderOptions"];
             AZ_Assert(shaderOptions.IsArray(), "Attribute ShaderOptions must be an array");
 
@@ -1037,13 +1046,20 @@ namespace AZ
                         cost = optionEntry["costImpact"].GetUint();
                     }
 
+                    int specializationId = -1;
+                    if (optionEntry.HasMember("specializationId"))
+                    {
+                        specializationId = optionEntry["specializationId"].GetInt();
+                    }
+
                     RPI::ShaderOptionDescriptor shaderOption(Name(optionName), 
                                                              optionType,
                                                              keyOffset,
                                                              order,
                                                              idIndexList,
                                                              defaultValueId,
-                                                             cost);
+                                                             cost,
+                                                             specializationId);
 
                     if (!shaderOptionGroupLayout->AddShaderOption(shaderOption))
                     {

+ 1 - 1
Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.h

@@ -61,7 +61,7 @@ namespace AZ
             //! make sense of a --srg json document and fill up the srg data container
             bool ParseSrgPopulateSrgData(const rapidjson::Document& input, SrgDataContainer& outSrgData) const;
             //! make sense of a --option json document and fill up the shader option group layout
-            bool ParseOptionsPopulateOptionGroupLayout(const rapidjson::Document& input, RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout) const;
+            bool ParseOptionsPopulateOptionGroupLayout(const rapidjson::Document& input, RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout, bool& outSpecializationConstants) const;
             //! make sense of a --bindingdep json documment and fill up the binding dependencies object
             bool ParseBindingdepPopulateBindingDependencies(const rapidjson::Document& input, BindingDependencies& bindingDependencies) const;
             //! make sense of a --srg json document and fill up the root constant data

+ 3 - 3
Gems/Atom/Asset/Shader/Code/Source/Editor/AzslShaderBuilderSystemComponent.cpp

@@ -83,7 +83,7 @@ namespace AZ
             // Register Shader Asset Builder
             AssetBuilderSDK::AssetBuilderDesc shaderAssetBuilderDescriptor;
             shaderAssetBuilderDescriptor.m_name = "Shader Asset Builder";
-            shaderAssetBuilderDescriptor.m_version = 123; // Metal shader debug symbols controlled via settings registry or shader build arguments
+            shaderAssetBuilderDescriptor.m_version = 124; // Add specialization constants for shader options
             shaderAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern( AZStd::string::format("*.%s", RPI::ShaderSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
             shaderAssetBuilderDescriptor.m_busId = azrtti_typeid<ShaderAssetBuilder>();
             shaderAssetBuilderDescriptor.m_createJobFunction = AZStd::bind(&ShaderAssetBuilder::CreateJobs, &m_shaderAssetBuilder, AZStd::placeholders::_1, AZStd::placeholders::_2);
@@ -108,7 +108,7 @@ namespace AZ
                 shaderVariantAssetBuilderDescriptor.m_name = "Shader Variant Asset Builder";
                 // Both "Shader Variant Asset Builder" and "Shader Asset Builder" produce ShaderVariantAsset products. If you update
                 // ShaderVariantAsset you will need to update BOTH version numbers, not just "Shader Variant Asset Builder".
-                shaderVariantAssetBuilderDescriptor.m_version = 40; // Metal shader debug symbols controlled via settings registry or shader build arguments
+                shaderVariantAssetBuilderDescriptor.m_version = 41; // Add specialization constants for shader options
                 shaderVariantAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", HashedVariantListSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", HashedVariantInfoSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantAssetBuilderDescriptor.m_busId = azrtti_typeid<ShaderVariantAssetBuilder>();
@@ -121,7 +121,7 @@ namespace AZ
                 // Register Shader Variant List Builder
                 AssetBuilderSDK::AssetBuilderDesc shaderVariantListBuilderDescriptor;
                 shaderVariantListBuilderDescriptor.m_name = "Shader Variant List Builder";
-                shaderVariantListBuilderDescriptor.m_version = 1; // First version of ShaderVariantListBuilder
+                shaderVariantListBuilderDescriptor.m_version = 2; // Add specialization constants for shader options
                 shaderVariantListBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", RPI::ShaderVariantListSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantListBuilderDescriptor.m_busId = azrtti_typeid<ShaderVariantListBuilder>();
                 shaderVariantListBuilderDescriptor.m_createJobFunction = AZStd::bind(&ShaderVariantListBuilder::CreateJobs, &m_shaderVariantListBuilder, AZStd::placeholders::_1, AZStd::placeholders::_2);

+ 18 - 2
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderAssetBuilder.cpp

@@ -202,6 +202,15 @@ namespace AZ
                 response.m_sourceFileDependencyList.emplace_back(AZStd::move(includeFileDependency));
             }
 
+            // Add the shader_build_option files as source dependencies
+            AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> configFiles = ShaderBuildArgumentsManager::DiscoverConfigurationFiles();
+            for (const auto& pair : configFiles)
+            {
+                AssetBuilderSDK::SourceFileDependency includeFileDependency;
+                includeFileDependency.m_sourceFileDependencyPath = pair.second.c_str();
+                response.m_sourceFileDependencyList.emplace_back(AZStd::move(includeFileDependency));
+            }
+
             for (const AssetBuilderSDK::PlatformInfo& platformInfo : request.m_enabledPlatforms)
             {
                 AZ_TraceContext("For platform", platformInfo.m_identifier.data());
@@ -497,9 +506,14 @@ namespace AZ
                     RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
                     BindingDependencies bindingDependencies;
                     RootConstantData rootConstantData;
+                    bool usesSpecializationConstants = false;
                     AssetBuilderSDK::ProcessJobResultCode azslJsonReadResult = ShaderBuilderUtility::PopulateAzslDataFromJsonFiles(
                         ShaderAssetBuilderName, subProductsPaths, azslData, srgLayoutList,
-                        shaderOptionGroupLayout, bindingDependencies, rootConstantData, request.m_tempDirPath);
+                        shaderOptionGroupLayout,
+                        bindingDependencies,
+                        rootConstantData,
+                        request.m_tempDirPath,
+                        usesSpecializationConstants);
                     if (azslJsonReadResult != AssetBuilderSDK::ProcessJobResult_Success)
                     {
                         response.m_resultCode = azslJsonReadResult;
@@ -507,6 +521,7 @@ namespace AZ
                     }
 
                     shaderAssetCreator.SetSrgLayoutList(srgLayoutList);
+                    shaderAssetCreator.SetUseSpecializationConstants(usesSpecializationConstants);
 
                     if (!finalShaderOptionGroupLayout)
                     {
@@ -665,7 +680,8 @@ namespace AZ
                         variantAssetId,
                         superVariantAzslinStemName,
                         hlslFullPath,
-                        hlslSourceCode};
+                        hlslSourceCode,
+                        usesSpecializationConstants };
 
                     // Preserve the Temp folder when shaders are compiled with debug symbols
                     // or because the ShaderSourceData has m_keepTempFolder set to true.

+ 13 - 7
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuildArgumentsManager.h

@@ -65,8 +65,8 @@ namespace AZ
             // The value of this registry key is customizable by the user.
             static constexpr char ConfigPathRegistryKey[] = "/O3DE/Atom/Shaders/Build/ConfigPath";
 
-            static constexpr char DefaultConfigPathDirectory[] = "@gemroot:AtomShader@/Config";
-            static constexpr char ShaderBuildOptionsJson[] = "shader_build_options.json";
+            static constexpr char DefaultConfigPathDirectory[] = "@gemroot:AtomShader@/Assets/Config/Shader";
+            static constexpr char ShaderBuildOptionsJson[] = "shader_build_options.settings";
             static constexpr char PlatformsDir[] = "Platform";
 
             //! Always loads all the factory arguments provided by the Atom Gem. In addition
@@ -105,17 +105,25 @@ namespace AZ
             //! @remark: The "" (global) arguments are never popped, regardless of how many times this function is called.
             void PopArgumentScope();
 
+            //! Finds the shader build config files from the default locations. Returns a map where the key is the name of the scope,
+            //! and the value is a fully qualified file path.
+            //! Remarks: Posible scope names are:
+            //!     "global"
+            //!     "<platform>". Example "Android", "Windows", etc
+            //!     "<platform>.<rhi>". Example "Windows.dx12" or "Windows.vulkan".
+            static AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFiles();
+
         private:
             friend class ::UnitTest::ShaderBuildArgumentsTests;
             void Init(AZStd::unordered_map<AZStd::string, AZ::RHI::ShaderBuildArguments> && removeBuildArgumentsMap
                     , AZStd::unordered_map<AZStd::string, AZ::RHI::ShaderBuildArguments> && addBuildArgumentsMap);
 
             //! @returns A fully qualified path where the factory settings, as provided by Atom, are found.
-            AZ::IO::FixedMaxPath GetDefaultConfigDirectoryPath();
+            static AZ::IO::FixedMaxPath GetDefaultConfigDirectoryPath();
 
             //! @returns A fully qualified path where the user customized command line arguments are found.
             //!     The returned path will be empty if the user did not customize the path in the registry.
-            AZ::IO::FixedMaxPath GetUserConfigDirectoryPath();
+            static AZ::IO::FixedMaxPath GetUserConfigDirectoryPath();
 
             //! @param dirPath Starting directory for the search of  shader_build_options.json files.
             //! @returns A map where the key is the name of the scope, and the value is a fully qualified file path.
@@ -123,9 +131,7 @@ namespace AZ
             //!     "global"
             //!     "<platform>". Example "Android", "Windows", etc
             //!     "<platform>.<rhi>". Example "Windows.dx12" or "Windows.vulkan".
-            AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFilesInDirectory(const AZ::IO::FixedMaxPath& dirPath);
-
-            AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFiles();
+            static AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFilesInDirectory(const AZ::IO::FixedMaxPath& dirPath);
 
             const AZ::RHI::ShaderBuildArguments& PushArgumentsInternal(const AZStd::string& name, const AZ::RHI::ShaderBuildArguments& arguments);
 

+ 4 - 2
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.cpp

@@ -139,7 +139,8 @@ namespace AZ
                 RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout,
                 BindingDependencies& bindingDependencies,
                 RootConstantData& rootConstantData,
-                const AZStd::string& tempFolder)
+                const AZStd::string& tempFolder,
+                bool& useSpecializationConstants)
             {
                 AzslCompiler azslc(azslData.m_preprocessedFullPath,  // set the input file for eventual error messages, but the compiler won't be called on it.
                                    tempFolder);
@@ -188,7 +189,8 @@ namespace AZ
 
                 // The shader options define what options are available, what are the allowed values/range
                 // for each option and what is its default value.
-                if (!azslc.ParseOptionsPopulateOptionGroupLayout(outcomes[AzslSubProducts::options].GetValue(), shaderOptionGroupLayout))
+                if (!azslc.ParseOptionsPopulateOptionGroupLayout(
+                        outcomes[AzslSubProducts::options].GetValue(), shaderOptionGroupLayout, useSpecializationConstants))
                 {
                     AZ_Error(builderName, false, "Failed to find a valid list of shader options!");
                     return AssetBuilderSDK::ProcessJobResult_Failed;

+ 2 - 1
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.h

@@ -71,7 +71,8 @@ namespace AZ
                 RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout,
                 BindingDependencies& bindingDependencies,
                 RootConstantData& rootConstantData,
-                const AZStd::string& tempFolder);
+                const AZStd::string& tempFolder,
+                bool& useSpecializationConstants);
 
 
             RHI::ShaderHardwareStage ToAssetBuilderShaderType(RPI::ShaderStageType stageType);

+ 124 - 45
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.cpp

@@ -175,6 +175,8 @@ namespace AZ
                 return;
             }
 
+            AZStd::string hashedVariantInfoDescriptorString;
+            RPI::JsonUtils::SaveObjectToJsonString(hashedVariantInfoDescriptor, hashedVariantInfoDescriptorString);
             AZStd::string hashedVariantInfoParentPath(request.m_watchFolder.data());
             AZStd::string hashedVariantListFullPath = GetHashedVariantListPathFromVariantInfoPath(hashedVariantInfoParentPath, hashedVariantInfoRelativePath);
             
@@ -193,6 +195,9 @@ namespace AZ
             
                 jobDescriptor.m_jobKey = GetShaderVariantAssetJobKey();
                 jobDescriptor.SetPlatformIdentifier(info.m_identifier.data());
+
+                // Add the content of the hashedVariantInfo file as a parameter to avoid reading it again.
+                jobDescriptor.m_jobParameters.emplace(ShaderVariantInfoJobParam, hashedVariantInfoDescriptorString);
             
                 // The ShaderVariantAssets should be built AFTER the ShaderVariantTreeAsset.
                 // With "OrderOnly" dependency, We make sure ShaderVariantTreeAsset completes before ShaderVariantAsset runs,
@@ -239,7 +244,8 @@ namespace AZ
             const AssetBuilderSDK::PlatformInfo& platformInfo,
             const AzslCompiler& azslCompiler,
             const AZStd::string& shaderSourceFileFullPath,
-            const RPI::SupervariantIndex supervariantIndex)
+            const RPI::SupervariantIndex supervariantIndex,
+            bool& useSpecializationConstants)
         {
             auto optionsGroupPathOutcome = ShaderBuilderUtility::ObtainBuildArtifactPathFromShaderAssetBuilder(
                 shaderPlatformInterface->GetAPIUniqueIndex(), platformInfo.m_identifier, shaderSourceFileFullPath, supervariantIndex.GetIndex(),
@@ -259,7 +265,8 @@ namespace AZ
                 AZ_Error(ShaderVariantAssetBuilderName, false, "%s", jsonOutcome.GetError().c_str());
                 return nullptr;
             }
-            if (!azslCompiler.ParseOptionsPopulateOptionGroupLayout(jsonOutcome.GetValue(), shaderOptionGroupLayout))
+            if (!azslCompiler.ParseOptionsPopulateOptionGroupLayout(
+                    jsonOutcome.GetValue(), shaderOptionGroupLayout, useSpecializationConstants))
             {
                 AZ_Error(ShaderVariantAssetBuilderName, false, "Failed to find a valid list of shader options!");
                 return nullptr;
@@ -473,28 +480,63 @@ namespace AZ
             ShaderBuilderUtility::GetAbsolutePathToAzslFile(shaderSourceFileFullPath, shaderSourceDescriptor.m_source, azslFullPath);
             AzslCompiler azslc(azslFullPath, request.m_tempDirPath);
 
+            auto supervariantList = ShaderBuilderUtility::GetSupervariantListFromShaderSourceData(shaderSourceDescriptor);
+
             AZStd::string previousLoopApiName;
+            bool usesVariants = false;
             for (RHI::ShaderPlatformInterface* shaderPlatformInterface : platformInterfaces)
             {
                 auto thisLoopApiName = shaderPlatformInterface->GetAPIName().GetStringView();
-                RPI::Ptr<RPI::ShaderOptionGroupLayout> loopLocal_ShaderOptionGroupLayout =
-                    LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
-                        shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, RPI::DefaultSupervariantIndex);
-                if (!loopLocal_ShaderOptionGroupLayout)
-                {
-                    response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
-                    return;
-                }
-                if (shaderOptionGroupLayout && shaderOptionGroupLayout->GetHash() != loopLocal_ShaderOptionGroupLayout->GetHash())
+                for (uint32_t supervariantIndexCounter = 0; supervariantIndexCounter < supervariantList.size(); ++supervariantIndexCounter)
                 {
-                    AZ_Error(ShaderVariantAssetBuilderName, false, "There was a discrepancy in shader options between %s and %s", previousLoopApiName.c_str(), thisLoopApiName.data());
-                    response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
-                    return;
+                    RPI::SupervariantIndex supervariantIndex(supervariantIndexCounter);
+                    bool usesSpecialization = false;
+                    RPI::Ptr<RPI::ShaderOptionGroupLayout> loopLocal_ShaderOptionGroupLayout =
+                        LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
+                            shaderPlatformInterface,
+                            request.m_platformInfo,
+                            azslc,
+                            shaderSourceFileFullPath,
+                            supervariantIndex,
+                            usesSpecialization);
+                    if (!loopLocal_ShaderOptionGroupLayout)
+                    {
+                        response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                        return;
+                    }
+                    if (shaderOptionGroupLayout && shaderOptionGroupLayout->GetHash() != loopLocal_ShaderOptionGroupLayout->GetHash())
+                    {
+                        AZ_Error(
+                            ShaderVariantAssetBuilderName,
+                            false,
+                            "There was a discrepancy in shader options between %s and %s",
+                            previousLoopApiName.c_str(),
+                            thisLoopApiName.data());
+                        response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                        return;
+                    }
+
+                    // Check if there's a supervariant that needs to generate the variants
+                    if (!usesSpecialization || !loopLocal_ShaderOptionGroupLayout->IsFullySpecialized())
+                    {
+                        usesVariants = true;
+                    }
+                    shaderOptionGroupLayout = loopLocal_ShaderOptionGroupLayout;
                 }
-                shaderOptionGroupLayout = loopLocal_ShaderOptionGroupLayout;
                 previousLoopApiName = thisLoopApiName;
             }
 
+            if (!usesVariants)
+            {
+                // No need to create the variant tree since all supervariants are fully specialized. Exit gracefully.
+                AZ_TracePrintf(
+                    ShaderVariantAssetBuilderName,
+                    "No azshadervarianttree is produced on behalf of %s because all valid RHI backends are using specialization constants for shader options.\n",
+                    shaderSourceFileFullPath.c_str());
+                response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Success;
+                return;
+            }
+
             AZStd::vector<RPI::ShaderVariantListSourceData::VariantInfo> variantInfos;
             variantInfos.reserve(hashedVariantListDescriptor.m_hashedVariants.size());
             for (const auto& hashedVariantInfo : hashedVariantListDescriptor.m_hashedVariants)
@@ -533,8 +575,7 @@ namespace AZ
 
             AZ_TracePrintf(ShaderVariantAssetBuilderName, "Shader Variant Tree Asset [%s] compiled successfully.\n", assetPath.c_str());
 
-            response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Success;
- 
+            response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Success; 
         }
 
 
@@ -545,8 +586,18 @@ namespace AZ
             AZStd::string hashedVariantInfoFullPath;
             AZ::StringFunc::Path::ConstructFull(request.m_watchFolder.data(), request.m_sourceFile.data(), hashedVariantInfoFullPath, true);
 
+
+            AZStd::string hashedVariantInfoDescriptorString;
+            if (!request.m_jobDescription.m_jobParameters.contains(ShaderVariantInfoJobParam))
+            {
+                AZ_Error(ShaderVariantAssetBuilderName, false, "Missing job Parameter: ShaderVariantInfoJobParam");
+                response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                return;
+            }
+            hashedVariantInfoDescriptorString = request.m_jobDescription.m_jobParameters.at(ShaderVariantInfoJobParam);
+
             HashedVariantInfoSourceData hashedVariantInfoDescriptor;
-            if (!RPI::JsonUtils::LoadObjectFromFile(hashedVariantInfoFullPath, hashedVariantInfoDescriptor, AZStd::numeric_limits<size_t>::max()))
+            if (!RPI::JsonUtils::LoadObjectFromJsonString(hashedVariantInfoDescriptorString, hashedVariantInfoDescriptor))
             {
                 AZ_Assert(false, "Failed to parse Hashed Variant Info Descriptor JSON [%s]", hashedVariantInfoFullPath.c_str());
                 response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
@@ -598,12 +649,17 @@ namespace AZ
                 AZ_TraceContext("Platform API", apiName);
 
                 buildArgsManager.PushArgumentScope(apiName);
-                buildArgsManager.PushArgumentScope(shaderSourceDescriptor.m_removeBuildArguments, shaderSourceDescriptor.m_addBuildArguments, shaderSourceDescriptor.m_definitions);
+                buildArgsManager.PushArgumentScope(
+                    shaderSourceDescriptor.m_removeBuildArguments,
+                    shaderSourceDescriptor.m_addBuildArguments,
+                    shaderSourceDescriptor.m_definitions);
 
                 // Loop through all the Supervariants.
-                uint32_t supervariantIndexCounter = 0;
-                for (const auto& supervariantInfo : supervariantList)
+                for (uint32_t supervariantIndexCounter = 0;
+                    supervariantIndexCounter < supervariantList.size();
+                    ++supervariantIndexCounter)
                 {
+                    const auto& supervariantInfo = supervariantList[supervariantIndexCounter];
                     RPI::SupervariantIndex supervariantIndex(supervariantIndexCounter);
 
                     // Check if we were canceled before we do any heavy processing of
@@ -614,7 +670,8 @@ namespace AZ
                         return;
                     }
 
-                    buildArgsManager.PushArgumentScope(supervariantInfo.m_removeBuildArguments, supervariantInfo.m_addBuildArguments, supervariantInfo.m_definitions);
+                    buildArgsManager.PushArgumentScope(
+                        supervariantInfo.m_removeBuildArguments, supervariantInfo.m_addBuildArguments, supervariantInfo.m_definitions);
 
                     AZStd::string shaderStemNamePrefix = shaderFileName;
                     if (supervariantIndex.GetIndex() > 0)
@@ -628,22 +685,40 @@ namespace AZ
                     // 3- hlsl code.
 
                     // 1- ShaderOptionsGroupLayout
+                    // The ShaderOptionsGroupLayout is the same for all platforms and supervariants, but the each supervariant
+                    // can have the use of specialization constants on or off.
+                    bool usesSpecializationConstants = false;
+                    shaderOptionGroupLayout = LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
+                        shaderPlatformInterface,
+                        request.m_platformInfo,
+                        azslc,
+                        shaderSourceFileFullPath,
+                        supervariantIndex,
+                        usesSpecializationConstants);
                     if (!shaderOptionGroupLayout)
                     {
-                        shaderOptionGroupLayout =
-                            LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
-                                shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, supervariantIndex);
-                        if (!shaderOptionGroupLayout)
-                        {
-                            response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
-                            return;
-                        }
+                        response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                        return;
+                    }
+
+                    if (usesSpecializationConstants && shaderOptionGroupLayout->IsFullySpecialized())
+                    {
+                        // No need to create the shader variants since all supervariants are fully specialized.
+                        AZ_TracePrintf(
+                            ShaderVariantAssetBuilderName,
+                            "No azshaderVariant is produced on behalf of %s, super variant %s, because it's using specialization "
+                            "constants "
+                            "for shader options.\n",
+                            shaderSourceFileFullPath.c_str(),
+                            supervariantInfo.m_name.GetCStr());
+                        buildArgsManager.PopArgumentScope();
+                        continue;
                     }
 
                     // 2- entryFunctions.
                     AzslFunctions azslFunctions;
                     LoadShaderFunctionsFromShaderAssetBuilder(
-                        shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, supervariantIndex,  azslFunctions);
+                        shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, supervariantIndex, azslFunctions);
                     if (azslFunctions.empty())
                     {
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
@@ -652,7 +727,7 @@ namespace AZ
                     MapOfStringToStageType shaderEntryPoints;
                     if (shaderSourceDescriptor.m_programSettings.m_entryPoints.empty())
                     {
-                        AZ_Error(ShaderVariantAssetBuilderName, false,  "ProgramSettings must specify entry points.");
+                        AZ_Error(ShaderVariantAssetBuilderName, false, "ProgramSettings must specify entry points.");
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                         return;
                     }
@@ -671,7 +746,7 @@ namespace AZ
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                         return;
                     }
-                    
+
                     //! It is important to keep this refcounted pointer outside of the if block to prevent it from being destroyed.
                     RHI::Ptr<RHI::PipelineLayoutDescriptor> pipelineLayoutDescriptor;
                     if (shaderPlatformInterface->VariantCompilationRequiresSrgLayoutData())
@@ -710,16 +785,18 @@ namespace AZ
                     }
 
                     // Setup the shader variant creation context:
-                    ShaderVariantCreationContext shaderVariantCreationContext =
-                    {
-                        *shaderPlatformInterface, request.m_platformInfo, buildArgsManager.GetCurrentArguments(), request.m_tempDirPath,
-                        shaderSourceDescriptor,
-                        *shaderOptionGroupLayout.get(),
-                        shaderEntryPoints,
-                        Uuid::CreateRandom(),
-                        shaderStemNamePrefix,
-                        hlslSourcePath, hlslCode
-                    };
+                    ShaderVariantCreationContext shaderVariantCreationContext = { *shaderPlatformInterface,
+                                                                                  request.m_platformInfo,
+                                                                                  buildArgsManager.GetCurrentArguments(),
+                                                                                  request.m_tempDirPath,
+                                                                                  shaderSourceDescriptor,
+                                                                                  *shaderOptionGroupLayout.get(),
+                                                                                  shaderEntryPoints,
+                                                                                  Uuid::CreateRandom(),
+                                                                                  shaderStemNamePrefix,
+                                                                                  hlslSourcePath,
+                                                                                  hlslCode,
+                                                                                  usesSpecializationConstants };
 
                     // Preserve the Temp folder when shaders are compiled with debug symbols
                     // or because the ShaderSourceData has m_keepTempFolder set to true.
@@ -767,7 +844,6 @@ namespace AZ
                         }
                     }
                     buildArgsManager.PopArgumentScope(); // Pop the supervariant build arguments.
-                    supervariantIndexCounter++;
                 } // End of supervariant for block
 
                 buildArgsManager.PopArgumentScope(); // Pop the .shader build arguments.
@@ -923,7 +999,10 @@ namespace AZ
                 RHI::ShaderPlatformInterface::StageDescriptor descriptor;
                 bool shaderWasCompiled = creationContext.m_shaderPlatformInterface.CompilePlatformInternal(
                     creationContext.m_platformInfo, variantShaderSourcePath, shaderEntryName, assetBuilderShaderType,
-                    creationContext.m_tempDirPath, descriptor, creationContext.m_shaderBuildArguments);
+                    creationContext.m_tempDirPath,
+                    descriptor,
+                    creationContext.m_shaderBuildArguments,
+                    creationContext.m_useSpecializationConstants);
 
                 if (!shaderWasCompiled)
                 {

+ 4 - 0
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.h

@@ -43,6 +43,7 @@ namespace AZ
             const AZStd::string& m_shaderStemNamePrefix; //<shaderName>-<supervariantName>
             const AZStd::string& m_hlslSourcePath;
             const AZStd::string& m_hlslSourceContent;
+            const bool m_useSpecializationConstants = false;
         };
 
 
@@ -88,6 +89,9 @@ namespace AZ
             void ShutDown() override { };
 
         private:
+            // Content of the hashedVariantInfo file 
+            static constexpr uint32_t ShaderVariantInfoJobParam = 0;
+
             AZ_DISABLE_COPY_MOVE(ShaderVariantAssetBuilder);
 
             static constexpr uint32_t ShaderSourceFilePathJobParam = 1;

+ 1 - 1
Gems/Atom/Asset/Shader/Registry/atom_shaders.setreg

@@ -4,7 +4,7 @@
             "Shaders": {
                 "BuildVariants": true,
                 "Build": {
-                    "ConfigPath": "@gemroot:AtomShader@/Config"
+                    "ConfigPath": "@gemroot:AtomShader@/Assets/Config/Shader"
                 }
             }
         }

+ 3 - 3
Gems/Atom/Feature/Common/Code/Source/CoreLights/DepthExponentiationPass.cpp

@@ -48,7 +48,7 @@ namespace AZ
             RPI::ShaderOptionGroup shaderOption = m_shader->CreateShaderOptionGroup();
             shaderOption.SetValue(m_optionName, m_optionValues[typeIndex]);
 
-            if (m_shaderResourceGroup)
+            if (!m_shaderVariant[typeIndex].m_isFullyBaked && m_shaderResourceGroup)
             {
                 m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(shaderOption.GetShaderVariantKeyFallbackValue());
             }
@@ -108,9 +108,9 @@ namespace AZ
                 RPI::ShaderVariant shaderVariant = m_shader->GetVariant(shaderOption.GetShaderVariantId());
 
                 RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOption);
 
-                ShaderVariantInfo variationInfo{shaderVariant.IsFullyBaked(),
+                ShaderVariantInfo variationInfo{!shaderVariant.UseKeyFallback(),
                     m_shader->AcquirePipelineState(pipelineStateDescriptor)
                 };
                 m_shaderVariant.push_back(AZStd::move(variationInfo));

+ 9 - 8
Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp

@@ -87,8 +87,8 @@ namespace AZ
 
         void LightCullingTilePreparePass::ChooseShaderVariant()
         {
-            const AZ::RPI::ShaderVariant& shaderVariant = CreateShaderVariant();
-            CreatePipelineStateFromShaderVariant(shaderVariant);
+            auto [shaderVariant, shaderOptions] = CreateShaderVariant();
+            CreatePipelineStateFromShaderVariant(shaderVariant, shaderOptions);
         }
 
         AZ::Name LightCullingTilePreparePass::GetMultiSampleName()
@@ -121,30 +121,31 @@ namespace AZ
         AZ::RPI::ShaderOptionGroup LightCullingTilePreparePass::CreateShaderOptionGroup()
         {
             RPI::ShaderOptionGroup shaderOptionGroup = m_shader->CreateShaderOptionGroup();
-            shaderOptionGroup.SetUnspecifiedToDefaultValues();
             shaderOptionGroup.SetValue(m_msaaOptionName, GetMultiSampleName());
+            shaderOptionGroup.SetUnspecifiedToDefaultValues();
             return shaderOptionGroup;
         }
 
-        void LightCullingTilePreparePass::CreatePipelineStateFromShaderVariant(const RPI::ShaderVariant& shaderVariant)
+        void LightCullingTilePreparePass::CreatePipelineStateFromShaderVariant(
+            const RPI::ShaderVariant& shaderVariant, const RPI::ShaderOptionGroup& shaderOptions)
         {
             AZ::RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptions);
             m_msaaPipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
             AZ_Error("LightCulling", m_msaaPipelineState, "Failed to acquire pipeline state for shader");
         }
 
-        const AZ::RPI::ShaderVariant& LightCullingTilePreparePass::CreateShaderVariant()
+        AZStd::pair<const AZ::RPI::ShaderVariant&, RPI::ShaderOptionGroup> LightCullingTilePreparePass::CreateShaderVariant()
         {
             RPI::ShaderOptionGroup shaderOptionGroup = CreateShaderOptionGroup();
             const RPI::ShaderVariant& shaderVariant = m_shader->GetVariant(shaderOptionGroup.GetShaderVariantId());
 
             //Set the fallbackkey
-            if (m_drawSrg)
+            if (shaderVariant.UseKeyFallback() && m_drawSrg)
             {
                 m_drawSrg->SetShaderVariantKeyFallbackValue(shaderOptionGroup.GetShaderVariantKeyFallbackValue());
             }
-            return shaderVariant;
+            return { shaderVariant, shaderOptionGroup };
         }
 
         void LightCullingTilePreparePass::SetConstantData()

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h

@@ -66,8 +66,8 @@ namespace AZ
             AZStd::array<float, 2> ComputeUnprojectConstants() const;
             AZ::RHI::Size GetDepthBufferDimensions();
             void ChooseShaderVariant();
-            const AZ::RPI::ShaderVariant& CreateShaderVariant();
-            void CreatePipelineStateFromShaderVariant(const RPI::ShaderVariant& shaderVariant);
+            AZStd::pair<const AZ::RPI::ShaderVariant&, RPI::ShaderOptionGroup> CreateShaderVariant();
+            void CreatePipelineStateFromShaderVariant(const RPI::ShaderVariant& shaderVariant, const RPI::ShaderOptionGroup& options);
             void SetConstantData();
             void OnShaderReloaded();
 

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp

@@ -66,13 +66,13 @@ namespace AZ
                 return false;
             }
 
-            if (!shaderVariant.IsFullyBaked() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
+            if (shaderVariant.UseKeyFallback() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
             {
                 m_instanceSrg->SetShaderVariantKeyFallbackValue(shaderOptionGroup.GetShaderVariantKeyFallbackValue());
             }
 
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptionGroup);
 
 
             InitRootConstants(pipelineStateDescriptor.m_pipelineLayoutDescriptor->GetRootConstantsLayout());

+ 5 - 6
Gems/Atom/Feature/Common/Code/Source/PostProcessing/BlendColorGradingLutsPass.cpp

@@ -65,10 +65,10 @@ namespace AZ
                 RPI::ShaderVariant shaderVariant = m_shader->GetVariant(shaderOption.GetShaderVariantId());
 
                 RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOption);
 
                 ShaderVariantInfo variantInfo{
-                    shaderVariant.IsFullyBaked(),
+                    !shaderVariant.UseKeyFallback(),
                     m_shader->AcquirePipelineState(pipelineStateDescriptor)
                 };
                 m_shaderVariant.push_back(AZStd::move(variantInfo));
@@ -91,11 +91,10 @@ namespace AZ
                 m_currentShaderVariantIndex = m_numSourceLuts;
             }
 
-            auto shaderOption = m_shader->CreateShaderOptionGroup();
-            shaderOption.SetValue(m_numSourceLutsShaderVariantOptionName, RPI::ShaderOptionValue{ m_numSourceLuts });
-
             if (!m_shaderVariant[m_currentShaderVariantIndex].m_isFullyBaked)
             {
+                auto shaderOption = m_shader->CreateShaderOptionGroup();
+                shaderOption.SetValue(m_numSourceLutsShaderVariantOptionName, RPI::ShaderOptionValue{ m_numSourceLuts });
                 m_currentShaderVariantKeyFallbackValue = shaderOption.GetShaderVariantKeyFallbackValue();
             }
             m_needToUpdateShaderVariant = false;
@@ -196,7 +195,7 @@ namespace AZ
                     m_shaderResourceGroup->SetConstant(m_shaderInputSourceLut4ShaperScaleIndex, m_colorGradingShaperParams[3].m_scale);
                 }
 
-                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
+                if (!m_shaderVariant[m_currentShaderVariantIndex].m_isFullyBaked && m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
                 {
                     m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_currentShaderVariantKeyFallbackValue);
                 }

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/PostProcessing/PostProcessingShaderOptionBase.cpp

@@ -25,7 +25,7 @@ namespace AZ
 
             auto shaderVariant = shader->GetVariant(shaderOption.GetShaderVariantId());
 
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOption);
             pipelineStateDescriptor.m_renderAttachmentConfiguration = renderAttachmentConfiguration;
             pipelineStateDescriptor.m_renderStates.m_multisampleState = multisampleState;
 
@@ -37,7 +37,7 @@ namespace AZ
             pipelineStateDescriptor.m_inputStreamLayout = inputStreamLayout;
 
             m_shaderVariantTable[variationKey].m_pipelineState = shader->AcquirePipelineState(pipelineStateDescriptor);
-            m_shaderVariantTable[variationKey].m_isFullyBaked = shaderVariant.IsFullyBaked();
+            m_shaderVariantTable[variationKey].m_isFullyBaked = !shaderVariant.UseKeyFallback();
         }
 
         void PostProcessingShaderOptionBase::UpdateShaderVariant(const AZ::RPI::ShaderOptionGroup& shaderOption)

+ 2 - 7
Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.cpp

@@ -70,11 +70,6 @@ namespace AZ
             if (m_needToUpdateSRG)
             {
                 UpdateSRG();
-
-                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-                {
-                    m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_currentShaderVariantKeyFallbackValue);
-                }
                 m_needToUpdateSRG = false;
             }
 
@@ -86,8 +81,8 @@ namespace AZ
             auto shaderOption = m_shader->CreateShaderOptionGroup();
 
             GetCurrentShaderOption(shaderOption);
-
-            m_currentShaderVariantKeyFallbackValue = shaderOption.GetShaderVariantKeyFallbackValue();
+            shaderOption.SetUnspecifiedToDefaultValues();
+            UpdateShaderOptions(shaderOption.GetShaderVariantId());
             m_needToUpdateShaderVariant = false;
             InvalidateSRG();
         }

+ 0 - 1
Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.h

@@ -60,7 +60,6 @@ namespace AZ
             // Scope producer functions...
             void CompileResources(const RHI::FrameGraphCompileContext& context) override;
 
-            AZ::RPI::ShaderVariantKey m_currentShaderVariantKeyFallbackValue;
             bool m_needToUpdateShaderVariant = false;
             bool m_needToUpdateSRG = true;
         };

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp

@@ -95,7 +95,7 @@ namespace AZ
                 auto shader{ AZ::RPI::Shader::FindOrCreate(shaderAsset, supervariantName) };
                 auto shaderVariant{ shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
                 AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shader->GetDefaultShaderOptions());
                 auto& shaderLib = shaderLibs.emplace_back();
                 shaderLib.m_shaderAssetId = assetReference.m_assetId;
                 shaderLib.m_shader = shader;

+ 14 - 27
Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.cpp

@@ -209,35 +209,37 @@ namespace AZ
 
         void DeferredFogPass::UpdateShaderOptions()
         {
-            RPI::ShaderOptionGroup shaderOption = m_shader->CreateShaderOptionGroup();
+            RPI::ShaderOptionGroup shaderOptions = m_shader->CreateShaderOptionGroup();
             DeferredFogSettings* fogSettings = GetPassFogSettings();
 
             // [TODO][ATOM-13659] - AZ::Name all over our code base should use init with string and
             // hash key for the iterations themselves.
-            shaderOption.SetValue(AZ::Name("o_enableFogLayer"),
+            shaderOptions.SetValue(
+                AZ::Name("o_enableFogLayer"),
                 r_fogLayerSupport && fogSettings->GetEnableFogLayerShaderOption() ? AZ::Name("true") : AZ::Name("false"));
-            shaderOption.SetValue(AZ::Name("o_useNoiseTexture"),
+            shaderOptions.SetValue(
+                AZ::Name("o_useNoiseTexture"),
                 r_fogTurbulenceSupport && fogSettings->GetUseNoiseTextureShaderOption() ? AZ::Name("true") : AZ::Name("false"));
             switch (fogSettings->GetFogMode())
             {
             case FogMode::Linear:
-                shaderOption.SetValue(m_fogModeOptionName, AZ::Name("FogMode::LinearMode"));
+                shaderOptions.SetValue(m_fogModeOptionName, AZ::Name("FogMode::LinearMode"));
                 break;
             case FogMode::Exponential:
-                shaderOption.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialMode"));
+                shaderOptions.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialMode"));
                 break;
             case FogMode::ExponentialSquared:
-                shaderOption.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialSquaredMode"));
+                shaderOptions.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialSquaredMode"));
                 break;
             default:
                 AZ_Error("DeferredFogPass", false, "Invalid fog mode %d", fogSettings->GetFogMode());
                 break;
             }
-
-            // The following method returns the specified options, as well as fall back values for all 
-            // non-specified options.  If all were set you can use the method GetShaderVariantKey that is 
-            // cheaper but will not make sure the populated values has the default fall back for any unset bit.
-            m_ShaderOptions = shaderOption.GetShaderVariantKeyFallbackValue();
+            shaderOptions.SetUnspecifiedToDefaultValues();
+            if (m_pipelineStateForDraw.GetShaderVariantId() != shaderOptions.GetShaderVariantId())
+            {
+                FullscreenTrianglePass::UpdateShaderOptions(shaderOptions.GetShaderVariantId());
+            }
         }
 
         void DeferredFogPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
@@ -248,26 +250,11 @@ namespace AZ
             DeferredFogSettings* fogSettings = GetPassFogSettings();
 
             UpdateEnable(fogSettings);
-
             // Update and set the per pass shader options - this will update the current required
             // shader variant and if doesn't exist, it will be created via the compile stage
-            if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-            {
-                UpdateShaderOptions();
-            }
-
+            UpdateShaderOptions();
             SetSrgConstants();
         }
-  
-        void DeferredFogPass::CompileResources(const RHI::FrameGraphCompileContext& context)
-        {
-            if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-            {
-                m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_ShaderOptions);
-            }
-
-            FullscreenTrianglePass::CompileResources(context);
-        }
     }   // namespace Render
 }   // namespace AZ
 

+ 0 - 4
Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.h

@@ -59,7 +59,6 @@ namespace AZ
 
             // Scope producer functions...
             void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override;
-            void CompileResources(const RHI::FrameGraphCompileContext& context) override;
 
             //! Set the binding indices of all members of the SRG
             void SetSrgBindIndices();
@@ -77,9 +76,6 @@ namespace AZ
             // actively pass them to the shader.
             DeferredFogSettings m_fallbackSettings;
 
-            // Shader options for variant generation (texture and layer activation in this case)
-            AZ::RPI::ShaderVariantKey m_ShaderOptions;
-
             // Fog mode option name
             const AZ::Name m_fogModeOptionName;
         };       

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshDispatchItem.cpp

@@ -77,7 +77,7 @@ namespace AZ
             const RPI::ShaderVariant& shaderVariant = m_skinningShader->GetVariant(m_shaderOptionGroup.GetShaderVariantId());
 
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, m_shaderOptionGroup);
 
             auto perInstanceSrgLayout = m_skinningShader->FindShaderResourceGroupLayout(AZ::Name{ "InstanceSrg" });
             if (!perInstanceSrgLayout)
@@ -94,7 +94,7 @@ namespace AZ
             }
 
             // If the shader variation is not fully baked, set the fallback key to use a runtime branch for the shader options
-            if (!shaderVariant.IsFullyBaked() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
+            if (shaderVariant.UseKeyFallback() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
             {
                 m_instanceSrg->SetShaderVariantKeyFallbackValue(m_shaderOptionGroup.GetShaderVariantKeyFallbackValue());
             }

+ 11 - 2
Gems/Atom/Feature/Common/Code/Source/SkyAtmosphere/SkyAtmospherePass.cpp

@@ -207,6 +207,7 @@ namespace AZ::Render
 
     void SkyAtmospherePass::UpdatePassData()
     {
+        uint32_t childIndex = 0;
         for (auto passData : m_atmospherePassData)
         {
             passData.m_srg->SetConstant(passData.m_index, m_constants);
@@ -217,8 +218,16 @@ namespace AZ::Render
             passData.m_shaderOptionGroup.SetValue(AZ::Name("o_enableFastAerialPerspective"), AZ::RPI::ShaderOptionValue{ m_fastAerialPerspectiveEnabled });
             passData.m_shaderOptionGroup.SetValue(AZ::Name("o_enableAerialPerspective"), AZ::RPI::ShaderOptionValue{ m_aerialPerspectiveEnabled });
 
-            auto key = passData.m_shaderOptionGroup.GetShaderVariantKeyFallbackValue();
-            passData.m_srg->SetShaderVariantKeyFallbackValue(key);
+            const auto& pass = m_children[childIndex];
+            if (auto fullscreenPass = azrtti_cast<RPI::FullscreenTrianglePass*>(pass); fullscreenPass != nullptr)
+            {
+                fullscreenPass->UpdateShaderOptions(passData.m_shaderOptionGroup.GetShaderVariantId());
+            }
+            else if (auto computePass = azrtti_cast<RPI::ComputePass*>(pass); computePass != nullptr)
+            {
+                computePass->UpdateShaderOptions(passData.m_shaderOptionGroup.GetShaderVariantId());
+            }
+            childIndex++;
         }
     }
 

+ 3 - 1
Gems/Atom/RHI/Code/Include/Atom/RHI.Edit/ShaderPlatformInterface.h

@@ -93,6 +93,7 @@ namespace AZ::RHI
             AZStd::string m_entryFunctionName;
 
             ByProducts m_byProducts;  //!< Optional; used for debug information
+            AZStd::string m_extraData; //!< Optional; extra data that can be pass for creating the Stage function.
         };
 
         //! @apiUniqueIndex See GetApiUniqueIndex() for details.
@@ -123,7 +124,8 @@ namespace AZ::RHI
             ShaderHardwareStage shaderStage,
             const AZStd::string& tempFolderPath,
             StageDescriptor& outputDescriptor,
-            const RHI::ShaderBuildArguments& shaderBuildArguments) const = 0;
+            const RHI::ShaderBuildArguments& shaderBuildArguments,
+            const bool useSpecializationConstants) const = 0;
 
         //! Query whether the shaders are set to build with debug information
         virtual bool BuildHasDebugInfo(const RHI::ShaderBuildArguments& shaderBuildArguments) const

+ 10 - 4
Gems/Atom/RHI/Code/Include/Atom/RHI/PipelineStateDescriptor.h

@@ -13,6 +13,7 @@
 #include <Atom/RHI.Reflect/ShaderStageFunction.h>
 #include <AzCore/Utils/TypeHash.h>
 #include <Atom/RHI.Reflect/PipelineLayoutDescriptor.h>
+#include <Atom/RHI/SpecializationConstant.h>
 
 namespace AZ::RHI
 {
@@ -38,16 +39,21 @@ namespace AZ::RHI
         PipelineStateType GetType() const;
 
         //! Returns the hash of the pipeline state descriptor contents.
-        virtual HashValue64 GetHash() const = 0;
+        HashValue64 GetHash() const;
 
         bool operator == (const PipelineStateDescriptor& rhs) const;
 
         //! The pipeline layout describing the shader resource bindings.
         ConstPtr<PipelineLayoutDescriptor> m_pipelineLayoutDescriptor = nullptr;
 
+        //! Values for specialization constants.
+        AZStd::vector<SpecializationConstant> m_specializationData;
+
     protected:
         PipelineStateDescriptor(PipelineStateType pipelineStateType);
 
+        virtual HashValue64 GetHashInternal() const = 0;
+
     private:
         PipelineStateType m_type = PipelineStateType::Count;
     };
@@ -69,7 +75,7 @@ namespace AZ::RHI
         PipelineStateDescriptorForDispatch();
 
         /// Computes the hash value for this descriptor.
-        HashValue64 GetHash() const override;
+        HashValue64 GetHashInternal() const override;
 
         bool operator == (const PipelineStateDescriptorForDispatch& rhs) const;
 
@@ -91,7 +97,7 @@ namespace AZ::RHI
         PipelineStateDescriptorForDraw();
 
         /// Computes the hash value for this descriptor.
-        HashValue64 GetHash() const override;
+        HashValue64 GetHashInternal() const override;
 
         bool operator == (const PipelineStateDescriptorForDraw& rhs) const;
 
@@ -124,7 +130,7 @@ namespace AZ::RHI
         PipelineStateDescriptorForRayTracing();
 
         //! Computes the hash value for this descriptor.
-        HashValue64 GetHash() const override;
+        HashValue64 GetHashInternal() const override;
 
         bool operator == (const PipelineStateDescriptorForRayTracing& rhs) const;
 

+ 45 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/SpecializationConstant.h

@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <AzCore/Utils/TypeHash.h>
+#include <Atom/RHI.Reflect/Handle.h>
+
+namespace AZ::RHI
+{
+    //! Holds a value for a specialization constant
+    using SpecializationValue = RHI::Handle<uint32_t, struct SpecializationConstant>;
+
+    //! Supported types for specialization constants
+    enum class SpecializationType : uint32_t
+    {
+        Integer,
+        Bool,
+        Invalid
+    };
+
+    //! Contains all the necessary information and value of a specialization constant
+    //! so it can be used when creating a PipelineState.
+    struct SpecializationConstant
+    {
+        SpecializationConstant() = default;
+
+        //! Name of the constant
+        Name m_name;
+        //! Id of the constant
+        uint32_t m_id = 0;
+        //! Value of the constant
+        SpecializationValue m_value;
+        //! Type of the constant
+        SpecializationType m_type = SpecializationType::Invalid;
+
+        bool operator==(const SpecializationConstant& rhs) const;
+        //! Returns a hash of the constant
+        HashValue64 GetHash() const;
+    };
+}

+ 23 - 16
Gems/Atom/RHI/Code/Source/RHI/PipelineStateDescriptor.cpp

@@ -22,9 +22,22 @@ namespace AZ::RHI
         return m_type;
     }
 
-    bool PipelineStateDescriptor::operator == (const PipelineStateDescriptor& rhs) const
+    HashValue64 PipelineStateDescriptor::GetHash() const
     {
-        return m_type == rhs.m_type;
+        AZ_Assert(m_pipelineLayoutDescriptor, "Pipeline layout descriptor is null.");
+        AZ::HashValue64 seed = AZ::HashValue64{ 0 };
+        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
+        for (const auto& constant : m_specializationData)
+        {
+            seed = TypeHash64(constant.GetHash(), seed);
+        }
+        seed = TypeHash64(GetHashInternal(), seed);
+        return seed;
+    }
+
+    bool PipelineStateDescriptor::operator==(const PipelineStateDescriptor& rhs) const
+    {
+        return m_type == rhs.m_type && m_specializationData == rhs.m_specializationData;
     }
 
     PipelineStateDescriptorForDraw::PipelineStateDescriptorForDraw()
@@ -39,21 +52,17 @@ namespace AZ::RHI
         : PipelineStateDescriptor(PipelineStateType::RayTracing)
     {}
 
-    AZ::HashValue64 PipelineStateDescriptorForDispatch::GetHash() const
+    AZ::HashValue64 PipelineStateDescriptorForDispatch::GetHashInternal() const
     {
-        AZ_Assert(m_pipelineLayoutDescriptor, "Pipeline layout descriptor is null.");
         AZ_Assert(m_computeFunction, "Compute function is null.");
 
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
-        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
         seed = TypeHash64(m_computeFunction->GetHash(), seed);
         return seed;
     }
 
-    AZ::HashValue64 PipelineStateDescriptorForDraw::GetHash() const
+    AZ::HashValue64 PipelineStateDescriptorForDraw::GetHashInternal() const
     {
-        AZ_Assert(m_pipelineLayoutDescriptor, "m_pipelineLayoutDescriptor is null.");
-
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
 
         if (m_vertexFunction)
@@ -69,7 +78,6 @@ namespace AZ::RHI
             seed = TypeHash64(m_fragmentFunction->GetHash(), seed);
         }
 
-        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
         seed = TypeHash64(m_inputStreamLayout.GetHash(), seed);
         seed = TypeHash64(m_renderAttachmentConfiguration.GetHash(), seed);
 
@@ -78,12 +86,9 @@ namespace AZ::RHI
         return seed;
     }
 
-    AZ::HashValue64 PipelineStateDescriptorForRayTracing::GetHash() const
+    AZ::HashValue64 PipelineStateDescriptorForRayTracing::GetHashInternal() const
     {
-        AZ_Assert(m_pipelineLayoutDescriptor, "Pipeline layout descriptor is null.");
-
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
-        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
         seed = TypeHash64(m_rayTracingFunction->GetHash(), seed);
         return seed;
     }
@@ -93,18 +98,20 @@ namespace AZ::RHI
         return m_fragmentFunction == rhs.m_fragmentFunction && m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
             m_renderStates == rhs.m_renderStates && m_vertexFunction == rhs.m_vertexFunction &&
             m_geometryFunction == rhs.m_geometryFunction && m_inputStreamLayout == rhs.m_inputStreamLayout && 
-            m_renderAttachmentConfiguration == rhs.m_renderAttachmentConfiguration;
+            m_renderAttachmentConfiguration == rhs.m_renderAttachmentConfiguration && m_specializationData == rhs.m_specializationData;
     }
 
     bool PipelineStateDescriptorForDispatch::operator == (const PipelineStateDescriptorForDispatch& rhs) const
     {
         return m_computeFunction == rhs.m_computeFunction &&
-            m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor;
+            m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
+            m_specializationData == rhs.m_specializationData;
     }
 
     bool PipelineStateDescriptorForRayTracing::operator == (const PipelineStateDescriptorForRayTracing& rhs) const
     {
         return m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
-            m_rayTracingFunction == rhs.m_rayTracingFunction;
+            m_rayTracingFunction == rhs.m_rayTracingFunction &&
+            m_specializationData == rhs.m_specializationData;
     }
 }

+ 31 - 0
Gems/Atom/RHI/Code/Source/RHI/SpecializationConstant.cpp

@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/RHI/SpecializationConstant.h>
+
+namespace AZ::RHI
+{
+    bool SpecializationConstant::operator==(const SpecializationConstant& rhs) const
+    {
+        return
+            m_value == rhs.m_value &&
+            m_name == rhs.m_name &&
+            m_id == rhs.m_id &&
+            m_type == rhs.m_type;
+    }
+
+    HashValue64 SpecializationConstant::GetHash() const
+    {
+        AZ::HashValue64 seed = AZ::HashValue64{ 0 };
+        seed = TypeHash64(m_value.GetIndex(), seed);
+        seed = TypeHash64(m_name.GetHash(), seed);
+        seed = TypeHash64(m_id, seed);
+        seed = TypeHash64(m_type, seed);
+        return seed;
+    }    
+}

+ 2 - 0
Gems/Atom/RHI/Code/atom_rhi_public_files.cmake

@@ -276,4 +276,6 @@ set(FILES
     Include/Atom/RHI/XRRenderingInterface.h
     Include/Atom/RHI/DeviceDispatchRaysIndirectBuffer.h
     Include/Atom/RHI/DispatchRaysIndirectBuffer.h
+    Include/Atom/RHI/SpecializationConstant.h
+    Source/RHI/SpecializationConstant.cpp
 )

+ 12 - 0
Gems/Atom/RHI/DX12/Code/CMakeLists.txt

@@ -90,6 +90,16 @@ ly_add_target(
             ../External/AMD_D3D12MemoryAllocator/v2.0.1
 )
 
+ly_add_target(
+    NAME OpenSSL_md5 STATIC
+    NAMESPACE Gem
+    FILES_CMAKE
+        openssl_md5_files.cmake
+    INCLUDE_DIRECTORIES
+        INTERFACE
+            ../External/md5
+)
+
 ly_add_target(
     NAME ${gem_name}.Reflect STATIC
     NAMESPACE Gem
@@ -131,6 +141,8 @@ ly_add_target(
             Gem::Amd_DX12MA
             3rdParty::d3dx12
             ${AFTERMATH_BUILD_DEPENDENCY}
+        PRIVATE
+            Gem::OpenSSL_md5
     COMPILE_DEFINITIONS 
         PRIVATE
             ${USE_NSIGHT_AFTERMATH_DEFINE}

+ 13 - 0
Gems/Atom/RHI/DX12/Code/Include/Atom/RHI.Reflect/DX12/ShaderStageFunction.h

@@ -33,6 +33,12 @@ namespace AZ
         using ShaderByteCode = AZStd::vector<uint8_t>;
         using ShaderByteCodeView = AZStd::span<const uint8_t>;
 
+        //! Sentinel value used when patching shaders for specialization constants
+        constexpr uint32_t SCSentinelValue = 0x45678900;
+        //! Mask that marks which bytes are used for the sentinel and which
+        //! ones are used for the specialization constant id.
+        constexpr uint64_t SCSentinelMask = 0xffffffffffffff00;
+
         /**
          * A set of indices used to access physical sub-stages within a virtual stage.
          */
@@ -60,6 +66,12 @@ namespace AZ
             /// Returns the assigned byte code.
             ShaderByteCodeView GetByteCode(uint32_t subStageIndex = 0) const;
 
+            using SpecializationOffsets = AZStd::unordered_map<uint32_t, uint32_t>;
+            void SetSpecializationOffsets(uint32_t subStageIndex, const SpecializationOffsets& offsets);
+            const SpecializationOffsets& GetSpecializationOffsets(uint32_t subStageIndex = 0) const;
+
+            bool UseSpecializationConstants(uint32_t subStageIndex = 0) const;
+
         private:
             ShaderStageFunction() = default;
             ShaderStageFunction(RHI::ShaderStage shaderStage);
@@ -74,6 +86,7 @@ namespace AZ
             ///////////////////////////////////////////////////////////////////
 
             AZStd::array<ShaderByteCode, ShaderSubStageCountMax> m_byteCodes;
+            AZStd::array<SpecializationOffsets, ShaderSubStageCountMax> m_specializationOffsets;
         };
     }
 }

+ 73 - 4
Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -16,6 +16,7 @@
 
 #include <AzCore/IO/FileIO.h>
 #include <AzCore/IO/SystemFile.h>
+#include <AzCore/Serialization/Json/JsonUtils.h>
 #include <AzFramework/StringFunc/StringFunc.h>
 
 namespace AZ
@@ -50,6 +51,34 @@ namespace AZ
             const int byteCodeIndex = 0;
             newShaderStageFunction->SetByteCode(byteCodeIndex, byteCode);
 
+            // Read the json data with the specialization constants offsets.
+            // If the shader was not compiled with specialization constants this attribute will be empty.
+            AZStd::string fileName;
+            if (!stageDescriptor.m_extraData.empty())
+            {
+                auto jsonOutcome = JsonSerializationUtils::ReadJsonFile(stageDescriptor.m_extraData);
+                if (!jsonOutcome.IsSuccess())
+                {
+                    AZ_Error(DX12ShaderPlatformName, false, "%s", jsonOutcome.GetError().c_str());
+                    return nullptr;
+                }
+
+                const rapidjson::Document& doc = jsonOutcome.GetValue();
+                ShaderStageFunction::SpecializationOffsets offsets;
+                for (auto itr = doc.MemberBegin(); itr != doc.MemberEnd(); ++itr)
+                {
+                    if (!AZ::StringFunc::LooksLikeInt(itr->name.GetString()))
+                    {
+                        AZ_Error(DX12ShaderPlatformName, false, "SpecializationId %s is not an Int", itr->name.GetString());
+                        continue;
+                    }
+                    uint32_t specializationId = static_cast<uint32_t>(AZ::StringFunc::ToInt(itr->name.GetString()));
+                    uint32_t offset = itr->value.GetUint();
+                    offsets[specializationId] = offset;
+                }
+                newShaderStageFunction->SetSpecializationOffsets(byteCodeIndex, offsets);
+            }
+         
             newShaderStageFunction->Finalize();
 
             return newShaderStageFunction;
@@ -149,10 +178,11 @@ namespace AZ
             RHI::ShaderHardwareStage shaderStage,
             const AZStd::string& tempFolderPath,
             StageDescriptor& outputDescriptor,
-            const RHI::ShaderBuildArguments& shaderBuildArguments) const
+            const RHI::ShaderBuildArguments& shaderBuildArguments,
+            const bool useSpecializationConstants) const
         {
             AZStd::vector<uint8_t> shaderByteCode;
-
+            AZStd::string specializationOffsetsFile;
             // Compile HLSL shader to byte code
             bool compiledSucessfully = CompileHLSLShader(
                 shaderSourcePath,                        // shader source filepath
@@ -161,7 +191,9 @@ namespace AZ
                 shaderStage,                             // shader stage (vertex shader, pixel shader, ...)
                 shaderBuildArguments,
                 shaderByteCode,                          // compiled shader output
-                outputDescriptor.m_byProducts);          // dynamic branch count output & byproduct files
+                outputDescriptor.m_byProducts,           // dynamic branch count output & byproduct files
+                specializationOffsetsFile,               // path to the json file with the specialization offsets
+                useSpecializationConstants);             // if the shader stage it's using specialization constants
 
             if (!compiledSucessfully)
             {
@@ -174,6 +206,7 @@ namespace AZ
             {
                 outputDescriptor.m_stageType = shaderStage;
                 outputDescriptor.m_byteCode = AZStd::move(shaderByteCode);
+                outputDescriptor.m_extraData = AZStd::move(specializationOffsetsFile);
             }
             else
             {
@@ -197,7 +230,9 @@ namespace AZ
             const RHI::ShaderHardwareStage shaderStageType,
             const RHI::ShaderBuildArguments& shaderBuildArguments,
             AZStd::vector<uint8_t>& compiledShader,
-            ByProducts& byProducts) const
+            ByProducts& byProducts,
+            AZStd::string& specializationOffsetsFile,
+            const bool useSpecializationConstants) const
         {
             // Shader compiler executable
             const auto dxcRelativePath = RHI::GetDirectXShaderCompilerPath("Builders/DirectXShaderCompiler/dxc.exe");
@@ -298,6 +333,40 @@ namespace AZ
                 return false;
             }
 
+            if (useSpecializationConstants)
+            {
+                // Need to patch the shader so it can be used with specialization constants.
+                const auto dxscRelativePath = RHI::GetDirectXShaderCompilerPath("Builders/DirectXShaderCompiler/dxsc.exe");
+
+                AZStd::string shaderOutputCommon;
+                AzFramework::StringFunc::Path::GetFileName(shaderSourceFile.c_str(), shaderOutputCommon);
+                AzFramework::StringFunc::Path::Join(tempFolder.c_str(), shaderOutputCommon.c_str(), shaderOutputCommon);
+
+                AZStd::string patchedShaderOutput = shaderOutputCommon;
+                AzFramework::StringFunc::Path::ReplaceExtension(patchedShaderOutput, "dxil.patched.bin");
+                AZStd::string offsetsOutput = shaderOutputCommon;
+                AzFramework::StringFunc::Path::ReplaceExtension(offsetsOutput, "offsets.json");
+
+                const auto dxscCommandOptions = AZStd::string::format(
+                    //   1.sentinel    3.offsets_output   
+                    //     |    2.output    |   4.dxil-in
+                    //     |       |        |      |
+                    "-sv=%lu -o=\"%s\" -f=\"%s\" \"%s\"",
+                    static_cast<unsigned long>(SCSentinelValue), // 1
+                    patchedShaderOutput.c_str(), // 2
+                    offsetsOutput.c_str(), // 3
+                    shaderOutputFile.c_str() // 4
+                );
+
+                if (!RHI::ExecuteShaderCompiler(dxscRelativePath, dxscCommandOptions, shaderSourceFile, tempFolder, "DXSC"))
+                {
+                    return false;
+                }
+                shaderOutputFile = patchedShaderOutput;
+
+                specializationOffsetsFile = offsetsOutput;
+            }
+
             auto shaderOutputFileLoadResult = AZ::RHI::LoadFileBytes(shaderOutputFile.c_str());
             if (!shaderOutputFileLoadResult)
             {

+ 5 - 2
Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -45,7 +45,8 @@ namespace AZ
                 RHI::ShaderHardwareStage shaderStage,
                 const AZStd::string& tempFolderPath,
                 StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override;
 
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
 
@@ -59,7 +60,9 @@ namespace AZ
                 const RHI::ShaderHardwareStage shaderStageType,
                 const RHI::ShaderBuildArguments& shaderBuildArguments,
                 AZStd::vector<uint8_t>& m_byteCode,
-                ByProducts& products) const;
+                ByProducts& products,
+                AZStd::string& specializationOffsetsFile,
+                const bool useSpecializationConstants) const;
 
             const Name m_apiName;
         };

+ 18 - 2
Gems/Atom/RHI/DX12/Code/Source/RHI.Reflect/ShaderStageFunction.cpp

@@ -19,8 +19,9 @@ namespace AZ
             if (SerializeContext* serializeContext = azrtti_cast<SerializeContext*>(context))
             {
                 serializeContext->Class<ShaderStageFunction, RHI::ShaderStageFunction>()
-                    ->Version(1)
-                    ->Field("m_byteCodes", &ShaderStageFunction::m_byteCodes);
+                    ->Version(2)
+                    ->Field("m_byteCodes", &ShaderStageFunction::m_byteCodes)
+                    ->Field("m_specializationOffsets", &ShaderStageFunction::m_specializationOffsets);
             }
         }
 
@@ -44,6 +45,21 @@ namespace AZ
             return ShaderByteCodeView(m_byteCodes[subStageIndex]);
         }
 
+        void ShaderStageFunction::SetSpecializationOffsets(uint32_t subStageIndex, const SpecializationOffsets& offsets)
+        {
+            m_specializationOffsets[subStageIndex] = offsets;
+        }
+
+        const ShaderStageFunction::SpecializationOffsets& ShaderStageFunction::GetSpecializationOffsets(uint32_t subStageIndex) const
+        {
+            return m_specializationOffsets[subStageIndex];
+        }
+
+        bool ShaderStageFunction::UseSpecializationConstants(uint32_t subStageIndex) const
+        {
+            return !GetSpecializationOffsets(subStageIndex).empty();
+        }
+
         RHI::ResultCode ShaderStageFunction::FinalizeInternal()
         {
             bool emptyByteCodes = true;

+ 2 - 0
Gems/Atom/RHI/DX12/Code/Source/RHI/DX12.h

@@ -33,6 +33,8 @@
 #define IID_GRAPHICS_PPV_ARGS(ppType) IID_PPV_ARGS(ppType)
 #endif
 
+#define MAKE_FOURCC(a, b, c, d) (((uint32_t)(d) << 24) | ((uint32_t)(c) << 16) | ((uint32_t)(b) << 8) | (uint32_t)(a))
+
 namespace AZ
 {
     namespace DX12

+ 14 - 6
Gems/Atom/RHI/DX12/Code/Source/RHI/PipelineState.cpp

@@ -10,6 +10,8 @@
 #include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
 #include <RHI/Conversions.h>
 #include <RHI/Device.h>
+#include <RHI/ShaderUtils.h>
+
 namespace AZ
 {
     namespace DX12
@@ -55,20 +57,24 @@ namespace AZ
             // Shader state.
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             pipelineStateDesc.pRootSignature = pipelineLayout->Get();
-
+            // Cache used for saving the patched version of the shader
+            AZStd::vector<ShaderByteCode> shaderByteCodeCache;
             if (const ShaderStageFunction* vertexFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_vertexFunction.get()))
             {
-                pipelineStateDesc.VS = D3D12BytecodeFromView(vertexFunction->GetByteCode());
+                pipelineStateDesc.VS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*vertexFunction, descriptor, shaderByteCodeCache));
             }
 
             if (const ShaderStageFunction* geometryFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_geometryFunction.get()))
             {
-                pipelineStateDesc.GS = D3D12BytecodeFromView(geometryFunction->GetByteCode());
+                pipelineStateDesc.GS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*geometryFunction, descriptor, shaderByteCodeCache));
             }
 
             if (const ShaderStageFunction* fragmentFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_fragmentFunction.get()))
             {
-                pipelineStateDesc.PS = D3D12BytecodeFromView(fragmentFunction->GetByteCode());
+                pipelineStateDesc.PS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*fragmentFunction, descriptor, shaderByteCodeCache));
             }
 
             const RHI::RenderAttachmentConfiguration& renderAttachmentConfiguration = descriptor.m_renderAttachmentConfiguration;
@@ -130,10 +136,12 @@ namespace AZ
 
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             pipelineStateDesc.pRootSignature = pipelineLayout->Get();
-
+            // Cache used for saving the patched version of the shader
+            AZStd::vector<ShaderByteCode> shaderByteCodeCache;
             if (const ShaderStageFunction* computeFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_computeFunction.get()))
             {
-                pipelineStateDesc.CS = D3D12BytecodeFromView(computeFunction->GetByteCode());
+                pipelineStateDesc.CS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*computeFunction, descriptor, shaderByteCodeCache));
             }
 
             PipelineLibrary* pipelineLibrary = static_cast<PipelineLibrary*>(pipelineLibraryBase);

+ 7 - 2
Gems/Atom/RHI/DX12/Code/Source/RHI/RayTracingPipelineState.cpp

@@ -10,6 +10,8 @@
 #include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
 #include <RHI/Conversions.h>
 #include <RHI/Device.h>
+#include <RHI/ShaderUtils.h>
+
 namespace AZ
 {
     namespace DX12
@@ -53,12 +55,15 @@ namespace AZ
             // add DXIL Libraries
             AZStd::vector<D3D12_DXIL_LIBRARY_DESC> libraryDescs;
             libraryDescs.reserve(dxilLibraryCount);
+            AZStd::vector<ShaderByteCode> patchedShaderCache;
             for (const RHI::RayTracingShaderLibrary& shaderLibrary : descriptor->GetShaderLibraries())
             {
                 const ShaderStageFunction* rayTracingFunction = azrtti_cast<const ShaderStageFunction*>(shaderLibrary.m_descriptor.m_rayTracingFunction.get());
-        
+                ShaderByteCodeView byteCode =
+                    ShaderUtils::PatchShaderFunction(*rayTracingFunction, shaderLibrary.m_descriptor, patchedShaderCache);
+
                 D3D12_DXIL_LIBRARY_DESC libraryDesc = {};
-                libraryDesc.DXILLibrary = D3D12_SHADER_BYTECODE{ rayTracingFunction->GetByteCode().data(), rayTracingFunction->GetByteCode().size() };
+                libraryDesc.DXILLibrary = D3D12_SHADER_BYTECODE{ byteCode.data(), byteCode.size() };
                 libraryDesc.NumExports = 0; // all shaders
                 libraryDesc.pExports = nullptr;
                 libraryDescs.push_back(libraryDesc);

+ 237 - 0
Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.cpp

@@ -0,0 +1,237 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <RHI/ShaderUtils.h>
+#include <openssl/md5.h>
+#include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
+
+namespace AZ::DX12
+{
+    static const uint32_t FOURCC_DXBC = MAKE_FOURCC('D', 'X', 'B', 'C');
+
+    // Modify the bits in the bytecode with a new value following the VBR rules of encoding
+    uint64_t TamperBits(uint8_t* byteCode, uint32_t patchVal, uint64_t offset)
+    {
+        constexpr uint32_t BitsPerByte = 8;
+        // LSB is used for encoding signed/unsigned
+        // VBR left shift all values to leave space for the sign bit
+        patchVal <<= 1;
+
+        uint64_t original = 0;
+        uint64_t currentOffset = offset;
+
+        auto SetBitFunc = [&](uint64_t pos, bool value)
+        {
+            if (value)
+            {
+                byteCode[pos / BitsPerByte] |= (uint8_t)1 << pos % BitsPerByte;
+            }
+            else
+            {
+                byteCode[pos / BitsPerByte] &= ~((uint8_t)1 << pos % BitsPerByte);
+            }
+        };
+        uint32_t originalValueBitIndex = 0;
+        // 32 bits take 5 full bytes in VBR.
+        const uint32_t sentinelBytes = 5;
+        for (uint32_t i = 0; i < sentinelBytes; i++)
+        {
+            // Patch all bits except the continuation bit
+            for (uint32_t j = 0; j < (BitsPerByte - 1); j++)
+            {
+                bool currentValue = ((byteCode[currentOffset / BitsPerByte] & ((uint8_t)1 << currentOffset % BitsPerByte)) != 0);
+                bool newValue = (patchVal & (uint64_t)1) != 0;
+                SetBitFunc(currentOffset, newValue);
+                original |= (uint64_t)currentValue << originalValueBitIndex;
+                patchVal >>= 1;
+                currentOffset++;
+                originalValueBitIndex++;
+            }
+            // Set continuation bit
+            SetBitFunc(currentOffset, true);
+            currentOffset++;
+        }
+        // MSB in VBR doesn't have a continuation bit (because it's the last byte)
+        SetBitFunc(currentOffset - 1, false);
+        // VBR left shift values for sign bit, so we right shift the value we found
+        return original >> 1;
+    }
+
+    ShaderByteCode ShaderUtils::PatchShaderFunction(
+        const ShaderStageFunction& shaderFunction, const RHI::PipelineStateDescriptor& descriptor)
+    {
+        ShaderByteCode patched(shaderFunction.GetByteCode().size());
+        ::memcpy(patched.data(), shaderFunction.GetByteCode().data(), patched.size());
+        const AZStd::vector<RHI::SpecializationConstant>& specializationConstants = descriptor.m_specializationData;
+        for (const auto& element : shaderFunction.GetSpecializationOffsets())
+        {
+            auto findIter = AZStd::find_if(
+                specializationConstants.begin(),
+                specializationConstants.end(),
+                [&](const RHI::SpecializationConstant& constantData)
+                {
+                    return constantData.m_id == element.first;
+                });
+
+            if (findIter == specializationConstants.end())
+            {
+                AZ_Error("ShaderUtils", false, "Specialization constant %d doesn't not have a value", element.first);
+                continue;
+            }
+
+            [[maybe_unused]] uint64_t sentinelFound = TamperBits(patched.data(), findIter->m_value.GetIndex(), element.second);
+            AZ_Assert(
+                static_cast<uint32_t>(sentinelFound & SCSentinelMask) == SCSentinelValue,
+                "Invalid sentinel value found %lu",
+                sentinelFound);
+        }
+
+        // Re-sign the shader bytecode after we patch it
+        if (!SignByteCode(patched))
+        {
+            AZ_Error("ShaderUtils", false, "Failed to sign container");
+            return {};
+        }
+
+        return patched;
+    }
+
+    ShaderByteCodeView ShaderUtils::PatchShaderFunction(
+        const ShaderStageFunction& shaderFunction,
+        const RHI::PipelineStateDescriptor& descriptor,
+        AZStd::vector<ShaderByteCode>& patchedShaderContainer)
+    {
+        if (!shaderFunction.UseSpecializationConstants())
+        {
+            // No need to patch anything
+            return shaderFunction.GetByteCode();
+        }
+
+        ShaderByteCode patchedShader = PatchShaderFunction(shaderFunction, descriptor);
+        patchedShaderContainer.emplace_back(AZStd::move(patchedShader));
+        return patchedShaderContainer.back();
+    }
+
+    bool ShaderUtils::SignByteCode(ShaderByteCode& bytecode)
+    {
+        // Original signing code from Renderdoc
+        struct FileHeader
+        {
+            uint32_t fourcc; // "DXBC"
+            uint32_t hashValue[4]; // unknown hash function and data
+            uint32_t containerVersion;
+            uint32_t fileLength;
+            uint32_t numChunks;
+            // uint32 chunkOffsets[numChunks]; follows
+        };
+
+        if (bytecode.size() < sizeof(FileHeader) || bytecode.data() == nullptr)
+        {
+            return false;
+        }
+
+        FileHeader* header = reinterpret_cast<FileHeader*>(bytecode.data());
+
+        if (header->fourcc != FOURCC_DXBC)
+        {
+            return false;
+        }
+
+        if (header->fileLength != static_cast<uint32_t>(bytecode.size()))
+        {
+            return false;
+        }
+
+        _MD5_CTX md5ctx = {};
+        _MD5_Init(&md5ctx);
+
+        // the hashable data starts immediately after the hash.
+        AZStd::byte* data = reinterpret_cast<AZStd::byte*>(&header->containerVersion);
+        uint32_t length = uint32_t(bytecode.size() - offsetof(FileHeader, containerVersion));
+
+        // we need to know the number of bits for putting in the trailing padding.
+        uint32_t numBits = length * 8;
+        uint32_t numBitsPart2 = (numBits >> 2) | 1;
+
+        // MD5 works on 64-byte chunks, process the first set of whole chunks, leaving 0-63 bytes left
+        // over
+        uint32_t leftoverLength = length % 64;
+        _MD5_Update(&md5ctx, data, length - leftoverLength);
+
+        data += length - leftoverLength;
+
+        uint32_t block[16] = {};
+        AZ_Assert(sizeof(block) == 64, "Block is not properly sized for MD5 round");
+
+        // normally MD5 finishes by appending a 1 bit to the bitstring. Since we are only appending bytes
+        // this would be an 0x80 byte (the first bit is considered to be the MSB). Then it pads out with
+        // zeroes until it has 56 bytes in the last block and appends appends the message length as a
+        // 64-bit integer as the final part of that block.
+        // in other words, normally whatever is leftover from the actual message gets one byte appended,
+        // then if there's at least 8 bytes left we'll append the length. Otherwise we pad that block with
+        // 0s and create a new block with the length at the end.
+        // Or as the original RFC/spec says: padding is always performed regardless of whether the
+        // original buffer already ended in exactly a 56 byte block.
+        //
+        // The DXBC finalisation is slightly different (previous work suggests this is due to a bug in the
+        // original implementation and it was maybe intended to be exactly MD5?):
+        //
+        // The length provided in the padding block is not 64-bit properly: the second dword with the high
+        // bits is instead the number of nybbles(?) with 1 OR'd on. The length is also split, so if it's
+        // in
+        // a padding block the low bits are in the first dword and the upper bits in the last. If there's
+        // no padding block the low dword is passed in first before the leftovers of the message and then
+        // the upper bits at the end.
+
+        // if the leftovers uses at least 56, we can't fit both the trailing 1 and the 64-bit length, so
+        // we need a padding block and then our own block for the length.
+        if (leftoverLength >= 56)
+        {
+            // pass in the leftover data padded out to 64 bytes with zeroes
+            _MD5_Update(&md5ctx, data, leftoverLength);
+
+            block[0] = 0x80; // first padding bit is 1
+            _MD5_Update(&md5ctx, block, 64 - leftoverLength);
+
+            // the final block contains the number of bits in the first dword, and the weird upper bits
+            block[0] = numBits;
+            block[15] = numBitsPart2;
+
+            // process this block directly, we're replacing the call to MD5_Final here manually
+            _MD5_Update(&md5ctx, block, 64);
+        }
+        else
+        {
+            // the leftovers mean we can put the padding inside the final block. But first we pass the "low"
+            // number of bits:
+            _MD5_Update(&md5ctx, &numBits, sizeof(numBits));
+
+            if (leftoverLength)
+            {
+                _MD5_Update(&md5ctx, data, leftoverLength);
+            }
+
+            uint32_t paddingBytes = 64 - leftoverLength - 4;
+
+            // prepare the remainder of this block, starting with the 0x80 padding start right after the
+            // leftovers and the first part of the bit length above.
+            block[0] = 0x80;
+            // then add the remainder of the 'length' here in the final part of the block
+            ::memcpy(((AZStd::byte*)block) + paddingBytes - 4, &numBitsPart2, 4);
+
+            _MD5_Update(&md5ctx, block, paddingBytes);
+        }
+
+        header->hashValue[0] = md5ctx.a;
+        header->hashValue[1] = md5ctx.b;
+        header->hashValue[2] = md5ctx.c;
+        header->hashValue[3] = md5ctx.d;
+
+        return true;
+    }
+} // namespace AZ

+ 36 - 0
Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.h

@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI/PipelineStateDescriptor.h>
+#include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
+
+namespace AZ::DX12
+{
+    namespace ShaderUtils
+    {
+        //! Patch a shader bytecode with the proper values of the specialization constants found in the
+        //! pipeline descriptor.
+        ShaderByteCode PatchShaderFunction(
+            const ShaderStageFunction& shaderFunction,
+            const RHI::PipelineStateDescriptor& descriptor);
+
+        //! Patch a shader bytecode with the proper values of the specialization constants found in the
+        //! pipeline descriptor. If the pipeline descriptor is not using specialization constants, it returns the
+        //! shader bytecode unchanged. If it needs to patch it, the patched shader bytecode is stored in the provided container.
+        //! Refer to RFC (https://github.com/o3de/sig-graphics-audio/blob/main/rfcs/SpecializationConstants/SpecializationConstants.md)
+        //! for more details on how specialization constants works on DX12 
+        ShaderByteCodeView PatchShaderFunction(
+            const ShaderStageFunction& shaderFunction,
+            const RHI::PipelineStateDescriptor& descriptor,
+            AZStd::vector<ShaderByteCode>& patchedShaderContainer);
+
+        //! Signs a DXIL blob so it can be used by the driver. Only needed if the bytecode has been modified.
+        bool SignByteCode(ShaderByteCode& bytecode);
+    }
+} // namespace AZ

+ 2 - 0
Gems/Atom/RHI/DX12/Code/atom_rhi_dx12_private_common_files.cmake

@@ -123,4 +123,6 @@ set(FILES
     Source/RHI/RayTracingPipelineState.h
     Source/RHI/RayTracingShaderTable.cpp
     Source/RHI/RayTracingShaderTable.h
+    Source/RHI/ShaderUtils.cpp
+    Source/RHI/ShaderUtils.h
 )

+ 12 - 0
Gems/Atom/RHI/DX12/Code/openssl_md5_files.cmake

@@ -0,0 +1,12 @@
+#
+# Copyright (c) Contributors to the Open 3D Engine Project.
+# For complete copyright and license terms please see the LICENSE at the root of this distribution.
+#
+# SPDX-License-Identifier: Apache-2.0 OR MIT
+#
+#
+
+set(FILES
+    ../External/md5/openssl/md5.c
+    ../External/md5/openssl/md5.h
+)

+ 38 - 0
Gems/Atom/RHI/DX12/External/md5/README.md

@@ -0,0 +1,38 @@
+Fetched from https://openwall.info/wiki/people/solar/software/public-domain-source-code/md5 on 2021-07-07
+
+Public domain licensed:
+
+> This is an OpenSSL-compatible implementation of the RSA Data Security, Inc.
+> MD5 Message-Digest Algorithm (RFC 1321).
+>
+> Homepage:
+> http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5
+>
+> Author:
+> Alexander Peslyak, better known as Solar Designer <solar at openwall.com>
+>
+> This software was written by Alexander Peslyak in 2001.  No copyright is
+> claimed, and the software is hereby placed in the public domain.
+> In case this attempt to disclaim copyright and place the software in the
+> public domain is deemed null and void, then the software is
+> Copyright (c) 2001 Alexander Peslyak and it is hereby released to the
+> general public under the following terms:
+>
+> Redistribution and use in source and binary forms, with or without
+> modification, are permitted.
+>
+> There's ABSOLUTELY NO WARRANTY, express or implied.
+>
+> (This is a heavily cut-down "BSD license".)
+>
+> This differs from Colin Plumb's older public domain implementation in that
+> no exactly 32-bit integer data type is required (any 32-bit or wider
+> unsigned integer data type will do), there's no compile-time endianness
+> configuration, and the function prototypes match OpenSSL's.  No code from
+> Colin Plumb's implementation has been reused; this comment merely compares
+> the properties of the two independent implementations.
+>
+> The primary goals of this implementation are portability and ease of use.
+> It is meant to be fast, but not as fast as possible.  Some known
+> optimizations are not included to reduce source code size and avoid
+> compile-time configuration.

+ 289 - 0
Gems/Atom/RHI/DX12/External/md5/openssl/md5.c

@@ -0,0 +1,289 @@
+/*
+ * This is an OpenSSL-compatible implementation of the RSA Data Security, Inc.
+ * MD5 Message-Digest Algorithm (RFC 1321).
+ *
+ * Homepage:
+ * http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5
+ *
+ * Author:
+ * Alexander Peslyak, better known as Solar Designer <solar at openwall.com>
+ *
+ * This software was written by Alexander Peslyak in 2001.  No copyright is
+ * claimed, and the software is hereby placed in the public domain.
+ * In case this attempt to disclaim copyright and place the software in the
+ * public domain is deemed null and void, then the software is
+ * Copyright (c) 2001 Alexander Peslyak and it is hereby released to the
+ * general public under the following terms:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted.
+ *
+ * There's ABSOLUTELY NO WARRANTY, express or implied.
+ *
+ * (This is a heavily cut-down "BSD license".)
+ *
+ * This differs from Colin Plumb's older public domain implementation in that
+ * no exactly 32-bit integer data type is required (any 32-bit or wider
+ * unsigned integer data type will do), there's no compile-time endianness
+ * configuration, and the function prototypes match OpenSSL's.  No code from
+ * Colin Plumb's implementation has been reused; this comment merely compares
+ * the properties of the two independent implementations.
+ *
+ * The primary goals of this implementation are portability and ease of use.
+ * It is meant to be fast, but not as fast as possible.  Some known
+ * optimizations are not included to reduce source code size and avoid
+ * compile-time configuration.
+ */
+
+#ifndef HAVE_OPENSSL
+
+#include "md5.h"
+
+/*
+ * The basic MD5 functions.
+ *
+ * F and G are optimized compared to their RFC 1321 definitions for
+ * architectures that lack an AND-NOT instruction, just like in Colin Plumb's
+ * implementation.
+ */
+#define F(x, y, z)			((z) ^ ((x) & ((y) ^ (z))))
+#define G(x, y, z)			((y) ^ ((z) & ((x) ^ (y))))
+#define H(x, y, z)			(((x) ^ (y)) ^ (z))
+#define H2(x, y, z)			((x) ^ ((y) ^ (z)))
+#define I(x, y, z)			((y) ^ ((x) | ~(z)))
+
+/*
+ * The MD5 transformation for all four rounds.
+ */
+#define STEP(f, a, b, c, d, x, t, s) \
+	(a) += f((b), (c), (d)) + (x) + (t); \
+	(a) = (((a) << (s)) | (((a) & 0xffffffff) >> (32 - (s)))); \
+	(a) += (b);
+
+/*
+ * SET reads 4 input bytes in little-endian byte order and stores them in a
+ * properly aligned word in host byte order.
+ *
+ * The check for little-endian architectures that tolerate unaligned memory
+ * accesses is just an optimization.  Nothing will break if it fails to detect
+ * a suitable architecture.
+ *
+ * Unfortunately, this optimization may be a C strict aliasing rules violation
+ * if the caller's data buffer has effective type that cannot be aliased by
+ * MD5_u32plus.  In practice, this problem may occur if these MD5 routines are
+ * inlined into a calling function, or with future and dangerously advanced
+ * link-time optimizations.  For the time being, keeping these MD5 routines in
+ * their own translation unit avoids the problem.
+ */
+#if defined(__i386__) || defined(__x86_64__) || defined(__vax__)
+#define SET(n) \
+	(*(_MD5_u32plus *)&ptr[(n) * 4])
+#define GET(n) \
+	SET(n)
+#else
+#define SET(n) \
+	(ctx->block[(n)] = \
+	(_MD5_u32plus)ptr[(n) * 4] | \
+	((_MD5_u32plus)ptr[(n) * 4 + 1] << 8) | \
+	((_MD5_u32plus)ptr[(n) * 4 + 2] << 16) | \
+	((_MD5_u32plus)ptr[(n) * 4 + 3] << 24))
+#define GET(n) \
+	(ctx->block[(n)])
+#endif
+
+/*
+ * This processes one or more 64-byte data blocks, but does NOT update the bit
+ * counters.  There are no alignment requirements.
+ */
+static const void *body(_MD5_CTX *ctx, const void *data, unsigned long size)
+{
+	const unsigned char *ptr;
+	_MD5_u32plus a, b, c, d;
+	_MD5_u32plus saved_a, saved_b, saved_c, saved_d;
+
+	ptr = (const unsigned char *)data;
+
+	a = ctx->a;
+	b = ctx->b;
+	c = ctx->c;
+	d = ctx->d;
+
+	do {
+		saved_a = a;
+		saved_b = b;
+		saved_c = c;
+		saved_d = d;
+
+/* Round 1 */
+		STEP(F, a, b, c, d, SET(0), 0xd76aa478, 7)
+		STEP(F, d, a, b, c, SET(1), 0xe8c7b756, 12)
+		STEP(F, c, d, a, b, SET(2), 0x242070db, 17)
+		STEP(F, b, c, d, a, SET(3), 0xc1bdceee, 22)
+		STEP(F, a, b, c, d, SET(4), 0xf57c0faf, 7)
+		STEP(F, d, a, b, c, SET(5), 0x4787c62a, 12)
+		STEP(F, c, d, a, b, SET(6), 0xa8304613, 17)
+		STEP(F, b, c, d, a, SET(7), 0xfd469501, 22)
+		STEP(F, a, b, c, d, SET(8), 0x698098d8, 7)
+		STEP(F, d, a, b, c, SET(9), 0x8b44f7af, 12)
+		STEP(F, c, d, a, b, SET(10), 0xffff5bb1, 17)
+		STEP(F, b, c, d, a, SET(11), 0x895cd7be, 22)
+		STEP(F, a, b, c, d, SET(12), 0x6b901122, 7)
+		STEP(F, d, a, b, c, SET(13), 0xfd987193, 12)
+		STEP(F, c, d, a, b, SET(14), 0xa679438e, 17)
+		STEP(F, b, c, d, a, SET(15), 0x49b40821, 22)
+
+/* Round 2 */
+		STEP(G, a, b, c, d, GET(1), 0xf61e2562, 5)
+		STEP(G, d, a, b, c, GET(6), 0xc040b340, 9)
+		STEP(G, c, d, a, b, GET(11), 0x265e5a51, 14)
+		STEP(G, b, c, d, a, GET(0), 0xe9b6c7aa, 20)
+		STEP(G, a, b, c, d, GET(5), 0xd62f105d, 5)
+		STEP(G, d, a, b, c, GET(10), 0x02441453, 9)
+		STEP(G, c, d, a, b, GET(15), 0xd8a1e681, 14)
+		STEP(G, b, c, d, a, GET(4), 0xe7d3fbc8, 20)
+		STEP(G, a, b, c, d, GET(9), 0x21e1cde6, 5)
+		STEP(G, d, a, b, c, GET(14), 0xc33707d6, 9)
+		STEP(G, c, d, a, b, GET(3), 0xf4d50d87, 14)
+		STEP(G, b, c, d, a, GET(8), 0x455a14ed, 20)
+		STEP(G, a, b, c, d, GET(13), 0xa9e3e905, 5)
+		STEP(G, d, a, b, c, GET(2), 0xfcefa3f8, 9)
+		STEP(G, c, d, a, b, GET(7), 0x676f02d9, 14)
+		STEP(G, b, c, d, a, GET(12), 0x8d2a4c8a, 20)
+
+/* Round 3 */
+		STEP(H, a, b, c, d, GET(5), 0xfffa3942, 4)
+		STEP(H2, d, a, b, c, GET(8), 0x8771f681, 11)
+		STEP(H, c, d, a, b, GET(11), 0x6d9d6122, 16)
+		STEP(H2, b, c, d, a, GET(14), 0xfde5380c, 23)
+		STEP(H, a, b, c, d, GET(1), 0xa4beea44, 4)
+		STEP(H2, d, a, b, c, GET(4), 0x4bdecfa9, 11)
+		STEP(H, c, d, a, b, GET(7), 0xf6bb4b60, 16)
+		STEP(H2, b, c, d, a, GET(10), 0xbebfbc70, 23)
+		STEP(H, a, b, c, d, GET(13), 0x289b7ec6, 4)
+		STEP(H2, d, a, b, c, GET(0), 0xeaa127fa, 11)
+		STEP(H, c, d, a, b, GET(3), 0xd4ef3085, 16)
+		STEP(H2, b, c, d, a, GET(6), 0x04881d05, 23)
+		STEP(H, a, b, c, d, GET(9), 0xd9d4d039, 4)
+		STEP(H2, d, a, b, c, GET(12), 0xe6db99e5, 11)
+		STEP(H, c, d, a, b, GET(15), 0x1fa27cf8, 16)
+		STEP(H2, b, c, d, a, GET(2), 0xc4ac5665, 23)
+
+/* Round 4 */
+		STEP(I, a, b, c, d, GET(0), 0xf4292244, 6)
+		STEP(I, d, a, b, c, GET(7), 0x432aff97, 10)
+		STEP(I, c, d, a, b, GET(14), 0xab9423a7, 15)
+		STEP(I, b, c, d, a, GET(5), 0xfc93a039, 21)
+		STEP(I, a, b, c, d, GET(12), 0x655b59c3, 6)
+		STEP(I, d, a, b, c, GET(3), 0x8f0ccc92, 10)
+		STEP(I, c, d, a, b, GET(10), 0xffeff47d, 15)
+		STEP(I, b, c, d, a, GET(1), 0x85845dd1, 21)
+		STEP(I, a, b, c, d, GET(8), 0x6fa87e4f, 6)
+		STEP(I, d, a, b, c, GET(15), 0xfe2ce6e0, 10)
+		STEP(I, c, d, a, b, GET(6), 0xa3014314, 15)
+		STEP(I, b, c, d, a, GET(13), 0x4e0811a1, 21)
+		STEP(I, a, b, c, d, GET(4), 0xf7537e82, 6)
+		STEP(I, d, a, b, c, GET(11), 0xbd3af235, 10)
+		STEP(I, c, d, a, b, GET(2), 0x2ad7d2bb, 15)
+		STEP(I, b, c, d, a, GET(9), 0xeb86d391, 21)
+
+		a += saved_a;
+		b += saved_b;
+		c += saved_c;
+		d += saved_d;
+
+		ptr += 64;
+	} while (size -= 64);
+
+	ctx->a = a;
+	ctx->b = b;
+	ctx->c = c;
+	ctx->d = d;
+
+	return ptr;
+}
+
+void _MD5_Init(_MD5_CTX *ctx)
+{
+	ctx->a = 0x67452301;
+	ctx->b = 0xefcdab89;
+	ctx->c = 0x98badcfe;
+	ctx->d = 0x10325476;
+
+	ctx->lo = 0;
+	ctx->hi = 0;
+}
+
+void _MD5_Update(_MD5_CTX *ctx, const void *data, unsigned long size)
+{
+	_MD5_u32plus saved_lo;
+	unsigned long used, available;
+
+	saved_lo = ctx->lo;
+	if ((ctx->lo = (saved_lo + size) & 0x1fffffff) < saved_lo)
+		ctx->hi++;
+	ctx->hi += size >> 29;
+
+	used = saved_lo & 0x3f;
+
+	if (used) {
+		available = 64 - used;
+
+		if (size < available) {
+			memcpy(&ctx->buffer[used], data, size);
+			return;
+		}
+
+		memcpy(&ctx->buffer[used], data, available);
+		data = (const unsigned char *)data + available;
+		size -= available;
+		body(ctx, ctx->buffer, 64);
+	}
+
+	if (size >= 64) {
+		data = body(ctx, data, size & ~(unsigned long)0x3f);
+		size &= 0x3f;
+	}
+
+	memcpy(ctx->buffer, data, size);
+}
+
+#define OUT(dst, src) \
+	(dst)[0] = (unsigned char)(src); \
+	(dst)[1] = (unsigned char)((src) >> 8); \
+	(dst)[2] = (unsigned char)((src) >> 16); \
+	(dst)[3] = (unsigned char)((src) >> 24);
+
+void _MD5_Final(unsigned char *result, _MD5_CTX *ctx)
+{
+	unsigned long used, available;
+
+	used = ctx->lo & 0x3f;
+
+	ctx->buffer[used++] = 0x80;
+
+	available = 64 - used;
+
+	if (available < 8) {
+		memset(&ctx->buffer[used], 0, available);
+		body(ctx, ctx->buffer, 64);
+		used = 0;
+		available = 64;
+	}
+
+	memset(&ctx->buffer[used], 0, available - 8);
+
+	ctx->lo <<= 3;
+	OUT(&ctx->buffer[56], ctx->lo)
+	OUT(&ctx->buffer[60], ctx->hi)
+
+	body(ctx, ctx->buffer, 64);
+
+	OUT(&result[0], ctx->a)
+	OUT(&result[4], ctx->b)
+	OUT(&result[8], ctx->c)
+	OUT(&result[12], ctx->d)
+
+	memset(ctx, 0, sizeof(*ctx));
+}
+
+#endif

+ 54 - 0
Gems/Atom/RHI/DX12/External/md5/openssl/md5.h

@@ -0,0 +1,54 @@
+/*
+ * This is an OpenSSL-compatible implementation of the RSA Data Security, Inc.
+ * MD5 Message-Digest Algorithm (RFC 1321).
+ *
+ * Homepage:
+ * http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5
+ *
+ * Author:
+ * Alexander Peslyak, better known as Solar Designer <solar at openwall.com>
+ *
+ * This software was written by Alexander Peslyak in 2001.  No copyright is
+ * claimed, and the software is hereby placed in the public domain.
+ * In case this attempt to disclaim copyright and place the software in the
+ * public domain is deemed null and void, then the software is
+ * Copyright (c) 2001 Alexander Peslyak and it is hereby released to the
+ * general public under the following terms:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted.
+ *
+ * There's ABSOLUTELY NO WARRANTY, express or implied.
+ *
+ * See md5.c for more information.
+ */
+
+#ifdef HAVE_OPENSSL
+#include <openssl/md5.h>
+#elif !defined(_MD5_H)
+#define _MD5_H
+
+/* Added by baldurk, for C++ compatibility */
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+/* Any 32-bit or wider unsigned integer data type will do */
+typedef unsigned int _MD5_u32plus;
+
+typedef struct {
+	_MD5_u32plus lo, hi;
+	_MD5_u32plus a, b, c, d;
+	unsigned char buffer[64];
+	_MD5_u32plus block[16];
+} _MD5_CTX;
+
+extern void _MD5_Init(_MD5_CTX *ctx);
+extern void _MD5_Update(_MD5_CTX *ctx, const void *data, unsigned long size);
+extern void _MD5_Final(unsigned char *result, _MD5_CTX *ctx);
+
+#if defined(__cplusplus)
+};    // extern "C"
+#endif
+
+#endif

+ 2 - 1
Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -173,7 +173,8 @@ namespace AZ
            RHI::ShaderHardwareStage shaderStage,
            const AZStd::string& tempFolderPath,
            StageDescriptor& outputDescriptor,
-           const RHI::ShaderBuildArguments& shaderBuildArguments) const
+           const RHI::ShaderBuildArguments& shaderBuildArguments,
+           [[maybe_unused]] const bool useSpecializationConstants) const
         {
             for ([[maybe_unused]] auto srgLayout : m_srgLayouts)
             {

+ 2 - 1
Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -51,7 +51,8 @@ namespace AZ
                 RHI::ShaderHardwareStage shaderStage,
                 const AZStd::string& tempFolderPath,
                 StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override; 
 
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
 

+ 41 - 7
Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.cpp

@@ -69,7 +69,7 @@ namespace AZ
             return m_primitiveTopology;
         }
         
-        id<MTLFunction> PipelineState::CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view sourceStr, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction)
+        id<MTLFunction> PipelineState::CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view sourceStr, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction, MTLFunctionConstantValues* constantValues)
         {
             id<MTLFunction> pFunction = nullptr;
             NSString* source = [[NSString alloc] initWithCString : sourceStr.data() encoding: NSASCIIStringEncoding];
@@ -125,7 +125,7 @@ namespace AZ
             if (lib)
             {
                 NSString* entryPointStr = [[NSString alloc] initWithCString : entryPoint.data() encoding: NSASCIIStringEncoding];
-                pFunction = [lib newFunctionWithName:entryPointStr];
+                pFunction = [lib newFunctionWithName:entryPointStr constantValues:constantValues error:&error];
                 [entryPointStr release];
                 entryPointStr = nil;
                 [lib release];
@@ -170,9 +170,11 @@ namespace AZ
             [vertexDescriptor release];
             vertexDescriptor = nil;
             
-            m_renderPipelineDesc.vertexFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_vertexFunction.get());
+            MTLFunctionConstantValues* constantValues = CreateFunctionConstantsValues(descriptor);
+            
+            m_renderPipelineDesc.vertexFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_vertexFunction.get(), constantValues);
             AZ_Assert(m_renderPipelineDesc.vertexFunction, "Vertex mtlFuntion can not be null");
-            m_renderPipelineDesc.fragmentFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_fragmentFunction.get());
+            m_renderPipelineDesc.fragmentFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_fragmentFunction.get(), constantValues);
             
             RHI::Format depthStencilFormat = attachmentsConfiguration.GetDepthStencilFormat();
             if(descriptor.m_renderStates.m_depthStencilState.m_stencil.m_enable || IsDepthStencilMerged(depthStencilFormat))
@@ -226,6 +228,8 @@ namespace AZ
                 m_renderPipelineDesc = nil;
             }
             
+            [constantValues release];
+            constantValues = nil;
              
             m_pipelineStateMultiSampleState = descriptor.m_renderStates.m_multisampleState;
             
@@ -257,7 +261,8 @@ namespace AZ
             m_computePipelineDesc = [[MTLComputePipelineDescriptor alloc] init];
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             AZ_Assert(pipelineLayout, "PipelineLayout can not be null");
-            m_computePipelineDesc.computeFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_computeFunction.get());
+            MTLFunctionConstantValues* constantValues = CreateFunctionConstantsValues(descriptor);
+            m_computePipelineDesc.computeFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_computeFunction.get(), constantValues);
             AZ_Assert(m_computePipelineDesc.computeFunction, "Compute mtlFuntion can not be null");
             
             PipelineLibrary* pipelineLibrary = static_cast<PipelineLibrary*>(pipelineLibraryBase);
@@ -279,6 +284,9 @@ namespace AZ
                 [m_computePipelineDesc release];
                 m_computePipelineDesc = nil;
             }
+                                                                       
+            [constantValues release];
+            constantValues = nil;
             
             if (m_computePipelineState)
             {
@@ -300,7 +308,7 @@ namespace AZ
             return RHI::ResultCode::Fail;
         }
 
-        id<MTLFunction> PipelineState::ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc)
+        id<MTLFunction> PipelineState::ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc, MTLFunctionConstantValues* constantValues)
         {
             // set the bound shader state settings
             if (stageFunc)
@@ -309,13 +317,39 @@ namespace AZ
                 AZStd::string_view strView(shaderFunction->GetSourceCode());
                 
                 id<MTLFunction> mtlFunction = nil;
-                mtlFunction = CompileShader(mtlDevice, strView, shaderFunction->GetEntryFunctionName(), shaderFunction);
+                mtlFunction = CompileShader(mtlDevice, strView, shaderFunction->GetEntryFunctionName(), shaderFunction, constantValues);
 
                 return mtlFunction;
             }
             
             return nil;
         }
+    
+        MTLFunctionConstantValues* PipelineState::CreateFunctionConstantsValues(const RHI::PipelineStateDescriptor& pipelineDescriptor) const
+        {
+            MTLFunctionConstantValues* constantValues = [[MTLFunctionConstantValues alloc] init];
+            for(const auto& specializationData : pipelineDescriptor.m_specializationData)
+            {
+                uint32_t value = specializationData.m_value.GetIndex();
+                MTLDataType type;
+                switch(specializationData.m_type)
+                {
+                    case RHI::SpecializationType::Integer:
+                        type = MTLDataTypeInt;
+                        break;
+                    case RHI::SpecializationType::Bool:
+                        type = MTLDataTypeBool;
+                        break;
+                    default:
+                        AZ_Assert(false, "Invalid specialization type %d", specializationData.m_type);
+                        type = MTLDataTypeInt;
+                        break;
+                }
+                [constantValues setConstantValue:&value type:type atIndex:specializationData.m_id];
+            }
+            return constantValues;
+        }
+    
         void PipelineState::ShutdownInternal()
         {
             if (m_graphicsPipelineState)

+ 3 - 2
Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.h

@@ -71,8 +71,9 @@ namespace AZ
             void ShutdownInternal() override;
             //////////////////////////////////////////////////////////////////////////
 
-            id<MTLFunction> CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view filePath, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction);
-            id<MTLFunction> ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc);
+            id<MTLFunction> CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view filePath, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction, MTLFunctionConstantValues* constantValues);
+            id<MTLFunction> ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc, MTLFunctionConstantValues* constantValues);
+            MTLFunctionConstantValues* CreateFunctionConstantsValues(const RHI::PipelineStateDescriptor& pipelineDescriptor) const;
             
             RHI::ConstPtr<PipelineLayout> m_pipelineLayout;
             AZStd::atomic_bool m_isCompiled = {false};

+ 2 - 1
Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -100,7 +100,8 @@ namespace AZ
             [[maybe_unused]] const AssetBuilderSDK::PlatformInfo& platform, [[maybe_unused]] const AZStd::string& shaderSourcePath,
             [[maybe_unused]] const AZStd::string& functionName, [[maybe_unused]] RHI::ShaderHardwareStage shaderStage,
             [[maybe_unused]] const AZStd::string& tempFolderPath, [[maybe_unused]] StageDescriptor& outputDescriptor,
-            [[maybe_unused]] const RHI::ShaderBuildArguments& shaderBuildArguments) const
+            [[maybe_unused]] const RHI::ShaderBuildArguments& shaderBuildArguments,
+            [[maybe_unused]] const bool useSpecializationConstants) const
         {
             outputDescriptor.m_stageType = shaderStage;
             return true;

+ 2 - 1
Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -33,7 +33,8 @@ namespace AZ
             bool CompilePlatformInternal(
                 const AssetBuilderSDK::PlatformInfo& platform, const AZStd::string& shaderSourcePath, const AZStd::string& functionName,
                 RHI::ShaderHardwareStage shaderStage, const AZStd::string& tempFolderPath, StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override;
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
             bool BuildPipelineLayoutDescriptor(
                 RHI::Ptr<RHI::PipelineLayoutDescriptor> pipelineLayoutDescriptor,

+ 2 - 1
Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -112,7 +112,8 @@ namespace AZ
             RHI::ShaderHardwareStage shaderAssetType,
             const AZStd::string& tempFolderPath,
             StageDescriptor& outputDescriptor,
-            const RHI::ShaderBuildArguments& shaderBuildArguments) const
+            const RHI::ShaderBuildArguments& shaderBuildArguments,
+            [[maybe_unused]] const bool useSpecializationConstants) const
         {
             AZStd::vector<uint8_t> shaderByteCode;
 

+ 2 - 1
Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -49,7 +49,8 @@ namespace AZ
                 RHI::ShaderHardwareStage shaderStage,
                 const AZStd::string& tempFolderPath,
                 StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override;
 
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
 

+ 3 - 1
Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.cpp

@@ -32,6 +32,7 @@ namespace AZ
             }
 
             Base::Init(*descriptor.m_device);
+            m_specializationConstantData.Init(*descriptor.m_pipelineDescritor);
 
             RHI::ResultCode result = InitInternal(descriptor, *layout);
             RETURN_RESULT_IF_UNSUCCESSFUL(result);
@@ -91,6 +92,7 @@ namespace AZ
                 device.GetContext().DestroyPipeline(device.GetNativeDevice(), m_nativePipeline, VkSystemAllocator::Get());
                 m_nativePipeline = VK_NULL_HANDLE;
             }
+            m_specializationConstantData.Shutdown();
             Base::Shutdown();
         }
 
@@ -137,7 +139,7 @@ namespace AZ
             createInfo.stage = stageBits;
             createInfo.module = shaderModule->GetNativeShaderModule();
             createInfo.pName = shaderModule->GetEntryFunctionName().c_str();
-            createInfo.pSpecializationInfo = nullptr;
+            createInfo.pSpecializationInfo = m_specializationConstantData.GetVkSpecializationInfo();
         }
     }
 }

+ 3 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.h

@@ -14,6 +14,7 @@
 #include <AzCore/std/containers/list.h>
 #include <RHI/PipelineLayout.h>
 #include <RHI/ShaderModule.h>
+#include <RHI/SpecializationConstantData.h>
 
 namespace AZ
 {
@@ -78,6 +79,8 @@ namespace AZ
             RHI::Ptr<PipelineLayout> m_pipelineLayout;
             AZStd::list<RHI::Ptr<ShaderModule>> m_shaderModules;
             VkPipeline m_nativePipeline = VK_NULL_HANDLE;
+            // Contains the values of the specialization constants
+            SpecializationConstantData m_specializationConstantData;
         };
     }
 }

+ 11 - 3
Gems/Atom/RHI/Vulkan/Code/Source/RHI/RayTracingPipelineState.cpp

@@ -7,9 +7,10 @@
  */
 
 #include <RHI/RayTracingPipelineState.h>
+#include <RHI/Device.h>
+#include <RHI/SpecializationConstantData.h>
 #include <Atom/RHI.Reflect/SamplerState.h>
 #include <Atom/RHI.Reflect/Vulkan/ShaderStageFunction.h>
-#include <RHI/Device.h>
 #include <Atom/RHI.Reflect/VkAllocator.h>
 
 namespace AZ
@@ -37,11 +38,14 @@ namespace AZ
             // process shader libraries into shader stages and groups
             AZStd::vector<VkPipelineShaderStageCreateInfo> stages;
             AZStd::vector<VkRayTracingShaderGroupCreateInfoKHR> groups;
+            AZStd::vector<SpecializationConstantData> specializationDataVector(descriptor->GetShaderLibraries().size());
+
             m_shaderModules.reserve(descriptor->GetShaderLibraries().size());
-            for (const RHI::RayTracingShaderLibrary& shaderLibrary : descriptor->GetShaderLibraries())
+            const auto& libraries = descriptor->GetShaderLibraries();
+            for (uint32_t i = 0; i < libraries.size(); ++i)
             {
+                const RHI::RayTracingShaderLibrary& shaderLibrary = libraries[i];
                 const ShaderStageFunction* rayTracingFunction = azrtti_cast<const ShaderStageFunction*>(shaderLibrary.m_descriptor.m_rayTracingFunction.get());
-
                 VkShaderModule& shaderModule = m_shaderModules.emplace_back();
                 VkShaderModuleCreateInfo moduleCreateInfo = {};
                 moduleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
@@ -50,9 +54,13 @@ namespace AZ
                 device.GetContext().CreateShaderModule(
                     device.GetNativeDevice(), &moduleCreateInfo, VkSystemAllocator::Get(), &shaderModule);
 
+                SpecializationConstantData& specializationData = specializationDataVector[i];
+                specializationData.Init(shaderLibrary.m_descriptor);
+
                 VkPipelineShaderStageCreateInfo stageCreateInfo = {};
                 stageCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
                 stageCreateInfo.module = shaderModule;
+                stageCreateInfo.pSpecializationInfo = specializationData.GetVkSpecializationInfo();
 
                 // ray generation
                 if (!shaderLibrary.m_rayGenerationShaderName.IsEmpty())

+ 66 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.cpp

@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#include <RHI/Vulkan.h>
+#include <RHI/SpecializationConstantData.h>
+
+namespace AZ::Vulkan
+{
+    template<class T>
+    void AddSpecializationValue(AZStd::vector<uint8_t>& data, const RHI::SpecializationValue& specValue)
+    {
+        size_t size = sizeof(T);
+        size_t offset = data.size();
+        data.resize(data.size() + size);
+        T value = static_cast<T>(specValue.GetIndex());
+        memcpy(data.data() + offset, &value, size);
+    }
+
+    RHI::ResultCode SpecializationConstantData::Init(const RHI::PipelineStateDescriptor& pipelineDescriptor)
+    {
+        m_specializationData.reserve(pipelineDescriptor.m_specializationData.size() * sizeof(uint32_t));
+        for (const RHI::SpecializationConstant& specialization : pipelineDescriptor.m_specializationData)
+        {
+            m_specializationMap.emplace_back();
+            VkSpecializationMapEntry& entry = m_specializationMap.back();
+            entry.constantID = specialization.m_id;
+            entry.offset = aznumeric_cast<uint32_t>(m_specializationData.size());
+            switch (specialization.m_type)
+            {
+            case RHI::SpecializationType::Integer:
+                entry.size = sizeof(uint32_t);
+                AddSpecializationValue<uint32_t>(m_specializationData, specialization.m_value);
+                break;
+            case RHI::SpecializationType::Bool:
+                entry.size = sizeof(VkBool32);
+                AddSpecializationValue<VkBool32>(m_specializationData, specialization.m_value);
+                break;
+            default:
+                AZ_Assert(false, "Invalid specialization type %d", specialization.m_type);
+                return RHI::ResultCode::InvalidArgument;
+            }
+        }
+
+        m_specializationInfo.dataSize = m_specializationData.size();
+        m_specializationInfo.mapEntryCount = aznumeric_cast<uint32_t>(m_specializationMap.size());
+        m_specializationInfo.pData = m_specializationData.data();
+        m_specializationInfo.pMapEntries = m_specializationMap.data();
+        return RHI::ResultCode::Success;
+    }
+
+    void SpecializationConstantData::Shutdown()
+    {
+        m_specializationInfo = {};
+        m_specializationMap.clear();
+        m_specializationData.clear();
+    }
+
+    const VkSpecializationInfo* SpecializationConstantData::GetVkSpecializationInfo() const
+    {
+        return m_specializationInfo.mapEntryCount ? &m_specializationInfo : nullptr;
+    }
+} // namespace AZ::Vulkan

+ 37 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.h

@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI/PipelineStateDescriptor.h>
+#include <AzCore/std/containers/vector.h>
+
+namespace AZ::Vulkan
+{
+    //! Contains the vulkan structure needed for using specialization constants
+    class SpecializationConstantData
+    {
+    public:
+        SpecializationConstantData() = default;
+
+        //! Initialize the contents with the specialization constants information of the pipeline descriptor
+        RHI::ResultCode Init(const RHI::PipelineStateDescriptor& descriptor);
+        //! Release any data previously used
+        void Shutdown();
+
+        //! Returns the vulkan specialization info
+        const VkSpecializationInfo* GetVkSpecializationInfo() const;
+
+    private:
+        // Vulkan structure for using specialization constants
+        VkSpecializationInfo m_specializationInfo{};
+        // Vector with the mapping information of the specialization constants (ids and offsets).
+        AZStd::vector<VkSpecializationMapEntry> m_specializationMap;
+        // Memory buffer with the values of the specialization constants
+        AZStd::vector<uint8_t> m_specializationData;
+    };
+}

+ 2 - 0
Gems/Atom/RHI/Vulkan/Code/atom_rhi_vulkan_private_common_files.cmake

@@ -169,4 +169,6 @@ set(FILES
     Source/RHI/Conversion.cpp
     Source/RHI/Conversion.h
     Source/RHI/WindowSurfaceBus.h
+    Source/RHI/SpecializationConstantData.cpp
+    Source/RHI/SpecializationConstantData.h
 )

+ 5 - 1
Gems/Atom/RPI/Code/Include/Atom/RPI.Edit/Shader/ShaderVariantAssetCreator.h

@@ -28,7 +28,11 @@ namespace AZ
             //!        It is still useful, because when creating the Root Variant for the ShaderAsset this assetId should
             //!        match the value that will be assigned by the asset processor because the Root Variant is serialized
             //!        as a Data::Asset<ShaderVariantAsset> inside the ShaderAsset.
-            void Begin(const AZ::Data::AssetId& assetId, const ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId stableId, bool isFullyBaked);
+            void Begin(
+                const AZ::Data::AssetId& assetId,
+                const ShaderVariantId& shaderVariantId,
+                RPI::ShaderVariantStableId stableId,
+                bool isFullyBaked);
 
             //! Finalizes and assigns ownership of the asset to result, if successful. 
             //! Otherwise false is returned and result is left untouched.

+ 3 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h

@@ -52,6 +52,9 @@ namespace AZ
             using ComputeShaderReloadedCallback = AZStd::function<void(ComputePass* computePass)>;
             void SetComputeShaderReloadedCallback(ComputeShaderReloadedCallback callback);
 
+            //! Updates the shader variant being used by the pass
+            void UpdateShaderOptions(const ShaderVariantId& shaderVariantId);
+
         protected:
             ComputePass(const PassDescriptor& descriptor, AZ::Name supervariant = AZ::Name(""));
 

+ 4 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h

@@ -46,6 +46,7 @@ namespace AZ
 
             //! Updates the shader options used in this pass
             void UpdateShaderOptions(const ShaderOptionList& shaderOptions);
+            void UpdateShaderOptions(const ShaderVariantId& shaderVariantId);
 
         protected:
             FullscreenTrianglePass(const PassDescriptor& descriptor);
@@ -64,6 +65,9 @@ namespace AZ
             void OnShaderAssetReinitialized(const Data::Asset<ShaderAsset>& shaderAsset) override;
             void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) override;
 
+            // Common code when updating the shader variant with new options
+            void UpdateShaderOptionsCommon();
+
             RHI::Viewport m_viewportState;
             RHI::Scissor m_scissorState;
 

+ 4 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h

@@ -38,6 +38,7 @@ namespace AZ
             //! Initialize the pipeline state from a shader and one of its shader variant
             //! The previous data will be reset
             void Init(const Data::Instance<Shader>& shader, const ShaderOptionList* optionAndValues = nullptr);
+            void Init(const Data::Instance<Shader>& shader, const ShaderVariantId& shaderVariantId);
 
             //! Update the pipeline state descriptor for the specified scene
             //! This is usually called when Scene's render pipelines changed
@@ -79,6 +80,9 @@ namespace AZ
             //! Clear all the states and references
             void Shutdown();
 
+            //! Returns the id of the shader variant being used
+            const ShaderVariantId& GetShaderVariantId() const;
+
         private:
             ///////////////////////////////////////////////////////////////////
             // ShaderReloadNotificationBus overrides...

+ 22 - 2
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h

@@ -26,9 +26,17 @@ namespace AZ
             virtual ~ShaderVariant();
             AZ_DEFAULT_COPY_MOVE(ShaderVariant);
 
-            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant. (Note that
-            //! this does not fill the InputStreamLayout or OutputAttachmentLayout as that also requires 
+            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant. 
+            //! It also configures the specialization constants if they are being used by the shader variant.
+            //! (Note that this does not fill the InputStreamLayout or OutputAttachmentLayout as that also requires 
             //! information from the mesh data and pass system and must be done as a separate step).
+            void ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor, const ShaderVariantId& specialization) const;
+            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant.
+            //! It also configures the specialization constants if they are being used by the shader variant.
+            void ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor, const ShaderOptionGroup& specialization) const;
+            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant.
+            //! Only use this function if the shader variant is not using ANY specialization constant. Otherwise
+            //! an error will be raised and the default values will be used.
             void ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const;
 
             const ShaderVariantId& GetShaderVariantId() const { return m_shaderVariantAsset->GetShaderVariantId(); }
@@ -38,6 +46,15 @@ namespace AZ
             //! If the shader variant is not fully baked, the ShaderVariantKeyFallbackValue must be correctly set when drawing.
             bool IsFullyBaked() const { return m_shaderVariantAsset->IsFullyBaked(); }
 
+            //! Returns whether the variant is using specialization constants for all of the options.
+            bool IsFullySpecialized() const;
+
+            //! Return true if this shader variant has at least one shader option using specialization constant.
+            bool UseSpecializationConstants() const;
+
+            //! Return true if this variant needs the ShaderVariantKeyFallbackValue to be correctly set when drawing.
+            bool UseKeyFallback() const;
+
             //! Return the timestamp when this asset was built.
             //! This is used to synchronize versions of the ShaderAsset and ShaderVariantAsset, especially during hot-reload.
             //! This timestamp must be >= than the ShaderAsset timestamp.
@@ -70,6 +87,9 @@ namespace AZ
 
             const RHI::RenderStates* m_renderStates = nullptr; // Cached from ShaderAsset.
             SupervariantIndex m_supervariantIndex;
+
+            // True if there's at least one shader option that is using a specialization constant.
+            bool m_useSpecializationConstants = false;
         };
     }
 }

+ 18 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAsset.h

@@ -216,6 +216,20 @@ namespace AZ
                 return GetAttribute(shaderStage, attributeName, DefaultSupervariantIndex);
             }
 
+            //! Returns if the supervariant uses specialization constants for at least one shader options.
+            bool UseSpecializationConstants(SupervariantIndex supervariantIndex) const;
+            bool UseSpecializationConstants() const
+            {
+                return UseSpecializationConstants(DefaultSupervariantIndex);
+            }
+
+            //! Returns true if the supervariant is fully specialized (all shader options are specialization constants)
+            bool IsFullySpecialized(SupervariantIndex supervariantIndex) const;
+            bool IsFullySpecialized() const
+            {
+                return IsFullySpecialized(DefaultSupervariantIndex);
+            }
+
         private:
             ///////////////////////////////////////////////////////////////////
             /// ShaderVariantFinderNotificationBus overrides
@@ -242,6 +256,7 @@ namespace AZ
                 RHI::RenderStates m_renderStates;
                 RHI::ShaderStageAttributeMapList m_attributeMaps;
                 Data::Asset<ShaderVariantAsset> m_rootShaderVariantAsset;
+                bool m_useSpecializationConstants = false;
             };
 
             //! Container of shader data that is specific to an RHI API.
@@ -313,6 +328,9 @@ namespace AZ
             mutable AZStd::shared_mutex m_variantTreeMutex;
 
             bool m_shaderVariantTreeLoadWasRequested = false;
+
+            //! True if all supervariants are fully specialized
+            bool m_isFullySpecialized = false;
         };
 
         class ShaderAssetHandler final

+ 3 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAssetCreator.h

@@ -73,6 +73,9 @@ namespace AZ
             //! [Required] There's always a root variant for each supervariant.
             void SetRootShaderVariantAsset(Data::Asset<ShaderVariantAsset> shaderVariantAsset);
 
+            //! Set if the supervariant uses specialization constants for shader options.
+            void SetUseSpecializationConstants(bool value);
+
             bool EndSupervariant();
 
             bool EndAPI();

+ 18 - 1
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderOptionGroupLayout.h

@@ -90,7 +90,8 @@ namespace AZ
                                    uint32_t order,
                                    const ShaderOptionValues& nameIndexList,
                                    const Name& defaultValue = {},
-                                   uint32_t cost = 0);
+                                   uint32_t cost = 0,
+                                   int specializationId = -1);
 
             AZ_DEFAULT_COPY_MOVE(ShaderOptionDescriptor);
 
@@ -105,6 +106,9 @@ namespace AZ
 
             uint32_t GetCostEstimate() const;
 
+            //! Return the specialization id. -1 if this option can't be specialize.
+            int GetSpecializationId() const;
+
             //! Returns the mask comprising bits specific to this option.
             ShaderVariantKey GetBitMask() const;
 
@@ -192,6 +196,7 @@ namespace AZ
             uint32_t m_bitCount = 0;
             uint32_t m_order = 0;          //!< The order (or rank) of the shader option dictates its priority. Lower order (rank) is higher priority.
             uint32_t m_costEstimate = 0;
+            int m_specializationId = -1; //< Specialization id. A value of -1 means no specialization.
             ShaderVariantKey m_bitMask;
             ShaderVariantKey m_bitMaskNot;
 
@@ -263,6 +268,13 @@ namespace AZ
 
             HashValue64 GetHash() const;
 
+            //! Returns true if all shader options of the layout are using specialization constants. Please note that each
+            //! supervariant can have specialization constants off even if the layout is IsFullySpecialized.
+            bool IsFullySpecialized() const;
+
+            //! Returns true if at least one shader option is using specialization constant.
+            bool UseSpecializationConstants() const;
+
         private:
             ShaderOptionGroupLayout() = default;
 
@@ -281,6 +293,11 @@ namespace AZ
             using NameReflectionMapForOptions = RHI::NameIdReflectionMap<ShaderOptionIndex>;
             NameReflectionMapForOptions m_nameReflectionForOptions;
             HashValue64 m_hash = HashValue64{ 0 };
+
+            // True if all shader options are using specialization constants
+            bool m_isFullySpecialized = false;
+            // True if at least one shader options is using specialization constants
+            bool m_useSpecializationConstants = false;
         };
     } // namespace RPI
 

+ 2 - 2
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderVariantAsset.h

@@ -58,14 +58,14 @@ namespace AZ
 
             //! Returns whether the variant is fully baked variant (all options are static branches), or false if the
             //! variant uses dynamic branches for some shader options.
-            //! If the shader variant is not fully baked, the ShaderVariantKeyFallbackValue must be correctly set when drawing.
+            //! If the shader variant is not fully baked or fully specialized, the ShaderVariantKeyFallbackValue must be correctly set when drawing.
             bool IsFullyBaked() const;
 
             //! Return the timestamp when this asset was built, and it must be >= than the timestamp of the main ShaderAsset.
             //! This is used to synchronize versions of the ShaderAsset and ShaderVariantAsset, especially during hot-reload.
             AZ::u64 GetBuildTimestamp() const;
 
-            bool IsRootVariant() const { return m_stableId == RPI::RootShaderVariantStableId; } 
+            bool IsRootVariant() const { return m_stableId == RPI::RootShaderVariantStableId; }
 
         private:
             //! Called by asset creators to assign the asset to a ready state.

+ 5 - 1
Gems/Atom/RPI/Code/Source/RPI.Edit/Shader/ShaderVariantAssetCreator.cpp

@@ -15,7 +15,11 @@ namespace AZ
 {
     namespace RPI
     {
-        void ShaderVariantAssetCreator::Begin(const AZ::Data::AssetId& assetId, const ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId stableId, bool isFullyBaked)
+        void ShaderVariantAssetCreator::Begin(
+            const AZ::Data::AssetId& assetId,
+            const ShaderVariantId& shaderVariantId,
+            RPI::ShaderVariantStableId stableId,
+            bool isFullyBaked)
         {
             BeginCommon(assetId);
 

+ 1 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/MeshDrawPacket.cpp

@@ -376,7 +376,7 @@ namespace AZ
 #endif
 
                 RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
-                variant.ConfigurePipelineState(pipelineStateDescriptor);
+                variant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptions);
 
                 // Render states need to merge the runtime variation.
                 // This allows materials to customize the render states that the shader uses.

+ 19 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp

@@ -144,9 +144,14 @@ namespace AZ
 
             // Setup pipeline state...
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            m_shader->GetDefaultVariant().ConfigurePipelineState(pipelineStateDescriptor);
+            ShaderOptionGroup options = m_shader->GetDefaultShaderOptions();
+            m_shader->GetDefaultVariant().ConfigurePipelineState(pipelineStateDescriptor, options);
 
             m_dispatchItem.SetPipelineState(m_shader->AcquirePipelineState(pipelineStateDescriptor));
+            if (m_drawSrg && m_shader->GetDefaultVariant().UseKeyFallback())
+            {
+                m_drawSrg->SetShaderVariantKeyFallbackValue(options.GetShaderVariantKeyFallbackValue());
+            }
 
             OnShaderReloadedInternal();
 
@@ -255,6 +260,19 @@ namespace AZ
             m_shaderReloadedCallback = callback;
         }
 
+        void ComputePass::UpdateShaderOptions(const ShaderVariantId& shaderVariantId)
+        {
+            const ShaderVariant& shaderVariant = m_shader->GetVariant(shaderVariantId);
+            RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderVariantId);
+
+            m_dispatchItem.SetPipelineState(m_shader->AcquirePipelineState(pipelineStateDescriptor));
+            if (m_drawSrg && shaderVariant.UseKeyFallback())
+            {
+                m_drawSrg->SetShaderVariantKeyFallbackValue(shaderVariantId.m_key);
+            }
+        }
+
         void ComputePass::OnShaderReloadedInternal()
         {
             if (m_shaderReloadedCallback)

+ 17 - 3
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp

@@ -68,6 +68,12 @@ namespace AZ
             LoadShader();
         }
 
+        void FullscreenTrianglePass::UpdateShaderOptionsCommon()
+        {
+            m_pipelineStateForDraw.UpdateSrgVariantFallback(m_shaderResourceGroup);
+            BuildDrawItem();
+        }
+
         void FullscreenTrianglePass::LoadShader()
         {
             AZ_Assert(GetPassState() != PassState::Rendering, "FullscreenTrianglePass - Reloading shader during Rendering phase!");
@@ -120,7 +126,7 @@ namespace AZ
             // Store stencil reference value for the draw call
             m_stencilRef = passData->m_stencilRef;
 
-            m_pipelineStateForDraw.Init(m_shader);
+            m_pipelineStateForDraw.Init(m_shader, m_shader->GetDefaultShaderOptions().GetShaderVariantId());
 
             UpdateSrgs();
 
@@ -191,8 +197,16 @@ namespace AZ
             if (m_shader)
             {
                 m_pipelineStateForDraw.Init(m_shader, &shaderOptions);
-                m_pipelineStateForDraw.UpdateSrgVariantFallback(m_shaderResourceGroup);
-                BuildDrawItem();
+                UpdateShaderOptionsCommon();
+            }
+        }
+
+        void FullscreenTrianglePass::UpdateShaderOptions(const ShaderVariantId& shaderVariantId)
+        {
+            if (m_shader)
+            {
+                m_pipelineStateForDraw.Init(m_shader, shaderVariantId);
+                UpdateShaderOptionsCommon();
             }
         }
 

+ 1 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/Specific/ImageAttachmentPreviewPass.cpp

@@ -324,7 +324,7 @@ namespace AZ
 
                 shaderOption.SetValue(AZ::Name(optionName), AZ::Name(optionValues[index]));
 
-                m_shader->GetVariant(shaderOption.GetShaderVariantId()).ConfigurePipelineState(pipelineDesc);
+                m_shader->GetVariant(shaderOption.GetShaderVariantId()).ConfigurePipelineState(pipelineDesc, shaderOption);
                 pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = attachmentsLayout;
                 pipelineDesc.m_inputStreamLayout = inputStreamLayout;
                 previewInfo.m_shaderVariantKeyFallback = shaderOption.GetShaderVariantKeyFallbackValue();

+ 27 - 16
Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp

@@ -51,16 +51,8 @@ namespace AZ
 
         void PipelineStateForDraw::Init(const Data::Instance<RPI::Shader>& shader, const ShaderOptionList* optionAndValues)
         {
-            // Reset some variables
-            m_pipelineState = nullptr;
-            m_shaderVariantId = ShaderVariantId{};
-
-            // Reset some flags
-            m_dirty = true;
-            m_isShaderVariantReady = true;
-                        
             // Get shader variant from the shader
-            auto shaderVariant = shader->GetRootVariant();
+            ShaderVariantId shaderVariant = {};
             if (optionAndValues)
             {
                 RPI::ShaderOptionGroup shaderOptionGroup = shader->CreateShaderOptionGroup();
@@ -69,15 +61,29 @@ namespace AZ
                 {
                     shaderOptionGroup.SetValue(optionAndValue.first, optionAndValue.second);
                 }
-                m_shaderVariantId = shaderOptionGroup.GetShaderVariantId();
-                shaderVariant = shader->GetVariant(m_shaderVariantId);
-                m_isShaderVariantReady = shaderVariant.IsFullyBaked();
+                shaderVariant = shaderOptionGroup.GetShaderVariantId();
             }
+            Init(shader, shaderVariant);
+        }
+
+        void PipelineStateForDraw::Init(const Data::Instance<Shader>& shader, const ShaderVariantId& shaderVariantId)
+        {
+            // Reset some variables
+            m_pipelineState = nullptr;
+
+            // Reset some flags
+            m_dirty = true;
+            m_isShaderVariantReady = true;
+
+            // Get shader variant from the shader
+            m_shaderVariantId = shaderVariantId;
+            ShaderVariant shaderVariant = shader->GetVariant(m_shaderVariantId);
+            m_isShaderVariantReady = !shaderVariant.UseKeyFallback();
 
             // Fill the descriptor with data from shader variant
-            shaderVariant.ConfigurePipelineState(m_descriptor);
+            shaderVariant.ConfigurePipelineState(m_descriptor, m_shaderVariantId);
 
-            // Connect to shader reload notification bus to rebuilt pipeline state when shader or shader variant changed. 
+            // Connect to shader reload notification bus to rebuilt pipeline state when shader or shader variant changed.
             ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
             ShaderReloadNotificationBus::MultiHandler::BusConnect(shader->GetAsset().GetId());
 
@@ -90,11 +96,11 @@ namespace AZ
         void PipelineStateForDraw::RefreshShaderVariant()
         {
             auto shaderVariant = m_shader->GetVariant(m_shaderVariantId);
-            m_isShaderVariantReady = shaderVariant.IsFullyBaked();
+            m_isShaderVariantReady = !shaderVariant.UseKeyFallback();
 
             auto multisampleState = m_descriptor.m_renderStates.m_multisampleState;
 
-            shaderVariant.ConfigurePipelineState(m_descriptor);
+            shaderVariant.ConfigurePipelineState(m_descriptor, m_shaderVariantId);
 
             // Recover multisampleState if it was set from output data
             if (m_hasOutputData)
@@ -233,5 +239,10 @@ namespace AZ
                         
             ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
         }
+
+        const ShaderVariantId& PipelineStateForDraw::GetShaderVariantId() const
+        {
+            return m_shaderVariantId;
+        }
     }
 }

+ 3 - 2
Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp

@@ -501,8 +501,9 @@ namespace AZ
             if (drawSrgLayout)
             {
                 drawSrg = RPI::ShaderResourceGroup::Create(m_asset, GetSupervariantIndex(), drawSrgLayout->GetName());
-
-                if (drawSrgLayout->HasShaderVariantKeyFallbackEntry())
+                bool useFallbackKey = !shaderOptions.GetShaderOptionLayout()->IsFullySpecialized() ||
+                    !m_asset->UseSpecializationConstants(GetSupervariantIndex());
+                if (useFallbackKey && drawSrgLayout->HasShaderVariantKeyFallbackEntry())
                 {
                     drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
                 }

+ 78 - 2
Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp

@@ -29,7 +29,7 @@ namespace AZ
             m_pipelineStateType = shaderAsset->GetPipelineStateType();
             m_pipelineLayoutDescriptor = shaderAsset->GetPipelineLayoutDescriptor(supervariantIndex);
             m_renderStates = &shaderAsset->GetRenderStates(supervariantIndex);
-
+            m_useSpecializationConstants = shaderAsset->UseSpecializationConstants(supervariantIndex);
             return true;
         }
 
@@ -38,7 +38,16 @@ namespace AZ
 
         }
 
-        void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
+        void ShaderVariant::ConfigurePipelineState(
+            RHI::PipelineStateDescriptor& descriptor,
+            const ShaderVariantId& specialization) const
+        {
+            ConfigurePipelineState(descriptor, ShaderOptionGroup(m_shaderAsset->GetShaderOptionGroupLayout(), specialization));
+        }
+
+        void ShaderVariant::ConfigurePipelineState(
+            RHI::PipelineStateDescriptor& descriptor,
+            const ShaderOptionGroup& specialization) const
         {
             descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor;
 
@@ -76,6 +85,73 @@ namespace AZ
                 AZ_Assert(false, "Unexpected PipelineStateType");
                 break;
             }
+
+            if (m_useSpecializationConstants)
+            {
+                // Configure specialization data for the shader
+                AZ_Assert(
+                    specialization.GetShaderOptionLayout() == m_shaderAsset->GetShaderOptionGroupLayout(),
+                    "OptionGroup for specialization is different to the one in the ShaderAsset");
+                descriptor.m_specializationData.clear();
+                ShaderOptionGroup options = specialization;
+                options.SetUnspecifiedToDefaultValues();
+                for (auto& option : options.GetShaderOptionLayout()->GetShaderOptions())
+                {
+                    if (option.GetSpecializationId() >= 0)
+                    {
+                        descriptor.m_specializationData.emplace_back();
+                        auto& specializationData = descriptor.m_specializationData.back();
+                        specializationData.m_name = option.GetName();
+                        specializationData.m_id = option.GetSpecializationId();
+                        specializationData.m_value = RHI::SpecializationValue(option.Get(options).GetIndex());
+                        switch (option.GetType())
+                        {
+                        case ShaderOptionType::Boolean:
+                            specializationData.m_type = RHI::SpecializationType::Bool;
+                            break;
+                        case ShaderOptionType::Enumeration:
+                        case ShaderOptionType::IntegerRange:
+                            specializationData.m_type = RHI::SpecializationType::Integer;
+                            break;
+                        default:
+                            break;
+                        }
+                    }
+                }
+            }
+        }
+
+        void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
+        {
+            auto layout = m_shaderAsset->GetShaderOptionGroupLayout();
+            for ([[maybe_unused]] auto& option : layout->GetShaderOptions())
+            {
+                if (m_useSpecializationConstants && option.GetSpecializationId() >= 0)
+                {
+                    AZ_Error(
+                        "ConfigurePipelineState",
+                        !m_useSpecializationConstants || option.GetSpecializationId() < 0,
+                        "Configuring PipelineStateDescriptor without specializing option %s.\
+                         Call ConfigurePipelineState with specialization data. Default value will be used.",
+                        option.GetName().GetCStr());
+                }
+            }
+            ConfigurePipelineState(descriptor, ShaderOptionGroup(layout));
+        }
+
+        bool ShaderVariant::IsFullySpecialized() const
+        {
+            return m_shaderAsset->IsFullySpecialized(m_supervariantIndex);
+        }
+
+        bool ShaderVariant::UseSpecializationConstants() const
+        {
+            return m_shaderAsset->UseSpecializationConstants(m_supervariantIndex);
+        }
+
+        bool ShaderVariant::UseKeyFallback() const
+        {
+            return !(IsFullyBaked() || IsFullySpecialized());
         }
 
     } // namespace RPI

+ 24 - 3
Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAsset.cpp

@@ -72,6 +72,7 @@ namespace AZ
                     ->Field("RenderStates", &Supervariant::m_renderStates)
                     ->Field("AttributeMapList", &Supervariant::m_attributeMaps)
                     ->Field("RootVariantAsset", &Supervariant::m_rootShaderVariantAsset)
+                    ->Field("UseSpecializationConstants", &Supervariant::m_useSpecializationConstants)
                     ;
             }
         }
@@ -194,7 +195,7 @@ namespace AZ
             Data::Asset<ShaderAsset> thisAsset(this, Data::AssetLoadBehavior::Default);
             Data::Asset<ShaderVariantAsset> shaderVariantAsset =
                 variantFinder->GetShaderVariantAssetByVariantId(thisAsset, shaderVariantId, supervariantIndex);
-            if (!shaderVariantAsset)
+            if (!shaderVariantAsset && !IsFullySpecialized(supervariantIndex))
             {
                 variantFinder->QueueLoadShaderVariantAssetByVariantId(thisAsset, shaderVariantId, supervariantIndex);
             }
@@ -206,7 +207,7 @@ namespace AZ
             uint32_t dynamicOptionCount = aznumeric_cast<uint32_t>(GetShaderOptionGroupLayout()->GetShaderOptions().size());
             ShaderVariantSearchResult variantSearchResult{RootShaderVariantStableId,  dynamicOptionCount };
 
-            if (!dynamicOptionCount)
+            if (!dynamicOptionCount || m_isFullySpecialized)
             {
                 // The shader has no options at all. There's nothing to search.
                 return variantSearchResult;
@@ -245,7 +246,9 @@ namespace AZ
         Data::Asset<ShaderVariantAsset> ShaderAsset::GetVariantAsset(
             ShaderVariantStableId shaderVariantStableId, SupervariantIndex supervariantIndex) const
         {
-            if (!shaderVariantStableId.IsValid() || shaderVariantStableId == RootShaderVariantStableId)
+            if (!shaderVariantStableId.IsValid() ||
+                shaderVariantStableId == RootShaderVariantStableId ||
+                IsFullySpecialized(supervariantIndex))
             {
                 return GetRootVariantAsset(supervariantIndex);
             }
@@ -457,6 +460,22 @@ namespace AZ
             return attrPair->second;
         }
 
+        bool ShaderAsset::UseSpecializationConstants(SupervariantIndex supervariantIndex) const
+        {
+            auto supervariant = GetSupervariant(supervariantIndex);
+            if (!supervariant)
+            {
+                return false;
+            }
+
+            return supervariant->m_useSpecializationConstants;
+        }
+
+        bool ShaderAsset::IsFullySpecialized(SupervariantIndex supervariantIndex) const
+        {            
+            return UseSpecializationConstants(supervariantIndex) && m_shaderOptionGroupLayout->IsFullySpecialized();
+        }
+
         ShaderAsset::ShaderApiDataContainer& ShaderAsset::GetCurrentShaderApiData()
         {
             const size_t perApiShaderDataCount = m_perAPIShaderData.size();
@@ -556,12 +575,14 @@ namespace AZ
                 }
             }
 
+            m_isFullySpecialized = m_shaderOptionGroupLayout->IsFullySpecialized();
             // Common finalize check
             for (const auto& shaderApiData : m_perAPIShaderData)
             {
                 const auto& supervariants = shaderApiData.m_supervariants;
                 for (const auto& supervariant : supervariants)
                 {
+                    m_isFullySpecialized &= supervariant.m_useSpecializationConstants;
                     bool beTrue = supervariant.m_attributeMaps.size() == RHI::ShaderStageCount;
                     if (!beTrue)
                     {

+ 17 - 0
Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAssetCreator.cpp

@@ -238,6 +238,20 @@ namespace AZ
             m_currentSupervariant->m_rootShaderVariantAsset = shaderVariantAsset;
         }
 
+        void ShaderAssetCreator::SetUseSpecializationConstants(bool value)
+        {
+            if (!ValidateIsReady())
+            {
+                return;
+            }
+            if (!m_currentSupervariant)
+            {
+                ReportError("BeginSupervariant() should be called first before calling %s", __FUNCTION__);
+                return;
+            }
+            m_currentSupervariant->m_useSpecializationConstants = value;
+        }
+
         static RHI::PipelineStateType GetPipelineStateType(const Data::Asset<ShaderVariantAsset>& shaderVariantAsset)
         {
             if (shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Vertex) ||
@@ -343,6 +357,9 @@ namespace AZ
                 }
             }
 
+            m_currentSupervariant->m_useSpecializationConstants =
+                m_currentSupervariant->m_useSpecializationConstants && m_asset->m_shaderOptionGroupLayout->UseSpecializationConstants();
+
             m_currentSupervariant = nullptr;
             return true;
         }

+ 37 - 3
Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderOptionGroupLayout.cpp

@@ -85,7 +85,7 @@ namespace AZ
             if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
             {
                 serializeContext->Class<ShaderOptionDescriptor>()
-                    ->Version(5)  // 5: addition of m_costEstimate field
+                    ->Version(6)  // 6: addition of m_specializationId field
                     ->Field("m_name", &ShaderOptionDescriptor::m_name)
                     ->Field("m_type", &ShaderOptionDescriptor::m_type)
                     ->Field("m_defaultValue", &ShaderOptionDescriptor::m_defaultValue)
@@ -99,6 +99,7 @@ namespace AZ
                     ->Field("m_bitMaskNot", &ShaderOptionDescriptor::m_bitMaskNot)
                     ->Field("m_hash", &ShaderOptionDescriptor::m_hash)
                     ->Field("m_nameReflectionForValues", &ShaderOptionDescriptor::m_nameReflectionForValues)
+                    ->Field("m_specializationId", &ShaderOptionDescriptor::m_specializationId)
                     ;
             }
 
@@ -130,7 +131,8 @@ namespace AZ
                                                        uint32_t order,
                                                        const ShaderOptionValues& nameIndexList,
                                                        const Name& defaultValue,
-                                                       uint32_t cost)
+                                                       uint32_t cost,
+                                                       int specializationId)
 
             : m_name{name}
             , m_type{optionType}
@@ -138,6 +140,7 @@ namespace AZ
             , m_order{order}
             , m_costEstimate{cost}
             , m_defaultValue{defaultValue}
+            , m_specializationId{specializationId}
         {
             for (auto pair : nameIndexList)
             {   // Registers the pair in the lookup table
@@ -187,6 +190,11 @@ namespace AZ
             return m_costEstimate;
         }
 
+        int ShaderOptionDescriptor::GetSpecializationId() const
+        {
+            return m_specializationId;
+        }
+
         ShaderVariantKey ShaderOptionDescriptor::GetBitMask() const
         {
             return m_bitMask;
@@ -452,11 +460,13 @@ namespace AZ
             if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
             {
                 serializeContext->Class<ShaderOptionGroupLayout>()
-                    ->Version(2)
+                    ->Version(3)
                     ->Field("m_bitMask", &ShaderOptionGroupLayout::m_bitMask)
                     ->Field("m_options", &ShaderOptionGroupLayout::m_options)
                     ->Field("m_nameReflectionForOptions", &ShaderOptionGroupLayout::m_nameReflectionForOptions)
                     ->Field("m_hash", &ShaderOptionGroupLayout::m_hash)
+                    ->Field("m_isFullySpecialized", &ShaderOptionGroupLayout::m_isFullySpecialized)
+                    ->Field("m_useSpecializationConstants", &ShaderOptionGroupLayout::m_useSpecializationConstants)
                     ;
             }
 
@@ -486,6 +496,16 @@ namespace AZ
             return m_hash;
         }
 
+        bool ShaderOptionGroupLayout::IsFullySpecialized() const
+        {
+            return m_isFullySpecialized;
+        }
+
+        bool ShaderOptionGroupLayout::UseSpecializationConstants() const
+        {
+            return m_useSpecializationConstants;
+        }
+
         void ShaderOptionGroupLayout::Clear()
         {
             m_options.clear();
@@ -506,6 +526,20 @@ namespace AZ
                 hash = TypeHash64(option.GetHash(), hash);
             }
             m_hash = hash;
+            m_isFullySpecialized = !AZStd::any_of(
+                m_options.begin(),
+                m_options.end(),
+                [](const ShaderOptionDescriptor& elem)
+                {
+                    return elem.GetSpecializationId() < 0;
+                });
+            m_useSpecializationConstants = AZStd::any_of(
+                m_options.begin(),
+                m_options.end(),
+                [](const ShaderOptionDescriptor& elem)
+                {
+                    return elem.GetSpecializationId() >= 0;
+                });
         }
 
         bool ShaderOptionGroupLayout::ValidateIsFinalized() const

+ 212 - 5
Gems/Atom/RPI/Code/Tests/Shader/ShaderTests.cpp

@@ -133,6 +133,16 @@ namespace UnitTest
         : public RPITestFixture
     {
     protected:
+        enum class SpecializationType
+        {
+            None = 0,
+            Partial,
+            Full,
+            Count
+        };
+
+        static const uint32_t SpecializationTypeCount = static_cast<uint32_t>(SpecializationType::Count);
+
         void SetUp() override
         {
             using namespace AZ;
@@ -219,10 +229,35 @@ namespace UnitTest
                                                           Name("Off") };
             bitOffset = m_bindings[3].GetBitOffset() + m_bindings[3].GetBitCount();
 
+            AZStd::vector<RPI::ShaderOptionValuePair> idList4;
+            idList4.push_back({ Name("True"), RPI::ShaderOptionValue(0) }); // 1+ bit
+            idList4.push_back({ Name("False"), RPI::ShaderOptionValue(1) }); // ...
+
+            for (uint32_t i = 0; i < m_bindingsFullSpecialization.size(); ++i)
+            {
+                m_bindingsFullSpecialization[i] = RPI::ShaderOptionDescriptor{
+                    Name{ AZStd::to_string(i) }, RPI::ShaderOptionType::Boolean, i, i, idList4, Name("True"), 0,
+                                                 aznumeric_caster(i) };
+            }
+
+            for (uint32_t i = 0; i < m_bindingsPartialSpecialization.size(); ++i)
+            {
+                m_bindingsPartialSpecialization[i] = RPI::ShaderOptionDescriptor{ Name{ AZStd::to_string(i) },
+                                                                                  RPI::ShaderOptionType::Boolean,
+                                                                                  i,
+                                                                                  i,
+                                                                                  idList4,
+                                                                                  Name("True"),
+                                                                                  0,
+                                                                                  aznumeric_caster(i % 2) ? aznumeric_caster(i) : -1 };
+            }
+
             m_name = Name("TestName");
             m_drawListName = Name("DrawListTagName");
             m_pipelineLayoutDescriptor = TestPipelineLayoutDescriptor::Create();
             m_shaderOptionGroupLayoutForAsset = CreateShaderOptionLayout();
+            m_shaderOptionGroupLayoutForAssetPartialSpecialization = CreateShaderOptionLayout({}, SpecializationType::Partial);
+            m_shaderOptionGroupLayoutForAssetFullSpecialization = CreateShaderOptionLayout({}, SpecializationType::Full);
             m_shaderOptionGroupLayoutForVariants = m_shaderOptionGroupLayoutForAsset;
 
             // Just set up a couple values, not the whole struct, for some basic checking later that the struct is copied.
@@ -253,27 +288,46 @@ namespace UnitTest
             for (size_t i = 0; i < m_bindings.size(); ++i)
             {
                 m_bindings[i] = {};
+                m_bindingsFullSpecialization[i] = {};
+                m_bindingsPartialSpecialization[i] = {};
             }
 
             m_srgLayouts.clear();
             m_pipelineLayoutDescriptor = nullptr;
             m_shaderOptionGroupLayoutForAsset = nullptr;
+            m_shaderOptionGroupLayoutForAssetPartialSpecialization = nullptr;
+            m_shaderOptionGroupLayoutForAssetFullSpecialization = nullptr;
             m_shaderOptionGroupLayoutForVariants = nullptr;
 
             RPITestFixture::TearDown();
         }
 
-        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> CreateShaderOptionLayout(AZ::RHI::Handle<size_t> indexToOmit = {})
+        const AZ::RPI::ShaderOptionDescriptor& GetShaderOptionDescriptor(SpecializationType specializationType, uint32_t index)
+        {
+            switch (specializationType)
+            {
+            case SpecializationType::Partial:
+                return m_bindingsPartialSpecialization[index];
+            case SpecializationType::Full:
+                return m_bindingsFullSpecialization[index];
+            case SpecializationType::None:
+            default:
+                return m_bindings[index];            
+            }
+        }
+
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> CreateShaderOptionLayout(
+            AZ::RHI::Handle<size_t> indexToOmit = {}, SpecializationType specializationType = SpecializationType::None)
         {
             using namespace AZ;
 
             RPI::Ptr<RPI::ShaderOptionGroupLayout> layout = RPI::ShaderOptionGroupLayout::Create();
-            for (size_t i = 0; i < m_bindings.size(); ++i)
+            for (uint32_t i = 0; i < m_bindings.size(); ++i)
             {
                 // Allows omitting a single option to test for missing options.
                 if (indexToOmit.GetIndex() != i)
                 {
-                    layout->AddShaderOption(m_bindings[i]);
+                    layout->AddShaderOption(GetShaderOptionDescriptor(specializationType, i));
                 }
             }
             layout->Finalize();
@@ -409,15 +463,31 @@ namespace UnitTest
             return shaderVariantAsset;
         }
 
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> GetShaderOptionGroupForAssets(SpecializationType specializationType)
+        {
+            switch (specializationType)
+            {
+            case SpecializationType::None:
+                return m_shaderOptionGroupLayoutForAsset;
+            case SpecializationType::Partial:
+                return m_shaderOptionGroupLayoutForAssetPartialSpecialization;
+            case SpecializationType::Full:
+                return m_shaderOptionGroupLayoutForAssetFullSpecialization;
+            default:
+                return nullptr;
+            }
+        }
+
         void BeginCreatingTestShaderAsset(AZ::RPI::ShaderAssetCreator& creator,
-            const AZStd::vector<RHI::ShaderStage>& stagesToActivate = {RHI::ShaderStage::Vertex, RHI::ShaderStage::Fragment} )
+            const AZStd::vector<RHI::ShaderStage>& stagesToActivate = {RHI::ShaderStage::Vertex, RHI::ShaderStage::Fragment},
+            SpecializationType specializationType = SpecializationType::None)
         {
             using namespace AZ;
 
             creator.Begin(Uuid::CreateRandom());
             creator.SetName(m_name);
             creator.SetDrawListName(m_drawListName);
-            creator.SetShaderOptionGroupLayout(m_shaderOptionGroupLayoutForAsset);
+            creator.SetShaderOptionGroupLayout(GetShaderOptionGroupForAssets(specializationType));
 
             creator.BeginAPI(RHI::Factory::Get().GetType());
 
@@ -430,6 +500,8 @@ namespace UnitTest
             creator.SetInputContract(CreateSimpleShaderInputContract());
             creator.SetOutputContract(CreateSimpleShaderOutputContract());
 
+            creator.SetUseSpecializationConstants(specializationType != SpecializationType::None);
+
             RHI::ShaderStageAttributeMapList attributeMaps;
             attributeMaps.resize(RHI::ShaderStageCount);
             creator.SetShaderStageAttributeMapList(attributeMaps);
@@ -569,11 +641,16 @@ namespace UnitTest
         }
 
         AZStd::array<AZ::RPI::ShaderOptionDescriptor, 4> m_bindings;
+        AZStd::array<AZ::RPI::ShaderOptionDescriptor, 4> m_bindingsFullSpecialization;
+        AZStd::array<AZ::RPI::ShaderOptionDescriptor, 4> m_bindingsPartialSpecialization;
+
 
         AZ::Name m_name;
         AZ::Name m_drawListName;
         AZ::RHI::Ptr<AZ::RHI::PipelineLayoutDescriptor> m_pipelineLayoutDescriptor;
         AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForAsset;
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForAssetPartialSpecialization;
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForAssetFullSpecialization;
         AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForVariants;
 
         AZ::RHI::RenderStates m_renderStates;
@@ -876,6 +953,62 @@ namespace UnitTest
         EXPECT_FALSE(shaderOptionGroupLayout->FindShaderOptionIndex(Name{ "Invalid" }).IsValid());
     }
 
+    TEST_F(ShaderTests, ShaderOptionGroupLayoutSpecializationTest)
+    {
+        using namespace AZ;
+        AZStd::vector<RPI::ShaderOptionValuePair> idList4;
+        idList4.push_back({ Name("True"), RPI::ShaderOptionValue(0) });
+        idList4.push_back({ Name("False"), RPI::ShaderOptionValue(1) });
+
+        {
+            RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
+            bool success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized1" }, RPI::ShaderOptionType::Boolean, 0, 0, idList4, Name("False"), 0, 0 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized2" }, RPI::ShaderOptionType::Boolean, 1, 1, idList4, Name("False"), 0, 1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized3" }, RPI::ShaderOptionType::Boolean, 2, 2, idList4, Name("False"), 0, 2 });
+            EXPECT_TRUE(success);
+            shaderOptionGroupLayout->Finalize();
+            EXPECT_TRUE(shaderOptionGroupLayout->IsFullySpecialized());
+            EXPECT_TRUE(shaderOptionGroupLayout->UseSpecializationConstants());
+        }
+
+        {
+            RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
+            bool success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized1" }, RPI::ShaderOptionType::Boolean, 0, 0, idList4, Name("False"), 0, 0 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized2" }, RPI::ShaderOptionType::Boolean, 1, 1, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized3" }, RPI::ShaderOptionType::Boolean, 2, 2, idList4, Name("False"), 0, 1 });
+            EXPECT_TRUE(success);
+            shaderOptionGroupLayout->Finalize();
+            EXPECT_FALSE(shaderOptionGroupLayout->IsFullySpecialized());
+            EXPECT_TRUE(shaderOptionGroupLayout->UseSpecializationConstants());
+        }
+
+        {
+            RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
+            bool success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized1" }, RPI::ShaderOptionType::Boolean, 0, 0, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized2" }, RPI::ShaderOptionType::Boolean, 1, 1, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized3" }, RPI::ShaderOptionType::Boolean, 2, 2, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            shaderOptionGroupLayout->Finalize();
+            EXPECT_FALSE(shaderOptionGroupLayout->IsFullySpecialized());
+            EXPECT_FALSE(shaderOptionGroupLayout->UseSpecializationConstants());
+        }
+    }
+
     TEST_F(ShaderTests, ImplicitDefaultValue)
     {
         // Add shader option with no default value.
@@ -1826,6 +1959,41 @@ namespace UnitTest
         EXPECT_EQ(resultG.GetStableId().GetIndex(), stableId5);
     }
 
+    TEST_F(ShaderTests, ShaderAsset_SpecializationConstants)
+    {
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::None);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_FALSE(shaderAsset->UseSpecializationConstants());
+            EXPECT_FALSE(shaderAsset->IsFullySpecialized());
+        }
+
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Partial);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_TRUE(shaderAsset->UseSpecializationConstants());
+            EXPECT_FALSE(shaderAsset->IsFullySpecialized());
+        }
+
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Full);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_TRUE(shaderAsset->UseSpecializationConstants());
+            EXPECT_TRUE(shaderAsset->IsFullySpecialized());
+        }
+
+        m_shaderOptionGroupLayoutForAssetFullSpecialization = m_shaderOptionGroupLayoutForAsset;
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Full);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_FALSE(shaderAsset->UseSpecializationConstants());
+            EXPECT_FALSE(shaderAsset->IsFullySpecialized());
+        }
+    }
 
     TEST_F(ShaderTests, ShaderVariantAsset_IsFullyBaked)
     {
@@ -1853,5 +2021,44 @@ namespace UnitTest
         EXPECT_FALSE(shaderVariantAsset->IsFullyBaked());
         EXPECT_FALSE(ShaderOptionGroup(m_shaderOptionGroupLayoutForAsset, shaderVariantAsset->GetShaderVariantId()).IsFullySpecified());
     }
+
+    TEST_F(ShaderTests, ShaderVariantAsset_IsFullySpecialized)
+    {
+        using namespace AZ;
+        using namespace AZ::RPI;
+
+         {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::None);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            Data::Instance<RPI::Shader> shader = RPI::Shader::FindOrCreate(shaderAsset);
+            const RPI::ShaderVariant& rootShaderVariant = shader->GetVariant(RPI::ShaderVariantStableId{ 0 });
+            EXPECT_TRUE(rootShaderVariant.UseKeyFallback());
+            EXPECT_FALSE(rootShaderVariant.IsFullySpecialized());
+            EXPECT_FALSE(rootShaderVariant.UseSpecializationConstants());
+         }
+
+         {
+             AZ::RPI::ShaderAssetCreator creator;
+             BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Partial);
+             AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+             Data::Instance<RPI::Shader> shader = RPI::Shader::FindOrCreate(shaderAsset);
+             const RPI::ShaderVariant& rootShaderVariant = shader->GetVariant(RPI::ShaderVariantStableId{ 0 });
+             EXPECT_TRUE(rootShaderVariant.UseKeyFallback());
+             EXPECT_FALSE(rootShaderVariant.IsFullySpecialized());
+             EXPECT_TRUE(rootShaderVariant.UseSpecializationConstants());
+         }
+
+         {
+             AZ::RPI::ShaderAssetCreator creator;
+             BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Full);
+             AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+             Data::Instance<RPI::Shader> shader = RPI::Shader::FindOrCreate(shaderAsset);
+             const RPI::ShaderVariant& rootShaderVariant = shader->GetVariant(RPI::ShaderVariantStableId{ 0 });
+             EXPECT_FALSE(rootShaderVariant.UseKeyFallback());
+             EXPECT_TRUE(rootShaderVariant.IsFullySpecialized());
+             EXPECT_TRUE(rootShaderVariant.UseSpecializationConstants());
+         }
+    }
 }
 

+ 2 - 2
Gems/AtomLyIntegration/EditorModeFeedback/Code/Source/Draw/EditorStateMeshDrawPacket.cpp

@@ -190,7 +190,7 @@ namespace AZ::Render
             const RPI::ShaderVariant& variant = shader->GetVariant(finalVariantId);
 
             RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
-            variant.ConfigurePipelineState(pipelineStateDescriptor);
+            variant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptions);
 
             // Render states need to merge the runtime variation.
             // This allows materials to customize the render states that the shader uses.
@@ -221,7 +221,7 @@ namespace AZ::Render
                 // If the DrawSrg exists we must create and bind it, otherwise the CommandList will fail validation for SRG being null
                 drawSrg = RPI::ShaderResourceGroup::Create(shader->GetAsset(), shader->GetSupervariantIndex(), drawSrgLayout->GetName());
 
-                if (!variant.IsFullyBaked() && drawSrgLayout->HasShaderVariantKeyFallbackEntry())
+                if (variant.UseKeyFallback() && drawSrgLayout->HasShaderVariantKeyFallbackEntry())
                 {
                     drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
                 }

+ 0 - 14
Gems/AtomTressFX/Assets/Passes/HairParentShortCutPass.pass

@@ -193,13 +193,6 @@
                                 "Attachment": "Depth"
                             }
                         },
-                        {
-                            "LocalSlot": "InverseAlphaRTOutput",
-                            "AttachmentRef": {
-                                "Pass": "This",
-                                "Attachment": "InverseAlphaRTOutput"
-                            }
-                        },
                         {
                             "LocalSlot": "HairDepthsTextureArray",
                             "AttachmentRef": {
@@ -237,13 +230,6 @@
                     "TemplateName": "HairShortCutGeometryShadingPassTemplate",
                     "Enabled": true,
                     "Connections": [
-                        {
-                            "LocalSlot": "HairColorRenderTarget",
-                            "AttachmentRef": {
-                                "Pass": "This",
-                                "Attachment": "HairColorRenderTarget"
-                            }
-                        },
                         {
                             // The final render target - this is MSAA mode RT - would it be cheaper to
                             // use non-MSAA and then copy?

+ 10 - 1
Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryDepthAlpha.pass

@@ -50,7 +50,7 @@
                 {
                     // This buffer is used as the render target and should be at non-MSAA screen resolution
                     // to make sure no overwork is done.
-                    "Name": "InverseAlphaRTOutput",
+                    "Name": "InverseAlphaRTImage",
                     "SizeSource": {
                         "Source": {
                             "Pass": "Parent",
@@ -85,6 +85,15 @@
                         ]
                     }
                 }
+            ],
+             "Connections": [
+                {
+                    "LocalSlot": "InverseAlphaRTOutput",
+                    "AttachmentRef": {
+                        "Pass": "This",
+                        "Attachment": "InverseAlphaRTImage"
+                    }
+                }
             ],
             "PassData": {
                 "$type": "RasterPassData",

+ 8 - 1
Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryShading.pass

@@ -113,7 +113,7 @@
                 {
                     // The shader hair color render target - important to have at a non-MSAA mode
                     // so that no overwork is done on sampling.
-                    "Name": "HairColorRenderTarget",
+                    "Name": "HairColorImage",
                     "SizeSource": {
                         "Source": {
                             "Pass": "Parent",
@@ -144,6 +144,13 @@
                         "Pass": "This",
                         "Attachment": "BRDFTexture"
                     }
+                },
+                {
+                    "LocalSlot": "HairColorRenderTarget",
+                    "AttachmentRef": {
+                        "Pass": "This",
+                        "Attachment": "HairColorImage"
+                    }
                 }
             ],
             "PassData": {

+ 37 - 21
Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.cpp

@@ -90,6 +90,40 @@ namespace AZ
                 return (RPI::RasterPass::IsEnabled() && m_initialized) ? true : false;
             }
 
+            bool HairGeometryRasterPass::UpdateShaderOptions(const RPI::ShaderVariantId& variantId)
+            {
+                m_currentShaderVariantId = variantId;
+                const RPI::ShaderVariant& shaderVariant = m_shader->GetVariant(m_currentShaderVariantId);
+                RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, m_currentShaderVariantId);
+
+                RPI::Scene* scene = GetScene();
+                if (!scene)
+                {
+                    AZ_Error("Hair Gem", false, "Scene could not be acquired");
+                    return false;
+                }
+                RHI::DrawListTag drawListTag = m_shader->GetDrawListTag();
+                scene->ConfigurePipelineState(drawListTag, pipelineStateDescriptor);
+
+                pipelineStateDescriptor.m_renderAttachmentConfiguration = GetRenderAttachmentConfiguration();
+                pipelineStateDescriptor.m_inputStreamLayout.SetTopology(AZ::RHI::PrimitiveTopology::TriangleList);
+                pipelineStateDescriptor.m_inputStreamLayout.Finalize();
+
+                m_pipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
+                if (!m_pipelineState)
+                {
+                    AZ_Error("Hair Gem", false, "Pipeline state could not be acquired");
+                    return false;
+                }
+
+                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry() && shaderVariant.UseKeyFallback())
+                {
+                    m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_currentShaderVariantId.m_key);
+                }
+                return true;
+            }
+
             bool HairGeometryRasterPass::LoadShaderAndPipelineState()
             {
                 RPI::ShaderReloadNotificationBus::Handler::BusDisconnect();
@@ -135,27 +169,9 @@ namespace AZ
                     }
                 }
 
-                const RPI::ShaderVariant& shaderVariant = m_shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
-                RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
-
-                RPI::Scene* scene = GetScene();
-                if (!scene)
-                {
-                    AZ_Error("Hair Gem", false, "Scene could not be acquired" );
-                    return false;
-                }
-                RHI::DrawListTag drawListTag = m_shader->GetDrawListTag();
-                scene->ConfigurePipelineState(drawListTag, pipelineStateDescriptor);
-
-                pipelineStateDescriptor.m_renderAttachmentConfiguration = GetRenderAttachmentConfiguration();
-                pipelineStateDescriptor.m_inputStreamLayout.SetTopology(AZ::RHI::PrimitiveTopology::TriangleList);
-                pipelineStateDescriptor.m_inputStreamLayout.Finalize();
-
-                m_pipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
-                if (!m_pipelineState)
+                if (!UpdateShaderOptions(m_shader->GetDefaultShaderOptions().GetShaderVariantId()))
                 {
-                    AZ_Error("Hair Gem", false, "Pipeline state could not be acquired");
+                    AZ_Error("Hair Gem", false, "Failed to create pipeline state");
                     return false;
                 }
 
@@ -177,6 +193,7 @@ namespace AZ
             void HairGeometryRasterPass::SchedulePacketBuild(HairRenderObject* hairObject)
             {
                 m_newRenderObjects.insert(hairObject);
+                BuildDrawPacket(hairObject);
             }
 
             bool HairGeometryRasterPass::BuildDrawPacket(HairRenderObject* hairObject)
@@ -249,7 +266,6 @@ namespace AZ
                 for (HairRenderObject* newObject : m_newRenderObjects)
                 {
                     newObject->BindPerObjectSrgForRaster();
-                    BuildDrawPacket(newObject);
                 }
 
                 // Clear the new added objects - BuildDrawPacket should only be carried out once per

+ 4 - 0
Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.h

@@ -81,6 +81,9 @@ namespace AZ
                 // Scope producer functions...
                 void CompileResources(const RHI::FrameGraphCompileContext& context) override;
 
+                //! Updates the shader variant being used by the pass
+                bool UpdateShaderOptions(const RPI::ShaderVariantId& variantId);
+
             protected:
                 HairFeatureProcessor* m_featureProcessor = nullptr;
 
@@ -104,6 +107,7 @@ namespace AZ
                 AZStd::unordered_set<HairRenderObject*> m_newRenderObjects;
 
                 bool m_initialized = false;
+                RPI::ShaderVariantId m_currentShaderVariantId;
             };
 
         } // namespace Hair

+ 8 - 4
Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.cpp

@@ -51,20 +51,24 @@ namespace AZ
             //! Once supported, this will be done via data driven code and the method can be removed.
             void HairPPLLRasterPass::BuildInternal()
             {
-                RasterPass::BuildInternal();    // change this to call parent if the method exists
+                HairGeometryRasterPass::BuildInternal(); // change this to call parent if the method exists
 
                 if (!AcquireFeatureProcessor())
                 {
                     return;
                 }
 
+                // Output
+                AttachBufferToSlot(Name{ "PerPixelLinkedList" }, m_featureProcessor->GetPerPixelListBuffer());
+            }
+
+            void HairPPLLRasterPass::InitializeInternal()
+            {
                 if (!LoadShaderAndPipelineState())
                 {
                     return;
                 }
-
-                // Output
-                AttachBufferToSlot(Name{ "PerPixelLinkedList" }, m_featureProcessor->GetPerPixelListBuffer());
+                HairGeometryRasterPass::InitializeInternal();
             }
 
         } // namespace Hair

+ 1 - 0
Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.h

@@ -47,6 +47,7 @@ namespace AZ
 
                 // Pass behavior overrides
                 void BuildInternal() override;
+                void InitializeInternal() override;                
             };
 
         } // namespace Hair

+ 5 - 7
Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.cpp

@@ -60,8 +60,12 @@ namespace AZ
                 shaderOption.SetValue(o_enableMarschner_TT, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableMarschner_TT });
                 shaderOption.SetValue(o_enableLongtitudeCoeff, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableLongtitudeCoeff });
                 shaderOption.SetValue(o_enableAzimuthCoeff, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableAzimuthCoeff });
+                shaderOption.SetUnspecifiedToDefaultValues();
 
-                m_shaderOptions = shaderOption.GetShaderVariantKeyFallbackValue();
+                if (m_pipelineStateForDraw.GetShaderVariantId() != shaderOption.GetShaderVariantId())
+                {
+                    UpdateShaderOptions(shaderOption.GetShaderVariantId());
+                }
             }
 
             RPI::Ptr<HairPPLLResolvePass> HairPPLLResolvePass::Create(const RPI::PassDescriptor& descriptor)
@@ -117,12 +121,6 @@ namespace AZ
 
                 UpdateGlobalShaderOptions();
 
-                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-                {
-                    m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_shaderOptions);
-                }
-
-
                 SrgBufferDescriptor descriptor = SrgBufferDescriptor(
                     RPI::CommonBufferPoolType::ReadWrite, RHI::Format::Unknown,
                     PPLL_NODE_SIZE, RESERVED_PIXELS_FOR_OIT,

+ 0 - 1
Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.h

@@ -76,7 +76,6 @@ namespace AZ
 
                 HairGlobalSettings m_hairGlobalSettings;
                 HairFeatureProcessor* m_featureProcessor = nullptr;
-                AZ::RPI::ShaderVariantKey m_shaderOptions;
             };
 
         } // namespace Hair

+ 5 - 1
Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.cpp

@@ -35,14 +35,18 @@ namespace AZ
 
             void HairShortCutGeometryDepthAlphaPass::BuildInternal()
             {
-                RasterPass::BuildInternal();    // change this to call parent if the method exists
+                HairGeometryRasterPass::BuildInternal(); // change this to call parent if the method exists
 
                 if (!AcquireFeatureProcessor())
                 {
                     return;
                 }
+            }
 
+            void HairShortCutGeometryDepthAlphaPass::InitializeInternal()
+            {
                 LoadShaderAndPipelineState();
+                HairGeometryRasterPass::InitializeInternal();
             }
 
         } // namespace Hair

+ 1 - 0
Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.h

@@ -42,6 +42,7 @@ namespace AZ
 
                 // Pass behavior overrides
                 void BuildInternal() override;
+                void InitializeInternal() override;
             };
 
         } // namespace Hair

Some files were not shown because too many files changed in this diff