MultiThreadComponent.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RHI/MultiThreadComponent.h>
  9. #include <AzCore/Math/MatrixUtils.h>
  10. #include <Atom/RHI/DrawItem.h>
  11. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  12. #include <Atom/RPI.Public/Shader/Shader.h>
  13. #include <AzCore/Math/Random.h>
  14. #include <SampleComponentManager.h>
  15. #include <Utils/Utils.h>
  16. namespace AtomSampleViewer
  17. {
  18. // static const variables.
  19. const AZ::Vector3 MultiThreadComponent::m_up = AZ::Vector3(0.0f, 1.0f, 0.0f);
  20. void MultiThreadComponent::Reflect(AZ::ReflectContext* context)
  21. {
  22. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  23. {
  24. serializeContext->Class<MultiThreadComponent, AZ::Component>()
  25. ->Version(0)
  26. ;
  27. }
  28. }
  29. MultiThreadComponent::MultiThreadComponent()
  30. {
  31. m_depthStencilID = AZ::RHI::AttachmentId{ "DepthStencilID" };
  32. uint32_t index = 0;
  33. // Create positions for each cube
  34. for (uint32_t j = 0; j < s_cubesPerLine*s_cubeSpacing; j+= s_cubeSpacing)
  35. {
  36. for (uint32_t i = 0; i < s_cubesPerLine*s_cubeSpacing; i+= s_cubeSpacing)
  37. {
  38. m_cubeTransforms[index] = AZ::Matrix4x4::CreateTranslation(AZ::Vector3(static_cast<float>(i), static_cast<float>(j), 0.0f));
  39. ++index;
  40. }
  41. }
  42. m_supportRHISamplePipeline = true;
  43. }
  44. void MultiThreadComponent::OnFramePrepare(AZ::RHI::FrameGraphBuilder& frameGraphBuilder)
  45. {
  46. m_time += 0.005f;
  47. BasicRHIComponent::OnFramePrepare(frameGraphBuilder);
  48. }
  49. void MultiThreadComponent::Activate()
  50. {
  51. // This is done here instead of doing it in the constructor,
  52. // since m_windowContext might not be yet initialized at construction time.
  53. float fieldOfView = AZ::Constants::Pi / 4.0f;
  54. float screenAspect = GetViewportWidth() / GetViewportHeight();
  55. float heighOfCubePlane = static_cast<float>(s_cubesPerLine * s_cubeSpacing);
  56. float distanceFromCubePlane = 1.0f * (1 / tanf(fieldOfView/2)) * heighOfCubePlane/2;
  57. float centerOfScreen = heighOfCubePlane/2;
  58. AZ::Vector3 m_worldPosition = AZ::Vector3(centerOfScreen, centerOfScreen, distanceFromCubePlane);
  59. m_lookAt = AZ::Vector3(centerOfScreen, centerOfScreen, 0.0f);
  60. MakePerspectiveFovMatrixRH(m_viewProjMatrix, fieldOfView, screenAspect, m_zNear, m_zFar);
  61. m_viewProjMatrix = m_viewProjMatrix * CreateViewMatrix(m_worldPosition, m_up, m_lookAt);
  62. CreateInputAssemblyBuffer();
  63. CreatePipeline();
  64. CreateScope();
  65. AZ::RHI::RHISystemNotificationBus::Handler::BusConnect();
  66. }
  67. void MultiThreadComponent::Deactivate()
  68. {
  69. AZ::RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  70. m_windowContext = nullptr;
  71. m_bufferPool = nullptr;
  72. m_inputAssemblyBuffer = nullptr;
  73. m_shaderResourceGroups.fill(nullptr);
  74. m_pipelineState = nullptr;
  75. m_scopeProducers.clear();
  76. }
  77. MultiThreadComponent::SingleCubeBufferData MultiThreadComponent::CreateSingleCubeBufferData(const AZ::Vector4 color)
  78. {
  79. // Create vertices, colors and normals for a cube and a plane
  80. SingleCubeBufferData bufferData;
  81. {
  82. AZStd::vector<AZ::Vector3> vertices =
  83. {
  84. //Front Face
  85. AZ::Vector3(1.0, 1.0, 1.0), AZ::Vector3(-1.0, 1.0, 1.0), AZ::Vector3(-1.0, -1.0, 1.0), AZ::Vector3(1.0, -1.0, 1.0),
  86. //Back Face
  87. AZ::Vector3(1.0, 1.0, -1.0), AZ::Vector3(-1.0, 1.0, -1.0), AZ::Vector3(-1.0, -1.0, -1.0), AZ::Vector3(1.0, -1.0, -1.0),
  88. //Left Face
  89. AZ::Vector3(-1.0, 1.0, 1.0), AZ::Vector3(-1.0, -1.0, 1.0), AZ::Vector3(-1.0, -1.0, -1.0), AZ::Vector3(-1.0, 1.0, -1.0),
  90. //Right Face
  91. AZ::Vector3(1.0, 1.0, 1.0), AZ::Vector3(1.0, -1.0, 1.0), AZ::Vector3(1.0, -1.0, -1.0), AZ::Vector3(1.0, 1.0, -1.0),
  92. //Top Face
  93. AZ::Vector3(1.0, 1.0, 1.0), AZ::Vector3(-1.0, 1.0, 1.0), AZ::Vector3(-1.0, 1.0, -1.0), AZ::Vector3(1.0, 1.0, -1.0),
  94. //Bottom Face
  95. AZ::Vector3(1.0, -1.0, 1.0), AZ::Vector3(-1.0, -1.0, 1.0), AZ::Vector3(-1.0, -1.0, -1.0), AZ::Vector3(1.0, -1.0, -1.0),
  96. };
  97. for (int i = 0; i < s_geometryVertexCount; ++i)
  98. {
  99. SetVertexPosition(bufferData.m_positions.data(), i, vertices[i]);
  100. SetVertexColor(bufferData.m_colors.data(), i, color);
  101. }
  102. bufferData.m_indices =
  103. {
  104. {
  105. //Back
  106. 2, 0, 1,
  107. 0, 2, 3,
  108. //Front
  109. 4, 6, 5,
  110. 6, 4, 7,
  111. //Left
  112. 8, 10, 9,
  113. 10, 8, 11,
  114. //Right
  115. 14, 12, 13,
  116. 15, 12, 14,
  117. //Top
  118. 16, 18, 17,
  119. 18, 16, 19,
  120. //Bottom
  121. 22, 20, 21,
  122. 23, 20, 22,
  123. }
  124. };
  125. }
  126. return bufferData;
  127. }
  128. void MultiThreadComponent::CreateInputAssemblyBuffer()
  129. {
  130. const AZ::RHI::Ptr<AZ::RHI::Device> device = Utils::GetRHIDevice();
  131. AZ::RHI::ResultCode result = AZ::RHI::ResultCode::Success;
  132. m_bufferPool = AZ::RHI::Factory::Get().CreateBufferPool();
  133. AZ::RHI::BufferPoolDescriptor bufferPoolDesc;
  134. bufferPoolDesc.m_bindFlags = AZ::RHI::BufferBindFlags::InputAssembly;
  135. bufferPoolDesc.m_heapMemoryLevel = AZ::RHI::HeapMemoryLevel::Device;
  136. result = m_bufferPool->Init(*device, bufferPoolDesc);
  137. if (result != AZ::RHI::ResultCode::Success)
  138. {
  139. AZ_Error("MultiThreadComponent", false, "Failed to initialize buffer pool with error code %d", result);
  140. return;
  141. }
  142. SingleCubeBufferData bufferData = CreateSingleCubeBufferData(AZ::Vector4(1.0f, 0.0f, 0.0f, 0.0f));
  143. m_inputAssemblyBuffer = AZ::RHI::Factory::Get().CreateBuffer();
  144. AZ::RHI::BufferInitRequest request;
  145. request.m_buffer = m_inputAssemblyBuffer.get();
  146. request.m_descriptor = AZ::RHI::BufferDescriptor{ AZ::RHI::BufferBindFlags::InputAssembly, sizeof(SingleCubeBufferData) };
  147. request.m_initialData = &bufferData;
  148. result = m_bufferPool->InitBuffer(request);
  149. if (result != AZ::RHI::ResultCode::Success)
  150. {
  151. AZ_Error("MultiThreadComponent", false, "Failed to initialize buffer with error code %d", result);
  152. return;
  153. }
  154. m_streamBufferViews[0] =
  155. {
  156. *m_inputAssemblyBuffer,
  157. offsetof(SingleCubeBufferData, m_positions),
  158. sizeof(SingleCubeBufferData::m_positions),
  159. sizeof(VertexPosition)
  160. };
  161. m_streamBufferViews[1] =
  162. {
  163. *m_inputAssemblyBuffer,
  164. offsetof(SingleCubeBufferData, m_colors),
  165. sizeof(SingleCubeBufferData::m_colors),
  166. sizeof(VertexColor)
  167. };
  168. m_indexBufferView =
  169. {
  170. *m_inputAssemblyBuffer,
  171. offsetof(SingleCubeBufferData, m_indices),
  172. sizeof(SingleCubeBufferData::m_indices),
  173. AZ::RHI::IndexFormat::Uint16
  174. };
  175. AZ::RHI::InputStreamLayoutBuilder layoutBuilder;
  176. layoutBuilder.SetTopology(AZ::RHI::PrimitiveTopology::TriangleList);
  177. layoutBuilder.AddBuffer()->Channel("POSITION", AZ::RHI::Format::R32G32B32_FLOAT);
  178. layoutBuilder.AddBuffer()->Channel("COLOR", AZ::RHI::Format::R32G32B32A32_FLOAT);
  179. m_streamLayoutDescriptor.Clear();
  180. m_streamLayoutDescriptor = layoutBuilder.End();
  181. AZ::RHI::ValidateStreamBufferViews(m_streamLayoutDescriptor, m_streamBufferViews);
  182. }
  183. void MultiThreadComponent::CreatePipeline()
  184. {
  185. const char* shaderFilePath = "Shaders/RHI/MultiThread.azshader";
  186. const char* sampleName = "MultiThreadComponent";
  187. auto shader = LoadShader(shaderFilePath, sampleName);
  188. if (shader == nullptr)
  189. return;
  190. const AZ::RHI::Ptr<AZ::RHI::Device> device = Utils::GetRHIDevice();
  191. AZ::RHI::PipelineStateDescriptorForDraw pipelineDesc;
  192. shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  193. pipelineDesc.m_inputStreamLayout = m_streamLayoutDescriptor;
  194. pipelineDesc.m_renderStates.m_depthStencilState.m_depth.m_enable = 1;
  195. pipelineDesc.m_renderStates.m_depthStencilState.m_depth.m_func = AZ::RHI::ComparisonFunc::LessEqual;
  196. AZ::RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  197. attachmentsBuilder.AddSubpass()
  198. ->RenderTargetAttachment(m_outputFormat)
  199. ->DepthStencilAttachment(device->GetNearestSupportedFormat(AZ::RHI::Format::D24_UNORM_S8_UINT, AZ::RHI::FormatCapabilities::DepthStencil));
  200. [[maybe_unused]] AZ::RHI::ResultCode result = attachmentsBuilder.End(pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout);
  201. AZ_Assert(result == AZ::RHI::ResultCode::Success, "Failed to create render attachment layout");
  202. m_pipelineState = shader->AcquirePipelineState(pipelineDesc);
  203. if (!m_pipelineState)
  204. {
  205. AZ_Error("MultiThreadComponent", false, "Failed to acquire default pipeline state for shader '%s'", shaderFilePath);
  206. return;
  207. }
  208. auto perInstanceSrgLayout = shader->FindShaderResourceGroupLayout(AZ::Name{ "MultiThreadInstanceSrg" });
  209. if (!perInstanceSrgLayout)
  210. {
  211. AZ_Error("MultiThreadComponent", false, "Failed to get shader resource group layout");
  212. return;
  213. }
  214. for (int i = 0; i < s_numberOfCubes; ++i)
  215. {
  216. m_shaderResourceGroups[i] = CreateShaderResourceGroup(shader, "MultiThreadInstanceSrg", sampleName);
  217. FindShaderInputIndex(&m_shaderIndexWorldMat, m_shaderResourceGroups[i], AZ::Name{"m_worldMatrix"}, "MultiThreadComponent");
  218. FindShaderInputIndex(&m_shaderIndexViewProj, m_shaderResourceGroups[i], AZ::Name{"m_viewProjMatrix"}, "MultiThreadComponent");
  219. }
  220. }
  221. void MultiThreadComponent::CreateScope()
  222. {
  223. const auto prepareFunction = [this](AZ::RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  224. {
  225. // Binds the swap chain as a color attachment. Clears it to black.
  226. {
  227. AZ::RHI::ImageScopeAttachmentDescriptor descriptor;
  228. descriptor.m_attachmentId = m_outputAttachmentId;
  229. descriptor.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  230. frameGraph.UseColorAttachment(descriptor);
  231. }
  232. // Create & Binds DepthStencil image
  233. {
  234. const AZ::RHI::Ptr<AZ::RHI::Device> device = Utils::GetRHIDevice();
  235. const AZ::RHI::ImageDescriptor imageDescriptor = AZ::RHI::ImageDescriptor::Create2D(
  236. AZ::RHI::ImageBindFlags::DepthStencil,
  237. m_outputWidth,
  238. m_outputHeight,
  239. device->GetNearestSupportedFormat(AZ::RHI::Format::D24_UNORM_S8_UINT, AZ::RHI::FormatCapabilities::DepthStencil));
  240. const AZ::RHI::TransientImageDescriptor transientImageDescriptor(m_depthStencilID, imageDescriptor);
  241. frameGraph.GetAttachmentDatabase().CreateTransientImage(transientImageDescriptor);
  242. AZ::RHI::ImageScopeAttachmentDescriptor dsDesc;
  243. dsDesc.m_attachmentId = m_depthStencilID;
  244. dsDesc.m_imageViewDescriptor.m_overrideFormat = device->GetNearestSupportedFormat(AZ::RHI::Format::D24_UNORM_S8_UINT, AZ::RHI::FormatCapabilities::DepthStencil);
  245. dsDesc.m_loadStoreAction.m_clearValue = AZ::RHI::ClearValue::CreateDepthStencil(1.0f, 0);
  246. dsDesc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Clear;
  247. frameGraph.UseDepthStencilAttachment(dsDesc, AZ::RHI::ScopeAttachmentAccess::Write);
  248. }
  249. // We will submit s_numberOfCubes draw items.
  250. frameGraph.SetEstimatedItemCount(s_numberOfCubes);
  251. };
  252. const auto compileFunction = [this]([[maybe_unused]] const AZ::RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  253. {
  254. AZ::Matrix4x4 rotation = AZ::Matrix4x4::CreateRotationY(m_time);
  255. for(int i = 0; i<s_numberOfCubes; ++i)
  256. {
  257. AZ::Matrix4x4 transform = m_cubeTransforms[i] * rotation;
  258. m_shaderResourceGroups[i]->SetConstant(m_shaderIndexWorldMat, transform);
  259. m_shaderResourceGroups[i]->SetConstant(m_shaderIndexViewProj, m_viewProjMatrix);
  260. m_shaderResourceGroups[i]->Compile();
  261. }
  262. };
  263. const auto executeFunction = [this](const AZ::RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  264. {
  265. AZ::RHI::CommandList* commandList = context.GetCommandList();
  266. // Set persistent viewport and scissor state.
  267. commandList->SetViewports(&m_viewport, 1);
  268. commandList->SetScissors(&m_scissor, 1);
  269. AZ::RHI::DrawIndexed drawIndexed;
  270. drawIndexed.m_indexCount = s_geometryIndexCount;
  271. drawIndexed.m_instanceCount = 1;
  272. // Dividing s_numberOfCubes by context.GetCommandListCount() to balance to number
  273. // of draw call equally between each thread.
  274. uint32_t numberOfCubesPerCommandList = s_numberOfCubes / context.GetCommandListCount();
  275. uint32_t indexStart = context.GetCommandListIndex() * numberOfCubesPerCommandList;
  276. uint32_t indexEnd = indexStart + numberOfCubesPerCommandList;
  277. if (context.GetCommandListIndex() == context.GetCommandListCount() - 1)
  278. {
  279. indexEnd = s_numberOfCubes;
  280. #if defined(AZ_DEBUG_BUILD)
  281. AZ_Printf("MultiThread", "Draw Calls: %d \n", s_numberOfCubes);
  282. AZ_Printf("MultiThread", "Num CommandLists: %d \n", context.GetCommandListCount());
  283. #endif
  284. }
  285. for (uint32_t i = indexStart; i < indexEnd; ++i)
  286. {
  287. const AZ::RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_shaderResourceGroups[i]->GetRHIShaderResourceGroup() };
  288. AZ::RHI::DrawItem drawItem;
  289. drawItem.m_arguments = drawIndexed;
  290. drawItem.m_pipelineState = m_pipelineState.get();
  291. drawItem.m_indexBufferView = &m_indexBufferView;
  292. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(AZ::RHI::ArraySize(shaderResourceGroups));
  293. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  294. drawItem.m_streamBufferViewCount = static_cast<uint8_t>(m_streamBufferViews.size());
  295. drawItem.m_streamBufferViews = m_streamBufferViews.data();
  296. commandList->Submit(drawItem);
  297. }
  298. };
  299. m_scopeProducers.emplace_back(
  300. aznew AZ::RHI::ScopeProducerFunction<
  301. ScopeData,
  302. decltype(prepareFunction),
  303. decltype(compileFunction),
  304. decltype(executeFunction)>(
  305. AZ::RHI::ScopeId{"MultiThreadMain"},
  306. ScopeData{},
  307. prepareFunction,
  308. compileFunction,
  309. executeFunction));
  310. }
  311. }// namespace AtomSampleViewer