Forráskód Böngészése

Add support for specialization constants for shader options (#18019)

- Add support of Specialization Constants for dx12 and vulkan
- Add configure function to specialize a pipeline descriptor
- Refactor ShaderVariantAsset job to batch variants
- Add signing of DX12 shader bytecode at runtime
- Add support for function constants on Metal

Signed-off-by: Akio Gaule <[email protected]>
Akio Gaule 1 éve
szülő
commit
fdabdc28e1
100 módosított fájl, 1947 hozzáadás és 283 törlés
  1. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Android/shader_build_options.settings
  2. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Linux/shader_build_options.settings
  3. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Mac/shader_build_options.settings
  4. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/dx12/shader_build_options.settings
  5. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/shader_build_options.settings
  6. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/vulkan/shader_build_options.settings
  7. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/iOS/shader_build_options.settings
  8. 1 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/README.md
  9. 0 0
      Gems/Atom/Asset/Shader/Assets/Config/Shader/shader_build_options.settings
  10. 18 2
      Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.cpp
  11. 1 1
      Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.h
  12. 3 3
      Gems/Atom/Asset/Shader/Code/Source/Editor/AzslShaderBuilderSystemComponent.cpp
  13. 18 2
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderAssetBuilder.cpp
  14. 13 7
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuildArgumentsManager.h
  15. 4 2
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.cpp
  16. 2 1
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.h
  17. 124 45
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.cpp
  18. 4 0
      Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.h
  19. 1 1
      Gems/Atom/Asset/Shader/Registry/atom_shaders.setreg
  20. 3 3
      Gems/Atom/Feature/Common/Code/Source/CoreLights/DepthExponentiationPass.cpp
  21. 9 8
      Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp
  22. 2 2
      Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h
  23. 2 2
      Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp
  24. 5 6
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/BlendColorGradingLutsPass.cpp
  25. 2 2
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/PostProcessingShaderOptionBase.cpp
  26. 2 7
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.cpp
  27. 0 1
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.h
  28. 1 1
      Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp
  29. 14 27
      Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.cpp
  30. 0 4
      Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.h
  31. 2 2
      Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshDispatchItem.cpp
  32. 11 2
      Gems/Atom/Feature/Common/Code/Source/SkyAtmosphere/SkyAtmospherePass.cpp
  33. 3 1
      Gems/Atom/RHI/Code/Include/Atom/RHI.Edit/ShaderPlatformInterface.h
  34. 10 4
      Gems/Atom/RHI/Code/Include/Atom/RHI/PipelineStateDescriptor.h
  35. 45 0
      Gems/Atom/RHI/Code/Include/Atom/RHI/SpecializationConstant.h
  36. 23 16
      Gems/Atom/RHI/Code/Source/RHI/PipelineStateDescriptor.cpp
  37. 31 0
      Gems/Atom/RHI/Code/Source/RHI/SpecializationConstant.cpp
  38. 2 0
      Gems/Atom/RHI/Code/atom_rhi_public_files.cmake
  39. 12 0
      Gems/Atom/RHI/DX12/Code/CMakeLists.txt
  40. 13 0
      Gems/Atom/RHI/DX12/Code/Include/Atom/RHI.Reflect/DX12/ShaderStageFunction.h
  41. 73 4
      Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  42. 5 2
      Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  43. 18 2
      Gems/Atom/RHI/DX12/Code/Source/RHI.Reflect/ShaderStageFunction.cpp
  44. 2 0
      Gems/Atom/RHI/DX12/Code/Source/RHI/DX12.h
  45. 14 6
      Gems/Atom/RHI/DX12/Code/Source/RHI/PipelineState.cpp
  46. 7 2
      Gems/Atom/RHI/DX12/Code/Source/RHI/RayTracingPipelineState.cpp
  47. 237 0
      Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.cpp
  48. 36 0
      Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.h
  49. 2 0
      Gems/Atom/RHI/DX12/Code/atom_rhi_dx12_private_common_files.cmake
  50. 12 0
      Gems/Atom/RHI/DX12/Code/openssl_md5_files.cmake
  51. 38 0
      Gems/Atom/RHI/DX12/External/md5/README.md
  52. 289 0
      Gems/Atom/RHI/DX12/External/md5/openssl/md5.c
  53. 54 0
      Gems/Atom/RHI/DX12/External/md5/openssl/md5.h
  54. 2 1
      Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  55. 2 1
      Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  56. 41 7
      Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.cpp
  57. 3 2
      Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.h
  58. 2 1
      Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  59. 2 1
      Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  60. 2 1
      Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp
  61. 2 1
      Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.h
  62. 3 1
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.cpp
  63. 3 0
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.h
  64. 11 3
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/RayTracingPipelineState.cpp
  65. 66 0
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.cpp
  66. 37 0
      Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.h
  67. 2 0
      Gems/Atom/RHI/Vulkan/Code/atom_rhi_vulkan_private_common_files.cmake
  68. 5 1
      Gems/Atom/RPI/Code/Include/Atom/RPI.Edit/Shader/ShaderVariantAssetCreator.h
  69. 3 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h
  70. 4 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h
  71. 4 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h
  72. 22 2
      Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h
  73. 18 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAsset.h
  74. 3 0
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAssetCreator.h
  75. 18 1
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderOptionGroupLayout.h
  76. 2 2
      Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderVariantAsset.h
  77. 5 1
      Gems/Atom/RPI/Code/Source/RPI.Edit/Shader/ShaderVariantAssetCreator.cpp
  78. 1 1
      Gems/Atom/RPI/Code/Source/RPI.Public/MeshDrawPacket.cpp
  79. 19 1
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp
  80. 17 3
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp
  81. 1 1
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/Specific/ImageAttachmentPreviewPass.cpp
  82. 27 16
      Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp
  83. 3 2
      Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp
  84. 78 2
      Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp
  85. 24 3
      Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAsset.cpp
  86. 17 0
      Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAssetCreator.cpp
  87. 37 3
      Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderOptionGroupLayout.cpp
  88. 212 5
      Gems/Atom/RPI/Code/Tests/Shader/ShaderTests.cpp
  89. 2 2
      Gems/AtomLyIntegration/EditorModeFeedback/Code/Source/Draw/EditorStateMeshDrawPacket.cpp
  90. 0 14
      Gems/AtomTressFX/Assets/Passes/HairParentShortCutPass.pass
  91. 10 1
      Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryDepthAlpha.pass
  92. 8 1
      Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryShading.pass
  93. 37 21
      Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.cpp
  94. 4 0
      Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.h
  95. 8 4
      Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.cpp
  96. 1 0
      Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.h
  97. 5 7
      Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.cpp
  98. 0 1
      Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.h
  99. 5 1
      Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.cpp
  100. 1 0
      Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.h

+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Android/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Android/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Linux/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Linux/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Mac/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Mac/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Windows/dx12/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/dx12/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Windows/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/Windows/vulkan/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/Windows/vulkan/shader_build_options.settings


+ 0 - 0
Gems/Atom/Asset/Shader/Config/Platform/iOS/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/Platform/iOS/shader_build_options.settings


+ 1 - 0
Gems/Atom/Asset/Shader/Assets/Config/Shader/README.md

@@ -0,0 +1 @@
+The files in these folders are config files for building shaders, not real assets that will be used at runtime. They are in the "Assets" folder so we can create source dependencies with the azsl shader files to be able to detect changes in order to rebuild the shaders.

+ 0 - 0
Gems/Atom/Asset/Shader/Config/shader_build_options.json → Gems/Atom/Asset/Shader/Assets/Config/Shader/shader_build_options.settings


+ 18 - 2
Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.cpp

@@ -900,7 +900,10 @@ namespace AZ
             return CompileToFileAndPrepareJsonDocument(output, "--options", "options.json") == BuildResult::Success;
             return CompileToFileAndPrepareJsonDocument(output, "--options", "options.json") == BuildResult::Success;
         }
         }
 
 
-        bool AzslCompiler::ParseOptionsPopulateOptionGroupLayout(const rapidjson::Document& input, RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout) const
+        bool AzslCompiler::ParseOptionsPopulateOptionGroupLayout(
+            const rapidjson::Document& input,
+            RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout,
+            bool& outUseSpecializationConstants) const
         {
         {
             auto totalBitOffset = (uint32_t) 0u;
             auto totalBitOffset = (uint32_t) 0u;
 
 
@@ -933,6 +936,12 @@ namespace AZ
                     return false;
                     return false;
                 };
                 };
 
 
+            outUseSpecializationConstants = false;
+            if (input.HasMember("specializationConstants"))
+            {
+                outUseSpecializationConstants = input["specializationConstants"].GetBool();
+            }
+
             const rapidjson::Value& shaderOptions = input["ShaderOptions"];
             const rapidjson::Value& shaderOptions = input["ShaderOptions"];
             AZ_Assert(shaderOptions.IsArray(), "Attribute ShaderOptions must be an array");
             AZ_Assert(shaderOptions.IsArray(), "Attribute ShaderOptions must be an array");
 
 
@@ -1037,13 +1046,20 @@ namespace AZ
                         cost = optionEntry["costImpact"].GetUint();
                         cost = optionEntry["costImpact"].GetUint();
                     }
                     }
 
 
+                    int specializationId = -1;
+                    if (optionEntry.HasMember("specializationId"))
+                    {
+                        specializationId = optionEntry["specializationId"].GetInt();
+                    }
+
                     RPI::ShaderOptionDescriptor shaderOption(Name(optionName), 
                     RPI::ShaderOptionDescriptor shaderOption(Name(optionName), 
                                                              optionType,
                                                              optionType,
                                                              keyOffset,
                                                              keyOffset,
                                                              order,
                                                              order,
                                                              idIndexList,
                                                              idIndexList,
                                                              defaultValueId,
                                                              defaultValueId,
-                                                             cost);
+                                                             cost,
+                                                             specializationId);
 
 
                     if (!shaderOptionGroupLayout->AddShaderOption(shaderOption))
                     if (!shaderOptionGroupLayout->AddShaderOption(shaderOption))
                     {
                     {

+ 1 - 1
Gems/Atom/Asset/Shader/Code/Source/Editor/AzslCompiler.h

@@ -61,7 +61,7 @@ namespace AZ
             //! make sense of a --srg json document and fill up the srg data container
             //! make sense of a --srg json document and fill up the srg data container
             bool ParseSrgPopulateSrgData(const rapidjson::Document& input, SrgDataContainer& outSrgData) const;
             bool ParseSrgPopulateSrgData(const rapidjson::Document& input, SrgDataContainer& outSrgData) const;
             //! make sense of a --option json document and fill up the shader option group layout
             //! make sense of a --option json document and fill up the shader option group layout
-            bool ParseOptionsPopulateOptionGroupLayout(const rapidjson::Document& input, RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout) const;
+            bool ParseOptionsPopulateOptionGroupLayout(const rapidjson::Document& input, RPI::Ptr<RPI::ShaderOptionGroupLayout>& shaderOptionGroupLayout, bool& outSpecializationConstants) const;
             //! make sense of a --bindingdep json documment and fill up the binding dependencies object
             //! make sense of a --bindingdep json documment and fill up the binding dependencies object
             bool ParseBindingdepPopulateBindingDependencies(const rapidjson::Document& input, BindingDependencies& bindingDependencies) const;
             bool ParseBindingdepPopulateBindingDependencies(const rapidjson::Document& input, BindingDependencies& bindingDependencies) const;
             //! make sense of a --srg json document and fill up the root constant data
             //! make sense of a --srg json document and fill up the root constant data

+ 3 - 3
Gems/Atom/Asset/Shader/Code/Source/Editor/AzslShaderBuilderSystemComponent.cpp

@@ -83,7 +83,7 @@ namespace AZ
             // Register Shader Asset Builder
             // Register Shader Asset Builder
             AssetBuilderSDK::AssetBuilderDesc shaderAssetBuilderDescriptor;
             AssetBuilderSDK::AssetBuilderDesc shaderAssetBuilderDescriptor;
             shaderAssetBuilderDescriptor.m_name = "Shader Asset Builder";
             shaderAssetBuilderDescriptor.m_name = "Shader Asset Builder";
-            shaderAssetBuilderDescriptor.m_version = 123; // Metal shader debug symbols controlled via settings registry or shader build arguments
+            shaderAssetBuilderDescriptor.m_version = 124; // Add specialization constants for shader options
             shaderAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern( AZStd::string::format("*.%s", RPI::ShaderSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
             shaderAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern( AZStd::string::format("*.%s", RPI::ShaderSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
             shaderAssetBuilderDescriptor.m_busId = azrtti_typeid<ShaderAssetBuilder>();
             shaderAssetBuilderDescriptor.m_busId = azrtti_typeid<ShaderAssetBuilder>();
             shaderAssetBuilderDescriptor.m_createJobFunction = AZStd::bind(&ShaderAssetBuilder::CreateJobs, &m_shaderAssetBuilder, AZStd::placeholders::_1, AZStd::placeholders::_2);
             shaderAssetBuilderDescriptor.m_createJobFunction = AZStd::bind(&ShaderAssetBuilder::CreateJobs, &m_shaderAssetBuilder, AZStd::placeholders::_1, AZStd::placeholders::_2);
@@ -108,7 +108,7 @@ namespace AZ
                 shaderVariantAssetBuilderDescriptor.m_name = "Shader Variant Asset Builder";
                 shaderVariantAssetBuilderDescriptor.m_name = "Shader Variant Asset Builder";
                 // Both "Shader Variant Asset Builder" and "Shader Asset Builder" produce ShaderVariantAsset products. If you update
                 // Both "Shader Variant Asset Builder" and "Shader Asset Builder" produce ShaderVariantAsset products. If you update
                 // ShaderVariantAsset you will need to update BOTH version numbers, not just "Shader Variant Asset Builder".
                 // ShaderVariantAsset you will need to update BOTH version numbers, not just "Shader Variant Asset Builder".
-                shaderVariantAssetBuilderDescriptor.m_version = 40; // Metal shader debug symbols controlled via settings registry or shader build arguments
+                shaderVariantAssetBuilderDescriptor.m_version = 41; // Add specialization constants for shader options
                 shaderVariantAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", HashedVariantListSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", HashedVariantListSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", HashedVariantInfoSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantAssetBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", HashedVariantInfoSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantAssetBuilderDescriptor.m_busId = azrtti_typeid<ShaderVariantAssetBuilder>();
                 shaderVariantAssetBuilderDescriptor.m_busId = azrtti_typeid<ShaderVariantAssetBuilder>();
@@ -121,7 +121,7 @@ namespace AZ
                 // Register Shader Variant List Builder
                 // Register Shader Variant List Builder
                 AssetBuilderSDK::AssetBuilderDesc shaderVariantListBuilderDescriptor;
                 AssetBuilderSDK::AssetBuilderDesc shaderVariantListBuilderDescriptor;
                 shaderVariantListBuilderDescriptor.m_name = "Shader Variant List Builder";
                 shaderVariantListBuilderDescriptor.m_name = "Shader Variant List Builder";
-                shaderVariantListBuilderDescriptor.m_version = 1; // First version of ShaderVariantListBuilder
+                shaderVariantListBuilderDescriptor.m_version = 2; // Add specialization constants for shader options
                 shaderVariantListBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", RPI::ShaderVariantListSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantListBuilderDescriptor.m_patterns.push_back(AssetBuilderSDK::AssetBuilderPattern(AZStd::string::format("*.%s", RPI::ShaderVariantListSourceData::Extension), AssetBuilderSDK::AssetBuilderPattern::PatternType::Wildcard));
                 shaderVariantListBuilderDescriptor.m_busId = azrtti_typeid<ShaderVariantListBuilder>();
                 shaderVariantListBuilderDescriptor.m_busId = azrtti_typeid<ShaderVariantListBuilder>();
                 shaderVariantListBuilderDescriptor.m_createJobFunction = AZStd::bind(&ShaderVariantListBuilder::CreateJobs, &m_shaderVariantListBuilder, AZStd::placeholders::_1, AZStd::placeholders::_2);
                 shaderVariantListBuilderDescriptor.m_createJobFunction = AZStd::bind(&ShaderVariantListBuilder::CreateJobs, &m_shaderVariantListBuilder, AZStd::placeholders::_1, AZStd::placeholders::_2);

+ 18 - 2
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderAssetBuilder.cpp

@@ -202,6 +202,15 @@ namespace AZ
                 response.m_sourceFileDependencyList.emplace_back(AZStd::move(includeFileDependency));
                 response.m_sourceFileDependencyList.emplace_back(AZStd::move(includeFileDependency));
             }
             }
 
 
+            // Add the shader_build_option files as source dependencies
+            AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> configFiles = ShaderBuildArgumentsManager::DiscoverConfigurationFiles();
+            for (const auto& pair : configFiles)
+            {
+                AssetBuilderSDK::SourceFileDependency includeFileDependency;
+                includeFileDependency.m_sourceFileDependencyPath = pair.second.c_str();
+                response.m_sourceFileDependencyList.emplace_back(AZStd::move(includeFileDependency));
+            }
+
             for (const AssetBuilderSDK::PlatformInfo& platformInfo : request.m_enabledPlatforms)
             for (const AssetBuilderSDK::PlatformInfo& platformInfo : request.m_enabledPlatforms)
             {
             {
                 AZ_TraceContext("For platform", platformInfo.m_identifier.data());
                 AZ_TraceContext("For platform", platformInfo.m_identifier.data());
@@ -497,9 +506,14 @@ namespace AZ
                     RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
                     RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
                     BindingDependencies bindingDependencies;
                     BindingDependencies bindingDependencies;
                     RootConstantData rootConstantData;
                     RootConstantData rootConstantData;
+                    bool usesSpecializationConstants = false;
                     AssetBuilderSDK::ProcessJobResultCode azslJsonReadResult = ShaderBuilderUtility::PopulateAzslDataFromJsonFiles(
                     AssetBuilderSDK::ProcessJobResultCode azslJsonReadResult = ShaderBuilderUtility::PopulateAzslDataFromJsonFiles(
                         ShaderAssetBuilderName, subProductsPaths, azslData, srgLayoutList,
                         ShaderAssetBuilderName, subProductsPaths, azslData, srgLayoutList,
-                        shaderOptionGroupLayout, bindingDependencies, rootConstantData, request.m_tempDirPath);
+                        shaderOptionGroupLayout,
+                        bindingDependencies,
+                        rootConstantData,
+                        request.m_tempDirPath,
+                        usesSpecializationConstants);
                     if (azslJsonReadResult != AssetBuilderSDK::ProcessJobResult_Success)
                     if (azslJsonReadResult != AssetBuilderSDK::ProcessJobResult_Success)
                     {
                     {
                         response.m_resultCode = azslJsonReadResult;
                         response.m_resultCode = azslJsonReadResult;
@@ -507,6 +521,7 @@ namespace AZ
                     }
                     }
 
 
                     shaderAssetCreator.SetSrgLayoutList(srgLayoutList);
                     shaderAssetCreator.SetSrgLayoutList(srgLayoutList);
+                    shaderAssetCreator.SetUseSpecializationConstants(usesSpecializationConstants);
 
 
                     if (!finalShaderOptionGroupLayout)
                     if (!finalShaderOptionGroupLayout)
                     {
                     {
@@ -665,7 +680,8 @@ namespace AZ
                         variantAssetId,
                         variantAssetId,
                         superVariantAzslinStemName,
                         superVariantAzslinStemName,
                         hlslFullPath,
                         hlslFullPath,
-                        hlslSourceCode};
+                        hlslSourceCode,
+                        usesSpecializationConstants };
 
 
                     // Preserve the Temp folder when shaders are compiled with debug symbols
                     // Preserve the Temp folder when shaders are compiled with debug symbols
                     // or because the ShaderSourceData has m_keepTempFolder set to true.
                     // or because the ShaderSourceData has m_keepTempFolder set to true.

+ 13 - 7
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuildArgumentsManager.h

@@ -65,8 +65,8 @@ namespace AZ
             // The value of this registry key is customizable by the user.
             // The value of this registry key is customizable by the user.
             static constexpr char ConfigPathRegistryKey[] = "/O3DE/Atom/Shaders/Build/ConfigPath";
             static constexpr char ConfigPathRegistryKey[] = "/O3DE/Atom/Shaders/Build/ConfigPath";
 
 
-            static constexpr char DefaultConfigPathDirectory[] = "@gemroot:AtomShader@/Config";
-            static constexpr char ShaderBuildOptionsJson[] = "shader_build_options.json";
+            static constexpr char DefaultConfigPathDirectory[] = "@gemroot:AtomShader@/Assets/Config/Shader";
+            static constexpr char ShaderBuildOptionsJson[] = "shader_build_options.settings";
             static constexpr char PlatformsDir[] = "Platform";
             static constexpr char PlatformsDir[] = "Platform";
 
 
             //! Always loads all the factory arguments provided by the Atom Gem. In addition
             //! Always loads all the factory arguments provided by the Atom Gem. In addition
@@ -105,17 +105,25 @@ namespace AZ
             //! @remark: The "" (global) arguments are never popped, regardless of how many times this function is called.
             //! @remark: The "" (global) arguments are never popped, regardless of how many times this function is called.
             void PopArgumentScope();
             void PopArgumentScope();
 
 
+            //! Finds the shader build config files from the default locations. Returns a map where the key is the name of the scope,
+            //! and the value is a fully qualified file path.
+            //! Remarks: Posible scope names are:
+            //!     "global"
+            //!     "<platform>". Example "Android", "Windows", etc
+            //!     "<platform>.<rhi>". Example "Windows.dx12" or "Windows.vulkan".
+            static AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFiles();
+
         private:
         private:
             friend class ::UnitTest::ShaderBuildArgumentsTests;
             friend class ::UnitTest::ShaderBuildArgumentsTests;
             void Init(AZStd::unordered_map<AZStd::string, AZ::RHI::ShaderBuildArguments> && removeBuildArgumentsMap
             void Init(AZStd::unordered_map<AZStd::string, AZ::RHI::ShaderBuildArguments> && removeBuildArgumentsMap
                     , AZStd::unordered_map<AZStd::string, AZ::RHI::ShaderBuildArguments> && addBuildArgumentsMap);
                     , AZStd::unordered_map<AZStd::string, AZ::RHI::ShaderBuildArguments> && addBuildArgumentsMap);
 
 
             //! @returns A fully qualified path where the factory settings, as provided by Atom, are found.
             //! @returns A fully qualified path where the factory settings, as provided by Atom, are found.
-            AZ::IO::FixedMaxPath GetDefaultConfigDirectoryPath();
+            static AZ::IO::FixedMaxPath GetDefaultConfigDirectoryPath();
 
 
             //! @returns A fully qualified path where the user customized command line arguments are found.
             //! @returns A fully qualified path where the user customized command line arguments are found.
             //!     The returned path will be empty if the user did not customize the path in the registry.
             //!     The returned path will be empty if the user did not customize the path in the registry.
-            AZ::IO::FixedMaxPath GetUserConfigDirectoryPath();
+            static AZ::IO::FixedMaxPath GetUserConfigDirectoryPath();
 
 
             //! @param dirPath Starting directory for the search of  shader_build_options.json files.
             //! @param dirPath Starting directory for the search of  shader_build_options.json files.
             //! @returns A map where the key is the name of the scope, and the value is a fully qualified file path.
             //! @returns A map where the key is the name of the scope, and the value is a fully qualified file path.
@@ -123,9 +131,7 @@ namespace AZ
             //!     "global"
             //!     "global"
             //!     "<platform>". Example "Android", "Windows", etc
             //!     "<platform>". Example "Android", "Windows", etc
             //!     "<platform>.<rhi>". Example "Windows.dx12" or "Windows.vulkan".
             //!     "<platform>.<rhi>". Example "Windows.dx12" or "Windows.vulkan".
-            AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFilesInDirectory(const AZ::IO::FixedMaxPath& dirPath);
-
-            AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFiles();
+            static AZStd::unordered_map<AZStd::string, AZ::IO::FixedMaxPath> DiscoverConfigurationFilesInDirectory(const AZ::IO::FixedMaxPath& dirPath);
 
 
             const AZ::RHI::ShaderBuildArguments& PushArgumentsInternal(const AZStd::string& name, const AZ::RHI::ShaderBuildArguments& arguments);
             const AZ::RHI::ShaderBuildArguments& PushArgumentsInternal(const AZStd::string& name, const AZ::RHI::ShaderBuildArguments& arguments);
 
 

+ 4 - 2
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.cpp

@@ -139,7 +139,8 @@ namespace AZ
                 RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout,
                 RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout,
                 BindingDependencies& bindingDependencies,
                 BindingDependencies& bindingDependencies,
                 RootConstantData& rootConstantData,
                 RootConstantData& rootConstantData,
-                const AZStd::string& tempFolder)
+                const AZStd::string& tempFolder,
+                bool& useSpecializationConstants)
             {
             {
                 AzslCompiler azslc(azslData.m_preprocessedFullPath,  // set the input file for eventual error messages, but the compiler won't be called on it.
                 AzslCompiler azslc(azslData.m_preprocessedFullPath,  // set the input file for eventual error messages, but the compiler won't be called on it.
                                    tempFolder);
                                    tempFolder);
@@ -188,7 +189,8 @@ namespace AZ
 
 
                 // The shader options define what options are available, what are the allowed values/range
                 // The shader options define what options are available, what are the allowed values/range
                 // for each option and what is its default value.
                 // for each option and what is its default value.
-                if (!azslc.ParseOptionsPopulateOptionGroupLayout(outcomes[AzslSubProducts::options].GetValue(), shaderOptionGroupLayout))
+                if (!azslc.ParseOptionsPopulateOptionGroupLayout(
+                        outcomes[AzslSubProducts::options].GetValue(), shaderOptionGroupLayout, useSpecializationConstants))
                 {
                 {
                     AZ_Error(builderName, false, "Failed to find a valid list of shader options!");
                     AZ_Error(builderName, false, "Failed to find a valid list of shader options!");
                     return AssetBuilderSDK::ProcessJobResult_Failed;
                     return AssetBuilderSDK::ProcessJobResult_Failed;

+ 2 - 1
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderBuilderUtility.h

@@ -71,7 +71,8 @@ namespace AZ
                 RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout,
                 RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout,
                 BindingDependencies& bindingDependencies,
                 BindingDependencies& bindingDependencies,
                 RootConstantData& rootConstantData,
                 RootConstantData& rootConstantData,
-                const AZStd::string& tempFolder);
+                const AZStd::string& tempFolder,
+                bool& useSpecializationConstants);
 
 
 
 
             RHI::ShaderHardwareStage ToAssetBuilderShaderType(RPI::ShaderStageType stageType);
             RHI::ShaderHardwareStage ToAssetBuilderShaderType(RPI::ShaderStageType stageType);

+ 124 - 45
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.cpp

@@ -175,6 +175,8 @@ namespace AZ
                 return;
                 return;
             }
             }
 
 
+            AZStd::string hashedVariantInfoDescriptorString;
+            RPI::JsonUtils::SaveObjectToJsonString(hashedVariantInfoDescriptor, hashedVariantInfoDescriptorString);
             AZStd::string hashedVariantInfoParentPath(request.m_watchFolder.data());
             AZStd::string hashedVariantInfoParentPath(request.m_watchFolder.data());
             AZStd::string hashedVariantListFullPath = GetHashedVariantListPathFromVariantInfoPath(hashedVariantInfoParentPath, hashedVariantInfoRelativePath);
             AZStd::string hashedVariantListFullPath = GetHashedVariantListPathFromVariantInfoPath(hashedVariantInfoParentPath, hashedVariantInfoRelativePath);
             
             
@@ -193,6 +195,9 @@ namespace AZ
             
             
                 jobDescriptor.m_jobKey = GetShaderVariantAssetJobKey();
                 jobDescriptor.m_jobKey = GetShaderVariantAssetJobKey();
                 jobDescriptor.SetPlatformIdentifier(info.m_identifier.data());
                 jobDescriptor.SetPlatformIdentifier(info.m_identifier.data());
+
+                // Add the content of the hashedVariantInfo file as a parameter to avoid reading it again.
+                jobDescriptor.m_jobParameters.emplace(ShaderVariantInfoJobParam, hashedVariantInfoDescriptorString);
             
             
                 // The ShaderVariantAssets should be built AFTER the ShaderVariantTreeAsset.
                 // The ShaderVariantAssets should be built AFTER the ShaderVariantTreeAsset.
                 // With "OrderOnly" dependency, We make sure ShaderVariantTreeAsset completes before ShaderVariantAsset runs,
                 // With "OrderOnly" dependency, We make sure ShaderVariantTreeAsset completes before ShaderVariantAsset runs,
@@ -239,7 +244,8 @@ namespace AZ
             const AssetBuilderSDK::PlatformInfo& platformInfo,
             const AssetBuilderSDK::PlatformInfo& platformInfo,
             const AzslCompiler& azslCompiler,
             const AzslCompiler& azslCompiler,
             const AZStd::string& shaderSourceFileFullPath,
             const AZStd::string& shaderSourceFileFullPath,
-            const RPI::SupervariantIndex supervariantIndex)
+            const RPI::SupervariantIndex supervariantIndex,
+            bool& useSpecializationConstants)
         {
         {
             auto optionsGroupPathOutcome = ShaderBuilderUtility::ObtainBuildArtifactPathFromShaderAssetBuilder(
             auto optionsGroupPathOutcome = ShaderBuilderUtility::ObtainBuildArtifactPathFromShaderAssetBuilder(
                 shaderPlatformInterface->GetAPIUniqueIndex(), platformInfo.m_identifier, shaderSourceFileFullPath, supervariantIndex.GetIndex(),
                 shaderPlatformInterface->GetAPIUniqueIndex(), platformInfo.m_identifier, shaderSourceFileFullPath, supervariantIndex.GetIndex(),
@@ -259,7 +265,8 @@ namespace AZ
                 AZ_Error(ShaderVariantAssetBuilderName, false, "%s", jsonOutcome.GetError().c_str());
                 AZ_Error(ShaderVariantAssetBuilderName, false, "%s", jsonOutcome.GetError().c_str());
                 return nullptr;
                 return nullptr;
             }
             }
-            if (!azslCompiler.ParseOptionsPopulateOptionGroupLayout(jsonOutcome.GetValue(), shaderOptionGroupLayout))
+            if (!azslCompiler.ParseOptionsPopulateOptionGroupLayout(
+                    jsonOutcome.GetValue(), shaderOptionGroupLayout, useSpecializationConstants))
             {
             {
                 AZ_Error(ShaderVariantAssetBuilderName, false, "Failed to find a valid list of shader options!");
                 AZ_Error(ShaderVariantAssetBuilderName, false, "Failed to find a valid list of shader options!");
                 return nullptr;
                 return nullptr;
@@ -473,28 +480,63 @@ namespace AZ
             ShaderBuilderUtility::GetAbsolutePathToAzslFile(shaderSourceFileFullPath, shaderSourceDescriptor.m_source, azslFullPath);
             ShaderBuilderUtility::GetAbsolutePathToAzslFile(shaderSourceFileFullPath, shaderSourceDescriptor.m_source, azslFullPath);
             AzslCompiler azslc(azslFullPath, request.m_tempDirPath);
             AzslCompiler azslc(azslFullPath, request.m_tempDirPath);
 
 
+            auto supervariantList = ShaderBuilderUtility::GetSupervariantListFromShaderSourceData(shaderSourceDescriptor);
+
             AZStd::string previousLoopApiName;
             AZStd::string previousLoopApiName;
+            bool usesVariants = false;
             for (RHI::ShaderPlatformInterface* shaderPlatformInterface : platformInterfaces)
             for (RHI::ShaderPlatformInterface* shaderPlatformInterface : platformInterfaces)
             {
             {
                 auto thisLoopApiName = shaderPlatformInterface->GetAPIName().GetStringView();
                 auto thisLoopApiName = shaderPlatformInterface->GetAPIName().GetStringView();
-                RPI::Ptr<RPI::ShaderOptionGroupLayout> loopLocal_ShaderOptionGroupLayout =
-                    LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
-                        shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, RPI::DefaultSupervariantIndex);
-                if (!loopLocal_ShaderOptionGroupLayout)
-                {
-                    response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
-                    return;
-                }
-                if (shaderOptionGroupLayout && shaderOptionGroupLayout->GetHash() != loopLocal_ShaderOptionGroupLayout->GetHash())
+                for (uint32_t supervariantIndexCounter = 0; supervariantIndexCounter < supervariantList.size(); ++supervariantIndexCounter)
                 {
                 {
-                    AZ_Error(ShaderVariantAssetBuilderName, false, "There was a discrepancy in shader options between %s and %s", previousLoopApiName.c_str(), thisLoopApiName.data());
-                    response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
-                    return;
+                    RPI::SupervariantIndex supervariantIndex(supervariantIndexCounter);
+                    bool usesSpecialization = false;
+                    RPI::Ptr<RPI::ShaderOptionGroupLayout> loopLocal_ShaderOptionGroupLayout =
+                        LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
+                            shaderPlatformInterface,
+                            request.m_platformInfo,
+                            azslc,
+                            shaderSourceFileFullPath,
+                            supervariantIndex,
+                            usesSpecialization);
+                    if (!loopLocal_ShaderOptionGroupLayout)
+                    {
+                        response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                        return;
+                    }
+                    if (shaderOptionGroupLayout && shaderOptionGroupLayout->GetHash() != loopLocal_ShaderOptionGroupLayout->GetHash())
+                    {
+                        AZ_Error(
+                            ShaderVariantAssetBuilderName,
+                            false,
+                            "There was a discrepancy in shader options between %s and %s",
+                            previousLoopApiName.c_str(),
+                            thisLoopApiName.data());
+                        response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                        return;
+                    }
+
+                    // Check if there's a supervariant that needs to generate the variants
+                    if (!usesSpecialization || !loopLocal_ShaderOptionGroupLayout->IsFullySpecialized())
+                    {
+                        usesVariants = true;
+                    }
+                    shaderOptionGroupLayout = loopLocal_ShaderOptionGroupLayout;
                 }
                 }
-                shaderOptionGroupLayout = loopLocal_ShaderOptionGroupLayout;
                 previousLoopApiName = thisLoopApiName;
                 previousLoopApiName = thisLoopApiName;
             }
             }
 
 
+            if (!usesVariants)
+            {
+                // No need to create the variant tree since all supervariants are fully specialized. Exit gracefully.
+                AZ_TracePrintf(
+                    ShaderVariantAssetBuilderName,
+                    "No azshadervarianttree is produced on behalf of %s because all valid RHI backends are using specialization constants for shader options.\n",
+                    shaderSourceFileFullPath.c_str());
+                response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Success;
+                return;
+            }
+
             AZStd::vector<RPI::ShaderVariantListSourceData::VariantInfo> variantInfos;
             AZStd::vector<RPI::ShaderVariantListSourceData::VariantInfo> variantInfos;
             variantInfos.reserve(hashedVariantListDescriptor.m_hashedVariants.size());
             variantInfos.reserve(hashedVariantListDescriptor.m_hashedVariants.size());
             for (const auto& hashedVariantInfo : hashedVariantListDescriptor.m_hashedVariants)
             for (const auto& hashedVariantInfo : hashedVariantListDescriptor.m_hashedVariants)
@@ -533,8 +575,7 @@ namespace AZ
 
 
             AZ_TracePrintf(ShaderVariantAssetBuilderName, "Shader Variant Tree Asset [%s] compiled successfully.\n", assetPath.c_str());
             AZ_TracePrintf(ShaderVariantAssetBuilderName, "Shader Variant Tree Asset [%s] compiled successfully.\n", assetPath.c_str());
 
 
-            response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Success;
- 
+            response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Success; 
         }
         }
 
 
 
 
@@ -545,8 +586,18 @@ namespace AZ
             AZStd::string hashedVariantInfoFullPath;
             AZStd::string hashedVariantInfoFullPath;
             AZ::StringFunc::Path::ConstructFull(request.m_watchFolder.data(), request.m_sourceFile.data(), hashedVariantInfoFullPath, true);
             AZ::StringFunc::Path::ConstructFull(request.m_watchFolder.data(), request.m_sourceFile.data(), hashedVariantInfoFullPath, true);
 
 
+
+            AZStd::string hashedVariantInfoDescriptorString;
+            if (!request.m_jobDescription.m_jobParameters.contains(ShaderVariantInfoJobParam))
+            {
+                AZ_Error(ShaderVariantAssetBuilderName, false, "Missing job Parameter: ShaderVariantInfoJobParam");
+                response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                return;
+            }
+            hashedVariantInfoDescriptorString = request.m_jobDescription.m_jobParameters.at(ShaderVariantInfoJobParam);
+
             HashedVariantInfoSourceData hashedVariantInfoDescriptor;
             HashedVariantInfoSourceData hashedVariantInfoDescriptor;
-            if (!RPI::JsonUtils::LoadObjectFromFile(hashedVariantInfoFullPath, hashedVariantInfoDescriptor, AZStd::numeric_limits<size_t>::max()))
+            if (!RPI::JsonUtils::LoadObjectFromJsonString(hashedVariantInfoDescriptorString, hashedVariantInfoDescriptor))
             {
             {
                 AZ_Assert(false, "Failed to parse Hashed Variant Info Descriptor JSON [%s]", hashedVariantInfoFullPath.c_str());
                 AZ_Assert(false, "Failed to parse Hashed Variant Info Descriptor JSON [%s]", hashedVariantInfoFullPath.c_str());
                 response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                 response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
@@ -598,12 +649,17 @@ namespace AZ
                 AZ_TraceContext("Platform API", apiName);
                 AZ_TraceContext("Platform API", apiName);
 
 
                 buildArgsManager.PushArgumentScope(apiName);
                 buildArgsManager.PushArgumentScope(apiName);
-                buildArgsManager.PushArgumentScope(shaderSourceDescriptor.m_removeBuildArguments, shaderSourceDescriptor.m_addBuildArguments, shaderSourceDescriptor.m_definitions);
+                buildArgsManager.PushArgumentScope(
+                    shaderSourceDescriptor.m_removeBuildArguments,
+                    shaderSourceDescriptor.m_addBuildArguments,
+                    shaderSourceDescriptor.m_definitions);
 
 
                 // Loop through all the Supervariants.
                 // Loop through all the Supervariants.
-                uint32_t supervariantIndexCounter = 0;
-                for (const auto& supervariantInfo : supervariantList)
+                for (uint32_t supervariantIndexCounter = 0;
+                    supervariantIndexCounter < supervariantList.size();
+                    ++supervariantIndexCounter)
                 {
                 {
+                    const auto& supervariantInfo = supervariantList[supervariantIndexCounter];
                     RPI::SupervariantIndex supervariantIndex(supervariantIndexCounter);
                     RPI::SupervariantIndex supervariantIndex(supervariantIndexCounter);
 
 
                     // Check if we were canceled before we do any heavy processing of
                     // Check if we were canceled before we do any heavy processing of
@@ -614,7 +670,8 @@ namespace AZ
                         return;
                         return;
                     }
                     }
 
 
-                    buildArgsManager.PushArgumentScope(supervariantInfo.m_removeBuildArguments, supervariantInfo.m_addBuildArguments, supervariantInfo.m_definitions);
+                    buildArgsManager.PushArgumentScope(
+                        supervariantInfo.m_removeBuildArguments, supervariantInfo.m_addBuildArguments, supervariantInfo.m_definitions);
 
 
                     AZStd::string shaderStemNamePrefix = shaderFileName;
                     AZStd::string shaderStemNamePrefix = shaderFileName;
                     if (supervariantIndex.GetIndex() > 0)
                     if (supervariantIndex.GetIndex() > 0)
@@ -628,22 +685,40 @@ namespace AZ
                     // 3- hlsl code.
                     // 3- hlsl code.
 
 
                     // 1- ShaderOptionsGroupLayout
                     // 1- ShaderOptionsGroupLayout
+                    // The ShaderOptionsGroupLayout is the same for all platforms and supervariants, but the each supervariant
+                    // can have the use of specialization constants on or off.
+                    bool usesSpecializationConstants = false;
+                    shaderOptionGroupLayout = LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
+                        shaderPlatformInterface,
+                        request.m_platformInfo,
+                        azslc,
+                        shaderSourceFileFullPath,
+                        supervariantIndex,
+                        usesSpecializationConstants);
                     if (!shaderOptionGroupLayout)
                     if (!shaderOptionGroupLayout)
                     {
                     {
-                        shaderOptionGroupLayout =
-                            LoadShaderOptionsGroupLayoutFromShaderAssetBuilder(
-                                shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, supervariantIndex);
-                        if (!shaderOptionGroupLayout)
-                        {
-                            response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
-                            return;
-                        }
+                        response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
+                        return;
+                    }
+
+                    if (usesSpecializationConstants && shaderOptionGroupLayout->IsFullySpecialized())
+                    {
+                        // No need to create the shader variants since all supervariants are fully specialized.
+                        AZ_TracePrintf(
+                            ShaderVariantAssetBuilderName,
+                            "No azshaderVariant is produced on behalf of %s, super variant %s, because it's using specialization "
+                            "constants "
+                            "for shader options.\n",
+                            shaderSourceFileFullPath.c_str(),
+                            supervariantInfo.m_name.GetCStr());
+                        buildArgsManager.PopArgumentScope();
+                        continue;
                     }
                     }
 
 
                     // 2- entryFunctions.
                     // 2- entryFunctions.
                     AzslFunctions azslFunctions;
                     AzslFunctions azslFunctions;
                     LoadShaderFunctionsFromShaderAssetBuilder(
                     LoadShaderFunctionsFromShaderAssetBuilder(
-                        shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, supervariantIndex,  azslFunctions);
+                        shaderPlatformInterface, request.m_platformInfo, azslc, shaderSourceFileFullPath, supervariantIndex, azslFunctions);
                     if (azslFunctions.empty())
                     if (azslFunctions.empty())
                     {
                     {
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
@@ -652,7 +727,7 @@ namespace AZ
                     MapOfStringToStageType shaderEntryPoints;
                     MapOfStringToStageType shaderEntryPoints;
                     if (shaderSourceDescriptor.m_programSettings.m_entryPoints.empty())
                     if (shaderSourceDescriptor.m_programSettings.m_entryPoints.empty())
                     {
                     {
-                        AZ_Error(ShaderVariantAssetBuilderName, false,  "ProgramSettings must specify entry points.");
+                        AZ_Error(ShaderVariantAssetBuilderName, false, "ProgramSettings must specify entry points.");
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                         return;
                         return;
                     }
                     }
@@ -671,7 +746,7 @@ namespace AZ
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                         response.m_resultCode = AssetBuilderSDK::ProcessJobResult_Failed;
                         return;
                         return;
                     }
                     }
-                    
+
                     //! It is important to keep this refcounted pointer outside of the if block to prevent it from being destroyed.
                     //! It is important to keep this refcounted pointer outside of the if block to prevent it from being destroyed.
                     RHI::Ptr<RHI::PipelineLayoutDescriptor> pipelineLayoutDescriptor;
                     RHI::Ptr<RHI::PipelineLayoutDescriptor> pipelineLayoutDescriptor;
                     if (shaderPlatformInterface->VariantCompilationRequiresSrgLayoutData())
                     if (shaderPlatformInterface->VariantCompilationRequiresSrgLayoutData())
@@ -710,16 +785,18 @@ namespace AZ
                     }
                     }
 
 
                     // Setup the shader variant creation context:
                     // Setup the shader variant creation context:
-                    ShaderVariantCreationContext shaderVariantCreationContext =
-                    {
-                        *shaderPlatformInterface, request.m_platformInfo, buildArgsManager.GetCurrentArguments(), request.m_tempDirPath,
-                        shaderSourceDescriptor,
-                        *shaderOptionGroupLayout.get(),
-                        shaderEntryPoints,
-                        Uuid::CreateRandom(),
-                        shaderStemNamePrefix,
-                        hlslSourcePath, hlslCode
-                    };
+                    ShaderVariantCreationContext shaderVariantCreationContext = { *shaderPlatformInterface,
+                                                                                  request.m_platformInfo,
+                                                                                  buildArgsManager.GetCurrentArguments(),
+                                                                                  request.m_tempDirPath,
+                                                                                  shaderSourceDescriptor,
+                                                                                  *shaderOptionGroupLayout.get(),
+                                                                                  shaderEntryPoints,
+                                                                                  Uuid::CreateRandom(),
+                                                                                  shaderStemNamePrefix,
+                                                                                  hlslSourcePath,
+                                                                                  hlslCode,
+                                                                                  usesSpecializationConstants };
 
 
                     // Preserve the Temp folder when shaders are compiled with debug symbols
                     // Preserve the Temp folder when shaders are compiled with debug symbols
                     // or because the ShaderSourceData has m_keepTempFolder set to true.
                     // or because the ShaderSourceData has m_keepTempFolder set to true.
@@ -767,7 +844,6 @@ namespace AZ
                         }
                         }
                     }
                     }
                     buildArgsManager.PopArgumentScope(); // Pop the supervariant build arguments.
                     buildArgsManager.PopArgumentScope(); // Pop the supervariant build arguments.
