RayTracingExampleComponent.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. /*
  2. * All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
  3. * its licensors.
  4. *
  5. * For complete copyright and license terms please see the LICENSE at the root of this
  6. * distribution (the "License"). All use of this software is governed by the License,
  7. * or, if provided, by the license below or the license accompanying this file. Do not
  8. * remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. *
  11. */
  12. #include <RHI/RayTracingExampleComponent.h>
  13. #include <Utils/Utils.h>
  14. #include <SampleComponentManager.h>
  15. #include <Atom/RHI/CommandList.h>
  16. #include <Atom/RHI/FrameGraphInterface.h>
  17. #include <Atom/RHI/RayTracingPipelineState.h>
  18. #include <Atom/RHI/RayTracingShaderTable.h>
  19. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  20. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  21. #include <Atom/RPI.Public/Shader/Shader.h>
  22. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  23. #include <AzCore/Serialization/SerializeContext.h>
  24. static const char* RayTracingExampleName = "RayTracingExample";
  25. namespace AtomSampleViewer
  26. {
  27. void RayTracingExampleComponent::Reflect(AZ::ReflectContext* context)
  28. {
  29. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  30. {
  31. serializeContext->Class<RayTracingExampleComponent, AZ::Component>()
  32. ->Version(0)
  33. ;
  34. }
  35. }
  36. RayTracingExampleComponent::RayTracingExampleComponent()
  37. {
  38. m_supportRHISamplePipeline = true;
  39. }
  40. void RayTracingExampleComponent::CreateResourcePools()
  41. {
  42. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  43. // create input assembly buffer pool
  44. {
  45. m_inputAssemblyBufferPool = RHI::Factory::Get().CreateBufferPool();
  46. RHI::BufferPoolDescriptor bufferPoolDesc;
  47. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  48. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Host;
  49. RHI::ResultCode resultCode = m_inputAssemblyBufferPool->Init(*device, bufferPoolDesc);
  50. AZ_Assert(resultCode == RHI::ResultCode::Success, "Failed to initialize input assembly buffer pool");
  51. }
  52. // create output image pool
  53. {
  54. RHI::ImagePoolDescriptor imagePoolDesc;
  55. imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
  56. m_imagePool = RHI::Factory::Get().CreateImagePool();
  57. RHI::ResultCode result = m_imagePool->Init(*device, imagePoolDesc);
  58. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image pool");
  59. }
  60. // initialize ray tracing buffer pools
  61. m_rayTracingBufferPools = RHI::RayTracingBufferPools::CreateRHIRayTracingBufferPools();
  62. m_rayTracingBufferPools->Init(device);
  63. }
  64. void RayTracingExampleComponent::CreateGeometry()
  65. {
  66. // triangle
  67. {
  68. // vertex buffer
  69. SetVertexPosition(m_triangleVertices.data(), 0, 0.0f, 0.5f, 1.0);
  70. SetVertexPosition(m_triangleVertices.data(), 1, 0.5f, -0.5f, 1.0);
  71. SetVertexPosition(m_triangleVertices.data(), 2, -0.5f, -0.5f, 1.0);
  72. m_triangleVB = RHI::Factory::Get().CreateBuffer();
  73. m_triangleVB->SetName(AZ::Name("Triangle VB"));
  74. RHI::BufferInitRequest request;
  75. request.m_buffer = m_triangleVB.get();
  76. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleVertices) };
  77. request.m_initialData = m_triangleVertices.data();
  78. m_inputAssemblyBufferPool->InitBuffer(request);
  79. // index buffer
  80. SetVertexIndexIncreasing(m_triangleIndices.data(), m_triangleIndices.size());
  81. m_triangleIB = RHI::Factory::Get().CreateBuffer();
  82. m_triangleIB->SetName(AZ::Name("Triangle IB"));
  83. request.m_buffer = m_triangleIB.get();
  84. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleIndices) };
  85. request.m_initialData = m_triangleIndices.data();
  86. m_inputAssemblyBufferPool->InitBuffer(request);
  87. }
  88. // rectangle
  89. {
  90. // vertex buffer
  91. SetVertexPosition(m_rectangleVertices.data(), 0, -0.5f, 0.5f, 1.0);
  92. SetVertexPosition(m_rectangleVertices.data(), 1, 0.5f, 0.5f, 1.0);
  93. SetVertexPosition(m_rectangleVertices.data(), 2, 0.5f, -0.5f, 1.0);
  94. SetVertexPosition(m_rectangleVertices.data(), 3, -0.5f, -0.5f, 1.0);
  95. m_rectangleVB = RHI::Factory::Get().CreateBuffer();
  96. m_rectangleVB->SetName(AZ::Name("Rectangle VB"));
  97. RHI::BufferInitRequest request;
  98. request.m_buffer = m_rectangleVB.get();
  99. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleVertices) };
  100. request.m_initialData = m_rectangleVertices.data();
  101. m_inputAssemblyBufferPool->InitBuffer(request);
  102. // index buffer
  103. m_rectangleIndices[0] = 0;
  104. m_rectangleIndices[1] = 1;
  105. m_rectangleIndices[2] = 2;
  106. m_rectangleIndices[3] = 0;
  107. m_rectangleIndices[4] = 2;
  108. m_rectangleIndices[5] = 3;
  109. m_rectangleIB = RHI::Factory::Get().CreateBuffer();
  110. m_rectangleIB->SetName(AZ::Name("Rectangle IB"));
  111. request.m_buffer = m_rectangleIB.get();
  112. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleIndices) };
  113. request.m_initialData = m_rectangleIndices.data();
  114. m_inputAssemblyBufferPool->InitBuffer(request);
  115. }
  116. }
  117. void RayTracingExampleComponent::CreateFullScreenBuffer()
  118. {
  119. FullScreenBufferData bufferData;
  120. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  121. m_fullScreenInputAssemblyBuffer = RHI::Factory::Get().CreateBuffer();
  122. RHI::BufferInitRequest request;
  123. request.m_buffer = m_fullScreenInputAssemblyBuffer.get();
  124. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  125. request.m_initialData = &bufferData;
  126. m_inputAssemblyBufferPool->InitBuffer(request);
  127. m_fullScreenStreamBufferViews[0] =
  128. {
  129. *m_fullScreenInputAssemblyBuffer,
  130. offsetof(FullScreenBufferData, m_positions),
  131. sizeof(FullScreenBufferData::m_positions),
  132. sizeof(VertexPosition)
  133. };
  134. m_fullScreenStreamBufferViews[1] =
  135. {
  136. *m_fullScreenInputAssemblyBuffer,
  137. offsetof(FullScreenBufferData, m_uvs),
  138. sizeof(FullScreenBufferData::m_uvs),
  139. sizeof(VertexUV)
  140. };
  141. m_fullScreenIndexBufferView =
  142. {
  143. *m_fullScreenInputAssemblyBuffer,
  144. offsetof(FullScreenBufferData, m_indices),
  145. sizeof(FullScreenBufferData::m_indices),
  146. RHI::IndexFormat::Uint16
  147. };
  148. RHI::InputStreamLayoutBuilder layoutBuilder;
  149. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  150. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  151. m_fullScreenInputStreamLayout = layoutBuilder.End();
  152. }
  153. void RayTracingExampleComponent::CreateOutputTexture()
  154. {
  155. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  156. // create output image
  157. m_outputImage = RHI::Factory::Get().CreateImage();
  158. uint32_t imageSize = m_imageWidth * m_imageHeight * RHI::GetFormatSize(RHI::Format::R8G8B8A8_UNORM);
  159. RHI::ImageInitRequest request;
  160. request.m_image = m_outputImage.get();
  161. request.m_descriptor = RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, m_imageWidth, m_imageHeight, RHI::Format::R8G8B8A8_UNORM);
  162. RHI::ResultCode result = m_imagePool->InitImage(request);
  163. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image");
  164. m_outputImageViewDescriptor = RHI::ImageViewDescriptor::Create(RHI::Format::R8G8B8A8_UNORM, 0, 0);
  165. m_outputImageView = m_outputImage->GetImageView(m_outputImageViewDescriptor);
  166. AZ_Assert(m_outputImageView.get(), "Failed to create output image view");
  167. AZ_Assert(m_outputImageView->IsFullView(), "Image View initialization IsFullView() failed");
  168. }
  169. void RayTracingExampleComponent::CreateRayTracingAccelerationStructureObjects()
  170. {
  171. m_triangleRayTracingBlas = AZ::RHI::RayTracingBlas::CreateRHIRayTracingBlas();
  172. m_rectangleRayTracingBlas = AZ::RHI::RayTracingBlas::CreateRHIRayTracingBlas();
  173. m_rayTracingTlas = AZ::RHI::RayTracingTlas::CreateRHIRayTracingTlas();
  174. }
  175. void RayTracingExampleComponent::CreateRasterShader()
  176. {
  177. const char* shaderFilePath = "Shaders/RHI/RayTracingDraw.azshader";
  178. auto drawShader = LoadShader(shaderFilePath, RayTracingExampleName);
  179. AZ_Assert(drawShader, "Failed to load Draw shader");
  180. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  181. drawShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  182. pipelineDesc.m_inputStreamLayout = m_fullScreenInputStreamLayout;
  183. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  184. attachmentsBuilder.AddSubpass()->RenderTargetAttachment(m_outputFormat);
  185. RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  186. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create draw render attachment layout");
  187. m_drawPipelineState = drawShader->AcquirePipelineState(pipelineDesc);
  188. AZ_Assert(m_drawPipelineState, "Failed to acquire draw pipeline state");
  189. m_drawSRG = CreateShaderResourceGroup(drawShader, "BufferSrg", RayTracingExampleName);
  190. }
  191. void RayTracingExampleComponent::CreateRayTracingPipelineState()
  192. {
  193. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  194. // load ray generation shader
  195. const char* rayGenerationShaderFilePath = "Shaders/RHI/RayTracingDispatch.azshader";
  196. m_rayGenerationShader = LoadShader(rayGenerationShaderFilePath, RayTracingExampleName);
  197. AZ_Assert(m_rayGenerationShader, "Failed to load ray generation shader");
  198. auto rayGenerationShaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  199. RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
  200. rayGenerationShaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
  201. // load miss shader
  202. const char* missShaderFilePath = "Shaders/RHI/RayTracingMiss.azshader";
  203. m_missShader = LoadShader(missShaderFilePath, RayTracingExampleName);
  204. AZ_Assert(m_missShader, "Failed to load miss shader");
  205. auto missShaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  206. RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
  207. missShaderVariant.ConfigurePipelineState(missShaderDescriptor);
  208. // load closest hit gradient shader
  209. const char* closestHitGradientShaderFilePath = "Shaders/RHI/RayTracingClosestHitGradient.azshader";
  210. m_closestHitGradientShader = LoadShader(closestHitGradientShaderFilePath, RayTracingExampleName);
  211. AZ_Assert(m_closestHitGradientShader, "Failed to load closest hit gradient shader");
  212. auto closestHitGradientShaderVariant = m_closestHitGradientShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  213. RHI::PipelineStateDescriptorForRayTracing closestHitGradiantShaderDescriptor;
  214. closestHitGradientShaderVariant.ConfigurePipelineState(closestHitGradiantShaderDescriptor);
  215. // load closest hit solid shader
  216. const char* closestHitSolidShaderFilePath = "Shaders/RHI/RayTracingClosestHitSolid.azshader";
  217. m_closestHitSolidShader = LoadShader(closestHitSolidShaderFilePath, RayTracingExampleName);
  218. AZ_Assert(m_closestHitSolidShader, "Failed to load closest hit solid shader");
  219. auto closestHitSolidShaderVariant = m_closestHitSolidShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  220. RHI::PipelineStateDescriptorForRayTracing closestHitSolidShaderDescriptor;
  221. closestHitSolidShaderVariant.ConfigurePipelineState(closestHitSolidShaderDescriptor);
  222. // global pipeline state and srg
  223. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
  224. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  225. m_globalSrg = CreateShaderResourceGroup(m_rayGenerationShader, "RayTracingGlobalSrg", RayTracingExampleName);
  226. // build the ray tracing pipeline state descriptor
  227. RHI::RayTracingPipelineStateDescriptor descriptor;
  228. descriptor.Build()
  229. ->PipelineState(m_globalPipelineState.get())
  230. ->ShaderLibrary(rayGenerationShaderDescriptor)
  231. ->RayGenerationShaderName(AZ::Name("RayGenerationShader"))
  232. ->ShaderLibrary(missShaderDescriptor)
  233. ->MissShaderName(AZ::Name("MissShader"))
  234. ->ShaderLibrary(closestHitGradiantShaderDescriptor)
  235. ->ClosestHitShaderName(AZ::Name("ClosestHitGradientShader"))
  236. ->ShaderLibrary(closestHitSolidShaderDescriptor)
  237. ->ClosestHitShaderName(AZ::Name("ClosestHitSolidShader"))
  238. ->HitGroup(AZ::Name("HitGroupGradient"))
  239. ->ClosestHitShaderName(AZ::Name("ClosestHitGradientShader"))
  240. ->HitGroup(AZ::Name("HitGroupSolid"))
  241. ->ClosestHitShaderName(AZ::Name("ClosestHitSolidShader"));
  242. // create the ray tracing pipeline state object
  243. m_rayTracingPipelineState = RHI::Factory::Get().CreateRayTracingPipelineState();
  244. m_rayTracingPipelineState->Init(*device.get(), &descriptor);
  245. }
  246. void RayTracingExampleComponent::CreateRayTracingShaderTableScope()
  247. {
  248. struct ScopeData
  249. {
  250. };
  251. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface& scopeBuilder, [[maybe_unused]] ScopeData& scopeData)
  252. {
  253. };
  254. const auto compileFunction = [this]([[maybe_unused]] const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  255. {
  256. };
  257. const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  258. {
  259. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  260. // build the ray tracing shader table descriptor
  261. RHI::RayTracingShaderTableDescriptor descriptor;
  262. descriptor.Build(AZ::Name("RayTracingExampleShaderTable"), m_rayTracingPipelineState)
  263. ->RayGenerationRecord(AZ::Name("RayGenerationShader"))
  264. ->MissRecord(AZ::Name("MissShader"))
  265. ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle1
  266. ->HitGroupRecord(AZ::Name("HitGroupGradient")) // triangle2
  267. ->HitGroupRecord(AZ::Name("HitGroupSolid")) // triangle3
  268. ->HitGroupRecord(AZ::Name("HitGroupSolid")) // rectangle
  269. ;
  270. // initialize the ray tracing shader table object
  271. m_rayTracingShaderTable->Init(*device.get(), &descriptor, *m_rayTracingBufferPools);
  272. };
  273. // create the shader table once, outside of the scope
  274. m_rayTracingShaderTable = RHI::Factory::Get().CreateRayTracingShaderTable();
  275. m_scopeProducers.emplace_back(
  276. aznew RHI::ScopeProducerFunction<
  277. ScopeData,
  278. decltype(prepareFunction),
  279. decltype(compileFunction),
  280. decltype(executeFunction)>(
  281. RHI::ScopeId{ "RayTracingBuildShaderTable" },
  282. ScopeData{},
  283. prepareFunction,
  284. compileFunction,
  285. executeFunction));
  286. m_shaderTableScopeId = m_scopeProducers.back()->GetScopeId();
  287. }
  288. void RayTracingExampleComponent::CreateRayTracingAccelerationTableScope()
  289. {
  290. struct ScopeData
  291. {
  292. };
  293. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  294. {
  295. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  296. // create triangle BLAS buffer if necessary
  297. if (!m_triangleRayTracingBlas->IsValid())
  298. {
  299. RHI::StreamBufferView triangleVertexBufferView =
  300. {
  301. *m_triangleVB,
  302. 0,
  303. sizeof(m_triangleVertices),
  304. sizeof(VertexPosition)
  305. };
  306. RHI::IndexBufferView triangleIndexBufferView =
  307. {
  308. *m_triangleIB,
  309. 0,
  310. sizeof(m_triangleIndices),
  311. RHI::IndexFormat::Uint16
  312. };
  313. RHI::RayTracingBlasDescriptor triangleBlasDescriptor;
  314. triangleBlasDescriptor.Build()
  315. ->Geometry()
  316. ->VertexFormat(RHI::Format::R32G32B32_FLOAT)
  317. ->VertexBuffer(triangleVertexBufferView)
  318. ->IndexBuffer(triangleIndexBufferView)
  319. ;
  320. m_triangleRayTracingBlas->CreateBuffers(*device, &triangleBlasDescriptor, *m_rayTracingBufferPools);
  321. }
  322. // create rectangle BLAS if necessary
  323. if (!m_rectangleRayTracingBlas->IsValid())
  324. {
  325. RHI::StreamBufferView rectangleVertexBufferView =
  326. {
  327. *m_rectangleVB,
  328. 0,
  329. sizeof(m_rectangleVertices),
  330. sizeof(VertexPosition)
  331. };
  332. RHI::IndexBufferView rectangleIndexBufferView =
  333. {
  334. *m_rectangleIB,
  335. 0,
  336. sizeof(m_rectangleIndices),
  337. RHI::IndexFormat::Uint16
  338. };
  339. RHI::RayTracingBlasDescriptor rectangleBlasDescriptor;
  340. rectangleBlasDescriptor.Build()
  341. ->Geometry()
  342. ->VertexFormat(RHI::Format::R32G32B32_FLOAT)
  343. ->VertexBuffer(rectangleVertexBufferView)
  344. ->IndexBuffer(rectangleIndexBufferView)
  345. ;
  346. m_rectangleRayTracingBlas->CreateBuffers(*device, &rectangleBlasDescriptor, *m_rayTracingBufferPools);
  347. }
  348. m_time += 0.005f;
  349. // transforms
  350. AZ::Transform triangleTransform1 = AZ::Transform::CreateIdentity();
  351. triangleTransform1.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * -100.0f, 1.0f);
  352. triangleTransform1.MultiplyByScale(Vector3(100.0f, 100.0f, 100.0f));
  353. AZ::Transform triangleTransform2 = AZ::Transform::CreateIdentity();
  354. triangleTransform2.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * 100.0f, 2.0f);
  355. triangleTransform2.MultiplyByScale(Vector3(100.0f, 100.0f, 100.0f));
  356. AZ::Transform triangleTransform3 = AZ::Transform::CreateIdentity();
  357. triangleTransform3.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * 100.0f, 3.0f);
  358. triangleTransform3.MultiplyByScale(Vector3(100.0f, 100.0f, 100.0f));
  359. AZ::Transform rectangleTransform = AZ::Transform::CreateIdentity();
  360. rectangleTransform.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * -100.0f, 4.0f);
  361. rectangleTransform.MultiplyByScale(Vector3(100.0f, 100.0f, 100.0f));
  362. // create the TLAS
  363. RHI::RayTracingTlasDescriptor tlasDescriptor;
  364. tlasDescriptor.Build()
  365. ->Instance()
  366. ->InstanceID(0)
  367. ->HitGroupIndex(0)
  368. ->Blas(m_triangleRayTracingBlas)
  369. ->Transform(triangleTransform1)
  370. ->Instance()
  371. ->InstanceID(1)
  372. ->HitGroupIndex(1)
  373. ->Blas(m_triangleRayTracingBlas)
  374. ->Transform(triangleTransform2)
  375. ->Instance()
  376. ->InstanceID(2)
  377. ->HitGroupIndex(2)
  378. ->Blas(m_triangleRayTracingBlas)
  379. ->Transform(triangleTransform3)
  380. ->Instance()
  381. ->InstanceID(3)
  382. ->HitGroupIndex(3)
  383. ->Blas(m_rectangleRayTracingBlas)
  384. ->Transform(rectangleTransform)
  385. ;
  386. m_rayTracingTlas->CreateBuffers(*device, &tlasDescriptor, *m_rayTracingBufferPools);
  387. m_tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, (uint32_t)m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  388. };
  389. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  390. const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  391. {
  392. RHI::CommandList* commandList = context.GetCommandList();
  393. commandList->BuildBottomLevelAccelerationStructure(*m_triangleRayTracingBlas);
  394. commandList->BuildBottomLevelAccelerationStructure(*m_rectangleRayTracingBlas);
  395. commandList->BuildTopLevelAccelerationStructure(*m_rayTracingTlas);
  396. };
  397. m_scopeProducers.emplace_back(
  398. aznew RHI::ScopeProducerFunction<
  399. ScopeData,
  400. decltype(prepareFunction),
  401. decltype(compileFunction),
  402. decltype(executeFunction)>(
  403. RHI::ScopeId{ "RayTracingBuildAccelerationStructure" },
  404. ScopeData{},
  405. prepareFunction,
  406. compileFunction,
  407. executeFunction));
  408. }
  409. void RayTracingExampleComponent::CreateRayTracingDispatchScope()
  410. {
  411. struct ScopeData
  412. {
  413. };
  414. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  415. {
  416. // attach output image
  417. {
  418. RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportImage(m_outputImageAttachmentId, m_outputImage);
  419. AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import output image with error %d", result);
  420. RHI::ImageScopeAttachmentDescriptor desc;
  421. desc.m_attachmentId = m_outputImageAttachmentId;
  422. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  423. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  424. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  425. }
  426. // attach TLAS buffer
  427. if (m_rayTracingTlas->GetTlasBuffer())
  428. {
  429. RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_tlasBufferAttachmentId, m_rayTracingTlas->GetTlasBuffer());
  430. AZ_Error(RayTracingExampleName, result == RHI::ResultCode::Success, "Failed to import TLAS buffer with error %d", result);
  431. RHI::BufferScopeAttachmentDescriptor desc;
  432. desc.m_attachmentId = m_tlasBufferAttachmentId;
  433. desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
  434. desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  435. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::Read);
  436. }
  437. frameGraph.ExecuteAfter(m_shaderTableScopeId);
  438. frameGraph.SetEstimatedItemCount(1);
  439. };
  440. const auto compileFunction = [this]([[maybe_unused]] const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  441. {
  442. if (m_rayTracingTlas->GetTlasBuffer())
  443. {
  444. // set the TLAS and output image in the ray tracing global Srg
  445. RHI::ShaderInputBufferIndex tlasConstantIndex;
  446. FindShaderInputIndex(&tlasConstantIndex, m_globalSrg, AZ::Name{ "m_scene" }, RayTracingExampleName);
  447. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  448. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  449. m_globalSrg->SetBufferView(tlasConstantIndex, m_rayTracingTlas->GetTlasBuffer()->GetBufferView(bufferViewDescriptor).get());
  450. RHI::ShaderInputImageIndex outputConstantIndex;
  451. FindShaderInputIndex(&outputConstantIndex, m_globalSrg, AZ::Name{ "m_output" }, RayTracingExampleName);
  452. m_globalSrg->SetImageView(outputConstantIndex, m_outputImageView.get());
  453. // set hit shader data, each array element corresponds to the InstanceIndex() of the geometry in the TLAS
  454. // Note: this method is used instead of LocalRootSignatures for compatibility with non-DX12 platforms
  455. // set HitGradient values
  456. RHI::ShaderInputConstantIndex hitGradientDataConstantIndex;
  457. FindShaderInputIndex(&hitGradientDataConstantIndex, m_globalSrg, AZ::Name{"m_hitGradientData"}, RayTracingExampleName);
  458. struct HitGradientData
  459. {
  460. AZ::Vector4 m_color;
  461. };
  462. AZStd::array<HitGradientData, 4> hitGradientData = {{
  463. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f)}, // triangle1
  464. {AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle2
  465. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  466. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  467. }};
  468. m_globalSrg->SetConstantArray(hitGradientDataConstantIndex, hitGradientData);
  469. // set HitSolid values
  470. RHI::ShaderInputConstantIndex hitSolidDataConstantIndex;
  471. FindShaderInputIndex(&hitSolidDataConstantIndex, m_globalSrg, AZ::Name{"m_hitSolidData"}, RayTracingExampleName);
  472. struct HitSolidData
  473. {
  474. AZ::Vector4 m_color1;
  475. float m_lerp;
  476. float m_pad[3];
  477. AZ::Vector4 m_color2;
  478. };
  479. AZStd::array<HitSolidData, 4> hitSolidData = {{
  480. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  481. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f), 0.0f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  482. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle3
  483. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 1.0f, 1.0f)}, // rectangle
  484. }};
  485. m_globalSrg->SetConstantArray(hitSolidDataConstantIndex, hitSolidData);
  486. m_globalSrg->Compile();
  487. }
  488. };
  489. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  490. {
  491. if (!m_rayTracingTlas->GetTlasBuffer())
  492. {
  493. return;
  494. }
  495. RHI::CommandList* commandList = context.GetCommandList();
  496. RHI::DispatchRaysItem dispatchRaysItem;
  497. dispatchRaysItem.m_width = m_imageWidth;
  498. dispatchRaysItem.m_height = m_imageHeight;
  499. dispatchRaysItem.m_depth = 1;
  500. dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState.get();
  501. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable.get();
  502. dispatchRaysItem.m_globalSrg = m_globalSrg->GetRHIShaderResourceGroup();
  503. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState.get();
  504. // submit the DispatchRays item
  505. commandList->Submit(dispatchRaysItem);
  506. };
  507. m_scopeProducers.emplace_back(
  508. aznew RHI::ScopeProducerFunction<
  509. ScopeData,
  510. decltype(prepareFunction),
  511. decltype(compileFunction),
  512. decltype(executeFunction)>(
  513. RHI::ScopeId{ "RayTracingDispatch" },
  514. ScopeData{},
  515. prepareFunction,
  516. compileFunction,
  517. executeFunction));
  518. }
  519. void RayTracingExampleComponent::CreateRasterScope()
  520. {
  521. struct ScopeData
  522. {
  523. };
  524. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  525. {
  526. // attach swapchain
  527. {
  528. RHI::ImageScopeAttachmentDescriptor descriptor;
  529. descriptor.m_attachmentId = m_outputAttachmentId;
  530. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  531. frameGraph.UseColorAttachment(descriptor);
  532. }
  533. // attach output buffer
  534. {
  535. RHI::ImageScopeAttachmentDescriptor desc;
  536. desc.m_attachmentId = m_outputImageAttachmentId;
  537. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  538. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  539. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  540. const Name outputImageId{ "m_output" };
  541. RHI::ShaderInputImageIndex outputImageIndex = m_drawSRG->FindShaderInputImageIndex(outputImageId);
  542. AZ_Error(RayTracingExampleName, outputImageIndex.IsValid(), "Failed to find shader input image %s.", outputImageId.GetCStr());
  543. m_drawSRG->SetImageView(outputImageIndex, m_outputImageView.get());
  544. m_drawSRG->Compile();
  545. }
  546. frameGraph.SetEstimatedItemCount(1);
  547. };
  548. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  549. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  550. {
  551. RHI::CommandList* commandList = context.GetCommandList();
  552. commandList->SetViewports(&m_viewport, 1);
  553. commandList->SetScissors(&m_scissor, 1);
  554. RHI::DrawIndexed drawIndexed;
  555. drawIndexed.m_indexCount = 6;
  556. drawIndexed.m_instanceCount = 1;
  557. const RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_drawSRG->GetRHIShaderResourceGroup() };
  558. RHI::DrawItem drawItem;
  559. drawItem.m_arguments = drawIndexed;
  560. drawItem.m_pipelineState = m_drawPipelineState.get();
  561. drawItem.m_indexBufferView = &m_fullScreenIndexBufferView;
  562. drawItem.m_streamBufferViewCount = static_cast<uint32_t>(m_fullScreenStreamBufferViews.size());
  563. drawItem.m_streamBufferViews = m_fullScreenStreamBufferViews.data();
  564. drawItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
  565. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  566. // submit the triangle draw item.
  567. commandList->Submit(drawItem);
  568. };
  569. m_scopeProducers.emplace_back(
  570. aznew RHI::ScopeProducerFunction<
  571. ScopeData,
  572. decltype(prepareFunction),
  573. decltype(compileFunction),
  574. decltype(executeFunction)>(
  575. RHI::ScopeId{ "Raster" },
  576. ScopeData{},
  577. prepareFunction,
  578. compileFunction,
  579. executeFunction));
  580. }
  581. void RayTracingExampleComponent::Activate()
  582. {
  583. CreateResourcePools();
  584. CreateGeometry();
  585. CreateFullScreenBuffer();
  586. CreateOutputTexture();
  587. CreateRasterShader();
  588. CreateRayTracingAccelerationStructureObjects();
  589. CreateRayTracingPipelineState();
  590. CreateRayTracingShaderTableScope();
  591. CreateRayTracingAccelerationTableScope();
  592. CreateRayTracingDispatchScope();
  593. CreateRasterScope();
  594. RHI::RHISystemNotificationBus::Handler::BusConnect();
  595. }
  596. void RayTracingExampleComponent::Deactivate()
  597. {
  598. RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  599. m_windowContext = nullptr;
  600. m_scopeProducers.clear();
  601. }
  602. } // namespace AtomSampleViewer