RayTracingExampleComponent.cpp 32 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RHI/RayTracingExampleComponent.h>
  9. #include <Utils/Utils.h>
  10. #include <SampleComponentManager.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI/FrameGraphInterface.h>
  13. #include <Atom/RHI/RayTracingPipelineState.h>
  14. #include <Atom/RHI/RayTracingShaderTable.h>
  15. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  16. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  17. #include <Atom/RPI.Public/Shader/Shader.h>
  18. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  19. #include <AzCore/Serialization/SerializeContext.h>
  20. static const char* RayTracingExampleName = "RayTracingExample";
  21. namespace AtomSampleViewer
  22. {
  23. void RayTracingExampleComponent::Reflect(AZ::ReflectContext* context)
  24. {
  25. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  26. {
  27. serializeContext->Class<RayTracingExampleComponent, AZ::Component>()
  28. ->Version(0)
  29. ;
  30. }
  31. }
  32. RayTracingExampleComponent::RayTracingExampleComponent()
  33. {
  34. m_supportRHISamplePipeline = true;
  35. }
  36. void RayTracingExampleComponent::Activate()
  37. {
  38. CreateResourcePools();
  39. CreateGeometry();
  40. CreateFullScreenBuffer();
  41. CreateOutputTexture();
  42. CreateRasterShader();
  43. CreateRayTracingAccelerationStructureObjects();
  44. CreateRayTracingPipelineState();
  45. CreateRayTracingShaderTable();
  46. CreateRayTracingAccelerationTableScope();
  47. CreateRayTracingDispatchScope();
  48. CreateRasterScope();
  49. RHI::RHISystemNotificationBus::Handler::BusConnect();
  50. }
  51. void RayTracingExampleComponent::Deactivate()
  52. {
  53. RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  54. m_windowContext = nullptr;
  55. m_scopeProducers.clear();
  56. }
  57. void RayTracingExampleComponent::CreateResourcePools()
  58. {
  59. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  60. // create input assembly buffer pool
  61. {
  62. m_inputAssemblyBufferPool = RHI::Factory::Get().CreateBufferPool();
  63. RHI::BufferPoolDescriptor bufferPoolDesc;
  64. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  65. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Host;
  66. [[maybe_unused]] RHI::ResultCode resultCode = m_inputAssemblyBufferPool->Init(*device, bufferPoolDesc);
  67. AZ_Assert(resultCode == RHI::ResultCode::Success, "Failed to initialize input assembly buffer pool");
  68. }
  69. // create output image pool
  70. {
  71. RHI::ImagePoolDescriptor imagePoolDesc;
  72. imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
  73. m_imagePool = RHI::Factory::Get().CreateImagePool();
  74. [[maybe_unused]] RHI::ResultCode result = m_imagePool->Init(*device, imagePoolDesc);
  75. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image pool");
  76. }
  77. // initialize ray tracing buffer pools
  78. m_rayTracingBufferPools = RHI::RayTracingBufferPools::CreateRHIRayTracingBufferPools();
  79. m_rayTracingBufferPools->Init(device);
  80. }
  81. void RayTracingExampleComponent::CreateGeometry()
  82. {
  83. // triangle
  84. {
  85. // vertex buffer
  86. SetVertexPosition(m_triangleVertices.data(), 0, 0.0f, 0.5f, 1.0);
  87. SetVertexPosition(m_triangleVertices.data(), 1, 0.5f, -0.5f, 1.0);
  88. SetVertexPosition(m_triangleVertices.data(), 2, -0.5f, -0.5f, 1.0);
  89. m_triangleVB = RHI::Factory::Get().CreateBuffer();
  90. m_triangleVB->SetName(AZ::Name("Triangle VB"));
  91. RHI::BufferInitRequest request;
  92. request.m_buffer = m_triangleVB.get();
  93. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleVertices) };
  94. request.m_initialData = m_triangleVertices.data();
  95. m_inputAssemblyBufferPool->InitBuffer(request);
  96. // index buffer
  97. SetVertexIndexIncreasing(m_triangleIndices.data(), m_triangleIndices.size());
  98. m_triangleIB = RHI::Factory::Get().CreateBuffer();
  99. m_triangleIB->SetName(AZ::Name("Triangle IB"));
  100. request.m_buffer = m_triangleIB.get();
  101. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleIndices) };
  102. request.m_initialData = m_triangleIndices.data();
  103. m_inputAssemblyBufferPool->InitBuffer(request);
  104. }
  105. // rectangle
  106. {
  107. // vertex buffer
  108. SetVertexPosition(m_rectangleVertices.data(), 0, -0.5f, 0.5f, 1.0);
  109. SetVertexPosition(m_rectangleVertices.data(), 1, 0.5f, 0.5f, 1.0);
  110. SetVertexPosition(m_rectangleVertices.data(), 2, 0.5f, -0.5f, 1.0);
  111. SetVertexPosition(m_rectangleVertices.data(), 3, -0.5f, -0.5f, 1.0);
  112. m_rectangleVB = RHI::Factory::Get().CreateBuffer();
  113. m_rectangleVB->SetName(AZ::Name("Rectangle VB"));
  114. RHI::BufferInitRequest request;
  115. request.m_buffer = m_rectangleVB.get();
  116. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleVertices) };
  117. request.m_initialData = m_rectangleVertices.data();
  118. m_inputAssemblyBufferPool->InitBuffer(request);
  119. // index buffer
  120. m_rectangleIndices[0] = 0;
  121. m_rectangleIndices[1] = 1;
  122. m_rectangleIndices[2] = 2;
  123. m_rectangleIndices[3] = 0;
  124. m_rectangleIndices[4] = 2;
  125. m_rectangleIndices[5] = 3;
  126. m_rectangleIB = RHI::Factory::Get().CreateBuffer();
  127. m_rectangleIB->SetName(AZ::Name("Rectangle IB"));
  128. request.m_buffer = m_rectangleIB.get();
  129. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleIndices) };
  130. request.m_initialData = m_rectangleIndices.data();
  131. m_inputAssemblyBufferPool->InitBuffer(request);
  132. }
  133. }
  134. void RayTracingExampleComponent::CreateFullScreenBuffer()
  135. {
  136. FullScreenBufferData bufferData;
  137. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  138. m_fullScreenInputAssemblyBuffer = RHI::Factory::Get().CreateBuffer();
  139. RHI::BufferInitRequest request;
  140. request.m_buffer = m_fullScreenInputAssemblyBuffer.get();
  141. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  142. request.m_initialData = &bufferData;
  143. m_inputAssemblyBufferPool->InitBuffer(request);
  144. m_fullScreenStreamBufferViews[0] =
  145. {
  146. *m_fullScreenInputAssemblyBuffer,
  147. offsetof(FullScreenBufferData, m_positions),
  148. sizeof(FullScreenBufferData::m_positions),
  149. sizeof(VertexPosition)
  150. };
  151. m_fullScreenStreamBufferViews[1] =
  152. {
  153. *m_fullScreenInputAssemblyBuffer,
  154. offsetof(FullScreenBufferData, m_uvs),
  155. sizeof(FullScreenBufferData::m_uvs),
  156. sizeof(VertexUV)
  157. };
  158. m_fullScreenIndexBufferView =
  159. {
  160. *m_fullScreenInputAssemblyBuffer,
  161. offsetof(FullScreenBufferData, m_indices),
  162. sizeof(FullScreenBufferData::m_indices),
  163. RHI::IndexFormat::Uint16
  164. };
  165. RHI::InputStreamLayoutBuilder layoutBuilder;
  166. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  167. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  168. m_fullScreenInputStreamLayout = layoutBuilder.End();
  169. }
  170. void RayTracingExampleComponent::CreateOutputTexture()
  171. {
  172. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  173. // create output image
  174. m_outputImage = RHI::Factory::Get().CreateImage();
  175. RHI::ImageInitRequest request;
  176. request.m_image = m_outputImage.get();
  177. request.m_descriptor = RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, m_imageWidth, m_imageHeight, RHI::Format::R8G8B8A8_UNORM);
  178. [[maybe_unused]] RHI::ResultCode result = m_imagePool->InitImage(request);
  179. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image");
  180. m_outputImageViewDescriptor = RHI::ImageViewDescriptor::Create(RHI::Format::R8G8B8A8_UNORM, 0, 0);
  181. m_outputImageView = m_outputImage->GetImageView(m_outputImageViewDescriptor);
  182. AZ_Assert(m_outputImageView.get(), "Failed to create output image view");
  183. AZ_Assert(m_outputImageView->IsFullView(), "Image View initialization IsFullView() failed");
  184. }
  185. void RayTracingExampleComponent::CreateRayTracingAccelerationStructureObjects()
  186. {
  187. m_triangleRayTracingBlas = AZ::RHI::RayTracingBlas::CreateRHIRayTracingBlas();
  188. m_rectangleRayTracingBlas = AZ::RHI::RayTracingBlas::CreateRHIRayTracingBlas();
  189. m_rayTracingTlas = AZ::RHI::RayTracingTlas::CreateRHIRayTracingTlas();
  190. }
  191. void RayTracingExampleComponent::CreateRasterShader()
  192. {
  193. const char* shaderFilePath = "Shaders/RHI/RayTracingDraw.azshader";
  194. auto drawShader = LoadShader(shaderFilePath, RayTracingExampleName);
  195. AZ_Assert(drawShader, "Failed to load Draw shader");
  196. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  197. drawShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  198. pipelineDesc.m_inputStreamLayout = m_fullScreenInputStreamLayout;
  199. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  200. attachmentsBuilder.AddSubpass()->RenderTargetAttachment(m_outputFormat);
  201. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  202. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create draw render attachment layout");
  203. m_drawPipelineState = drawShader->AcquirePipelineState(pipelineDesc);
  204. AZ_Assert(m_drawPipelineState, "Failed to acquire draw pipeline state");
  205. m_drawSRG = CreateShaderResourceGroup(drawShader, "BufferSrg", RayTracingExampleName);
  206. }
  207. void RayTracingExampleComponent::CreateRayTracingPipelineState()
  208. {
  209. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  210. // load ray generation shader
  211. const char* rayGenerationShaderFilePath = "Shaders/RHI/RayTracingDispatch.azshader";
  212. m_rayGenerationShader = LoadShader(rayGenerationShaderFilePath, RayTracingExampleName);
  213. AZ_Assert(m_rayGenerationShader, "Failed to load ray generation shader");
  214. auto rayGenerationShaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  215. RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
  216. rayGenerationShaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
  217. // load miss shader
  218. const char* missShaderFilePath = "Shaders/RHI/RayTracingMiss.azshader";
  219. m_missShader = LoadShader(missShaderFilePath, RayTracingExampleName);
  220. AZ_Assert(m_missShader, "Failed to load miss shader");
  221. auto missShaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  222. RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
  223. missShaderVariant.ConfigurePipelineState(missShaderDescriptor);
  224. // load closest hit gradient shader
  225. const char* closestHitGradientShaderFilePath = "Shaders/RHI/RayTracingClosestHitGradient.azshader";
  226. m_closestHitGradientShader = LoadShader(closestHitGradientShaderFilePath, RayTracingExampleName);
  227. AZ_Assert(m_closestHitGradientShader, "Failed to load closest hit gradient shader");
  228. auto closestHitGradientShaderVariant = m_closestHitGradientShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  229. RHI::PipelineStateDescriptorForRayTracing closestHitGradiantShaderDescriptor;
  230. closestHitGradientShaderVariant.ConfigurePipelineState(closestHitGradiantShaderDescriptor);
  231. // load closest hit solid shader
  232. const char* closestHitSolidShaderFilePath = "Shaders/RHI/RayTracingClosestHitSolid.azshader";
  233. m_closestHitSolidShader = LoadShader(closestHitSolidShaderFilePath, RayTracingExampleName);
  234. AZ_Assert(m_closestHitSolidShader, "Failed to load closest hit solid shader");
  235. auto closestHitSolidShaderVariant = m_closestHitSolidShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  236. RHI::PipelineStateDescriptorForRayTracing closestHitSolidShaderDescriptor;
  237. closestHitSolidShaderVariant.ConfigurePipelineState(closestHitSolidShaderDescriptor);
  238. // global pipeline state and srg
  239. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
  240. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  241. m_globalSrg = CreateShaderResourceGroup(m_rayGenerationShader, "RayTracingGlobalSrg", RayTracingExampleName);
  242. // build the ray tracing pipeline state descriptor
  243. RHI::RayTracingPipelineStateDescriptor descriptor;
  244. descriptor.Build()
  245. ->PipelineState(m_globalPipelineState.get())
  246. ->ShaderLibrary(rayGenerationShaderDescriptor)
  247. ->RayGenerationShaderName(AZ::Name("RayGenerationShader"))
  248. ->ShaderLibrary(missShaderDescriptor)
  249. ->MissShaderName(AZ::Name("MissShader"))
  250. ->ShaderLibrary(closestHitGradiantShaderDescriptor)
  251. ->ClosestHitShaderName(AZ::Name("ClosestHitGradientShader"))
  252. ->ShaderLibrary(closestHitSolidShaderDescriptor)
  253. ->ClosestHitShaderName(AZ::Name("ClosestHitSolidShader"))
  254. ->HitGroup(AZ::Name("HitGroupGradient"))
  255. ->ClosestHitShaderName(AZ::Name("ClosestHitGradientShader"))
  256. ->HitGroup(AZ::Name("HitGroupSolid"))
  257. ->ClosestHitShaderName(AZ::Name("ClosestHitSolidShader"));
  258. // create the ray tracing pipeline state object
  259. m_rayTracingPipelineState = RHI::Factory::Get().CreateRayTracingPipelineState();
  260. m_rayTracingPipelineState->Init(*device.get(), &descriptor);
  261. }
  262. void RayTracingExampleComponent::CreateRayTracingShaderTable()
  263. {
  264. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  265. m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
  266. m_rayTracingShaderTable->Init(*device.get(), *m_rayTracingBufferPools);
  267. }
  268. void RayTracingExampleComponent::CreateRayTracingAccelerationTableScope()
  269. {
  270. struct ScopeData
  271. {
  272. };
  273. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  274. {
  275. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  276. // create triangle BLAS buffer if necessary
  277. if (!m_triangleRayTracingBlas->IsValid())
  278. {
  279. RHI::StreamBufferView triangleVertexBufferView =
  280. {
  281. *m_triangleVB,
  282. 0,
  283. sizeof(m_triangleVertices),
  284. sizeof(VertexPosition)
  285. };
  286. RHI::IndexBufferView triangleIndexBufferView =
  287. {
  288. *m_triangleIB,
  289. 0,
  290. sizeof(m_triangleIndices),
  291. RHI::IndexFormat::Uint16
  292. };
  293. RHI::RayTracingBlasDescriptor triangleBlasDescriptor;
  294. triangleBlasDescriptor.Build()
  295. ->Geometry()
  296. ->VertexFormat(RHI::Format::R32G32B32_FLOAT)
  297. ->VertexBuffer(triangleVertexBufferView)
  298. ->IndexBuffer(triangleIndexBufferView)
  299. ;
  300. m_triangleRayTracingBlas->CreateBuffers(*device, &triangleBlasDescriptor, *m_rayTracingBufferPools);
  301. }
  302. // create rectangle BLAS if necessary
  303. if (!m_rectangleRayTracingBlas->IsValid())
  304. {
  305. RHI::StreamBufferView rectangleVertexBufferView =
  306. {
  307. *m_rectangleVB,
  308. 0,
  309. sizeof(m_rectangleVertices),
  310. sizeof(VertexPosition)
  311. };
  312. RHI::IndexBufferView rectangleIndexBufferView =
  313. {
  314. *m_rectangleIB,
  315. 0,
  316. sizeof(m_rectangleIndices),
  317. RHI::IndexFormat::Uint16
  318. };
  319. RHI::RayTracingBlasDescriptor rectangleBlasDescriptor;
  320. rectangleBlasDescriptor.Build()
  321. ->Geometry()
  322. ->VertexFormat(RHI::Format::R32G32B32_FLOAT)
  323. ->VertexBuffer(rectangleVertexBufferView)
  324. ->IndexBuffer(rectangleIndexBufferView)
  325. ;
  326. m_rectangleRayTracingBlas->CreateBuffers(*device, &rectangleBlasDescriptor, *m_rayTracingBufferPools);
  327. }
  328. m_time += 0.005f;
  329. // transforms
  330. AZ::Transform triangleTransform1 = AZ::Transform::CreateIdentity();
  331. triangleTransform1.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * -100.0f, 1.0f);
  332. triangleTransform1.MultiplyByUniformScale(100.0f);
  333. AZ::Transform triangleTransform2 = AZ::Transform::CreateIdentity();
  334. triangleTransform2.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * 100.0f, 2.0f);
  335. triangleTransform2.MultiplyByUniformScale(100.0f);
  336. AZ::Transform triangleTransform3 = AZ::Transform::CreateIdentity();
  337. triangleTransform3.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * 100.0f, 3.0f);
  338. triangleTransform3.MultiplyByUniformScale(100.0f);
  339. AZ::Transform rectangleTransform = AZ::Transform::CreateIdentity();
  340. rectangleTransform.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * -100.0f, 4.0f);
  341. rectangleTransform.MultiplyByUniformScale(100.0f);
  342. // create the TLAS
  343. RHI::RayTracingTlasDescriptor tlasDescriptor;
  344. tlasDescriptor.Build()
  345. ->Instance()
  346. ->InstanceID(0)
  347. ->HitGroupIndex(0)
  348. ->Blas(m_triangleRayTracingBlas)
  349. ->Transform(triangleTransform1)
  350. ->Instance()
  351. ->InstanceID(1)
  352. ->HitGroupIndex(1)
  353. ->Blas(m_triangleRayTracingBlas)
  354. ->Transform(triangleTransform2)
  355. ->Instance()
  356. ->InstanceID(2)
  357. ->HitGroupIndex(2)
  358. ->Blas(m_triangleRayTracingBlas)
  359. ->Transform(triangleTransform3)
  360. ->Instance()
  361. ->InstanceID(3)
  362. ->HitGroupIndex(3)
  363. ->Blas(m_rectangleRayTracingBlas)
  364. ->Transform(rectangleTransform)
  365. ;
  366. m_rayTracingTlas->CreateBuffers(*device, &tlasDescriptor, *m_rayTracingBufferPools);
  367. m_tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, (uint32_t)m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  368. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_tlasBufferAttachmentId, m_rayTracingTlas->GetTlasBuffer());
  369. AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import TLAS buffer with error %d", result);
  370. RHI::BufferScopeAttachmentDescriptor desc;
  371. desc.m_attachmentId = m_tlasBufferAttachmentId;
  372. desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
  373. desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  374. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  375. };
  376. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  377. const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  378. {
  379. RHI::CommandList* commandList = context.GetCommandList();
  380. commandList->BuildBottomLevelAccelerationStructure(*m_triangleRayTracingBlas);
  381. commandList->BuildBottomLevelAccelerationStructure(*m_rectangleRayTracingBlas);
  382. commandList->BuildTopLevelAccelerationStructure(
  383. *m_rayTracingTlas, { m_triangleRayTracingBlas.get(), m_rectangleRayTracingBlas.get() });
  384. };
  385. m_scopeProducers.emplace_back(
  386. aznew RHI::ScopeProducerFunction<
  387. ScopeData,
  388. decltype(prepareFunction),
  389. decltype(compileFunction),
  390. decltype(executeFunction)>(
  391. RHI::ScopeId{ "RayTracingBuildAccelerationStructure" },
  392. ScopeData{},
  393. prepareFunction,
  394. compileFunction,
  395. executeFunction));
  396. }
  397. void RayTracingExampleComponent::CreateRayTracingDispatchScope()
  398. {
  399. struct ScopeData
  400. {
  401. };
  402. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  403. {
  404. // attach output image
  405. {
  406. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportImage(m_outputImageAttachmentId, m_outputImage);
  407. AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import output image with error %d", result);
  408. RHI::ImageScopeAttachmentDescriptor desc;
  409. desc.m_attachmentId = m_outputImageAttachmentId;
  410. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  411. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  412. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  413. }
  414. // attach TLAS buffer
  415. if (m_rayTracingTlas->GetTlasBuffer())
  416. {
  417. RHI::BufferScopeAttachmentDescriptor desc;
  418. desc.m_attachmentId = m_tlasBufferAttachmentId;
  419. desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
  420. desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  421. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  422. }
  423. frameGraph.SetEstimatedItemCount(1);
  424. };
  425. const auto compileFunction = [this]([[maybe_unused]] const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  426. {
  427. if (m_rayTracingTlas->GetTlasBuffer())
  428. {
  429. // set the TLAS and output image in the ray tracing global Srg
  430. RHI::ShaderInputBufferIndex tlasConstantIndex;
  431. FindShaderInputIndex(&tlasConstantIndex, m_globalSrg, AZ::Name{ "m_scene" }, RayTracingExampleName);
  432. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  433. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  434. m_globalSrg->SetBufferView(tlasConstantIndex, m_rayTracingTlas->GetTlasBuffer()->GetBufferView(bufferViewDescriptor).get());
  435. RHI::ShaderInputImageIndex outputConstantIndex;
  436. FindShaderInputIndex(&outputConstantIndex, m_globalSrg, AZ::Name{ "m_output" }, RayTracingExampleName);
  437. m_globalSrg->SetImageView(outputConstantIndex, m_outputImageView.get());
  438. // set hit shader data, each array element corresponds to the InstanceIndex() of the geometry in the TLAS
  439. // Note: this method is used instead of LocalRootSignatures for compatibility with non-DX12 platforms
  440. // set HitGradient values
  441. RHI::ShaderInputConstantIndex hitGradientDataConstantIndex;
  442. FindShaderInputIndex(&hitGradientDataConstantIndex, m_globalSrg, AZ::Name{"m_hitGradientData"}, RayTracingExampleName);
  443. struct HitGradientData
  444. {
  445. AZ::Vector4 m_color;
  446. };
  447. AZStd::array<HitGradientData, 4> hitGradientData = {{
  448. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f)}, // triangle1
  449. {AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle2
  450. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  451. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  452. }};
  453. m_globalSrg->SetConstantArray(hitGradientDataConstantIndex, hitGradientData);
  454. // set HitSolid values
  455. RHI::ShaderInputConstantIndex hitSolidDataConstantIndex;
  456. FindShaderInputIndex(&hitSolidDataConstantIndex, m_globalSrg, AZ::Name{"m_hitSolidData"}, RayTracingExampleName);
  457. struct HitSolidData
  458. {
  459. AZ::Vector4 m_color1;
  460. float m_lerp;
  461. float m_pad[3];
  462. AZ::Vector4 m_color2;
  463. };
  464. AZStd::array<HitSolidData, 4> hitSolidData = {{
  465. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  466. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  467. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle3
  468. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 1.0f, 1.0f)}, // rectangle
  469. }};
  470. m_globalSrg->SetConstantArray(hitSolidDataConstantIndex, hitSolidData);
  471. m_globalSrg->Compile();
  472. // update the ray tracing shader table
  473. AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
  474. descriptor->Build(AZ::Name("RayTracingExampleShaderTable"), m_rayTracingPipelineState)
  475. ->RayGenerationRecord(AZ::Name("RayGenerationShader"))
  476. ->MissRecord(AZ::Name("MissShader"))
  477. ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle1
  478. ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle2
  479. ->HitGroupRecord(AZ::Name("HitGroupSolid")) // triangle3
  480. ->HitGroupRecord(AZ::Name("HitGroupSolid")) // rectangle
  481. ;
  482. m_rayTracingShaderTable->Build(descriptor);
  483. }
  484. };
  485. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  486. {
  487. if (!m_rayTracingTlas->GetTlasBuffer())
  488. {
  489. return;
  490. }
  491. RHI::CommandList* commandList = context.GetCommandList();
  492. const RHI::ShaderResourceGroup* shaderResourceGroups[] = {
  493. m_globalSrg->GetRHIShaderResourceGroup()
  494. };
  495. RHI::DispatchRaysItem dispatchRaysItem;
  496. dispatchRaysItem.m_arguments.m_direct.m_width = m_imageWidth;
  497. dispatchRaysItem.m_arguments.m_direct.m_height = m_imageHeight;
  498. dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
  499. dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState.get();
  500. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable.get();
  501. dispatchRaysItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
  502. dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
  503. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState.get();
  504. // submit the DispatchRays item
  505. commandList->Submit(dispatchRaysItem);
  506. };
  507. m_scopeProducers.emplace_back(
  508. aznew RHI::ScopeProducerFunction<
  509. ScopeData,
  510. decltype(prepareFunction),
  511. decltype(compileFunction),
  512. decltype(executeFunction)>(
  513. RHI::ScopeId{ "RayTracingDispatch" },
  514. ScopeData{},
  515. prepareFunction,
  516. compileFunction,
  517. executeFunction));
  518. }
  519. void RayTracingExampleComponent::CreateRasterScope()
  520. {
  521. struct ScopeData
  522. {
  523. };
  524. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  525. {
  526. // attach swapchain
  527. {
  528. RHI::ImageScopeAttachmentDescriptor descriptor;
  529. descriptor.m_attachmentId = m_outputAttachmentId;
  530. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  531. frameGraph.UseColorAttachment(descriptor);
  532. }
  533. // attach output buffer
  534. {
  535. RHI::ImageScopeAttachmentDescriptor desc;
  536. desc.m_attachmentId = m_outputImageAttachmentId;
  537. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  538. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  539. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  540. const Name outputImageId{ "m_output" };
  541. RHI::ShaderInputImageIndex outputImageIndex = m_drawSRG->FindShaderInputImageIndex(outputImageId);
  542. AZ_Error(RayTracingExampleName, outputImageIndex.IsValid(), "Failed to find shader input image %s.", outputImageId.GetCStr());
  543. m_drawSRG->SetImageView(outputImageIndex, m_outputImageView.get());
  544. m_drawSRG->Compile();
  545. }
  546. frameGraph.SetEstimatedItemCount(1);
  547. };
  548. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  549. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  550. {
  551. RHI::CommandList* commandList = context.GetCommandList();
  552. commandList->SetViewports(&m_viewport, 1);
  553. commandList->SetScissors(&m_scissor, 1);
  554. RHI::DrawIndexed drawIndexed;
  555. drawIndexed.m_indexCount = 6;
  556. drawIndexed.m_instanceCount = 1;
  557. const RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_drawSRG->GetRHIShaderResourceGroup() };
  558. RHI::DrawItem drawItem;
  559. drawItem.m_arguments = drawIndexed;
  560. drawItem.m_pipelineState = m_drawPipelineState.get();
  561. drawItem.m_indexBufferView = &m_fullScreenIndexBufferView;
  562. drawItem.m_streamBufferViewCount = static_cast<uint8_t>(m_fullScreenStreamBufferViews.size());
  563. drawItem.m_streamBufferViews = m_fullScreenStreamBufferViews.data();
  564. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
  565. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  566. // submit the triangle draw item.
  567. commandList->Submit(drawItem);
  568. };
  569. m_scopeProducers.emplace_back(
  570. aznew RHI::ScopeProducerFunction<
  571. ScopeData,
  572. decltype(prepareFunction),
  573. decltype(compileFunction),
  574. decltype(executeFunction)>(
  575. RHI::ScopeId{ "Raster" },
  576. ScopeData{},
  577. prepareFunction,
  578. compileFunction,
  579. executeFunction));
  580. }
  581. } // namespace AtomSampleViewer