-                    supervariantIndexCounter++;
                 } // End of supervariant for block
                 } // End of supervariant for block
 
 
                 buildArgsManager.PopArgumentScope(); // Pop the .shader build arguments.
                 buildArgsManager.PopArgumentScope(); // Pop the .shader build arguments.
@@ -923,7 +999,10 @@ namespace AZ
                 RHI::ShaderPlatformInterface::StageDescriptor descriptor;
                 RHI::ShaderPlatformInterface::StageDescriptor descriptor;
                 bool shaderWasCompiled = creationContext.m_shaderPlatformInterface.CompilePlatformInternal(
                 bool shaderWasCompiled = creationContext.m_shaderPlatformInterface.CompilePlatformInternal(
                     creationContext.m_platformInfo, variantShaderSourcePath, shaderEntryName, assetBuilderShaderType,
                     creationContext.m_platformInfo, variantShaderSourcePath, shaderEntryName, assetBuilderShaderType,
-                    creationContext.m_tempDirPath, descriptor, creationContext.m_shaderBuildArguments);
+                    creationContext.m_tempDirPath,
+                    descriptor,
+                    creationContext.m_shaderBuildArguments,
+                    creationContext.m_useSpecializationConstants);
 
 
                 if (!shaderWasCompiled)
                 if (!shaderWasCompiled)
                 {
                 {

+ 4 - 0
Gems/Atom/Asset/Shader/Code/Source/Editor/ShaderVariantAssetBuilder.h

@@ -43,6 +43,7 @@ namespace AZ
             const AZStd::string& m_shaderStemNamePrefix; //<shaderName>-<supervariantName>
             const AZStd::string& m_shaderStemNamePrefix; //<shaderName>-<supervariantName>
             const AZStd::string& m_hlslSourcePath;
             const AZStd::string& m_hlslSourcePath;
             const AZStd::string& m_hlslSourceContent;
             const AZStd::string& m_hlslSourceContent;
+            const bool m_useSpecializationConstants = false;
         };
         };
 
 
 
 
@@ -88,6 +89,9 @@ namespace AZ
             void ShutDown() override { };
             void ShutDown() override { };
 
 
         private:
         private:
+            // Content of the hashedVariantInfo file 
+            static constexpr uint32_t ShaderVariantInfoJobParam = 0;
+
             AZ_DISABLE_COPY_MOVE(ShaderVariantAssetBuilder);
             AZ_DISABLE_COPY_MOVE(ShaderVariantAssetBuilder);
 
 
             static constexpr uint32_t ShaderSourceFilePathJobParam = 1;
             static constexpr uint32_t ShaderSourceFilePathJobParam = 1;

+ 1 - 1
Gems/Atom/Asset/Shader/Registry/atom_shaders.setreg

@@ -4,7 +4,7 @@
             "Shaders": {
             "Shaders": {
                 "BuildVariants": true,
                 "BuildVariants": true,
                 "Build": {
                 "Build": {
-                    "ConfigPath": "@gemroot:AtomShader@/Config"
+                    "ConfigPath": "@gemroot:AtomShader@/Assets/Config/Shader"
                 }
                 }
             }
             }
         }
         }

+ 3 - 3
Gems/Atom/Feature/Common/Code/Source/CoreLights/DepthExponentiationPass.cpp

@@ -48,7 +48,7 @@ namespace AZ
             RPI::ShaderOptionGroup shaderOption = m_shader->CreateShaderOptionGroup();
             RPI::ShaderOptionGroup shaderOption = m_shader->CreateShaderOptionGroup();
             shaderOption.SetValue(m_optionName, m_optionValues[typeIndex]);
             shaderOption.SetValue(m_optionName, m_optionValues[typeIndex]);
 
 
-            if (m_shaderResourceGroup)
+            if (!m_shaderVariant[typeIndex].m_isFullyBaked && m_shaderResourceGroup)
             {
             {
                 m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(shaderOption.GetShaderVariantKeyFallbackValue());
                 m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(shaderOption.GetShaderVariantKeyFallbackValue());
             }
             }
@@ -108,9 +108,9 @@ namespace AZ
                 RPI::ShaderVariant shaderVariant = m_shader->GetVariant(shaderOption.GetShaderVariantId());
                 RPI::ShaderVariant shaderVariant = m_shader->GetVariant(shaderOption.GetShaderVariantId());
 
 
                 RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
                 RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOption);
 
 
-                ShaderVariantInfo variationInfo{shaderVariant.IsFullyBaked(),
+                ShaderVariantInfo variationInfo{!shaderVariant.UseKeyFallback(),
                     m_shader->AcquirePipelineState(pipelineStateDescriptor)
                     m_shader->AcquirePipelineState(pipelineStateDescriptor)
                 };
                 };
                 m_shaderVariant.push_back(AZStd::move(variationInfo));
                 m_shaderVariant.push_back(AZStd::move(variationInfo));

+ 9 - 8
Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp

@@ -87,8 +87,8 @@ namespace AZ
 
 
         void LightCullingTilePreparePass::ChooseShaderVariant()
         void LightCullingTilePreparePass::ChooseShaderVariant()
         {
         {
-            const AZ::RPI::ShaderVariant& shaderVariant = CreateShaderVariant();
-            CreatePipelineStateFromShaderVariant(shaderVariant);
+            auto [shaderVariant, shaderOptions] = CreateShaderVariant();
+            CreatePipelineStateFromShaderVariant(shaderVariant, shaderOptions);
         }
         }
 
 
         AZ::Name LightCullingTilePreparePass::GetMultiSampleName()
         AZ::Name LightCullingTilePreparePass::GetMultiSampleName()
@@ -121,30 +121,31 @@ namespace AZ
         AZ::RPI::ShaderOptionGroup LightCullingTilePreparePass::CreateShaderOptionGroup()
         AZ::RPI::ShaderOptionGroup LightCullingTilePreparePass::CreateShaderOptionGroup()
         {
         {
             RPI::ShaderOptionGroup shaderOptionGroup = m_shader->CreateShaderOptionGroup();
             RPI::ShaderOptionGroup shaderOptionGroup = m_shader->CreateShaderOptionGroup();
-            shaderOptionGroup.SetUnspecifiedToDefaultValues();
             shaderOptionGroup.SetValue(m_msaaOptionName, GetMultiSampleName());
             shaderOptionGroup.SetValue(m_msaaOptionName, GetMultiSampleName());
+            shaderOptionGroup.SetUnspecifiedToDefaultValues();
             return shaderOptionGroup;
             return shaderOptionGroup;
         }
         }
 
 
-        void LightCullingTilePreparePass::CreatePipelineStateFromShaderVariant(const RPI::ShaderVariant& shaderVariant)
+        void LightCullingTilePreparePass::CreatePipelineStateFromShaderVariant(
+            const RPI::ShaderVariant& shaderVariant, const RPI::ShaderOptionGroup& shaderOptions)
         {
         {
             AZ::RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
             AZ::RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptions);
             m_msaaPipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
             m_msaaPipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
             AZ_Error("LightCulling", m_msaaPipelineState, "Failed to acquire pipeline state for shader");
             AZ_Error("LightCulling", m_msaaPipelineState, "Failed to acquire pipeline state for shader");
         }
         }
 
 
-        const AZ::RPI::ShaderVariant& LightCullingTilePreparePass::CreateShaderVariant()
+        AZStd::pair<const AZ::RPI::ShaderVariant&, RPI::ShaderOptionGroup> LightCullingTilePreparePass::CreateShaderVariant()
         {
         {
             RPI::ShaderOptionGroup shaderOptionGroup = CreateShaderOptionGroup();
             RPI::ShaderOptionGroup shaderOptionGroup = CreateShaderOptionGroup();
             const RPI::ShaderVariant& shaderVariant = m_shader->GetVariant(shaderOptionGroup.GetShaderVariantId());
             const RPI::ShaderVariant& shaderVariant = m_shader->GetVariant(shaderOptionGroup.GetShaderVariantId());
 
 
             //Set the fallbackkey
             //Set the fallbackkey
-            if (m_drawSrg)
+            if (shaderVariant.UseKeyFallback() && m_drawSrg)
             {
             {
                 m_drawSrg->SetShaderVariantKeyFallbackValue(shaderOptionGroup.GetShaderVariantKeyFallbackValue());
                 m_drawSrg->SetShaderVariantKeyFallbackValue(shaderOptionGroup.GetShaderVariantKeyFallbackValue());
             }
             }
-            return shaderVariant;
+            return { shaderVariant, shaderOptionGroup };
         }
         }
 
 
         void LightCullingTilePreparePass::SetConstantData()
         void LightCullingTilePreparePass::SetConstantData()

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.h

@@ -66,8 +66,8 @@ namespace AZ
             AZStd::array<float, 2> ComputeUnprojectConstants() const;
             AZStd::array<float, 2> ComputeUnprojectConstants() const;
             AZ::RHI::Size GetDepthBufferDimensions();
             AZ::RHI::Size GetDepthBufferDimensions();
             void ChooseShaderVariant();
             void ChooseShaderVariant();
-            const AZ::RPI::ShaderVariant& CreateShaderVariant();
-            void CreatePipelineStateFromShaderVariant(const RPI::ShaderVariant& shaderVariant);
+            AZStd::pair<const AZ::RPI::ShaderVariant&, RPI::ShaderOptionGroup> CreateShaderVariant();
+            void CreatePipelineStateFromShaderVariant(const RPI::ShaderVariant& shaderVariant, const RPI::ShaderOptionGroup& options);
             void SetConstantData();
             void SetConstantData();
             void OnShaderReloaded();
             void OnShaderReloaded();
 
 

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/MorphTargets/MorphTargetDispatchItem.cpp

@@ -66,13 +66,13 @@ namespace AZ
                 return false;
                 return false;
             }
             }
 
 
-            if (!shaderVariant.IsFullyBaked() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
+            if (shaderVariant.UseKeyFallback() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
             {
             {
                 m_instanceSrg->SetShaderVariantKeyFallbackValue(shaderOptionGroup.GetShaderVariantKeyFallbackValue());
                 m_instanceSrg->SetShaderVariantKeyFallbackValue(shaderOptionGroup.GetShaderVariantKeyFallbackValue());
             }
             }
 
 
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptionGroup);
 
 
 
 
             InitRootConstants(pipelineStateDescriptor.m_pipelineLayoutDescriptor->GetRootConstantsLayout());
             InitRootConstants(pipelineStateDescriptor.m_pipelineLayoutDescriptor->GetRootConstantsLayout());

+ 5 - 6
Gems/Atom/Feature/Common/Code/Source/PostProcessing/BlendColorGradingLutsPass.cpp

@@ -65,10 +65,10 @@ namespace AZ
                 RPI::ShaderVariant shaderVariant = m_shader->GetVariant(shaderOption.GetShaderVariantId());
                 RPI::ShaderVariant shaderVariant = m_shader->GetVariant(shaderOption.GetShaderVariantId());
 
 
                 RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
                 RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOption);
 
 
                 ShaderVariantInfo variantInfo{
                 ShaderVariantInfo variantInfo{
-                    shaderVariant.IsFullyBaked(),
+                    !shaderVariant.UseKeyFallback(),
                     m_shader->AcquirePipelineState(pipelineStateDescriptor)
                     m_shader->AcquirePipelineState(pipelineStateDescriptor)
                 };
                 };
                 m_shaderVariant.push_back(AZStd::move(variantInfo));
                 m_shaderVariant.push_back(AZStd::move(variantInfo));
@@ -91,11 +91,10 @@ namespace AZ
                 m_currentShaderVariantIndex = m_numSourceLuts;
                 m_currentShaderVariantIndex = m_numSourceLuts;
             }
             }
 
 
-            auto shaderOption = m_shader->CreateShaderOptionGroup();
-            shaderOption.SetValue(m_numSourceLutsShaderVariantOptionName, RPI::ShaderOptionValue{ m_numSourceLuts });
-
             if (!m_shaderVariant[m_currentShaderVariantIndex].m_isFullyBaked)
             if (!m_shaderVariant[m_currentShaderVariantIndex].m_isFullyBaked)
             {
             {
+                auto shaderOption = m_shader->CreateShaderOptionGroup();
+                shaderOption.SetValue(m_numSourceLutsShaderVariantOptionName, RPI::ShaderOptionValue{ m_numSourceLuts });
                 m_currentShaderVariantKeyFallbackValue = shaderOption.GetShaderVariantKeyFallbackValue();
                 m_currentShaderVariantKeyFallbackValue = shaderOption.GetShaderVariantKeyFallbackValue();
             }
             }
             m_needToUpdateShaderVariant = false;
             m_needToUpdateShaderVariant = false;
@@ -196,7 +195,7 @@ namespace AZ
                     m_shaderResourceGroup->SetConstant(m_shaderInputSourceLut4ShaperScaleIndex, m_colorGradingShaperParams[3].m_scale);
                     m_shaderResourceGroup->SetConstant(m_shaderInputSourceLut4ShaperScaleIndex, m_colorGradingShaperParams[3].m_scale);
                 }
                 }
 
 
-                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
+                if (!m_shaderVariant[m_currentShaderVariantIndex].m_isFullyBaked && m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
                 {
                 {
                     m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_currentShaderVariantKeyFallbackValue);
                     m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_currentShaderVariantKeyFallbackValue);
                 }
                 }

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/PostProcessing/PostProcessingShaderOptionBase.cpp

@@ -25,7 +25,7 @@ namespace AZ
 
 
             auto shaderVariant = shader->GetVariant(shaderOption.GetShaderVariantId());
             auto shaderVariant = shader->GetVariant(shaderOption.GetShaderVariantId());
 
 
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderOption);
             pipelineStateDescriptor.m_renderAttachmentConfiguration = renderAttachmentConfiguration;
             pipelineStateDescriptor.m_renderAttachmentConfiguration = renderAttachmentConfiguration;
             pipelineStateDescriptor.m_renderStates.m_multisampleState = multisampleState;
             pipelineStateDescriptor.m_renderStates.m_multisampleState = multisampleState;
 
 
@@ -37,7 +37,7 @@ namespace AZ
             pipelineStateDescriptor.m_inputStreamLayout = inputStreamLayout;
             pipelineStateDescriptor.m_inputStreamLayout = inputStreamLayout;
 
 
             m_shaderVariantTable[variationKey].m_pipelineState = shader->AcquirePipelineState(pipelineStateDescriptor);
             m_shaderVariantTable[variationKey].m_pipelineState = shader->AcquirePipelineState(pipelineStateDescriptor);
