VariableRateShadingExampleComponent.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RHI/VariableRateShadingExampleComponent.h>
  9. #include <Utils/Utils.h>
  10. #include <SampleComponentManager.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  13. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  14. #include <Atom/RHI.Reflect/VariableRateShadingEnums.h>
  15. #include <Atom/RPI.Public/Shader/Shader.h>
  16. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  17. #include <AzCore/Serialization/SerializeContext.h>
  18. #include <AzCore/Math/MathUtils.h>
  19. #include <AzCore/std/containers/span.h>
  20. #include <AzCore/Math/MatrixUtils.h>
  21. #include <AzFramework/Input/Devices/Mouse/InputDeviceMouse.h>
  22. #include <AzFramework/Input/Devices/Touch/InputDeviceTouch.h>
  23. namespace AtomSampleViewer
  24. {
  25. using namespace AZ;
  26. namespace VariableRateShading
  27. {
  28. const char* SampleName = "VariableRateShadingExample";
  29. const char* ShadingRateAttachmentId = "ShadingRateAttachmentId";
  30. const char* ShadingRateAttachmentUpdateId = "ShadingRateAttachmentUpdateId";
  31. }
  32. RHI::Format ConvertToUInt(RHI::Format format)
  33. {
  34. uint32_t count = GetFormatComponentCount(format);
  35. if (count == 1)
  36. {
  37. return RHI::Format::R8_UINT;
  38. }
  39. else if (count == 2)
  40. {
  41. return RHI::Format::R8G8_UINT;
  42. }
  43. return RHI::Format::R8G8B8A8_UINT;
  44. }
  45. const char* ToString(RHI::ShadingRate rate)
  46. {
  47. switch (rate)
  48. {
  49. case RHI::ShadingRate::Rate1x1: return "Rate1x1";
  50. case RHI::ShadingRate::Rate1x2: return "Rate1x2";
  51. case RHI::ShadingRate::Rate2x1: return "Rate2x1";
  52. case RHI::ShadingRate::Rate2x2: return "Rate2x2";
  53. case RHI::ShadingRate::Rate2x4: return "Rate2x4";
  54. case RHI::ShadingRate::Rate4x2: return "Rate4x2";
  55. case RHI::ShadingRate::Rate4x4: return "Rate4x4";
  56. case RHI::ShadingRate::Rate4x1: return "Rate4x1";
  57. case RHI::ShadingRate::Rate1x4: return "Rate1x4";
  58. default: return "";
  59. }
  60. }
  61. void VariableRateShadingExampleComponent::Reflect(AZ::ReflectContext* context)
  62. {
  63. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  64. {
  65. serializeContext->Class<VariableRateShadingExampleComponent, AZ::Component>()
  66. ->Version(0)
  67. ;
  68. }
  69. }
  70. VariableRateShadingExampleComponent::VariableRateShadingExampleComponent()
  71. {
  72. m_supportRHISamplePipeline = true;
  73. }
  74. void VariableRateShadingExampleComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time)
  75. {
  76. if (m_imguiSidebar.Begin())
  77. {
  78. DrawSettings();
  79. }
  80. }
  81. void VariableRateShadingExampleComponent::OnFramePrepare(AZ::RHI::FrameGraphBuilder& frameGraphBuilder)
  82. {
  83. if (m_windowContext->GetSwapChainsSize() && m_windowContext->GetSwapChain())
  84. {
  85. if (m_useImageShadingRate)
  86. {
  87. frameGraphBuilder.GetAttachmentDatabase().ImportImage(RHI::AttachmentId{ VariableRateShading::ShadingRateAttachmentId }, m_shadingRateImages[m_frameCount % m_shadingRateImages.size()]);
  88. if (!Utils::GetRHIDevice()->GetFeatures().m_dynamicShadingRateImage)
  89. {
  90. // We cannot update and use the same shading rate image because "m_dynamicShadingRateImage" is not supported.
  91. frameGraphBuilder.GetAttachmentDatabase().ImportImage(RHI::AttachmentId{ VariableRateShading::ShadingRateAttachmentUpdateId }, m_shadingRateImages[(m_frameCount + m_shadingRateImages.size() - 1) % m_shadingRateImages.size()]);
  92. }
  93. }
  94. m_frameCount++;
  95. }
  96. BasicRHIComponent::OnFramePrepare(frameGraphBuilder);
  97. }
  98. void VariableRateShadingExampleComponent::Activate()
  99. {
  100. AZ::TickBus::Handler::BusConnect();
  101. AZ::RHI::RHISystemNotificationBus::Handler::BusConnect();
  102. AzFramework::InputChannelEventListener::Connect();
  103. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  104. const auto& deviceFeatures = device->GetFeatures();
  105. if (!RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerRegion))
  106. {
  107. m_useImageShadingRate = false;
  108. }
  109. if (!RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerDraw))
  110. {
  111. m_useDrawShadingRate = false;
  112. }
  113. if (m_useImageShadingRate)
  114. {
  115. for (uint32_t i = 0; i < static_cast<uint32_t>(RHI::Format::Count); ++i)
  116. {
  117. RHI::Format format = static_cast<RHI::Format>(i);
  118. RHI::FormatCapabilities capabilities = device->GetFormatCapabilities(format);
  119. if (RHI::CheckBitsAll(capabilities, RHI::FormatCapabilities::ShadingRate))
  120. {
  121. m_rateShadingImageFormat = format;
  122. break;
  123. }
  124. }
  125. AZ_Assert(m_rateShadingImageFormat != RHI::Format::Unknown, "Could not find a format for the shading rate image");
  126. }
  127. const auto& supportedMask = device->GetFeatures().m_shadingRateMask;
  128. for (uint32_t i = 0; i < static_cast<uint32_t>(RHI::ShadingRate::Count); ++i)
  129. {
  130. if (RHI::CheckBitsAll(supportedMask, static_cast<RHI::ShadingRateFlags>(AZ_BIT(i))))
  131. {
  132. m_supportedModes.push_back(static_cast<RHI::ShadingRate>(i));
  133. }
  134. }
  135. m_shadingRate = m_supportedModes[0];
  136. CreateShadingRateImage();
  137. LoadShaders();
  138. CreateInputAssemblyBuffersAndViews();
  139. CreateShaderResourceGroups();
  140. CreatePipelines();
  141. CreateComputeScope();
  142. CreateRenderScope();
  143. CreatImageDisplayScope();
  144. m_frameCount = 0;
  145. }
  146. void VariableRateShadingExampleComponent::CreateShadingRateImage()
  147. {
  148. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  149. const auto& tileSize = device->GetLimits().m_shadingRateTileSize;
  150. m_shadingRateImageSize = Vector2(ceil(static_cast<float>(m_outputWidth) / tileSize.m_width), ceil(static_cast<float>(m_outputHeight) / tileSize.m_height));
  151. m_imagePool = aznew RHI::ImagePool();
  152. RHI::ImagePoolDescriptor imagePoolDesc;
  153. imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShadingRate | RHI::ImageBindFlags::ShaderReadWrite;
  154. m_imagePool->Init(imagePoolDesc);
  155. // Initialize the shading rate images with proper values. Invalid values may cause a crash.
  156. uint32_t width = static_cast<uint32_t>(m_shadingRateImageSize.GetX());
  157. uint32_t height = static_cast<uint32_t>(m_shadingRateImageSize.GetY());
  158. uint32_t formatSize = GetFormatSize(m_rateShadingImageFormat);
  159. uint32_t bufferSize = width * height * formatSize;
  160. AZStd::vector<uint8_t> shadingRatePatternData(bufferSize);
  161. if (m_useImageShadingRate)
  162. {
  163. // Use the lowest shading rate as the default value.
  164. RHI::ShadingRateImageValue defaultValue = device->ConvertShadingRate(m_supportedModes[m_supportedModes.size() - 1]);
  165. uint8_t* ptrData = shadingRatePatternData.data();
  166. for (uint32_t y = 0; y < height; y++)
  167. {
  168. for (uint32_t x = 0; x < width; x++)
  169. {
  170. ::memcpy(ptrData, &defaultValue, formatSize);
  171. ptrData += formatSize;
  172. }
  173. }
  174. }
  175. // Since the device may not support "Dynamic Shading Rate Image", we need to buffer the update of the shading rate image
  176. // because the CPU may be trying to read the image.
  177. m_shadingRateImages.resize(device->GetFeatures().m_dynamicShadingRateImage ? 1 : device->GetDescriptor().m_frameCountMax+3);
  178. for (auto& image : m_shadingRateImages)
  179. {
  180. image = aznew RHI::Image();
  181. RHI::ImageInitRequest initImageRequest;
  182. RHI::ClearValue clearValue = RHI::ClearValue::CreateVector4Float(1, 1, 1, 1);
  183. initImageRequest.m_image = image.get();
  184. initImageRequest.m_descriptor = RHI::ImageDescriptor::Create2D(
  185. imagePoolDesc.m_bindFlags,
  186. static_cast<uint32_t>(m_shadingRateImageSize.GetX()),
  187. static_cast<uint32_t>(m_shadingRateImageSize.GetY()),
  188. m_rateShadingImageFormat);
  189. initImageRequest.m_optimizedClearValue = &clearValue;
  190. m_imagePool->InitImage(initImageRequest);
  191. RHI::ImageUpdateRequest request;
  192. request.m_image = image.get();
  193. request.m_sourceData = shadingRatePatternData.data();
  194. request.m_sourceSubresourceLayout = RHI::ImageSubresourceLayout();
  195. request.m_sourceSubresourceLayout.Init(
  196. RHI::MultiDevice::AllDevices, { RHI::Size(width, height, 1), height, width * formatSize, bufferSize, 1, 1 });
  197. m_imagePool->UpdateImageContents(request);
  198. }
  199. }
  200. void VariableRateShadingExampleComponent::LoadShaders()
  201. {
  202. const char* shaders[] =
  203. {
  204. "Shaders/RHI/VariableRateShading.azshader",
  205. "Shaders/RHI/VariableRateShadingCompute.azshader",
  206. "Shaders/RHI/VariableRateShadingImage.azshader"
  207. };
  208. m_shaders.resize(AZ_ARRAY_SIZE(shaders));
  209. for (size_t i = 0; i < AZ_ARRAY_SIZE(shaders); ++i)
  210. {
  211. auto shader = LoadShader(shaders[i], VariableRateShading::SampleName);
  212. if (shader == nullptr)
  213. {
  214. return;
  215. }
  216. m_shaders[i] = shader;
  217. }
  218. const auto& numThreads = m_shaders[1]->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name("numthreads"));
  219. if (numThreads)
  220. {
  221. const RHI::ShaderStageAttributeArguments& args = *numThreads;
  222. m_numThreadsX = args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : m_numThreadsX;
  223. m_numThreadsY = args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : m_numThreadsY;
  224. m_numThreadsZ = args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : m_numThreadsZ;
  225. }
  226. }
  227. void VariableRateShadingExampleComponent::CreateShaderResourceGroups()
  228. {
  229. const Name albedoId{ "m_texture" };
  230. auto textureIamge = LoadStreamingImage("textures/bricks.png.streamingimage", VariableRateShading::SampleName);
  231. AZ::RHI::ShaderInputImageIndex albedoIndex;
  232. m_modelShaderResourceGroup = CreateShaderResourceGroup(m_shaders[0], "InstanceSrg", VariableRateShading::SampleName);
  233. FindShaderInputIndex(&albedoIndex, m_modelShaderResourceGroup, albedoId, VariableRateShading::SampleName);
  234. m_modelShaderResourceGroup->SetImage(albedoIndex, textureIamge);
  235. m_modelShaderResourceGroup->Compile();
  236. const Name centerId{ "m_center" };
  237. const Name distancesId{ "m_distances" };
  238. const Name patternId{ "m_pattern" };
  239. const Name shadingRateImageId{ "m_shadingRateTexture" };
  240. AZ::RHI::ShaderInputConstantIndex patternIndex;
  241. m_computeShaderResourceGroup = CreateShaderResourceGroup(m_shaders[1], "ComputeSrg", VariableRateShading::SampleName);
  242. FindShaderInputIndex(&patternIndex, m_computeShaderResourceGroup, patternId, VariableRateShading::SampleName);
  243. FindShaderInputIndex(&m_centerIndex, m_computeShaderResourceGroup, centerId, VariableRateShading::SampleName);
  244. FindShaderInputIndex(&m_shadingRateIndex, m_computeShaderResourceGroup, shadingRateImageId, VariableRateShading::SampleName);
  245. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  246. constexpr uint32_t elementsCount = 4;
  247. struct Pattern
  248. {
  249. float m_distance[elementsCount];
  250. uint32_t m_rate[elementsCount];
  251. };
  252. struct Color
  253. {
  254. float m_color[4];
  255. uint32_t m_rate[4];
  256. };
  257. const float alpha = 0.3f;
  258. const uint32_t numRates = static_cast<uint32_t>(RHI::ShadingRate::Count);
  259. AZStd::array<Pattern, numRates> pattern;
  260. AZStd::array<Color, numRates> patternColors;
  261. AZStd::array<AZ::Color, numRates> colors =
  262. {{
  263. AZ::Color(0.f, 0.f, 1.f, alpha),
  264. AZ::Color(1.f, 0.f, 0.f, alpha),
  265. AZ::Color(0.f, 1.f, 0.f, alpha),
  266. AZ::Color(1.f, 0.f, 1.f, alpha),
  267. AZ::Color(1.f, 1.f, 0.f, alpha),
  268. AZ::Color(0.f, 1.f, 1.f, alpha),
  269. AZ::Color(1.f, 1.f, 1.f, alpha)
  270. }};
  271. float range = 60.0f / numRates;
  272. float currentRange = 8.0f;
  273. const auto& supportedMask = device->GetFeatures().m_shadingRateMask;
  274. for (uint32_t i = 0; i < pattern.size(); ++i)
  275. {
  276. RHI::ShadingRateImageValue rate = {};
  277. pattern[i].m_distance[0] = 0.0f;
  278. if (RHI::CheckBitsAll(supportedMask, static_cast<RHI::ShadingRateFlags>(AZ_BIT(i))))
  279. {
  280. rate = device->ConvertShadingRate(static_cast<RHI::ShadingRate>(i));
  281. pattern[i].m_distance[0] = currentRange;
  282. }
  283. pattern[i].m_rate[0] = rate.m_x;
  284. pattern[i].m_rate[1] = rate.m_y;
  285. currentRange += range;
  286. patternColors[i].m_rate[0] = pattern[i].m_rate[0];
  287. patternColors[i].m_rate[1] = pattern[i].m_rate[1];
  288. colors[i].StoreToFloat4(patternColors[i].m_color);
  289. }
  290. Vector2 center(static_cast<float>(m_shadingRateImageSize.GetX()) * 0.5f, static_cast<float>(m_shadingRateImageSize.GetY()) * 0.5f);
  291. m_computeShaderResourceGroup->SetConstant(m_centerIndex, center);
  292. m_computeShaderResourceGroup->SetConstantArray(patternIndex, pattern);
  293. const Name colorsId{ "m_colors" };
  294. const Name textureId{ "m_texture" };
  295. AZ::RHI::ShaderInputConstantIndex colorsIndex;
  296. m_imageShaderResourceGroup = CreateShaderResourceGroup(m_shaders[2], "InstanceSrg", VariableRateShading::SampleName);
  297. FindShaderInputIndex(&colorsIndex, m_imageShaderResourceGroup, colorsId, VariableRateShading::SampleName);
  298. FindShaderInputIndex(&m_shadingRateDisplayIndex, m_imageShaderResourceGroup, textureId, VariableRateShading::SampleName);
  299. m_imageShaderResourceGroup->SetConstantArray(colorsIndex, patternColors);
  300. }
  301. void VariableRateShadingExampleComponent::CreatePipelines()
  302. {
  303. {
  304. // We create one pipeline when using a shading rate attachment, and another one when we are not using it.
  305. RHI::RenderAttachmentLayoutBuilder shadingRateAttachmentsBuilder;
  306. shadingRateAttachmentsBuilder.AddSubpass()
  307. ->RenderTargetAttachment(m_outputFormat)
  308. ->ShadingRateAttachment(m_rateShadingImageFormat);
  309. RHI::RenderAttachmentLayout shadingRateRenderAttachmentLayout;
  310. [[maybe_unused]] RHI::ResultCode result = shadingRateAttachmentsBuilder.End(shadingRateRenderAttachmentLayout);
  311. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  312. const auto& shader = m_shaders[0];
  313. auto& variant = shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId);
  314. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  315. variant.ConfigurePipelineState(pipelineDesc);
  316. pipelineDesc.m_renderStates.m_depthStencilState = RHI::DepthStencilState::CreateDisabled();
  317. pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = shadingRateRenderAttachmentLayout;
  318. pipelineDesc.m_renderAttachmentConfiguration.m_subpassIndex = 0;
  319. pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
  320. m_modelPipelineState[0] = shader->AcquirePipelineState(pipelineDesc);
  321. if (!m_modelPipelineState[0])
  322. {
  323. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
  324. return;
  325. }
  326. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  327. attachmentsBuilder.AddSubpass()
  328. ->RenderTargetAttachment(m_outputFormat);
  329. RHI::RenderAttachmentLayout rateRenderAttachmentLayout;
  330. result = attachmentsBuilder.End(rateRenderAttachmentLayout);
  331. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  332. pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = rateRenderAttachmentLayout;
  333. m_modelPipelineState[1] = shader->AcquirePipelineState(pipelineDesc);
  334. if (!m_modelPipelineState[1])
  335. {
  336. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
  337. return;
  338. }
  339. }
  340. {
  341. RHI::PipelineStateDescriptorForDispatch pipelineDesc;
  342. const auto& shader = m_shaders[1];
  343. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  344. m_computePipelineState = shader->AcquirePipelineState(pipelineDesc);
  345. if (!m_computePipelineState)
  346. {
  347. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for compute");
  348. return;
  349. }
  350. }
  351. {
  352. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  353. attachmentsBuilder.AddSubpass()
  354. ->RenderTargetAttachment(m_outputFormat);
  355. RHI::RenderAttachmentLayout renderAttachmentLayout;
  356. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(renderAttachmentLayout);
  357. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  358. const auto& shader = m_shaders[2];
  359. auto& variant = shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId);
  360. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  361. variant.ConfigurePipelineState(pipelineDesc);
  362. pipelineDesc.m_renderStates.m_depthStencilState = RHI::DepthStencilState::CreateDisabled();
  363. pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = renderAttachmentLayout;
  364. pipelineDesc.m_renderAttachmentConfiguration.m_subpassIndex = 0;
  365. pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
  366. RHI::TargetBlendState& targetBlendState = pipelineDesc.m_renderStates.m_blendState.m_targets[0];
  367. targetBlendState.m_enable = true;
  368. targetBlendState.m_blendSource = RHI::BlendFactor::AlphaSource;
  369. targetBlendState.m_blendDest = RHI::BlendFactor::AlphaSourceInverse;
  370. targetBlendState.m_blendOp = RHI::BlendOp::Add;
  371. m_imagePipelineState = shader->AcquirePipelineState(pipelineDesc);
  372. if (!m_imagePipelineState)
  373. {
  374. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
  375. return;
  376. }
  377. }
  378. }
  379. void VariableRateShadingExampleComponent::CreateInputAssemblyBuffersAndViews()
  380. {
  381. m_bufferPool = aznew RHI::BufferPool();
  382. RHI::BufferPoolDescriptor bufferPoolDesc;
  383. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  384. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
  385. m_bufferPool->Init(bufferPoolDesc);
  386. struct BufferData
  387. {
  388. AZStd::array<VertexPosition, 4> m_positions;
  389. AZStd::array<VertexUV, 4> m_uvs;
  390. AZStd::array<uint16_t, 6> m_indices;
  391. };
  392. BufferData bufferData;
  393. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  394. m_inputAssemblyBuffer = aznew RHI::Buffer();
  395. RHI::ResultCode result = RHI::ResultCode::Success;
  396. RHI::BufferInitRequest request;
  397. request.m_buffer = m_inputAssemblyBuffer.get();
  398. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  399. request.m_initialData = &bufferData;
  400. result = m_bufferPool->InitBuffer(request);
  401. if (result != RHI::ResultCode::Success)
  402. {
  403. AZ_Error(VariableRateShading::SampleName, false, "Failed to initialize buffer with error code %d", result);
  404. return;
  405. }
  406. m_geometryView.SetDrawArguments(RHI::DrawIndexed(0, 6, 0));
  407. m_geometryView.AddStreamBufferView({
  408. *m_inputAssemblyBuffer,
  409. offsetof(BufferData, m_positions),
  410. sizeof(BufferData::m_positions),
  411. sizeof(VertexPosition)
  412. });
  413. m_geometryView.AddStreamBufferView({
  414. *m_inputAssemblyBuffer,
  415. offsetof(BufferData, m_uvs),
  416. sizeof(BufferData::m_uvs),
  417. sizeof(VertexUV)
  418. });
  419. m_geometryView.SetIndexBufferView({
  420. *m_inputAssemblyBuffer,
  421. offsetof(BufferData, m_indices),
  422. sizeof(BufferData::m_indices),
  423. RHI::IndexFormat::Uint16
  424. });
  425. RHI::InputStreamLayoutBuilder layoutBuilder;
  426. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  427. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  428. m_inputStreamLayout = layoutBuilder.End();
  429. RHI::ValidateStreamBufferViews(m_inputStreamLayout, m_geometryView, m_geometryView.GetFullStreamBufferIndices());
  430. }
  431. void VariableRateShadingExampleComponent::CreateRenderScope()
  432. {
  433. struct ScopeData
  434. {
  435. };
  436. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  437. {
  438. {
  439. // Binds the swap chain as a color attachment.
  440. RHI::ImageScopeAttachmentDescriptor descriptor;
  441. descriptor.m_attachmentId = m_outputAttachmentId;
  442. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  443. frameGraph.UseColorAttachment(descriptor);
  444. }
  445. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  446. bool useImageShadingRate = m_useImageShadingRate && (device->GetFeatures().m_dynamicShadingRateImage || m_frameCount > device->GetDescriptor().m_frameCountMax);
  447. if (useImageShadingRate)
  448. {
  449. // Binds the shading rate image attachment
  450. AZ::RHI::ImageScopeAttachmentDescriptor dsDesc;
  451. dsDesc.m_attachmentId = VariableRateShading::ShadingRateAttachmentId;
  452. dsDesc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  453. dsDesc.m_loadStoreAction.m_storeAction = AZ::RHI::AttachmentStoreAction::DontCare;
  454. frameGraph.UseAttachment(
  455. dsDesc, AZ::RHI::ScopeAttachmentAccess::Read, AZ::RHI::ScopeAttachmentUsage::ShadingRate,
  456. AZ::RHI::ScopeAttachmentStage::ShadingRate);
  457. }
  458. frameGraph.SetEstimatedItemCount(1);
  459. };
  460. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  461. const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  462. {
  463. RHI::CommandList* commandList = context.GetCommandList();
  464. // Set persistent viewport and scissor state.
  465. commandList->SetViewports(&m_viewport, 1);
  466. commandList->SetScissors(&m_scissor, 1);
  467. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  468. if (m_useDrawShadingRate)
  469. {
  470. RHI::ShadingRateCombinators combinators = { RHI::ShadingRateCombinerOp::Passthrough, m_combinerOp };
  471. commandList->SetFragmentShadingRate(m_shadingRate, combinators);
  472. }
  473. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  474. m_modelShaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  475. };
  476. // We have to wait until the updating of the initial contents of the shading rate image is done if
  477. // dynamic mode is not supported (since the CPU would try to read it while the GPU is updating the contents)
  478. bool useImageShadingRate = m_useImageShadingRate && (device->GetFeatures().m_dynamicShadingRateImage || m_frameCount > device->GetDescriptor().m_frameCountMax);
  479. RHI::DeviceDrawItem drawItem;
  480. drawItem.m_geometryView = m_geometryView.GetDeviceGeometryView(context.GetDeviceIndex());
  481. drawItem.m_streamIndices = m_geometryView.GetFullStreamBufferIndices();
  482. drawItem.m_pipelineState = m_modelPipelineState[useImageShadingRate ? 0 : 1]->GetDevicePipelineState(context.GetDeviceIndex()).get();
  483. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));;
  484. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  485. commandList->Submit(drawItem);
  486. if (m_useDrawShadingRate)
  487. {
  488. commandList->SetFragmentShadingRate(RHI::ShadingRate::Rate1x1);
  489. }
  490. };
  491. const RHI::ScopeId forwardScope("SceneScope");
  492. m_scopeProducers.emplace_back(
  493. aznew RHI::ScopeProducerFunction<
  494. ScopeData,
  495. decltype(prepareFunction),
  496. decltype(compileFunction),
  497. decltype(executeFunction)>(
  498. forwardScope,
  499. ScopeData{},
  500. prepareFunction,
  501. compileFunction,
  502. executeFunction));
  503. }
  504. void VariableRateShadingExampleComponent::CreateComputeScope()
  505. {
  506. struct ScopeData
  507. {
  508. };
  509. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  510. const auto& deviceFeatures = device->GetFeatures();
  511. // If "m_dynamicShadingRateImage" is not supported we cannot update the same image that is being used as shading rate this frame.
  512. // We use an "old" one that is not longer in used.
  513. const char* shadingRateAttachmentId = deviceFeatures.m_dynamicShadingRateImage ? VariableRateShading::ShadingRateAttachmentId : VariableRateShading::ShadingRateAttachmentUpdateId;
  514. const auto prepareFunction = [this, shadingRateAttachmentId](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  515. {
  516. if (m_useImageShadingRate)
  517. {
  518. RHI::ImageScopeAttachmentDescriptor shadingRateImageDesc;
  519. shadingRateImageDesc.m_attachmentId = shadingRateAttachmentId;
  520. shadingRateImageDesc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  521. shadingRateImageDesc.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::Store;
  522. shadingRateImageDesc.m_imageViewDescriptor.m_overrideFormat = ConvertToUInt(m_rateShadingImageFormat);
  523. frameGraph.UseShaderAttachment(
  524. shadingRateImageDesc, RHI::ScopeAttachmentAccess::Write, RHI::ScopeAttachmentStage::ComputeShader);
  525. }
  526. frameGraph.SetEstimatedItemCount(1);
  527. };
  528. const auto compileFunction = [this, shadingRateAttachmentId](const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  529. {
  530. if (m_useImageShadingRate)
  531. {
  532. Vector2 center = m_cursorPos * m_shadingRateImageSize;
  533. const auto* shadingRateImageView = context.GetImageView(RHI::AttachmentId(shadingRateAttachmentId));
  534. m_computeShaderResourceGroup->SetImageView(m_shadingRateIndex, shadingRateImageView);
  535. m_computeShaderResourceGroup->SetConstant(m_centerIndex, center);
  536. m_computeShaderResourceGroup->Compile();
  537. }
  538. };
  539. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  540. {
  541. if (!m_useImageShadingRate)
  542. {
  543. return;
  544. }
  545. RHI::CommandList* commandList = context.GetCommandList();
  546. RHI::DeviceDispatchItem dispatchItem;
  547. decltype(dispatchItem.m_shaderResourceGroups) shaderResourceGroups = { { m_computeShaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get() } };
  548. RHI::DispatchDirect dispatchArgs;
  549. dispatchArgs.m_totalNumberOfThreadsX = aznumeric_cast<uint32_t>(m_shadingRateImageSize.GetX());
  550. dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(m_numThreadsX);
  551. dispatchArgs.m_totalNumberOfThreadsY = aznumeric_cast<uint32_t>(m_shadingRateImageSize.GetY());
  552. dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(m_numThreadsY);
  553. dispatchArgs.m_totalNumberOfThreadsZ = 1;
  554. dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(m_numThreadsZ);
  555. AZ_Assert(dispatchArgs.m_threadsPerGroupX == dispatchArgs.m_threadsPerGroupY, "If the shader source changes, this logic should change too.");
  556. AZ_Assert(dispatchArgs.m_threadsPerGroupZ == 1, "If the shader source changes, this logic should change too.");
  557. dispatchItem.m_arguments = dispatchArgs;
  558. dispatchItem.m_pipelineState = m_computePipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  559. dispatchItem.m_shaderResourceGroupCount = 1;
  560. dispatchItem.m_shaderResourceGroups = shaderResourceGroups;
  561. commandList->Submit(dispatchItem);
  562. };
  563. const RHI::ScopeId computeScope("ShadingRateImageCompute");
  564. m_scopeProducers.emplace_back(
  565. aznew RHI::ScopeProducerFunction<
  566. ScopeData,
  567. decltype(prepareFunction),
  568. decltype(compileFunction),
  569. decltype(executeFunction)>(
  570. computeScope,
  571. ScopeData{},
  572. prepareFunction,
  573. compileFunction,
  574. executeFunction));
  575. }
  576. void VariableRateShadingExampleComponent::CreatImageDisplayScope()
  577. {
  578. struct ScopeData
  579. {
  580. };
  581. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  582. {
  583. {
  584. // Binds the swap chain as a color attachment.
  585. RHI::ImageScopeAttachmentDescriptor descriptor;
  586. descriptor.m_attachmentId = m_outputAttachmentId;
  587. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  588. frameGraph.UseColorAttachment(descriptor);
  589. }
  590. if (m_showShadingRateImage)
  591. {
  592. // Binds the shading rate image for reading (not as attachment)
  593. RHI::ImageScopeAttachmentDescriptor shadingRateImageDesc;
  594. shadingRateImageDesc.m_attachmentId = VariableRateShading::ShadingRateAttachmentId;
  595. shadingRateImageDesc.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::DontCare;
  596. shadingRateImageDesc.m_imageViewDescriptor.m_overrideFormat = ConvertToUInt(m_rateShadingImageFormat);
  597. frameGraph.UseShaderAttachment(
  598. shadingRateImageDesc, RHI::ScopeAttachmentAccess::Read, RHI::ScopeAttachmentStage::FragmentShader);
  599. }
  600. frameGraph.SetEstimatedItemCount(1);
  601. };
  602. const auto compileFunction = [this](const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  603. {
  604. if (m_showShadingRateImage)
  605. {
  606. const auto* shadingRateImageView = context.GetImageView(RHI::AttachmentId(VariableRateShading::ShadingRateAttachmentId));
  607. m_imageShaderResourceGroup->SetImageView(m_shadingRateDisplayIndex, shadingRateImageView);
  608. m_imageShaderResourceGroup->Compile();
  609. }
  610. };
  611. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  612. {
  613. if (!m_showShadingRateImage)
  614. {
  615. return;
  616. }
  617. RHI::CommandList* commandList = context.GetCommandList();
  618. // Set persistent viewport and scissor state.
  619. commandList->SetViewports(&m_viewport, 1);
  620. commandList->SetScissors(&m_scissor, 1);
  621. const RHI::DeviceShaderResourceGroup* shaderResourceGroups[] = {
  622. m_imageShaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
  623. };
  624. RHI::DeviceDrawItem drawItem;
  625. drawItem.m_geometryView = m_geometryView.GetDeviceGeometryView(context.GetDeviceIndex());
  626. drawItem.m_streamIndices = m_geometryView.GetFullStreamBufferIndices();
  627. drawItem.m_pipelineState = m_imagePipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
  628. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
  629. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  630. commandList->Submit(drawItem);
  631. };
  632. const RHI::ScopeId forwardScope("ImageDisplayScope");
  633. m_scopeProducers.emplace_back(
  634. aznew RHI::ScopeProducerFunction<
  635. ScopeData,
  636. decltype(prepareFunction),
  637. decltype(compileFunction),
  638. decltype(executeFunction)>(
  639. forwardScope,
  640. ScopeData{},
  641. prepareFunction,
  642. compileFunction,
  643. executeFunction));
  644. }
  645. void VariableRateShadingExampleComponent::Deactivate()
  646. {
  647. m_imguiSidebar.Deactivate();
  648. AZ::RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  649. AZ::TickBus::Handler::BusDisconnect();
  650. AzFramework::InputChannelEventListener::BusDisconnect();
  651. m_bufferPool = nullptr;
  652. m_inputAssemblyBuffer = nullptr;
  653. m_modelPipelineState[0] = nullptr;
  654. m_modelPipelineState[1] = nullptr;
  655. m_imagePipelineState = nullptr;
  656. m_modelShaderResourceGroup = nullptr;
  657. m_computeShaderResourceGroup = nullptr;
  658. m_imageShaderResourceGroup = nullptr;
  659. m_shaders.clear();
  660. m_supportedModes.clear();
  661. m_windowContext = nullptr;
  662. m_imagePool = nullptr;
  663. m_shadingRateImages.clear();
  664. m_scopeProducers.clear();
  665. }
  666. void VariableRateShadingExampleComponent::DrawSettings()
  667. {
  668. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  669. const auto& deviceFeatures = device->GetFeatures();
  670. ImGui::Spacing();
  671. if (RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerRegion))
  672. {
  673. ScriptableImGui::Checkbox("Image Shade Rate", &m_useImageShadingRate);
  674. if (m_useImageShadingRate)
  675. {
  676. ImGui::Indent();
  677. ScriptableImGui::Checkbox("Show Image", &m_showShadingRateImage);
  678. ScriptableImGui::Checkbox("Follow Pointer", &m_followPointer);
  679. ImGui::Unindent();
  680. }
  681. else
  682. {
  683. m_showShadingRateImage = false;
  684. m_followPointer = false;
  685. }
  686. }
  687. if (RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerDraw))
  688. {
  689. ScriptableImGui::Checkbox("Draw Shade Rate", &m_useDrawShadingRate);
  690. if (m_useDrawShadingRate)
  691. {
  692. ImGui::Indent();
  693. AZStd::vector<const char*> items;
  694. for(const auto rate : m_supportedModes)
  695. {
  696. items.push_back(ToString(rate));
  697. }
  698. int current_item = static_cast<int>(AZStd::distance(m_supportedModes.begin(), AZStd::find(m_supportedModes.begin(), m_supportedModes.end(), m_shadingRate)));
  699. ScriptableImGui::Combo("Shading Rates", &current_item, items.data(), static_cast<int>(items.size()));
  700. m_shadingRate = m_supportedModes[current_item];
  701. ImGui::Unindent();
  702. }
  703. }
  704. if (m_useDrawShadingRate && m_useImageShadingRate)
  705. {
  706. AZStd::vector<const char*> items = { "Passthrough", "Override", "Min", "Max" };
  707. int current_item = static_cast<int>(m_combinerOp);
  708. ScriptableImGui::Combo("Combiner Op", &current_item, items.data(), static_cast<int>(items.size()));
  709. m_combinerOp = static_cast<RHI::ShadingRateCombinerOp>(current_item);
  710. }
  711. else if(m_useDrawShadingRate)
  712. {
  713. m_combinerOp = RHI::ShadingRateCombinerOp::Passthrough;
  714. }
  715. if (!m_followPointer)
  716. {
  717. m_cursorPos = AZ::Vector2(0.5f, 0.5f);
  718. }
  719. m_imguiSidebar.End();
  720. }
  721. bool VariableRateShadingExampleComponent::OnInputChannelEventFiltered(const AzFramework::InputChannel& inputChannel)
  722. {
  723. if (m_followPointer)
  724. {
  725. const AzFramework::InputChannelId& inputChannelId = inputChannel.GetInputChannelId();
  726. switch (inputChannel.GetState())
  727. {
  728. case AzFramework::InputChannel::State::Began:
  729. case AzFramework::InputChannel::State::Updated: // update the camera rotation
  730. {
  731. const AzFramework::InputChannel::PositionData2D* position = nullptr;
  732. // Mouse or Touch Events
  733. if (inputChannelId == AzFramework::InputDeviceMouse::SystemCursorPosition ||
  734. inputChannelId == AzFramework::InputDeviceTouch::Touch::Index0)
  735. {
  736. position = inputChannel.GetCustomData<AzFramework::InputChannel::PositionData2D>();
  737. }
  738. if (position)
  739. {
  740. m_cursorPos = position->m_normalizedPosition;
  741. }
  742. break;
  743. }
  744. default:
  745. break;
  746. }
  747. }
  748. return false;
  749. }
  750. } // namespace AtomSampleViewer