3
0

ComputePass.cpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <AzCore/Asset/AssetCommon.h>
  9. #include <AzCore/Asset/AssetManagerBus.h>
  10. #include <Atom/RHI/CommandList.h>
  11. #include <Atom/RHI/Factory.h>
  12. #include <Atom/RHI/FrameScheduler.h>
  13. #include <Atom/RHI/PipelineState.h>
  14. #include <Atom/RPI.Reflect/Pass/ComputePassData.h>
  15. #include <Atom/RPI.Reflect/Pass/PassTemplate.h>
  16. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  17. #include <Atom/RPI.Public/Pass/ComputePass.h>
  18. #include <Atom/RPI.Public/Pass/PassUtils.h>
  19. #include <Atom/RPI.Public/RPIUtils.h>
  20. #include <Atom/RPI.Public/Shader/Shader.h>
  21. #include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
  22. namespace AZ
  23. {
  24. namespace RPI
  25. {
  26. ComputePass::~ComputePass()
  27. {
  28. ShaderReloadNotificationBus::Handler::BusDisconnect();
  29. }
  30. Ptr<ComputePass> ComputePass::Create(const PassDescriptor& descriptor)
  31. {
  32. Ptr<ComputePass> pass = aznew ComputePass(descriptor);
  33. return pass;
  34. }
  35. ComputePass::ComputePass(const PassDescriptor& descriptor, AZ::Name supervariant)
  36. : RenderPass(descriptor)
  37. , m_passDescriptor(descriptor)
  38. {
  39. const ComputePassData* passData = PassUtils::GetPassData<ComputePassData>(m_passDescriptor);
  40. if (passData == nullptr)
  41. {
  42. AZ_Error(
  43. "PassSystem", false, "[ComputePass '%s']: Trying to construct without valid ComputePassData!", GetPathName().GetCStr());
  44. return;
  45. }
  46. RHI::DispatchDirect dispatchArgs;
  47. dispatchArgs.m_totalNumberOfThreadsX = passData->m_totalNumberOfThreadsX;
  48. dispatchArgs.m_totalNumberOfThreadsY = passData->m_totalNumberOfThreadsY;
  49. dispatchArgs.m_totalNumberOfThreadsZ = passData->m_totalNumberOfThreadsZ;
  50. m_dispatchItem.m_arguments = dispatchArgs;
  51. LoadShader(supervariant);
  52. }
  53. void ComputePass::LoadShader(AZ::Name supervariant)
  54. {
  55. // Load ComputePassData...
  56. const ComputePassData* passData = PassUtils::GetPassData<ComputePassData>(m_passDescriptor);
  57. if (passData == nullptr)
  58. {
  59. AZ_Error("PassSystem", false, "[ComputePass '%s']: Trying to construct without valid ComputePassData!",
  60. GetPathName().GetCStr());
  61. return;
  62. }
  63. // Hardware Queue Class
  64. if (passData->m_useAsyncCompute)
  65. {
  66. m_hardwareQueueClass = RHI::HardwareQueueClass::Compute;
  67. }
  68. // Load Shader
  69. Data::Asset<ShaderAsset> shaderAsset;
  70. if (passData->m_shaderReference.m_assetId.IsValid())
  71. {
  72. shaderAsset = RPI::FindShaderAsset(passData->m_shaderReference.m_assetId, passData->m_shaderReference.m_filePath);
  73. }
  74. if (!shaderAsset.IsReady())
  75. {
  76. AZ_Error("PassSystem", false, "[ComputePass '%s']: Failed to load shader '%s'!",
  77. GetPathName().GetCStr(),
  78. passData->m_shaderReference.m_filePath.data());
  79. return;
  80. }
  81. m_shader = Shader::FindOrCreate(shaderAsset, supervariant);
  82. if (m_shader == nullptr)
  83. {
  84. AZ_Error("PassSystem", false, "[ComputePass '%s']: Failed to create shader instance from asset '%s'!",
  85. GetPathName().GetCStr(),
  86. passData->m_shaderReference.m_filePath.data());
  87. return;
  88. }
  89. // Load Pass SRG...
  90. const auto passSrgLayout = m_shader->FindShaderResourceGroupLayout(SrgBindingSlot::Pass);
  91. if (passSrgLayout)
  92. {
  93. m_shaderResourceGroup = ShaderResourceGroup::Create(shaderAsset, m_shader->GetSupervariantIndex(), passSrgLayout->GetName());
  94. AZ_Assert(m_shaderResourceGroup, "[ComputePass '%s']: Failed to create SRG from shader asset '%s'",
  95. GetPathName().GetCStr(),
  96. passData->m_shaderReference.m_filePath.data());
  97. PassUtils::BindDataMappingsToSrg(m_passDescriptor, m_shaderResourceGroup.get());
  98. }
  99. // Load Draw SRG...
  100. const bool compileDrawSrg = false; // The SRG will be compiled in CompileResources()
  101. m_drawSrg = m_shader->CreateDefaultDrawSrg(compileDrawSrg);
  102. if (m_dispatchItem.m_arguments.m_type == RHI::DispatchType::Direct)
  103. {
  104. const auto outcome = RPI::GetComputeShaderNumThreads(m_shader->GetAsset(), m_dispatchItem.m_arguments.m_direct);
  105. if (!outcome.IsSuccess())
  106. {
  107. AZ_Error(
  108. "PassSystem",
  109. false,
  110. "[ComputePass '%s']: Shader '%.*s' contains invalid numthreads arguments:\n%s",
  111. GetPathName().GetCStr(),
  112. passData->m_shaderReference.m_filePath.size(),
  113. passData->m_shaderReference.m_filePath.data(),
  114. outcome.GetError().c_str());
  115. }
  116. }
  117. m_isFullscreenPass = passData->m_makeFullscreenPass;
  118. // Setup pipeline state...
  119. RHI::PipelineStateDescriptorForDispatch pipelineStateDescriptor;
  120. m_shader->GetDefaultVariant().ConfigurePipelineState(pipelineStateDescriptor);
  121. m_dispatchItem.m_pipelineState = m_shader->AcquirePipelineState(pipelineStateDescriptor);
  122. ShaderReloadNotificationBus::Handler::BusDisconnect();
  123. ShaderReloadNotificationBus::Handler::BusConnect(passData->m_shaderReference.m_assetId);
  124. }
  125. // Scope producer functions
  126. void ComputePass::CompileResources(const RHI::FrameGraphCompileContext& context)
  127. {
  128. if (m_shaderResourceGroup != nullptr)
  129. {
  130. BindPassSrg(context, m_shaderResourceGroup);
  131. m_shaderResourceGroup->Compile();
  132. }
  133. if (m_drawSrg != nullptr)
  134. {
  135. BindSrg(m_drawSrg->GetRHIShaderResourceGroup());
  136. m_drawSrg->Compile();
  137. }
  138. }
  139. void ComputePass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
  140. {
  141. RHI::CommandList* commandList = context.GetCommandList();
  142. SetSrgsForDispatch(commandList);
  143. commandList->Submit(m_dispatchItem);
  144. }
  145. void ComputePass::MatchDimensionsToOutput()
  146. {
  147. PassAttachment* outputAttachment = nullptr;
  148. if (GetOutputCount() > 0)
  149. {
  150. outputAttachment = GetOutputBinding(0).GetAttachment().get();
  151. }
  152. else if (GetInputOutputCount() > 0)
  153. {
  154. outputAttachment = GetInputOutputBinding(0).GetAttachment().get();
  155. }
  156. AZ_Assert(outputAttachment != nullptr, "[ComputePass '%s']: A fullscreen compute pass must have a valid output or input/output.",
  157. GetPathName().GetCStr());
  158. AZ_Assert(outputAttachment->GetAttachmentType() == RHI::AttachmentType::Image,
  159. "[ComputePass '%s']: The output of a fullscreen compute pass must be an image.",
  160. GetPathName().GetCStr());
  161. RHI::Size targetImageSize = outputAttachment->m_descriptor.m_image.m_size;
  162. SetTargetThreadCounts(targetImageSize.m_width, targetImageSize.m_height, targetImageSize.m_depth);
  163. }
  164. void ComputePass::SetTargetThreadCounts(uint32_t targetThreadCountX, uint32_t targetThreadCountY, uint32_t targetThreadCountZ)
  165. {
  166. auto& arguments = m_dispatchItem.m_arguments.m_direct;
  167. arguments.m_totalNumberOfThreadsX = targetThreadCountX;
  168. arguments.m_totalNumberOfThreadsY = targetThreadCountY;
  169. arguments.m_totalNumberOfThreadsZ = targetThreadCountZ;
  170. }
  171. Data::Instance<ShaderResourceGroup> ComputePass::GetShaderResourceGroup() const
  172. {
  173. return m_shaderResourceGroup;
  174. }
  175. Data::Instance<Shader> ComputePass::GetShader() const
  176. {
  177. return m_shader;
  178. }
  179. void ComputePass::FrameBeginInternal(FramePrepareParams params)
  180. {
  181. if (m_isFullscreenPass)
  182. {
  183. MatchDimensionsToOutput();
  184. }
  185. RenderPass::FrameBeginInternal(params);
  186. }
  187. void ComputePass::OnShaderReinitialized(const Shader& shader)
  188. {
  189. AZ_UNUSED(shader);
  190. LoadShader();
  191. }
  192. void ComputePass::OnShaderAssetReinitialized(const Data::Asset<ShaderAsset>& shaderAsset)
  193. {
  194. AZ_UNUSED(shaderAsset);
  195. LoadShader();
  196. }
  197. void ComputePass::OnShaderVariantReinitialized(const ShaderVariant&)
  198. {
  199. LoadShader();
  200. }
  201. } // namespace RPI
  202. } // namespace AZ