-            m_shaderVariantTable[variationKey].m_isFullyBaked = shaderVariant.IsFullyBaked();
+            m_shaderVariantTable[variationKey].m_isFullyBaked = !shaderVariant.UseKeyFallback();
         }
         }
 
 
         void PostProcessingShaderOptionBase::UpdateShaderVariant(const AZ::RPI::ShaderOptionGroup& shaderOption)
         void PostProcessingShaderOptionBase::UpdateShaderVariant(const AZ::RPI::ShaderOptionGroup& shaderOption)

+ 2 - 7
Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.cpp

@@ -70,11 +70,6 @@ namespace AZ
             if (m_needToUpdateSRG)
             if (m_needToUpdateSRG)
             {
             {
                 UpdateSRG();
                 UpdateSRG();
-
-                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-                {
-                    m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_currentShaderVariantKeyFallbackValue);
-                }
                 m_needToUpdateSRG = false;
                 m_needToUpdateSRG = false;
             }
             }
 
 
@@ -86,8 +81,8 @@ namespace AZ
             auto shaderOption = m_shader->CreateShaderOptionGroup();
             auto shaderOption = m_shader->CreateShaderOptionGroup();
 
 
             GetCurrentShaderOption(shaderOption);
             GetCurrentShaderOption(shaderOption);
-
-            m_currentShaderVariantKeyFallbackValue = shaderOption.GetShaderVariantKeyFallbackValue();
+            shaderOption.SetUnspecifiedToDefaultValues();
+            UpdateShaderOptions(shaderOption.GetShaderVariantId());
             m_needToUpdateShaderVariant = false;
             m_needToUpdateShaderVariant = false;
             InvalidateSRG();
             InvalidateSRG();
         }
         }

+ 0 - 1
Gems/Atom/Feature/Common/Code/Source/PostProcessing/SMAABasePass.h

@@ -60,7 +60,6 @@ namespace AZ
             // Scope producer functions...
             // Scope producer functions...
             void CompileResources(const RHI::FrameGraphCompileContext& context) override;
             void CompileResources(const RHI::FrameGraphCompileContext& context) override;
 
 
-            AZ::RPI::ShaderVariantKey m_currentShaderVariantKeyFallbackValue;
             bool m_needToUpdateShaderVariant = false;
             bool m_needToUpdateShaderVariant = false;
             bool m_needToUpdateSRG = true;
             bool m_needToUpdateSRG = true;
         };
         };

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp

@@ -95,7 +95,7 @@ namespace AZ
                 auto shader{ AZ::RPI::Shader::FindOrCreate(shaderAsset, supervariantName) };
                 auto shader{ AZ::RPI::Shader::FindOrCreate(shaderAsset, supervariantName) };
                 auto shaderVariant{ shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
                 auto shaderVariant{ shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId) };
                 AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
                 AZ::RHI::PipelineStateDescriptorForRayTracing pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shader->GetDefaultShaderOptions());
                 auto& shaderLib = shaderLibs.emplace_back();
                 auto& shaderLib = shaderLibs.emplace_back();
                 shaderLib.m_shaderAssetId = assetReference.m_assetId;
                 shaderLib.m_shaderAssetId = assetReference.m_assetId;
                 shaderLib.m_shader = shader;
                 shaderLib.m_shader = shader;

+ 14 - 27
Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.cpp

@@ -209,35 +209,37 @@ namespace AZ
 
 
         void DeferredFogPass::UpdateShaderOptions()
         void DeferredFogPass::UpdateShaderOptions()
         {
         {
-            RPI::ShaderOptionGroup shaderOption = m_shader->CreateShaderOptionGroup();
+            RPI::ShaderOptionGroup shaderOptions = m_shader->CreateShaderOptionGroup();
             DeferredFogSettings* fogSettings = GetPassFogSettings();
             DeferredFogSettings* fogSettings = GetPassFogSettings();
 
 
             // [TODO][ATOM-13659] - AZ::Name all over our code base should use init with string and
             // [TODO][ATOM-13659] - AZ::Name all over our code base should use init with string and
             // hash key for the iterations themselves.
             // hash key for the iterations themselves.
-            shaderOption.SetValue(AZ::Name("o_enableFogLayer"),
+            shaderOptions.SetValue(
+                AZ::Name("o_enableFogLayer"),
                 r_fogLayerSupport && fogSettings->GetEnableFogLayerShaderOption() ? AZ::Name("true") : AZ::Name("false"));
                 r_fogLayerSupport && fogSettings->GetEnableFogLayerShaderOption() ? AZ::Name("true") : AZ::Name("false"));
-            shaderOption.SetValue(AZ::Name("o_useNoiseTexture"),
+            shaderOptions.SetValue(
+                AZ::Name("o_useNoiseTexture"),
                 r_fogTurbulenceSupport && fogSettings->GetUseNoiseTextureShaderOption() ? AZ::Name("true") : AZ::Name("false"));
                 r_fogTurbulenceSupport && fogSettings->GetUseNoiseTextureShaderOption() ? AZ::Name("true") : AZ::Name("false"));
             switch (fogSettings->GetFogMode())
             switch (fogSettings->GetFogMode())
             {
             {
             case FogMode::Linear:
             case FogMode::Linear:
-                shaderOption.SetValue(m_fogModeOptionName, AZ::Name("FogMode::LinearMode"));
+                shaderOptions.SetValue(m_fogModeOptionName, AZ::Name("FogMode::LinearMode"));
                 break;
                 break;
             case FogMode::Exponential:
             case FogMode::Exponential:
-                shaderOption.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialMode"));
+                shaderOptions.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialMode"));
                 break;
                 break;
             case FogMode::ExponentialSquared:
             case FogMode::ExponentialSquared:
-                shaderOption.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialSquaredMode"));
+                shaderOptions.SetValue(m_fogModeOptionName, AZ::Name("FogMode::ExponentialSquaredMode"));
                 break;
                 break;
             default:
             default:
                 AZ_Error("DeferredFogPass", false, "Invalid fog mode %d", fogSettings->GetFogMode());
                 AZ_Error("DeferredFogPass", false, "Invalid fog mode %d", fogSettings->GetFogMode());
                 break;
                 break;
             }
             }
-
-            // The following method returns the specified options, as well as fall back values for all 
-            // non-specified options.  If all were set you can use the method GetShaderVariantKey that is 
-            // cheaper but will not make sure the populated values has the default fall back for any unset bit.
-            m_ShaderOptions = shaderOption.GetShaderVariantKeyFallbackValue();
+            shaderOptions.SetUnspecifiedToDefaultValues();
+            if (m_pipelineStateForDraw.GetShaderVariantId() != shaderOptions.GetShaderVariantId())
+            {
+                FullscreenTrianglePass::UpdateShaderOptions(shaderOptions.GetShaderVariantId());
+            }
         }
         }
 
 
         void DeferredFogPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
         void DeferredFogPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
@@ -248,26 +250,11 @@ namespace AZ
             DeferredFogSettings* fogSettings = GetPassFogSettings();
             DeferredFogSettings* fogSettings = GetPassFogSettings();
 
 
             UpdateEnable(fogSettings);
             UpdateEnable(fogSettings);
-
             // Update and set the per pass shader options - this will update the current required
             // Update and set the per pass shader options - this will update the current required
             // shader variant and if doesn't exist, it will be created via the compile stage
             // shader variant and if doesn't exist, it will be created via the compile stage
-            if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-            {
-                UpdateShaderOptions();
-            }
-
+            UpdateShaderOptions();
             SetSrgConstants();
             SetSrgConstants();
         }
         }
-  
-        void DeferredFogPass::CompileResources(const RHI::FrameGraphCompileContext& context)
-        {
-            if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-            {
-                m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_ShaderOptions);
-            }
-
-            FullscreenTrianglePass::CompileResources(context);
-        }
     }   // namespace Render
     }   // namespace Render
 }   // namespace AZ
 }   // namespace AZ
 
 

+ 0 - 4
Gems/Atom/Feature/Common/Code/Source/ScreenSpace/DeferredFogPass.h

@@ -59,7 +59,6 @@ namespace AZ
 
 
             // Scope producer functions...
             // Scope producer functions...
             void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override;
             void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override;
-            void CompileResources(const RHI::FrameGraphCompileContext& context) override;
 
 
             //! Set the binding indices of all members of the SRG
             //! Set the binding indices of all members of the SRG
             void SetSrgBindIndices();
             void SetSrgBindIndices();
@@ -77,9 +76,6 @@ namespace AZ
             // actively pass them to the shader.
             // actively pass them to the shader.
             DeferredFogSettings m_fallbackSettings;
             DeferredFogSettings m_fallbackSettings;
 
 
-            // Shader options for variant generation (texture and layer activation in this case)
-            AZ::RPI::ShaderVariantKey m_ShaderOptions;
-
             // Fog mode option name
             // Fog mode option name
             const AZ::Name m_fogModeOptionName;
             const AZ::Name m_fogModeOptionName;
         };       
         };       

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/SkinnedMesh/SkinnedMeshDispatchItem.cpp

@@ -77,7 +77,7 @@ namespace AZ
             const RPI::ShaderVariant& shaderVariant = m_skinningShader->GetVariant(m_shaderOptionGroup.GetShaderVariantId());
             const RPI::ShaderVariant& shaderVariant = m_skinningShader->GetVariant(m_shaderOptionGroup.GetShaderVariantId());
 
 
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, m_shaderOptionGroup);
 
 
             auto perInstanceSrgLayout = m_skinningShader->FindShaderResourceGroupLayout(AZ::Name{ "InstanceSrg" });
             auto perInstanceSrgLayout = m_skinningShader->FindShaderResourceGroupLayout(AZ::Name{ "InstanceSrg" });
             if (!perInstanceSrgLayout)
             if (!perInstanceSrgLayout)
@@ -94,7 +94,7 @@ namespace AZ
             }
             }
 
 
             // If the shader variation is not fully baked, set the fallback key to use a runtime branch for the shader options
             // If the shader variation is not fully baked, set the fallback key to use a runtime branch for the shader options
-            if (!shaderVariant.IsFullyBaked() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
+            if (shaderVariant.UseKeyFallback() && m_instanceSrg->HasShaderVariantKeyFallbackEntry())
             {
             {
                 m_instanceSrg->SetShaderVariantKeyFallbackValue(m_shaderOptionGroup.GetShaderVariantKeyFallbackValue());
                 m_instanceSrg->SetShaderVariantKeyFallbackValue(m_shaderOptionGroup.GetShaderVariantKeyFallbackValue());
             }
             }

+ 11 - 2
Gems/Atom/Feature/Common/Code/Source/SkyAtmosphere/SkyAtmospherePass.cpp

@@ -207,6 +207,7 @@ namespace AZ::Render
 
 
     void SkyAtmospherePass::UpdatePassData()
     void SkyAtmospherePass::UpdatePassData()
     {
     {
+        uint32_t childIndex = 0;
         for (auto passData : m_atmospherePassData)
         for (auto passData : m_atmospherePassData)
         {
         {
             passData.m_srg->SetConstant(passData.m_index, m_constants);
             passData.m_srg->SetConstant(passData.m_index, m_constants);
@@ -217,8 +218,16 @@ namespace AZ::Render
             passData.m_shaderOptionGroup.SetValue(AZ::Name("o_enableFastAerialPerspective"), AZ::RPI::ShaderOptionValue{ m_fastAerialPerspectiveEnabled });
             passData.m_shaderOptionGroup.SetValue(AZ::Name("o_enableFastAerialPerspective"), AZ::RPI::ShaderOptionValue{ m_fastAerialPerspectiveEnabled });
             passData.m_shaderOptionGroup.SetValue(AZ::Name("o_enableAerialPerspective"), AZ::RPI::ShaderOptionValue{ m_aerialPerspectiveEnabled });
             passData.m_shaderOptionGroup.SetValue(AZ::Name("o_enableAerialPerspective"), AZ::RPI::ShaderOptionValue{ m_aerialPerspectiveEnabled });
 
 
-            auto key = passData.m_shaderOptionGroup.GetShaderVariantKeyFallbackValue();
-            passData.m_srg->SetShaderVariantKeyFallbackValue(key);
+            const auto& pass = m_children[childIndex];
+            if (auto fullscreenPass = azrtti_cast<RPI::FullscreenTrianglePass*>(pass); fullscreenPass != nullptr)
+            {
+                fullscreenPass->UpdateShaderOptions(passData.m_shaderOptionGroup.GetShaderVariantId());
+            }
+            else if (auto computePass = azrtti_cast<RPI::ComputePass*>(pass); computePass != nullptr)
+            {
+                computePass->UpdateShaderOptions(passData.m_shaderOptionGroup.GetShaderVariantId());
+            }
+            childIndex++;
         }
         }
     }
     }
 
 

+ 3 - 1
Gems/Atom/RHI/Code/Include/Atom/RHI.Edit/ShaderPlatformInterface.h

@@ -93,6 +93,7 @@ namespace AZ::RHI
             AZStd::string m_entryFunctionName;
             AZStd::string m_entryFunctionName;
 
 
             ByProducts m_byProducts;  //!< Optional; used for debug information
             ByProducts m_byProducts;  //!< Optional; used for debug information
+            AZStd::string m_extraData; //!< Optional; extra data that can be pass for creating the Stage function.
         };
         };
 
 
         //! @apiUniqueIndex See GetApiUniqueIndex() for details.
         //! @apiUniqueIndex See GetApiUniqueIndex() for details.
@@ -123,7 +124,8 @@ namespace AZ::RHI
             ShaderHardwareStage shaderStage,
             ShaderHardwareStage shaderStage,
             const AZStd::string& tempFolderPath,
             const AZStd::string& tempFolderPath,
             StageDescriptor& outputDescriptor,
             StageDescriptor& outputDescriptor,
-            const RHI::ShaderBuildArguments& shaderBuildArguments) const = 0;
+            const RHI::ShaderBuildArguments& shaderBuildArguments,
+            const bool useSpecializationConstants) const = 0;
 
 
         //! Query whether the shaders are set to build with debug information
         //! Query whether the shaders are set to build with debug information
         virtual bool BuildHasDebugInfo(const RHI::ShaderBuildArguments& shaderBuildArguments) const
         virtual bool BuildHasDebugInfo(const RHI::ShaderBuildArguments& shaderBuildArguments) const

+ 10 - 4
Gems/Atom/RHI/Code/Include/Atom/RHI/PipelineStateDescriptor.h

@@ -13,6 +13,7 @@
 #include <Atom/RHI.Reflect/ShaderStageFunction.h>
 #include <Atom/RHI.Reflect/ShaderStageFunction.h>
 #include <AzCore/Utils/TypeHash.h>
 #include <AzCore/Utils/TypeHash.h>
 #include <Atom/RHI.Reflect/PipelineLayoutDescriptor.h>
 #include <Atom/RHI.Reflect/PipelineLayoutDescriptor.h>
+#include <Atom/RHI/SpecializationConstant.h>
 
 
 namespace AZ::RHI
 namespace AZ::RHI
 {
 {
@@ -38,16 +39,21 @@ namespace AZ::RHI
         PipelineStateType GetType() const;
         PipelineStateType GetType() const;
 
 
         //! Returns the hash of the pipeline state descriptor contents.
         //! Returns the hash of the pipeline state descriptor contents.
-        virtual HashValue64 GetHash() const = 0;
+        HashValue64 GetHash() const;
 
 
         bool operator == (const PipelineStateDescriptor& rhs) const;
         bool operator == (const PipelineStateDescriptor& rhs) const;
 
 
         //! The pipeline layout describing the shader resource bindings.
         //! The pipeline layout describing the shader resource bindings.
         ConstPtr<PipelineLayoutDescriptor> m_pipelineLayoutDescriptor = nullptr;
         ConstPtr<PipelineLayoutDescriptor> m_pipelineLayoutDescriptor = nullptr;
 
 
+        //! Values for specialization constants.
+        AZStd::vector<SpecializationConstant> m_specializationData;
+
     protected:
     protected:
         PipelineStateDescriptor(PipelineStateType pipelineStateType);
         PipelineStateDescriptor(PipelineStateType pipelineStateType);
 
 
+        virtual HashValue64 GetHashInternal() const = 0;
+
     private:
     private:
         PipelineStateType m_type = PipelineStateType::Count;
         PipelineStateType m_type = PipelineStateType::Count;
     };
     };
@@ -69,7 +75,7 @@ namespace AZ::RHI
         PipelineStateDescriptorForDispatch();
         PipelineStateDescriptorForDispatch();
 
 
         /// Computes the hash value for this descriptor.
         /// Computes the hash value for this descriptor.
-        HashValue64 GetHash() const override;
+        HashValue64 GetHashInternal() const override;
 
 
         bool operator == (const PipelineStateDescriptorForDispatch& rhs) const;
         bool operator == (const PipelineStateDescriptorForDispatch& rhs) const;
 
 
@@ -91,7 +97,7 @@ namespace AZ::RHI
         PipelineStateDescriptorForDraw();
         PipelineStateDescriptorForDraw();
 
 
         /// Computes the hash value for this descriptor.
         /// Computes the hash value for this descriptor.
-        HashValue64 GetHash() const override;
+        HashValue64 GetHashInternal() const override;
 
 
         bool operator == (const PipelineStateDescriptorForDraw& rhs) const;
         bool operator == (const PipelineStateDescriptorForDraw& rhs) const;
 
 
@@ -124,7 +130,7 @@ namespace AZ::RHI
         PipelineStateDescriptorForRayTracing();
         PipelineStateDescriptorForRayTracing();
 
 
         //! Computes the hash value for this descriptor.
         //! Computes the hash value for this descriptor.
-        HashValue64 GetHash() const override;
+        HashValue64 GetHashInternal() const override;
 
 
         bool operator == (const PipelineStateDescriptorForRayTracing& rhs) const;
         bool operator == (const PipelineStateDescriptorForRayTracing& rhs) const;
 
 

+ 45 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/SpecializationConstant.h

@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <AzCore/Utils/TypeHash.h>
+#include <Atom/RHI.Reflect/Handle.h>
+
+namespace AZ::RHI
+{
+    //! Holds a value for a specialization constant
+    using SpecializationValue = RHI::Handle<uint32_t, struct SpecializationConstant>;
+
+    //! Supported types for specialization constants
+    enum class SpecializationType : uint32_t
+    {
+        Integer,
+        Bool,
+        Invalid
+    };
+
+    //! Contains all the necessary information and value of a specialization constant
+    //! so it can be used when creating a PipelineState.
+    struct SpecializationConstant
+    {
+        SpecializationConstant() = default;
+
+        //! Name of the constant
+        Name m_name;
+        //! Id of the constant
+        uint32_t m_id = 0;
+        //! Value of the constant
+        SpecializationValue m_value;
+        //! Type of the constant
+        SpecializationType m_type = SpecializationType::Invalid;
+
+        bool operator==(const SpecializationConstant& rhs) const;
+        //! Returns a hash of the constant
+        HashValue64 GetHash() const;
+    };
+}

+ 23 - 16
Gems/Atom/RHI/Code/Source/RHI/PipelineStateDescriptor.cpp

@@ -22,9 +22,22 @@ namespace AZ::RHI
         return m_type;
         return m_type;
     }
     }
 
 
-    bool PipelineStateDescriptor::operator == (const PipelineStateDescriptor& rhs) const
+    HashValue64 PipelineStateDescriptor::GetHash() const
     {
     {
-        return m_type == rhs.m_type;
+        AZ_Assert(m_pipelineLayoutDescriptor, "Pipeline layout descriptor is null.");
+        AZ::HashValue64 seed = AZ::HashValue64{ 0 };
+        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
+        for (const auto& constant : m_specializationData)
+        {
+            seed = TypeHash64(constant.GetHash(), seed);
+        }
+        seed = TypeHash64(GetHashInternal(), seed);
+        return seed;
+    }
+
+    bool PipelineStateDescriptor::operator==(const PipelineStateDescriptor& rhs) const
+    {
+        return m_type == rhs.m_type && m_specializationData == rhs.m_specializationData;
     }
     }
 
 
     PipelineStateDescriptorForDraw::PipelineStateDescriptorForDraw()
     PipelineStateDescriptorForDraw::PipelineStateDescriptorForDraw()
@@ -39,21 +52,17 @@ namespace AZ::RHI
         : PipelineStateDescriptor(PipelineStateType::RayTracing)
         : PipelineStateDescriptor(PipelineStateType::RayTracing)
     {}
     {}
 
 
-    AZ::HashValue64 PipelineStateDescriptorForDispatch::GetHash() const
+    AZ::HashValue64 PipelineStateDescriptorForDispatch::GetHashInternal() const
     {
     {
-        AZ_Assert(m_pipelineLayoutDescriptor, "Pipeline layout descriptor is null.");
         AZ_Assert(m_computeFunction, "Compute function is null.");
         AZ_Assert(m_computeFunction, "Compute function is null.");
 
 
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
-        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
         seed = TypeHash64(m_computeFunction->GetHash(), seed);
         seed = TypeHash64(m_computeFunction->GetHash(), seed);
         return seed;
         return seed;
     }
     }
 
 
-    AZ::HashValue64 PipelineStateDescriptorForDraw::GetHash() const
+    AZ::HashValue64 PipelineStateDescriptorForDraw::GetHashInternal() const
     {
     {
-        AZ_Assert(m_pipelineLayoutDescriptor, "m_pipelineLayoutDescriptor is null.");
-
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
 
 
         if (m_vertexFunction)
         if (m_vertexFunction)
@@ -69,7 +78,6 @@ namespace AZ::RHI
             seed = TypeHash64(m_fragmentFunction->GetHash(), seed);
             seed = TypeHash64(m_fragmentFunction->GetHash(), seed);
         }
         }
 
 
-        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
         seed = TypeHash64(m_inputStreamLayout.GetHash(), seed);
         seed = TypeHash64(m_inputStreamLayout.GetHash(), seed);
         seed = TypeHash64(m_renderAttachmentConfiguration.GetHash(), seed);
         seed = TypeHash64(m_renderAttachmentConfiguration.GetHash(), seed);
 
 
@@ -78,12 +86,9 @@ namespace AZ::RHI
         return seed;
         return seed;
     }
     }
 
 
-    AZ::HashValue64 PipelineStateDescriptorForRayTracing::GetHash() const
+    AZ::HashValue64 PipelineStateDescriptorForRayTracing::GetHashInternal() const
     {
     {
-        AZ_Assert(m_pipelineLayoutDescriptor, "Pipeline layout descriptor is null.");
-
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
         AZ::HashValue64 seed = AZ::HashValue64{ 0 };
-        seed = TypeHash64(m_pipelineLayoutDescriptor->GetHash(), seed);
         seed = TypeHash64(m_rayTracingFunction->GetHash(), seed);
         seed = TypeHash64(m_rayTracingFunction->GetHash(), seed);
         return seed;
         return seed;
     }
     }
@@ -93,18 +98,20 @@ namespace AZ::RHI
         return m_fragmentFunction == rhs.m_fragmentFunction && m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
         return m_fragmentFunction == rhs.m_fragmentFunction && m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
             m_renderStates == rhs.m_renderStates && m_vertexFunction == rhs.m_vertexFunction &&
             m_renderStates == rhs.m_renderStates && m_vertexFunction == rhs.m_vertexFunction &&
             m_geometryFunction == rhs.m_geometryFunction && m_inputStreamLayout == rhs.m_inputStreamLayout && 
             m_geometryFunction == rhs.m_geometryFunction && m_inputStreamLayout == rhs.m_inputStreamLayout && 
-            m_renderAttachmentConfiguration == rhs.m_renderAttachmentConfiguration;
+            m_renderAttachmentConfiguration == rhs.m_renderAttachmentConfiguration && m_specializationData == rhs.m_specializationData;
     }
     }
 
 
     bool PipelineStateDescriptorForDispatch::operator == (const PipelineStateDescriptorForDispatch& rhs) const
     bool PipelineStateDescriptorForDispatch::operator == (const PipelineStateDescriptorForDispatch& rhs) const
     {
     {
         return m_computeFunction == rhs.m_computeFunction &&
         return m_computeFunction == rhs.m_computeFunction &&
-            m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor;
+            m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
+            m_specializationData == rhs.m_specializationData;
     }
     }
 
 
     bool PipelineStateDescriptorForRayTracing::operator == (const PipelineStateDescriptorForRayTracing& rhs) const
     bool PipelineStateDescriptorForRayTracing::operator == (const PipelineStateDescriptorForRayTracing& rhs) const
     {
     {
         return m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
         return m_pipelineLayoutDescriptor == rhs.m_pipelineLayoutDescriptor &&
-            m_rayTracingFunction == rhs.m_rayTracingFunction;
+            m_rayTracingFunction == rhs.m_rayTracingFunction &&
+            m_specializationData == rhs.m_specializationData;
     }
     }
 }
 }

+ 31 - 0
Gems/Atom/RHI/Code/Source/RHI/SpecializationConstant.cpp

@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/RHI/SpecializationConstant.h>
+
+namespace AZ::RHI
+{
+    bool SpecializationConstant::operator==(const SpecializationConstant& rhs) const
+    {
+        return
+            m_value == rhs.m_value &&
+            m_name == rhs.m_name &&
+            m_id == rhs.m_id &&
+            m_type == rhs.m_type;
+    }
+
+    HashValue64 SpecializationConstant::GetHash() const
+    {
+        AZ::HashValue64 seed = AZ::HashValue64{ 0 };
+        seed = TypeHash64(m_value.GetIndex(), seed);
+        seed = TypeHash64(m_name.GetHash(), seed);
+        seed = TypeHash64(m_id, seed);
+        seed = TypeHash64(m_type, seed);
+        return seed;
+    }    
+}

+ 2 - 0
Gems/Atom/RHI/Code/atom_rhi_public_files.cmake

@@ -276,4 +276,6 @@ set(FILES
     Include/Atom/RHI/XRRenderingInterface.h
     Include/Atom/RHI/XRRenderingInterface.h
     Include/Atom/RHI/DeviceDispatchRaysIndirectBuffer.h
     Include/Atom/RHI/DeviceDispatchRaysIndirectBuffer.h
     Include/Atom/RHI/DispatchRaysIndirectBuffer.h
     Include/Atom/RHI/DispatchRaysIndirectBuffer.h
+    Include/Atom/RHI/SpecializationConstant.h
+    Source/RHI/SpecializationConstant.cpp
 )
 )

+ 12 - 0
Gems/Atom/RHI/DX12/Code/CMakeLists.txt

@@ -90,6 +90,16 @@ ly_add_target(
             ../External/AMD_D3D12MemoryAllocator/v2.0.1
             ../External/AMD_D3D12MemoryAllocator/v2.0.1
 )
 )
 
 
+ly_add_target(
+    NAME OpenSSL_md5 STATIC
+    NAMESPACE Gem
+    FILES_CMAKE
+        openssl_md5_files.cmake
+    INCLUDE_DIRECTORIES
+        INTERFACE
+            ../External/md5
+)
+
 ly_add_target(
 ly_add_target(
     NAME ${gem_name}.Reflect STATIC
     NAME ${gem_name}.Reflect STATIC
     NAMESPACE Gem
     NAMESPACE Gem
@@ -131,6 +141,8 @@ ly_add_target(
             Gem::Amd_DX12MA
             Gem::Amd_DX12MA
             3rdParty::d3dx12
             3rdParty::d3dx12
             ${AFTERMATH_BUILD_DEPENDENCY}
             ${AFTERMATH_BUILD_DEPENDENCY}
+        PRIVATE
+            Gem::OpenSSL_md5
     COMPILE_DEFINITIONS 
     COMPILE_DEFINITIONS 
         PRIVATE
         PRIVATE
             ${USE_NSIGHT_AFTERMATH_DEFINE}
             ${USE_NSIGHT_AFTERMATH_DEFINE}

+ 13 - 0
Gems/Atom/RHI/DX12/Code/Include/Atom/RHI.Reflect/DX12/ShaderStageFunction.h

@@ -33,6 +33,12 @@ namespace AZ
         using ShaderByteCode = AZStd::vector<uint8_t>;
         using ShaderByteCode = AZStd::vector<uint8_t>;
         using ShaderByteCodeView = AZStd::span<const uint8_t>;
         using ShaderByteCodeView = AZStd::span<const uint8_t>;
 
 
+        //! Sentinel value used when patching shaders for specialization constants
+        constexpr uint32_t SCSentinelValue = 0x45678900;
+        //! Mask that marks which bytes are used for the sentinel and which
+        //! ones are used for the specialization constant id.
+        constexpr uint64_t SCSentinelMask = 0xffffffffffffff00;
+
         /**
         /**
          * A set of indices used to access physical sub-stages within a virtual stage.
          * A set of indices used to access physical sub-stages within a virtual stage.
          */
          */
@@ -60,6 +66,12 @@ namespace AZ
             /// Returns the assigned byte code.
             /// Returns the assigned byte code.
             ShaderByteCodeView GetByteCode(uint32_t subStageIndex = 0) const;
             ShaderByteCodeView GetByteCode(uint32_t subStageIndex = 0) const;
 
 
+            using SpecializationOffsets = AZStd::unordered_map<uint32_t, uint32_t>;
+            void SetSpecializationOffsets(uint32_t subStageIndex, const SpecializationOffsets& offsets);
+            const SpecializationOffsets& GetSpecializationOffsets(uint32_t subStageIndex = 0) const;
+
+            bool UseSpecializationConstants(uint32_t subStageIndex = 0) const;
+
         private:
         private:
             ShaderStageFunction() = default;
             ShaderStageFunction() = default;
             ShaderStageFunction(RHI::ShaderStage shaderStage);
             ShaderStageFunction(RHI::ShaderStage shaderStage);
@@ -74,6 +86,7 @@ namespace AZ
             ///////////////////////////////////////////////////////////////////
             ///////////////////////////////////////////////////////////////////
 
 
             AZStd::array<ShaderByteCode, ShaderSubStageCountMax> m_byteCodes;
             AZStd::array<ShaderByteCode, ShaderSubStageCountMax> m_byteCodes;
+            AZStd::array<SpecializationOffsets, ShaderSubStageCountMax> m_specializationOffsets;
         };
         };
     }
     }
 }
 }

