ComputeExampleComponent.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/CommandList.h>
  9. #include <Atom/RHI/Factory.h>
  10. #include <Atom/RHI/FrameScheduler.h>
  11. #include <Atom/RHI/ScopeProducerFunction.h>
  12. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  13. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  14. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  15. #include <AzCore/Math/Vector2.h>
  16. #include <AzCore/Math/Vector4.h>
  17. #include <AzCore/Serialization/SerializeContext.h>
  18. #include <RHI/ComputeExampleComponent.h>
  19. #include <SampleComponentConfig.h>
  20. #include <SampleComponentManager.h>
  21. #include <Utils/Utils.h>
  22. namespace AtomSampleViewer
  23. {
  24. const char* ComputeExampleComponent::s_computeExampleName = "ComputeExample";
  25. namespace ShaderInputs
  26. {
  27. static const char* const ShaderInputDimension{ "dimension" };
  28. static const char* const ShaderInputSeed{ "seed" };
  29. }
  30. void ComputeExampleComponent::Reflect(AZ::ReflectContext* context)
  31. {
  32. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  33. {
  34. serializeContext->Class<ComputeExampleComponent, AZ::Component>()->Version(0);
  35. }
  36. }
  37. ComputeExampleComponent::ComputeExampleComponent()
  38. {
  39. m_supportRHISamplePipeline = true;
  40. }
  41. void ComputeExampleComponent::OnFramePrepare(AZ::RHI::FrameGraphBuilder& frameGraphBuilder)
  42. {
  43. m_dispatchSRGs[0]->SetConstant(m_dispatchSeedConstantIndex, AZ::Vector2(cosf(m_time), sin(m_time)));
  44. m_dispatchSRGs[0]->Compile();
  45. BasicRHIComponent::OnFramePrepare(frameGraphBuilder);
  46. }
  47. void ComputeExampleComponent::OnTick(float deltaTime, AZ::ScriptTimePoint time)
  48. {
  49. AZ_UNUSED(time);
  50. m_time += deltaTime;
  51. }
  52. void ComputeExampleComponent::Activate()
  53. {
  54. CreateInputAssemblyBuffersAndViews();
  55. CreateComputeBuffer();
  56. LoadComputeShader();
  57. LoadRasterShader();
  58. CreateComputeScope();
  59. CreateRasterScope();
  60. AZ::TickBus::Handler::BusConnect();
  61. AZ::RHI::RHISystemNotificationBus::Handler::BusConnect();
  62. }
  63. void ComputeExampleComponent::Deactivate()
  64. {
  65. m_inputAssemblyBuffer = nullptr;
  66. m_inputAssemblyBufferPool = nullptr;
  67. m_dispatchPipelineState = nullptr;
  68. m_drawPipelineState = nullptr;
  69. m_dispatchSRGs.fill(nullptr);
  70. m_drawSRGs.fill(nullptr);
  71. m_computeBufferPool = nullptr;
  72. m_computeBuffer = nullptr;
  73. m_computeBufferView = nullptr;
  74. m_scopeProducers.clear();
  75. m_windowContext = nullptr;
  76. AZ::TickBus::Handler::BusDisconnect();
  77. AZ::RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  78. }
  79. void ComputeExampleComponent::CreateInputAssemblyBuffersAndViews()
  80. {
  81. using namespace AZ;
  82. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  83. m_inputAssemblyBufferPool = aznew RHI::BufferPool();
  84. RHI::BufferPoolDescriptor bufferPoolDesc;
  85. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  86. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
  87. m_inputAssemblyBufferPool->Init(bufferPoolDesc);
  88. BufferData bufferData;
  89. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  90. m_inputAssemblyBuffer = aznew RHI::Buffer();
  91. RHI::BufferInitRequest request;
  92. request.m_buffer = m_inputAssemblyBuffer.get();
  93. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  94. request.m_initialData = &bufferData;
  95. m_inputAssemblyBufferPool->InitBuffer(request);
  96. m_geometryView.SetDrawArguments(RHI::DrawIndexed(0, 6, 0));
  97. m_geometryView.AddStreamBufferView({
  98. *m_inputAssemblyBuffer,
  99. offsetof(BufferData, m_positions),
  100. sizeof(BufferData::m_positions),
  101. sizeof(VertexPosition)
  102. });
  103. m_geometryView.AddStreamBufferView({
  104. *m_inputAssemblyBuffer,
  105. offsetof(BufferData, m_uvs),
  106. sizeof(BufferData::m_uvs),
  107. sizeof(VertexUV)
  108. });
  109. m_geometryView.SetIndexBufferView({
  110. *m_inputAssemblyBuffer,
  111. offsetof(BufferData, m_indices),
  112. sizeof(BufferData::m_indices),
  113. RHI::IndexFormat::Uint16
  114. });
  115. RHI::InputStreamLayoutBuilder layoutBuilder;
  116. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  117. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  118. m_inputStreamLayout = layoutBuilder.End();
  119. RHI::ValidateStreamBufferViews(m_inputStreamLayout, m_geometryView, m_geometryView.GetFullStreamBufferIndices());
  120. }
  121. void ComputeExampleComponent::LoadComputeShader()
  122. {
  123. using namespace AZ;
  124. const char* shaderFilePath = "Shaders/RHI/ComputeDispatch.azshader";
  125. const auto shader = LoadShader(shaderFilePath, s_computeExampleName);
  126. if (shader == nullptr)
  127. return;
  128. RHI::PipelineStateDescriptorForDispatch pipelineDesc;
  129. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  130. const auto& numThreads = shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name("numthreads"));
  131. if (numThreads)
  132. {
  133. const RHI::ShaderStageAttributeArguments& args = *numThreads;
  134. m_numThreadsX = args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : m_numThreadsX;
  135. m_numThreadsY = args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : m_numThreadsY;
  136. m_numThreadsZ = args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : m_numThreadsZ;
  137. }
  138. else
  139. {
  140. AZ_Error(s_computeExampleName, false, "Did not find expected numthreads attribute");
  141. }
  142. m_dispatchPipelineState = shader->AcquirePipelineState(pipelineDesc);
  143. if (!m_dispatchPipelineState)
  144. {
  145. AZ_Error(s_computeExampleName, false, "Failed to acquire default pipeline state for shader '%s'", shaderFilePath);
  146. return;
  147. }
  148. m_dispatchSRGs[0] = CreateShaderResourceGroup(shader, "ConstantSrg", s_computeExampleName);
  149. m_dispatchSRGs[1] = CreateShaderResourceGroup(shader, "BufferSrg", s_computeExampleName);
  150. FindShaderInputIndex(&m_dispatchDimensionConstantIndex, m_dispatchSRGs[0], AZ::Name{ShaderInputs::ShaderInputDimension}, s_computeExampleName);
  151. FindShaderInputIndex(&m_dispatchSeedConstantIndex, m_dispatchSRGs[0], AZ::Name{ShaderInputs::ShaderInputSeed}, s_computeExampleName);
  152. // This SRG will be compiled during the OnFramePrepare
  153. m_dispatchSRGs[0]->SetConstant(m_dispatchDimensionConstantIndex, AZ::Vector2(static_cast<float>(m_bufferWidth), static_cast<float>(m_bufferHeight)));
  154. }
  155. void ComputeExampleComponent::LoadRasterShader()
  156. {
  157. using namespace AZ;
  158. const char* shaderFilePath = "Shaders/RHI/ComputeDraw.azshader";
  159. auto shader = LoadShader(shaderFilePath, s_computeExampleName);
  160. if (shader == nullptr)
  161. return;
  162. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  163. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  164. pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
  165. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  166. attachmentsBuilder.AddSubpass()
  167. ->RenderTargetAttachment(m_outputFormat);
  168. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  169. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  170. m_drawPipelineState = shader->AcquirePipelineState(pipelineDesc);
  171. if (!m_drawPipelineState)
  172. {
  173. AZ_Error(s_computeExampleName, false, "Failed to acquire default pipeline state for shader '%s'", shaderFilePath);
  174. return;
  175. }
  176. m_drawSRGs[0] = CreateShaderResourceGroup(shader, "ConstantSrg", s_computeExampleName);
  177. m_drawSRGs[1] = CreateShaderResourceGroup(shader, "BufferSrg", s_computeExampleName);
  178. FindShaderInputIndex(&m_drawDimensionConstantIndex, m_drawSRGs[0], AZ::Name{ShaderInputs::ShaderInputDimension}, s_computeExampleName);
  179. m_drawSRGs[0]->SetConstant(m_drawDimensionConstantIndex, AZ::Vector2(static_cast<float>(m_bufferWidth), static_cast<float>(m_bufferHeight)));
  180. m_drawSRGs[0]->Compile();
  181. }
  182. void ComputeExampleComponent::CreateComputeBuffer()
  183. {
  184. using namespace AZ;
  185. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  186. [[maybe_unused]] RHI::ResultCode result = RHI::ResultCode::Success;
  187. m_computeBufferPool = aznew RHI::BufferPool();
  188. RHI::BufferPoolDescriptor bufferPoolDesc;
  189. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
  190. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
  191. bufferPoolDesc.m_hostMemoryAccess = RHI::HostMemoryAccess::Write;
  192. result = m_computeBufferPool->Init(bufferPoolDesc);
  193. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialized compute buffer pool");
  194. m_computeBuffer = aznew RHI::Buffer();
  195. uint32_t bufferSize = m_bufferWidth * m_bufferHeight * RHI::GetFormatSize(RHI::Format::R32G32B32A32_FLOAT);
  196. RHI::BufferInitRequest request;
  197. request.m_buffer = m_computeBuffer.get();
  198. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::ShaderReadWrite, bufferSize };
  199. result = m_computeBufferPool->InitBuffer(request);
  200. AZ_Assert(result == RHI::ResultCode::Success, "Failed to initialized compute buffer");
  201. m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, m_bufferWidth * m_bufferHeight, RHI::GetFormatSize(RHI::Format::R32G32B32A32_FLOAT));
  202. m_computeBufferView = m_computeBuffer->BuildBufferView(m_bufferViewDescriptor);
  203. if(!m_computeBufferView.get())
  204. {
  205. AZ_Assert(false, "Failed to initialized compute buffer view");
  206. }
  207. AZ_Assert(m_computeBufferView->GetDeviceBufferView(RHI::MultiDevice::DefaultDeviceIndex)->IsFullView(), "compute Buffer View initialization failed to cover in full the Compute Buffer");
  208. }
  209. void ComputeExampleComponent::CreateComputeScope()
  210. {
  211. using namespace AZ;
  212. struct ScopeData
  213. {
  214. //UserDataParam - Empty for this samples
  215. };
  216. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  217. {
  218. // attach compute buffer
  219. {
  220. [[maybe_unused]] RHI::ResultCode result = frameGraph.GetAttachmentDatabase().ImportBuffer(m_bufferAttachmentId, m_computeBuffer);
  221. AZ_Error(s_computeExampleName, result == RHI::ResultCode::Success, "Failed to import compute buffer with error %d", result);
  222. RHI::BufferScopeAttachmentDescriptor desc;
  223. desc.m_attachmentId = m_bufferAttachmentId;
  224. desc.m_bufferViewDescriptor = m_bufferViewDescriptor;
  225. desc.m_loadStoreAction.m_clearValue = AZ::RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  226. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::ComputeShader);
  227. const Name computeBufferId{ "m_computeBuffer" };
  228. RHI::ShaderInputBufferIndex computeBufferIndex = m_dispatchSRGs[1]->FindShaderInputBufferIndex(computeBufferId);
  229. AZ_Error(s_computeExampleName, computeBufferIndex.IsValid(), "Failed to find shader input buffer %s.", computeBufferId.GetCStr());
  230. m_dispatchSRGs[1]->SetBufferView(computeBufferIndex, m_computeBufferView.get());
  231. m_dispatchSRGs[1]->Compile();
  232. }
  233. frameGraph.SetEstimatedItemCount(1);
  234. };
  235. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  236. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  237. {
  238. RHI::CommandList* commandList = context.GetCommandList();
  239. // Set persistent viewport and scissor state.
  240. commandList->SetViewports(&m_viewport, 1);
  241. commandList->SetScissors(&m_scissor, 1);
  242. AZStd::array<const RHI::DeviceShaderResourceGroup*, 8> shaderResourceGroups;
  243. shaderResourceGroups[0] = m_dispatchSRGs[0]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get();
  244. shaderResourceGroups[1] = m_dispatchSRGs[1]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get();
  245. RHI::DeviceDispatchItem dispatchItem;
  246. RHI::DispatchDirect dispatchArgs;
  247. dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(m_numThreadsX);
  248. dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(m_numThreadsY);
  249. dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(m_numThreadsZ);
  250. dispatchArgs.m_totalNumberOfThreadsX = m_bufferWidth;
  251. dispatchArgs.m_totalNumberOfThreadsY = m_bufferHeight;
  252. dispatchArgs.m_totalNumberOfThreadsZ = 1;
  253. dispatchItem.m_arguments = dispatchArgs;
  254. dispatchItem.m_pipelineState = m_dispatchPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  255. dispatchItem.m_shaderResourceGroupCount = 2;
  256. dispatchItem.m_shaderResourceGroups = shaderResourceGroups;
  257. commandList->Submit(dispatchItem);
  258. };
  259. m_scopeProducers.emplace_back(
  260. aznew RHI::ScopeProducerFunction<
  261. ScopeData,
  262. decltype(prepareFunction),
  263. decltype(compileFunction),
  264. decltype(executeFunction)>(
  265. RHI::ScopeId{"Compute"},
  266. ScopeData{},
  267. prepareFunction,
  268. compileFunction,
  269. executeFunction));
  270. }
  271. void ComputeExampleComponent::CreateRasterScope()
  272. {
  273. using namespace AZ;
  274. struct ScopeData
  275. {
  276. RPI::WindowContext* m_windowContext;
  277. };
  278. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  279. {
  280. // Binds the swap chain as a color attachment. Clears it to white.
  281. {
  282. RHI::ImageScopeAttachmentDescriptor descriptor;
  283. descriptor.m_attachmentId = m_outputAttachmentId;
  284. descriptor.m_loadStoreAction.m_clearValue = RHI::ClearValue::CreateVector4Float(1.0f, 1.0, 1.0, 0.0);
  285. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Clear;
  286. frameGraph.UseColorAttachment(descriptor);
  287. }
  288. // attach compute buffer
  289. {
  290. RHI::BufferScopeAttachmentDescriptor desc;
  291. desc.m_attachmentId = m_bufferAttachmentId;
  292. desc.m_bufferViewDescriptor = m_bufferViewDescriptor;
  293. desc.m_loadStoreAction.m_clearValue = AZ::RHI::ClearValue::CreateVector4Float(0.0f, 0.0f, 0.0f, 0.0f);
  294. frameGraph.UseShaderAttachment(desc, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::FragmentShader);
  295. const Name computeBufferId{ "m_computeBuffer" };
  296. RHI::ShaderInputBufferIndex computeBufferIndex = m_drawSRGs[1]->FindShaderInputBufferIndex(computeBufferId);
  297. AZ_Error(s_computeExampleName, computeBufferIndex.IsValid(), "Failed to find shader input buffer %s.", computeBufferId.GetCStr());
  298. m_drawSRGs[1]->SetBufferView(computeBufferIndex, m_computeBufferView.get());
  299. m_drawSRGs[1]->Compile();
  300. }
  301. // We will submit a single draw item.
  302. frameGraph.SetEstimatedItemCount(1);
  303. };
  304. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  305. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  306. {
  307. RHI::CommandList* commandList = context.GetCommandList();
  308. // Set persistent viewport and scissor state.
  309. commandList->SetViewports(&m_viewport, 1);
  310. commandList->SetScissors(&m_scissor, 1);
  311. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  312. m_drawSRGs[0]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
  313. m_drawSRGs[1]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  314. };
  315. RHI::DeviceDrawItem drawItem;
  316. drawItem.m_geometryView = m_geometryView.GetDeviceGeometryView(context.GetDeviceIndex());
  317. drawItem.m_streamIndices = m_geometryView.GetFullStreamBufferIndices();
  318. drawItem.m_pipelineState = m_drawPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  319. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
  320. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  321. // Submit the triangle draw item.
  322. commandList->Submit(drawItem);
  323. };
  324. m_scopeProducers.emplace_back(
  325. aznew RHI::ScopeProducerFunction<
  326. ScopeData,
  327. decltype(prepareFunction),
  328. decltype(compileFunction),
  329. decltype(executeFunction)>(
  330. RHI::ScopeId{"Raster"},
  331. ScopeData{},
  332. prepareFunction,
  333. compileFunction,
  334. executeFunction));
  335. }
  336. }