2
0

RayTracingExampleComponent.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RHI/RayTracingExampleComponent.h>
  9. #include <Utils/Utils.h>
  10. #include <SampleComponentManager.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI/FrameGraphInterface.h>
  13. #include <Atom/RHI/RayTracingPipelineState.h>
  14. #include <Atom/RHI/RayTracingShaderTable.h>
  15. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  16. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  17. #include <Atom/RPI.Public/Shader/Shader.h>
  18. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  19. #include <AzCore/Serialization/SerializeContext.h>
  20. static const char* RayTracingExampleName = "RayTracingExample";
  21. namespace AtomSampleViewer
  22. {
  23. void RayTracingExampleComponent::Reflect(AZ::ReflectContext* context)
  24. {
  25. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  26. {
  27. serializeContext->Class<RayTracingExampleComponent, AZ::Component>()
  28. ->Version(0)
  29. ;
  30. }
  31. }
  32. RayTracingExampleComponent::RayTracingExampleComponent()
  33. {
  34. m_supportRHISamplePipeline = true;
  35. }
  36. void RayTracingExampleComponent::Activate()
  37. {
  38. CreateResourcePools();
  39. CreateGeometry();
  40. CreateFullScreenBuffer();
  41. CreateOutputTexture();
  42. CreateRasterShader();
  43. CreateRayTracingAccelerationStructureObjects();
  44. CreateRayTracingPipelineState();
  45. CreateRayTracingShaderTable();
  46. CreateRayTracingAccelerationTableScope();
  47. CreateRayTracingDispatchScope();
  48. CreateRasterScope();
  49. RHI::RHISystemNotificationBus::Handler::BusConnect();
  50. }
  51. void RayTracingExampleComponent::Deactivate()
  52. {
  53. RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  54. m_windowContext = nullptr;
  55. m_scopeProducers.clear();
  56. }
  57. void RayTracingExampleComponent::CreateResourcePools()
  58. {
  59. // create input assembly buffer pool
  60. {
  61. m_inputAssemblyBufferPool = aznew RHI::BufferPool();
  62. RHI::BufferPoolDescriptor bufferPoolDesc;
  63. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  64. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Host;
  65. [[maybe_unused]] RHI::ResultCode resultCode = m_inputAssemblyBufferPool->Init(bufferPoolDesc);
  66. AZ_Assert(resultCode == RHI::ResultCode::Success, "Failed to initialize input assembly buffer pool");
  67. }
  68. // create output image pool
  69. {
  70. RHI::ImagePoolDescriptor imagePoolDesc;
  71. imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
  72. m_imagePool = aznew RHI::ImagePool();
  73. [[maybe_unused]] RHI::ResultCode result = m_imagePool->Init(imagePoolDesc);
  74. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image pool");
  75. }
  76. // initialize ray tracing buffer pools
  77. m_rayTracingBufferPools = aznew RHI::RayTracingBufferPools;
  78. m_rayTracingBufferPools->Init(RHI::MultiDevice::AllDevices);
  79. }
  80. void RayTracingExampleComponent::CreateGeometry()
  81. {
  82. // triangle
  83. {
  84. // vertex buffer
  85. SetVertexPosition(m_triangleVertices.data(), 0, 0.0f, 0.5f, 1.0);
  86. SetVertexPosition(m_triangleVertices.data(), 1, 0.5f, -0.5f, 1.0);
  87. SetVertexPosition(m_triangleVertices.data(), 2, -0.5f, -0.5f, 1.0);
  88. m_triangleVB = aznew RHI::Buffer();
  89. m_triangleVB->SetName(AZ::Name("Triangle VB"));
  90. RHI::BufferInitRequest request;
  91. request.m_buffer = m_triangleVB.get();
  92. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleVertices) };
  93. request.m_initialData = m_triangleVertices.data();
  94. m_inputAssemblyBufferPool->InitBuffer(request);
  95. // index buffer
  96. SetVertexIndexIncreasing(m_triangleIndices.data(), m_triangleIndices.size());
  97. m_triangleIB = aznew RHI::Buffer();
  98. m_triangleIB->SetName(AZ::Name("Triangle IB"));
  99. request.m_buffer = m_triangleIB.get();
  100. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleIndices) };
  101. request.m_initialData = m_triangleIndices.data();
  102. m_inputAssemblyBufferPool->InitBuffer(request);
  103. }
  104. // rectangle
  105. {
  106. // vertex buffer
  107. SetVertexPosition(m_rectangleVertices.data(), 0, -0.5f, 0.5f, 1.0);
  108. SetVertexPosition(m_rectangleVertices.data(), 1, 0.5f, 0.5f, 1.0);
  109. SetVertexPosition(m_rectangleVertices.data(), 2, 0.5f, -0.5f, 1.0);
  110. SetVertexPosition(m_rectangleVertices.data(), 3, -0.5f, -0.5f, 1.0);
  111. m_rectangleVB = aznew RHI::Buffer();
  112. m_rectangleVB->SetName(AZ::Name("Rectangle VB"));
  113. RHI::BufferInitRequest request;
  114. request.m_buffer = m_rectangleVB.get();
  115. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleVertices) };
  116. request.m_initialData = m_rectangleVertices.data();
  117. m_inputAssemblyBufferPool->InitBuffer(request);
  118. // index buffer
  119. m_rectangleIndices[0] = 0;
  120. m_rectangleIndices[1] = 1;
  121. m_rectangleIndices[2] = 2;
  122. m_rectangleIndices[3] = 0;
  123. m_rectangleIndices[4] = 2;
  124. m_rectangleIndices[5] = 3;
  125. m_rectangleIB = aznew RHI::Buffer();
  126. m_rectangleIB->SetName(AZ::Name("Rectangle IB"));
  127. request.m_buffer = m_rectangleIB.get();
  128. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleIndices) };
  129. request.m_initialData = m_rectangleIndices.data();
  130. m_inputAssemblyBufferPool->InitBuffer(request);
  131. }
  132. }
  133. void RayTracingExampleComponent::CreateFullScreenBuffer()
  134. {
  135. FullScreenBufferData bufferData;
  136. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  137. m_fullScreenInputAssemblyBuffer = aznew RHI::Buffer();
  138. RHI::BufferInitRequest request;
  139. request.m_buffer = m_fullScreenInputAssemblyBuffer.get();
  140. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  141. request.m_initialData = &bufferData;
  142. m_inputAssemblyBufferPool->InitBuffer(request);
  143. m_geometryView.SetDrawArguments(RHI::DrawIndexed(0, 6, 0));
  144. m_geometryView.AddStreamBufferView({
  145. *m_fullScreenInputAssemblyBuffer,
  146. offsetof(FullScreenBufferData, m_positions),
  147. sizeof(FullScreenBufferData::m_positions),
  148. sizeof(VertexPosition)
  149. });
  150. m_geometryView.AddStreamBufferView({
  151. *m_fullScreenInputAssemblyBuffer,
  152. offsetof(FullScreenBufferData, m_uvs),
  153. sizeof(FullScreenBufferData::m_uvs),
  154. sizeof(VertexUV)
  155. });
  156. m_geometryView.SetIndexBufferView({
  157. *m_fullScreenInputAssemblyBuffer,
  158. offsetof(FullScreenBufferData, m_indices),
  159. sizeof(FullScreenBufferData::m_indices),
  160. RHI::IndexFormat::Uint16
  161. });
  162. RHI::InputStreamLayoutBuilder layoutBuilder;
  163. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  164. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  165. m_fullScreenInputStreamLayout = layoutBuilder.End();
  166. }
  167. void RayTracingExampleComponent::CreateOutputTexture()
  168. {
  169. // create output image
  170. m_outputImage = aznew RHI::Image();
  171. RHI::ImageInitRequest request;
  172. request.m_image = m_outputImage.get();
  173. request.m_descriptor = RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, m_imageWidth, m_imageHeight, RHI::Format::R8G8B8A8_UNORM);
  174. [[maybe_unused]] RHI::ResultCode result = m_imagePool->InitImage(request);
  175. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image");
  176. m_outputImageViewDescriptor = RHI::ImageViewDescriptor::Create(RHI::Format::R8G8B8A8_UNORM, 0, 0);
  177. m_outputImageView = m_outputImage->BuildImageView(m_outputImageViewDescriptor);
  178. AZ_Assert(m_outputImageView.get(), "Failed to create output image view");
  179. AZ_Assert(m_outputImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex)->IsFullView(), "Image View initialization IsFullView() failed");
  180. }
  181. void RayTracingExampleComponent::CreateRayTracingAccelerationStructureObjects()
  182. {
  183. m_triangleRayTracingBlas = aznew AZ::RHI::RayTracingBlas;
  184. m_rectangleRayTracingBlas = aznew AZ::RHI::RayTracingBlas;
  185. m_rayTracingTlas = aznew AZ::RHI::RayTracingTlas;
  186. }
  187. void RayTracingExampleComponent::CreateRasterShader()
  188. {
  189. const char* shaderFilePath = "Shaders/RHI/RayTracingDraw.azshader";
  190. auto drawShader = LoadShader(shaderFilePath, RayTracingExampleName);
  191. AZ_Assert(drawShader, "Failed to load Draw shader");
  192. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  193. drawShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  194. pipelineDesc.m_inputStreamLayout = m_fullScreenInputStreamLayout;
  195. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  196. attachmentsBuilder.AddSubpass()->RenderTargetAttachment(m_outputFormat);
  197. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  198. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create draw render attachment layout");
  199. m_drawPipelineState = drawShader->AcquirePipelineState(pipelineDesc);
  200. AZ_Assert(m_drawPipelineState, "Failed to acquire draw pipeline state");
  201. m_drawSRG = CreateShaderResourceGroup(drawShader, "BufferSrg", RayTracingExampleName);
  202. }
  203. void RayTracingExampleComponent::CreateRayTracingPipelineState()
  204. {
  205. // load ray generation shader
  206. const char* rayGenerationShaderFilePath = "Shaders/RHI/RayTracingDispatch.azshader";
  207. m_rayGenerationShader = LoadShader(rayGenerationShaderFilePath, RayTracingExampleName);
  208. AZ_Assert(m_rayGenerationShader, "Failed to load ray generation shader");
  209. auto rayGenerationShaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  210. RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
  211. rayGenerationShaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
  212. // load miss shader
  213. const char* missShaderFilePath = "Shaders/RHI/RayTracingMiss.azshader";
  214. m_missShader = LoadShader(missShaderFilePath, RayTracingExampleName);
  215. AZ_Assert(m_missShader, "Failed to load miss shader");
  216. auto missShaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  217. RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
  218. missShaderVariant.ConfigurePipelineState(missShaderDescriptor);
  219. // load closest hit gradient shader
  220. const char* closestHitGradientShaderFilePath = "Shaders/RHI/RayTracingClosestHitGradient.azshader";
  221. m_closestHitGradientShader = LoadShader(closestHitGradientShaderFilePath, RayTracingExampleName);
  222. AZ_Assert(m_closestHitGradientShader, "Failed to load closest hit gradient shader");
  223. auto closestHitGradientShaderVariant = m_closestHitGradientShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  224. RHI::PipelineStateDescriptorForRayTracing closestHitGradiantShaderDescriptor;
  225. closestHitGradientShaderVariant.ConfigurePipelineState(closestHitGradiantShaderDescriptor);
  226. // load closest hit solid shader
  227. const char* closestHitSolidShaderFilePath = "Shaders/RHI/RayTracingClosestHitSolid.azshader";
  228. m_closestHitSolidShader = LoadShader(closestHitSolidShaderFilePath, RayTracingExampleName);
  229. AZ_Assert(m_closestHitSolidShader, "Failed to load closest hit solid shader");
  230. auto closestHitSolidShaderVariant = m_closestHitSolidShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  231. RHI::PipelineStateDescriptorForRayTracing closestHitSolidShaderDescriptor;
  232. closestHitSolidShaderVariant.ConfigurePipelineState(closestHitSolidShaderDescriptor);
  233. // global pipeline state and srg
  234. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
  235. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  236. m_globalSrg = CreateShaderResourceGroup(m_rayGenerationShader, "RayTracingGlobalSrg", RayTracingExampleName);
  237. // build the ray tracing pipeline state descriptor
  238. RHI::RayTracingPipelineStateDescriptor descriptor;
  239. descriptor.Build()
  240. ->PipelineState(m_globalPipelineState.get())
  241. ->ShaderLibrary(rayGenerationShaderDescriptor)
  242. ->RayGenerationShaderName(AZ::Name("RayGenerationShader"))
  243. ->ShaderLibrary(missShaderDescriptor)
  244. ->MissShaderName(AZ::Name("MissShader"))
  245. ->ShaderLibrary(closestHitGradiantShaderDescriptor)
  246. ->ClosestHitShaderName(AZ::Name("ClosestHitGradientShader"))
  247. ->ShaderLibrary(closestHitSolidShaderDescriptor)
  248. ->ClosestHitShaderName(AZ::Name("ClosestHitSolidShader"))
  249. ->HitGroup(AZ::Name("HitGroupGradient"))
  250. ->ClosestHitShaderName(AZ::Name("ClosestHitGradientShader"))
  251. ->HitGroup(AZ::Name("HitGroupSolid"))
  252. ->ClosestHitShaderName(AZ::Name("ClosestHitSolidShader"));
  253. // create the ray tracing pipeline state object
  254. m_rayTracingPipelineState = aznew RHI::RayTracingPipelineState;
  255. m_rayTracingPipelineState->Init(RHI::MultiDevice::AllDevices, descriptor);
  256. }
  257. void RayTracingExampleComponent::CreateRayTracingShaderTable()
  258. {
  259. m_rayTracingShaderTable = aznew RHI::RayTracingShaderTable;
  260. m_rayTracingShaderTable->Init(RHI::MultiDevice::AllDevices, *m_rayTracingBufferPools);
  261. }
  262. void RayTracingExampleComponent::CreateRayTracingAccelerationTableScope()
  263. {
  264. struct ScopeData
  265. {
  266. };
  267. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  268. {
  269. // create triangle BLAS buffer if necessary
  270. if (!m_triangleRayTracingBlas->IsValid())
  271. {
  272. RHI::StreamBufferView triangleVertexBufferView =
  273. {
  274. *m_triangleVB,
  275. 0,
  276. sizeof(m_triangleVertices),
  277. sizeof(VertexPosition)
  278. };
  279. RHI::IndexBufferView triangleIndexBufferView =
  280. {
  281. *m_triangleIB,
  282. 0,
  283. sizeof(m_triangleIndices),
  284. RHI::IndexFormat::Uint16
  285. };
  286. RHI::RayTracingBlasDescriptor triangleBlasDescriptor;
  287. triangleBlasDescriptor.Build()
  288. ->Geometry()
  289. ->VertexFormat(RHI::Format::R32G32B32_FLOAT)
  290. ->VertexBuffer(triangleVertexBufferView)
  291. ->IndexBuffer(triangleIndexBufferView)
  292. ;
  293. m_triangleRayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &triangleBlasDescriptor, *m_rayTracingBufferPools);
  294. }
  295. // create rectangle BLAS if necessary
  296. if (!m_rectangleRayTracingBlas->IsValid())
  297. {
  298. RHI::StreamBufferView rectangleVertexBufferView =
  299. {
  300. *m_rectangleVB,
  301. 0,
  302. sizeof(m_rectangleVertices),
  303. sizeof(VertexPosition)
  304. };
  305. RHI::IndexBufferView rectangleIndexBufferView =
  306. {
  307. *m_rectangleIB,
  308. 0,
  309. sizeof(m_rectangleIndices),
  310. RHI::IndexFormat::Uint16
  311. };
  312. RHI::RayTracingBlasDescriptor rectangleBlasDescriptor;
  313. rectangleBlasDescriptor.Build()
  314. ->Geometry()
  315. ->VertexFormat(RHI::Format::R32G32B32_FLOAT)
  316. ->VertexBuffer(rectangleVertexBufferView)
  317. ->IndexBuffer(rectangleIndexBufferView)
  318. ;
  319. m_rectangleRayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &rectangleBlasDescriptor, *m_rayTracingBufferPools);
  320. }
  321. m_time += 0.005f;
  322. // transforms
  323. AZ::Transform triangleTransform1 = AZ::Transform::CreateIdentity();
  324. triangleTransform1.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * -100.0f, 1.0f);
  325. triangleTransform1.MultiplyByUniformScale(100.0f);
  326. AZ::Transform triangleTransform2 = AZ::Transform::CreateIdentity();
  327. triangleTransform2.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * 100.0f, 2.0f);
  328. triangleTransform2.MultiplyByUniformScale(100.0f);
  329. AZ::Transform triangleTransform3 = AZ::Transform::CreateIdentity();
  330. triangleTransform3.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * 100.0f, 3.0f);
  331. triangleTransform3.MultiplyByUniformScale(100.0f);
  332. AZ::Transform rectangleTransform = AZ::Transform::CreateIdentity();
  333. rectangleTransform.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * -100.0f, 4.0f);
  334. rectangleTransform.MultiplyByUniformScale(100.0f);
  335. // create the TLAS
  336. RHI::RayTracingTlasDescriptor tlasDescriptor;
  337. tlasDescriptor.Build()
  338. ->Instance()
  339. ->InstanceID(0)
  340. ->HitGroupIndex(0)
  341. ->Blas(m_triangleRayTracingBlas)
  342. ->Transform(triangleTransform1)
  343. ->Instance()
  344. ->InstanceID(1)
  345. ->HitGroupIndex(1)
  346. ->Blas(m_triangleRayTracingBlas)
  347. ->Transform(triangleTransform2)
  348. ->Instance()
  349. ->InstanceID(2)
  350. ->HitGroupIndex(2)
  351. ->Blas(m_triangleRayTracingBlas)
  352. ->Transform(triangleTransform3)
  353. ->Instance()
  354. ->InstanceID(3)
  355. ->HitGroupIndex(3)
  356. ->Blas(m_rectangleRayTracingBlas)
  357. ->Transform(rectangleTransform)
  358. ;
  359. m_rayTracingTlas->CreateBuffers(RHI::MultiDevice::AllDevices, &tlasDescriptor, *m_rayTracingBufferPools);
  360. m_tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, (uint32_t)m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  361. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_tlasBufferAttachmentId, m_rayTracingTlas->GetTlasBuffer());
  362. AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import TLAS buffer with error %d", result);
  363. RHI::BufferScopeAttachmentDescriptor desc;
  364. desc.m_attachmentId = m_tlasBufferAttachmentId;
  365. desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
  366. desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  367. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::AnyGraphics);
  368. };
  369. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  370. const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  371. {
  372. RHI::CommandList* commandList = context.GetCommandList();
  373. commandList->BuildBottomLevelAccelerationStructure(*m_triangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  374. commandList->BuildBottomLevelAccelerationStructure(*m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  375. commandList->BuildTopLevelAccelerationStructure(
  376. *m_rayTracingTlas->GetDeviceRayTracingTlas(context.GetDeviceIndex()), { m_triangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get(), m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get() });
  377. };
  378. m_scopeProducers.emplace_back(
  379. aznew RHI::ScopeProducerFunction<
  380. ScopeData,
  381. decltype(prepareFunction),
  382. decltype(compileFunction),
  383. decltype(executeFunction)>(
  384. RHI::ScopeId{ "RayTracingBuildAccelerationStructure" },
  385. ScopeData{},
  386. prepareFunction,
  387. compileFunction,
  388. executeFunction));
  389. }
  390. void RayTracingExampleComponent::CreateRayTracingDispatchScope()
  391. {
  392. struct ScopeData
  393. {
  394. };
  395. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  396. {
  397. // attach output image
  398. {
  399. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportImage(m_outputImageAttachmentId, m_outputImage);
  400. AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import output image with error %d", result);
  401. RHI::ImageScopeAttachmentDescriptor desc;
  402. desc.m_attachmentId = m_outputImageAttachmentId;
  403. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  404. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  405. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::RayTracingShader);
  406. }
  407. // attach TLAS buffer
  408. if (m_rayTracingTlas->GetTlasBuffer())
  409. {
  410. RHI::BufferScopeAttachmentDescriptor desc;
  411. desc.m_attachmentId = m_tlasBufferAttachmentId;
  412. desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
  413. desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  414. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::RayTracingShader);
  415. }
  416. frameGraph.SetEstimatedItemCount(1);
  417. };
  418. const auto compileFunction = [this]([[maybe_unused]] const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  419. {
  420. if (m_rayTracingTlas->GetTlasBuffer())
  421. {
  422. // set the TLAS and output image in the ray tracing global Srg
  423. RHI::ShaderInputBufferIndex tlasConstantIndex;
  424. FindShaderInputIndex(&tlasConstantIndex, m_globalSrg, AZ::Name{ "m_scene" }, RayTracingExampleName);
  425. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  426. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  427. m_globalSrg->SetBufferView(tlasConstantIndex, m_rayTracingTlas->GetTlasBuffer()->BuildBufferView(bufferViewDescriptor).get());
  428. RHI::ShaderInputImageIndex outputConstantIndex;
  429. FindShaderInputIndex(&outputConstantIndex, m_globalSrg, AZ::Name{ "m_output" }, RayTracingExampleName);
  430. m_globalSrg->SetImageView(outputConstantIndex, m_outputImageView.get());
  431. // set hit shader data, each array element corresponds to the InstanceIndex() of the geometry in the TLAS
  432. // Note: this method is used instead of LocalRootSignatures for compatibility with non-DX12 platforms
  433. // set HitGradient values
  434. RHI::ShaderInputConstantIndex hitGradientDataConstantIndex;
  435. FindShaderInputIndex(&hitGradientDataConstantIndex, m_globalSrg, AZ::Name{"m_hitGradientData"}, RayTracingExampleName);
  436. struct HitGradientData
  437. {
  438. AZ::Vector4 m_color;
  439. };
  440. AZStd::array<HitGradientData, 4> hitGradientData = {{
  441. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f)}, // triangle1
  442. {AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle2
  443. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  444. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  445. }};
  446. m_globalSrg->SetConstantArray(hitGradientDataConstantIndex, hitGradientData);
  447. // set HitSolid values
  448. RHI::ShaderInputConstantIndex hitSolidDataConstantIndex;
  449. FindShaderInputIndex(&hitSolidDataConstantIndex, m_globalSrg, AZ::Name{"m_hitSolidData"}, RayTracingExampleName);
  450. struct HitSolidData
  451. {
  452. AZ::Vector4 m_color1;
  453. float m_lerp;
  454. float m_pad[3];
  455. AZ::Vector4 m_color2;
  456. };
  457. AZStd::array<HitSolidData, 4> hitSolidData = {{
  458. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  459. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  460. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle3
  461. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 1.0f, 1.0f)}, // rectangle
  462. }};
  463. m_globalSrg->SetConstantArray(hitSolidDataConstantIndex, hitSolidData);
  464. m_globalSrg->Compile();
  465. // update the ray tracing shader table
  466. AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
  467. descriptor->Build(AZ::Name("RayTracingExampleShaderTable"), m_rayTracingPipelineState)
  468. ->RayGenerationRecord(AZ::Name("RayGenerationShader"))
  469. ->MissRecord(AZ::Name("MissShader"))
  470. ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle1
  471. ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle2
  472. ->HitGroupRecord(AZ::Name("HitGroupSolid")) // triangle3
  473. ->HitGroupRecord(AZ::Name("HitGroupSolid")) // rectangle
  474. ;
  475. m_rayTracingShaderTable->Build(descriptor);
  476. }
  477. };
  478. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  479. {
  480. if (!m_rayTracingTlas->GetTlasBuffer())
  481. {
  482. return;
  483. }
  484. RHI::CommandList* commandList = context.GetCommandList();
  485. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  486. m_globalSrg->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  487. };
  488. RHI::DeviceDispatchRaysItem dispatchRaysItem;
  489. dispatchRaysItem.m_arguments.m_direct.m_width = m_imageWidth;
  490. dispatchRaysItem.m_arguments.m_direct.m_height = m_imageHeight;
  491. dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
  492. dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
  493. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
  494. dispatchRaysItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
  495. dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
  496. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  497. // submit the DispatchRays item
  498. commandList->Submit(dispatchRaysItem);
  499. };
  500. m_scopeProducers.emplace_back(
  501. aznew RHI::ScopeProducerFunction<
  502. ScopeData,
  503. decltype(prepareFunction),
  504. decltype(compileFunction),
  505. decltype(executeFunction)>(
  506. RHI::ScopeId{ "RayTracingDispatch" },
  507. ScopeData{},
  508. prepareFunction,
  509. compileFunction,
  510. executeFunction));
  511. }
  512. void RayTracingExampleComponent::CreateRasterScope()
  513. {
  514. struct ScopeData
  515. {
  516. };
  517. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  518. {
  519. // attach swapchain
  520. {
  521. RHI::ImageScopeAttachmentDescriptor descriptor;
  522. descriptor.m_attachmentId = m_outputAttachmentId;
  523. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  524. frameGraph.UseColorAttachment(descriptor);
  525. }
  526. // attach output buffer
  527. {
  528. RHI::ImageScopeAttachmentDescriptor desc;
  529. desc.m_attachmentId = m_outputImageAttachmentId;
  530. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  531. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  532. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::FragmentShader);
  533. const Name outputImageId{ "m_output" };
  534. RHI::ShaderInputImageIndex outputImageIndex = m_drawSRG->FindShaderInputImageIndex(outputImageId);
  535. AZ_Error(RayTracingExampleName, outputImageIndex.IsValid(), "Failed to find shader input image %s.", outputImageId.GetCStr());
  536. m_drawSRG->SetImageView(outputImageIndex, m_outputImageView.get());
  537. m_drawSRG->Compile();
  538. }
  539. frameGraph.SetEstimatedItemCount(1);
  540. };
  541. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  542. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  543. {
  544. RHI::CommandList* commandList = context.GetCommandList();
  545. commandList->SetViewports(&m_viewport, 1);
  546. commandList->SetScissors(&m_scissor, 1);
  547. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  548. m_drawSRG->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  549. };
  550. RHI::DeviceDrawItem drawItem;
  551. drawItem.m_geometryView = m_geometryView.GetDeviceGeometryView(context.GetDeviceIndex());
  552. drawItem.m_streamIndices = m_geometryView.GetFullStreamBufferIndices();
  553. drawItem.m_pipelineState = m_drawPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  554. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
  555. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  556. // submit the triangle draw item.
  557. commandList->Submit(drawItem);
  558. };
  559. m_scopeProducers.emplace_back(
  560. aznew RHI::ScopeProducerFunction<
  561. ScopeData,
  562. decltype(prepareFunction),
  563. decltype(compileFunction),
  564. decltype(executeFunction)>(
  565. RHI::ScopeId{ "Raster" },
  566. ScopeData{},
  567. prepareFunction,
  568. compileFunction,
  569. executeFunction));
  570. }
  571. } // namespace AtomSampleViewer