+ 73 - 4
Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -16,6 +16,7 @@
 
 
 #include <AzCore/IO/FileIO.h>
 #include <AzCore/IO/FileIO.h>
 #include <AzCore/IO/SystemFile.h>
 #include <AzCore/IO/SystemFile.h>
+#include <AzCore/Serialization/Json/JsonUtils.h>
 #include <AzFramework/StringFunc/StringFunc.h>
 #include <AzFramework/StringFunc/StringFunc.h>
 
 
 namespace AZ
 namespace AZ
@@ -50,6 +51,34 @@ namespace AZ
             const int byteCodeIndex = 0;
             const int byteCodeIndex = 0;
             newShaderStageFunction->SetByteCode(byteCodeIndex, byteCode);
             newShaderStageFunction->SetByteCode(byteCodeIndex, byteCode);
 
 
+            // Read the json data with the specialization constants offsets.
+            // If the shader was not compiled with specialization constants this attribute will be empty.
+            AZStd::string fileName;
+            if (!stageDescriptor.m_extraData.empty())
+            {
+                auto jsonOutcome = JsonSerializationUtils::ReadJsonFile(stageDescriptor.m_extraData);
+                if (!jsonOutcome.IsSuccess())
+                {
+                    AZ_Error(DX12ShaderPlatformName, false, "%s", jsonOutcome.GetError().c_str());
+                    return nullptr;
+                }
+
+                const rapidjson::Document& doc = jsonOutcome.GetValue();
+                ShaderStageFunction::SpecializationOffsets offsets;
+                for (auto itr = doc.MemberBegin(); itr != doc.MemberEnd(); ++itr)
+                {
+                    if (!AZ::StringFunc::LooksLikeInt(itr->name.GetString()))
+                    {
+                        AZ_Error(DX12ShaderPlatformName, false, "SpecializationId %s is not an Int", itr->name.GetString());
+                        continue;
+                    }
+                    uint32_t specializationId = static_cast<uint32_t>(AZ::StringFunc::ToInt(itr->name.GetString()));
+                    uint32_t offset = itr->value.GetUint();
+                    offsets[specializationId] = offset;
+                }
+                newShaderStageFunction->SetSpecializationOffsets(byteCodeIndex, offsets);
+            }
+         
             newShaderStageFunction->Finalize();
             newShaderStageFunction->Finalize();
 
 
             return newShaderStageFunction;
             return newShaderStageFunction;
@@ -149,10 +178,11 @@ namespace AZ
             RHI::ShaderHardwareStage shaderStage,
             RHI::ShaderHardwareStage shaderStage,
             const AZStd::string& tempFolderPath,
             const AZStd::string& tempFolderPath,
             StageDescriptor& outputDescriptor,
             StageDescriptor& outputDescriptor,
-            const RHI::ShaderBuildArguments& shaderBuildArguments) const
+            const RHI::ShaderBuildArguments& shaderBuildArguments,
+            const bool useSpecializationConstants) const
         {
         {
             AZStd::vector<uint8_t> shaderByteCode;
             AZStd::vector<uint8_t> shaderByteCode;
-
+            AZStd::string specializationOffsetsFile;
             // Compile HLSL shader to byte code
             // Compile HLSL shader to byte code
             bool compiledSucessfully = CompileHLSLShader(
             bool compiledSucessfully = CompileHLSLShader(
                 shaderSourcePath,                        // shader source filepath
                 shaderSourcePath,                        // shader source filepath
@@ -161,7 +191,9 @@ namespace AZ
                 shaderStage,                             // shader stage (vertex shader, pixel shader, ...)
                 shaderStage,                             // shader stage (vertex shader, pixel shader, ...)
                 shaderBuildArguments,
                 shaderBuildArguments,
                 shaderByteCode,                          // compiled shader output
                 shaderByteCode,                          // compiled shader output
-                outputDescriptor.m_byProducts);          // dynamic branch count output & byproduct files
+                outputDescriptor.m_byProducts,           // dynamic branch count output & byproduct files
+                specializationOffsetsFile,               // path to the json file with the specialization offsets
+                useSpecializationConstants);             // if the shader stage it's using specialization constants
 
 
             if (!compiledSucessfully)
             if (!compiledSucessfully)
             {
             {
@@ -174,6 +206,7 @@ namespace AZ
             {
             {
                 outputDescriptor.m_stageType = shaderStage;
                 outputDescriptor.m_stageType = shaderStage;
                 outputDescriptor.m_byteCode = AZStd::move(shaderByteCode);
                 outputDescriptor.m_byteCode = AZStd::move(shaderByteCode);
+                outputDescriptor.m_extraData = AZStd::move(specializationOffsetsFile);
             }
             }
             else
             else
             {
             {
@@ -197,7 +230,9 @@ namespace AZ
             const RHI::ShaderHardwareStage shaderStageType,
             const RHI::ShaderHardwareStage shaderStageType,
             const RHI::ShaderBuildArguments& shaderBuildArguments,
             const RHI::ShaderBuildArguments& shaderBuildArguments,
             AZStd::vector<uint8_t>& compiledShader,
             AZStd::vector<uint8_t>& compiledShader,
-            ByProducts& byProducts) const
+            ByProducts& byProducts,
+            AZStd::string& specializationOffsetsFile,
+            const bool useSpecializationConstants) const
         {
         {
             // Shader compiler executable
             // Shader compiler executable
             const auto dxcRelativePath = RHI::GetDirectXShaderCompilerPath("Builders/DirectXShaderCompiler/dxc.exe");
             const auto dxcRelativePath = RHI::GetDirectXShaderCompilerPath("Builders/DirectXShaderCompiler/dxc.exe");
@@ -298,6 +333,40 @@ namespace AZ
                 return false;
                 return false;
             }
             }
 
 
+            if (useSpecializationConstants)
+            {
+                // Need to patch the shader so it can be used with specialization constants.
+                const auto dxscRelativePath = RHI::GetDirectXShaderCompilerPath("Builders/DirectXShaderCompiler/dxsc.exe");
+
+                AZStd::string shaderOutputCommon;
+                AzFramework::StringFunc::Path::GetFileName(shaderSourceFile.c_str(), shaderOutputCommon);
+                AzFramework::StringFunc::Path::Join(tempFolder.c_str(), shaderOutputCommon.c_str(), shaderOutputCommon);
+
+                AZStd::string patchedShaderOutput = shaderOutputCommon;
+                AzFramework::StringFunc::Path::ReplaceExtension(patchedShaderOutput, "dxil.patched.bin");
+                AZStd::string offsetsOutput = shaderOutputCommon;
+                AzFramework::StringFunc::Path::ReplaceExtension(offsetsOutput, "offsets.json");
+
+                const auto dxscCommandOptions = AZStd::string::format(
+                    //   1.sentinel    3.offsets_output   
+                    //     |    2.output    |   4.dxil-in
+                    //     |       |        |      |
+                    "-sv=%lu -o=\"%s\" -f=\"%s\" \"%s\"",
+                    static_cast<unsigned long>(SCSentinelValue), // 1
+                    patchedShaderOutput.c_str(), // 2
+                    offsetsOutput.c_str(), // 3
+                    shaderOutputFile.c_str() // 4
+                );
+
+                if (!RHI::ExecuteShaderCompiler(dxscRelativePath, dxscCommandOptions, shaderSourceFile, tempFolder, "DXSC"))
+                {
+                    return false;
+                }
+                shaderOutputFile = patchedShaderOutput;
+
+                specializationOffsetsFile = offsetsOutput;
+            }
+
             auto shaderOutputFileLoadResult = AZ::RHI::LoadFileBytes(shaderOutputFile.c_str());
             auto shaderOutputFileLoadResult = AZ::RHI::LoadFileBytes(shaderOutputFile.c_str());
             if (!shaderOutputFileLoadResult)
             if (!shaderOutputFileLoadResult)
             {
             {

+ 5 - 2
Gems/Atom/RHI/DX12/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -45,7 +45,8 @@ namespace AZ
                 RHI::ShaderHardwareStage shaderStage,
                 RHI::ShaderHardwareStage shaderStage,
                 const AZStd::string& tempFolderPath,
                 const AZStd::string& tempFolderPath,
                 StageDescriptor& outputDescriptor,
                 StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override;
 
 
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
 
 
@@ -59,7 +60,9 @@ namespace AZ
                 const RHI::ShaderHardwareStage shaderStageType,
                 const RHI::ShaderHardwareStage shaderStageType,
                 const RHI::ShaderBuildArguments& shaderBuildArguments,
                 const RHI::ShaderBuildArguments& shaderBuildArguments,
                 AZStd::vector<uint8_t>& m_byteCode,
                 AZStd::vector<uint8_t>& m_byteCode,
-                ByProducts& products) const;
+                ByProducts& products,
+                AZStd::string& specializationOffsetsFile,
+                const bool useSpecializationConstants) const;
 
 
             const Name m_apiName;
             const Name m_apiName;
         };
         };

+ 18 - 2
Gems/Atom/RHI/DX12/Code/Source/RHI.Reflect/ShaderStageFunction.cpp

@@ -19,8 +19,9 @@ namespace AZ
             if (SerializeContext* serializeContext = azrtti_cast<SerializeContext*>(context))
             if (SerializeContext* serializeContext = azrtti_cast<SerializeContext*>(context))
             {
             {
                 serializeContext->Class<ShaderStageFunction, RHI::ShaderStageFunction>()
                 serializeContext->Class<ShaderStageFunction, RHI::ShaderStageFunction>()
-                    ->Version(1)
-                    ->Field("m_byteCodes", &ShaderStageFunction::m_byteCodes);
+                    ->Version(2)
+                    ->Field("m_byteCodes", &ShaderStageFunction::m_byteCodes)
+                    ->Field("m_specializationOffsets", &ShaderStageFunction::m_specializationOffsets);
             }
             }
         }
         }
 
 
@@ -44,6 +45,21 @@ namespace AZ
             return ShaderByteCodeView(m_byteCodes[subStageIndex]);
             return ShaderByteCodeView(m_byteCodes[subStageIndex]);
         }
         }
 
 
+        void ShaderStageFunction::SetSpecializationOffsets(uint32_t subStageIndex, const SpecializationOffsets& offsets)
+        {
+            m_specializationOffsets[subStageIndex] = offsets;
+        }
+
+        const ShaderStageFunction::SpecializationOffsets& ShaderStageFunction::GetSpecializationOffsets(uint32_t subStageIndex) const
+        {
+            return m_specializationOffsets[subStageIndex];
+        }
+
+        bool ShaderStageFunction::UseSpecializationConstants(uint32_t subStageIndex) const
+        {
+            return !GetSpecializationOffsets(subStageIndex).empty();
+        }
+
         RHI::ResultCode ShaderStageFunction::FinalizeInternal()
         RHI::ResultCode ShaderStageFunction::FinalizeInternal()
         {
         {
             bool emptyByteCodes = true;
             bool emptyByteCodes = true;

+ 2 - 0
Gems/Atom/RHI/DX12/Code/Source/RHI/DX12.h

@@ -33,6 +33,8 @@
 #define IID_GRAPHICS_PPV_ARGS(ppType) IID_PPV_ARGS(ppType)
 #define IID_GRAPHICS_PPV_ARGS(ppType) IID_PPV_ARGS(ppType)
 #endif
 #endif
 
 
+#define MAKE_FOURCC(a, b, c, d) (((uint32_t)(d) << 24) | ((uint32_t)(c) << 16) | ((uint32_t)(b) << 8) | (uint32_t)(a))
+
 namespace AZ
 namespace AZ
 {
 {
     namespace DX12
     namespace DX12

+ 14 - 6
Gems/Atom/RHI/DX12/Code/Source/RHI/PipelineState.cpp

@@ -10,6 +10,8 @@
 #include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
 #include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
 #include <RHI/Conversions.h>
 #include <RHI/Conversions.h>
 #include <RHI/Device.h>
 #include <RHI/Device.h>
+#include <RHI/ShaderUtils.h>
+
 namespace AZ
 namespace AZ
 {
 {
     namespace DX12
     namespace DX12
@@ -55,20 +57,24 @@ namespace AZ
             // Shader state.
             // Shader state.
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             pipelineStateDesc.pRootSignature = pipelineLayout->Get();
             pipelineStateDesc.pRootSignature = pipelineLayout->Get();
-
+            // Cache used for saving the patched version of the shader
+            AZStd::vector<ShaderByteCode> shaderByteCodeCache;
             if (const ShaderStageFunction* vertexFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_vertexFunction.get()))
             if (const ShaderStageFunction* vertexFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_vertexFunction.get()))
             {
             {
-                pipelineStateDesc.VS = D3D12BytecodeFromView(vertexFunction->GetByteCode());
+                pipelineStateDesc.VS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*vertexFunction, descriptor, shaderByteCodeCache));
             }
             }
 
 
             if (const ShaderStageFunction* geometryFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_geometryFunction.get()))
             if (const ShaderStageFunction* geometryFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_geometryFunction.get()))
             {
             {
-                pipelineStateDesc.GS = D3D12BytecodeFromView(geometryFunction->GetByteCode());
+                pipelineStateDesc.GS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*geometryFunction, descriptor, shaderByteCodeCache));
             }
             }
 
 
             if (const ShaderStageFunction* fragmentFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_fragmentFunction.get()))
             if (const ShaderStageFunction* fragmentFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_fragmentFunction.get()))
             {
             {
-                pipelineStateDesc.PS = D3D12BytecodeFromView(fragmentFunction->GetByteCode());
+                pipelineStateDesc.PS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*fragmentFunction, descriptor, shaderByteCodeCache));
             }
             }
 
 
             const RHI::RenderAttachmentConfiguration& renderAttachmentConfiguration = descriptor.m_renderAttachmentConfiguration;
             const RHI::RenderAttachmentConfiguration& renderAttachmentConfiguration = descriptor.m_renderAttachmentConfiguration;
@@ -130,10 +136,12 @@ namespace AZ
 
 
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             pipelineStateDesc.pRootSignature = pipelineLayout->Get();
             pipelineStateDesc.pRootSignature = pipelineLayout->Get();
-
+            // Cache used for saving the patched version of the shader
+            AZStd::vector<ShaderByteCode> shaderByteCodeCache;
             if (const ShaderStageFunction* computeFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_computeFunction.get()))
             if (const ShaderStageFunction* computeFunction = azrtti_cast<const ShaderStageFunction*>(descriptor.m_computeFunction.get()))
             {
             {
-                pipelineStateDesc.CS = D3D12BytecodeFromView(computeFunction->GetByteCode());
+                pipelineStateDesc.CS =
+                    D3D12BytecodeFromView(ShaderUtils::PatchShaderFunction(*computeFunction, descriptor, shaderByteCodeCache));
             }
             }
 
 
             PipelineLibrary* pipelineLibrary = static_cast<PipelineLibrary*>(pipelineLibraryBase);
             PipelineLibrary* pipelineLibrary = static_cast<PipelineLibrary*>(pipelineLibraryBase);

+ 7 - 2
Gems/Atom/RHI/DX12/Code/Source/RHI/RayTracingPipelineState.cpp

@@ -10,6 +10,8 @@
 #include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
 #include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
 #include <RHI/Conversions.h>
 #include <RHI/Conversions.h>
 #include <RHI/Device.h>
 #include <RHI/Device.h>
