RayTracingClusterExampleComponent.cpp 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RHI/RayTracingClusterExampleComponent.h>
  9. #include <Utils/Utils.h>
  10. #include <SampleComponentManager.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI/FrameGraphInterface.h>
  13. #include <Atom/RHI/RayTracingPipelineState.h>
  14. #include <Atom/RHI/RayTracingShaderTable.h>
  15. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  16. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  17. #include <Atom/RPI.Public/Shader/Shader.h>
  18. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  19. #include <AzCore/Serialization/SerializeContext.h>
  20. namespace
  21. {
  22. static const char* SampleName = "RayTracingClusterExample";
  23. }
  24. namespace AtomSampleViewer
  25. {
  26. void RayTracingClusterExampleComponent::Reflect(AZ::ReflectContext* context)
  27. {
  28. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  29. {
  30. serializeContext->Class<RayTracingClusterExampleComponent, AZ::Component>()
  31. ->Version(0)
  32. ;
  33. }
  34. }
  35. RayTracingClusterExampleComponent::RayTracingClusterExampleComponent()
  36. {
  37. m_supportRHISamplePipeline = true;
  38. }
  39. void RayTracingClusterExampleComponent::Activate()
  40. {
  41. CreateResourcePools();
  42. CreateGeometry();
  43. CreateFullScreenBuffer();
  44. CreateOutputTexture();
  45. CreateRasterShader();
  46. CreateRayTracingAccelerationStructureObjects();
  47. CreateRayTracingPipelineState();
  48. CreateRayTracingShaderTable();
  49. CreateRayTracingAccelerationTableScope();
  50. CreateRayTracingDispatchScope();
  51. CreateRasterScope();
  52. RHI::RHISystemNotificationBus::Handler::BusConnect();
  53. }
  54. void RayTracingClusterExampleComponent::Deactivate()
  55. {
  56. RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  57. m_windowContext = nullptr;
  58. m_scopeProducers.clear();
  59. }
  60. void RayTracingClusterExampleComponent::CreateResourcePools()
  61. {
  62. // create input assembly buffer pool
  63. {
  64. m_inputAssemblyBufferPool = aznew RHI::BufferPool();
  65. RHI::BufferPoolDescriptor bufferPoolDesc;
  66. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  67. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Host;
  68. [[maybe_unused]] RHI::ResultCode resultCode = m_inputAssemblyBufferPool->Init(bufferPoolDesc);
  69. AZ_Assert(resultCode == RHI::ResultCode::Success, "Failed to initialize input assembly buffer pool");
  70. }
  71. // create output image pool
  72. {
  73. RHI::ImagePoolDescriptor imagePoolDesc;
  74. imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShaderReadWrite;
  75. m_imagePool = aznew RHI::ImagePool();
  76. [[maybe_unused]] RHI::ResultCode result = m_imagePool->Init(imagePoolDesc);
  77. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image pool");
  78. }
  79. // initialize ray tracing buffer pools
  80. m_rayTracingBufferPools = aznew RHI::RayTracingBufferPools;
  81. m_rayTracingBufferPools->Init(RHI::MultiDevice::AllDevices);
  82. }
  83. void RayTracingClusterExampleComponent::CreateGeometry()
  84. {
  85. // triangle
  86. {
  87. // vertex buffer
  88. SetVertexPosition(m_triangleVertices.data(), 0, 0.0f, 0.5f, 1.0);
  89. SetVertexPosition(m_triangleVertices.data(), 1, 0.5f, -0.5f, 1.0);
  90. SetVertexPosition(m_triangleVertices.data(), 2, -0.5f, -0.5f, 1.0);
  91. m_triangleVB = aznew RHI::Buffer();
  92. m_triangleVB->SetName(AZ::Name("Triangle VB"));
  93. RHI::BufferInitRequest request;
  94. request.m_buffer = m_triangleVB.get();
  95. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleVertices) };
  96. request.m_initialData = m_triangleVertices.data();
  97. m_inputAssemblyBufferPool->InitBuffer(request);
  98. // index buffer
  99. SetVertexIndexIncreasing(m_triangleIndices.data(), m_triangleIndices.size());
  100. m_triangleIB = aznew RHI::Buffer();
  101. m_triangleIB->SetName(AZ::Name("Triangle IB"));
  102. request.m_buffer = m_triangleIB.get();
  103. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_triangleIndices) };
  104. request.m_initialData = m_triangleIndices.data();
  105. m_inputAssemblyBufferPool->InitBuffer(request);
  106. }
  107. // rectangle
  108. {
  109. // vertex buffer
  110. SetVertexPosition(m_rectangleVertices.data(), 0, -0.5f, 0.5f, 1.0);
  111. SetVertexPosition(m_rectangleVertices.data(), 1, 0.5f, 0.5f, 1.0);
  112. SetVertexPosition(m_rectangleVertices.data(), 2, 0.5f, -0.5f, 1.0);
  113. SetVertexPosition(m_rectangleVertices.data(), 3, -0.5f, -0.5f, 1.0);
  114. m_rectangleVB = aznew RHI::Buffer();
  115. m_rectangleVB->SetName(AZ::Name("Rectangle VB"));
  116. RHI::BufferInitRequest request;
  117. request.m_buffer = m_rectangleVB.get();
  118. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleVertices) };
  119. request.m_initialData = m_rectangleVertices.data();
  120. m_inputAssemblyBufferPool->InitBuffer(request);
  121. // index buffer
  122. m_rectangleIndices[0] = 0;
  123. m_rectangleIndices[1] = 1;
  124. m_rectangleIndices[2] = 2;
  125. m_rectangleIndices[3] = 0;
  126. m_rectangleIndices[4] = 2;
  127. m_rectangleIndices[5] = 3;
  128. m_rectangleIB = aznew RHI::Buffer();
  129. m_rectangleIB->SetName(AZ::Name("Rectangle IB"));
  130. request.m_buffer = m_rectangleIB.get();
  131. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(m_rectangleIndices) };
  132. request.m_initialData = m_rectangleIndices.data();
  133. m_inputAssemblyBufferPool->InitBuffer(request);
  134. }
  135. // clusters
  136. {
  137. AZStd::vector<ClusterVertexType> clusterVertices;
  138. AZStd::vector<ClusterIndexType> clusterTriangles;
  139. RHI::RayTracingClasBuildTriangleClusterInfoExpanded commonClusterInfo;
  140. commonClusterInfo.m_clusterFlags = RHI::RayTracingClasClusterFlags::AllowDisableOpacityMicromaps;
  141. commonClusterInfo.m_positionTruncateBitCount = 0;
  142. commonClusterInfo.m_geometryFlags = RHI::RayTracingClasGeometryFlags::Opaque;
  143. commonClusterInfo.m_indexType = RHI::RayTracingClasIndexFormat::UINT32;
  144. commonClusterInfo.m_indexBufferStride = 0;
  145. commonClusterInfo.m_vertexBufferStride = 0;
  146. commonClusterInfo.m_geometryIndexAndFlagsBufferStride = 0;
  147. commonClusterInfo.m_opacityMicromapIndexBufferStride = 0;
  148. commonClusterInfo.m_opacityMicromapArrayAddress = 0;
  149. commonClusterInfo.m_opacityMicromapIndexBufferAddress = 0;
  150. commonClusterInfo.m_geometryIndexAndFlagsBufferAddress = 0;
  151. // Cluster 1: Quad with size 2x1 centered at (0,0)
  152. {
  153. auto& quadClusterInfo = m_clusterSourceInfosExpanded.emplace_back(commonClusterInfo);
  154. quadClusterInfo.m_clusterID = 0;
  155. quadClusterInfo.m_vertexCount = 4;
  156. quadClusterInfo.m_triangleCount = 2;
  157. quadClusterInfo.m_baseGeometryIndex = 0;
  158. clusterVertices.emplace_back(-1.f, -0.5f, 1.f);
  159. clusterVertices.emplace_back(-1.f, 0.5f, 1.f);
  160. clusterVertices.emplace_back(1.f, 0.5f, 1.f);
  161. clusterVertices.emplace_back(1.f, -0.5f, 1.f);
  162. clusterTriangles.emplace_back(0, 1, 2);
  163. clusterTriangles.emplace_back(0, 2, 3);
  164. }
  165. // Cluster 2: Regular pentagon with radius 1 centered at (3,0) and pointing upwards
  166. {
  167. auto& pentagonClusterInfo = m_clusterSourceInfosExpanded.emplace_back(commonClusterInfo);
  168. pentagonClusterInfo.m_clusterID = 1;
  169. pentagonClusterInfo.m_vertexCount = 5;
  170. pentagonClusterInfo.m_triangleCount = 3;
  171. pentagonClusterInfo.m_baseGeometryIndex = 1;
  172. clusterVertices.emplace_back(3.f + 0.f, 1.f, 1.f);
  173. clusterVertices.emplace_back(3.f + 0.951f, 0.309f, 1.f);
  174. clusterVertices.emplace_back(3.f + 0.588f, -0.809f, 1.f);
  175. clusterVertices.emplace_back(3.f - 0.588f, -0.809f, 1.f);
  176. clusterVertices.emplace_back(3.f - 0.951f, 0.309f, 1.f);
  177. clusterTriangles.emplace_back(4, 5, 6);
  178. clusterTriangles.emplace_back(4, 6, 7);
  179. clusterTriangles.emplace_back(4, 7, 8);
  180. }
  181. // Cluster 3: The text "CLAS" written in the pixel font "CG pixel 4x5" (16 rectangles -> 64 vertices, 32 triangles)
  182. // Font source: https://fontstruct.com/fontstructions/show/1404171/cg-pixel-4x5 (License: Public domain)
  183. // 0 5 9 14
  184. // 4 ## # ## ###
  185. // 3 # # # # # #
  186. // 2 # # #### ##
  187. // 1 # # # # # #
  188. // 0 ## ### # # ###
  189. {
  190. auto& textClusterInfo = m_clusterSourceInfosExpanded.emplace_back(commonClusterInfo);
  191. textClusterInfo.m_clusterID = 2;
  192. textClusterInfo.m_vertexCount = 64;
  193. textClusterInfo.m_triangleCount = 32;
  194. textClusterInfo.m_baseGeometryIndex = 2;
  195. auto AddRectangle = [&](int gridX, int gridY, int gridWidth, int gridHeight)
  196. {
  197. float scale = 0.2f;
  198. float x = static_cast<float>(gridX) * scale - 2.f;
  199. float y = static_cast<float>(gridY) * scale + 1.5f;
  200. float width = static_cast<float>(gridWidth) * scale;
  201. float height = static_cast<float>(gridHeight) * scale;
  202. uint32_t textIndexOffset{ aznumeric_cast<uint32_t>(clusterVertices.size()) };
  203. clusterVertices.emplace_back(x, y, 1.f);
  204. clusterVertices.emplace_back(x, y + height, 1.f);
  205. clusterVertices.emplace_back(x + width, y + height, 1.f);
  206. clusterVertices.emplace_back(x + width, y, 1.f);
  207. clusterTriangles.emplace_back(textIndexOffset, textIndexOffset + 1, textIndexOffset + 2);
  208. clusterTriangles.emplace_back(textIndexOffset, textIndexOffset + 2, textIndexOffset + 3);
  209. };
  210. // Letter "C"
  211. AddRectangle(3, 1, 1, 1);
  212. AddRectangle(1, 0, 2, 1);
  213. AddRectangle(0, 1, 1, 3);
  214. AddRectangle(1, 4, 2, 1);
  215. AddRectangle(3, 3, 1, 1);
  216. // Letter "L"
  217. AddRectangle(5, 1, 1, 4);
  218. AddRectangle(5, 0, 3, 1);
  219. // Letter "A"
  220. AddRectangle(9, 0, 1, 4);
  221. AddRectangle(10, 4, 2, 1);
  222. AddRectangle(12, 0, 1, 4);
  223. AddRectangle(10, 2, 2, 1);
  224. // Letter "S"
  225. AddRectangle(15, 4, 3, 1);
  226. AddRectangle(14, 3, 1, 1);
  227. AddRectangle(15, 2, 2, 1);
  228. AddRectangle(17, 1, 1, 1);
  229. AddRectangle(14, 0, 3, 1);
  230. }
  231. // Create cluster vertex buffer
  232. {
  233. m_clusterVertexBuffer = aznew RHI::Buffer();
  234. m_clusterVertexBuffer->SetName(Name("Cluster vertex buffer"));
  235. RHI::BufferInitRequest request;
  236. request.m_buffer = m_clusterVertexBuffer.get();
  237. request.m_descriptor.m_byteCount = clusterVertices.size() * sizeof(ClusterVertexType);
  238. request.m_descriptor.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  239. request.m_initialData = clusterVertices.data();
  240. m_inputAssemblyBufferPool->InitBuffer(request);
  241. }
  242. // Create cluster index buffer
  243. {
  244. m_clusterIndexBuffer = aznew RHI::Buffer();
  245. m_clusterIndexBuffer->SetName(Name("Cluster index buffer"));
  246. RHI::BufferInitRequest request;
  247. request.m_buffer = m_clusterIndexBuffer.get();
  248. request.m_descriptor.m_byteCount = clusterTriangles.size() * sizeof(ClusterIndexType);
  249. request.m_descriptor.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  250. request.m_initialData = clusterTriangles.data();
  251. m_inputAssemblyBufferPool->InitBuffer(request);
  252. }
  253. // Calculate upper bound data for CLAS
  254. for (const auto& clusterSourceInfoExpanded : m_clusterSourceInfosExpanded)
  255. {
  256. m_maxClusterTriangleCount = AZStd::max(m_maxClusterTriangleCount, clusterSourceInfoExpanded.m_triangleCount);
  257. m_maxClusterVertexCount = AZStd::max(m_maxClusterVertexCount, clusterSourceInfoExpanded.m_vertexCount);
  258. m_maxGeometryIndex = AZStd::max(m_maxGeometryIndex, clusterSourceInfoExpanded.m_baseGeometryIndex);
  259. }
  260. m_maxTotalTriangleCount = aznumeric_cast<uint32_t>(clusterTriangles.size());
  261. m_maxTotalVertexCount = aznumeric_cast<uint32_t>(clusterVertices.size());
  262. // Create srcInfosArray buffer
  263. {
  264. m_srcInfosArrayBuffer = aznew RHI::Buffer();
  265. m_srcInfosArrayBuffer->SetName(Name("Source infos array buffer"));
  266. RHI::BufferInitRequest request;
  267. request.m_buffer = m_srcInfosArrayBuffer.get();
  268. request.m_descriptor.m_byteCount =
  269. m_clusterSourceInfosExpanded.size() * sizeof(RHI::RayTracingClasBuildTriangleClusterInfo);
  270. request.m_descriptor.m_bindFlags = m_rayTracingBufferPools->GetSrcInfosArrayBufferPool()->GetDescriptor().m_bindFlags;
  271. // Buffer data is populated in CreateRayTracingAccelerationTableScope
  272. m_rayTracingBufferPools->GetSrcInfosArrayBufferPool()->InitBuffer(request);
  273. }
  274. }
  275. }
  276. void RayTracingClusterExampleComponent::CreateFullScreenBuffer()
  277. {
  278. FullScreenBufferData bufferData;
  279. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  280. m_fullScreenInputAssemblyBuffer = aznew RHI::Buffer();
  281. RHI::BufferInitRequest request;
  282. request.m_buffer = m_fullScreenInputAssemblyBuffer.get();
  283. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  284. request.m_initialData = &bufferData;
  285. m_inputAssemblyBufferPool->InitBuffer(request);
  286. m_geometryView.SetDrawArguments(RHI::DrawIndexed(0, 6, 0));
  287. m_geometryView.AddStreamBufferView({
  288. *m_fullScreenInputAssemblyBuffer,
  289. offsetof(FullScreenBufferData, m_positions),
  290. sizeof(FullScreenBufferData::m_positions),
  291. sizeof(VertexPosition)
  292. });
  293. m_geometryView.AddStreamBufferView({
  294. *m_fullScreenInputAssemblyBuffer,
  295. offsetof(FullScreenBufferData, m_uvs),
  296. sizeof(FullScreenBufferData::m_uvs),
  297. sizeof(VertexUV)
  298. });
  299. m_geometryView.SetIndexBufferView({
  300. *m_fullScreenInputAssemblyBuffer,
  301. offsetof(FullScreenBufferData, m_indices),
  302. sizeof(FullScreenBufferData::m_indices),
  303. RHI::IndexFormat::Uint16
  304. });
  305. RHI::InputStreamLayoutBuilder layoutBuilder;
  306. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  307. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  308. m_fullScreenInputStreamLayout = layoutBuilder.End();
  309. }
  310. void RayTracingClusterExampleComponent::CreateOutputTexture()
  311. {
  312. // create output image
  313. m_outputImage = aznew RHI::Image();
  314. RHI::ImageInitRequest request;
  315. request.m_image = m_outputImage.get();
  316. request.m_descriptor = RHI::ImageDescriptor::Create2D(RHI::ImageBindFlags::ShaderReadWrite, m_imageWidth, m_imageHeight, RHI::Format::R8G8B8A8_UNORM);
  317. [[maybe_unused]] RHI::ResultCode result = m_imagePool->InitImage(request);
  318. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialize output image");
  319. m_outputImageViewDescriptor = RHI::ImageViewDescriptor::Create(RHI::Format::R8G8B8A8_UNORM, 0, 0);
  320. m_outputImageView = m_outputImage->GetImageView(m_outputImageViewDescriptor);
  321. AZ_Assert(m_outputImageView.get(), "Failed to create output image view");
  322. AZ_Assert(m_outputImageView->GetDeviceImageView(RHI::MultiDevice::DefaultDeviceIndex)->IsFullView(), "Image View initialization IsFullView() failed");
  323. }
  324. void RayTracingClusterExampleComponent::CreateRayTracingAccelerationStructureObjects()
  325. {
  326. m_triangleRayTracingBlas = aznew AZ::RHI::RayTracingBlas;
  327. m_rectangleRayTracingBlas = aznew AZ::RHI::RayTracingBlas;
  328. m_clusterRayTracingBlas = aznew AZ::RHI::RayTracingClusterBlas;
  329. m_rayTracingTlas = aznew AZ::RHI::RayTracingTlas;
  330. }
  331. void RayTracingClusterExampleComponent::CreateRasterShader()
  332. {
  333. const char* shaderFilePath = "Shaders/RHI/RayTracingDraw.azshader";
  334. auto drawShader = LoadShader(shaderFilePath, SampleName);
  335. AZ_Assert(drawShader, "Failed to load Draw shader");
  336. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  337. drawShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  338. pipelineDesc.m_inputStreamLayout = m_fullScreenInputStreamLayout;
  339. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  340. attachmentsBuilder.AddSubpass()->RenderTargetAttachment(m_outputFormat);
  341. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  342. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create draw render attachment layout");
  343. m_drawPipelineState = drawShader->AcquirePipelineState(pipelineDesc);
  344. AZ_Assert(m_drawPipelineState, "Failed to acquire draw pipeline state");
  345. m_drawSRG = CreateShaderResourceGroup(drawShader, "BufferSrg", SampleName);
  346. }
  347. void RayTracingClusterExampleComponent::CreateRayTracingPipelineState()
  348. {
  349. // load ray generation shader
  350. const char* rayGenerationShaderFilePath = "Shaders/RHI/RayTracingDispatch.azshader";
  351. m_rayGenerationShader = LoadShader(rayGenerationShaderFilePath, SampleName);
  352. AZ_Assert(m_rayGenerationShader, "Failed to load ray generation shader");
  353. auto rayGenerationShaderVariant = m_rayGenerationShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  354. RHI::PipelineStateDescriptorForRayTracing rayGenerationShaderDescriptor;
  355. rayGenerationShaderVariant.ConfigurePipelineState(rayGenerationShaderDescriptor);
  356. // load miss shader
  357. const char* missShaderFilePath = "Shaders/RHI/RayTracingMiss.azshader";
  358. m_missShader = LoadShader(missShaderFilePath, SampleName);
  359. AZ_Assert(m_missShader, "Failed to load miss shader");
  360. auto missShaderVariant = m_missShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  361. RHI::PipelineStateDescriptorForRayTracing missShaderDescriptor;
  362. missShaderVariant.ConfigurePipelineState(missShaderDescriptor);
  363. // load closest hit gradient shader
  364. const char* closestHitGradientShaderFilePath = "Shaders/RHI/RayTracingClosestHitGradient.azshader";
  365. m_closestHitGradientShader = LoadShader(closestHitGradientShaderFilePath, SampleName);
  366. AZ_Assert(m_closestHitGradientShader, "Failed to load closest hit gradient shader");
  367. auto closestHitGradientShaderVariant = m_closestHitGradientShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  368. RHI::PipelineStateDescriptorForRayTracing closestHitGradiantShaderDescriptor;
  369. closestHitGradientShaderVariant.ConfigurePipelineState(closestHitGradiantShaderDescriptor);
  370. // load closest hit solid shader
  371. const char* closestHitSolidShaderFilePath = "Shaders/RHI/RayTracingClosestHitSolid.azshader";
  372. m_closestHitSolidShader = LoadShader(closestHitSolidShaderFilePath, SampleName);
  373. AZ_Assert(m_closestHitSolidShader, "Failed to load closest hit solid shader");
  374. auto closestHitSolidShaderVariant = m_closestHitSolidShader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId);
  375. RHI::PipelineStateDescriptorForRayTracing closestHitSolidShaderDescriptor;
  376. closestHitSolidShaderVariant.ConfigurePipelineState(closestHitSolidShaderDescriptor);
  377. // global pipeline state and srg
  378. m_globalPipelineState = m_rayGenerationShader->AcquirePipelineState(rayGenerationShaderDescriptor);
  379. AZ_Assert(m_globalPipelineState, "Failed to acquire ray tracing global pipeline state");
  380. m_globalSrg = CreateShaderResourceGroup(m_rayGenerationShader, "RayTracingGlobalSrg", SampleName);
  381. // build the ray tracing pipeline state descriptor
  382. RHI::RayTracingPipelineStateDescriptor descriptor;
  383. descriptor.m_pipelineState = m_globalPipelineState.get();
  384. descriptor.AddRayGenerationShaderLibrary(rayGenerationShaderDescriptor, Name("RayGenerationShader"));
  385. descriptor.AddMissShaderLibrary(missShaderDescriptor, Name("MissShader"));
  386. descriptor.AddClosestHitShaderLibrary(closestHitGradiantShaderDescriptor, Name("ClosestHitGradientShader"));
  387. descriptor.AddClosestHitShaderLibrary(closestHitSolidShaderDescriptor, Name("ClosestHitSolidShader"));
  388. descriptor.AddHitGroup(Name("HitGroupGradient"), Name("ClosestHitGradientShader"));
  389. descriptor.AddHitGroup(Name("HitGroupSolid"), Name("ClosestHitSolidShader"));
  390. // create the ray tracing pipeline state object
  391. m_rayTracingPipelineState = aznew RHI::RayTracingPipelineState;
  392. m_rayTracingPipelineState->Init(RHI::MultiDevice::AllDevices, descriptor);
  393. }
  394. void RayTracingClusterExampleComponent::CreateRayTracingShaderTable()
  395. {
  396. m_rayTracingShaderTable = aznew RHI::RayTracingShaderTable;
  397. m_rayTracingShaderTable->Init(RHI::MultiDevice::AllDevices, *m_rayTracingBufferPools);
  398. }
  399. void RayTracingClusterExampleComponent::CreateRayTracingAccelerationTableScope()
  400. {
  401. struct ScopeData
  402. {
  403. };
  404. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  405. {
  406. // create triangle BLAS buffer if necessary
  407. if (!m_triangleRayTracingBlas->IsValid())
  408. {
  409. RHI::StreamBufferView triangleVertexBufferView =
  410. {
  411. *m_triangleVB,
  412. 0,
  413. sizeof(m_triangleVertices),
  414. sizeof(VertexPosition)
  415. };
  416. RHI::IndexBufferView triangleIndexBufferView =
  417. {
  418. *m_triangleIB,
  419. 0,
  420. sizeof(m_triangleIndices),
  421. RHI::IndexFormat::Uint16
  422. };
  423. RHI::RayTracingBlasDescriptor triangleBlasDescriptor;
  424. RHI::RayTracingGeometry& triangleBlasGeometry = triangleBlasDescriptor.m_geometries.emplace_back();
  425. triangleBlasGeometry.m_vertexFormat = RHI::VertexFormat::R32G32B32_FLOAT;
  426. triangleBlasGeometry.m_vertexBuffer = triangleVertexBufferView;
  427. triangleBlasGeometry.m_indexBuffer = triangleIndexBufferView;
  428. m_triangleRayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &triangleBlasDescriptor, *m_rayTracingBufferPools);
  429. }
  430. // create rectangle BLAS if necessary
  431. if (!m_rectangleRayTracingBlas->IsValid())
  432. {
  433. RHI::StreamBufferView rectangleVertexBufferView =
  434. {
  435. *m_rectangleVB,
  436. 0,
  437. sizeof(m_rectangleVertices),
  438. sizeof(VertexPosition)
  439. };
  440. RHI::IndexBufferView rectangleIndexBufferView =
  441. {
  442. *m_rectangleIB,
  443. 0,
  444. sizeof(m_rectangleIndices),
  445. RHI::IndexFormat::Uint16
  446. };
  447. RHI::RayTracingBlasDescriptor rectangleBlasDescriptor;
  448. RHI::RayTracingGeometry& rectangleBlasGeometry = rectangleBlasDescriptor.m_geometries.emplace_back();
  449. rectangleBlasGeometry.m_vertexFormat = RHI::VertexFormat::R32G32B32_FLOAT;
  450. rectangleBlasGeometry.m_vertexBuffer = rectangleVertexBufferView;
  451. rectangleBlasGeometry.m_indexBuffer = rectangleIndexBufferView;
  452. m_rectangleRayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &rectangleBlasDescriptor, *m_rayTracingBufferPools);
  453. }
  454. if (!m_clusterRayTracingBlasInitialized)
  455. {
  456. m_clusterRayTracingBlasInitialized = true;
  457. RHI::RayTracingClusterBlasDescriptor clusterBlasDescriptor;
  458. clusterBlasDescriptor.m_vertexFormat = AZ::RHI::VertexFormat::R32G32B32_FLOAT;
  459. clusterBlasDescriptor.m_maxGeometryIndexValue = m_maxGeometryIndex;
  460. clusterBlasDescriptor.m_maxClusterUniqueGeometryCount = 1;
  461. clusterBlasDescriptor.m_maxClusterTriangleCount = m_maxClusterTriangleCount;
  462. clusterBlasDescriptor.m_maxClusterVertexCount = m_maxClusterVertexCount;
  463. clusterBlasDescriptor.m_maxTotalTriangleCount = m_maxTotalTriangleCount;
  464. clusterBlasDescriptor.m_maxTotalVertexCount = m_maxTotalVertexCount;
  465. clusterBlasDescriptor.m_minPositionTruncateBitCount = 0;
  466. clusterBlasDescriptor.m_maxClusterCount = aznumeric_cast<uint32_t>(m_clusterSourceInfosExpanded.size());
  467. clusterBlasDescriptor.m_srcInfosArrayBufferView = m_srcInfosArrayBuffer->GetBufferView(
  468. RHI::BufferViewDescriptor::CreateStructured(
  469. 0, aznumeric_cast<uint32_t>(m_clusterSourceInfosExpanded.size()),
  470. sizeof(RHI::RayTracingClasBuildTriangleClusterInfo)));
  471. m_clusterRayTracingBlas->CreateBuffers(RHI::MultiDevice::AllDevices, &clusterBlasDescriptor, *m_rayTracingBufferPools);
  472. m_clusterRayTracingBlas->IterateDevices(
  473. RHI::MultiDevice::AllDevices,
  474. [&](int deviceIndex)
  475. {
  476. auto deviceClusterBuffers = m_clusterRayTracingBlas->GetDeviceRayTracingClusterBlas(deviceIndex);
  477. auto deviceBufferPool = m_rayTracingBufferPools->GetSrcInfosArrayBufferPool()->GetDeviceBufferPool(deviceIndex);
  478. uint64_t deviceVertexBufferAddress = m_clusterVertexBuffer->GetDeviceBuffer(deviceIndex)->GetDeviceAddress();
  479. uint64_t deviceIndexBufferAddress = m_clusterIndexBuffer->GetDeviceBuffer(deviceIndex)->GetDeviceAddress();
  480. RHI::DeviceBufferMapRequest request;
  481. request.m_buffer = m_srcInfosArrayBuffer->GetDeviceBuffer(deviceIndex).get();
  482. request.m_byteCount = m_clusterSourceInfosExpanded.size() * sizeof(RHI::RayTracingClasBuildTriangleClusterInfo);
  483. request.m_byteOffset = 0;
  484. RHI::DeviceBufferMapResponse response;
  485. RHI::ResultCode result = deviceBufferPool->MapBuffer(request, response);
  486. AZ_Assert(result == AZ::RHI::ResultCode::Success, "Failed to map SrcInfosArrayBuffer");
  487. auto* gpuClusterInfo = reinterpret_cast<RHI::RayTracingClasBuildTriangleClusterInfo*>(response.m_data);
  488. for (auto clusterSourceInfoExpanded : m_clusterSourceInfosExpanded)
  489. {
  490. clusterSourceInfoExpanded.m_vertexBufferAddress = deviceVertexBufferAddress;
  491. clusterSourceInfoExpanded.m_indexBufferAddress = deviceIndexBufferAddress;
  492. *gpuClusterInfo = RHI::RayTracingClasConvertBuildTriangleClusterInfo(clusterSourceInfoExpanded);
  493. gpuClusterInfo++;
  494. deviceIndexBufferAddress += clusterSourceInfoExpanded.m_triangleCount * sizeof(ClusterIndexType);
  495. }
  496. deviceBufferPool->UnmapBuffer(*request.m_buffer);
  497. return true;
  498. });
  499. }
  500. m_time += 0.005f;
  501. // transforms
  502. AZ::Transform triangleTransform1 = AZ::Transform::CreateIdentity();
  503. triangleTransform1.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * -100.0f, 1.0f);
  504. triangleTransform1.MultiplyByUniformScale(100.0f);
  505. AZ::Transform triangleTransform2 = AZ::Transform::CreateIdentity();
  506. triangleTransform2.SetTranslation(sinf(m_time) * -100.0f, cosf(m_time) * 100.0f, 2.0f);
  507. triangleTransform2.MultiplyByUniformScale(100.0f);
  508. AZ::Transform rectangleTransform = AZ::Transform::CreateIdentity();
  509. rectangleTransform.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * -100.0f, 4.0f);
  510. rectangleTransform.MultiplyByUniformScale(100.0f);
  511. AZ::Transform clusterTransform = AZ::Transform::CreateIdentity();
  512. clusterTransform.SetTranslation(sinf(m_time) * 100.0f, cosf(m_time) * 100.0f, 3.0f);
  513. clusterTransform.MultiplyByUniformScale(100.0f);
  514. // create the TLAS
  515. auto deviceMask = RHI::MultiDevice::AllDevices;
  516. AZStd::unordered_map<int, RHI::DeviceRayTracingTlasDescriptor> tlasDescriptor;
  517. RHI::MultiDeviceObject::IterateDevices(
  518. deviceMask,
  519. [&](int deviceIndex)
  520. {
  521. {
  522. auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
  523. tlasInstance.m_instanceID = 0;
  524. tlasInstance.m_hitGroupIndex = 0;
  525. tlasInstance.m_blas = m_triangleRayTracingBlas->GetDeviceRayTracingBlas(deviceIndex);
  526. tlasInstance.m_transform = triangleTransform1;
  527. }
  528. {
  529. auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
  530. tlasInstance.m_instanceID = 1;
  531. tlasInstance.m_hitGroupIndex = 1;
  532. tlasInstance.m_blas = m_triangleRayTracingBlas->GetDeviceRayTracingBlas(deviceIndex);
  533. tlasInstance.m_transform = triangleTransform2;
  534. }
  535. {
  536. auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
  537. tlasInstance.m_instanceID = 2;
  538. tlasInstance.m_hitGroupIndex = 2;
  539. tlasInstance.m_blas = m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(deviceIndex);
  540. tlasInstance.m_transform = rectangleTransform;
  541. }
  542. {
  543. auto& tlasInstance = tlasDescriptor[deviceIndex].m_instances.emplace_back();
  544. tlasInstance.m_instanceID = 3;
  545. tlasInstance.m_hitGroupIndex = 3;
  546. tlasInstance.m_clusterBlas = m_clusterRayTracingBlas->GetDeviceRayTracingClusterBlas(deviceIndex);
  547. tlasInstance.m_transform = clusterTransform;
  548. }
  549. return true;
  550. });
  551. m_rayTracingTlas->CreateBuffers(deviceMask, tlasDescriptor, *m_rayTracingBufferPools);
  552. m_tlasBufferViewDescriptor = RHI::BufferViewDescriptor::CreateRaw(0, (uint32_t)m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  553. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_tlasBufferAttachmentId, m_rayTracingTlas->GetTlasBuffer());
  554. AZ_Error(SampleName, result == RHI::ResultCode::Success, "Failed to import TLAS buffer with error %d", result);
  555. RHI::BufferScopeAttachmentDescriptor desc;
  556. desc.m_attachmentId = m_tlasBufferAttachmentId;
  557. desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
  558. desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  559. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::AnyGraphics);
  560. };
  561. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  562. const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  563. {
  564. RHI::CommandList* commandList = context.GetCommandList();
  565. commandList->BuildClusterAccelerationStructures(*m_clusterRayTracingBlas->GetDeviceRayTracingClusterBlas(context.GetDeviceIndex()));
  566. commandList->BuildBottomLevelAccelerationStructure(*m_triangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  567. commandList->BuildBottomLevelAccelerationStructure(*m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
  568. commandList->BuildClusterBottomLevelAccelerationStructures({ m_clusterRayTracingBlas->GetDeviceRayTracingClusterBlas(context.GetDeviceIndex()).get() });
  569. commandList->BuildTopLevelAccelerationStructure(
  570. *m_rayTracingTlas->GetDeviceRayTracingTlas(context.GetDeviceIndex()),
  571. { m_triangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get(), m_rectangleRayTracingBlas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get() },
  572. { m_clusterRayTracingBlas->GetDeviceRayTracingClusterBlas(context.GetDeviceIndex()).get() });
  573. };
  574. m_scopeProducers.emplace_back(
  575. aznew RHI::ScopeProducerFunction<
  576. ScopeData,
  577. decltype(prepareFunction),
  578. decltype(compileFunction),
  579. decltype(executeFunction)>(
  580. RHI::ScopeId{ "RayTracingBuildAccelerationStructure" },
  581. ScopeData{},
  582. prepareFunction,
  583. compileFunction,
  584. executeFunction));
  585. }
  586. void RayTracingClusterExampleComponent::CreateRayTracingDispatchScope()
  587. {
  588. struct ScopeData
  589. {
  590. };
  591. const auto prepareFunction = [this]([[maybe_unused]] RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  592. {
  593. // attach output image
  594. {
  595. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportImage(m_outputImageAttachmentId, m_outputImage);
  596. AZ_Error(SampleName, result == RHI::ResultCode::Success, "Failed to import output image with error %d", result);
  597. RHI::ImageScopeAttachmentDescriptor desc;
  598. desc.m_attachmentId = m_outputImageAttachmentId;
  599. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  600. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  601. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::RayTracingShader);
  602. }
  603. // attach TLAS buffer
  604. if (m_rayTracingTlas->GetTlasBuffer())
  605. {
  606. RHI::BufferScopeAttachmentDescriptor desc;
  607. desc.m_attachmentId = m_tlasBufferAttachmentId;
  608. desc.m_bufferViewDescriptor = m_tlasBufferViewDescriptor;
  609. desc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  610. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::RayTracingShader);
  611. }
  612. frameGraph.SetEstimatedItemCount(1);
  613. };
  614. const auto compileFunction = [this]([[maybe_unused]] const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  615. {
  616. if (m_rayTracingTlas->GetTlasBuffer())
  617. {
  618. // set the TLAS and output image in the ray tracing global Srg
  619. RHI::ShaderInputBufferIndex tlasConstantIndex;
  620. FindShaderInputIndex(&tlasConstantIndex, m_globalSrg, AZ::Name{ "m_scene" }, SampleName);
  621. uint32_t tlasBufferByteCount = aznumeric_cast<uint32_t>(m_rayTracingTlas->GetTlasBuffer()->GetDescriptor().m_byteCount);
  622. RHI::BufferViewDescriptor bufferViewDescriptor = RHI::BufferViewDescriptor::CreateRayTracingTLAS(tlasBufferByteCount);
  623. m_globalSrg->SetBufferView(tlasConstantIndex, m_rayTracingTlas->GetTlasBuffer()->GetBufferView(bufferViewDescriptor).get());
  624. RHI::ShaderInputImageIndex outputConstantIndex;
  625. FindShaderInputIndex(&outputConstantIndex, m_globalSrg, AZ::Name{ "m_output" }, SampleName);
  626. m_globalSrg->SetImageView(outputConstantIndex, m_outputImageView.get());
  627. // set hit shader data, each array element corresponds to the InstanceIndex() of the geometry in the TLAS
  628. // Note: this method is used instead of LocalRootSignatures for compatibility with non-DX12 platforms
  629. // set HitGradient values
  630. RHI::ShaderInputConstantIndex hitGradientDataConstantIndex;
  631. FindShaderInputIndex(&hitGradientDataConstantIndex, m_globalSrg, AZ::Name{"m_hitGradientData"}, SampleName);
  632. struct HitGradientData
  633. {
  634. AZ::Vector4 m_color;
  635. };
  636. AZStd::array<HitGradientData, 4> hitGradientData = {{
  637. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f)}, // unused
  638. {AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // unused
  639. {AZ::Vector4(0.0f, 0.0f, 0.0f, 0.0f)}, // unused
  640. {AZ::Vector4(1.0f, 1.0f, 1.0f, 0.0f)}, // cluster
  641. }};
  642. m_globalSrg->SetConstantArray(hitGradientDataConstantIndex, hitGradientData);
  643. // set HitSolid values
  644. RHI::ShaderInputConstantIndex hitSolidDataConstantIndex;
  645. FindShaderInputIndex(&hitSolidDataConstantIndex, m_globalSrg, AZ::Name{"m_hitSolidData"}, SampleName);
  646. struct HitSolidData
  647. {
  648. AZ::Vector4 m_color1;
  649. float m_lerp;
  650. float m_pad[3];
  651. AZ::Vector4 m_color2;
  652. };
  653. AZStd::array<HitSolidData, 4> hitSolidData = {{
  654. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(1.0f, 0.0f, 0.0f, 0.0f)}, // triangle1
  655. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 1.0f, 0.0f, 1.0f)}, // triangle2
  656. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 1.0f, 1.0f)}, // rectangle
  657. {AZ::Vector4(1.0f, 0.0f, 0.0f, 1.0f), 0.5f, {0.0f, 0.0f, 0.0f}, AZ::Vector4(0.0f, 0.0f, 0.0f, 1.0f)}, // unused
  658. }};
  659. m_globalSrg->SetConstantArray(hitSolidDataConstantIndex, hitSolidData);
  660. m_globalSrg->Compile();
  661. // update the ray tracing shader table
  662. AZStd::shared_ptr<RHI::RayTracingShaderTableDescriptor> descriptor = AZStd::make_shared<RHI::RayTracingShaderTableDescriptor>();
  663. descriptor->m_name = AZ::Name("RayTracingExampleShaderTable");
  664. descriptor->m_rayTracingPipelineState = m_rayTracingPipelineState;
  665. descriptor->m_rayGenerationRecord.emplace_back(AZ::Name("RayGenerationShader"));
  666. descriptor->m_missRecords.emplace_back(AZ::Name("MissShader"));
  667. descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupSolid")); // triangle1
  668. descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupSolid")); // triangle2
  669. descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupSolid")); // rectangle
  670. descriptor->m_hitGroupRecords.emplace_back(AZ::Name("HitGroupGradient")); // clusters
  671. m_rayTracingShaderTable->Build(descriptor);
  672. }
  673. };
  674. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  675. {
  676. if (!m_rayTracingTlas->GetTlasBuffer())
  677. {
  678. return;
  679. }
  680. RHI::CommandList* commandList = context.GetCommandList();
  681. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  682. m_globalSrg->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  683. };
  684. RHI::DeviceDispatchRaysItem dispatchRaysItem;
  685. dispatchRaysItem.m_arguments.m_direct.m_width = m_imageWidth;
  686. dispatchRaysItem.m_arguments.m_direct.m_height = m_imageHeight;
  687. dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
  688. dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
  689. dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
  690. dispatchRaysItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
  691. dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
  692. dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  693. // submit the DispatchRays item
  694. commandList->Submit(dispatchRaysItem);
  695. };
  696. m_scopeProducers.emplace_back(
  697. aznew RHI::ScopeProducerFunction<
  698. ScopeData,
  699. decltype(prepareFunction),
  700. decltype(compileFunction),
  701. decltype(executeFunction)>(
  702. RHI::ScopeId{ "RayTracingDispatch" },
  703. ScopeData{},
  704. prepareFunction,
  705. compileFunction,
  706. executeFunction));
  707. }
  708. void RayTracingClusterExampleComponent::CreateRasterScope()
  709. {
  710. struct ScopeData
  711. {
  712. };
  713. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  714. {
  715. // attach swapchain
  716. {
  717. RHI::ImageScopeAttachmentDescriptor descriptor;
  718. descriptor.m_attachmentId = m_outputAttachmentId;
  719. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  720. frameGraph.UseColorAttachment(descriptor);
  721. }
  722. // attach output buffer
  723. {
  724. RHI::ImageScopeAttachmentDescriptor desc;
  725. desc.m_attachmentId = m_outputImageAttachmentId;
  726. desc.m_imageViewDescriptor = m_outputImageViewDescriptor;
  727. desc.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  728. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::FragmentShader);
  729. const Name outputImageId{ "m_output" };
  730. RHI::ShaderInputImageIndex outputImageIndex = m_drawSRG->FindShaderInputImageIndex(outputImageId);
  731. AZ_Error(SampleName, outputImageIndex.IsValid(), "Failed to find shader input image %s.", outputImageId.GetCStr());
  732. m_drawSRG->SetImageView(outputImageIndex, m_outputImageView.get());
  733. m_drawSRG->Compile();
  734. }
  735. frameGraph.SetEstimatedItemCount(1);
  736. };
  737. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  738. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  739. {
  740. RHI::CommandList* commandList = context.GetCommandList();
  741. commandList->SetViewports(&m_viewport, 1);
  742. commandList->SetScissors(&m_scissor, 1);
  743. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  744. m_drawSRG->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  745. };
  746. RHI::DeviceDrawItem drawItem;
  747. drawItem.m_geometryView = m_geometryView.GetDeviceGeometryView(context.GetDeviceIndex());
  748. drawItem.m_streamIndices = m_geometryView.GetFullStreamBufferIndices();
  749. drawItem.m_pipelineState = m_drawPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  750. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
  751. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  752. // submit the triangle draw item.
  753. commandList->Submit(drawItem);
  754. };
  755. m_scopeProducers.emplace_back(
  756. aznew RHI::ScopeProducerFunction<
  757. ScopeData,
  758. decltype(prepareFunction),
  759. decltype(compileFunction),
  760. decltype(executeFunction)>(
  761. RHI::ScopeId{ "Raster" },
  762. ScopeData{},
  763. prepareFunction,
  764. compileFunction,
  765. executeFunction));
  766. }
  767. } // namespace AtomSampleViewer