ShaderVariant.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RPI.Public/Shader/ShaderVariant.h>
  9. #include <Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h>
  10. #include <Atom/RPI.Public/Shader/ShaderReloadDebugTracker.h>
  11. #include <Atom/RHI/DrawListTagRegistry.h>
  12. #include <Atom/RHI/RHISystemInterface.h>
  13. #include <Atom/RHI.Reflect/ShaderStageFunction.h>
  14. namespace AZ
  15. {
  16. namespace RPI
  17. {
  18. bool ShaderVariant::Init(
  19. const Data::Asset<ShaderAsset>& shaderAsset,
  20. const Data::Asset<ShaderVariantAsset>& shaderVariantAsset,
  21. SupervariantIndex supervariantIndex)
  22. {
  23. m_shaderAsset = shaderAsset;
  24. m_shaderVariantAsset = shaderVariantAsset;
  25. m_supervariantIndex = supervariantIndex;
  26. m_pipelineStateType = shaderAsset->GetPipelineStateType();
  27. m_pipelineLayoutDescriptor = shaderAsset->GetPipelineLayoutDescriptor(supervariantIndex);
  28. m_renderStates = &shaderAsset->GetRenderStates(supervariantIndex);
  29. m_useSpecializationConstants = shaderAsset->UseSpecializationConstants(supervariantIndex);
  30. return true;
  31. }
  32. ShaderVariant::~ShaderVariant()
  33. {
  34. }
  35. void ShaderVariant::ConfigurePipelineState(
  36. RHI::PipelineStateDescriptor& descriptor,
  37. const ShaderVariantId& specialization) const
  38. {
  39. ConfigurePipelineState(descriptor, ShaderOptionGroup(m_shaderAsset->GetShaderOptionGroupLayout(), specialization));
  40. }
  41. void ShaderVariant::ConfigurePipelineState(
  42. RHI::PipelineStateDescriptor& descriptor,
  43. const ShaderOptionGroup& specialization) const
  44. {
  45. descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor;
  46. switch (descriptor.GetType())
  47. {
  48. case RHI::PipelineStateType::Draw:
  49. {
  50. AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Draw, "ShaderVariant is not intended for the raster pipeline.");
  51. AZ_Assert(m_renderStates, "Invalid RenderStates");
  52. RHI::PipelineStateDescriptorForDraw& descriptorForDraw = static_cast<RHI::PipelineStateDescriptorForDraw&>(descriptor);
  53. descriptorForDraw.m_vertexFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Vertex);
  54. descriptorForDraw.m_geometryFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Geometry);
  55. descriptorForDraw.m_fragmentFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Fragment);
  56. descriptorForDraw.m_renderStates = *m_renderStates;
  57. break;
  58. }
  59. case RHI::PipelineStateType::Dispatch:
  60. {
  61. AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Dispatch, "ShaderVariant is not intended for the compute pipeline.");
  62. RHI::PipelineStateDescriptorForDispatch& descriptorForDispatch = static_cast<RHI::PipelineStateDescriptorForDispatch&>(descriptor);
  63. descriptorForDispatch.m_computeFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Compute);
  64. break;
  65. }
  66. case RHI::PipelineStateType::RayTracing:
  67. {
  68. AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::RayTracing, "ShaderVariant is not intended for the ray tracing pipeline.");
  69. RHI::PipelineStateDescriptorForRayTracing& descriptorForRayTracing = static_cast<RHI::PipelineStateDescriptorForRayTracing&>(descriptor);
  70. descriptorForRayTracing.m_rayTracingFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::RayTracing);
  71. break;
  72. }
  73. default:
  74. AZ_Assert(false, "Unexpected PipelineStateType");
  75. break;
  76. }
  77. if (m_useSpecializationConstants)
  78. {
  79. // Configure specialization data for the shader
  80. AZ_Assert(
  81. specialization.GetShaderOptionLayout() == m_shaderAsset->GetShaderOptionGroupLayout(),
  82. "OptionGroup for specialization is different to the one in the ShaderAsset");
  83. descriptor.m_specializationData.clear();
  84. ShaderOptionGroup options = specialization;
  85. options.SetUnspecifiedToDefaultValues();
  86. for (auto& option : options.GetShaderOptionLayout()->GetShaderOptions())
  87. {
  88. if (option.GetSpecializationId() >= 0)
  89. {
  90. descriptor.m_specializationData.emplace_back();
  91. auto& specializationData = descriptor.m_specializationData.back();
  92. specializationData.m_name = option.GetName();
  93. specializationData.m_id = option.GetSpecializationId();
  94. specializationData.m_value = RHI::SpecializationValue(option.Get(options).GetIndex());
  95. switch (option.GetType())
  96. {
  97. case ShaderOptionType::Boolean:
  98. specializationData.m_type = RHI::SpecializationType::Bool;
  99. break;
  100. case ShaderOptionType::Enumeration:
  101. case ShaderOptionType::IntegerRange:
  102. specializationData.m_type = RHI::SpecializationType::Integer;
  103. break;
  104. default:
  105. break;
  106. }
  107. }
  108. }
  109. }
  110. }
  111. void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
  112. {
  113. auto layout = m_shaderAsset->GetShaderOptionGroupLayout();
  114. for ([[maybe_unused]] auto& option : layout->GetShaderOptions())
  115. {
  116. if (m_useSpecializationConstants && option.GetSpecializationId() >= 0)
  117. {
  118. AZ_Error(
  119. "ConfigurePipelineState",
  120. !m_useSpecializationConstants || option.GetSpecializationId() < 0,
  121. "Configuring PipelineStateDescriptor without specializing option %s.\
  122. Call ConfigurePipelineState with specialization data. Default value will be used.",
  123. option.GetName().GetCStr());
  124. }
  125. }
  126. ConfigurePipelineState(descriptor, ShaderOptionGroup(layout));
  127. }
  128. bool ShaderVariant::IsFullySpecialized() const
  129. {
  130. return m_shaderAsset->IsFullySpecialized(m_supervariantIndex);
  131. }
  132. bool ShaderVariant::UseSpecializationConstants() const
  133. {
  134. return m_shaderAsset->UseSpecializationConstants(m_supervariantIndex);
  135. }
  136. bool ShaderVariant::UseKeyFallback() const
  137. {
  138. return !(IsFullyBaked() || IsFullySpecialized());
  139. }
  140. } // namespace RPI
  141. } // namespace AZ