+#include <RHI/ShaderUtils.h>
+
 namespace AZ
 namespace AZ
 {
 {
     namespace DX12
     namespace DX12
@@ -53,12 +55,15 @@ namespace AZ
             // add DXIL Libraries
             // add DXIL Libraries
             AZStd::vector<D3D12_DXIL_LIBRARY_DESC> libraryDescs;
             AZStd::vector<D3D12_DXIL_LIBRARY_DESC> libraryDescs;
             libraryDescs.reserve(dxilLibraryCount);
             libraryDescs.reserve(dxilLibraryCount);
+            AZStd::vector<ShaderByteCode> patchedShaderCache;
             for (const RHI::RayTracingShaderLibrary& shaderLibrary : descriptor->GetShaderLibraries())
             for (const RHI::RayTracingShaderLibrary& shaderLibrary : descriptor->GetShaderLibraries())
             {
             {
                 const ShaderStageFunction* rayTracingFunction = azrtti_cast<const ShaderStageFunction*>(shaderLibrary.m_descriptor.m_rayTracingFunction.get());
                 const ShaderStageFunction* rayTracingFunction = azrtti_cast<const ShaderStageFunction*>(shaderLibrary.m_descriptor.m_rayTracingFunction.get());
-        
+                ShaderByteCodeView byteCode =
+                    ShaderUtils::PatchShaderFunction(*rayTracingFunction, shaderLibrary.m_descriptor, patchedShaderCache);
+
                 D3D12_DXIL_LIBRARY_DESC libraryDesc = {};
                 D3D12_DXIL_LIBRARY_DESC libraryDesc = {};
-                libraryDesc.DXILLibrary = D3D12_SHADER_BYTECODE{ rayTracingFunction->GetByteCode().data(), rayTracingFunction->GetByteCode().size() };
+                libraryDesc.DXILLibrary = D3D12_SHADER_BYTECODE{ byteCode.data(), byteCode.size() };
                 libraryDesc.NumExports = 0; // all shaders
                 libraryDesc.NumExports = 0; // all shaders
                 libraryDesc.pExports = nullptr;
                 libraryDesc.pExports = nullptr;
                 libraryDescs.push_back(libraryDesc);
                 libraryDescs.push_back(libraryDesc);

+ 237 - 0
Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.cpp

@@ -0,0 +1,237 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <RHI/ShaderUtils.h>
+#include <openssl/md5.h>
+#include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
+
+namespace AZ::DX12
+{
+    static const uint32_t FOURCC_DXBC = MAKE_FOURCC('D', 'X', 'B', 'C');
+
+    // Modify the bits in the bytecode with a new value following the VBR rules of encoding
+    uint64_t TamperBits(uint8_t* byteCode, uint32_t patchVal, uint64_t offset)
+    {
+        constexpr uint32_t BitsPerByte = 8;
+        // LSB is used for encoding signed/unsigned
+        // VBR left shift all values to leave space for the sign bit
+        patchVal <<= 1;
+
+        uint64_t original = 0;
+        uint64_t currentOffset = offset;
+
+        auto SetBitFunc = [&](uint64_t pos, bool value)
+        {
+            if (value)
+            {
+                byteCode[pos / BitsPerByte] |= (uint8_t)1 << pos % BitsPerByte;
+            }
+            else
+            {
+                byteCode[pos / BitsPerByte] &= ~((uint8_t)1 << pos % BitsPerByte);
+            }
+        };
+        uint32_t originalValueBitIndex = 0;
+        // 32 bits take 5 full bytes in VBR.
+        const uint32_t sentinelBytes = 5;
+        for (uint32_t i = 0; i < sentinelBytes; i++)
+        {
+            // Patch all bits except the continuation bit
+            for (uint32_t j = 0; j < (BitsPerByte - 1); j++)
+            {
+                bool currentValue = ((byteCode[currentOffset / BitsPerByte] & ((uint8_t)1 << currentOffset % BitsPerByte)) != 0);
+                bool newValue = (patchVal & (uint64_t)1) != 0;
+                SetBitFunc(currentOffset, newValue);
+                original |= (uint64_t)currentValue << originalValueBitIndex;
+                patchVal >>= 1;
+                currentOffset++;
+                originalValueBitIndex++;
+            }
+            // Set continuation bit
+            SetBitFunc(currentOffset, true);
+            currentOffset++;
+        }
+        // MSB in VBR doesn't have a continuation bit (because it's the last byte)
+        SetBitFunc(currentOffset - 1, false);
+        // VBR left shift values for sign bit, so we right shift the value we found
+        return original >> 1;
+    }
+
+    ShaderByteCode ShaderUtils::PatchShaderFunction(
+        const ShaderStageFunction& shaderFunction, const RHI::PipelineStateDescriptor& descriptor)
+    {
+        ShaderByteCode patched(shaderFunction.GetByteCode().size());
+        ::memcpy(patched.data(), shaderFunction.GetByteCode().data(), patched.size());
+        const AZStd::vector<RHI::SpecializationConstant>& specializationConstants = descriptor.m_specializationData;
+        for (const auto& element : shaderFunction.GetSpecializationOffsets())
+        {
+            auto findIter = AZStd::find_if(
+                specializationConstants.begin(),
+                specializationConstants.end(),
+                [&](const RHI::SpecializationConstant& constantData)
+                {
+                    return constantData.m_id == element.first;
+                });
+
+            if (findIter == specializationConstants.end())
+            {
+                AZ_Error("ShaderUtils", false, "Specialization constant %d doesn't not have a value", element.first);
+                continue;
+            }
+
+            [[maybe_unused]] uint64_t sentinelFound = TamperBits(patched.data(), findIter->m_value.GetIndex(), element.second);
+            AZ_Assert(
+                static_cast<uint32_t>(sentinelFound & SCSentinelMask) == SCSentinelValue,
+                "Invalid sentinel value found %lu",
+                sentinelFound);
+        }
+
+        // Re-sign the shader bytecode after we patch it
+        if (!SignByteCode(patched))
+        {
+            AZ_Error("ShaderUtils", false, "Failed to sign container");
+            return {};
+        }
+
+        return patched;
+    }
+
+    ShaderByteCodeView ShaderUtils::PatchShaderFunction(
+        const ShaderStageFunction& shaderFunction,
+        const RHI::PipelineStateDescriptor& descriptor,
+        AZStd::vector<ShaderByteCode>& patchedShaderContainer)
+    {
+        if (!shaderFunction.UseSpecializationConstants())
+        {
+            // No need to patch anything
+            return shaderFunction.GetByteCode();
+        }
+
+        ShaderByteCode patchedShader = PatchShaderFunction(shaderFunction, descriptor);
+        patchedShaderContainer.emplace_back(AZStd::move(patchedShader));
+        return patchedShaderContainer.back();
+    }
+
+    bool ShaderUtils::SignByteCode(ShaderByteCode& bytecode)
+    {
+        // Original signing code from Renderdoc
+        struct FileHeader
+        {
+            uint32_t fourcc; // "DXBC"
+            uint32_t hashValue[4]; // unknown hash function and data
+            uint32_t containerVersion;
+            uint32_t fileLength;
+            uint32_t numChunks;
+            // uint32 chunkOffsets[numChunks]; follows
+        };
+
+        if (bytecode.size() < sizeof(FileHeader) || bytecode.data() == nullptr)
+        {
+            return false;
+        }
+
+        FileHeader* header = reinterpret_cast<FileHeader*>(bytecode.data());
+
+        if (header->fourcc != FOURCC_DXBC)
+        {
+            return false;
+        }
+
+        if (header->fileLength != static_cast<uint32_t>(bytecode.size()))
+        {
+            return false;
+        }
+
+        _MD5_CTX md5ctx = {};
+        _MD5_Init(&md5ctx);
+
+        // the hashable data starts immediately after the hash.
+        AZStd::byte* data = reinterpret_cast<AZStd::byte*>(&header->containerVersion);
+        uint32_t length = uint32_t(bytecode.size() - offsetof(FileHeader, containerVersion));
+
+        // we need to know the number of bits for putting in the trailing padding.
+        uint32_t numBits = length * 8;
+        uint32_t numBitsPart2 = (numBits >> 2) | 1;
+
+        // MD5 works on 64-byte chunks, process the first set of whole chunks, leaving 0-63 bytes left
+        // over
+        uint32_t leftoverLength = length % 64;
+        _MD5_Update(&md5ctx, data, length - leftoverLength);
+
+        data += length - leftoverLength;
+
+        uint32_t block[16] = {};
+        AZ_Assert(sizeof(block) == 64, "Block is not properly sized for MD5 round");
+
+        // normally MD5 finishes by appending a 1 bit to the bitstring. Since we are only appending bytes
+        // this would be an 0x80 byte (the first bit is considered to be the MSB). Then it pads out with
+        // zeroes until it has 56 bytes in the last block and appends appends the message length as a
+        // 64-bit integer as the final part of that block.
+        // in other words, normally whatever is leftover from the actual message gets one byte appended,
+        // then if there's at least 8 bytes left we'll append the length. Otherwise we pad that block with
+        // 0s and create a new block with the length at the end.
+        // Or as the original RFC/spec says: padding is always performed regardless of whether the
+        // original buffer already ended in exactly a 56 byte block.
+        //
+        // The DXBC finalisation is slightly different (previous work suggests this is due to a bug in the
+        // original implementation and it was maybe intended to be exactly MD5?):
+        //
+        // The length provided in the padding block is not 64-bit properly: the second dword with the high
+        // bits is instead the number of nybbles(?) with 1 OR'd on. The length is also split, so if it's
+        // in
+        // a padding block the low bits are in the first dword and the upper bits in the last. If there's
+        // no padding block the low dword is passed in first before the leftovers of the message and then
+        // the upper bits at the end.
+
+        // if the leftovers uses at least 56, we can't fit both the trailing 1 and the 64-bit length, so
+        // we need a padding block and then our own block for the length.
+        if (leftoverLength >= 56)
+        {
+            // pass in the leftover data padded out to 64 bytes with zeroes
+            _MD5_Update(&md5ctx, data, leftoverLength);
+
+            block[0] = 0x80; // first padding bit is 1
+            _MD5_Update(&md5ctx, block, 64 - leftoverLength);
+
+            // the final block contains the number of bits in the first dword, and the weird upper bits
+            block[0] = numBits;
+            block[15] = numBitsPart2;
+
+            // process this block directly, we're replacing the call to MD5_Final here manually
+            _MD5_Update(&md5ctx, block, 64);
+        }
+        else
+        {
+            // the leftovers mean we can put the padding inside the final block. But first we pass the "low"
+            // number of bits:
+            _MD5_Update(&md5ctx, &numBits, sizeof(numBits));
+
+            if (leftoverLength)
+            {
+                _MD5_Update(&md5ctx, data, leftoverLength);
+            }
+
+            uint32_t paddingBytes = 64 - leftoverLength - 4;
+
+            // prepare the remainder of this block, starting with the 0x80 padding start right after the
+            // leftovers and the first part of the bit length above.
+            block[0] = 0x80;
+            // then add the remainder of the 'length' here in the final part of the block
+            ::memcpy(((AZStd::byte*)block) + paddingBytes - 4, &numBitsPart2, 4);
+
+            _MD5_Update(&md5ctx, block, paddingBytes);
+        }
+
+        header->hashValue[0] = md5ctx.a;
+        header->hashValue[1] = md5ctx.b;
+        header->hashValue[2] = md5ctx.c;
+        header->hashValue[3] = md5ctx.d;
+
+        return true;
+    }
+} // namespace AZ

+ 36 - 0
Gems/Atom/RHI/DX12/Code/Source/RHI/ShaderUtils.h

@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI/PipelineStateDescriptor.h>
+#include <Atom/RHI.Reflect/DX12/ShaderStageFunction.h>
+
+namespace AZ::DX12
+{
+    namespace ShaderUtils
+    {
+        //! Patch a shader bytecode with the proper values of the specialization constants found in the
+        //! pipeline descriptor.
+        ShaderByteCode PatchShaderFunction(
+            const ShaderStageFunction& shaderFunction,
+            const RHI::PipelineStateDescriptor& descriptor);
+
+        //! Patch a shader bytecode with the proper values of the specialization constants found in the
+        //! pipeline descriptor. If the pipeline descriptor is not using specialization constants, it returns the
+        //! shader bytecode unchanged. If it needs to patch it, the patched shader bytecode is stored in the provided container.
+        //! Refer to RFC (https://github.com/o3de/sig-graphics-audio/blob/main/rfcs/SpecializationConstants/SpecializationConstants.md)
+        //! for more details on how specialization constants works on DX12 
+        ShaderByteCodeView PatchShaderFunction(
+            const ShaderStageFunction& shaderFunction,
+            const RHI::PipelineStateDescriptor& descriptor,
+            AZStd::vector<ShaderByteCode>& patchedShaderContainer);
+
+        //! Signs a DXIL blob so it can be used by the driver. Only needed if the bytecode has been modified.
+        bool SignByteCode(ShaderByteCode& bytecode);
+    }
+} // namespace AZ

+ 2 - 0
Gems/Atom/RHI/DX12/Code/atom_rhi_dx12_private_common_files.cmake

@@ -123,4 +123,6 @@ set(FILES
     Source/RHI/RayTracingPipelineState.h
     Source/RHI/RayTracingPipelineState.h
     Source/RHI/RayTracingShaderTable.cpp
     Source/RHI/RayTracingShaderTable.cpp
     Source/RHI/RayTracingShaderTable.h
     Source/RHI/RayTracingShaderTable.h
+    Source/RHI/ShaderUtils.cpp
+    Source/RHI/ShaderUtils.h
 )
 )

+ 12 - 0
Gems/Atom/RHI/DX12/Code/openssl_md5_files.cmake

@@ -0,0 +1,12 @@
+#
+# Copyright (c) Contributors to the Open 3D Engine Project.
+# For complete copyright and license terms please see the LICENSE at the root of this distribution.
+#
+# SPDX-License-Identifier: Apache-2.0 OR MIT
+#
+#
+
+set(FILES
+    ../External/md5/openssl/md5.c
+    ../External/md5/openssl/md5.h
+)

+ 38 - 0
Gems/Atom/RHI/DX12/External/md5/README.md

@@ -0,0 +1,38 @@
+Fetched from https://openwall.info/wiki/people/solar/software/public-domain-source-code/md5 on 2021-07-07
+
+Public domain licensed:
+
+> This is an OpenSSL-compatible implementation of the RSA Data Security, Inc.
+> MD5 Message-Digest Algorithm (RFC 1321).
+>
+> Homepage:
+> http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5
+>
+> Author:
+> Alexander Peslyak, better known as Solar Designer <solar at openwall.com>
+>
+> This software was written by Alexander Peslyak in 2001.  No copyright is
+> claimed, and the software is hereby placed in the public domain.
+> In case this attempt to disclaim copyright and place the software in the
+> public domain is deemed null and void, then the software is
+> Copyright (c) 2001 Alexander Peslyak and it is hereby released to the
+> general public under the following terms:
+>
+> Redistribution and use in source and binary forms, with or without
+> modification, are permitted.
+>
+> There's ABSOLUTELY NO WARRANTY, express or implied.
+>
+> (This is a heavily cut-down "BSD license".)
+>
+> This differs from Colin Plumb's older public domain implementation in that
+> no exactly 32-bit integer data type is required (any 32-bit or wider
+> unsigned integer data type will do), there's no compile-time endianness
+> configuration, and the function prototypes match OpenSSL's.  No code from
+> Colin Plumb's implementation has been reused; this comment merely compares
+> the properties of the two independent implementations.
+>
+> The primary goals of this implementation are portability and ease of use.
+> It is meant to be fast, but not as fast as possible.  Some known
+> optimizations are not included to reduce source code size and avoid
+> compile-time configuration.

+ 289 - 0
Gems/Atom/RHI/DX12/External/md5/openssl/md5.c

@@ -0,0 +1,289 @@
+/*
+ * This is an OpenSSL-compatible implementation of the RSA Data Security, Inc.
+ * MD5 Message-Digest Algorithm (RFC 1321).
+ *
+ * Homepage:
+ * http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5
+ *
+ * Author:
+ * Alexander Peslyak, better known as Solar Designer <solar at openwall.com>
+ *
+ * This software was written by Alexander Peslyak in 2001.  No copyright is
+ * claimed, and the software is hereby placed in the public domain.
+ * In case this attempt to disclaim copyright and place the software in the
+ * public domain is deemed null and void, then the software is
+ * Copyright (c) 2001 Alexander Peslyak and it is hereby released to the
+ * general public under the following terms:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted.
+ *
+ * There's ABSOLUTELY NO WARRANTY, express or implied.
+ *
+ * (This is a heavily cut-down "BSD license".)
+ *
+ * This differs from Colin Plumb's older public domain implementation in that
+ * no exactly 32-bit integer data type is required (any 32-bit or wider
+ * unsigned integer data type will do), there's no compile-time endianness
+ * configuration, and the function prototypes match OpenSSL's.  No code from
+ * Colin Plumb's implementation has been reused; this comment merely compares
+ * the properties of the two independent implementations.
+ *
+ * The primary goals of this implementation are portability and ease of use.
+ * It is meant to be fast, but not as fast as possible.  Some known
+ * optimizations are not included to reduce source code size and avoid
+ * compile-time configuration.
+ */
+
+#ifndef HAVE_OPENSSL
+
+#include "md5.h"
+
+/*
+ * The basic MD5 functions.
+ *
+ * F and G are optimized compared to their RFC 1321 definitions for
+ * architectures that lack an AND-NOT instruction, just like in Colin Plumb's
+ * implementation.
+ */
+#define F(x, y, z)			((z) ^ ((x) & ((y) ^ (z))))
+#define G(x, y, z)			((y) ^ ((z) & ((x) ^ (y))))
+#define H(x, y, z)			(((x) ^ (y)) ^ (z))
+#define H2(x, y, z)			((x) ^ ((y) ^ (z)))
+#define I(x, y, z)			((y) ^ ((x) | ~(z)))
+
+/*
+ * The MD5 transformation for all four rounds.
+ */
+#define STEP(f, a, b, c, d, x, t, s) \
+	(a) += f((b), (c), (d)) + (x) + (t); \
+	(a) = (((a) << (s)) | (((a) & 0xffffffff) >> (32 - (s)))); \
+	(a) += (b);
+
+/*
+ * SET reads 4 input bytes in little-endian byte order and stores them in a
+ * properly aligned word in host byte order.
+ *
+ * The check for little-endian architectures that tolerate unaligned memory
+ * accesses is just an optimization.  Nothing will break if it fails to detect
+ * a suitable architecture.
+ *
+ * Unfortunately, this optimization may be a C strict aliasing rules violation
+ * if the caller's data buffer has effective type that cannot be aliased by
+ * MD5_u32plus.  In practice, this problem may occur if these MD5 routines are
+ * inlined into a calling function, or with future and dangerously advanced
+ * link-time optimizations.  For the time being, keeping these MD5 routines in
+ * their own translation unit avoids the problem.
+ */
+#if defined(__i386__) || defined(__x86_64__) || defined(__vax__)
+#define SET(n) \
+	(*(_MD5_u32plus *)&ptr[(n) * 4])
+#define GET(n) \
+	SET(n)
+#else
+#define SET(n) \
+	(ctx->block[(n)] = \
+	(_MD5_u32plus)ptr[(n) * 4] | \
+	((_MD5_u32plus)ptr[(n) * 4 + 1] << 8) | \
+	((_MD5_u32plus)ptr[(n) * 4 + 2] << 16) | \
+	((_MD5_u32plus)ptr[(n) * 4 + 3] << 24))
+#define GET(n) \
+	(ctx->block[(n)])
+#endif
+
+/*
+ * This processes one or more 64-byte data blocks, but does NOT update the bit
+ * counters.  There are no alignment requirements.
+ */
+static const void *body(_MD5_CTX *ctx, const void *data, unsigned long size)
+{
+	const unsigned char *ptr;
+	_MD5_u32plus a, b, c, d;
+	_MD5_u32plus saved_a, saved_b, saved_c, saved_d;
+
+	ptr = (const unsigned char *)data;
+
+	a = ctx->a;
+	b = ctx->b;
+	c = ctx->c;
+	d = ctx->d;
+
+	do {
+		saved_a = a;
+		saved_b = b;
+		saved_c = c;
+		saved_d = d;
+
+/* Round 1 */
+		STEP(F, a, b, c, d, SET(0), 0xd76aa478, 7)
+		STEP(F, d, a, b, c, SET(1), 0xe8c7b756, 12)
+		STEP(F, c, d, a, b, SET(2), 0x242070db, 17)
+		STEP(F, b, c, d, a, SET(3), 0xc1bdceee, 22)
+		STEP(F, a, b, c, d, SET(4), 0xf57c0faf, 7)
+		STEP(F, d, a, b, c, SET(5), 0x4787c62a, 12)
+		STEP(F, c, d, a, b, SET(6), 0xa8304613, 17)
+		STEP(F, b, c, d, a, SET(7), 0xfd469501, 22)
+		STEP(F, a, b, c, d, SET(8), 0x698098d8, 7)
+		STEP(F, d, a, b, c, SET(9), 0x8b44f7af, 12)
+		STEP(F, c, d, a, b, SET(10), 0xffff5bb1, 17)
+		STEP(F, b, c, d, a, SET(11), 0x895cd7be, 22)
+		STEP(F, a, b, c, d, SET(12), 0x6b901122, 7)
+		STEP(F, d, a, b, c, SET(13), 0xfd987193, 12)
+		STEP(F, c, d, a, b, SET(14), 0xa679438e, 17)
+		STEP(F, b, c, d, a, SET(15), 0x49b40821, 22)
+
+/* Round 2 */
+		STEP(G, a, b, c, d, GET(1), 0xf61e2562, 5)
+		STEP(G, d, a, b, c, GET(6), 0xc040b340, 9)
+		STEP(G, c, d, a, b, GET(11), 0x265e5a51, 14)
+		STEP(G, b, c, d, a, GET(0), 0xe9b6c7aa, 20)
+		STEP(G, a, b, c, d, GET(5), 0xd62f105d, 5)
+		STEP(G, d, a, b, c, GET(10), 0x02441453, 9)
+		STEP(G, c, d, a, b, GET(15), 0xd8a1e681, 14)
+		STEP(G, b, c, d, a, GET(4), 0xe7d3fbc8, 20)
+		STEP(G, a, b, c, d, GET(9), 0x21e1cde6, 5)
+		STEP(G, d, a, b, c, GET(14), 0xc33707d6, 9)
+		STEP(G, c, d, a, b, GET(3), 0xf4d50d87, 14)
+		STEP(G, b, c, d, a, GET(8), 0x455a14ed, 20)
+		STEP(G, a, b, c, d, GET(13), 0xa9e3e905, 5)
+		STEP(G, d, a, b, c, GET(2), 0xfcefa3f8, 9)
+		STEP(G, c, d, a, b, GET(7), 0x676f02d9, 14)
+		STEP(G, b, c, d, a, GET(12), 0x8d2a4c8a, 20)
+
+/* Round 3 */
+		STEP(H, a, b, c, d, GET(5), 0xfffa3942, 4)
+		STEP(H2, d, a, b, c, GET(8), 0x8771f681, 11)
+		STEP(H, c, d, a, b, GET(11), 0x6d9d6122, 16)
+		STEP(H2, b, c, d, a, GET(14), 0xfde5380c, 23)
+		STEP(H, a, b, c, d, GET(1), 0xa4beea44, 4)
+		STEP(H2, d, a, b, c, GET(4), 0x4bdecfa9, 11)
+		STEP(H, c, d, a, b, GET(7), 0xf6bb4b60, 16)
+		STEP(H2, b, c, d, a, GET(10), 0xbebfbc70, 23)
+		STEP(H, a, b, c, d, GET(13), 0x289b7ec6, 4)
+		STEP(H2, d, a, b, c, GET(0), 0xeaa127fa, 11)
+		STEP(H, c, d, a, b, GET(3), 0xd4ef3085, 16)
+		STEP(H2, b, c, d, a, GET(6), 0x04881d05, 23)
+		STEP(H, a, b, c, d, GET(9), 0xd9d4d039, 4)
+		STEP(H2, d, a, b, c, GET(12), 0xe6db99e5, 11)
+		STEP(H, c, d, a, b, GET(15), 0x1fa27cf8, 16)
+		STEP(H2, b, c, d, a, GET(2), 0xc4ac5665, 23)
+
+/* Round 4 */
+		STEP(I, a, b, c, d, GET(0), 0xf4292244, 6)
+		STEP(I, d, a, b, c, GET(7), 0x432aff97, 10)
+		STEP(I, c, d, a, b, GET(14), 0xab9423a7, 15)
+		STEP(I, b, c, d, a, GET(5), 0xfc93a039, 21)
+		STEP(I, a, b, c, d, GET(12), 0x655b59c3, 6)
+		STEP(I, d, a, b, c, GET(3), 0x8f0ccc92, 10)
+		STEP(I, c, d, a, b, GET(10), 0xffeff47d, 15)
+		STEP(I, b, c, d, a, GET(1), 0x85845dd1, 21)
+		STEP(I, a, b, c, d, GET(8), 0x6fa87e4f, 6)
+		STEP(I, d, a, b, c, GET(15), 0xfe2ce6e0, 10)
+		STEP(I, c, d, a, b, GET(6), 0xa3014314, 15)
+		STEP(I, b, c, d, a, GET(13), 0x4e0811a1, 21)
+		STEP(I, a, b, c, d, GET(4), 0xf7537e82, 6)
+		STEP(I, d, a, b, c, GET(11), 0xbd3af235, 10)
+		STEP(I, c, d, a, b, GET(2), 0x2ad7d2bb, 15)
+		STEP(I, b, c, d, a, GET(9), 0xeb86d391, 21)
+
+		a += saved_a;
+		b += saved_b;
+		c += saved_c;
+		d += saved_d;
+
+		ptr += 64;
+	} while (size -= 64);
+
+	ctx->a = a;
+	ctx->b = b;
+	ctx->c = c;
+	ctx->d = d;
+
+	return ptr;
+}
+
+void _MD5_Init(_MD5_CTX *ctx)
+{
+	ctx->a = 0x67452301;
+	ctx->b = 0xefcdab89;
+	ctx->c = 0x98badcfe;
+	ctx->d = 0x10325476;
+
+	ctx->lo = 0;
+	ctx->hi = 0;
+}
+
+void _MD5_Update(_MD5_CTX *ctx, const void *data, unsigned long size)
+{
+	_MD5_u32plus saved_lo;
+	unsigned long used, available;
+
+	saved_lo = ctx->lo;
+	if ((ctx->lo = (saved_lo + size) & 0x1fffffff) < saved_lo)
+		ctx->hi++;
+	ctx->hi += size >> 29;
+
+	used = saved_lo & 0x3f;
+
+	if (used) {
+		available = 64 - used;
+
+		if (size < available) {
+			memcpy(&ctx->buffer[used], data, size);
+			return;
+		}
+
+		memcpy(&ctx->buffer[used], data, available);
+		data = (const unsigned char *)data + available;
+		size -= available;
+		body(ctx, ctx->buffer, 64);
+	}
+
+	if (size >= 64) {
+		data = body(ctx, data, size & ~(unsigned long)0x3f);
+		size &= 0x3f;
+	}
+
+	memcpy(ctx->buffer, data, size);
+}
+
+#define OUT(dst, src) \
+	(dst)[0] = (unsigned char)(src); \
+	(dst)[1] = (unsigned char)((src) >> 8); \
+	(dst)[2] = (unsigned char)((src) >> 16); \
+	(dst)[3] = (unsigned char)((src) >> 24);
+
+void _MD5_Final(unsigned char *result, _MD5_CTX *ctx)
+{
+	unsigned long used, available;
+
+	used = ctx->lo & 0x3f;
+
+	ctx->buffer[used++] = 0x80;
+
+	available = 64 - used;
+
+	if (available < 8) {
+		memset(&ctx->buffer[used], 0, available);
+		body(ctx, ctx->buffer, 64);
+		used = 0;
+		available = 64;
+	}
+
+	memset(&ctx->buffer[used], 0, available - 8);
+
+	ctx->lo <<= 3;
+	OUT(&ctx->buffer[56], ctx->lo)
+	OUT(&ctx->buffer[60], ctx->hi)
+
+	body(ctx, ctx->buffer, 64);
+
+	OUT(&result[0], ctx->a)
+	OUT(&result[4], ctx->b)
+	OUT(&result[8], ctx->c)
+	OUT(&result[12], ctx->d)
+
+	memset(ctx, 0, sizeof(*ctx));
+}
+
+#endif

+ 54 - 0
Gems/Atom/RHI/DX12/External/md5/openssl/md5.h

@@ -0,0 +1,54 @@
+/*
+ * This is an OpenSSL-compatible implementation of the RSA Data Security, Inc.
+ * MD5 Message-Digest Algorithm (RFC 1321).
+ *
+ * Homepage:
+ * http://openwall.info/wiki/people/solar/software/public-domain-source-code/md5
+ *
+ * Author:
+ * Alexander Peslyak, better known as Solar Designer <solar at openwall.com>
+ *
+ * This software was written by Alexander Peslyak in 2001.  No copyright is
+ * claimed, and the software is hereby placed in the public domain.
+ * In case this attempt to disclaim copyright and place the software in the
+ * public domain is deemed null and void, then the software is
+ * Copyright (c) 2001 Alexander Peslyak and it is hereby released to the
+ * general public under the following terms:
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted.
+ *
+ * There's ABSOLUTELY NO WARRANTY, express or implied.
+ *
+ * See md5.c for more information.
+ */
+
+#ifdef HAVE_OPENSSL
+#include <openssl/md5.h>
+#elif !defined(_MD5_H)
+#define _MD5_H
+
+/* Added by baldurk, for C++ compatibility */
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+/* Any 32-bit or wider unsigned integer data type will do */
+typedef unsigned int _MD5_u32plus;
+
+typedef struct {
+	_MD5_u32plus lo, hi;
+	_MD5_u32plus a, b, c, d;
+	unsigned char buffer[64];
+	_MD5_u32plus block[16];
+} _MD5_CTX;
+
+extern void _MD5_Init(_MD5_CTX *ctx);
+extern void _MD5_Update(_MD5_CTX *ctx, const void *data, unsigned long size);
+extern void _MD5_Final(unsigned char *result, _MD5_CTX *ctx);
+
+#if defined(__cplusplus)
+};    // extern "C"
+#endif
+
+#endif

+ 2 - 1
Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -173,7 +173,8 @@ namespace AZ
            RHI::ShaderHardwareStage shaderStage,
            RHI::ShaderHardwareStage shaderStage,
            const AZStd::string& tempFolderPath,
            const AZStd::string& tempFolderPath,
            StageDescriptor& outputDescriptor,
            StageDescriptor& outputDescriptor,
-           const RHI::ShaderBuildArguments& shaderBuildArguments) const
+           const RHI::ShaderBuildArguments& shaderBuildArguments,
+           [[maybe_unused]] const bool useSpecializationConstants) const
         {
         {
             for ([[maybe_unused]] auto srgLayout : m_srgLayouts)
             for ([[maybe_unused]] auto srgLayout : m_srgLayouts)
             {
             {

+ 2 - 1
Gems/Atom/RHI/Metal/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -51,7 +51,8 @@ namespace AZ
                 RHI::ShaderHardwareStage shaderStage,
                 RHI::ShaderHardwareStage shaderStage,
                 const AZStd::string& tempFolderPath,
                 const AZStd::string& tempFolderPath,
                 StageDescriptor& outputDescriptor,
                 StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override; 
 
 
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
 
 

+ 41 - 7
Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.cpp

@@ -69,7 +69,7 @@ namespace AZ
             return m_primitiveTopology;
             return m_primitiveTopology;
         }
         }
         
         
-        id<MTLFunction> PipelineState::CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view sourceStr, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction)
+        id<MTLFunction> PipelineState::CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view sourceStr, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction, MTLFunctionConstantValues* constantValues)
         {
         {
             id<MTLFunction> pFunction = nullptr;
             id<MTLFunction> pFunction = nullptr;
             NSString* source = [[NSString alloc] initWithCString : sourceStr.data() encoding: NSASCIIStringEncoding];
             NSString* source = [[NSString alloc] initWithCString : sourceStr.data() encoding: NSASCIIStringEncoding];
@@ -125,7 +125,7 @@ namespace AZ
             if (lib)
             if (lib)
             {
             {
                 NSString* entryPointStr = [[NSString alloc] initWithCString : entryPoint.data() encoding: NSASCIIStringEncoding];
                 NSString* entryPointStr = [[NSString alloc] initWithCString : entryPoint.data() encoding: NSASCIIStringEncoding];
-                pFunction = [lib newFunctionWithName:entryPointStr];
+                pFunction = [lib newFunctionWithName:entryPointStr constantValues:constantValues error:&error];
                 [entryPointStr release];
                 [entryPointStr release];
                 entryPointStr = nil;
                 entryPointStr = nil;
                 [lib release];
                 [lib release];
@@ -170,9 +170,11 @@ namespace AZ
             [vertexDescriptor release];
             [vertexDescriptor release];
             vertexDescriptor = nil;
             vertexDescriptor = nil;
             
             
-            m_renderPipelineDesc.vertexFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_vertexFunction.get());
+            MTLFunctionConstantValues* constantValues = CreateFunctionConstantsValues(descriptor);
+            
+            m_renderPipelineDesc.vertexFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_vertexFunction.get(), constantValues);
             AZ_Assert(m_renderPipelineDesc.vertexFunction, "Vertex mtlFuntion can not be null");
             AZ_Assert(m_renderPipelineDesc.vertexFunction, "Vertex mtlFuntion can not be null");
-            m_renderPipelineDesc.fragmentFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_fragmentFunction.get());
+            m_renderPipelineDesc.fragmentFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_fragmentFunction.get(), constantValues);
             
             
             RHI::Format depthStencilFormat = attachmentsConfiguration.GetDepthStencilFormat();
             RHI::Format depthStencilFormat = attachmentsConfiguration.GetDepthStencilFormat();
             if(descriptor.m_renderStates.m_depthStencilState.m_stencil.m_enable || IsDepthStencilMerged(depthStencilFormat))
             if(descriptor.m_renderStates.m_depthStencilState.m_stencil.m_enable || IsDepthStencilMerged(depthStencilFormat))
@@ -226,6 +228,8 @@ namespace AZ
                 m_renderPipelineDesc = nil;
                 m_renderPipelineDesc = nil;
             }
             }
             
             
+            [constantValues release];
+            constantValues = nil;
              
              
             m_pipelineStateMultiSampleState = descriptor.m_renderStates.m_multisampleState;
             m_pipelineStateMultiSampleState = descriptor.m_renderStates.m_multisampleState;
             
             
@@ -257,7 +261,8 @@ namespace AZ
             m_computePipelineDesc = [[MTLComputePipelineDescriptor alloc] init];
             m_computePipelineDesc = [[MTLComputePipelineDescriptor alloc] init];
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             RHI::ConstPtr<PipelineLayout> pipelineLayout = device.AcquirePipelineLayout(*descriptor.m_pipelineLayoutDescriptor);
             AZ_Assert(pipelineLayout, "PipelineLayout can not be null");
             AZ_Assert(pipelineLayout, "PipelineLayout can not be null");
-            m_computePipelineDesc.computeFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_computeFunction.get());
+            MTLFunctionConstantValues* constantValues = CreateFunctionConstantsValues(descriptor);
+            m_computePipelineDesc.computeFunction = ExtractMtlFunction(device.GetMtlDevice(), descriptor.m_computeFunction.get(), constantValues);
             AZ_Assert(m_computePipelineDesc.computeFunction, "Compute mtlFuntion can not be null");
             AZ_Assert(m_computePipelineDesc.computeFunction, "Compute mtlFuntion can not be null");
             
             
             PipelineLibrary* pipelineLibrary = static_cast<PipelineLibrary*>(pipelineLibraryBase);
             PipelineLibrary* pipelineLibrary = static_cast<PipelineLibrary*>(pipelineLibraryBase);
@@ -279,6 +284,9 @@ namespace AZ
                 [m_computePipelineDesc release];
                 [m_computePipelineDesc release];
                 m_computePipelineDesc = nil;
                 m_computePipelineDesc = nil;
             }
             }
+                                                                       
+            [constantValues release];
+            constantValues = nil;
             
             
             if (m_computePipelineState)
             if (m_computePipelineState)
             {
             {
@@ -300,7 +308,7 @@ namespace AZ
             return RHI::ResultCode::Fail;
             return RHI::ResultCode::Fail;
         }
         }
 
 
-        id<MTLFunction> PipelineState::ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc)
+        id<MTLFunction> PipelineState::ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc, MTLFunctionConstantValues* constantValues)
         {
         {
             // set the bound shader state settings
             // set the bound shader state settings
             if (stageFunc)
             if (stageFunc)
@@ -309,13 +317,39 @@ namespace AZ
                 AZStd::string_view strView(shaderFunction->GetSourceCode());
                 AZStd::string_view strView(shaderFunction->GetSourceCode());
                 
                 
                 id<MTLFunction> mtlFunction = nil;
                 id<MTLFunction> mtlFunction = nil;
-                mtlFunction = CompileShader(mtlDevice, strView, shaderFunction->GetEntryFunctionName(), shaderFunction);
+                mtlFunction = CompileShader(mtlDevice, strView, shaderFunction->GetEntryFunctionName(), shaderFunction, constantValues);
 
 
                 return mtlFunction;
                 return mtlFunction;
             }
             }
             
             
             return nil;
             return nil;
         }
         }
+    
+        MTLFunctionConstantValues* PipelineState::CreateFunctionConstantsValues(const RHI::PipelineStateDescriptor& pipelineDescriptor) const
+        {
+            MTLFunctionConstantValues* constantValues = [[MTLFunctionConstantValues alloc] init];
+            for(const auto& specializationData : pipelineDescriptor.m_specializationData)
+            {
+                uint32_t value = specializationData.m_value.GetIndex();
+                MTLDataType type;
+                switch(specializationData.m_type)
+                {
+                    case RHI::SpecializationType::Integer:
+                        type = MTLDataTypeInt;
+                        break;
+                    case RHI::SpecializationType::Bool:
+                        type = MTLDataTypeBool;
+                        break;
+                    default:
+                        AZ_Assert(false, "Invalid specialization type %d", specializationData.m_type);
+                        type = MTLDataTypeInt;
+                        break;
+                }
+                [constantValues setConstantValue:&value type:type atIndex:specializationData.m_id];
+            }
+            return constantValues;
+        }
+    
         void PipelineState::ShutdownInternal()
         void PipelineState::ShutdownInternal()
         {
         {
             if (m_graphicsPipelineState)
             if (m_graphicsPipelineState)

+ 3 - 2
Gems/Atom/RHI/Metal/Code/Source/RHI/PipelineState.h

@@ -71,8 +71,9 @@ namespace AZ
             void ShutdownInternal() override;
             void ShutdownInternal() override;
             //////////////////////////////////////////////////////////////////////////
             //////////////////////////////////////////////////////////////////////////
 
 
-            id<MTLFunction> CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view filePath, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction);
-            id<MTLFunction> ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc);
+            id<MTLFunction> CompileShader(id<MTLDevice> mtlDevice, const AZStd::string_view filePath, const AZStd::string_view entryPoint, const ShaderStageFunction* shaderFunction, MTLFunctionConstantValues* constantValues);
+            id<MTLFunction> ExtractMtlFunction(id<MTLDevice> mtlDevice, const RHI::ShaderStageFunction* stageFunc, MTLFunctionConstantValues* constantValues);
+            MTLFunctionConstantValues* CreateFunctionConstantsValues(const RHI::PipelineStateDescriptor& pipelineDescriptor) const;
             
             
             RHI::ConstPtr<PipelineLayout> m_pipelineLayout;
             RHI::ConstPtr<PipelineLayout> m_pipelineLayout;
             AZStd::atomic_bool m_isCompiled = {false};
             AZStd::atomic_bool m_isCompiled = {false};

+ 2 - 1
Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -100,7 +100,8 @@ namespace AZ
             [[maybe_unused]] const AssetBuilderSDK::PlatformInfo& platform, [[maybe_unused]] const AZStd::string& shaderSourcePath,
             [[maybe_unused]] const AssetBuilderSDK::PlatformInfo& platform, [[maybe_unused]] const AZStd::string& shaderSourcePath,
             [[maybe_unused]] const AZStd::string& functionName, [[maybe_unused]] RHI::ShaderHardwareStage shaderStage,
             [[maybe_unused]] const AZStd::string& functionName, [[maybe_unused]] RHI::ShaderHardwareStage shaderStage,
             [[maybe_unused]] const AZStd::string& tempFolderPath, [[maybe_unused]] StageDescriptor& outputDescriptor,
             [[maybe_unused]] const AZStd::string& tempFolderPath, [[maybe_unused]] StageDescriptor& outputDescriptor,
-            [[maybe_unused]] const RHI::ShaderBuildArguments& shaderBuildArguments) const
+            [[maybe_unused]] const RHI::ShaderBuildArguments& shaderBuildArguments,
+            [[maybe_unused]] const bool useSpecializationConstants) const
         {
         {
             outputDescriptor.m_stageType = shaderStage;
             outputDescriptor.m_stageType = shaderStage;
             return true;
             return true;

+ 2 - 1
Gems/Atom/RHI/Null/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -33,7 +33,8 @@ namespace AZ
             bool CompilePlatformInternal(
             bool CompilePlatformInternal(
                 const AssetBuilderSDK::PlatformInfo& platform, const AZStd::string& shaderSourcePath, const AZStd::string& functionName,
                 const AssetBuilderSDK::PlatformInfo& platform, const AZStd::string& shaderSourcePath, const AZStd::string& functionName,
                 RHI::ShaderHardwareStage shaderStage, const AZStd::string& tempFolderPath, StageDescriptor& outputDescriptor,
                 RHI::ShaderHardwareStage shaderStage, const AZStd::string& tempFolderPath, StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override;
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
             bool BuildPipelineLayoutDescriptor(
             bool BuildPipelineLayoutDescriptor(
                 RHI::Ptr<RHI::PipelineLayoutDescriptor> pipelineLayoutDescriptor,
                 RHI::Ptr<RHI::PipelineLayoutDescriptor> pipelineLayoutDescriptor,

+ 2 - 1
Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.cpp

@@ -112,7 +112,8 @@ namespace AZ
             RHI::ShaderHardwareStage shaderAssetType,
             RHI::ShaderHardwareStage shaderAssetType,
             const AZStd::string& tempFolderPath,
             const AZStd::string& tempFolderPath,
             StageDescriptor& outputDescriptor,
             StageDescriptor& outputDescriptor,
-            const RHI::ShaderBuildArguments& shaderBuildArguments) const
+            const RHI::ShaderBuildArguments& shaderBuildArguments,
+            [[maybe_unused]] const bool useSpecializationConstants) const
         {
         {
             AZStd::vector<uint8_t> shaderByteCode;
             AZStd::vector<uint8_t> shaderByteCode;
 
 

+ 2 - 1
Gems/Atom/RHI/Vulkan/Code/Source/RHI.Builders/ShaderPlatformInterface.h

@@ -49,7 +49,8 @@ namespace AZ
                 RHI::ShaderHardwareStage shaderStage,
                 RHI::ShaderHardwareStage shaderStage,
                 const AZStd::string& tempFolderPath,
                 const AZStd::string& tempFolderPath,
                 StageDescriptor& outputDescriptor,
                 StageDescriptor& outputDescriptor,
-                const RHI::ShaderBuildArguments& shaderBuildArguments) const override;
+                const RHI::ShaderBuildArguments& shaderBuildArguments,
+                const bool useSpecializationConstants) const override;
 
 
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
             const char* GetAzslHeader(const AssetBuilderSDK::PlatformInfo& platform) const override;
 
 

+ 3 - 1
Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.cpp

@@ -32,6 +32,7 @@ namespace AZ
             }
             }
 
 
             Base::Init(*descriptor.m_device);
             Base::Init(*descriptor.m_device);
+            m_specializationConstantData.Init(*descriptor.m_pipelineDescritor);
 
 
             RHI::ResultCode result = InitInternal(descriptor, *layout);
             RHI::ResultCode result = InitInternal(descriptor, *layout);
             RETURN_RESULT_IF_UNSUCCESSFUL(result);
             RETURN_RESULT_IF_UNSUCCESSFUL(result);
@@ -91,6 +92,7 @@ namespace AZ
                 device.GetContext().DestroyPipeline(device.GetNativeDevice(), m_nativePipeline, VkSystemAllocator::Get());
                 device.GetContext().DestroyPipeline(device.GetNativeDevice(), m_nativePipeline, VkSystemAllocator::Get());
                 m_nativePipeline = VK_NULL_HANDLE;
                 m_nativePipeline = VK_NULL_HANDLE;
             }
             }
