InputAssemblyExampleComponent.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/CommandList.h>
  9. #include <Atom/RHI/Factory.h>
  10. #include <Atom/RHI/FrameScheduler.h>
  11. #include <Atom/RHI/ScopeProducerFunction.h>
  12. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  13. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  14. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  15. #include <AzCore/Math/Matrix4x4.h>
  16. #include <AzCore/Math/Vector4.h>
  17. #include <AzCore/Serialization/SerializeContext.h>
  18. #include <RHI/InputAssemblyExampleComponent.h>
  19. #include <SampleComponentConfig.h>
  20. #include <SampleComponentManager.h>
  21. #include <Utils/Utils.h>
  22. namespace AtomSampleViewer
  23. {
  24. namespace InputAssembly
  25. {
  26. const char* SampleName = "InputaAssemblyExample";
  27. const char* const ShaderInputTime{ "m_time" };
  28. const char* const ShaderInpuIABuffer{ "m_IABuffer" };
  29. const char* const ShaderInputMatrix{ "m_matrix" };
  30. const char* const ShaderInputColor{ "m_color" };
  31. const char* InputAssemblyBufferAttachmentId = "InputAssemblyBufferAttachmentId";
  32. const char* ImportedInputAssemblyBufferAttachmentId = "ImportedInputAssemblyBufferAttachmentId";
  33. }
  34. void InputAssemblyExampleComponent::Reflect(AZ::ReflectContext* context)
  35. {
  36. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  37. {
  38. serializeContext->Class<InputAssemblyExampleComponent, AZ::Component>()->Version(0);
  39. }
  40. }
  41. InputAssemblyExampleComponent::InputAssemblyExampleComponent()
  42. {
  43. m_supportRHISamplePipeline = true;
  44. }
  45. void InputAssemblyExampleComponent::FrameBeginInternal(AZ::RHI::FrameGraphBuilder& frameGraphBuilder)
  46. {
  47. using namespace AZ;
  48. if (m_windowContext->GetSwapChain())
  49. {
  50. // Create Input Assembly buffer
  51. {
  52. RHI::TransientBufferDescriptor bufferDesc;
  53. bufferDesc.m_attachmentId = InputAssembly::InputAssemblyBufferAttachmentId;
  54. bufferDesc.m_bufferDescriptor = RHI::BufferDescriptor(
  55. RHI::BufferBindFlags::InputAssembly | RHI::BufferBindFlags::ShaderReadWrite,
  56. sizeof(BufferData));
  57. frameGraphBuilder.GetAttachmentDatabase().CreateTransientBuffer(bufferDesc);
  58. }
  59. {
  60. frameGraphBuilder.GetAttachmentDatabase().ImportBuffer(AZ::Name{ InputAssembly::ImportedInputAssemblyBufferAttachmentId }, m_inputAssemblyBuffer);
  61. }
  62. float aspectRatio = static_cast<float>(m_outputWidth / m_outputHeight);
  63. AZ::Vector2 scale(AZStd::min(1.0f / aspectRatio, 1.0f), AZStd::min(aspectRatio, 1.0f));
  64. {
  65. AZ::Matrix4x4 scaleTranslate =
  66. AZ::Matrix4x4::CreateTranslation(AZ::Vector3(0.4f, 0.4f, 0)) *
  67. AZ::Matrix4x4::CreateScale(AZ::Vector3(scale.GetX() * 0.6f, scale.GetY() * 0.6f, 1.0f));
  68. m_drawSRG[0]->SetConstant(m_drawMatrixIndex, scaleTranslate);
  69. m_drawSRG[0]->SetConstant(m_drawColorIndex, AZ::Vector4(1.0, 0, 0, 1.0f));
  70. m_drawSRG[0]->Compile();
  71. }
  72. {
  73. AZ::Matrix4x4 scaleTranslate =
  74. AZ::Matrix4x4::CreateTranslation(AZ::Vector3(-0.4f, -0.4f, 0)) *
  75. AZ::Matrix4x4::CreateScale(AZ::Vector3(scale.GetX() * 0.4f, scale.GetY() * 0.4f, 1.0f));
  76. m_drawSRG[1]->SetConstant(m_drawMatrixIndex, scaleTranslate);
  77. m_drawSRG[1]->SetConstant(m_drawColorIndex, AZ::Vector4(0.0, 1, 0, 1.0f));
  78. m_drawSRG[1]->Compile();
  79. }
  80. }
  81. }
  82. void InputAssemblyExampleComponent::OnTick(float deltaTime, AZ::ScriptTimePoint time)
  83. {
  84. AZ_UNUSED(time);
  85. m_time += deltaTime;
  86. }
  87. void InputAssemblyExampleComponent::Activate()
  88. {
  89. CreateInputAssemblyLayout();
  90. CreateBuffers();
  91. LoadComputeShader();
  92. LoadRasterShader();
  93. CreateComputeScope();
  94. CreateRasterScope();
  95. AZ::TickBus::Handler::BusConnect();
  96. AZ::RHI::RHISystemNotificationBus::Handler::BusConnect();
  97. }
  98. void InputAssemblyExampleComponent::Deactivate()
  99. {
  100. m_dispatchPipelineState = nullptr;
  101. m_drawPipelineState = nullptr;
  102. m_dispatchSRG[0] = nullptr;
  103. m_dispatchSRG[1] = nullptr;
  104. m_drawSRG[0] = nullptr;
  105. m_drawSRG[1] = nullptr;
  106. m_inputAssemblyBuffer = nullptr;
  107. m_inputAssemblyBufferPool = nullptr;
  108. m_scopeProducers.clear();
  109. m_windowContext = nullptr;
  110. AZ::TickBus::Handler::BusDisconnect();
  111. AZ::RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  112. }
  113. void InputAssemblyExampleComponent::CreateInputAssemblyLayout()
  114. {
  115. using namespace AZ;
  116. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  117. RHI::InputStreamLayoutBuilder layoutBuilder;
  118. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32A32_FLOAT);
  119. layoutBuilder.SetTopology(RHI::PrimitiveTopology::TriangleStrip);
  120. m_inputStreamLayout = layoutBuilder.End();
  121. }
  122. void InputAssemblyExampleComponent::CreateBuffers()
  123. {
  124. using namespace AZ;
  125. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  126. m_inputAssemblyBufferPool = aznew RHI::BufferPool();
  127. RHI::BufferPoolDescriptor bufferPoolDesc;
  128. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly | RHI::BufferBindFlags::ShaderReadWrite;
  129. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
  130. m_inputAssemblyBufferPool->Init(bufferPoolDesc);
  131. m_inputAssemblyBuffer = aznew RHI::Buffer();
  132. RHI::BufferInitRequest request;
  133. request.m_buffer = m_inputAssemblyBuffer.get();
  134. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly | RHI::BufferBindFlags::ShaderReadWrite, sizeof(BufferData) };
  135. request.m_initialData = nullptr;
  136. m_inputAssemblyBufferPool->InitBuffer(request);
  137. }
  138. void InputAssemblyExampleComponent::LoadComputeShader()
  139. {
  140. using namespace AZ;
  141. const char* shaderFilePath = "Shaders/RHI/InputAssemblyCompute.azshader";
  142. const auto shader = LoadShader(shaderFilePath, InputAssembly::SampleName);
  143. if (shader == nullptr)
  144. return;
  145. RHI::PipelineStateDescriptorForDispatch pipelineDesc;
  146. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  147. const auto& numThreads = shader->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name("numthreads"));
  148. if (numThreads)
  149. {
  150. const RHI::ShaderStageAttributeArguments& args = *numThreads;
  151. m_numThreadsX = args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : m_numThreadsX;
  152. m_numThreadsY = args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : m_numThreadsY;
  153. m_numThreadsZ = args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : m_numThreadsZ;
  154. }
  155. else
  156. {
  157. AZ_Error(InputAssembly::SampleName, false, "Did not find expected numthreads attribute");
  158. }
  159. m_dispatchPipelineState = shader->AcquirePipelineState(pipelineDesc);
  160. if (!m_dispatchPipelineState)
  161. {
  162. AZ_Error(InputAssembly::SampleName, false, "Failed to acquire default pipeline state for shader '%s'", shaderFilePath);
  163. return;
  164. }
  165. m_dispatchSRG[0] = CreateShaderResourceGroup(shader, "DispatchSRG", InputAssembly::SampleName);
  166. m_dispatchSRG[1] = CreateShaderResourceGroup(shader, "DispatchSRG", InputAssembly::SampleName);
  167. FindShaderInputIndex(&m_dispatchTimeConstantIndex, m_dispatchSRG[0], AZ::Name{ InputAssembly::ShaderInputTime }, InputAssembly::SampleName);
  168. FindShaderInputIndex(&m_dispatchIABufferIndex, m_dispatchSRG[0], AZ::Name{ InputAssembly::ShaderInpuIABuffer }, InputAssembly::SampleName);
  169. }
  170. void InputAssemblyExampleComponent::LoadRasterShader()
  171. {
  172. using namespace AZ;
  173. const char* shaderFilePath = "Shaders/RHI/InputAssemblyDraw.azshader";
  174. auto shader = LoadShader(shaderFilePath, InputAssembly::SampleName);
  175. if (shader == nullptr)
  176. return;
  177. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  178. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  179. pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
  180. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  181. attachmentsBuilder.AddSubpass()
  182. ->RenderTargetAttachment(m_outputFormat);
  183. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  184. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  185. m_drawPipelineState = shader->AcquirePipelineState(pipelineDesc);
  186. if (!m_drawPipelineState)
  187. {
  188. AZ_Error(InputAssembly::SampleName, false, "Failed to acquire default pipeline state for shader '%s'", shaderFilePath);
  189. return;
  190. }
  191. m_drawSRG[0] = CreateShaderResourceGroup(shader, "DrawSRG", InputAssembly::SampleName);
  192. m_drawSRG[1] = CreateShaderResourceGroup(shader, "DrawSRG", InputAssembly::SampleName);
  193. FindShaderInputIndex(&m_drawMatrixIndex, m_drawSRG[0], AZ::Name{ InputAssembly::ShaderInputMatrix }, InputAssembly::SampleName);
  194. FindShaderInputIndex(&m_drawColorIndex, m_drawSRG[0], AZ::Name{ InputAssembly::ShaderInputColor }, InputAssembly::SampleName);
  195. }
  196. void InputAssemblyExampleComponent::CreateComputeScope()
  197. {
  198. using namespace AZ;
  199. struct ScopeData
  200. {
  201. //UserDataParam - Empty for this samples
  202. };
  203. const auto prepareFunction = [](RHI::FrameGraphInterface frameGraph, ScopeData& scopeData)
  204. {
  205. AZ_UNUSED(scopeData);
  206. // Declare usage of the vertex buffer as UAV
  207. {
  208. RHI::BufferScopeAttachmentDescriptor attachmentDescriptor;
  209. attachmentDescriptor.m_attachmentId = InputAssembly::InputAssemblyBufferAttachmentId;
  210. attachmentDescriptor.m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, BufferData::array_size, sizeof(BufferData::value_type));
  211. attachmentDescriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  212. frameGraph.UseShaderAttachment(
  213. attachmentDescriptor, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::ComputeShader);
  214. }
  215. {
  216. RHI::BufferScopeAttachmentDescriptor attachmentDescriptor;
  217. attachmentDescriptor.m_attachmentId = InputAssembly::ImportedInputAssemblyBufferAttachmentId;
  218. attachmentDescriptor.m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, BufferData::array_size, sizeof(BufferData::value_type));
  219. attachmentDescriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  220. frameGraph.UseShaderAttachment(
  221. attachmentDescriptor, RHI::ScopeAttachmentAccess::ReadWrite, RHI::ScopeAttachmentStage::ComputeShader);
  222. }
  223. frameGraph.SetEstimatedItemCount(2);
  224. };
  225. const auto compileFunction = [this](const RHI::FrameGraphCompileContext& context, const ScopeData& scopeData)
  226. {
  227. AZ_UNUSED(scopeData);
  228. {
  229. const auto* inputAssemblyBufferView = context.GetBufferView(RHI::AttachmentId{ InputAssembly::InputAssemblyBufferAttachmentId });
  230. m_dispatchSRG[0]->SetBufferView(m_dispatchIABufferIndex, inputAssemblyBufferView);
  231. m_dispatchSRG[0]->SetConstant(m_dispatchTimeConstantIndex, m_time);
  232. m_dispatchSRG[0]->Compile();
  233. }
  234. {
  235. const auto* inputAssemblyBufferView = context.GetBufferView(RHI::AttachmentId{ InputAssembly::ImportedInputAssemblyBufferAttachmentId });
  236. m_dispatchSRG[1]->SetBufferView(m_dispatchIABufferIndex, inputAssemblyBufferView);
  237. m_dispatchSRG[1]->SetConstant(m_dispatchTimeConstantIndex, m_time);
  238. m_dispatchSRG[1]->Compile();
  239. }
  240. };
  241. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, const ScopeData& scopeData)
  242. {
  243. AZ_UNUSED(scopeData);
  244. RHI::CommandList* commandList = context.GetCommandList();
  245. RHI::DeviceDispatchItem dispatchItem;
  246. RHI::DispatchDirect dispatchArgs;
  247. dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(m_numThreadsX);
  248. dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(m_numThreadsY);
  249. dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(m_numThreadsZ);
  250. dispatchArgs.m_totalNumberOfThreadsX = 1;
  251. dispatchArgs.m_totalNumberOfThreadsY = 1;
  252. dispatchArgs.m_totalNumberOfThreadsZ = 1;
  253. dispatchItem.m_arguments = dispatchArgs;
  254. dispatchItem.m_pipelineState = m_dispatchPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  255. dispatchItem.m_shaderResourceGroupCount = 1;
  256. for (uint32_t index = context.GetSubmitRange().m_startIndex; index < context.GetSubmitRange().m_endIndex; ++index)
  257. {
  258. dispatchItem.m_shaderResourceGroups[0] = m_dispatchSRG[index]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get();
  259. commandList->Submit(dispatchItem, index);
  260. }
  261. };
  262. m_scopeProducers.emplace_back(
  263. aznew RHI::ScopeProducerFunction<
  264. ScopeData,
  265. decltype(prepareFunction),
  266. decltype(compileFunction),
  267. decltype(executeFunction)>(
  268. RHI::ScopeId{"IACompute"},
  269. ScopeData{},
  270. prepareFunction,
  271. compileFunction,
  272. executeFunction));
  273. }
  274. void InputAssemblyExampleComponent::CreateRasterScope()
  275. {
  276. using namespace AZ;
  277. struct ScopeData
  278. {
  279. };
  280. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, ScopeData& scopeData)
  281. {
  282. AZ_UNUSED(scopeData);
  283. // Binds the swap chain as a color attachment.
  284. {
  285. RHI::ImageScopeAttachmentDescriptor descriptor;
  286. descriptor.m_attachmentId = m_outputAttachmentId;
  287. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  288. frameGraph.UseColorAttachment(descriptor);
  289. }
  290. // Declare the usage of the vertex buffer as Input Assembly. This is needed because we modify the vertex buffer in the GPU
  291. // and it needs synchronization.
  292. {
  293. RHI::BufferScopeAttachmentDescriptor attachmentDescriptor;
  294. attachmentDescriptor.m_attachmentId = InputAssembly::InputAssemblyBufferAttachmentId;
  295. attachmentDescriptor.m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, BufferData::array_size, sizeof(BufferData::value_type));
  296. attachmentDescriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  297. attachmentDescriptor.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::DontCare;
  298. frameGraph.UseInputAssemblyAttachment(attachmentDescriptor);
  299. }
  300. {
  301. RHI::BufferScopeAttachmentDescriptor attachmentDescriptor;
  302. attachmentDescriptor.m_attachmentId = InputAssembly::ImportedInputAssemblyBufferAttachmentId;
  303. attachmentDescriptor.m_bufferViewDescriptor = RHI::BufferViewDescriptor::CreateStructured(0, BufferData::array_size, sizeof(BufferData::value_type));
  304. attachmentDescriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  305. attachmentDescriptor.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::DontCare;
  306. frameGraph.UseInputAssemblyAttachment(attachmentDescriptor);
  307. }
  308. // We will submit a single draw item.
  309. frameGraph.SetEstimatedItemCount(2);
  310. };
  311. const auto compileFunction = [this](const RHI::FrameGraphCompileContext& context, const ScopeData& scopeData)
  312. {
  313. AZ_UNUSED(scopeData);
  314. {
  315. const auto* inputAssemblyBufferView = context.GetBufferView(RHI::AttachmentId{ InputAssembly::InputAssemblyBufferAttachmentId });
  316. if (inputAssemblyBufferView)
  317. {
  318. m_geometryView[0].ClearStreamBufferViews();
  319. m_geometryView[0].AddStreamBufferView({*(inputAssemblyBufferView->GetBuffer()), 0, sizeof(BufferData), sizeof(BufferData::value_type)});
  320. RHI::ValidateStreamBufferViews(m_inputStreamLayout, m_geometryView[0], m_geometryView[0].GetFullStreamBufferIndices());
  321. }
  322. }
  323. {
  324. const auto* inputAssemblyBufferView = context.GetBufferView(RHI::AttachmentId{ InputAssembly::ImportedInputAssemblyBufferAttachmentId });
  325. if (inputAssemblyBufferView)
  326. {
  327. m_geometryView[1].ClearStreamBufferViews();
  328. m_geometryView[1].AddStreamBufferView({*(inputAssemblyBufferView->GetBuffer()), 0, sizeof(BufferData), sizeof(BufferData::value_type)});
  329. RHI::ValidateStreamBufferViews(m_inputStreamLayout, m_geometryView[1], m_geometryView[1].GetFullStreamBufferIndices());
  330. }
  331. }
  332. };
  333. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, const ScopeData& scopeData)
  334. {
  335. AZ_UNUSED(scopeData);
  336. RHI::CommandList* commandList = context.GetCommandList();
  337. // Set persistent viewport and scissor state.
  338. commandList->SetViewports(&m_viewport, 1);
  339. commandList->SetScissors(&m_scissor, 1);
  340. RHI::DrawLinear drawLinear;
  341. drawLinear.m_vertexCount = BufferData::array_size;
  342. RHI::DeviceDrawItem drawItem;
  343. drawItem.m_pipelineState = m_drawPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  344. drawItem.m_shaderResourceGroupCount = 1;
  345. for (uint32_t index = context.GetSubmitRange().m_startIndex; index < context.GetSubmitRange().m_endIndex; ++index)
  346. {
  347. m_geometryView[index].SetDrawArguments(drawLinear);
  348. drawItem.m_geometryView = m_geometryView[index].GetDeviceGeometryView(context.GetDeviceIndex());
  349. drawItem.m_streamIndices = m_geometryView[index].GetFullStreamBufferIndices();
  350. RHI::DeviceShaderResourceGroup* rhiSRGS[] = {
  351. m_drawSRG[index]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  352. };
  353. drawItem.m_shaderResourceGroups = rhiSRGS;
  354. commandList->Submit(drawItem, index);
  355. }
  356. };
  357. m_scopeProducers.emplace_back(
  358. aznew RHI::ScopeProducerFunction<
  359. ScopeData,
  360. decltype(prepareFunction),
  361. decltype(compileFunction),
  362. decltype(executeFunction)>(
  363. RHI::ScopeId{"IARaster"},
  364. ScopeData{},
  365. prepareFunction,
  366. compileFunction,
  367. executeFunction));
  368. }
  369. }