3
0

RasterPass.cpp 11 KB


  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <Atom/RHI/CommandList.h>
  9. #include <Atom/RHI/DrawListTagRegistry.h>
  10. #include <Atom/RHI/RHISystemInterface.h>
  11. #include <Atom/RHI/ShaderResourceGroup.h>
  12. #include <Atom/RPI.Public/DynamicDraw/DynamicDrawInterface.h>
  13. #include <Atom/RPI.Public/Pass/RasterPass.h>
  14. #include <Atom/RPI.Public/RenderPipeline.h>
  15. #include <Atom/RPI.Public/RPISystemInterface.h>
  16. #include <Atom/RPI.Public/Scene.h>
  17. #include <Atom/RPI.Public/View.h>
  18. #include <Atom/RPI.Reflect/Asset/AssetUtils.h>
  19. #include <Atom/RPI.Reflect/Pass/RasterPassData.h>
  20. namespace AZ
  21. {
  22. namespace RPI
  23. {
  24. // --- Creation & Initialization ---
  25. Ptr<RasterPass> RasterPass::Create(const PassDescriptor& descriptor)
  26. {
  27. Ptr<RasterPass> pass = aznew RasterPass(descriptor);
  28. return pass;
  29. }
  30. RasterPass::RasterPass(const PassDescriptor& descriptor)
  31. : RenderPass(descriptor)
  32. {
  33. const RasterPassData* rasterData = PassUtils::GetPassData<RasterPassData>(descriptor);
  34. // If we successfully retrieved our custom data, use it to set the DrawListTag
  35. if (rasterData == nullptr)
  36. {
  37. return;
  38. }
  39. SetDrawListTag(rasterData->m_drawListTag);
  40. m_drawListSortType = rasterData->m_drawListSortType;
  41. RHI::RHISystemInterface::Get()->SetDrawListTagEnabledByDefault(m_drawListTag, rasterData->m_enableDrawItemsByDefault);
  42. // Get the shader asset that contains the SRG Layout.
  43. Data::Asset<ShaderAsset> shaderAsset;
  44. if (rasterData->m_passSrgShaderReference.m_assetId.IsValid())
  45. {
  46. shaderAsset = AssetUtils::LoadAssetById<ShaderAsset>(rasterData->m_passSrgShaderReference.m_assetId, AssetUtils::TraceLevel::Error);
  47. }
  48. else if (!rasterData->m_passSrgShaderReference.m_filePath.empty())
  49. {
  50. shaderAsset = AssetUtils::LoadAssetByProductPath<ShaderAsset>(
  51. rasterData->m_passSrgShaderReference.m_filePath.c_str(), AssetUtils::TraceLevel::Error);
  52. }
  53. if (shaderAsset)
  54. {
  55. const auto srgLayout = shaderAsset->FindShaderResourceGroupLayout(SrgBindingSlot::Pass);
  56. if (srgLayout)
  57. {
  58. m_shaderResourceGroup = ShaderResourceGroup::Create(shaderAsset, srgLayout->GetName());
  59. AZ_Assert(
  60. m_shaderResourceGroup, "[RasterPass '%s']: Failed to create SRG from shader asset '%s'", GetPathName().GetCStr(),
  61. rasterData->m_passSrgShaderReference.m_filePath.data());
  62. PassUtils::BindDataMappingsToSrg(descriptor, m_shaderResourceGroup.get());
  63. }
  64. }
  65. if (!rasterData->m_overrideScissor.IsNull())
  66. {
  67. m_scissorState = rasterData->m_overrideScissor;
  68. m_overrideScissorSate = true;
  69. }
  70. if (!rasterData->m_overrideViewport.IsNull())
  71. {
  72. m_viewportState = rasterData->m_overrideViewport;
  73. m_overrideViewportState = true;
  74. }
  75. m_viewportAndScissorTargetOutputIndex = rasterData->m_viewportAndScissorTargetOutputIndex;
  76. }
  77. RasterPass::~RasterPass()
  78. {
  79. if (m_drawListTag != RHI::DrawListTag::Null)
  80. {
  81. RHI::RHISystemInterface* rhiSystem = RHI::RHISystemInterface::Get();
  82. rhiSystem->GetDrawListTagRegistry()->ReleaseTag(m_drawListTag);
  83. }
  84. }
  85. void RasterPass::SetDrawListTag(Name drawListName)
  86. {
  87. // Use AcquireTag to register a draw list tag if it doesn't exist.
  88. RHI::RHISystemInterface* rhiSystem = RHI::RHISystemInterface::Get();
  89. m_drawListTag = rhiSystem->GetDrawListTagRegistry()->AcquireTag(drawListName);
  90. m_flags.m_hasDrawListTag = true;
  91. }
  92. void RasterPass::SetPipelineStateDataIndex(uint32_t index)
  93. {
  94. m_pipelineStateDataIndex.m_index = index;
  95. }
  96. ShaderResourceGroup* RasterPass::GetShaderResourceGroup()
  97. {
  98. return m_shaderResourceGroup.get();
  99. }
  100. uint32_t RasterPass::GetDrawItemCount()
  101. {
  102. return m_drawItemCount;
  103. }
  104. // --- Pass behaviour overrides ---
  105. void RasterPass::Validate(PassValidationResults& validationResults)
  106. {
  107. AZ_RPI_PASS_ERROR(m_drawListTag.IsValid(), "DrawListTag for RasterPass [%s] is invalid", GetPathName().GetCStr());
  108. AZ_RPI_PASS_ERROR(!GetPipelineViewTag().IsEmpty(), "ViewTag for RasterPass [%s] is invalid", GetPathName().GetCStr());
  109. RenderPass::Validate(validationResults);
  110. }
  111. void RasterPass::FrameBeginInternal(FramePrepareParams params)
  112. {
  113. // Binding to use for viewport and scissor calculations
  114. PassAttachmentBinding* viewportTarget = nullptr;
  115. // If a target binding for viewport calculation is specified
  116. if (m_viewportAndScissorTargetOutputIndex >= 0)
  117. {
  118. u32 idx = u32(m_viewportAndScissorTargetOutputIndex);
  119. // First check outputs
  120. if (GetOutputCount() > idx)
  121. {
  122. viewportTarget = &GetOutputBinding(idx);
  123. }
  124. // If not an output, check input/outputs
  125. else if (GetInputOutputCount() > idx)
  126. {
  127. viewportTarget = &GetInputOutputBinding(idx);
  128. }
  129. }
  130. // Build viewport and scissor from target binding if specified
  131. if (viewportTarget)
  132. {
  133. u32 targetWidth = viewportTarget->GetAttachment()->m_descriptor.m_image.m_size.m_width;
  134. u32 targetHeight = viewportTarget->GetAttachment()->m_descriptor.m_image.m_size.m_height;
  135. m_scissorState = RHI::Scissor(0, 0, targetWidth, targetHeight);
  136. m_viewportState = RHI::Viewport(0, static_cast<float>(targetWidth), 0, static_cast<float>(targetHeight));
  137. }
  138. // Otherwise check whether viewport/scissor overrides were manually provided
  139. else
  140. {
  141. if (!m_overrideScissorSate)
  142. {
  143. m_scissorState = params.m_scissorState;
  144. }
  145. if (!m_overrideViewportState)
  146. {
  147. m_viewportState = params.m_viewportState;
  148. }
  149. }
  150. UpdateDrawList();
  151. RenderPass::FrameBeginInternal(params);
  152. }
  153. void RasterPass::UpdateDrawList()
  154. {
  155. // DrawLists from dynamic draw
  156. AZStd::vector<RHI::DrawListView> drawLists = DynamicDrawInterface::Get()->GetDrawListsForPass(this);
  157. // Get DrawList from view
  158. const AZStd::vector<ViewPtr>& views = m_pipeline->GetViews(GetPipelineViewTag());
  159. RHI::DrawListView viewDrawList;
  160. if (!views.empty())
  161. {
  162. const ViewPtr& view = views.front();
  163. // Assert the view has our draw list (the view's DrawlistTags are collected from passes using its viewTag)
  164. AZ_Assert(view->HasDrawListTag(m_drawListTag), "View's DrawListTags out of sync with pass'. ");
  165. // Draw List
  166. viewDrawList = view->GetDrawList(m_drawListTag);
  167. }
  168. // clean up data
  169. m_drawListView = {};
  170. m_combinedDrawList.clear();
  171. m_drawItemCount = 0;
  172. // draw list from view was sorted and if it's the only draw list then we can use it directly
  173. if (viewDrawList.size() > 0 && drawLists.size() == 0)
  174. {
  175. m_drawListView = viewDrawList;
  176. m_drawItemCount += static_cast<uint32_t>(viewDrawList.size());
  177. PassSystemInterface::Get()->IncrementFrameDrawItemCount(m_drawItemCount);
  178. return;
  179. }
  180. // add view's draw list to drawLists too
  181. drawLists.push_back(viewDrawList);
  182. // combine draw items from mutiple draw lists to one draw list and sort it.
  183. for (auto drawList : drawLists)
  184. {
  185. m_drawItemCount += static_cast<uint32_t>(drawList.size());
  186. }
  187. PassSystemInterface::Get()->IncrementFrameDrawItemCount(m_drawItemCount);
  188. m_combinedDrawList.resize(m_drawItemCount);
  189. RHI::DrawItemProperties* currentBuffer = m_combinedDrawList.data();
  190. for (auto drawList : drawLists)
  191. {
  192. memcpy(currentBuffer, drawList.data(), drawList.size()*sizeof(RHI::DrawItemProperties));
  193. currentBuffer += drawList.size();
  194. }
  195. SortDrawList(m_combinedDrawList);
  196. // have the final draw list point to the combined draw list.
  197. m_drawListView = m_combinedDrawList;
  198. }
  199. // --- DrawList and PipelineView Tags ---
  200. RHI::DrawListTag RasterPass::GetDrawListTag() const
  201. {
  202. return m_drawListTag;
  203. }
  204. // --- Scope producer functions ---
  205. void RasterPass::SetupFrameGraphDependencies(RHI::FrameGraphInterface frameGraph)
  206. {
  207. RenderPass::SetupFrameGraphDependencies(frameGraph);
  208. frameGraph.SetEstimatedItemCount(static_cast<uint32_t>(m_drawListView.size()));
  209. }
  210. void RasterPass::CompileResources(const RHI::FrameGraphCompileContext& context)
  211. {
  212. if (m_shaderResourceGroup == nullptr)
  213. {
  214. return;
  215. }
  216. BindPassSrg(context, m_shaderResourceGroup);
  217. m_shaderResourceGroup->Compile();
  218. }
  219. void RasterPass::SubmitDrawItems(const RHI::FrameGraphExecuteContext& context, uint32_t startIndex, uint32_t endIndex, uint32_t indexOffset) const
  220. {
  221. RHI::CommandList* commandList = context.GetCommandList();
  222. uint32_t clampedEndIndex = AZStd::GetMin<uint32_t>(endIndex, static_cast<uint32_t>(m_drawListView.size()));
  223. for (uint32_t index = startIndex; index < clampedEndIndex; ++index)
  224. {
  225. const RHI::DrawItemProperties& drawItemProperties = m_drawListView[index];
  226. if (drawItemProperties.m_drawFilterMask & m_pipeline->GetDrawFilterMask())
  227. {
  228. commandList->Submit(*drawItemProperties.m_item, index + indexOffset);
  229. }
  230. }
  231. }
  232. void RasterPass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  233. {
  234. RHI::CommandList* commandList = context.GetCommandList();
  235. if (context.GetSubmitRange().m_startIndex != context.GetSubmitRange().m_endIndex)
  236. {
  237. commandList->SetViewport(m_viewportState);
  238. commandList->SetScissor(m_scissorState);
  239. SetSrgsForDraw(commandList);
  240. SubmitDrawItems(context, context.GetSubmitRange().m_startIndex, context.GetSubmitRange().m_endIndex, 0);
  241. }
  242. }
  243. } // namespace RPI
  244. } // namespace AZ