RayTracingPass.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <AzCore/Memory/SystemAllocator.h>
  10. #include <Atom/RHI/RayTracingPipelineState.h>
  11. #include <Atom/RHI/RayTracingShaderTable.h>
  12. #include <Atom/RPI.Public/Pass/RenderPass.h>
  13. #include <Atom/RPI.Public/Shader/Shader.h>
  14. #include <Atom/RPI.Public/Shader/ShaderReloadNotificationBus.h>
  15. namespace AZ
  16. {
  17. namespace Render
  18. {
  19. struct RayTracingPassData;
  20. //! This pass executes a raytracing shader as specified in the PassData.
  21. class RayTracingPass
  22. : public RPI::RenderPass
  23. , private RPI::ShaderReloadNotificationBus::MultiHandler
  24. {
  25. AZ_RPI_PASS(RayTracingPass);
  26. public:
  27. AZ_RTTI(RayTracingPass, "{7A68A36E-956A-4258-93FE-38686042C4D9}", RPI::RenderPass);
  28. AZ_CLASS_ALLOCATOR(RayTracingPass, SystemAllocator);
  29. virtual ~RayTracingPass();
  30. //! Creates a RayTracingPass
  31. static RPI::Ptr<RayTracingPass> Create(const RPI::PassDescriptor& descriptor);
  32. void SetMaxRayLength(float maxRayLength) { m_maxRayLength = maxRayLength; }
  33. protected:
  34. RayTracingPass(const RPI::PassDescriptor& descriptor);
  35. // Pass overrides
  36. bool IsEnabled() const override;
  37. void FrameBeginInternal(FramePrepareParams params) override;
  38. // Scope producer functions
  39. void SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph) override;
  40. void CompileResources(const RHI::FrameGraphCompileContext& context) override;
  41. void BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context) override;
  42. // ShaderReloadNotificationBus::Handler overrides
  43. void OnShaderReinitialized(const RPI::Shader& shader) override;
  44. void OnShaderAssetReinitialized(const Data::Asset<RPI::ShaderAsset>& shaderAsset) override;
  45. void OnShaderVariantReinitialized(const RPI::ShaderVariant& shaderVariant) override;
  46. // load the raytracing shaders and setup pipeline states
  47. void CreatePipelineState();
  48. // helper for loading a shader from a shader asset reference
  49. Data::Instance<RPI::Shader> LoadShader(const RPI::AssetReference& shaderAssetReference);
  50. // pass data
  51. RPI::PassDescriptor m_passDescriptor;
  52. const RayTracingPassData* m_passData = nullptr;
  53. // revision number of the ray tracing TLAS when the shader table was built
  54. uint32_t m_rayTracingRevision = 0;
  55. uint32_t m_proceduralGeometryTypeRevision = 0;
  56. // raytracing shaders, pipeline states, and shader table
  57. Data::Instance<RPI::Shader> m_rayGenerationShader;
  58. Data::Instance<RPI::Shader> m_missShader;
  59. Data::Instance<RPI::Shader> m_closestHitShader;
  60. Data::Instance<RPI::Shader> m_closestHitProceduralShader;
  61. Data::Instance<RPI::Shader> m_intersectionShader;
  62. RHI::Ptr<RHI::RayTracingPipelineState> m_rayTracingPipelineState;
  63. RHI::ConstPtr<RHI::PipelineState> m_globalPipelineState;
  64. RHI::Ptr<RHI::RayTracingShaderTable> m_rayTracingShaderTable;
  65. bool m_requiresViewSrg = false;
  66. bool m_requiresSceneSrg = false;
  67. bool m_requiresRayTracingMaterialSrg = false;
  68. bool m_requiresRayTracingSceneSrg = false;
  69. float m_maxRayLength = 1e27f;
  70. };
  71. } // namespace RPI
  72. } // namespace AZ