+            m_specializationConstantData.Shutdown();
             Base::Shutdown();
             Base::Shutdown();
         }
         }
 
 
@@ -137,7 +139,7 @@ namespace AZ
             createInfo.stage = stageBits;
             createInfo.stage = stageBits;
             createInfo.module = shaderModule->GetNativeShaderModule();
             createInfo.module = shaderModule->GetNativeShaderModule();
             createInfo.pName = shaderModule->GetEntryFunctionName().c_str();
             createInfo.pName = shaderModule->GetEntryFunctionName().c_str();
-            createInfo.pSpecializationInfo = nullptr;
+            createInfo.pSpecializationInfo = m_specializationConstantData.GetVkSpecializationInfo();
         }
         }
     }
     }
 }
 }

+ 3 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/Pipeline.h

@@ -14,6 +14,7 @@
 #include <AzCore/std/containers/list.h>
 #include <AzCore/std/containers/list.h>
 #include <RHI/PipelineLayout.h>
 #include <RHI/PipelineLayout.h>
 #include <RHI/ShaderModule.h>
 #include <RHI/ShaderModule.h>
+#include <RHI/SpecializationConstantData.h>
 
 
 namespace AZ
 namespace AZ
 {
 {
@@ -78,6 +79,8 @@ namespace AZ
             RHI::Ptr<PipelineLayout> m_pipelineLayout;
             RHI::Ptr<PipelineLayout> m_pipelineLayout;
             AZStd::list<RHI::Ptr<ShaderModule>> m_shaderModules;
             AZStd::list<RHI::Ptr<ShaderModule>> m_shaderModules;
             VkPipeline m_nativePipeline = VK_NULL_HANDLE;
             VkPipeline m_nativePipeline = VK_NULL_HANDLE;
+            // Contains the values of the specialization constants
+            SpecializationConstantData m_specializationConstantData;
         };
         };
     }
     }
 }
 }

+ 11 - 3
Gems/Atom/RHI/Vulkan/Code/Source/RHI/RayTracingPipelineState.cpp

@@ -7,9 +7,10 @@
  */
  */
 
 
 #include <RHI/RayTracingPipelineState.h>
 #include <RHI/RayTracingPipelineState.h>
+#include <RHI/Device.h>
+#include <RHI/SpecializationConstantData.h>
 #include <Atom/RHI.Reflect/SamplerState.h>
 #include <Atom/RHI.Reflect/SamplerState.h>
 #include <Atom/RHI.Reflect/Vulkan/ShaderStageFunction.h>
 #include <Atom/RHI.Reflect/Vulkan/ShaderStageFunction.h>
-#include <RHI/Device.h>
 #include <Atom/RHI.Reflect/VkAllocator.h>
 #include <Atom/RHI.Reflect/VkAllocator.h>
 
 
 namespace AZ
 namespace AZ
@@ -37,11 +38,14 @@ namespace AZ
             // process shader libraries into shader stages and groups
             // process shader libraries into shader stages and groups
             AZStd::vector<VkPipelineShaderStageCreateInfo> stages;
             AZStd::vector<VkPipelineShaderStageCreateInfo> stages;
             AZStd::vector<VkRayTracingShaderGroupCreateInfoKHR> groups;
             AZStd::vector<VkRayTracingShaderGroupCreateInfoKHR> groups;
+            AZStd::vector<SpecializationConstantData> specializationDataVector(descriptor->GetShaderLibraries().size());
+
             m_shaderModules.reserve(descriptor->GetShaderLibraries().size());
             m_shaderModules.reserve(descriptor->GetShaderLibraries().size());
-            for (const RHI::RayTracingShaderLibrary& shaderLibrary : descriptor->GetShaderLibraries())
+            const auto& libraries = descriptor->GetShaderLibraries();
+            for (uint32_t i = 0; i < libraries.size(); ++i)
             {
             {
+                const RHI::RayTracingShaderLibrary& shaderLibrary = libraries[i];
                 const ShaderStageFunction* rayTracingFunction = azrtti_cast<const ShaderStageFunction*>(shaderLibrary.m_descriptor.m_rayTracingFunction.get());
                 const ShaderStageFunction* rayTracingFunction = azrtti_cast<const ShaderStageFunction*>(shaderLibrary.m_descriptor.m_rayTracingFunction.get());
-
                 VkShaderModule& shaderModule = m_shaderModules.emplace_back();
                 VkShaderModule& shaderModule = m_shaderModules.emplace_back();
                 VkShaderModuleCreateInfo moduleCreateInfo = {};
                 VkShaderModuleCreateInfo moduleCreateInfo = {};
                 moduleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
                 moduleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
@@ -50,9 +54,13 @@ namespace AZ
                 device.GetContext().CreateShaderModule(
                 device.GetContext().CreateShaderModule(
                     device.GetNativeDevice(), &moduleCreateInfo, VkSystemAllocator::Get(), &shaderModule);
                     device.GetNativeDevice(), &moduleCreateInfo, VkSystemAllocator::Get(), &shaderModule);
 
 
+                SpecializationConstantData& specializationData = specializationDataVector[i];
+                specializationData.Init(shaderLibrary.m_descriptor);
+
                 VkPipelineShaderStageCreateInfo stageCreateInfo = {};
                 VkPipelineShaderStageCreateInfo stageCreateInfo = {};
                 stageCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
                 stageCreateInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
                 stageCreateInfo.module = shaderModule;
                 stageCreateInfo.module = shaderModule;
+                stageCreateInfo.pSpecializationInfo = specializationData.GetVkSpecializationInfo();
 
 
                 // ray generation
                 // ray generation
                 if (!shaderLibrary.m_rayGenerationShaderName.IsEmpty())
                 if (!shaderLibrary.m_rayGenerationShaderName.IsEmpty())

+ 66 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.cpp

@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#include <RHI/Vulkan.h>
+#include <RHI/SpecializationConstantData.h>
+
+namespace AZ::Vulkan
+{
+    template<class T>
+    void AddSpecializationValue(AZStd::vector<uint8_t>& data, const RHI::SpecializationValue& specValue)
+    {
+        size_t size = sizeof(T);
+        size_t offset = data.size();
+        data.resize(data.size() + size);
+        T value = static_cast<T>(specValue.GetIndex());
+        memcpy(data.data() + offset, &value, size);
+    }
+
+    RHI::ResultCode SpecializationConstantData::Init(const RHI::PipelineStateDescriptor& pipelineDescriptor)
+    {
+        m_specializationData.reserve(pipelineDescriptor.m_specializationData.size() * sizeof(uint32_t));
+        for (const RHI::SpecializationConstant& specialization : pipelineDescriptor.m_specializationData)
+        {
+            m_specializationMap.emplace_back();
+            VkSpecializationMapEntry& entry = m_specializationMap.back();
+            entry.constantID = specialization.m_id;
+            entry.offset = aznumeric_cast<uint32_t>(m_specializationData.size());
+            switch (specialization.m_type)
+            {
+            case RHI::SpecializationType::Integer:
+                entry.size = sizeof(uint32_t);
+                AddSpecializationValue<uint32_t>(m_specializationData, specialization.m_value);
+                break;
+            case RHI::SpecializationType::Bool:
+                entry.size = sizeof(VkBool32);
+                AddSpecializationValue<VkBool32>(m_specializationData, specialization.m_value);
+                break;
+            default:
+                AZ_Assert(false, "Invalid specialization type %d", specialization.m_type);
+                return RHI::ResultCode::InvalidArgument;
+            }
+        }
+
+        m_specializationInfo.dataSize = m_specializationData.size();
+        m_specializationInfo.mapEntryCount = aznumeric_cast<uint32_t>(m_specializationMap.size());
+        m_specializationInfo.pData = m_specializationData.data();
+        m_specializationInfo.pMapEntries = m_specializationMap.data();
+        return RHI::ResultCode::Success;
+    }
+
+    void SpecializationConstantData::Shutdown()
+    {
+        m_specializationInfo = {};
+        m_specializationMap.clear();
+        m_specializationData.clear();
+    }
+
+    const VkSpecializationInfo* SpecializationConstantData::GetVkSpecializationInfo() const
+    {
+        return m_specializationInfo.mapEntryCount ? &m_specializationInfo : nullptr;
+    }
+} // namespace AZ::Vulkan

+ 37 - 0
Gems/Atom/RHI/Vulkan/Code/Source/RHI/SpecializationConstantData.h

@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI/PipelineStateDescriptor.h>
+#include <AzCore/std/containers/vector.h>
+
+namespace AZ::Vulkan
+{
+    //! Contains the vulkan structure needed for using specialization constants
+    class SpecializationConstantData
+    {
+    public:
+        SpecializationConstantData() = default;
+
+        //! Initialize the contents with the specialization constants information of the pipeline descriptor
+        RHI::ResultCode Init(const RHI::PipelineStateDescriptor& descriptor);
+        //! Release any data previously used
+        void Shutdown();
+
+        //! Returns the vulkan specialization info
+        const VkSpecializationInfo* GetVkSpecializationInfo() const;
+
+    private:
+        // Vulkan structure for using specialization constants
+        VkSpecializationInfo m_specializationInfo{};
+        // Vector with the mapping information of the specialization constants (ids and offsets).
+        AZStd::vector<VkSpecializationMapEntry> m_specializationMap;
+        // Memory buffer with the values of the specialization constants
+        AZStd::vector<uint8_t> m_specializationData;
+    };
+}

+ 2 - 0
Gems/Atom/RHI/Vulkan/Code/atom_rhi_vulkan_private_common_files.cmake

@@ -169,4 +169,6 @@ set(FILES
     Source/RHI/Conversion.cpp
     Source/RHI/Conversion.cpp
     Source/RHI/Conversion.h
     Source/RHI/Conversion.h
     Source/RHI/WindowSurfaceBus.h
     Source/RHI/WindowSurfaceBus.h
+    Source/RHI/SpecializationConstantData.cpp
+    Source/RHI/SpecializationConstantData.h
 )
 )

+ 5 - 1
Gems/Atom/RPI/Code/Include/Atom/RPI.Edit/Shader/ShaderVariantAssetCreator.h

@@ -28,7 +28,11 @@ namespace AZ
             //!        It is still useful, because when creating the Root Variant for the ShaderAsset this assetId should
             //!        It is still useful, because when creating the Root Variant for the ShaderAsset this assetId should
             //!        match the value that will be assigned by the asset processor because the Root Variant is serialized
             //!        match the value that will be assigned by the asset processor because the Root Variant is serialized
             //!        as a Data::Asset<ShaderVariantAsset> inside the ShaderAsset.
             //!        as a Data::Asset<ShaderVariantAsset> inside the ShaderAsset.
-            void Begin(const AZ::Data::AssetId& assetId, const ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId stableId, bool isFullyBaked);
+            void Begin(
+                const AZ::Data::AssetId& assetId,
+                const ShaderVariantId& shaderVariantId,
+                RPI::ShaderVariantStableId stableId,
+                bool isFullyBaked);
 
 
             //! Finalizes and assigns ownership of the asset to result, if successful. 
             //! Finalizes and assigns ownership of the asset to result, if successful. 
             //! Otherwise false is returned and result is left untouched.
             //! Otherwise false is returned and result is left untouched.

+ 3 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/ComputePass.h

@@ -52,6 +52,9 @@ namespace AZ
             using ComputeShaderReloadedCallback = AZStd::function<void(ComputePass* computePass)>;
             using ComputeShaderReloadedCallback = AZStd::function<void(ComputePass* computePass)>;
             void SetComputeShaderReloadedCallback(ComputeShaderReloadedCallback callback);
             void SetComputeShaderReloadedCallback(ComputeShaderReloadedCallback callback);
 
 
+            //! Updates the shader variant being used by the pass
+            void UpdateShaderOptions(const ShaderVariantId& shaderVariantId);
+
         protected:
         protected:
             ComputePass(const PassDescriptor& descriptor, AZ::Name supervariant = AZ::Name(""));
             ComputePass(const PassDescriptor& descriptor, AZ::Name supervariant = AZ::Name(""));
 
 

+ 4 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Pass/FullscreenTrianglePass.h

@@ -46,6 +46,7 @@ namespace AZ
 
 
             //! Updates the shader options used in this pass
             //! Updates the shader options used in this pass
             void UpdateShaderOptions(const ShaderOptionList& shaderOptions);
             void UpdateShaderOptions(const ShaderOptionList& shaderOptions);
+            void UpdateShaderOptions(const ShaderVariantId& shaderVariantId);
 
 
         protected:
         protected:
             FullscreenTrianglePass(const PassDescriptor& descriptor);
             FullscreenTrianglePass(const PassDescriptor& descriptor);
@@ -64,6 +65,9 @@ namespace AZ
             void OnShaderAssetReinitialized(const Data::Asset<ShaderAsset>& shaderAsset) override;
             void OnShaderAssetReinitialized(const Data::Asset<ShaderAsset>& shaderAsset) override;
             void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) override;
             void OnShaderVariantReinitialized(const ShaderVariant& shaderVariant) override;
 
 
+            // Common code when updating the shader variant with new options
+            void UpdateShaderOptionsCommon();
+
             RHI::Viewport m_viewportState;
             RHI::Viewport m_viewportState;
             RHI::Scissor m_scissorState;
             RHI::Scissor m_scissorState;
 
 

+ 4 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/PipelineState.h

@@ -38,6 +38,7 @@ namespace AZ
             //! Initialize the pipeline state from a shader and one of its shader variant
             //! Initialize the pipeline state from a shader and one of its shader variant
             //! The previous data will be reset
             //! The previous data will be reset
             void Init(const Data::Instance<Shader>& shader, const ShaderOptionList* optionAndValues = nullptr);
             void Init(const Data::Instance<Shader>& shader, const ShaderOptionList* optionAndValues = nullptr);
+            void Init(const Data::Instance<Shader>& shader, const ShaderVariantId& shaderVariantId);
 
 
             //! Update the pipeline state descriptor for the specified scene
             //! Update the pipeline state descriptor for the specified scene
             //! This is usually called when Scene's render pipelines changed
             //! This is usually called when Scene's render pipelines changed
@@ -79,6 +80,9 @@ namespace AZ
             //! Clear all the states and references
             //! Clear all the states and references
             void Shutdown();
             void Shutdown();
 
 
+            //! Returns the id of the shader variant being used
+            const ShaderVariantId& GetShaderVariantId() const;
+
         private:
         private:
             ///////////////////////////////////////////////////////////////////
             ///////////////////////////////////////////////////////////////////
             // ShaderReloadNotificationBus overrides...
             // ShaderReloadNotificationBus overrides...

+ 22 - 2
Gems/Atom/RPI/Code/Include/Atom/RPI.Public/Shader/ShaderVariant.h

@@ -26,9 +26,17 @@ namespace AZ
             virtual ~ShaderVariant();
             virtual ~ShaderVariant();
             AZ_DEFAULT_COPY_MOVE(ShaderVariant);
             AZ_DEFAULT_COPY_MOVE(ShaderVariant);
 
 
-            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant. (Note that
-            //! this does not fill the InputStreamLayout or OutputAttachmentLayout as that also requires 
+            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant. 
+            //! It also configures the specialization constants if they are being used by the shader variant.
+            //! (Note that this does not fill the InputStreamLayout or OutputAttachmentLayout as that also requires 
             //! information from the mesh data and pass system and must be done as a separate step).
             //! information from the mesh data and pass system and must be done as a separate step).
+            void ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor, const ShaderVariantId& specialization) const;
+            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant.
+            //! It also configures the specialization constants if they are being used by the shader variant.
+            void ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor, const ShaderOptionGroup& specialization) const;
+            //! Fills a pipeline state descriptor with settings provided by the ShaderVariant.
+            //! Only use this function if the shader variant is not using ANY specialization constant. Otherwise
+            //! an error will be raised and the default values will be used.
             void ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const;
             void ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const;
 
 
             const ShaderVariantId& GetShaderVariantId() const { return m_shaderVariantAsset->GetShaderVariantId(); }
             const ShaderVariantId& GetShaderVariantId() const { return m_shaderVariantAsset->GetShaderVariantId(); }
@@ -38,6 +46,15 @@ namespace AZ
             //! If the shader variant is not fully baked, the ShaderVariantKeyFallbackValue must be correctly set when drawing.
             //! If the shader variant is not fully baked, the ShaderVariantKeyFallbackValue must be correctly set when drawing.
             bool IsFullyBaked() const { return m_shaderVariantAsset->IsFullyBaked(); }
             bool IsFullyBaked() const { return m_shaderVariantAsset->IsFullyBaked(); }
 
 
+            //! Returns whether the variant is using specialization constants for all of the options.
+            bool IsFullySpecialized() const;
+
+            //! Return true if this shader variant has at least one shader option using specialization constant.
+            bool UseSpecializationConstants() const;
+
+            //! Return true if this variant needs the ShaderVariantKeyFallbackValue to be correctly set when drawing.
+            bool UseKeyFallback() const;
+
             //! Return the timestamp when this asset was built.
             //! Return the timestamp when this asset was built.
             //! This is used to synchronize versions of the ShaderAsset and ShaderVariantAsset, especially during hot-reload.
             //! This is used to synchronize versions of the ShaderAsset and ShaderVariantAsset, especially during hot-reload.
             //! This timestamp must be >= than the ShaderAsset timestamp.
             //! This timestamp must be >= than the ShaderAsset timestamp.
@@ -70,6 +87,9 @@ namespace AZ
 
 
             const RHI::RenderStates* m_renderStates = nullptr; // Cached from ShaderAsset.
             const RHI::RenderStates* m_renderStates = nullptr; // Cached from ShaderAsset.
             SupervariantIndex m_supervariantIndex;
             SupervariantIndex m_supervariantIndex;
+
+            // True if there's at least one shader option that is using a specialization constant.
+            bool m_useSpecializationConstants = false;
         };
         };
     }
     }
 }
 }

+ 18 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAsset.h

@@ -216,6 +216,20 @@ namespace AZ
                 return GetAttribute(shaderStage, attributeName, DefaultSupervariantIndex);
                 return GetAttribute(shaderStage, attributeName, DefaultSupervariantIndex);
             }
             }
 
 
+            //! Returns if the supervariant uses specialization constants for at least one shader options.
+            bool UseSpecializationConstants(SupervariantIndex supervariantIndex) const;
+            bool UseSpecializationConstants() const
+            {
+                return UseSpecializationConstants(DefaultSupervariantIndex);
+            }
+
+            //! Returns true if the supervariant is fully specialized (all shader options are specialization constants)
+            bool IsFullySpecialized(SupervariantIndex supervariantIndex) const;
+            bool IsFullySpecialized() const
+            {
+                return IsFullySpecialized(DefaultSupervariantIndex);
+            }
+
         private:
         private:
             ///////////////////////////////////////////////////////////////////
             ///////////////////////////////////////////////////////////////////
             /// ShaderVariantFinderNotificationBus overrides
             /// ShaderVariantFinderNotificationBus overrides
@@ -242,6 +256,7 @@ namespace AZ
                 RHI::RenderStates m_renderStates;
                 RHI::RenderStates m_renderStates;
                 RHI::ShaderStageAttributeMapList m_attributeMaps;
                 RHI::ShaderStageAttributeMapList m_attributeMaps;
                 Data::Asset<ShaderVariantAsset> m_rootShaderVariantAsset;
                 Data::Asset<ShaderVariantAsset> m_rootShaderVariantAsset;
+                bool m_useSpecializationConstants = false;
             };
             };
 
 
             //! Container of shader data that is specific to an RHI API.
             //! Container of shader data that is specific to an RHI API.
@@ -313,6 +328,9 @@ namespace AZ
             mutable AZStd::shared_mutex m_variantTreeMutex;
             mutable AZStd::shared_mutex m_variantTreeMutex;
 
 
             bool m_shaderVariantTreeLoadWasRequested = false;
             bool m_shaderVariantTreeLoadWasRequested = false;
+
+            //! True if all supervariants are fully specialized
+            bool m_isFullySpecialized = false;
         };
         };
 
 
         class ShaderAssetHandler final
         class ShaderAssetHandler final

+ 3 - 0
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderAssetCreator.h

@@ -73,6 +73,9 @@ namespace AZ
             //! [Required] There's always a root variant for each supervariant.
             //! [Required] There's always a root variant for each supervariant.
             void SetRootShaderVariantAsset(Data::Asset<ShaderVariantAsset> shaderVariantAsset);
             void SetRootShaderVariantAsset(Data::Asset<ShaderVariantAsset> shaderVariantAsset);
 
 
+            //! Set if the supervariant uses specialization constants for shader options.
+            void SetUseSpecializationConstants(bool value);
+
             bool EndSupervariant();
             bool EndSupervariant();
 
 
             bool EndAPI();
             bool EndAPI();

+ 18 - 1
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderOptionGroupLayout.h

@@ -90,7 +90,8 @@ namespace AZ
                                    uint32_t order,
                                    uint32_t order,
                                    const ShaderOptionValues& nameIndexList,
                                    const ShaderOptionValues& nameIndexList,
                                    const Name& defaultValue = {},
                                    const Name& defaultValue = {},
-                                   uint32_t cost = 0);
+                                   uint32_t cost = 0,
+                                   int specializationId = -1);
 
 
             AZ_DEFAULT_COPY_MOVE(ShaderOptionDescriptor);
             AZ_DEFAULT_COPY_MOVE(ShaderOptionDescriptor);
 
 
@@ -105,6 +106,9 @@ namespace AZ
 
 
             uint32_t GetCostEstimate() const;
             uint32_t GetCostEstimate() const;
 
 
+            //! Return the specialization id. -1 if this option can't be specialize.
+            int GetSpecializationId() const;
+
             //! Returns the mask comprising bits specific to this option.
             //! Returns the mask comprising bits specific to this option.
             ShaderVariantKey GetBitMask() const;
             ShaderVariantKey GetBitMask() const;
 
 
@@ -192,6 +196,7 @@ namespace AZ
             uint32_t m_bitCount = 0;
             uint32_t m_bitCount = 0;
             uint32_t m_order = 0;          //!< The order (or rank) of the shader option dictates its priority. Lower order (rank) is higher priority.
             uint32_t m_order = 0;          //!< The order (or rank) of the shader option dictates its priority. Lower order (rank) is higher priority.
             uint32_t m_costEstimate = 0;
             uint32_t m_costEstimate = 0;
+            int m_specializationId = -1; //< Specialization id. A value of -1 means no specialization.
             ShaderVariantKey m_bitMask;
             ShaderVariantKey m_bitMask;
             ShaderVariantKey m_bitMaskNot;
             ShaderVariantKey m_bitMaskNot;
 
 
@@ -263,6 +268,13 @@ namespace AZ
 
 
             HashValue64 GetHash() const;
             HashValue64 GetHash() const;
 
 
+            //! Returns true if all shader options of the layout are using specialization constants. Please note that each
+            //! supervariant can have specialization constants off even if the layout is IsFullySpecialized.
+            bool IsFullySpecialized() const;
+
+            //! Returns true if at least one shader option is using specialization constant.
+            bool UseSpecializationConstants() const;
+
         private:
         private:
             ShaderOptionGroupLayout() = default;
             ShaderOptionGroupLayout() = default;
 
 
@@ -281,6 +293,11 @@ namespace AZ
             using NameReflectionMapForOptions = RHI::NameIdReflectionMap<ShaderOptionIndex>;
             using NameReflectionMapForOptions = RHI::NameIdReflectionMap<ShaderOptionIndex>;
             NameReflectionMapForOptions m_nameReflectionForOptions;
             NameReflectionMapForOptions m_nameReflectionForOptions;
             HashValue64 m_hash = HashValue64{ 0 };
             HashValue64 m_hash = HashValue64{ 0 };
+
+            // True if all shader options are using specialization constants
+            bool m_isFullySpecialized = false;
+            // True if at least one shader options is using specialization constants
+            bool m_useSpecializationConstants = false;
         };
         };
     } // namespace RPI
     } // namespace RPI
 
 

+ 2 - 2
Gems/Atom/RPI/Code/Include/Atom/RPI.Reflect/Shader/ShaderVariantAsset.h

@@ -58,14 +58,14 @@ namespace AZ
 
 
             //! Returns whether the variant is fully baked variant (all options are static branches), or false if the
             //! Returns whether the variant is fully baked variant (all options are static branches), or false if the
             //! variant uses dynamic branches for some shader options.
             //! variant uses dynamic branches for some shader options.
-            //! If the shader variant is not fully baked, the ShaderVariantKeyFallbackValue must be correctly set when drawing.
+            //! If the shader variant is not fully baked or fully specialized, the ShaderVariantKeyFallbackValue must be correctly set when drawing.
             bool IsFullyBaked() const;
             bool IsFullyBaked() const;
 
 
             //! Return the timestamp when this asset was built, and it must be >= than the timestamp of the main ShaderAsset.
             //! Return the timestamp when this asset was built, and it must be >= than the timestamp of the main ShaderAsset.
             //! This is used to synchronize versions of the ShaderAsset and ShaderVariantAsset, especially during hot-reload.
             //! This is used to synchronize versions of the ShaderAsset and ShaderVariantAsset, especially during hot-reload.
             AZ::u64 GetBuildTimestamp() const;
             AZ::u64 GetBuildTimestamp() const;
 
 
-            bool IsRootVariant() const { return m_stableId == RPI::RootShaderVariantStableId; } 
+            bool IsRootVariant() const { return m_stableId == RPI::RootShaderVariantStableId; }
 
 
         private:
         private:
             //! Called by asset creators to assign the asset to a ready state.
             //! Called by asset creators to assign the asset to a ready state.

+ 5 - 1
Gems/Atom/RPI/Code/Source/RPI.Edit/Shader/ShaderVariantAssetCreator.cpp

