3
0

ShaderVariant.cpp 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RPI.Public/Shader/ShaderVariant.h>
  9. #include <Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h>
  10. #include <Atom/RPI.Public/Shader/ShaderReloadDebugTracker.h>
  11. #include <Atom/RHI/DrawListTagRegistry.h>
  12. #include <Atom/RHI/RHISystemInterface.h>
  13. #include <Atom/RHI.Reflect/ShaderStageFunction.h>
  14. namespace AZ
  15. {
  16. namespace RPI
  17. {
  18. bool ShaderVariant::Init(
  19. const Data::Asset<ShaderAsset>& shaderAsset,
  20. const Data::Asset<ShaderVariantAsset>& shaderVariantAsset,
  21. SupervariantIndex supervariantIndex)
  22. {
  23. m_shaderAsset = shaderAsset;
  24. m_shaderVariantAsset = shaderVariantAsset;
  25. m_supervariantIndex = supervariantIndex;
  26. m_pipelineStateType = shaderAsset->GetPipelineStateType();
  27. m_pipelineLayoutDescriptor = shaderAsset->GetPipelineLayoutDescriptor(supervariantIndex);
  28. m_renderStates = &shaderAsset->GetRenderStates(supervariantIndex);
  29. return true;
  30. }
  31. ShaderVariant::~ShaderVariant()
  32. {
  33. }
  34. void ShaderVariant::ConfigurePipelineState(RHI::PipelineStateDescriptor& descriptor) const
  35. {
  36. descriptor.m_pipelineLayoutDescriptor = m_pipelineLayoutDescriptor;
  37. switch (descriptor.GetType())
  38. {
  39. case RHI::PipelineStateType::Draw:
  40. {
  41. AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Draw, "ShaderVariant is not intended for the raster pipeline.");
  42. AZ_Assert(m_renderStates, "Invalid RenderStates");
  43. RHI::PipelineStateDescriptorForDraw& descriptorForDraw = static_cast<RHI::PipelineStateDescriptorForDraw&>(descriptor);
  44. descriptorForDraw.m_vertexFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Vertex);
  45. descriptorForDraw.m_geometryFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Geometry);
  46. descriptorForDraw.m_fragmentFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Fragment);
  47. descriptorForDraw.m_renderStates = *m_renderStates;
  48. break;
  49. }
  50. case RHI::PipelineStateType::Dispatch:
  51. {
  52. AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::Dispatch, "ShaderVariant is not intended for the compute pipeline.");
  53. RHI::PipelineStateDescriptorForDispatch& descriptorForDispatch = static_cast<RHI::PipelineStateDescriptorForDispatch&>(descriptor);
  54. descriptorForDispatch.m_computeFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::Compute);
  55. break;
  56. }
  57. case RHI::PipelineStateType::RayTracing:
  58. {
  59. AZ_Assert(m_pipelineStateType == RHI::PipelineStateType::RayTracing, "ShaderVariant is not intended for the ray tracing pipeline.");
  60. RHI::PipelineStateDescriptorForRayTracing& descriptorForRayTracing = static_cast<RHI::PipelineStateDescriptorForRayTracing&>(descriptor);
  61. descriptorForRayTracing.m_rayTracingFunction = m_shaderVariantAsset->GetShaderStageFunction(RHI::ShaderStage::RayTracing);
  62. break;
  63. }
  64. default:
  65. AZ_Assert(false, "Unexpected PipelineStateType");
  66. break;
  67. }
  68. }
  69. } // namespace RPI
  70. } // namespace AZ