VariableRateShadingExampleComponent.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include <RHI/VariableRateShadingExampleComponent.h>
  9. #include <Utils/Utils.h>
  10. #include <SampleComponentManager.h>
  11. #include <Atom/RHI/CommandList.h>
  12. #include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
  13. #include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
  14. #include <Atom/RHI.Reflect/VariableRateShadingEnums.h>
  15. #include <Atom/RPI.Public/Shader/Shader.h>
  16. #include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
  17. #include <AzCore/Serialization/SerializeContext.h>
  18. #include <AzCore/Math/MathUtils.h>
  19. #include <AzCore/std/containers/span.h>
  20. #include <AzCore/Math/MatrixUtils.h>
  21. #include <AzFramework/Input/Devices/Mouse/InputDeviceMouse.h>
  22. #include <AzFramework/Input/Devices/Touch/InputDeviceTouch.h>
  23. namespace AtomSampleViewer
  24. {
  25. using namespace AZ;
  26. namespace VariableRateShading
  27. {
  28. const char* SampleName = "VariableRateShadingExample";
  29. const char* ShadingRateAttachmentId = "ShadingRateAttachmentId";
  30. const char* ShadingRateAttachmentUpdateId = "ShadingRateAttachmentUpdateId";
  31. }
  32. RHI::Format ConvertToUInt(RHI::Format format)
  33. {
  34. uint32_t count = GetFormatComponentCount(format);
  35. if (count == 1)
  36. {
  37. return RHI::Format::R8_UINT;
  38. }
  39. else if (count == 2)
  40. {
  41. return RHI::Format::R8G8_UINT;
  42. }
  43. return RHI::Format::R8G8B8A8_UINT;
  44. }
  45. const char* ToString(RHI::ShadingRate rate)
  46. {
  47. switch (rate)
  48. {
  49. case RHI::ShadingRate::Rate1x1: return "Rate1x1";
  50. case RHI::ShadingRate::Rate1x2: return "Rate1x2";
  51. case RHI::ShadingRate::Rate2x1: return "Rate2x1";
  52. case RHI::ShadingRate::Rate2x2: return "Rate2x2";
  53. case RHI::ShadingRate::Rate2x4: return "Rate2x4";
  54. case RHI::ShadingRate::Rate4x2: return "Rate4x2";
  55. case RHI::ShadingRate::Rate4x4: return "Rate4x4";
  56. default: return "";
  57. }
  58. }
  59. void VariableRateShadingExampleComponent::Reflect(AZ::ReflectContext* context)
  60. {
  61. if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
  62. {
  63. serializeContext->Class<VariableRateShadingExampleComponent, AZ::Component>()
  64. ->Version(0)
  65. ;
  66. }
  67. }
  68. VariableRateShadingExampleComponent::VariableRateShadingExampleComponent()
  69. {
  70. m_supportRHISamplePipeline = true;
  71. }
  72. void VariableRateShadingExampleComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time)
  73. {
  74. if (m_imguiSidebar.Begin())
  75. {
  76. DrawSettings();
  77. }
  78. }
  79. void VariableRateShadingExampleComponent::OnFramePrepare(AZ::RHI::FrameGraphBuilder& frameGraphBuilder)
  80. {
  81. if (m_windowContext->GetSwapChainsSize() && m_windowContext->GetSwapChain())
  82. {
  83. if (m_useImageShadingRate)
  84. {
  85. frameGraphBuilder.GetAttachmentDatabase().ImportImage(RHI::AttachmentId{ VariableRateShading::ShadingRateAttachmentId }, m_shadingRateImages[m_frameCount % m_shadingRateImages.size()]);
  86. if (!Utils::GetRHIDevice()->GetFeatures().m_dynamicShadingRateImage)
  87. {
  88. // We cannot update and use the same shading rate image because "m_dynamicShadingRateImage" is not supported.
  89. frameGraphBuilder.GetAttachmentDatabase().ImportImage(RHI::AttachmentId{ VariableRateShading::ShadingRateAttachmentUpdateId }, m_shadingRateImages[(m_frameCount + m_shadingRateImages.size() - 1) % m_shadingRateImages.size()]);
  90. }
  91. }
  92. m_frameCount++;
  93. }
  94. BasicRHIComponent::OnFramePrepare(frameGraphBuilder);
  95. }
  96. void VariableRateShadingExampleComponent::Activate()
  97. {
  98. AZ::TickBus::Handler::BusConnect();
  99. AZ::RHI::RHISystemNotificationBus::Handler::BusConnect();
  100. AzFramework::InputChannelEventListener::Connect();
  101. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  102. const auto& deviceFeatures = device->GetFeatures();
  103. if (!RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerImage))
  104. {
  105. m_useImageShadingRate = false;
  106. }
  107. if (!RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerDraw))
  108. {
  109. m_useDrawShadingRate = false;
  110. }
  111. if (m_useImageShadingRate)
  112. {
  113. for (uint32_t i = 0; i < static_cast<uint32_t>(RHI::Format::Count); ++i)
  114. {
  115. RHI::Format format = static_cast<RHI::Format>(i);
  116. RHI::FormatCapabilities capabilities = device->GetFormatCapabilities(format);
  117. if (RHI::CheckBitsAll(capabilities, RHI::FormatCapabilities::ShadingRate))
  118. {
  119. m_rateShadingImageFormat = format;
  120. break;
  121. }
  122. }
  123. AZ_Assert(m_rateShadingImageFormat != RHI::Format::Unknown, "Could not find a format for the shading rate image");
  124. }
  125. const auto& supportedMask = device->GetFeatures().m_shadingRateMask;
  126. for (uint32_t i = 0; i < static_cast<uint32_t>(RHI::ShadingRate::Count); ++i)
  127. {
  128. if (RHI::CheckBitsAll(supportedMask, static_cast<RHI::ShadingRateFlags>(AZ_BIT(i))))
  129. {
  130. m_supportedModes.push_back(static_cast<RHI::ShadingRate>(i));
  131. }
  132. }
  133. m_shadingRate = m_supportedModes[0];
  134. CreateShadingRateImage();
  135. LoadShaders();
  136. CreateInputAssemblyBuffersAndViews();
  137. CreateShaderResourceGroups();
  138. CreatePipelines();
  139. CreateComputeScope();
  140. CreateRenderScope();
  141. CreatImageDisplayScope();
  142. m_frameCount = 0;
  143. }
  144. void VariableRateShadingExampleComponent::CreateShadingRateImage()
  145. {
  146. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  147. const auto& tileSize = device->GetLimits().m_shadingRateTileSize;
  148. m_shadingRateImageSize = Vector2(ceil(static_cast<float>(m_outputWidth) / tileSize.m_width), ceil(static_cast<float>(m_outputHeight) / tileSize.m_height));
  149. m_imagePool = RHI::Factory::Get().CreateImagePool();
  150. RHI::ImagePoolDescriptor imagePoolDesc;
  151. imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShadingRate | RHI::ImageBindFlags::ShaderReadWrite;
  152. m_imagePool->Init(*device, imagePoolDesc);
  153. // Initialize the shading rate images with proper values. Invalid values may cause a crash.
  154. uint32_t width = static_cast<uint32_t>(m_shadingRateImageSize.GetX());
  155. uint32_t height = static_cast<uint32_t>(m_shadingRateImageSize.GetY());
  156. uint32_t formatSize = GetFormatSize(m_rateShadingImageFormat);
  157. uint32_t bufferSize = width * height * formatSize;
  158. AZStd::vector<uint8_t> shadingRatePatternData(bufferSize);
  159. if (m_useImageShadingRate)
  160. {
  161. // Use the lowest shading rate as the default value.
  162. RHI::ShadingRateImageValue defaultValue = device->ConvertShadingRate(m_supportedModes[m_supportedModes.size() - 1]);
  163. uint8_t* ptrData = shadingRatePatternData.data();
  164. for (uint32_t y = 0; y < height; y++)
  165. {
  166. for (uint32_t x = 0; x < width; x++)
  167. {
  168. ::memcpy(ptrData, &defaultValue, formatSize);
  169. ptrData += formatSize;
  170. }
  171. }
  172. }
  173. // Since the device may not support "Dynamic Shading Rate Image", we need to buffer the update of the shading rate image
  174. // because the CPU may be trying to read the image.
  175. m_shadingRateImages.resize(device->GetFeatures().m_dynamicShadingRateImage ? 1 : device->GetDescriptor().m_frameCountMax+3);
  176. for (auto& image : m_shadingRateImages)
  177. {
  178. image = RHI::Factory::Get().CreateImage();
  179. RHI::ImageInitRequest initImageRequest;
  180. RHI::ClearValue clearValue = RHI::ClearValue::CreateVector4Float(1, 1, 1, 1);
  181. initImageRequest.m_image = image.get();
  182. initImageRequest.m_descriptor = RHI::ImageDescriptor::Create2D(
  183. imagePoolDesc.m_bindFlags,
  184. static_cast<uint32_t>(m_shadingRateImageSize.GetX()),
  185. static_cast<uint32_t>(m_shadingRateImageSize.GetY()),
  186. m_rateShadingImageFormat);
  187. initImageRequest.m_optimizedClearValue = &clearValue;
  188. m_imagePool->InitImage(initImageRequest);
  189. RHI::ImageUpdateRequest request;
  190. request.m_image = image.get();
  191. request.m_sourceData = shadingRatePatternData.data();
  192. request.m_sourceSubresourceLayout = RHI::ImageSubresourceLayout(
  193. RHI::Size(width, height, 1),
  194. height,
  195. width * formatSize,
  196. bufferSize,
  197. 1,
  198. 1
  199. );
  200. m_imagePool->UpdateImageContents(request);
  201. }
  202. }
  203. void VariableRateShadingExampleComponent::LoadShaders()
  204. {
  205. const char* shaders[] =
  206. {
  207. "Shaders/RHI/VariableRateShading.azshader",
  208. "Shaders/RHI/VariableRateShadingCompute.azshader",
  209. "Shaders/RHI/VariableRateShadingImage.azshader"
  210. };
  211. m_shaders.resize(AZ_ARRAY_SIZE(shaders));
  212. for (size_t i = 0; i < AZ_ARRAY_SIZE(shaders); ++i)
  213. {
  214. auto shader = LoadShader(shaders[i], VariableRateShading::SampleName);
  215. if (shader == nullptr)
  216. {
  217. return;
  218. }
  219. m_shaders[i] = shader;
  220. }
  221. const auto& numThreads = m_shaders[1]->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name("numthreads"));
  222. if (numThreads)
  223. {
  224. const RHI::ShaderStageAttributeArguments& args = *numThreads;
  225. m_numThreadsX = args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : m_numThreadsX;
  226. m_numThreadsY = args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : m_numThreadsY;
  227. m_numThreadsZ = args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : m_numThreadsZ;
  228. }
  229. }
  230. void VariableRateShadingExampleComponent::CreateShaderResourceGroups()
  231. {
  232. const Name albedoId{ "m_texture" };
  233. auto textureIamge = LoadStreamingImage("textures/bricks.png.streamingimage", VariableRateShading::SampleName);
  234. AZ::RHI::ShaderInputImageIndex albedoIndex;
  235. m_modelShaderResourceGroup = CreateShaderResourceGroup(m_shaders[0], "InstanceSrg", VariableRateShading::SampleName);
  236. FindShaderInputIndex(&albedoIndex, m_modelShaderResourceGroup, albedoId, VariableRateShading::SampleName);
  237. m_modelShaderResourceGroup->SetImage(albedoIndex, textureIamge);
  238. m_modelShaderResourceGroup->Compile();
  239. const Name centerId{ "m_center" };
  240. const Name distancesId{ "m_distances" };
  241. const Name patternId{ "m_pattern" };
  242. const Name shadingRateImageId{ "m_shadingRateTexture" };
  243. AZ::RHI::ShaderInputConstantIndex patternIndex;
  244. m_computeShaderResourceGroup = CreateShaderResourceGroup(m_shaders[1], "ComputeSrg", VariableRateShading::SampleName);
  245. FindShaderInputIndex(&patternIndex, m_computeShaderResourceGroup, patternId, VariableRateShading::SampleName);
  246. FindShaderInputIndex(&m_centerIndex, m_computeShaderResourceGroup, centerId, VariableRateShading::SampleName);
  247. FindShaderInputIndex(&m_shadingRateIndex, m_computeShaderResourceGroup, shadingRateImageId, VariableRateShading::SampleName);
  248. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  249. constexpr uint32_t elementsCount = 4;
  250. struct Pattern
  251. {
  252. float m_distance[elementsCount];
  253. uint32_t m_rate[elementsCount];
  254. };
  255. struct Color
  256. {
  257. float m_color[4];
  258. uint32_t m_rate[4];
  259. };
  260. const float alpha = 0.3f;
  261. const uint32_t numRates = static_cast<uint32_t>(RHI::ShadingRate::Count);
  262. AZStd::array<Pattern, numRates> pattern;
  263. AZStd::array<Color, numRates> patternColors;
  264. AZStd::array<AZ::Color, numRates> colors =
  265. {{
  266. AZ::Color(0.f, 0.f, 1.f, alpha),
  267. AZ::Color(1.f, 0.f, 0.f, alpha),
  268. AZ::Color(0.f, 1.f, 0.f, alpha),
  269. AZ::Color(1.f, 0.f, 1.f, alpha),
  270. AZ::Color(1.f, 1.f, 0.f, alpha),
  271. AZ::Color(0.f, 1.f, 1.f, alpha),
  272. AZ::Color(1.f, 1.f, 1.f, alpha)
  273. }};
  274. float range = 60.0f / numRates;
  275. float currentRange = 8.0f;
  276. const auto& supportedMask = device->GetFeatures().m_shadingRateMask;
  277. for (uint32_t i = 0; i < pattern.size(); ++i)
  278. {
  279. RHI::ShadingRateImageValue rate = {};
  280. pattern[i].m_distance[0] = 0.0f;
  281. if (RHI::CheckBitsAll(supportedMask, static_cast<RHI::ShadingRateFlags>(AZ_BIT(i))))
  282. {
  283. rate = device->ConvertShadingRate(static_cast<RHI::ShadingRate>(i));
  284. pattern[i].m_distance[0] = currentRange;
  285. }
  286. pattern[i].m_rate[0] = rate.m_x;
  287. pattern[i].m_rate[1] = rate.m_y;
  288. currentRange += range;
  289. patternColors[i].m_rate[0] = pattern[i].m_rate[0];
  290. patternColors[i].m_rate[1] = pattern[i].m_rate[1];
  291. colors[i].StoreToFloat4(patternColors[i].m_color);
  292. }
  293. Vector2 center(static_cast<float>(m_shadingRateImageSize.GetX()) * 0.5f, static_cast<float>(m_shadingRateImageSize.GetY()) * 0.5f);
  294. m_computeShaderResourceGroup->SetConstant(m_centerIndex, center);
  295. m_computeShaderResourceGroup->SetConstantArray(patternIndex, pattern);
  296. const Name colorsId{ "m_colors" };
  297. const Name textureId{ "m_texture" };
  298. AZ::RHI::ShaderInputConstantIndex colorsIndex;
  299. m_imageShaderResourceGroup = CreateShaderResourceGroup(m_shaders[2], "InstanceSrg", VariableRateShading::SampleName);
  300. FindShaderInputIndex(&colorsIndex, m_imageShaderResourceGroup, colorsId, VariableRateShading::SampleName);
  301. FindShaderInputIndex(&m_shadingRateDisplayIndex, m_imageShaderResourceGroup, textureId, VariableRateShading::SampleName);
  302. m_imageShaderResourceGroup->SetConstantArray(colorsIndex, patternColors);
  303. }
  304. void VariableRateShadingExampleComponent::CreatePipelines()
  305. {
  306. {
  307. // We create one pipeline when using a shading rate attachment, and another one when we are not using it.
  308. RHI::RenderAttachmentLayoutBuilder shadingRateAttachmentsBuilder;
  309. shadingRateAttachmentsBuilder.AddSubpass()
  310. ->RenderTargetAttachment(m_outputFormat)
  311. ->ShadingRateAttachment(m_rateShadingImageFormat);
  312. RHI::RenderAttachmentLayout shadingRateRenderAttachmentLayout;
  313. [[maybe_unused]] RHI::ResultCode result = shadingRateAttachmentsBuilder.End(shadingRateRenderAttachmentLayout);
  314. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  315. const auto& shader = m_shaders[0];
  316. auto& variant = shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId);
  317. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  318. variant.ConfigurePipelineState(pipelineDesc);
  319. pipelineDesc.m_renderStates.m_depthStencilState = RHI::DepthStencilState::CreateDisabled();
  320. pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = shadingRateRenderAttachmentLayout;
  321. pipelineDesc.m_renderAttachmentConfiguration.m_subpassIndex = 0;
  322. pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
  323. m_modelPipelineState[0] = shader->AcquirePipelineState(pipelineDesc);
  324. if (!m_modelPipelineState[0])
  325. {
  326. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
  327. return;
  328. }
  329. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  330. attachmentsBuilder.AddSubpass()
  331. ->RenderTargetAttachment(m_outputFormat);
  332. RHI::RenderAttachmentLayout rateRenderAttachmentLayout;
  333. result = attachmentsBuilder.End(rateRenderAttachmentLayout);
  334. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  335. pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = rateRenderAttachmentLayout;
  336. m_modelPipelineState[1] = shader->AcquirePipelineState(pipelineDesc);
  337. if (!m_modelPipelineState[1])
  338. {
  339. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
  340. return;
  341. }
  342. }
  343. {
  344. RHI::PipelineStateDescriptorForDispatch pipelineDesc;
  345. const auto& shader = m_shaders[1];
  346. shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
  347. m_computePipelineState = shader->AcquirePipelineState(pipelineDesc);
  348. if (!m_computePipelineState)
  349. {
  350. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for compute");
  351. return;
  352. }
  353. }
  354. {
  355. RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
  356. attachmentsBuilder.AddSubpass()
  357. ->RenderTargetAttachment(m_outputFormat);
  358. RHI::RenderAttachmentLayout renderAttachmentLayout;
  359. [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(renderAttachmentLayout);
  360. AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
  361. const auto& shader = m_shaders[2];
  362. auto& variant = shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId);
  363. RHI::PipelineStateDescriptorForDraw pipelineDesc;
  364. variant.ConfigurePipelineState(pipelineDesc);
  365. pipelineDesc.m_renderStates.m_depthStencilState = RHI::DepthStencilState::CreateDisabled();
  366. pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = renderAttachmentLayout;
  367. pipelineDesc.m_renderAttachmentConfiguration.m_subpassIndex = 0;
  368. pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
  369. RHI::TargetBlendState& targetBlendState = pipelineDesc.m_renderStates.m_blendState.m_targets[0];
  370. targetBlendState.m_enable = true;
  371. targetBlendState.m_blendSource = RHI::BlendFactor::AlphaSource;
  372. targetBlendState.m_blendDest = RHI::BlendFactor::AlphaSourceInverse;
  373. targetBlendState.m_blendOp = RHI::BlendOp::Add;
  374. m_imagePipelineState = shader->AcquirePipelineState(pipelineDesc);
  375. if (!m_imagePipelineState)
  376. {
  377. AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
  378. return;
  379. }
  380. }
  381. }
  382. void VariableRateShadingExampleComponent::CreateInputAssemblyBuffersAndViews()
  383. {
  384. const RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  385. m_bufferPool = RHI::Factory::Get().CreateBufferPool();
  386. RHI::BufferPoolDescriptor bufferPoolDesc;
  387. bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
  388. bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
  389. m_bufferPool->Init(*device, bufferPoolDesc);
  390. struct BufferData
  391. {
  392. AZStd::array<VertexPosition, 4> m_positions;
  393. AZStd::array<VertexUV, 4> m_uvs;
  394. AZStd::array<uint16_t, 6> m_indices;
  395. };
  396. BufferData bufferData;
  397. SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
  398. m_inputAssemblyBuffer = RHI::Factory::Get().CreateBuffer();
  399. RHI::ResultCode result = RHI::ResultCode::Success;
  400. RHI::BufferInitRequest request;
  401. request.m_buffer = m_inputAssemblyBuffer.get();
  402. request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
  403. request.m_initialData = &bufferData;
  404. result = m_bufferPool->InitBuffer(request);
  405. if (result != RHI::ResultCode::Success)
  406. {
  407. AZ_Error(VariableRateShading::SampleName, false, "Failed to initialize buffer with error code %d", result);
  408. return;
  409. }
  410. m_streamBufferViews[0] =
  411. {
  412. *m_inputAssemblyBuffer,
  413. offsetof(BufferData, m_positions),
  414. sizeof(BufferData::m_positions),
  415. sizeof(VertexPosition)
  416. };
  417. m_streamBufferViews[1] =
  418. {
  419. *m_inputAssemblyBuffer,
  420. offsetof(BufferData, m_uvs),
  421. sizeof(BufferData::m_uvs),
  422. sizeof(VertexUV)
  423. };
  424. m_indexBufferView =
  425. {
  426. *m_inputAssemblyBuffer,
  427. offsetof(BufferData, m_indices),
  428. sizeof(BufferData::m_indices),
  429. RHI::IndexFormat::Uint16
  430. };
  431. RHI::InputStreamLayoutBuilder layoutBuilder;
  432. layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
  433. layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
  434. m_inputStreamLayout = layoutBuilder.End();
  435. RHI::ValidateStreamBufferViews(m_inputStreamLayout, m_streamBufferViews);
  436. }
  437. void VariableRateShadingExampleComponent::CreateRenderScope()
  438. {
  439. struct ScopeData
  440. {
  441. };
  442. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  443. {
  444. {
  445. // Binds the swap chain as a color attachment.
  446. RHI::ImageScopeAttachmentDescriptor descriptor;
  447. descriptor.m_attachmentId = m_outputAttachmentId;
  448. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  449. frameGraph.UseColorAttachment(descriptor);
  450. }
  451. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  452. bool useImageShadingRate = m_useImageShadingRate && (device->GetFeatures().m_dynamicShadingRateImage || m_frameCount > device->GetDescriptor().m_frameCountMax);
  453. if (useImageShadingRate)
  454. {
  455. // Binds the shading rate image attachment
  456. AZ::RHI::ImageScopeAttachmentDescriptor dsDesc;
  457. dsDesc.m_attachmentId = VariableRateShading::ShadingRateAttachmentId;
  458. dsDesc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
  459. dsDesc.m_loadStoreAction.m_storeAction = AZ::RHI::AttachmentStoreAction::DontCare;
  460. frameGraph.UseAttachment(dsDesc, AZ::RHI::ScopeAttachmentAccess::Read, AZ::RHI::ScopeAttachmentUsage::ShadingRate);
  461. }
  462. frameGraph.SetEstimatedItemCount(1);
  463. };
  464. RHI::EmptyCompileFunction<ScopeData> compileFunction;
  465. const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  466. {
  467. RHI::CommandList* commandList = context.GetCommandList();
  468. // Set persistent viewport and scissor state.
  469. commandList->SetViewports(&m_viewport, 1);
  470. commandList->SetScissors(&m_scissor, 1);
  471. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  472. if (m_useDrawShadingRate)
  473. {
  474. RHI::ShadingRateCombinators combinators = { RHI::ShadingRateCombinerOp::Passthrough, m_combinerOp };
  475. commandList->SetFragmentShadingRate(m_shadingRate, combinators);
  476. }
  477. const RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_modelShaderResourceGroup->GetRHIShaderResourceGroup() };
  478. // We have to wait until the updating of the initial contents of the shading rate image is done if
  479. // dynamic mode is not supported (since the CPU would try to read it while the GPU is updating the contents)
  480. bool useImageShadingRate = m_useImageShadingRate && (device->GetFeatures().m_dynamicShadingRateImage || m_frameCount > device->GetDescriptor().m_frameCountMax);
  481. RHI::DrawIndexed drawIndexed;
  482. drawIndexed.m_indexCount = 6;
  483. drawIndexed.m_instanceCount = 1;
  484. RHI::DrawItem drawItem;
  485. drawItem.m_arguments = drawIndexed;
  486. drawItem.m_pipelineState = m_modelPipelineState[useImageShadingRate ? 0 : 1].get();
  487. drawItem.m_indexBufferView = &m_indexBufferView;
  488. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));;
  489. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  490. drawItem.m_streamBufferViewCount = static_cast<uint8_t>(m_streamBufferViews.size());
  491. drawItem.m_streamBufferViews = m_streamBufferViews.data();
  492. commandList->Submit(drawItem);
  493. };
  494. const RHI::ScopeId forwardScope("SceneScope");
  495. m_scopeProducers.emplace_back(
  496. aznew RHI::ScopeProducerFunction<
  497. ScopeData,
  498. decltype(prepareFunction),
  499. decltype(compileFunction),
  500. decltype(executeFunction)>(
  501. forwardScope,
  502. ScopeData{},
  503. prepareFunction,
  504. compileFunction,
  505. executeFunction));
  506. }
  507. void VariableRateShadingExampleComponent::CreateComputeScope()
  508. {
  509. struct ScopeData
  510. {
  511. };
  512. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  513. const auto& deviceFeatures = device->GetFeatures();
  514. // If "m_dynamicShadingRateImage" is not supported we cannot update the same image that is being used as shading rate this frame.
  515. // We use an "old" one that is not longer in used.
  516. const char* shadingRateAttachmentId = deviceFeatures.m_dynamicShadingRateImage ? VariableRateShading::ShadingRateAttachmentId : VariableRateShading::ShadingRateAttachmentUpdateId;
  517. const auto prepareFunction = [this, shadingRateAttachmentId](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  518. {
  519. if (m_useImageShadingRate)
  520. {
  521. RHI::ImageScopeAttachmentDescriptor shadingRateImageDesc;
  522. shadingRateImageDesc.m_attachmentId = shadingRateAttachmentId;
  523. shadingRateImageDesc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
  524. shadingRateImageDesc.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::Store;
  525. shadingRateImageDesc.m_imageViewDescriptor.m_overrideFormat = ConvertToUInt(m_rateShadingImageFormat);
  526. frameGraph.UseShaderAttachment(shadingRateImageDesc, RHI::ScopeAttachmentAccess::Write);
  527. }
  528. frameGraph.SetEstimatedItemCount(1);
  529. };
  530. const auto compileFunction = [this, shadingRateAttachmentId](const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  531. {
  532. if (m_useImageShadingRate)
  533. {
  534. Vector2 center = m_cursorPos * m_shadingRateImageSize;
  535. const RHI::ImageView* shadingRateImageView = context.GetImageView(RHI::AttachmentId(shadingRateAttachmentId));
  536. m_computeShaderResourceGroup->SetImageView(m_shadingRateIndex, shadingRateImageView);
  537. m_computeShaderResourceGroup->SetConstant(m_centerIndex, center);
  538. m_computeShaderResourceGroup->Compile();
  539. }
  540. };
  541. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  542. {
  543. if (!m_useImageShadingRate)
  544. {
  545. return;
  546. }
  547. RHI::CommandList* commandList = context.GetCommandList();
  548. RHI::DispatchItem dispatchItem;
  549. decltype(dispatchItem.m_shaderResourceGroups) shaderResourceGroups = { { m_computeShaderResourceGroup->GetRHIShaderResourceGroup() } };
  550. RHI::DispatchDirect dispatchArgs;
  551. dispatchArgs.m_totalNumberOfThreadsX = aznumeric_cast<uint32_t>(m_shadingRateImageSize.GetX());
  552. dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(m_numThreadsX);
  553. dispatchArgs.m_totalNumberOfThreadsY = aznumeric_cast<uint32_t>(m_shadingRateImageSize.GetY());
  554. dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(m_numThreadsY);
  555. dispatchArgs.m_totalNumberOfThreadsZ = 1;
  556. dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(m_numThreadsZ);
  557. AZ_Assert(dispatchArgs.m_threadsPerGroupX == dispatchArgs.m_threadsPerGroupY, "If the shader source changes, this logic should change too.");
  558. AZ_Assert(dispatchArgs.m_threadsPerGroupZ == 1, "If the shader source changes, this logic should change too.");
  559. dispatchItem.m_arguments = dispatchArgs;
  560. dispatchItem.m_pipelineState = m_computePipelineState.get();
  561. dispatchItem.m_shaderResourceGroupCount = 1;
  562. dispatchItem.m_shaderResourceGroups = shaderResourceGroups;
  563. commandList->Submit(dispatchItem);
  564. };
  565. const RHI::ScopeId computeScope("ShadingRateImageCompute");
  566. m_scopeProducers.emplace_back(
  567. aznew RHI::ScopeProducerFunction<
  568. ScopeData,
  569. decltype(prepareFunction),
  570. decltype(compileFunction),
  571. decltype(executeFunction)>(
  572. computeScope,
  573. ScopeData{},
  574. prepareFunction,
  575. compileFunction,
  576. executeFunction));
  577. }
  578. void VariableRateShadingExampleComponent::CreatImageDisplayScope()
  579. {
  580. struct ScopeData
  581. {
  582. };
  583. const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
  584. {
  585. {
  586. // Binds the swap chain as a color attachment.
  587. RHI::ImageScopeAttachmentDescriptor descriptor;
  588. descriptor.m_attachmentId = m_outputAttachmentId;
  589. descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
  590. frameGraph.UseColorAttachment(descriptor);
  591. }
  592. if (m_showShadingRateImage)
  593. {
  594. // Binds the shading rate image for reading (not as attachment)
  595. RHI::ImageScopeAttachmentDescriptor shadingRateImageDesc;
  596. shadingRateImageDesc.m_attachmentId = VariableRateShading::ShadingRateAttachmentId;
  597. shadingRateImageDesc.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::DontCare;
  598. shadingRateImageDesc.m_imageViewDescriptor.m_overrideFormat = ConvertToUInt(m_rateShadingImageFormat);
  599. frameGraph.UseShaderAttachment(shadingRateImageDesc, RHI::ScopeAttachmentAccess::Read);
  600. }
  601. frameGraph.SetEstimatedItemCount(1);
  602. };
  603. const auto compileFunction = [this](const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
  604. {
  605. if (m_showShadingRateImage)
  606. {
  607. const RHI::ImageView* shadingRateImageView = context.GetImageView(RHI::AttachmentId(VariableRateShading::ShadingRateAttachmentId));
  608. m_imageShaderResourceGroup->SetImageView(m_shadingRateDisplayIndex, shadingRateImageView);
  609. m_imageShaderResourceGroup->Compile();
  610. }
  611. };
  612. const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
  613. {
  614. if (!m_showShadingRateImage)
  615. {
  616. return;
  617. }
  618. RHI::CommandList* commandList = context.GetCommandList();
  619. // Set persistent viewport and scissor state.
  620. commandList->SetViewports(&m_viewport, 1);
  621. commandList->SetScissors(&m_scissor, 1);
  622. const RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_imageShaderResourceGroup->GetRHIShaderResourceGroup() };
  623. RHI::DrawIndexed drawIndexed;
  624. drawIndexed.m_indexCount = 6;
  625. drawIndexed.m_instanceCount = 1;
  626. RHI::DrawItem drawItem;
  627. drawItem.m_arguments = drawIndexed;
  628. drawItem.m_pipelineState = m_imagePipelineState.get();
  629. drawItem.m_indexBufferView = &m_indexBufferView;
  630. drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
  631. drawItem.m_shaderResourceGroups = shaderResourceGroups;
  632. drawItem.m_streamBufferViewCount = static_cast<uint8_t>(m_streamBufferViews.size());
  633. drawItem.m_streamBufferViews = m_streamBufferViews.data();
  634. commandList->Submit(drawItem);
  635. };
  636. const RHI::ScopeId forwardScope("ImageDisplayScope");
  637. m_scopeProducers.emplace_back(
  638. aznew RHI::ScopeProducerFunction<
  639. ScopeData,
  640. decltype(prepareFunction),
  641. decltype(compileFunction),
  642. decltype(executeFunction)>(
  643. forwardScope,
  644. ScopeData{},
  645. prepareFunction,
  646. compileFunction,
  647. executeFunction));
  648. }
  649. void VariableRateShadingExampleComponent::Deactivate()
  650. {
  651. m_imguiSidebar.Deactivate();
  652. AZ::RHI::RHISystemNotificationBus::Handler::BusDisconnect();
  653. AZ::TickBus::Handler::BusDisconnect();
  654. AzFramework::InputChannelEventListener::BusDisconnect();
  655. m_bufferPool = nullptr;
  656. m_inputAssemblyBuffer = nullptr;
  657. m_modelPipelineState[0] = nullptr;
  658. m_modelPipelineState[1] = nullptr;
  659. m_imagePipelineState = nullptr;
  660. m_modelShaderResourceGroup = nullptr;
  661. m_computeShaderResourceGroup = nullptr;
  662. m_imageShaderResourceGroup = nullptr;
  663. m_shaders.clear();
  664. m_supportedModes.clear();
  665. m_windowContext = nullptr;
  666. m_imagePool = nullptr;
  667. m_shadingRateImages.clear();
  668. m_scopeProducers.clear();
  669. }
  670. void VariableRateShadingExampleComponent::DrawSettings()
  671. {
  672. RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
  673. const auto& deviceFeatures = device->GetFeatures();
  674. ImGui::Spacing();
  675. if (RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerImage))
  676. {
  677. ScriptableImGui::Checkbox("Image Shade Rate", &m_useImageShadingRate);
  678. if (m_useImageShadingRate)
  679. {
  680. ImGui::Indent();
  681. ScriptableImGui::Checkbox("Show Image", &m_showShadingRateImage);
  682. ScriptableImGui::Checkbox("Follow Pointer", &m_followPointer);
  683. ImGui::Unindent();
  684. }
  685. else
  686. {
  687. m_showShadingRateImage = false;
  688. m_followPointer = false;
  689. }
  690. }
  691. if (RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerDraw))
  692. {
  693. ScriptableImGui::Checkbox("Draw Shade Rate", &m_useDrawShadingRate);
  694. if (m_useDrawShadingRate)
  695. {
  696. ImGui::Indent();
  697. AZStd::vector<const char*> items;
  698. for(const auto rate : m_supportedModes)
  699. {
  700. items.push_back(ToString(rate));
  701. }
  702. int current_item = static_cast<int>(AZStd::distance(m_supportedModes.begin(), AZStd::find(m_supportedModes.begin(), m_supportedModes.end(), m_shadingRate)));
  703. ScriptableImGui::Combo("Shading Rates", &current_item, items.data(), static_cast<int>(items.size()));
  704. m_shadingRate = m_supportedModes[current_item];
  705. ImGui::Unindent();
  706. }
  707. }
  708. if (m_useDrawShadingRate && m_useImageShadingRate)
  709. {
  710. AZStd::vector<const char*> items = { "Passthrough", "Override", "Min", "Max" };
  711. int current_item = static_cast<int>(m_combinerOp);
  712. ScriptableImGui::Combo("Combiner Op", &current_item, items.data(), static_cast<int>(items.size()));
  713. m_combinerOp = static_cast<RHI::ShadingRateCombinerOp>(current_item);
  714. }
  715. else if(m_useDrawShadingRate)
  716. {
  717. m_combinerOp = RHI::ShadingRateCombinerOp::Passthrough;
  718. }
  719. if (!m_followPointer)
  720. {
  721. m_cursorPos = AZ::Vector2(0.5f, 0.5f);
  722. }
  723. m_imguiSidebar.End();
  724. }
  725. bool VariableRateShadingExampleComponent::OnInputChannelEventFiltered(const AzFramework::InputChannel& inputChannel)
  726. {
  727. if (m_followPointer)
  728. {
  729. const AzFramework::InputChannelId& inputChannelId = inputChannel.GetInputChannelId();
  730. switch (inputChannel.GetState())
  731. {
  732. case AzFramework::InputChannel::State::Began:
  733. case AzFramework::InputChannel::State::Updated: // update the camera rotation
  734. {
  735. const AzFramework::InputChannel::PositionData2D* position = nullptr;
  736. // Mouse or Touch Events
  737. if (inputChannelId == AzFramework::InputDeviceMouse::SystemCursorPosition ||
  738. inputChannelId == AzFramework::InputDeviceTouch::Touch::Index0)
  739. {
  740. position = inputChannel.GetCustomData<AzFramework::InputChannel::PositionData2D>();
  741. }
  742. if (position)
  743. {
  744. m_cursorPos = position->m_normalizedPosition;
  745. }
  746. break;
  747. }
  748. default:
  749. break;
  750. }
  751. }
  752. return false;
  753. }
  754. } // namespace AtomSampleViewer