@@ -15,7 +15,11 @@ namespace AZ
 {
 {
     namespace RPI
     namespace RPI
     {
     {
-        void ShaderVariantAssetCreator::Begin(const AZ::Data::AssetId& assetId, const ShaderVariantId& shaderVariantId, RPI::ShaderVariantStableId stableId, bool isFullyBaked)
+        void ShaderVariantAssetCreator::Begin(
+            const AZ::Data::AssetId& assetId,
+            const ShaderVariantId& shaderVariantId,
+            RPI::ShaderVariantStableId stableId,
+            bool isFullyBaked)
         {
         {
             BeginCommon(assetId);
             BeginCommon(assetId);
 
 

+ 1 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/MeshDrawPacket.cpp

@@ -376,7 +376,7 @@ namespace AZ
 #endif
 #endif
 
 
                 RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
                 RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
-                variant.ConfigurePipelineState(pipelineStateDescriptor);
+                variant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptions);
 
 
                 // Render states need to merge the runtime variation.
                 // Render states need to merge the runtime variation.
                 // This allows materials to customize the render states that the shader uses.
                 // This allows materials to customize the render states that the shader uses.

+ 19 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/ComputePass.cpp

@@ -144,9 +144,14 @@ namespace AZ
 
 
             // Setup pipeline state...
             // Setup pipeline state...
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
             RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
-            m_shader->GetDefaultVariant().ConfigurePipelineState(pipelineStateDescriptor);
+            ShaderOptionGroup options = m_shader->GetDefaultShaderOptions();
+            m_shader->GetDefaultVariant().ConfigurePipelineState(pipelineStateDescriptor, options);
 
 
             m_dispatchItem.SetPipelineState(m_shader->AcquirePipelineState(pipelineStateDescriptor));
             m_dispatchItem.SetPipelineState(m_shader->AcquirePipelineState(pipelineStateDescriptor));
+            if (m_drawSrg && m_shader->GetDefaultVariant().UseKeyFallback())
+            {
+                m_drawSrg->SetShaderVariantKeyFallbackValue(options.GetShaderVariantKeyFallbackValue());
+            }
 
 
             OnShaderReloadedInternal();
             OnShaderReloadedInternal();
 
 
@@ -255,6 +260,19 @@ namespace AZ
             m_shaderReloadedCallback = callback;
             m_shaderReloadedCallback = callback;
         }
         }
 
 
+        void ComputePass::UpdateShaderOptions(const ShaderVariantId& shaderVariantId)
+        {
+            const ShaderVariant& shaderVariant = m_shader->GetVariant(shaderVariantId);
+            RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
+            shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, shaderVariantId);
+
+            m_dispatchItem.SetPipelineState(m_shader->AcquirePipelineState(pipelineStateDescriptor));
+            if (m_drawSrg && shaderVariant.UseKeyFallback())
+            {
+                m_drawSrg->SetShaderVariantKeyFallbackValue(shaderVariantId.m_key);
+            }
+        }
+
         void ComputePass::OnShaderReloadedInternal()
         void ComputePass::OnShaderReloadedInternal()
         {
         {
             if (m_shaderReloadedCallback)
             if (m_shaderReloadedCallback)

+ 17 - 3
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/FullscreenTrianglePass.cpp

@@ -68,6 +68,12 @@ namespace AZ
             LoadShader();
             LoadShader();
         }
         }
 
 
+        void FullscreenTrianglePass::UpdateShaderOptionsCommon()
+        {
+            m_pipelineStateForDraw.UpdateSrgVariantFallback(m_shaderResourceGroup);
+            BuildDrawItem();
+        }
+
         void FullscreenTrianglePass::LoadShader()
         void FullscreenTrianglePass::LoadShader()
         {
         {
             AZ_Assert(GetPassState() != PassState::Rendering, "FullscreenTrianglePass - Reloading shader during Rendering phase!");
             AZ_Assert(GetPassState() != PassState::Rendering, "FullscreenTrianglePass - Reloading shader during Rendering phase!");
@@ -120,7 +126,7 @@ namespace AZ
             // Store stencil reference value for the draw call
             // Store stencil reference value for the draw call
             m_stencilRef = passData->m_stencilRef;
             m_stencilRef = passData->m_stencilRef;
 
 
-            m_pipelineStateForDraw.Init(m_shader);
+            m_pipelineStateForDraw.Init(m_shader, m_shader->GetDefaultShaderOptions().GetShaderVariantId());
 
 
             UpdateSrgs();
             UpdateSrgs();
 
 
@@ -191,8 +197,16 @@ namespace AZ
             if (m_shader)
             if (m_shader)
             {
             {
                 m_pipelineStateForDraw.Init(m_shader, &shaderOptions);
                 m_pipelineStateForDraw.Init(m_shader, &shaderOptions);
-                m_pipelineStateForDraw.UpdateSrgVariantFallback(m_shaderResourceGroup);
-                BuildDrawItem();
+                UpdateShaderOptionsCommon();
+            }
+        }
+
+        void FullscreenTrianglePass::UpdateShaderOptions(const ShaderVariantId& shaderVariantId)
+        {
+            if (m_shader)
+            {
+                m_pipelineStateForDraw.Init(m_shader, shaderVariantId);
+                UpdateShaderOptionsCommon();
             }
             }
         }
         }
 
 

+ 1 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/Specific/ImageAttachmentPreviewPass.cpp

@@ -324,7 +324,7 @@ namespace AZ
 
 
                 shaderOption.SetValue(AZ::Name(optionName), AZ::Name(optionValues[index]));
                 shaderOption.SetValue(AZ::Name(optionName), AZ::Name(optionValues[index]));
 
 
-                m_shader->GetVariant(shaderOption.GetShaderVariantId()).ConfigurePipelineState(pipelineDesc);
+                m_shader->GetVariant(shaderOption.GetShaderVariantId()).ConfigurePipelineState(pipelineDesc, shaderOption);
                 pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = attachmentsLayout;
                 pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = attachmentsLayout;
                 pipelineDesc.m_inputStreamLayout = inputStreamLayout;
                 pipelineDesc.m_inputStreamLayout = inputStreamLayout;
                 previewInfo.m_shaderVariantKeyFallback = shaderOption.GetShaderVariantKeyFallbackValue();
                 previewInfo.m_shaderVariantKeyFallback = shaderOption.GetShaderVariantKeyFallbackValue();

+ 27 - 16
Gems/Atom/RPI/Code/Source/RPI.Public/PipelineState.cpp

@@ -51,16 +51,8 @@ namespace AZ
 
 
         void PipelineStateForDraw::Init(const Data::Instance<RPI::Shader>& shader, const ShaderOptionList* optionAndValues)
         void PipelineStateForDraw::Init(const Data::Instance<RPI::Shader>& shader, const ShaderOptionList* optionAndValues)
         {
         {
-            // Reset some variables
-            m_pipelineState = nullptr;
-            m_shaderVariantId = ShaderVariantId{};
-
-            // Reset some flags
-            m_dirty = true;
-            m_isShaderVariantReady = true;
-                        
             // Get shader variant from the shader
             // Get shader variant from the shader
-            auto shaderVariant = shader->GetRootVariant();
+            ShaderVariantId shaderVariant = {};
             if (optionAndValues)
             if (optionAndValues)
             {
             {
                 RPI::ShaderOptionGroup shaderOptionGroup = shader->CreateShaderOptionGroup();
                 RPI::ShaderOptionGroup shaderOptionGroup = shader->CreateShaderOptionGroup();
@@ -69,15 +61,29 @@ namespace AZ
                 {
                 {
                     shaderOptionGroup.SetValue(optionAndValue.first, optionAndValue.second);
                     shaderOptionGroup.SetValue(optionAndValue.first, optionAndValue.second);
                 }
                 }
-                m_shaderVariantId = shaderOptionGroup.GetShaderVariantId();
-                shaderVariant = shader->GetVariant(m_shaderVariantId);
-                m_isShaderVariantReady = shaderVariant.IsFullyBaked();
+                shaderVariant = shaderOptionGroup.GetShaderVariantId();
             }
             }
+            Init(shader, shaderVariant);
+        }
+
+        void PipelineStateForDraw::Init(const Data::Instance<Shader>& shader, const ShaderVariantId& shaderVariantId)
+        {
+            // Reset some variables
+            m_pipelineState = nullptr;
+
+            // Reset some flags
+            m_dirty = true;
+            m_isShaderVariantReady = true;
+
+            // Get shader variant from the shader
+            m_shaderVariantId = shaderVariantId;
+            ShaderVariant shaderVariant = shader->GetVariant(m_shaderVariantId);
+            m_isShaderVariantReady = !shaderVariant.UseKeyFallback();
 
 
             // Fill the descriptor with data from shader variant
             // Fill the descriptor with data from shader variant
-            shaderVariant.ConfigurePipelineState(m_descriptor);
+            shaderVariant.ConfigurePipelineState(m_descriptor, m_shaderVariantId);
 
 
-            // Connect to shader reload notification bus to rebuilt pipeline state when shader or shader variant changed. 
+            // Connect to shader reload notification bus to rebuilt pipeline state when shader or shader variant changed.
             ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
             ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
             ShaderReloadNotificationBus::MultiHandler::BusConnect(shader->GetAsset().GetId());
             ShaderReloadNotificationBus::MultiHandler::BusConnect(shader->GetAsset().GetId());
 
 
@@ -90,11 +96,11 @@ namespace AZ
         void PipelineStateForDraw::RefreshShaderVariant()
         void PipelineStateForDraw::RefreshShaderVariant()
         {
         {
             auto shaderVariant = m_shader->GetVariant(m_shaderVariantId);
             auto shaderVariant = m_shader->GetVariant(m_shaderVariantId);
-            m_isShaderVariantReady = shaderVariant.IsFullyBaked();
+            m_isShaderVariantReady = !shaderVariant.UseKeyFallback();
 
 
             auto multisampleState = m_descriptor.m_renderStates.m_multisampleState;
             auto multisampleState = m_descriptor.m_renderStates.m_multisampleState;
 
 
-            shaderVariant.ConfigurePipelineState(m_descriptor);
+            shaderVariant.ConfigurePipelineState(m_descriptor, m_shaderVariantId);
 
 
             // Recover multisampleState if it was set from output data
             // Recover multisampleState if it was set from output data
             if (m_hasOutputData)
             if (m_hasOutputData)
@@ -233,5 +239,10 @@ namespace AZ
                         
                         
             ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
             ShaderReloadNotificationBus::MultiHandler::BusDisconnect();
         }
         }
+
+        const ShaderVariantId& PipelineStateForDraw::GetShaderVariantId() const
+        {
+            return m_shaderVariantId;
+        }
     }
     }
 }
 }

+ 3 - 2
Gems/Atom/RPI/Code/Source/RPI.Public/Shader/Shader.cpp

@@ -501,8 +501,9 @@ namespace AZ
             if (drawSrgLayout)
             if (drawSrgLayout)
             {
             {
                 drawSrg = RPI::ShaderResourceGroup::Create(m_asset, GetSupervariantIndex(), drawSrgLayout->GetName());
                 drawSrg = RPI::ShaderResourceGroup::Create(m_asset, GetSupervariantIndex(), drawSrgLayout->GetName());
-
-                if (drawSrgLayout->HasShaderVariantKeyFallbackEntry())
+                bool useFallbackKey = !shaderOptions.GetShaderOptionLayout()->IsFullySpecialized() ||
+                    !m_asset->UseSpecializationConstants(GetSupervariantIndex());
+                if (useFallbackKey && drawSrgLayout->HasShaderVariantKeyFallbackEntry())
                 {
                 {
                     drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
                     drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
                 }
                 }

+ 78 - 2
Gems/Atom/RPI/Code/Source/RPI.Public/Shader/ShaderVariant.cpp

@@ -29,7 +29,7 @@ namespace AZ
             m_pipelineStateType = shaderAsset->GetPipelineStateType();
             m_pipelineStateType = shaderAsset->GetPipelineStateType();
             m_pipelineLayoutDescriptor = shaderAsset->GetPipelineLayoutDescriptor(supervariantIndex);
             m_pipelineLayoutDescriptor = shaderAsset->GetPipelineLayoutDescriptor(supervariantIndex);
             m_renderStates = &shaderAsset->GetRenderStates(supervariantIndex);
             m_renderStates = &shaderAsset->GetRenderStates(supervariantIndex);
-
+            m_useSpecializationConstants = shaderAsset->UseSpecializationConstants(supervariantIndex);
             return true;
             return true;
         }
         }
 
 
@@ -38,7 +38,16 @@ namespace AZ
 
 
         }
         }
 
 
-        void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
+        void ShaderVariant::ConfigurePipelineState(
+            RHI::PipelineStateDescriptor& descriptor,
+            const ShaderVariantId& specialization) const
+        {
+            ConfigurePipelineState(descriptor, ShaderOptionGroup(m_shaderAsset->GetShaderOptionGroupLayout(), specialization));
+        }
+
+        void ShaderVariant::ConfigurePipelineState(
+            RHI::PipelineStateDescriptor& descriptor,
+            const ShaderOptionGroup& specialization) const
         {
         {
             descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor;
             descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor;
 
 
@@ -76,6 +85,73 @@ namespace AZ
                 AZ_Assert(false, "Unexpected PipelineStateType");
                 AZ_Assert(false, "Unexpected PipelineStateType");
                 break;
                 break;
             }
             }
+
+            if (m_useSpecializationConstants)
+            {
+                // Configure specialization data for the shader
+                AZ_Assert(
+                    specialization.GetShaderOptionLayout() == m_shaderAsset->GetShaderOptionGroupLayout(),
+                    "OptionGroup for specialization is different to the one in the ShaderAsset");
+                descriptor.m_specializationData.clear();
+                ShaderOptionGroup options = specialization;
+                options.SetUnspecifiedToDefaultValues();
+                for (auto& option : options.GetShaderOptionLayout()->GetShaderOptions())
+                {
+                    if (option.GetSpecializationId() >= 0)
+                    {
+                        descriptor.m_specializationData.emplace_back();
+                        auto& specializationData = descriptor.m_specializationData.back();
+                        specializationData.m_name = option.GetName();
+                        specializationData.m_id = option.GetSpecializationId();
+                        specializationData.m_value = RHI::SpecializationValue(option.Get(options).GetIndex());
+                        switch (option.GetType())
+                        {
+                        case ShaderOptionType::Boolean:
+                            specializationData.m_type = RHI::SpecializationType::Bool;
+                            break;
+                        case ShaderOptionType::Enumeration:
+                        case ShaderOptionType::IntegerRange:
+                            specializationData.m_type = RHI::SpecializationType::Integer;
+                            break;
+                        default:
+                            break;
+                        }
+                    }
+                }
+            }
+        }
+
+        void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
+        {
+            auto layout = m_shaderAsset->GetShaderOptionGroupLayout();
+            for ([[maybe_unused]] auto& option : layout->GetShaderOptions())
+            {
+                if (m_useSpecializationConstants && option.GetSpecializationId() >= 0)
+                {
+                    AZ_Error(
+                        "ConfigurePipelineState",
+                        !m_useSpecializationConstants || option.GetSpecializationId() < 0,
+                        "Configuring PipelineStateDescriptor without specializing option %s.\
+                         Call ConfigurePipelineState with specialization data. Default value will be used.",
+                        option.GetName().GetCStr());
+                }
+            }
+            ConfigurePipelineState(descriptor, ShaderOptionGroup(layout));
+        }
+
+        bool ShaderVariant::IsFullySpecialized() const
+        {
+            return m_shaderAsset->IsFullySpecialized(m_supervariantIndex);
+        }
+
+        bool ShaderVariant::UseSpecializationConstants() const
+        {
+            return m_shaderAsset->UseSpecializationConstants(m_supervariantIndex);
+        }
+
+        bool ShaderVariant::UseKeyFallback() const
+        {
+            return !(IsFullyBaked() || IsFullySpecialized());
         }
         }
 
 
     } // namespace RPI
     } // namespace RPI

+ 24 - 3
Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAsset.cpp

@@ -72,6 +72,7 @@ namespace AZ
                     ->Field("RenderStates", &Supervariant::m_renderStates)
                     ->Field("RenderStates", &Supervariant::m_renderStates)
                     ->Field("AttributeMapList", &Supervariant::m_attributeMaps)
                     ->Field("AttributeMapList", &Supervariant::m_attributeMaps)
                     ->Field("RootVariantAsset", &Supervariant::m_rootShaderVariantAsset)
                     ->Field("RootVariantAsset", &Supervariant::m_rootShaderVariantAsset)
+                    ->Field("UseSpecializationConstants", &Supervariant::m_useSpecializationConstants)
                     ;
                     ;
             }
             }
         }
         }
@@ -194,7 +195,7 @@ namespace AZ
             Data::Asset<ShaderAsset> thisAsset(this, Data::AssetLoadBehavior::Default);
             Data::Asset<ShaderAsset> thisAsset(this, Data::AssetLoadBehavior::Default);
             Data::Asset<ShaderVariantAsset> shaderVariantAsset =
             Data::Asset<ShaderVariantAsset> shaderVariantAsset =
                 variantFinder->GetShaderVariantAssetByVariantId(thisAsset, shaderVariantId, supervariantIndex);
                 variantFinder->GetShaderVariantAssetByVariantId(thisAsset, shaderVariantId, supervariantIndex);
-            if (!shaderVariantAsset)
+            if (!shaderVariantAsset && !IsFullySpecialized(supervariantIndex))
             {
             {
                 variantFinder->QueueLoadShaderVariantAssetByVariantId(thisAsset, shaderVariantId, supervariantIndex);
                 variantFinder->QueueLoadShaderVariantAssetByVariantId(thisAsset, shaderVariantId, supervariantIndex);
             }
             }
@@ -206,7 +207,7 @@ namespace AZ
             uint32_t dynamicOptionCount = aznumeric_cast<uint32_t>(GetShaderOptionGroupLayout()->GetShaderOptions().size());
             uint32_t dynamicOptionCount = aznumeric_cast<uint32_t>(GetShaderOptionGroupLayout()->GetShaderOptions().size());
             ShaderVariantSearchResult variantSearchResult{RootShaderVariantStableId,  dynamicOptionCount };
             ShaderVariantSearchResult variantSearchResult{RootShaderVariantStableId,  dynamicOptionCount };
 
 
-            if (!dynamicOptionCount)
+            if (!dynamicOptionCount || m_isFullySpecialized)
             {
             {
                 // The shader has no options at all. There's nothing to search.
                 // The shader has no options at all. There's nothing to search.
                 return variantSearchResult;
                 return variantSearchResult;
@@ -245,7 +246,9 @@ namespace AZ
         Data::Asset<ShaderVariantAsset> ShaderAsset::GetVariantAsset(
         Data::Asset<ShaderVariantAsset> ShaderAsset::GetVariantAsset(
             ShaderVariantStableId shaderVariantStableId, SupervariantIndex supervariantIndex) const
             ShaderVariantStableId shaderVariantStableId, SupervariantIndex supervariantIndex) const
         {
         {
-            if (!shaderVariantStableId.IsValid() || shaderVariantStableId == RootShaderVariantStableId)
+            if (!shaderVariantStableId.IsValid() ||
+                shaderVariantStableId == RootShaderVariantStableId ||
+                IsFullySpecialized(supervariantIndex))
             {
             {
                 return GetRootVariantAsset(supervariantIndex);
                 return GetRootVariantAsset(supervariantIndex);
             }
             }
@@ -457,6 +460,22 @@ namespace AZ
             return attrPair->second;
             return attrPair->second;
         }
         }
 
 
+        bool ShaderAsset::UseSpecializationConstants(SupervariantIndex supervariantIndex) const
+        {
+            auto supervariant = GetSupervariant(supervariantIndex);
+            if (!supervariant)
+            {
+                return false;
+            }
+
+            return supervariant->m_useSpecializationConstants;
+        }
+
+        bool ShaderAsset::IsFullySpecialized(SupervariantIndex supervariantIndex) const
+        {            
+            return UseSpecializationConstants(supervariantIndex) && m_shaderOptionGroupLayout->IsFullySpecialized();
+        }
+
         ShaderAsset::ShaderApiDataContainer& ShaderAsset::GetCurrentShaderApiData()
         ShaderAsset::ShaderApiDataContainer& ShaderAsset::GetCurrentShaderApiData()
         {
         {
             const size_t perApiShaderDataCount = m_perAPIShaderData.size();
             const size_t perApiShaderDataCount = m_perAPIShaderData.size();
@@ -556,12 +575,14 @@ namespace AZ
                 }
                 }
             }
             }
 
 
+            m_isFullySpecialized = m_shaderOptionGroupLayout->IsFullySpecialized();
             // Common finalize check
             // Common finalize check
             for (const auto& shaderApiData : m_perAPIShaderData)
             for (const auto& shaderApiData : m_perAPIShaderData)
             {
             {
                 const auto& supervariants = shaderApiData.m_supervariants;
                 const auto& supervariants = shaderApiData.m_supervariants;
                 for (const auto& supervariant : supervariants)
                 for (const auto& supervariant : supervariants)
                 {
                 {
+                    m_isFullySpecialized &= supervariant.m_useSpecializationConstants;
                     bool beTrue = supervariant.m_attributeMaps.size() == RHI::ShaderStageCount;
                     bool beTrue = supervariant.m_attributeMaps.size() == RHI::ShaderStageCount;
                     if (!beTrue)
                     if (!beTrue)
                     {
                     {

+ 17 - 0
Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderAssetCreator.cpp

@@ -238,6 +238,20 @@ namespace AZ
             m_currentSupervariant->m_rootShaderVariantAsset = shaderVariantAsset;
             m_currentSupervariant->m_rootShaderVariantAsset = shaderVariantAsset;
         }
         }
 
 
+        void ShaderAssetCreator::SetUseSpecializationConstants(bool value)
+        {
+            if (!ValidateIsReady())
+            {
+                return;
+            }
+            if (!m_currentSupervariant)
+            {
+                ReportError("BeginSupervariant() should be called first before calling %s", __FUNCTION__);
+                return;
+            }
+            m_currentSupervariant->m_useSpecializationConstants = value;
+        }
+
         static RHI::PipelineStateType GetPipelineStateType(const Data::Asset<ShaderVariantAsset>& shaderVariantAsset)
         static RHI::PipelineStateType GetPipelineStateType(const Data::Asset<ShaderVariantAsset>& shaderVariantAsset)
         {
         {
             if (shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Vertex) ||
             if (shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Vertex) ||
@@ -343,6 +357,9 @@ namespace AZ
                 }
                 }
             }
             }
 
 
+            m_currentSupervariant->m_useSpecializationConstants =
+                m_currentSupervariant->m_useSpecializationConstants && m_asset->m_shaderOptionGroupLayout->UseSpecializationConstants();
+
             m_currentSupervariant = nullptr;
             m_currentSupervariant = nullptr;
             return true;
             return true;
         }
         }

+ 37 - 3
Gems/Atom/RPI/Code/Source/RPI.Reflect/Shader/ShaderOptionGroupLayout.cpp

@@ -85,7 +85,7 @@ namespace AZ
             if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
             if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
             {
             {
                 serializeContext->Class<ShaderOptionDescriptor>()
                 serializeContext->Class<ShaderOptionDescriptor>()
-                    ->Version(5)  // 5: addition of m_costEstimate field
+                    ->Version(6)  // 6: addition of m_specializationId field
                     ->Field("m_name", &ShaderOptionDescriptor::m_name)
                     ->Field("m_name", &ShaderOptionDescriptor::m_name)
                     ->Field("m_type", &ShaderOptionDescriptor::m_type)
                     ->Field("m_type", &ShaderOptionDescriptor::m_type)
                     ->Field("m_defaultValue", &ShaderOptionDescriptor::m_defaultValue)
                     ->Field("m_defaultValue", &ShaderOptionDescriptor::m_defaultValue)
@@ -99,6 +99,7 @@ namespace AZ
                     ->Field("m_bitMaskNot", &ShaderOptionDescriptor::m_bitMaskNot)
                     ->Field("m_bitMaskNot", &ShaderOptionDescriptor::m_bitMaskNot)
                     ->Field("m_hash", &ShaderOptionDescriptor::m_hash)
                     ->Field("m_hash", &ShaderOptionDescriptor::m_hash)
                     ->Field("m_nameReflectionForValues", &ShaderOptionDescriptor::m_nameReflectionForValues)
                     ->Field("m_nameReflectionForValues", &ShaderOptionDescriptor::m_nameReflectionForValues)
+                    ->Field("m_specializationId", &ShaderOptionDescriptor::m_specializationId)
                     ;
                     ;
             }
             }
 
 
@@ -130,7 +131,8 @@ namespace AZ
                                                        uint32_t order,
                                                        uint32_t order,
                                                        const ShaderOptionValues& nameIndexList,
                                                        const ShaderOptionValues& nameIndexList,
                                                        const Name& defaultValue,
                                                        const Name& defaultValue,
-                                                       uint32_t cost)
+                                                       uint32_t cost,
+                                                       int specializationId)
 
 
             : m_name{name}
             : m_name{name}
             , m_type{optionType}
             , m_type{optionType}
@@ -138,6 +140,7 @@ namespace AZ
             , m_order{order}
             , m_order{order}
             , m_costEstimate{cost}
             , m_costEstimate{cost}
             , m_defaultValue{defaultValue}
             , m_defaultValue{defaultValue}
+            , m_specializationId{specializationId}
         {
         {
             for (auto pair : nameIndexList)
             for (auto pair : nameIndexList)
             {   // Registers the pair in the lookup table
             {   // Registers the pair in the lookup table
@@ -187,6 +190,11 @@ namespace AZ
             return m_costEstimate;
             return m_costEstimate;
         }
         }
 
 
+        int ShaderOptionDescriptor::GetSpecializationId() const
+        {
+            return m_specializationId;
+        }
+
         ShaderVariantKey ShaderOptionDescriptor::GetBitMask() const
         ShaderVariantKey ShaderOptionDescriptor::GetBitMask() const
         {
         {
             return m_bitMask;
             return m_bitMask;
@@ -452,11 +460,13 @@ namespace AZ
             if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
             if (auto* serializeContext = azrtti_cast<SerializeContext*>(context))
             {
             {
                 serializeContext->Class<ShaderOptionGroupLayout>()
                 serializeContext->Class<ShaderOptionGroupLayout>()
-                    ->Version(2)
+                    ->Version(3)
                     ->Field("m_bitMask", &ShaderOptionGroupLayout::m_bitMask)
                     ->Field("m_bitMask", &ShaderOptionGroupLayout::m_bitMask)
                     ->Field("m_options", &ShaderOptionGroupLayout::m_options)
                     ->Field("m_options", &ShaderOptionGroupLayout::m_options)
                     ->Field("m_nameReflectionForOptions", &ShaderOptionGroupLayout::m_nameReflectionForOptions)
                     ->Field("m_nameReflectionForOptions", &ShaderOptionGroupLayout::m_nameReflectionForOptions)
                     ->Field("m_hash", &ShaderOptionGroupLayout::m_hash)
                     ->Field("m_hash", &ShaderOptionGroupLayout::m_hash)
+                    ->Field("m_isFullySpecialized", &ShaderOptionGroupLayout::m_isFullySpecialized)
+                    ->Field("m_useSpecializationConstants", &ShaderOptionGroupLayout::m_useSpecializationConstants)
                     ;
                     ;
             }
             }
 
 
@@ -486,6 +496,16 @@ namespace AZ
             return m_hash;
             return m_hash;
         }
         }
 
 
+        bool ShaderOptionGroupLayout::IsFullySpecialized() const
+        {
+            return m_isFullySpecialized;
+        }
+
+        bool ShaderOptionGroupLayout::UseSpecializationConstants() const
+        {
+            return m_useSpecializationConstants;
+        }
+
         void ShaderOptionGroupLayout::Clear()
         void ShaderOptionGroupLayout::Clear()
         {
         {
             m_options.clear();
             m_options.clear();
@@ -506,6 +526,20 @@ namespace AZ
                 hash = TypeHash64(option.GetHash(), hash);
                 hash = TypeHash64(option.GetHash(), hash);
             }
             }
             m_hash = hash;
             m_hash = hash;
+            m_isFullySpecialized = !AZStd::any_of(
+                m_options.begin(),
+                m_options.end(),
+                [](const ShaderOptionDescriptor& elem)
+                {
+                    return elem.GetSpecializationId() < 0;
+                });
+            m_useSpecializationConstants = AZStd::any_of(
+                m_options.begin(),
+                m_options.end(),
+                [](const ShaderOptionDescriptor& elem)
+                {
+                    return elem.GetSpecializationId() >= 0;
+                });
         }
         }
 
 
         bool ShaderOptionGroupLayout::ValidateIsFinalized() const
         bool ShaderOptionGroupLayout::ValidateIsFinalized() const

+ 212 - 5
Gems/Atom/RPI/Code/Tests/Shader/ShaderTests.cpp

@@ -133,6 +133,16 @@ namespace UnitTest
         : public RPITestFixture
         : public RPITestFixture
     {
     {
     protected:
     protected:
+        enum class SpecializationType
+        {
+            None = 0,
+            Partial,
+            Full,
+            Count
+        };
+
+        static const uint32_t SpecializationTypeCount = static_cast<uint32_t>(SpecializationType::Count);
+
         void SetUp() override
         void SetUp() override
         {
         {
             using namespace AZ;
             using namespace AZ;
@@ -219,10 +229,35 @@ namespace UnitTest
                                                           Name("Off") };
                                                           Name("Off") };
             bitOffset = m_bindings[3].GetBitOffset() + m_bindings[3].GetBitCount();
             bitOffset = m_bindings[3].GetBitOffset() + m_bindings[3].GetBitCount();
 
 
+            AZStd::vector<RPI::ShaderOptionValuePair> idList4;
+            idList4.push_back({ Name("True"), RPI::ShaderOptionValue(0) }); // 1+ bit
+            idList4.push_back({ Name("False"), RPI::ShaderOptionValue(1) }); // ...
+
+            for (uint32_t i = 0; i < m_bindingsFullSpecialization.size(); ++i)
+            {
+                m_bindingsFullSpecialization[i] = RPI::ShaderOptionDescriptor{
+                    Name{ AZStd::to_string(i) }, RPI::ShaderOptionType::Boolean, i, i, idList4, Name("True"), 0,
+                                                 aznumeric_caster(i) };
+            }
+
+            for (uint32_t i = 0; i < m_bindingsPartialSpecialization.size(); ++i)
+            {
+                m_bindingsPartialSpecialization[i] = RPI::ShaderOptionDescriptor{ Name{ AZStd::to_string(i) },
+                                                                                  RPI::ShaderOptionType::Boolean,
+                                                                                  i,
+                                                                                  i,
+                                                                                  idList4,
+                                                                                  Name("True"),
+                                                                                  0,
+                                                                                  aznumeric_caster(i % 2) ? aznumeric_caster(i) : -1 };
+            }
+
             m_name = Name("TestName");
             m_name = Name("TestName");
             m_drawListName = Name("DrawListTagName");
             m_drawListName = Name("DrawListTagName");
             m_pipelineLayoutDescriptor = TestPipelineLayoutDescriptor::Create();
             m_pipelineLayoutDescriptor = TestPipelineLayoutDescriptor::Create();
             m_shaderOptionGroupLayoutForAsset = CreateShaderOptionLayout();
             m_shaderOptionGroupLayoutForAsset = CreateShaderOptionLayout();
+            m_shaderOptionGroupLayoutForAssetPartialSpecialization = CreateShaderOptionLayout({}, SpecializationType::Partial);
+            m_shaderOptionGroupLayoutForAssetFullSpecialization = CreateShaderOptionLayout({}, SpecializationType::Full);
             m_shaderOptionGroupLayoutForVariants = m_shaderOptionGroupLayoutForAsset;
             m_shaderOptionGroupLayoutForVariants = m_shaderOptionGroupLayoutForAsset;
 
 
             // Just set up a couple values, not the whole struct, for some basic checking later that the struct is copied.
             // Just set up a couple values, not the whole struct, for some basic checking later that the struct is copied.
@@ -253,27 +288,46 @@ namespace UnitTest
             for (size_t i = 0; i < m_bindings.size(); ++i)
             for (size_t i = 0; i < m_bindings.size(); ++i)
             {
             {
                 m_bindings[i] = {};
                 m_bindings[i] = {};
+                m_bindingsFullSpecialization[i] = {};
+                m_bindingsPartialSpecialization[i] = {};
             }
             }
 
 
             m_srgLayouts.clear();
             m_srgLayouts.clear();
             m_pipelineLayoutDescriptor = nullptr;
             m_pipelineLayoutDescriptor = nullptr;
             m_shaderOptionGroupLayoutForAsset = nullptr;
             m_shaderOptionGroupLayoutForAsset = nullptr;
+            m_shaderOptionGroupLayoutForAssetPartialSpecialization = nullptr;
+            m_shaderOptionGroupLayoutForAssetFullSpecialization = nullptr;
             m_shaderOptionGroupLayoutForVariants = nullptr;
             m_shaderOptionGroupLayoutForVariants = nullptr;
 
 
             RPITestFixture::TearDown();
             RPITestFixture::TearDown();
         }
         }
 
 
