ComputeExampleComponent.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/CommandList.h>
  9. #include <Atom/RHI/Factory.h>
  10. #include <Atom/RHI/FrameScheduler.h>
  11. #include <Atom/RHI/Image.h>
  12. #include <Atom/RHI/ImagePool.h>
  13. #include <Atom/RHI/ScopeProducerFunction.h>
  14. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  15. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  16. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  17. #include <AzCore/Math/Vector2.h>
  18. #include <AzCore/Math/Vector4.h>
  19. #include <AzCore/Serialization/SerializeContext.h>
  20. #include <RHI/ComputeExampleComponent.h>
  21. #include <SampleComponentConfig.h>
  22. #include <SampleComponentManager.h>
  23. #include <Utils/Utils.h>
  24. namespace AtomSampleViewer
  25. {
  26. const char* ComputeExampleComponent::s_computeExampleName = "ComputeExample";
  27. namespace ShaderInputs
  28. {
  29. static const char* const ShaderInputDimension{ "dimension" };
  30. static const char* const ShaderInputSeed{ "seed" };
  31. }
  32. void ComputeExampleComponent::Reflect(AZ::ReflectContext* context)
  33. {
  34. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  35. {
  36. serializeContext->Class<ComputeExampleComponent, AZ::Component>()->Version(0);
  37. }
  38. }
  39. ComputeExampleComponent::ComputeExampleComponent()
  40. {
  41. m_supportRHISamplePipeline = true;
  42. }
  43. void ComputeExampleComponent::OnFramePrepare(AZ::RHI::FrameGraphBuilder& frameGraphBuilder)
  44. {
  45. m_dispatchSRGs[0]->SetConstant(m_dispatchSeedConstantIndex, AZ::Vector2(cosf(m_time), sin(m_time)));
  46. m_dispatchSRGs[0]->Compile();
  47. BasicRHIComponent::OnFramePrepare(frameGraphBuilder);
  48. }
  49. void ComputeExampleComponent::OnTick(float deltaTime, AZ::ScriptTimePoint time)
  50. {
  51. AZ_UNUSED(time);
  52. m_time += deltaTime;
  53. }
  54. void ComputeExampleComponent::Activate()
  55. {
  56. CreateInputAssemblyBuffersAndViews();
  57. CreateComputeBuffer();
  58. LoadComputeShader();
  59. LoadRasterShader();
  60. CreateComputeScope();
  61. CreateRasterScope();
  62. AZ::TickBus::Handler::BusConnect();
  63. AZ::RHI::RHISystemNotificationBus::Handler::BusConnect();
  64. }
  65. void ComputeExampleComponent::Deactivate()
  66. {
  67. m_inputAssemblyBuffer = nullptr;
  68. m_inputAssemblyBufferPool = nullptr;
  69. m_dispatchPipelineState = nullptr;
  70. m_drawPipelineState = nullptr;
  71. m_dispatchSRGs.fill(nullptr);
  72. m_drawSRGs.fill(nullptr);
  73. m_computeBufferPool = nullptr;
  74. m_computeBuffer = nullptr;
  75. m_computeBufferView = nullptr;
  76. m_scopeProducers.clear();
  77. m_windowContext = nullptr;
  78. AZ::TickBus::Handler::BusDisconnect();
  79. AZ::RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  80. }
  81. void ComputeExampleComponent::CreateInputAssemblyBuffersAndViews()
  82. {
  83. using namespace AZ;
  84. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  85. m_inputAssemblyBufferPool = RHI::Factory::Get().CreateBufferPool();
  86. RHI::BufferPoolDescriptor bufferPoolDesc;
  87. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  88. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
  89. m_inputAssemblyBufferPool->Init(*device, bufferPoolDesc);
  90. BufferData bufferData;
  91. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  92. m_inputAssemblyBuffer = RHI::Factory::Get().CreateBuffer();
  93. RHI::BufferInitRequest request;
  94. request.m_buffer = m_inputAssemblyBuffer.get();
  95. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  96. request.m_initialData = &bufferData;
  97. m_inputAssemblyBufferPool->InitBuffer(request);
  98. m_streamBufferViews[0] =
  99. {
  100. *m_inputAssemblyBuffer,
  101. offsetof(BufferData, m_positions),
  102. sizeof(BufferData::m_positions),
  103. sizeof(VertexPosition)
  104. };
  105. m_streamBufferViews[1] =
  106. {
  107. *m_inputAssemblyBuffer,
  108. offsetof(BufferData, m_uvs),
  109. sizeof(BufferData::m_uvs),
  110. sizeof(VertexUV)
  111. };
  112. m_indexBufferView =
  113. {
  114. *m_inputAssemblyBuffer,
  115. offsetof(BufferData, m_indices),
  116. sizeof(BufferData::m_indices),
  117. RHI::IndexFormat::Uint16
  118. };
  119. RHI::InputStreamLayoutBuilder layoutBuilder;
  120. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  121. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  122. m_inputStreamLayout = layoutBuilder.End();
  123. RHI::ValidateStreamBufferViews(m_inputStreamLayout, m_streamBufferViews);
  124. }
  125. void ComputeExampleComponent::LoadComputeShader()
  126. {
  127. using namespace AZ;
  128. const char* shaderFilePath = "Shaders/RHI/ComputeDispatch.azshader";
  129. const auto shader = LoadShader(shaderFilePath, s_computeExampleName);
  130. if (shader == nullptr)
  131. return;
  132. RHI::PipelineStateDescriptorForDispatch pipelineDesc;
  133. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  134. const auto& numThreads = shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name("numthreads"));
  135. if (numThreads)
  136. {
  137. const RHI::ShaderStageAttributeArguments& args = *numThreads;
  138. m_numThreadsX = args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : m_numThreadsX;
  139. m_numThreadsY = args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : m_numThreadsY;
  140. m_numThreadsZ = args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : m_numThreadsZ;
  141. }
  142. else
  143. {
  144. AZ_Error(s_computeExampleName, false, "Did not find expected numthreads attribute");
  145. }
  146. m_dispatchPipelineState = shader->AcquirePipelineState(pipelineDesc);
  147. if (!m_dispatchPipelineState)
  148. {
  149. AZ_Error(s_computeExampleName, false, "Failed to acquire default pipeline state for shader '%s'", shaderFilePath);
  150. return;
  151. }
  152. m_dispatchSRGs[0] = CreateShaderResourceGroup(shader, "ConstantSrg", s_computeExampleName);
  153. m_dispatchSRGs[1] = CreateShaderResourceGroup(shader, "BufferSrg", s_computeExampleName);
  154. FindShaderInputIndex(&m_dispatchDimensionConstantIndex, m_dispatchSRGs[0], AZ::Name{ShaderInputs::ShaderInputDimension}, s_computeExampleName);
  155. FindShaderInputIndex(&m_dispatchSeedConstantIndex, m_dispatchSRGs[0], AZ::Name{ShaderInputs::ShaderInputSeed}, s_computeExampleName);
  156. // This SRG will be compiled during the OnFramePrepare
  157. m_dispatchSRGs[0]->SetConstant(m_dispatchDimensionConstantIndex, AZ::Vector2(static_cast<float>(m_bufferWidth), static_cast<float>(m_bufferHeight)));
  158. }
  159. void ComputeExampleComponent::LoadRasterShader()
  160. {
  161. using namespace AZ;
  162. const char* shaderFilePath = "Shaders/RHI/ComputeDraw.azshader";
  163. auto shader = LoadShader(shaderFilePath, s_computeExampleName);
  164. if (shader == nullptr)
  165. return;
  166. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  167. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  168. pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
  169. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  170. attachmentsBuilder.AddSubpass()
  171. ->RenderTargetAttachment(m_outputFormat);
  172. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  173. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  174. m_drawPipelineState = shader->AcquirePipelineState(pipelineDesc);
  175. if (!m_drawPipelineState)
  176. {
  177. AZ_Error(s_computeExampleName, false, "Failed to acquire default pipeline state for shader '%s'", shaderFilePath);
  178. return;
  179. }
  180. m_drawSRGs[0] = CreateShaderResourceGroup(shader, "ConstantSrg", s_computeExampleName);
  181. m_drawSRGs[1] = CreateShaderResourceGroup(shader, "BufferSrg", s_computeExampleName);
  182. FindShaderInputIndex(&m_drawDimensionConstantIndex, m_drawSRGs[0], AZ::Name{ShaderInputs::ShaderInputDimension}, s_computeExampleName);
  183. m_drawSRGs[0]->SetConstant(m_drawDimensionConstantIndex, AZ::Vector2(static_cast<float>(m_bufferWidth), static_cast<float>(m_bufferHeight)));
  184. m_drawSRGs[0]->Compile();
  185. }
  186. void ComputeExampleComponent::CreateComputeBuffer()
  187. {
  188. using namespace AZ;
  189. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  190. RHI::ResultCode result = RHI::ResultCode::Success;
  191. m_computeBufferPool = RHI::Factory::Get().CreateBufferPool();
  192. RHI::BufferPoolDescriptor bufferPoolDesc;
  193. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
  194. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
  195. bufferPoolDesc.m_hostMemoryAccess = RHI::HostMemoryAccess::Write;
  196. result = m_computeBufferPool->Init(*device, bufferPoolDesc);
  197. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialized compute buffer pool");
  198. m_computeBuffer = RHI::Factory::Get().CreateBuffer();
  199. uint32_t bufferSize = m_bufferWidth * m_bufferHeight * RHI::GetFormatSize(RHI::Format::R32G32B32A32_FLOAT);
  200. RHI::BufferInitRequest request;
  201. request.m_buffer = m_computeBuffer.get();
  202. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::ShaderReadWrite, bufferSize };
  203. result = m_computeBufferPool->InitBuffer(request);
  204. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialized compute buffer");
  205. m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, m_bufferWidth * m_bufferHeight, RHI::GetFormatSize(RHI::Format::R32G32B32A32_FLOAT));
  206. m_computeBufferView = m_computeBuffer->GetBufferView(m_bufferViewDescriptor);
  207. if(!m_computeBufferView.get())
  208. {
  209. AZ_Assert(false, "Failed to initialized compute buffer view");
  210. }
  211. AZ_Assert(m_computeBufferView->IsFullView(), "compute Buffer View initialization failed to cover in full the Compute Buffer");
  212. }
  213. void ComputeExampleComponent::CreateComputeScope()
  214. {
  215. using namespace AZ;
  216. struct ScopeData
  217. {
  218. //UserDataParam - Empty for this samples
  219. };
  220. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  221. {
  222. // attach compute buffer
  223. {
  224. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_bufferAttachmentId, m_computeBuffer);
  225. AZ_Error(s_computeExampleName, result == RHI::ResultCode::Success, "Failed to import compute buffer with error %d", result);
  226. RHI::BufferScopeAttachmentDescriptor desc;
  227. desc.m_attachmentId = m_bufferAttachmentId;
  228. desc.m_bufferViewDescriptor = m_bufferViewDescriptor;
  229. desc.m_loadStoreAction.m_clearValue = AZ::RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  230. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  231. const Name computeBufferId{ "m_computeBuffer" };
  232. RHI::ShaderInputBufferIndex computeBufferIndex = m_dispatchSRGs[1]->FindShaderInputBufferIndex(computeBufferId);
  233. AZ_Error(s_computeExampleName, computeBufferIndex.IsValid(), "Failed to find shader input buffer %s.", computeBufferId.GetCStr());
  234. m_dispatchSRGs[1]->SetBufferView(computeBufferIndex, m_computeBufferView.get());
  235. m_dispatchSRGs[1]->Compile();
  236. }
  237. frameGraph.SetEstimatedItemCount(1);
  238. };
  239. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  240. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  241. {
  242. RHI::CommandList* commandList = context.GetCommandList();
  243. // Set persistent viewport and scissor state.
  244. commandList->SetViewports(&m_viewport, 1);
  245. commandList->SetScissors(&m_scissor, 1);
  246. AZStd::array <const RHI::ShaderResourceGroup*, 8> shaderResourceGroups;
  247. shaderResourceGroups[0] = m_dispatchSRGs[0]->GetRHIShaderResourceGroup();
  248. shaderResourceGroups[1] = m_dispatchSRGs[1]->GetRHIShaderResourceGroup();
  249. RHI::DispatchItem dispatchItem;
  250. RHI::DispatchDirect dispatchArgs;
  251. dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(m_numThreadsX);
  252. dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(m_numThreadsY);
  253. dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(m_numThreadsZ);
  254. dispatchArgs.m_totalNumberOfThreadsX = m_bufferWidth;
  255. dispatchArgs.m_totalNumberOfThreadsY = m_bufferHeight;
  256. dispatchArgs.m_totalNumberOfThreadsZ = 1;
  257. dispatchItem.m_arguments = dispatchArgs;
  258. dispatchItem.m_pipelineState = m_dispatchPipelineState.get();
  259. dispatchItem.m_shaderResourceGroupCount = 2;
  260. dispatchItem.m_shaderResourceGroups = shaderResourceGroups;
  261. commandList->Submit(dispatchItem);
  262. };
  263. m_scopeProducers.emplace_back(
  264. aznew RHI::ScopeProducerFunction<
  265. ScopeData,
  266. decltype(prepareFunction),
  267. decltype(compileFunction),
  268. decltype(executeFunction)>(
  269. RHI::ScopeId{"Compute"},
  270. ScopeData{},
  271. prepareFunction,
  272. compileFunction,
  273. executeFunction));
  274. }
  275. void ComputeExampleComponent::CreateRasterScope()
  276. {
  277. using namespace AZ;
  278. struct ScopeData
  279. {
  280. RPI::WindowContext* m_windowContext;
  281. };
  282. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  283. {
  284. // Binds the swap chain as a color attachment. Clears it to white.
  285. {
  286. RHI::ImageScopeAttachmentDescriptor descriptor;
  287. descriptor.m_attachmentId = m_outputAttachmentId;
  288. descriptor.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 1.0, 1.0, 0.0);
  289. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
  290. frameGraph.UseColorAttachment(descriptor);
  291. }
  292. // attach compute buffer
  293. {
  294. RHI::BufferScopeAttachmentDescriptor desc;
  295. desc.m_attachmentId = m_bufferAttachmentId;
  296. desc.m_bufferViewDescriptor = m_bufferViewDescriptor;
  297. desc.m_loadStoreAction.m_clearValue = AZ::RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  298. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite);
  299. const Name computeBufferId{ "m_computeBuffer" };
  300. RHI::ShaderInputBufferIndex computeBufferIndex = m_drawSRGs[1]->FindShaderInputBufferIndex(computeBufferId);
  301. AZ_Error(s_computeExampleName, computeBufferIndex.IsValid(), "Failed to find shader input buffer %s.", computeBufferId.GetCStr());
  302. m_drawSRGs[1]->SetBufferView(computeBufferIndex, m_computeBufferView.get());
  303. m_drawSRGs[1]->Compile();
  304. }
  305. // We will submit a single draw item.
  306. frameGraph.SetEstimatedItemCount(1);
  307. };
  308. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  309. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  310. {
  311. RHI::CommandList* commandList = context.GetCommandList();
  312. // Set persistent viewport and scissor state.
  313. commandList->SetViewports(&m_viewport, 1);
  314. commandList->SetScissors(&m_scissor, 1);
  315. RHI::DrawIndexed drawIndexed;
  316. drawIndexed.m_indexCount = 6;
  317. drawIndexed.m_instanceCount = 1;
  318. const RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_drawSRGs[0]->GetRHIShaderResourceGroup(), m_drawSRGs[1]->GetRHIShaderResourceGroup() };
  319. RHI::DrawItem drawItem;
  320. drawItem.m_arguments = drawIndexed;
  321. drawItem.m_pipelineState = m_drawPipelineState.get();
  322. drawItem.m_indexBufferView = &m_indexBufferView;
  323. drawItem.m_streamBufferViewCount = static_cast<uint8_t>(m_streamBufferViews.size());
  324. drawItem.m_streamBufferViews = m_streamBufferViews.data();
  325. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
  326. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  327. // Submit the triangle draw item.
  328. commandList->Submit(drawItem);
  329. };
  330. m_scopeProducers.emplace_back(
  331. aznew RHI::ScopeProducerFunction<
  332. ScopeData,
  333. decltype(prepareFunction),
  334. decltype(compileFunction),
  335. decltype(executeFunction)>(
  336. RHI::ScopeId{"Raster"},
  337. ScopeData{},
  338. prepareFunction,
  339. compileFunction,
  340. executeFunction));
  341. }
  342. }