RayTracingExampleComponent.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #pragma once
  9. #include <Atom/RHI/BufferPool.h>
  10. #include <Atom/RHI/Device.h>
  11. #include <Atom/RHI/DeviceDispatchRaysItem.h>
  12. #include <Atom/RHI/Factory.h>
  13. #include <Atom/RHI/FrameScheduler.h>
  14. #include <Atom/RHI/PipelineState.h>
  15. #include <Atom/RHI/RayTracingAccelerationStructure.h>
  16. #include <Atom/RHI/RayTracingBufferPools.h>
  17. #include <Atom/RHI/RayTracingPipelineState.h>
  18. #include <Atom/RHI/RayTracingShaderTable.h>
  19. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  20. #include <AzCore/Component/Component.h>
  21. #include <AzCore/Math/Matrix4x4.h>
  22. #include <RHI/BasicRHIComponent.h>
  23. namespace AtomSampleViewer
  24. {
  25. using namespace AZ;
  26. // This sample demonstrates the use of Atom Ray Tracing through the RHI abstraction layer.
  27. // It creates three triangles and one rectangle in a scene, and ray traces that scene to
  28. // an output image and displays it.
  29. class RayTracingExampleComponent final
  30. : public BasicRHIComponent
  31. {
  32. public:
  33. AZ_COMPONENT(RayTracingExampleComponent, "{FC4636BC-9C5C-4D7D-8FEF-41A02C56B62D}", AZ::Component);
  34. AZ_DISABLE_COPY(RayTracingExampleComponent);
  35. static void Reflect(AZ::ReflectContext* context);
  36. RayTracingExampleComponent();
  37. ~RayTracingExampleComponent() override {}
  38. protected:
  39. // AZ::Component
  40. void Activate() override;
  41. void Deactivate() override;
  42. private:
  43. void CreateResourcePools();
  44. void CreateGeometry();
  45. void CreateFullScreenBuffer();
  46. void CreateOutputTexture();
  47. void CreateRasterShader();
  48. void CreateRayTracingAccelerationStructureObjects();
  49. void CreateRayTracingPipelineState();
  50. void CreateRayTracingShaderTable();
  51. void CreateRayTracingAccelerationTableScope();
  52. void CreateRayTracingDispatchScope();
  53. void CreateRasterScope();
  54. static const uint32_t m_imageWidth = 1920;
  55. static const uint32_t m_imageHeight = 1080;
  56. // resource pools
  57. RHI::Ptr<RHI::BufferPool> m_inputAssemblyBufferPool;
  58. RHI::Ptr<RHI::ImagePool> m_imagePool;
  59. RHI::Ptr<RHI::RayTracingBufferPools> m_rayTracingBufferPools;
  60. // triangle vertex/index buffers
  61. AZStd::array<VertexPosition, 3> m_triangleVertices;
  62. AZStd::array<uint16_t, 3> m_triangleIndices;
  63. RHI::Ptr<RHI::Buffer> m_triangleVB;
  64. RHI::Ptr<RHI::Buffer> m_triangleIB;
  65. // rectangle vertex/index buffers
  66. AZStd::array<VertexPosition, 4> m_rectangleVertices;
  67. AZStd::array<uint16_t, 6> m_rectangleIndices;
  68. RHI::Ptr<RHI::Buffer> m_rectangleVB;
  69. RHI::Ptr<RHI::Buffer> m_rectangleIB;
  70. // ray tracing acceleration structures
  71. RHI::Ptr<RHI::RayTracingBlas> m_triangleRayTracingBlas;
  72. RHI::Ptr<RHI::RayTracingBlas> m_rectangleRayTracingBlas;
  73. RHI::Ptr<RHI::RayTracingTlas> m_rayTracingTlas;
  74. RHI::BufferViewDescriptor m_tlasBufferViewDescriptor;
  75. RHI::AttachmentId m_tlasBufferAttachmentId = RHI::AttachmentId("tlasBufferAttachmentId");
  76. // ray tracing shaders
  77. Data::Instance<RPI::Shader> m_rayGenerationShader;
  78. Data::Instance<RPI::Shader> m_missShader;
  79. Data::Instance<RPI::Shader> m_closestHitGradientShader;
  80. Data::Instance<RPI::Shader> m_closestHitSolidShader;
  81. // ray tracing pipeline state
  82. RHI::Ptr<RHI::RayTracingPipelineState> m_rayTracingPipelineState;
  83. // ray tracing shader table
  84. RHI::Ptr<RHI::RayTracingShaderTable> m_rayTracingShaderTable;
  85. // ray tracing global shader resource group and pipeline state
  86. Data::Instance<RPI::ShaderResourceGroup> m_globalSrg;
  87. RHI::ConstPtr<RHI::PipelineState> m_globalPipelineState;
  88. // ray tracing local shader resource groups, one for each object in the scene
  89. enum LocalSrgs
  90. {
  91. Triangle1,
  92. Triangle2,
  93. Triangle3,
  94. Rectangle,
  95. Count
  96. };
  97. AZStd::array<Data::Instance<RPI::ShaderResourceGroup>, LocalSrgs::Count> m_localSrgs;
  98. bool m_buildLocalSrgs = true;
  99. // output image, written to by the ray tracing shader and displayed in the fullscreen draw shader
  100. RHI::Ptr<RHI::Image> m_outputImage;
  101. RHI::Ptr<RHI::ImageView> m_outputImageView;
  102. RHI::ImageViewDescriptor m_outputImageViewDescriptor;
  103. RHI::AttachmentId m_outputImageAttachmentId = RHI::AttachmentId("outputImageAttachmentId");
  104. // fullscreen buffer for the raster pass to display the output image
  105. struct FullScreenBufferData
  106. {
  107. AZStd::array<VertexPosition, 4> m_positions;
  108. AZStd::array<VertexUV, 4> m_uvs;
  109. AZStd::array<uint16_t, 6> m_indices;
  110. };
  111. RHI::Ptr<RHI::Buffer> m_fullScreenInputAssemblyBuffer;
  112. RHI::GeometryView m_geometryView{ AZ::RHI::MultiDevice::AllDevices };
  113. RHI::InputStreamLayout m_fullScreenInputStreamLayout;
  114. RHI::ConstPtr<RHI::PipelineState> m_drawPipelineState;
  115. Data::Instance<RPI::ShaderResourceGroup> m_drawSRG;
  116. RHI::ShaderInputConstantIndex m_drawDimensionConstantIndex;
  117. // time variable for moving the triangles and rectangle each frame
  118. float m_time = 0.0f;
  119. };
  120. } // namespace AtomSampleViewer