-        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> CreateShaderOptionLayout(AZ::RHI::Handle<size_t> indexToOmit = {})
+        const AZ::RPI::ShaderOptionDescriptor& GetShaderOptionDescriptor(SpecializationType specializationType, uint32_t index)
+        {
+            switch (specializationType)
+            {
+            case SpecializationType::Partial:
+                return m_bindingsPartialSpecialization[index];
+            case SpecializationType::Full:
+                return m_bindingsFullSpecialization[index];
+            case SpecializationType::None:
+            default:
+                return m_bindings[index];            
+            }
+        }
+
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> CreateShaderOptionLayout(
+            AZ::RHI::Handle<size_t> indexToOmit = {}, SpecializationType specializationType = SpecializationType::None)
         {
         {
             using namespace AZ;
             using namespace AZ;
 
 
             RPI::Ptr<RPI::ShaderOptionGroupLayout> layout = RPI::ShaderOptionGroupLayout::Create();
             RPI::Ptr<RPI::ShaderOptionGroupLayout> layout = RPI::ShaderOptionGroupLayout::Create();
-            for (size_t i = 0; i < m_bindings.size(); ++i)
+            for (uint32_t i = 0; i < m_bindings.size(); ++i)
             {
             {
                 // Allows omitting a single option to test for missing options.
                 // Allows omitting a single option to test for missing options.
                 if (indexToOmit.GetIndex() != i)
                 if (indexToOmit.GetIndex() != i)
                 {
                 {
-                    layout->AddShaderOption(m_bindings[i]);
+                    layout->AddShaderOption(GetShaderOptionDescriptor(specializationType, i));
                 }
                 }
             }
             }
             layout->Finalize();
             layout->Finalize();
@@ -409,15 +463,31 @@ namespace UnitTest
             return shaderVariantAsset;
             return shaderVariantAsset;
         }
         }
 
 
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> GetShaderOptionGroupForAssets(SpecializationType specializationType)
+        {
+            switch (specializationType)
+            {
+            case SpecializationType::None:
+                return m_shaderOptionGroupLayoutForAsset;
+            case SpecializationType::Partial:
+                return m_shaderOptionGroupLayoutForAssetPartialSpecialization;
+            case SpecializationType::Full:
+                return m_shaderOptionGroupLayoutForAssetFullSpecialization;
+            default:
+                return nullptr;
+            }
+        }
+
         void BeginCreatingTestShaderAsset(AZ::RPI::ShaderAssetCreator& creator,
         void BeginCreatingTestShaderAsset(AZ::RPI::ShaderAssetCreator& creator,
-            const AZStd::vector<RHI::ShaderStage>& stagesToActivate = {RHI::ShaderStage::Vertex, RHI::ShaderStage::Fragment} )
+            const AZStd::vector<RHI::ShaderStage>& stagesToActivate = {RHI::ShaderStage::Vertex, RHI::ShaderStage::Fragment},
+            SpecializationType specializationType = SpecializationType::None)
         {
         {
             using namespace AZ;
             using namespace AZ;
 
 
             creator.Begin(Uuid::CreateRandom());
             creator.Begin(Uuid::CreateRandom());
             creator.SetName(m_name);
             creator.SetName(m_name);
             creator.SetDrawListName(m_drawListName);
             creator.SetDrawListName(m_drawListName);
-            creator.SetShaderOptionGroupLayout(m_shaderOptionGroupLayoutForAsset);
+            creator.SetShaderOptionGroupLayout(GetShaderOptionGroupForAssets(specializationType));
 
 
             creator.BeginAPI(RHI::Factory::Get().GetType());
             creator.BeginAPI(RHI::Factory::Get().GetType());
 
 
@@ -430,6 +500,8 @@ namespace UnitTest
             creator.SetInputContract(CreateSimpleShaderInputContract());
             creator.SetInputContract(CreateSimpleShaderInputContract());
             creator.SetOutputContract(CreateSimpleShaderOutputContract());
             creator.SetOutputContract(CreateSimpleShaderOutputContract());
 
 
+            creator.SetUseSpecializationConstants(specializationType != SpecializationType::None);
+
             RHI::ShaderStageAttributeMapList attributeMaps;
             RHI::ShaderStageAttributeMapList attributeMaps;
             attributeMaps.resize(RHI::ShaderStageCount);
             attributeMaps.resize(RHI::ShaderStageCount);
             creator.SetShaderStageAttributeMapList(attributeMaps);
             creator.SetShaderStageAttributeMapList(attributeMaps);
@@ -569,11 +641,16 @@ namespace UnitTest
         }
         }
 
 
         AZStd::array<AZ::RPI::ShaderOptionDescriptor, 4> m_bindings;
         AZStd::array<AZ::RPI::ShaderOptionDescriptor, 4> m_bindings;
+        AZStd::array<AZ::RPI::ShaderOptionDescriptor, 4> m_bindingsFullSpecialization;
+        AZStd::array<AZ::RPI::ShaderOptionDescriptor, 4> m_bindingsPartialSpecialization;
+
 
 
         AZ::Name m_name;
         AZ::Name m_name;
         AZ::Name m_drawListName;
         AZ::Name m_drawListName;
         AZ::RHI::Ptr<AZ::RHI::PipelineLayoutDescriptor> m_pipelineLayoutDescriptor;
         AZ::RHI::Ptr<AZ::RHI::PipelineLayoutDescriptor> m_pipelineLayoutDescriptor;
         AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForAsset;
         AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForAsset;
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForAssetPartialSpecialization;
+        AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForAssetFullSpecialization;
         AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForVariants;
         AZ::RPI::Ptr<AZ::RPI::ShaderOptionGroupLayout> m_shaderOptionGroupLayoutForVariants;
 
 
         AZ::RHI::RenderStates m_renderStates;
         AZ::RHI::RenderStates m_renderStates;
@@ -876,6 +953,62 @@ namespace UnitTest
         EXPECT_FALSE(shaderOptionGroupLayout->FindShaderOptionIndex(Name{ "Invalid" }).IsValid());
         EXPECT_FALSE(shaderOptionGroupLayout->FindShaderOptionIndex(Name{ "Invalid" }).IsValid());
     }
     }
 
 
+    TEST_F(ShaderTests, ShaderOptionGroupLayoutSpecializationTest)
+    {
+        using namespace AZ;
+        AZStd::vector<RPI::ShaderOptionValuePair> idList4;
+        idList4.push_back({ Name("True"), RPI::ShaderOptionValue(0) });
+        idList4.push_back({ Name("False"), RPI::ShaderOptionValue(1) });
+
+        {
+            RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
+            bool success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized1" }, RPI::ShaderOptionType::Boolean, 0, 0, idList4, Name("False"), 0, 0 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized2" }, RPI::ShaderOptionType::Boolean, 1, 1, idList4, Name("False"), 0, 1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized3" }, RPI::ShaderOptionType::Boolean, 2, 2, idList4, Name("False"), 0, 2 });
+            EXPECT_TRUE(success);
+            shaderOptionGroupLayout->Finalize();
+            EXPECT_TRUE(shaderOptionGroupLayout->IsFullySpecialized());
+            EXPECT_TRUE(shaderOptionGroupLayout->UseSpecializationConstants());
+        }
+
+        {
+            RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
+            bool success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized1" }, RPI::ShaderOptionType::Boolean, 0, 0, idList4, Name("False"), 0, 0 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized2" }, RPI::ShaderOptionType::Boolean, 1, 1, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized3" }, RPI::ShaderOptionType::Boolean, 2, 2, idList4, Name("False"), 0, 1 });
+            EXPECT_TRUE(success);
+            shaderOptionGroupLayout->Finalize();
+            EXPECT_FALSE(shaderOptionGroupLayout->IsFullySpecialized());
+            EXPECT_TRUE(shaderOptionGroupLayout->UseSpecializationConstants());
+        }
+
+        {
+            RPI::Ptr<RPI::ShaderOptionGroupLayout> shaderOptionGroupLayout = RPI::ShaderOptionGroupLayout::Create();
+            bool success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized1" }, RPI::ShaderOptionType::Boolean, 0, 0, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized2" }, RPI::ShaderOptionType::Boolean, 1, 1, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            success = shaderOptionGroupLayout->AddShaderOption(
+                RPI::ShaderOptionDescriptor{ Name{ "Specialized3" }, RPI::ShaderOptionType::Boolean, 2, 2, idList4, Name("False"), 0, -1 });
+            EXPECT_TRUE(success);
+            shaderOptionGroupLayout->Finalize();
+            EXPECT_FALSE(shaderOptionGroupLayout->IsFullySpecialized());
+            EXPECT_FALSE(shaderOptionGroupLayout->UseSpecializationConstants());
+        }
+    }
+
     TEST_F(ShaderTests, ImplicitDefaultValue)
     TEST_F(ShaderTests, ImplicitDefaultValue)
     {
     {
         // Add shader option with no default value.
         // Add shader option with no default value.
@@ -1826,6 +1959,41 @@ namespace UnitTest
         EXPECT_EQ(resultG.GetStableId().GetIndex(), stableId5);
         EXPECT_EQ(resultG.GetStableId().GetIndex(), stableId5);
     }
     }
 
 
+    TEST_F(ShaderTests, ShaderAsset_SpecializationConstants)
+    {
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::None);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_FALSE(shaderAsset->UseSpecializationConstants());
+            EXPECT_FALSE(shaderAsset->IsFullySpecialized());
+        }
+
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Partial);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_TRUE(shaderAsset->UseSpecializationConstants());
+            EXPECT_FALSE(shaderAsset->IsFullySpecialized());
+        }
+
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Full);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_TRUE(shaderAsset->UseSpecializationConstants());
+            EXPECT_TRUE(shaderAsset->IsFullySpecialized());
+        }
+
+        m_shaderOptionGroupLayoutForAssetFullSpecialization = m_shaderOptionGroupLayoutForAsset;
+        {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Full);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            EXPECT_FALSE(shaderAsset->UseSpecializationConstants());
+            EXPECT_FALSE(shaderAsset->IsFullySpecialized());
+        }
+    }
 
 
     TEST_F(ShaderTests, ShaderVariantAsset_IsFullyBaked)
     TEST_F(ShaderTests, ShaderVariantAsset_IsFullyBaked)
     {
     {
@@ -1853,5 +2021,44 @@ namespace UnitTest
         EXPECT_FALSE(shaderVariantAsset->IsFullyBaked());
         EXPECT_FALSE(shaderVariantAsset->IsFullyBaked());
         EXPECT_FALSE(ShaderOptionGroup(m_shaderOptionGroupLayoutForAsset, shaderVariantAsset->GetShaderVariantId()).IsFullySpecified());
         EXPECT_FALSE(ShaderOptionGroup(m_shaderOptionGroupLayoutForAsset, shaderVariantAsset->GetShaderVariantId()).IsFullySpecified());
     }
     }
+
+    TEST_F(ShaderTests, ShaderVariantAsset_IsFullySpecialized)
+    {
+        using namespace AZ;
+        using namespace AZ::RPI;
+
+         {
+            AZ::RPI::ShaderAssetCreator creator;
+            BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::None);
+            AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+            Data::Instance<RPI::Shader> shader = RPI::Shader::FindOrCreate(shaderAsset);
+            const RPI::ShaderVariant& rootShaderVariant = shader->GetVariant(RPI::ShaderVariantStableId{ 0 });
+            EXPECT_TRUE(rootShaderVariant.UseKeyFallback());
+            EXPECT_FALSE(rootShaderVariant.IsFullySpecialized());
+            EXPECT_FALSE(rootShaderVariant.UseSpecializationConstants());
+         }
+
+         {
+             AZ::RPI::ShaderAssetCreator creator;
+             BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Partial);
+             AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+             Data::Instance<RPI::Shader> shader = RPI::Shader::FindOrCreate(shaderAsset);
+             const RPI::ShaderVariant& rootShaderVariant = shader->GetVariant(RPI::ShaderVariantStableId{ 0 });
+             EXPECT_TRUE(rootShaderVariant.UseKeyFallback());
+             EXPECT_FALSE(rootShaderVariant.IsFullySpecialized());
+             EXPECT_TRUE(rootShaderVariant.UseSpecializationConstants());
+         }
+
+         {
+             AZ::RPI::ShaderAssetCreator creator;
+             BeginCreatingTestShaderAsset(creator, { RHI::ShaderStage::Compute }, SpecializationType::Full);
+             AZ::Data::Asset<AZ::RPI::ShaderAsset> shaderAsset = EndCreatingTestShaderAsset(creator);
+             Data::Instance<RPI::Shader> shader = RPI::Shader::FindOrCreate(shaderAsset);
+             const RPI::ShaderVariant& rootShaderVariant = shader->GetVariant(RPI::ShaderVariantStableId{ 0 });
+             EXPECT_FALSE(rootShaderVariant.UseKeyFallback());
+             EXPECT_TRUE(rootShaderVariant.IsFullySpecialized());
+             EXPECT_TRUE(rootShaderVariant.UseSpecializationConstants());
+         }
+    }
 }
 }
 
 

+ 2 - 2
Gems/AtomLyIntegration/EditorModeFeedback/Code/Source/Draw/EditorStateMeshDrawPacket.cpp

@@ -190,7 +190,7 @@ namespace AZ::Render
             const RPI::ShaderVariant& variant = shader->GetVariant(finalVariantId);
             const RPI::ShaderVariant& variant = shader->GetVariant(finalVariantId);
 
 
             RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
             RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
-            variant.ConfigurePipelineState(pipelineStateDescriptor);
+            variant.ConfigurePipelineState(pipelineStateDescriptor, shaderOptions);
 
 
             // Render states need to merge the runtime variation.
             // Render states need to merge the runtime variation.
             // This allows materials to customize the render states that the shader uses.
             // This allows materials to customize the render states that the shader uses.
@@ -221,7 +221,7 @@ namespace AZ::Render
                 // If the DrawSrg exists we must create and bind it, otherwise the CommandList will fail validation for SRG being null
                 // If the DrawSrg exists we must create and bind it, otherwise the CommandList will fail validation for SRG being null
                 drawSrg = RPI::ShaderResourceGroup::Create(shader->GetAsset(), shader->GetSupervariantIndex(), drawSrgLayout->GetName());
                 drawSrg = RPI::ShaderResourceGroup::Create(shader->GetAsset(), shader->GetSupervariantIndex(), drawSrgLayout->GetName());
 
 
-                if (!variant.IsFullyBaked() && drawSrgLayout->HasShaderVariantKeyFallbackEntry())
+                if (variant.UseKeyFallback() && drawSrgLayout->HasShaderVariantKeyFallbackEntry())
                 {
                 {
                     drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
                     drawSrg->SetShaderVariantKeyFallbackValue(shaderOptions.GetShaderVariantKeyFallbackValue());
                 }
                 }

+ 0 - 14
Gems/AtomTressFX/Assets/Passes/HairParentShortCutPass.pass

@@ -193,13 +193,6 @@
                                 "Attachment": "Depth"
                                 "Attachment": "Depth"
                             }
                             }
                         },
                         },
-                        {
-                            "LocalSlot": "InverseAlphaRTOutput",
-                            "AttachmentRef": {
-                                "Pass": "This",
-                                "Attachment": "InverseAlphaRTOutput"
-                            }
-                        },
                         {
                         {
                             "LocalSlot": "HairDepthsTextureArray",
                             "LocalSlot": "HairDepthsTextureArray",
                             "AttachmentRef": {
                             "AttachmentRef": {
@@ -237,13 +230,6 @@
                     "TemplateName": "HairShortCutGeometryShadingPassTemplate",
                     "TemplateName": "HairShortCutGeometryShadingPassTemplate",
                     "Enabled": true,
                     "Enabled": true,
                     "Connections": [
                     "Connections": [
-                        {
-                            "LocalSlot": "HairColorRenderTarget",
-                            "AttachmentRef": {
-                                "Pass": "This",
-                                "Attachment": "HairColorRenderTarget"
-                            }
-                        },
                         {
                         {
                             // The final render target - this is MSAA mode RT - would it be cheaper to
                             // The final render target - this is MSAA mode RT - would it be cheaper to
                             // use non-MSAA and then copy?
                             // use non-MSAA and then copy?

+ 10 - 1
Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryDepthAlpha.pass

@@ -50,7 +50,7 @@
                 {
                 {
                     // This buffer is used as the render target and should be at non-MSAA screen resolution
                     // This buffer is used as the render target and should be at non-MSAA screen resolution
                     // to make sure no overwork is done.
                     // to make sure no overwork is done.
-                    "Name": "InverseAlphaRTOutput",
+                    "Name": "InverseAlphaRTImage",
                     "SizeSource": {
                     "SizeSource": {
                         "Source": {
                         "Source": {
                             "Pass": "Parent",
                             "Pass": "Parent",
@@ -85,6 +85,15 @@
                         ]
                         ]
                     }
                     }
                 }
                 }
+            ],
+             "Connections": [
+                {
+                    "LocalSlot": "InverseAlphaRTOutput",
+                    "AttachmentRef": {
+                        "Pass": "This",
+                        "Attachment": "InverseAlphaRTImage"
+                    }
+                }
             ],
             ],
             "PassData": {
             "PassData": {
                 "$type": "RasterPassData",
                 "$type": "RasterPassData",

+ 8 - 1
Gems/AtomTressFX/Assets/Passes/HairShortCutGeometryShading.pass

@@ -113,7 +113,7 @@
                 {
                 {
                     // The shader hair color render target - important to have at a non-MSAA mode
                     // The shader hair color render target - important to have at a non-MSAA mode
                     // so that no overwork is done on sampling.
                     // so that no overwork is done on sampling.
-                    "Name": "HairColorRenderTarget",
+                    "Name": "HairColorImage",
                     "SizeSource": {
                     "SizeSource": {
                         "Source": {
                         "Source": {
                             "Pass": "Parent",
                             "Pass": "Parent",
@@ -144,6 +144,13 @@
                         "Pass": "This",
                         "Pass": "This",
                         "Attachment": "BRDFTexture"
                         "Attachment": "BRDFTexture"
                     }
                     }
+                },
+                {
+                    "LocalSlot": "HairColorRenderTarget",
+                    "AttachmentRef": {
+                        "Pass": "This",
+                        "Attachment": "HairColorImage"
+                    }
                 }
                 }
             ],
             ],
             "PassData": {
             "PassData": {

+ 37 - 21
Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.cpp

@@ -90,6 +90,40 @@ namespace AZ
                 return (RPI::RasterPass::IsEnabled() && m_initialized) ? true : false;
                 return (RPI::RasterPass::IsEnabled() && m_initialized) ? true : false;
             }
             }
 
 
+            bool HairGeometryRasterPass::UpdateShaderOptions(const RPI::ShaderVariantId& variantId)
+            {
+                m_currentShaderVariantId = variantId;
+                const RPI::ShaderVariant& shaderVariant = m_shader->GetVariant(m_currentShaderVariantId);
+                RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
+                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor, m_currentShaderVariantId);
+
+                RPI::Scene* scene = GetScene();
+                if (!scene)
+                {
+                    AZ_Error("Hair Gem", false, "Scene could not be acquired");
+                    return false;
+                }
+                RHI::DrawListTag drawListTag = m_shader->GetDrawListTag();
+                scene->ConfigurePipelineState(drawListTag, pipelineStateDescriptor);
+
+                pipelineStateDescriptor.m_renderAttachmentConfiguration = GetRenderAttachmentConfiguration();
+                pipelineStateDescriptor.m_inputStreamLayout.SetTopology(AZ::RHI::PrimitiveTopology::TriangleList);
+                pipelineStateDescriptor.m_inputStreamLayout.Finalize();
+
+                m_pipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
+                if (!m_pipelineState)
+                {
+                    AZ_Error("Hair Gem", false, "Pipeline state could not be acquired");
+                    return false;
+                }
+
+                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry() && shaderVariant.UseKeyFallback())
+                {
+                    m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_currentShaderVariantId.m_key);
+                }
+                return true;
+            }
+
             bool HairGeometryRasterPass::LoadShaderAndPipelineState()
             bool HairGeometryRasterPass::LoadShaderAndPipelineState()
             {
             {
                 RPI::ShaderReloadNotificationBus::Handler::BusDisconnect();
                 RPI::ShaderReloadNotificationBus::Handler::BusDisconnect();
@@ -135,27 +169,9 @@ namespace AZ
                     }
                     }
                 }
                 }
 
 
-                const RPI::ShaderVariant& shaderVariant = m_shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
-                RHI::PipelineStateDescriptorForDraw pipelineStateDescriptor;
-                shaderVariant.ConfigurePipelineState(pipelineStateDescriptor);
-
-                RPI::Scene* scene = GetScene();
-                if (!scene)
-                {
-                    AZ_Error("Hair Gem", false, "Scene could not be acquired" );
-                    return false;
-                }
-                RHI::DrawListTag drawListTag = m_shader->GetDrawListTag();
-                scene->ConfigurePipelineState(drawListTag, pipelineStateDescriptor);
-
-                pipelineStateDescriptor.m_renderAttachmentConfiguration = GetRenderAttachmentConfiguration();
-                pipelineStateDescriptor.m_inputStreamLayout.SetTopology(AZ::RHI::PrimitiveTopology::TriangleList);
-                pipelineStateDescriptor.m_inputStreamLayout.Finalize();
-
-                m_pipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
-                if (!m_pipelineState)
+                if (!UpdateShaderOptions(m_shader->GetDefaultShaderOptions().GetShaderVariantId()))
                 {
                 {
-                    AZ_Error("Hair Gem", false, "Pipeline state could not be acquired");
+                    AZ_Error("Hair Gem", false, "Failed to create pipeline state");
                     return false;
                     return false;
                 }
                 }
 
 
@@ -177,6 +193,7 @@ namespace AZ
             void HairGeometryRasterPass::SchedulePacketBuild(HairRenderObject* hairObject)
             void HairGeometryRasterPass::SchedulePacketBuild(HairRenderObject* hairObject)
             {
             {
                 m_newRenderObjects.insert(hairObject);
                 m_newRenderObjects.insert(hairObject);
+                BuildDrawPacket(hairObject);
             }
             }
 
 
             bool HairGeometryRasterPass::BuildDrawPacket(HairRenderObject* hairObject)
             bool HairGeometryRasterPass::BuildDrawPacket(HairRenderObject* hairObject)
@@ -249,7 +266,6 @@ namespace AZ
                 for (HairRenderObject* newObject : m_newRenderObjects)
                 for (HairRenderObject* newObject : m_newRenderObjects)
                 {
                 {
                     newObject->BindPerObjectSrgForRaster();
                     newObject->BindPerObjectSrgForRaster();
-                    BuildDrawPacket(newObject);
                 }
                 }
 
 
                 // Clear the new added objects - BuildDrawPacket should only be carried out once per
                 // Clear the new added objects - BuildDrawPacket should only be carried out once per

+ 4 - 0
Gems/AtomTressFX/Code/Passes/HairGeometryRasterPass.h

@@ -81,6 +81,9 @@ namespace AZ
                 // Scope producer functions...
                 // Scope producer functions...
                 void CompileResources(const RHI::FrameGraphCompileContext& context) override;
                 void CompileResources(const RHI::FrameGraphCompileContext& context) override;
 
 
+                //! Updates the shader variant being used by the pass
+                bool UpdateShaderOptions(const RPI::ShaderVariantId& variantId);
+
             protected:
             protected:
                 HairFeatureProcessor* m_featureProcessor = nullptr;
                 HairFeatureProcessor* m_featureProcessor = nullptr;
 
 
@@ -104,6 +107,7 @@ namespace AZ
                 AZStd::unordered_set<HairRenderObject*> m_newRenderObjects;
                 AZStd::unordered_set<HairRenderObject*> m_newRenderObjects;
 
 
                 bool m_initialized = false;
                 bool m_initialized = false;
+                RPI::ShaderVariantId m_currentShaderVariantId;
             };
             };
 
 
         } // namespace Hair
         } // namespace Hair

+ 8 - 4
Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.cpp

@@ -51,20 +51,24 @@ namespace AZ
             //! Once supported, this will be done via data driven code and the method can be removed.
             //! Once supported, this will be done via data driven code and the method can be removed.
             void HairPPLLRasterPass::BuildInternal()
             void HairPPLLRasterPass::BuildInternal()
             {
             {
-                RasterPass::BuildInternal();    // change this to call parent if the method exists
+                HairGeometryRasterPass::BuildInternal(); // change this to call parent if the method exists
 
 
                 if (!AcquireFeatureProcessor())
                 if (!AcquireFeatureProcessor())
                 {
                 {
                     return;
                     return;
                 }
                 }
 
 
+                // Output
+                AttachBufferToSlot(Name{ "PerPixelLinkedList" }, m_featureProcessor->GetPerPixelListBuffer());
+            }
+
+            void HairPPLLRasterPass::InitializeInternal()
+            {
                 if (!LoadShaderAndPipelineState())
                 if (!LoadShaderAndPipelineState())
                 {
                 {
                     return;
                     return;
                 }
                 }
-
-                // Output
-                AttachBufferToSlot(Name{ "PerPixelLinkedList" }, m_featureProcessor->GetPerPixelListBuffer());
+                HairGeometryRasterPass::InitializeInternal();
             }
             }
 
 
         } // namespace Hair
         } // namespace Hair

+ 1 - 0
Gems/AtomTressFX/Code/Passes/HairPPLLRasterPass.h

@@ -47,6 +47,7 @@ namespace AZ
 
 
                 // Pass behavior overrides
                 // Pass behavior overrides
                 void BuildInternal() override;
                 void BuildInternal() override;
+                void InitializeInternal() override;                
             };
             };
 
 
         } // namespace Hair
         } // namespace Hair

+ 5 - 7
Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.cpp

@@ -60,8 +60,12 @@ namespace AZ
                 shaderOption.SetValue(o_enableMarschner_TT, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableMarschner_TT });
                 shaderOption.SetValue(o_enableMarschner_TT, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableMarschner_TT });
                 shaderOption.SetValue(o_enableLongtitudeCoeff, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableLongtitudeCoeff });
                 shaderOption.SetValue(o_enableLongtitudeCoeff, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableLongtitudeCoeff });
                 shaderOption.SetValue(o_enableAzimuthCoeff, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableAzimuthCoeff });
                 shaderOption.SetValue(o_enableAzimuthCoeff, AZ::RPI::ShaderOptionValue{ m_hairGlobalSettings.m_enableAzimuthCoeff });
+                shaderOption.SetUnspecifiedToDefaultValues();
 
 
-                m_shaderOptions = shaderOption.GetShaderVariantKeyFallbackValue();
+                if (m_pipelineStateForDraw.GetShaderVariantId() != shaderOption.GetShaderVariantId())
+                {
+                    UpdateShaderOptions(shaderOption.GetShaderVariantId());
+                }
             }
             }
 
 
             RPI::Ptr<HairPPLLResolvePass> HairPPLLResolvePass::Create(const RPI::PassDescriptor& descriptor)
             RPI::Ptr<HairPPLLResolvePass> HairPPLLResolvePass::Create(const RPI::PassDescriptor& descriptor)
@@ -117,12 +121,6 @@ namespace AZ
 
 
                 UpdateGlobalShaderOptions();
                 UpdateGlobalShaderOptions();
 
 
-                if (m_shaderResourceGroup->HasShaderVariantKeyFallbackEntry())
-                {
-                    m_shaderResourceGroup->SetShaderVariantKeyFallbackValue(m_shaderOptions);
-                }
-
-
                 SrgBufferDescriptor descriptor = SrgBufferDescriptor(
                 SrgBufferDescriptor descriptor = SrgBufferDescriptor(
                     RPI::CommonBufferPoolType::ReadWrite, RHI::Format::Unknown,
                     RPI::CommonBufferPoolType::ReadWrite, RHI::Format::Unknown,
                     PPLL_NODE_SIZE, RESERVED_PIXELS_FOR_OIT,
                     PPLL_NODE_SIZE, RESERVED_PIXELS_FOR_OIT,

+ 0 - 1
Gems/AtomTressFX/Code/Passes/HairPPLLResolvePass.h

@@ -76,7 +76,6 @@ namespace AZ
 
 
                 HairGlobalSettings m_hairGlobalSettings;
                 HairGlobalSettings m_hairGlobalSettings;
                 HairFeatureProcessor* m_featureProcessor = nullptr;
                 HairFeatureProcessor* m_featureProcessor = nullptr;
-                AZ::RPI::ShaderVariantKey m_shaderOptions;
             };
             };
 
 
         } // namespace Hair
         } // namespace Hair

+ 5 - 1
Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.cpp

@@ -35,14 +35,18 @@ namespace AZ
 
 
             void HairShortCutGeometryDepthAlphaPass::BuildInternal()
             void HairShortCutGeometryDepthAlphaPass::BuildInternal()
             {
             {
-                RasterPass::BuildInternal();    // change this to call parent if the method exists
+                HairGeometryRasterPass::BuildInternal(); // change this to call parent if the method exists
 
 
                 if (!AcquireFeatureProcessor())
                 if (!AcquireFeatureProcessor())
                 {
                 {
                     return;
                     return;
                 }
                 }
+            }
 
 
+            void HairShortCutGeometryDepthAlphaPass::InitializeInternal()
+            {
                 LoadShaderAndPipelineState();
                 LoadShaderAndPipelineState();
+                HairGeometryRasterPass::InitializeInternal();
             }
             }
 
 
         } // namespace Hair
         } // namespace Hair

+ 1 - 0
Gems/AtomTressFX/Code/Passes/HairShortCutGeometryDepthAlphaPass.h

@@ -42,6 +42,7 @@ namespace AZ
 
 
                 // Pass behavior overrides
                 // Pass behavior overrides
                 void BuildInternal() override;
                 void BuildInternal() override;
+                void InitializeInternal() override;
             };
             };
 
 
         } // namespace Hair
         } // namespace Hair

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott