Kaynağa Gözat

VariableRateShading RHI sample

Resolves #13388

Signed-off-by: Akio Gaule <[email protected]>
Akio Gaule 2 yıl önce
ebeveyn
işleme
e66b463d8d

+ 3 - 3
Gem/Code/Source/RHI/SubpassExampleComponent.cpp

@@ -231,9 +231,9 @@ namespace AtomSampleViewer
             ->DepthStencilAttachment(AZ::RHI::Format::D32_FLOAT, m_depthStencilAttachmentId);
         // Composition Subpass
         attachmentsBuilder.AddSubpass()
-            ->SubpassInputAttachment(m_positionAttachmentId)
-            ->SubpassInputAttachment(m_normalAttachmentId)
-            ->SubpassInputAttachment(m_albedoAttachmentId)
+            ->SubpassInputAttachment(m_positionAttachmentId, RHI::ImageAspectFlags::Color)
+            ->SubpassInputAttachment(m_normalAttachmentId, RHI::ImageAspectFlags::Color)
+            ->SubpassInputAttachment(m_albedoAttachmentId, RHI::ImageAspectFlags::Color)
             ->RenderTargetAttachment(m_outputAttachmentId)
             ->DepthStencilAttachment(m_depthStencilAttachmentId);
 

+ 871 - 0
Gem/Code/Source/RHI/VariableRateShadingExampleComponent.cpp

@@ -0,0 +1,871 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <RHI/VariableRateShadingExampleComponent.h>
+#include <Utils/Utils.h>
+
+#include <SampleComponentManager.h>
+
+#include <Atom/RHI/CommandList.h>
+#include <Atom/RHI.Reflect/InputStreamLayoutBuilder.h>
+#include <Atom/RHI.Reflect/RenderAttachmentLayoutBuilder.h>
+#include <Atom/RHI.Reflect/VariableRateShadingEnums.h>
+#include <Atom/RPI.Public/Shader/Shader.h>
+#include <Atom/RPI.Reflect/Shader/ShaderAsset.h>
+#include <AzCore/Serialization/SerializeContext.h>
+#include <AzCore/Math/MathUtils.h>
+#include <AzCore/std/containers/span.h>
+#include <AzCore/Math/MatrixUtils.h>
+#include <AzFramework/Input/Devices/Mouse/InputDeviceMouse.h>
+#include <AzFramework/Input/Devices/Touch/InputDeviceTouch.h>
+
+namespace AtomSampleViewer
+{
+    using namespace AZ;
+
+    namespace VariableRateShading
+    {
+        const char* SampleName = "VariableRateShadingExample";
+        const char* ShadingRateAttachmentId = "ShadingRateAttachmentId";
+        const char* ShadingRateAttachmentUpdateId = "ShadingRateAttachmentUpdateId";
+    }
+
+    RHI::Format ConvertToUInt(RHI::Format format)
+    {
+        uint32_t count = GetFormatComponentCount(format);
+        if (count == 1)
+        {
+            return RHI::Format::R8_UINT;
+        }
+        else if (count == 2)
+        {
+            return RHI::Format::R8G8_UINT;
+        }
+        return RHI::Format::R8G8B8A8_UINT;
+    }
+
+    const char* ToString(RHI::ShadingRate rate)
+    {
+        switch (rate)
+        {
+        case RHI::ShadingRate::Rate1x1: return "Rate1x1";
+        case RHI::ShadingRate::Rate1x2: return "Rate1x2";
+        case RHI::ShadingRate::Rate2x1: return "Rate2x1";
+        case RHI::ShadingRate::Rate2x2: return "Rate2x2";
+        case RHI::ShadingRate::Rate2x4: return "Rate2x4";
+        case RHI::ShadingRate::Rate4x2: return "Rate4x2";
+        case RHI::ShadingRate::Rate4x4: return "Rate4x4";
+        default: return "";
+        }
+    }
+
+    void VariableRateShadingExampleComponent::Reflect(AZ::ReflectContext* context)
+    {
+        if (auto* serializeContext = azrtti_cast<AZ::SerializeContext*>(context))
+        {
+            serializeContext->Class<VariableRateShadingExampleComponent, AZ::Component>()
+                ->Version(0)
+                ;
+        }
+    }
+
+    VariableRateShadingExampleComponent::VariableRateShadingExampleComponent()
+    {
+        m_supportRHISamplePipeline = true;
+    }
+
+    void VariableRateShadingExampleComponent::OnTick([[maybe_unused]] float deltaTime, [[maybe_unused]] AZ::ScriptTimePoint time)
+    {
+        if (m_imguiSidebar.Begin())
+        {
+            DrawSettings();
+        }
+    }
+
+    void VariableRateShadingExampleComponent::OnFramePrepare(AZ::RHI::FrameGraphBuilder& frameGraphBuilder)
+    {
+        if (m_windowContext->GetSwapChainsSize() && m_windowContext->GetSwapChain())
+        {
+            if (m_useImageShadingRate)
+            {
+                frameGraphBuilder.GetAttachmentDatabase().ImportImage(RHI::AttachmentId{ VariableRateShading::ShadingRateAttachmentId }, m_shadingRateImages[m_frameCount % m_shadingRateImages.size()]);
+                if (!Utils::GetRHIDevice()->GetFeatures().m_dynamicShadingRateImage)
+                {
+                    // We cannot update and use the same shading rate image because "m_dynamicShadingRateImage" is not supported.
+                    frameGraphBuilder.GetAttachmentDatabase().ImportImage(RHI::AttachmentId{ VariableRateShading::ShadingRateAttachmentUpdateId }, m_shadingRateImages[(m_frameCount + m_shadingRateImages.size() - 1) % m_shadingRateImages.size()]);
+                }
+            }
+            m_frameCount++;
+        }
+
+        BasicRHIComponent::OnFramePrepare(frameGraphBuilder);
+    }
+
+    void VariableRateShadingExampleComponent::Activate()
+    {
+        AZ::TickBus::Handler::BusConnect();
+        AZ::RHI::RHISystemNotificationBus::Handler::BusConnect();
+        AzFramework::InputChannelEventListener::Connect();
+
+
+        RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+        const auto& deviceFeatures = device->GetFeatures();
+        if (!RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerImage))
+        {
+            m_useImageShadingRate = false;
+        }
+
+        if (!RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerDraw))
+        {
+            m_useDrawShadingRate = false;
+        }
+
+        if (RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerImage))
+        {
+            for (uint32_t i = 0; i < static_cast<uint32_t>(RHI::Format::Count); ++i)
+            {
+                RHI::Format format = static_cast<RHI::Format>(i);
+                RHI::FormatCapabilities capabilities = device->GetFormatCapabilities(format);
+                if (RHI::CheckBitsAll(capabilities, RHI::FormatCapabilities::ShadingRate))
+                {
+                    m_rateShadingImageFormat = format;
+                    break;
+                }
+            }
+            AZ_Assert(m_rateShadingImageFormat != RHI::Format::Unknown, "Could not find a format for the shading rate image");
+        }
+
+        const auto& supportedMask = device->GetFeatures().m_shadingRateMask;
+        for (uint32_t i = 0; i < static_cast<uint32_t>(RHI::ShadingRate::Count); ++i)
+        {
+            if (RHI::CheckBitsAll(supportedMask, static_cast<RHI::ShadingRateFlags>(AZ_BIT(i))))
+            {
+                m_supportedModes.push_back(static_cast<RHI::ShadingRate>(i));
+            }
+        }
+        m_shadingRate = m_supportedModes[0];
+
+        CreateShadingRateImage();
+        LoadShaders();
+        CreateInputAssemblyBuffersAndViews();
+        CreateShaderResourceGroups();
+        CreatePipelines();
+        CreateComputeScope();
+        CreateRenderScope();
+        CreatImageDisplayScope();
+        m_frameCount = 0;
+    }
+
+    void VariableRateShadingExampleComponent::CreateShadingRateImage()
+    {
+        RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+
+        const auto& tileSize = device->GetLimits().m_shadingRateTileSize;
+        m_shadingRateImageSize = Vector2(ceil(static_cast<float>(m_outputWidth) / tileSize.m_width), ceil(static_cast<float>(m_outputHeight) / tileSize.m_height));
+
+        m_imagePool = RHI::Factory::Get().CreateImagePool();
+        RHI::ImagePoolDescriptor imagePoolDesc;
+        imagePoolDesc.m_bindFlags = RHI::ImageBindFlags::ShadingRate | RHI::ImageBindFlags::ShaderReadWrite;
+        m_imagePool->Init(*device, imagePoolDesc);
+
+        // Initialize the shading rate images with proper values. Invalid values may cause a crash.
+        uint32_t width = static_cast<uint32_t>(m_shadingRateImageSize.GetX());
+        uint32_t height = static_cast<uint32_t>(m_shadingRateImageSize.GetY());
+        uint32_t formatSize = GetFormatSize(m_rateShadingImageFormat);
+        uint32_t bufferSize = width * height * formatSize;
+        AZStd::vector<uint8_t> shadingRatePatternData(bufferSize);
+        // Use the lowest shading rate as the default value.
+        RHI::ShadingRateImageValue defaultValue = device->ConvertShadingRate(m_supportedModes[m_supportedModes.size() - 1]);
+        uint8_t* ptrData = shadingRatePatternData.data();
+        for (uint32_t y = 0; y < height; y++)
+        {
+            for (uint32_t x = 0; x < width; x++)
+            {
+                ::memcpy(ptrData, &defaultValue, formatSize);
+                ptrData += formatSize;
+            }
+        }
+
+        // Since the device may not support "Dynamic Shading Rate Image", we need to buffer the update of the shading rate image
+        // because the CPU may be trying to read the image.
+        m_shadingRateImages.resize(device->GetFeatures().m_dynamicShadingRateImage ? 1 : device->GetDescriptor().m_frameCountMax+3);
+        for (auto& image : m_shadingRateImages)
+        {
+            image = RHI::Factory::Get().CreateImage();
+            RHI::ImageInitRequest initImageRequest;
+            RHI::ClearValue clearValue = RHI::ClearValue::CreateVector4Float(1, 1, 1, 1);
+            initImageRequest.m_image = image.get();
+            initImageRequest.m_descriptor = RHI::ImageDescriptor::Create2D(
+                imagePoolDesc.m_bindFlags,
+                static_cast<uint32_t>(m_shadingRateImageSize.GetX()),
+                static_cast<uint32_t>(m_shadingRateImageSize.GetY()),
+                m_rateShadingImageFormat);
+            initImageRequest.m_optimizedClearValue = &clearValue;
+            m_imagePool->InitImage(initImageRequest);
+
+            RHI::ImageUpdateRequest request;
+            request.m_image = image.get();
+            request.m_sourceData = shadingRatePatternData.data();
+            request.m_sourceSubresourceLayout = RHI::ImageSubresourceLayout(
+                RHI::Size(width, height, 1),
+                height,
+                width * formatSize,
+                bufferSize,
+                1,
+                1
+            );
+
+            m_imagePool->UpdateImageContents(request);
+        }        
+    }  
+
+    void VariableRateShadingExampleComponent::LoadShaders()
+    {
+        const char* shaders[] =
+        {
+            "Shaders/RHI/VariableRateShading.azshader",
+            "Shaders/RHI/VariableRateShadingCompute.azshader",
+            "Shaders/RHI/VariableRateShadingImage.azshader"
+        };
+
+        m_shaders.resize(AZ_ARRAY_SIZE(shaders));
+        for (size_t i = 0; i < AZ_ARRAY_SIZE(shaders); ++i)
+        {
+            auto shader = LoadShader(shaders[i], VariableRateShading::SampleName);
+            if (shader == nullptr)
+            {
+                return;
+            }
+
+            m_shaders[i] = shader;
+        }
+
+        const auto& numThreads = m_shaders[1]->GetAsset()->GetAttribute(RHI::ShaderStage::Compute, Name("numthreads"));
+        if (numThreads)
+        {
+            const RHI::ShaderStageAttributeArguments& args = *numThreads;
+            m_numThreadsX = args[0].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[0]) : m_numThreadsX;
+            m_numThreadsY = args[1].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[1]) : m_numThreadsY;
+            m_numThreadsZ = args[2].type() == azrtti_typeid<int>() ? AZStd::any_cast<int>(args[2]) : m_numThreadsZ;
+        }
+    }
+
+    void VariableRateShadingExampleComponent::CreateShaderResourceGroups()
+    {
+        const Name albedoId{ "m_texture" };
+        auto textureIamge = LoadStreamingImage("textures/bricks.png.streamingimage", VariableRateShading::SampleName);
+
+        AZ::RHI::ShaderInputImageIndex albedoIndex;
+        m_modelShaderResourceGroup = CreateShaderResourceGroup(m_shaders[0], "InstanceSrg", VariableRateShading::SampleName);
+        FindShaderInputIndex(&albedoIndex, m_modelShaderResourceGroup, albedoId, VariableRateShading::SampleName);
+        m_modelShaderResourceGroup->SetImage(albedoIndex, textureIamge);
+        m_modelShaderResourceGroup->Compile();
+
+        const Name centerId{ "m_center" };
+        const Name distancesId{ "m_distances" };
+        const Name patternId{ "m_pattern" };
+        const Name shadingRateImageId{ "m_shadingRateTexture" };
+
+        AZ::RHI::ShaderInputConstantIndex patternIndex;
+
+        m_computeShaderResourceGroup = CreateShaderResourceGroup(m_shaders[1], "ComputeSrg", VariableRateShading::SampleName);
+        FindShaderInputIndex(&patternIndex, m_computeShaderResourceGroup, patternId, VariableRateShading::SampleName);
+        FindShaderInputIndex(&m_centerIndex, m_computeShaderResourceGroup, centerId, VariableRateShading::SampleName);
+        FindShaderInputIndex(&m_shadingRateIndex, m_computeShaderResourceGroup, shadingRateImageId, VariableRateShading::SampleName);
+
+        RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+
+        struct Pattern
+        {
+            float m_distance[4];
+            uint32_t m_rate[4];
+        };
+
+        struct Color
+        {
+            float m_color[4];
+            uint32_t m_rate[4];
+        };
+
+        const float alpha = 0.3f;
+        const uint32_t numRates = static_cast<uint32_t>(RHI::ShadingRate::Count);
+        AZStd::array<Pattern, numRates> pattern;
+        AZStd::array<Color, numRates> patternColors;
+        AZStd::array<AZ::Color, numRates> colors = 
+        {{
+            AZ::Color(0.f, 0.f, 1.f, alpha),
+            AZ::Color(1.f, 0.f, 0.f, alpha),
+            AZ::Color(0.f, 1.f, 0.f, alpha),
+            AZ::Color(1.f, 0.f, 1.f, alpha),
+            AZ::Color(1.f, 1.f, 0.f, alpha),
+            AZ::Color(0.f, 1.f, 1.f, alpha),
+            AZ::Color(1.f, 1.f, 1.f, alpha)
+        }};
+
+        float range = 60.0f / numRates;
+        float currentRange = 8.0f;
+
+        const auto& supportedMask = device->GetFeatures().m_shadingRateMask;
+        for (uint32_t i = 0; i < pattern.size(); ++i)
+        {
+            RHI::ShadingRateImageValue rate = {};
+            pattern[i].m_distance[0] = 0.0f;
+            if (RHI::CheckBitsAll(supportedMask, static_cast<RHI::ShadingRateFlags>(AZ_BIT(i))))
+            {
+                rate = device->ConvertShadingRate(static_cast<RHI::ShadingRate>(i));
+                pattern[i].m_distance[0] = currentRange;
+            }
+            pattern[i].m_rate[0] = rate.m_x;
+            pattern[i].m_rate[1] = rate.m_y;
+            currentRange += range;
+
+            patternColors[i].m_rate[0] = pattern[i].m_rate[0];
+            patternColors[i].m_rate[1] = pattern[i].m_rate[1];
+            colors[i].StoreToFloat4(patternColors[i].m_color);
+        }       
+
+        Vector2 center(static_cast<float>(m_shadingRateImageSize.GetX()) * 0.5f, static_cast<float>(m_shadingRateImageSize.GetY()) * 0.5f);
+        m_computeShaderResourceGroup->SetConstant(m_centerIndex, center);
+        m_computeShaderResourceGroup->SetConstantArray(patternIndex, pattern);
+
+        const Name colorsId{ "m_colors" };
+        const Name textureId{ "m_texture" };
+        AZ::RHI::ShaderInputConstantIndex colorsIndex;
+
+        m_imageShaderResourceGroup = CreateShaderResourceGroup(m_shaders[2], "InstanceSrg", VariableRateShading::SampleName);
+        FindShaderInputIndex(&colorsIndex, m_imageShaderResourceGroup, colorsId, VariableRateShading::SampleName);
+        FindShaderInputIndex(&m_shadingRateDisplayIndex, m_imageShaderResourceGroup, textureId, VariableRateShading::SampleName);
+        m_imageShaderResourceGroup->SetConstantArray(colorsIndex, patternColors);
+    }
+
+    void VariableRateShadingExampleComponent::CreatePipelines()
+    {        
+        {
+            // We create one pipeline when using a shading rate attachment, and another one when we are not using it.
+            RHI::RenderAttachmentLayoutBuilder shadingRateAttachmentsBuilder;
+            shadingRateAttachmentsBuilder.AddSubpass()
+                ->RenderTargetAttachment(m_outputFormat)
+                ->ShadingRateAttachment(m_rateShadingImageFormat);
+
+            RHI::RenderAttachmentLayout shadingRateRenderAttachmentLayout;
+            [[maybe_unused]] RHI::ResultCode result = shadingRateAttachmentsBuilder.End(shadingRateRenderAttachmentLayout);
+            AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
+
+            const auto& shader = m_shaders[0];
+            auto& variant = shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId);
+
+            RHI::PipelineStateDescriptorForDraw pipelineDesc;
+            variant.ConfigurePipelineState(pipelineDesc);
+            pipelineDesc.m_renderStates.m_depthStencilState = RHI::DepthStencilState::CreateDisabled();
+            pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = shadingRateRenderAttachmentLayout;
+            pipelineDesc.m_renderAttachmentConfiguration.m_subpassIndex = 0;
+            pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
+
+            m_modelPipelineState[0] = shader->AcquirePipelineState(pipelineDesc);
+            if (!m_modelPipelineState[0])
+            {
+                AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
+                return;
+            }
+
+            RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
+            attachmentsBuilder.AddSubpass()
+                ->RenderTargetAttachment(m_outputFormat);
+
+            RHI::RenderAttachmentLayout rateRenderAttachmentLayout;
+            result = attachmentsBuilder.End(rateRenderAttachmentLayout);
+            AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
+            pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = rateRenderAttachmentLayout;
+
+            m_modelPipelineState[1] = shader->AcquirePipelineState(pipelineDesc);
+            if (!m_modelPipelineState[1])
+            {
+                AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
+                return;
+            }
+        }
+
+        {
+            RHI::PipelineStateDescriptorForDispatch pipelineDesc;
+            const auto& shader = m_shaders[1];
+            shader->GetVariant(RPI::ShaderAsset::RootShaderVariantStableId).ConfigurePipelineState(pipelineDesc);
+
+            m_computePipelineState = shader->AcquirePipelineState(pipelineDesc);
+            if (!m_computePipelineState)
+            {
+                AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for compute");
+                return;
+            }
+        }
+
+        {
+            RHI::RenderAttachmentLayoutBuilder attachmentsBuilder;
+            attachmentsBuilder.AddSubpass()
+                ->RenderTargetAttachment(m_outputFormat);
+
+            RHI::RenderAttachmentLayout renderAttachmentLayout;
+            [[maybe_unused]] RHI::ResultCode result = attachmentsBuilder.End(renderAttachmentLayout);
+            AZ_Assert(result == RHI::ResultCode::Success, "Failed to create render attachment layout");
+
+            const auto& shader = m_shaders[2];
+            auto& variant = shader->GetVariant(AZ::RPI::ShaderAsset::RootShaderVariantStableId);
+
+            RHI::PipelineStateDescriptorForDraw pipelineDesc;
+            variant.ConfigurePipelineState(pipelineDesc);
+            pipelineDesc.m_renderStates.m_depthStencilState = RHI::DepthStencilState::CreateDisabled();
+            pipelineDesc.m_renderAttachmentConfiguration.m_renderAttachmentLayout = renderAttachmentLayout;
+            pipelineDesc.m_renderAttachmentConfiguration.m_subpassIndex = 0;
+            pipelineDesc.m_inputStreamLayout = m_inputStreamLayout;
+
+            RHI::TargetBlendState& targetBlendState = pipelineDesc.m_renderStates.m_blendState.m_targets[0];
+            targetBlendState.m_enable = true;
+            targetBlendState.m_blendSource = RHI::BlendFactor::AlphaSource;
+            targetBlendState.m_blendDest = RHI::BlendFactor::AlphaSourceInverse;
+            targetBlendState.m_blendOp = RHI::BlendOp::Add;
+
+            m_imagePipelineState = shader->AcquirePipelineState(pipelineDesc);
+            if (!m_imagePipelineState)
+            {
+                AZ_Error(VariableRateShading::SampleName, false, "Failed to acquire default pipeline state for shader");
+                return;
+            }
+        }
+    }
+
+    void VariableRateShadingExampleComponent::CreateInputAssemblyBuffersAndViews()
+    {
+        const RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+
+        m_bufferPool = RHI::Factory::Get().CreateBufferPool();
+        RHI::BufferPoolDescriptor bufferPoolDesc;
+        bufferPoolDesc.m_bindFlags = RHI::BufferBindFlags::InputAssembly;
+        bufferPoolDesc.m_heapMemoryLevel = RHI::HeapMemoryLevel::Device;
+        m_bufferPool->Init(*device, bufferPoolDesc);
+
+        struct BufferData
+        {
+            AZStd::array<VertexPosition, 4> m_positions;
+            AZStd::array<VertexUV, 4> m_uvs;
+            AZStd::array<uint16_t, 6> m_indices;
+        };
+
+        BufferData bufferData;
+        SetFullScreenRect(bufferData.m_positions.data(), bufferData.m_uvs.data(), bufferData.m_indices.data());
+
+        m_inputAssemblyBuffer = RHI::Factory::Get().CreateBuffer();
+        RHI::ResultCode result = RHI::ResultCode::Success;
+        RHI::BufferInitRequest request;
+
+        request.m_buffer = m_inputAssemblyBuffer.get();
+        request.m_descriptor = RHI::BufferDescriptor{ RHI::BufferBindFlags::InputAssembly, sizeof(bufferData) };
+        request.m_initialData = &bufferData;
+        result = m_bufferPool->InitBuffer(request);
+        if (result != RHI::ResultCode::Success)
+        {
+            AZ_Error(VariableRateShading::SampleName, false, "Failed to initialize buffer with error code %d", result);
+            return;
+        }
+
+        m_streamBufferViews[0] =
+        {
+            *m_inputAssemblyBuffer,
+            offsetof(BufferData, m_positions),
+            sizeof(BufferData::m_positions),
+            sizeof(VertexPosition)
+        };
+
+        m_streamBufferViews[1] =
+        {
+            *m_inputAssemblyBuffer,
+            offsetof(BufferData, m_uvs),
+            sizeof(BufferData::m_uvs),
+            sizeof(VertexUV)
+        };
+
+        m_indexBufferView =
+        {
+            *m_inputAssemblyBuffer,
+            offsetof(BufferData, m_indices),
+            sizeof(BufferData::m_indices),
+            RHI::IndexFormat::Uint16
+        };
+
+        RHI::InputStreamLayoutBuilder layoutBuilder;
+        layoutBuilder.AddBuffer()->Channel("POSITION", RHI::Format::R32G32B32_FLOAT);
+        layoutBuilder.AddBuffer()->Channel("UV", RHI::Format::R32G32_FLOAT);
+        m_inputStreamLayout = layoutBuilder.End();
+
+        RHI::ValidateStreamBufferViews(m_inputStreamLayout, m_streamBufferViews);
+    }
+
+    void VariableRateShadingExampleComponent::CreateRenderScope()
+    {
+        struct ScopeData
+        {
+        };
+
+        const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
+        {
+            {
+                // Binds the swap chain as a color attachment.
+                RHI::ImageScopeAttachmentDescriptor descriptor;
+                descriptor.m_attachmentId = m_outputAttachmentId;
+                descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
+                frameGraph.UseColorAttachment(descriptor);
+            }
+
+            RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+            bool useImageShadingRate = m_useImageShadingRate && (device->GetFeatures().m_dynamicShadingRateImage || m_frameCount > device->GetDescriptor().m_frameCountMax);
+            if (useImageShadingRate)
+            {
+                // Binds the shading rate image attachment
+                AZ::RHI::ImageScopeAttachmentDescriptor dsDesc;
+                dsDesc.m_attachmentId = VariableRateShading::ShadingRateAttachmentId;
+                dsDesc.m_loadStoreAction.m_loadAction = AZ::RHI::AttachmentLoadAction::Load;
+                dsDesc.m_loadStoreAction.m_storeAction = AZ::RHI::AttachmentStoreAction::DontCare;
+                frameGraph.UseAttachment(dsDesc, AZ::RHI::ScopeAttachmentAccess::Read, AZ::RHI::ScopeAttachmentUsage::ShadingRate);
+            }
+
+            frameGraph.SetEstimatedItemCount(1);
+        };
+
+        RHI::EmptyCompileFunction<ScopeData> compileFunction;
+
+        const auto executeFunction = [this]([[maybe_unused]] const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
+        {
+            RHI::CommandList* commandList = context.GetCommandList();
+
+            // Set persistent viewport and scissor state.
+            commandList->SetViewports(&m_viewport, 1);
+            commandList->SetScissors(&m_scissor, 1);
+
+            RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+            if (m_useDrawShadingRate)
+            {
+                RHI::ShadingRateCombinators combinators = { RHI::ShadingRateCombinerOp::Passthrough, m_combinerOp };
+                commandList->SetFragmentShadingRate(m_shadingRate, combinators);
+            }
+
+            const RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_modelShaderResourceGroup->GetRHIShaderResourceGroup() };
+            // We have to wait until the updating of the initial contents of the shading rate image is done if
+            // dynamic mode is not supported (since the CPU would try to read it while the GPU is updating the contents)
+            bool useImageShadingRate = m_useImageShadingRate && (device->GetFeatures().m_dynamicShadingRateImage || m_frameCount > device->GetDescriptor().m_frameCountMax);
+
+            RHI::DrawIndexed drawIndexed;
+            drawIndexed.m_indexCount = 6;
+            drawIndexed.m_instanceCount = 1;
+
+            RHI::DrawItem drawItem;
+            drawItem.m_arguments = drawIndexed;
+            drawItem.m_pipelineState = m_modelPipelineState[useImageShadingRate ? 0 : 1].get();
+            drawItem.m_indexBufferView = &m_indexBufferView;
+            drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));;
+            drawItem.m_shaderResourceGroups = shaderResourceGroups;
+            drawItem.m_streamBufferViewCount = static_cast<uint8_t>(m_streamBufferViews.size());
+            drawItem.m_streamBufferViews = m_streamBufferViews.data();
+
+            commandList->Submit(drawItem);
+        };
+
+        const RHI::ScopeId forwardScope("SceneScope");
+        m_scopeProducers.emplace_back(
+            aznew RHI::ScopeProducerFunction<
+            ScopeData,
+            decltype(prepareFunction),
+            decltype(compileFunction),
+            decltype(executeFunction)>(
+                forwardScope,
+                ScopeData{},
+                prepareFunction,
+                compileFunction,
+                executeFunction));
+    }
+
+    void VariableRateShadingExampleComponent::CreateComputeScope()
+    {
+        struct ScopeData
+        {
+        };
+
+        RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+        const auto& deviceFeatures = device->GetFeatures();
+        // If "m_dynamicShadingRateImage" is not supported we cannot update the same image that is being used as shading rate this frame.
+        // We use an "old" one that is not longer in used.
+        const char* shadingRateAttachmentId = deviceFeatures.m_dynamicShadingRateImage ? VariableRateShading::ShadingRateAttachmentId : VariableRateShading::ShadingRateAttachmentUpdateId;
+        const auto prepareFunction = [this, shadingRateAttachmentId](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
+        {
+            if (m_useImageShadingRate)
+            {
+                RHI::ImageScopeAttachmentDescriptor shadingRateImageDesc;
+                shadingRateImageDesc.m_attachmentId = shadingRateAttachmentId;
+                shadingRateImageDesc.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::DontCare;
+                shadingRateImageDesc.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::Store;
+                shadingRateImageDesc.m_imageViewDescriptor.m_overrideFormat = ConvertToUInt(m_rateShadingImageFormat);
+                frameGraph.UseShaderAttachment(shadingRateImageDesc, RHI::ScopeAttachmentAccess::Write);
+            }
+
+            frameGraph.SetEstimatedItemCount(1);
+        };
+
+        const auto compileFunction = [this, shadingRateAttachmentId](const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
+        {
+            if (m_useImageShadingRate)
+            {
+                Vector2 center = m_cursorPos * m_shadingRateImageSize;
+                const RHI::ImageView* shadingRateImageView = context.GetImageView(RHI::AttachmentId(shadingRateAttachmentId));
+                m_computeShaderResourceGroup->SetImageView(m_shadingRateIndex, shadingRateImageView);
+                m_computeShaderResourceGroup->SetConstant(m_centerIndex, center);
+                m_computeShaderResourceGroup->Compile();
+            }
+        };
+
+        const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
+        {
+            if (!m_useImageShadingRate)
+            {
+                return;
+            }
+
+            RHI::CommandList* commandList = context.GetCommandList();
+
+            RHI::DispatchItem dispatchItem;
+            decltype(dispatchItem.m_shaderResourceGroups) shaderResourceGroups = { { m_computeShaderResourceGroup->GetRHIShaderResourceGroup() } };
+
+            RHI::DispatchDirect dispatchArgs;
+
+            dispatchArgs.m_totalNumberOfThreadsX = aznumeric_cast<uint32_t>(m_shadingRateImageSize.GetX());
+            dispatchArgs.m_threadsPerGroupX = aznumeric_cast<uint16_t>(m_numThreadsX);
+            dispatchArgs.m_totalNumberOfThreadsY = aznumeric_cast<uint32_t>(m_shadingRateImageSize.GetY());
+            dispatchArgs.m_threadsPerGroupY = aznumeric_cast<uint16_t>(m_numThreadsY);
+            dispatchArgs.m_totalNumberOfThreadsZ = 1;
+            dispatchArgs.m_threadsPerGroupZ = aznumeric_cast<uint16_t>(m_numThreadsZ);
+
+            AZ_Assert(dispatchArgs.m_threadsPerGroupX == dispatchArgs.m_threadsPerGroupY, "If the shader source changes, this logic should change too.");
+            AZ_Assert(dispatchArgs.m_threadsPerGroupZ == 1, "If the shader source changes, this logic should change too.");
+
+            dispatchItem.m_arguments = dispatchArgs;
+            dispatchItem.m_pipelineState = m_computePipelineState.get();
+            dispatchItem.m_shaderResourceGroupCount = 1;
+            dispatchItem.m_shaderResourceGroups = shaderResourceGroups;
+
+            commandList->Submit(dispatchItem);
+        };
+
+        const RHI::ScopeId computeScope("ShadingRateImageCompute");
+        m_scopeProducers.emplace_back(
+            aznew RHI::ScopeProducerFunction<
+            ScopeData,
+            decltype(prepareFunction),
+            decltype(compileFunction),
+            decltype(executeFunction)>(
+                computeScope,
+                ScopeData{},
+                prepareFunction,
+                compileFunction,
+                executeFunction));
+    }
+
+    void VariableRateShadingExampleComponent::CreatImageDisplayScope()
+    {
+        struct ScopeData
+        {
+        };
+
+        const auto prepareFunction = [this](RHI::FrameGraphInterface frameGraph, [[maybe_unused]] ScopeData& scopeData)
+        {
+            {
+                // Binds the swap chain as a color attachment.
+                RHI::ImageScopeAttachmentDescriptor descriptor;
+                descriptor.m_attachmentId = m_outputAttachmentId;
+                descriptor.m_loadStoreAction.m_loadAction = RHI::AttachmentLoadAction::Load;
+                frameGraph.UseColorAttachment(descriptor);
+            }
+
+            if (m_showShadingRateImage)
+            {
+                // Binds the shading rate image for reading (not as attachment)
+                RHI::ImageScopeAttachmentDescriptor shadingRateImageDesc;
+                shadingRateImageDesc.m_attachmentId = VariableRateShading::ShadingRateAttachmentId;
+                shadingRateImageDesc.m_loadStoreAction.m_storeAction = RHI::AttachmentStoreAction::DontCare;
+                shadingRateImageDesc.m_imageViewDescriptor.m_overrideFormat = ConvertToUInt(m_rateShadingImageFormat);
+                frameGraph.UseShaderAttachment(shadingRateImageDesc, RHI::ScopeAttachmentAccess::Read);
+            }
+
+            frameGraph.SetEstimatedItemCount(1);
+        };
+
+        const auto compileFunction = [this](const RHI::FrameGraphCompileContext& context, [[maybe_unused]] const ScopeData& scopeData)
+        {
+            if (m_showShadingRateImage)
+            {
+                const RHI::ImageView* shadingRateImageView = context.GetImageView(RHI::AttachmentId(VariableRateShading::ShadingRateAttachmentId));
+                m_imageShaderResourceGroup->SetImageView(m_shadingRateDisplayIndex, shadingRateImageView);
+                m_imageShaderResourceGroup->Compile();
+            }
+        };
+
+        const auto executeFunction = [this](const RHI::FrameGraphExecuteContext& context, [[maybe_unused]] const ScopeData& scopeData)
+        {
+            if (!m_showShadingRateImage)
+            {
+                return;
+            }
+
+            RHI::CommandList* commandList = context.GetCommandList();
+
+            // Set persistent viewport and scissor state.
+            commandList->SetViewports(&m_viewport, 1);
+            commandList->SetScissors(&m_scissor, 1);
+
+            const RHI::ShaderResourceGroup* shaderResourceGroups[] = { m_imageShaderResourceGroup->GetRHIShaderResourceGroup() };
+
+            RHI::DrawIndexed drawIndexed;
+            drawIndexed.m_indexCount = 6;
+            drawIndexed.m_instanceCount = 1;
+
+            RHI::DrawItem drawItem;
+            drawItem.m_arguments = drawIndexed;
+            drawItem.m_pipelineState = m_imagePipelineState.get();
+            drawItem.m_indexBufferView = &m_indexBufferView;
+            drawItem.m_shaderResourceGroupCount = static_cast<uint8_t>(RHI::ArraySize(shaderResourceGroups));
+            drawItem.m_shaderResourceGroups = shaderResourceGroups;
+            drawItem.m_streamBufferViewCount = static_cast<uint8_t>(m_streamBufferViews.size());
+            drawItem.m_streamBufferViews = m_streamBufferViews.data();
+
+            commandList->Submit(drawItem);
+        };
+
+        const RHI::ScopeId forwardScope("ImageDisplayScope");
+        m_scopeProducers.emplace_back(
+            aznew RHI::ScopeProducerFunction<
+            ScopeData,
+            decltype(prepareFunction),
+            decltype(compileFunction),
+            decltype(executeFunction)>(
+                forwardScope,
+                ScopeData{},
+                prepareFunction,
+                compileFunction,
+                executeFunction));
+    }
+
+    void VariableRateShadingExampleComponent::Deactivate()
+    {
+        m_imguiSidebar.Deactivate();
+        AZ::RHI::RHISystemNotificationBus::Handler::BusDisconnect();
+        AZ::TickBus::Handler::BusDisconnect();
+        AzFramework::InputChannelEventListener::BusDisconnect();
+
+        m_bufferPool = nullptr;
+        m_inputAssemblyBuffer = nullptr;
+        m_modelPipelineState[0] = nullptr;
+        m_modelPipelineState[1] = nullptr;
+        m_imagePipelineState = nullptr;
+        m_modelShaderResourceGroup = nullptr;
+        m_computeShaderResourceGroup = nullptr;
+        m_imageShaderResourceGroup = nullptr;
+        m_shaders.clear();
+        m_supportedModes.clear();
+        m_windowContext = nullptr;
+        m_imagePool = nullptr;
+        m_shadingRateImages.clear();
+        m_scopeProducers.clear();
+    }
+
+    void VariableRateShadingExampleComponent::DrawSettings()
+    {
+        RHI::Ptr<RHI::Device> device = Utils::GetRHIDevice();
+        const auto& deviceFeatures = device->GetFeatures();
+
+        ImGui::Spacing();
+        if (RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerImage))
+        {
+            ScriptableImGui::Checkbox("Image Shade Rate", &m_useImageShadingRate);
+            if (m_useImageShadingRate)
+            {
+                ImGui::Indent();
+                ScriptableImGui::Checkbox("Show Image", &m_showShadingRateImage);
+                ScriptableImGui::Checkbox("Follow Pointer", &m_followPointer);
+                ImGui::Unindent();
+            }
+            else
+            {
+                m_showShadingRateImage = false;
+                m_followPointer = false;
+            }
+        }
+
+        if (RHI::CheckBitsAll(deviceFeatures.m_shadingRateTypeMask, RHI::ShadingRateTypeFlags::PerDraw))
+        {
+            ScriptableImGui::Checkbox("Draw Shade Rate", &m_useDrawShadingRate);
+            if (m_useDrawShadingRate)
+            {
+                ImGui::Indent();
+                AZStd::vector<const char*> items;
+                for(const auto rate : m_supportedModes)
+                {
+                    items.push_back(ToString(rate));
+                }
+                int current_item = static_cast<int>(AZStd::distance(m_supportedModes.begin(), AZStd::find(m_supportedModes.begin(), m_supportedModes.end(), m_shadingRate)));
+                ScriptableImGui::Combo("Shading Rates", &current_item, items.data(), static_cast<int>(items.size()));
+                m_shadingRate = m_supportedModes[current_item];
+                ImGui::Unindent();
+            }
+        }
+
+        if (m_useDrawShadingRate && m_useImageShadingRate)
+        {
+            AZStd::vector<const char*> items = { "Passthrough", "Override", "Min", "Max" };
+            int current_item = static_cast<int>(m_combinerOp);
+            ScriptableImGui::Combo("Combiner Op", &current_item, items.data(), static_cast<int>(items.size()));
+            m_combinerOp = static_cast<RHI::ShadingRateCombinerOp>(current_item);
+        }
+        else if(m_useDrawShadingRate)
+        {
+            m_combinerOp = RHI::ShadingRateCombinerOp::Passthrough;
+        }
+
+        if (!m_followPointer)
+        {
+            m_cursorPos = AZ::Vector2(0.5f, 0.5f);
+        }
+
+        m_imguiSidebar.End();
+    }    
+
+    bool VariableRateShadingExampleComponent::OnInputChannelEventFiltered(const AzFramework::InputChannel& inputChannel)
+    {
+        if (m_followPointer)
+        {
+            const AzFramework::InputChannelId& inputChannelId = inputChannel.GetInputChannelId();
+            switch (inputChannel.GetState())
+            {
+            case AzFramework::InputChannel::State::Began:
+            case AzFramework::InputChannel::State::Updated: // update the camera rotation
+            {
+                const AzFramework::InputChannel::PositionData2D* position = nullptr;
+                // Mouse or Touch Events
+                if (inputChannelId == AzFramework::InputDeviceMouse::SystemCursorPosition ||
+                    inputChannelId == AzFramework::InputDeviceTouch::Touch::Index0)
+                {
+                    position = inputChannel.GetCustomData<AzFramework::InputChannel::PositionData2D>();
+                }
+
+                if (position)
+                {
+                    m_cursorPos = position->m_normalizedPosition;
+                }
+                break;
+            }
+            default:
+                break;
+            }
+        }
+        return false;
+    }
+} // namespace AtomSampleViewer

+ 160 - 0
Gem/Code/Source/RHI/VariableRateShadingExampleComponent.h

@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#pragma once
+
+#include <AzCore/Component/Component.h>
+#include <AzCore/Component/TickBus.h>
+#include <AzCore/Math/Color.h>
+#include <AzCore/std/containers/array.h>
+
+#include <AzFramework/Input/Events/InputChannelEventListener.h>
+#include <AzFramework/Windowing/WindowBus.h>
+#include <AzFramework/Windowing/NativeWindow.h>
+
+#include <Atom/RPI.Public/Shader/Shader.h>
+#include <Atom/RPI.Public/Shader/ShaderResourceGroup.h>
+
+#include <Atom/RHI/BufferPool.h>
+#include <Atom/RHI/DrawItem.h>
+#include <Atom/RHI/Device.h>
+#include <Atom/RHI/Factory.h>
+#include <Atom/RHI/CopyItem.h>
+#include <Atom/RHI/FrameScheduler.h>
+#include <Atom/RHI/PipelineState.h>
+
+#include <RHI/BasicRHIComponent.h>
+#include <Utils/ImGuiSidebar.h>
+#include <Utils/ImGuiProgressList.h>
+#include <Atom/Utils/AssetCollectionAsyncLoader.h>
+
+namespace AtomSampleViewer
+{
+    //! This samples demonstrates the use of Variable Rate Shading on the RHI.
+    //! Shading rates can be specified in 3 different ways: PerDraw, PerPrimtive and PerImage.
+    //! This samples only uses the PerDraw and PerImage modes.
+    //! The samples render a full screen quad using different shading rates.
+    //! When the PerImage mode is used, an image is generated using a compute shader with different
+    //! rates in a circular pattern from the center (or the pointer position).
+    //! When a PerDraw mode is used, the rate is applied equally to the whole quad. The rate can be changed
+    //! using the GUI of the sample.
+    //! Combinator operations are also exposed when both PerDraw and PerImage are being used.
+    class VariableRateShadingExampleComponent final
+        : public BasicRHIComponent
+        , public AZ::TickBus::Handler
+        , public AzFramework::InputChannelEventListener
+    {
+    public:
+        AZ_COMPONENT(VariableRateShadingExampleComponent, "{B98E1C6A-8C23-4AA4-82E6-4B652F6151DD}", AZ::Component);
+        AZ_DISABLE_COPY(VariableRateShadingExampleComponent);
+
+        static void Reflect(AZ::ReflectContext* context);
+
+        VariableRateShadingExampleComponent();
+        ~VariableRateShadingExampleComponent() override = default;
+
+    private:
+        // AZ::Component
+        void Activate() override;
+        void Deactivate() override;
+
+        // AZ::TickBus::Handler
+        void OnTick(float deltaTime, AZ::ScriptTimePoint time);
+
+        // RHISystemNotificationBus::Handler
+        void OnFramePrepare(AZ::RHI::FrameGraphBuilder& frameGraphBuilder) override;
+
+        // AzFramework::InputChannelEventListener
+        bool OnInputChannelEventFiltered(const AzFramework::InputChannel& inputChannel) override;
+
+        // Draw the ImGUI settings
+        void DrawSettings();
+        // Loads the compute and graphics shaders.
+        void LoadShaders();
+        // Creates the image pool and images used for shading rate attachments.
+        void CreateShadingRateImage();
+        // Creates the shading resource groups used by the compute and graphic scopes.
+        void CreateShaderResourceGroups();
+        // Creates the IA resources for the full screen quad.
+        void CreateInputAssemblyBuffersAndViews();
+        // Creates the necessary pipelines.
+        void CreatePipelines();
+        // Creates the scope used for rendering the full screen quad.
+        void CreateRenderScope();
+        // Creates the scope for showing the shading rate image.
+        void CreatImageDisplayScope();
+        // Creates the compute used for updating the shading rate image.
+        void CreateComputeScope();
+
+        // ImGUI sidebar that handles the options of the sample.
+        ImGuiSidebar m_imguiSidebar;
+        // Whether to use a shading rate image.
+        bool m_useImageShadingRate = true;
+        // Whether to show the shading rate image.
+        bool m_showShadingRateImage = false;
+        // Whether the center of the shading rate image follows the position of the pointer.
+        bool m_followPointer = false;
+        // Whether to use the PerDraw mode.
+        bool m_useDrawShadingRate = false;
+        // Combinator operation to applied between the PerDraw and PerImage shading rate.
+        AZ::RHI::ShadingRateCombinerOp m_combinerOp = AZ::RHI::ShadingRateCombinerOp::Passthrough;
+        // Shading rate when using the PerDraw mode.
+        AZ::RHI::ShadingRate m_shadingRate = AZ::RHI::ShadingRate::Rate1x1;
+
+        // Pipelines used for rendering the full screen quad with and without a shading rate attachments.
+        AZ::RHI::ConstPtr<AZ::RHI::PipelineState> m_modelPipelineState[2];
+        // Pipeline used for updating the shading rate image.
+        AZ::RHI::ConstPtr<AZ::RHI::PipelineState> m_computePipelineState;
+        // Pipeline used for showing the shading rate image.
+        AZ::RHI::ConstPtr<AZ::RHI::PipelineState> m_imagePipelineState;
+        // Compute and graphics shaders.
+        AZStd::vector<AZ::Data::Instance<AZ::RPI::Shader>> m_shaders;
+        // SRG with information when rendering the full screen quad.
+        AZ::Data::Instance<AZ::RPI::ShaderResourceGroup> m_modelShaderResourceGroup;
+        // SRG with information when updating the shading rate iamge.
+        AZ::Data::Instance<AZ::RPI::ShaderResourceGroup> m_computeShaderResourceGroup;
+        // SRG with information when displaying the shading rate image.
+        AZ::Data::Instance<AZ::RPI::ShaderResourceGroup> m_imageShaderResourceGroup;
+        
+        // Indices into the SRGs for properties that are updated.
+        AZ::RHI::ShaderInputConstantIndex m_centerIndex;
+        AZ::RHI::ShaderInputImageIndex m_shadingRateIndex;
+        AZ::RHI::ShaderInputImageIndex m_shadingRateDisplayIndex;
+
+        // Size of the shading rate image tile (in pixels)
+        AZ::Vector2 m_shadingRateImageSize;
+        int m_numThreadsX = 8;
+        int m_numThreadsY = 8;
+        int m_numThreadsZ = 1;
+
+        // Bufferpool for creating the IA buffer
+        AZ::RHI::Ptr<AZ::RHI::BufferPool> m_bufferPool;
+        // Buffer for the IA of the full screen quad.
+        AZ::RHI::Ptr<AZ::RHI::Buffer> m_inputAssemblyBuffer;
+        // Bufferviews into the full screen quad IA
+        AZStd::array<AZ::RHI::StreamBufferView, 2> m_streamBufferViews;
+        // Indexview of the full screen quad index buffer
+        AZ::RHI::IndexBufferView m_indexBufferView;
+        // Layout of the full screen quad.
+        AZ::RHI::InputStreamLayout m_inputStreamLayout;
+
+        // Image pool containing the shading rate images.
+        AZ::RHI::Ptr<AZ::RHI::ImagePool> m_imagePool;
+        // List of shading rate images used as attachments.
+        AZStd::fixed_vector<AZ::RHI::Ptr<AZ::RHI::Image>, AZ::RHI::Limits::Device::FrameCountMax> m_shadingRateImages;
+
+        // Cursor position (mouse or touch)
+        AZ::Vector2 m_cursorPos;
+        // Selected format to be used for the shading rate image.
+        AZ::RHI::Format m_rateShadingImageFormat = AZ::RHI::Format::Unknown;
+        // Frame counter for selecting the proper shading rate image. 
+        uint32_t m_frameCount = 0;
+        // List of supported shading rate values.
+        AZStd::fixed_vector<AZ::RHI::ShadingRate, static_cast<uint32_t>(AZ::RHI::ShadingRate::Count)> m_supportedModes;
+    };
+} // namespace AtomSampleViewer

+ 2 - 0
Gem/Code/Source/SampleComponentManager.cpp

@@ -64,6 +64,7 @@
 #include <RHI/BindlessPrototypeExampleComponent.h>
 #include <RHI/RayTracingExampleComponent.h>
 #include <RHI/MatrixAlignmentTestExampleComponent.h>
+#include <RHI/VariableRateShadingExampleComponent.h>
 
 #include <Performance/100KDrawable_SingleView_ExampleComponent.h>
 #include <Performance/100KDraw_10KDrawable_MultiView_ExampleComponent.h>
@@ -306,6 +307,7 @@ namespace AtomSampleViewer
             NewRHISample<TriangleExampleComponent>("Triangle"),
             NewRHISample<TrianglesConstantBufferExampleComponent>("TrianglesConstantBuffer"),
             NewRHISample<MatrixAlignmentTestExampleComponent>("MatrixAlignmentTest"),
+            NewRHISample<VariableRateShadingExampleComponent>("VariableRateShading", []() { return Utils::GetRHIDevice()->GetFeatures().m_shadingRateTypeMask != RHI::ShadingRateTypeFlags::None; }),
             NewRPISample<AssetLoadTestComponent>("AssetLoadTest"),
             NewRPISample<AuxGeomExampleComponent>("AuxGeom"),
             NewRPISample<BakedShaderVariantExampleComponent>("BakedShaderVariant"),

+ 2 - 0
Gem/Code/atomsampleviewergem_private_files.cmake

@@ -85,6 +85,8 @@ set(FILES
     Source/RHI/RayTracingExampleComponent.h
     Source/RHI/MatrixAlignmentTestExampleComponent.cpp
     Source/RHI/MatrixAlignmentTestExampleComponent.h
+    Source/RHI/VariableRateShadingExampleComponent.cpp
+    Source/RHI/VariableRateShadingExampleComponent.h
     Source/Performance/HighInstanceExampleComponent.cpp
     Source/Performance/HighInstanceExampleComponent.h
     Source/Performance/100KDrawable_SingleView_ExampleComponent.cpp

+ 204 - 0
Gem/Code/atomsampleviewergem_private_files.cmake.orig

@@ -0,0 +1,204 @@
+#
+# Copyright (c) Contributors to the Open 3D Engine Project.
+# For complete copyright and license terms please see the LICENSE at the root of this distribution.
+#
+# SPDX-License-Identifier: Apache-2.0 OR MIT
+#
+#
+
+set(FILES
+    Source/AtomSampleComponent.cpp
+    Source/AtomSampleComponent.h
+    Source/AtomSampleViewerOptions.h
+    Source/AtomSampleViewerSystemComponent.cpp
+    Source/AtomSampleViewerSystemComponent.h
+    Source/AtomSampleViewerRequestBus.h
+    Source/SampleComponentManager.cpp
+    Source/SampleComponentManager.h
+    Source/SampleComponentManagerBus.h
+    Source/SampleComponentConfig.cpp
+    Source/SampleComponentConfig.h
+    Source/Automation/AssetStatusTracker.cpp
+    Source/Automation/AssetStatusTracker.h
+    Source/Automation/ImageComparisonConfig.h
+    Source/Automation/ImageComparisonConfig.cpp
+    Source/Automation/PrecommitWizardSettings.h
+    Source/Automation/ScriptableImGui.cpp
+    Source/Automation/ScriptableImGui.h
+    Source/Automation/ScriptManager.cpp
+    Source/Automation/ScriptManager.h
+    Source/Automation/ScriptRepeaterBus.h
+    Source/Automation/ScriptRunnerBus.h
+    Source/Automation/ScriptReporter.cpp
+    Source/Automation/ScriptReporter.h
+    Source/RHI/AlphaToCoverageExampleComponent.cpp
+    Source/RHI/AlphaToCoverageExampleComponent.h
+    Source/RHI/AsyncComputeExampleComponent.h
+    Source/RHI/AsyncComputeExampleComponent.cpp
+    Source/RHI/BasicRHIComponent.cpp
+    Source/RHI/BasicRHIComponent.h
+    Source/RHI/BindlessPrototypeExampleComponent.h
+    Source/RHI/BindlessPrototypeExampleComponent.cpp
+    Source/RHI/ComputeExampleComponent.cpp
+    Source/RHI/ComputeExampleComponent.h
+    Source/RHI/CopyQueueComponent.cpp
+    Source/RHI/CopyQueueComponent.h
+    Source/RHI/DualSourceBlendingComponent.cpp
+    Source/RHI/DualSourceBlendingComponent.h
+    Source/RHI/IndirectRenderingExampleComponent.cpp
+    Source/RHI/IndirectRenderingExampleComponent.h
+    Source/RHI/InputAssemblyExampleComponent.cpp
+    Source/RHI/InputAssemblyExampleComponent.h
+    Source/RHI/MRTExampleComponent.h
+    Source/RHI/MRTExampleComponent.cpp
+    Source/RHI/MSAAExampleComponent.h
+    Source/RHI/MSAAExampleComponent.cpp
+    Source/RHI/MultiThreadComponent.cpp
+    Source/RHI/MultiThreadComponent.h
+    Source/RHI/MultipleViewsComponent.cpp
+    Source/RHI/MultipleViewsComponent.h
+    Source/RHI/MultiViewportSwapchainComponent.cpp
+    Source/RHI/MultiViewportSwapchainComponent.h
+    Source/RHI/QueryExampleComponent.h
+    Source/RHI/QueryExampleComponent.cpp
+    Source/RHI/StencilExampleComponent.cpp
+    Source/RHI/StencilExampleComponent.h
+    Source/RHI/SwapchainExampleComponent.cpp
+    Source/RHI/SwapchainExampleComponent.h
+    Source/RHI/SphericalHarmonicsExampleComponent.cpp
+    Source/RHI/SphericalHarmonicsExampleComponent.h
+    Source/RHI/SubpassExampleComponent.cpp
+    Source/RHI/SubpassExampleComponent.h
+    Source/RHI/Texture3dExampleComponent.cpp
+    Source/RHI/Texture3dExampleComponent.h
+    Source/RHI/TextureArrayExampleComponent.cpp
+    Source/RHI/TextureArrayExampleComponent.h
+    Source/RHI/TextureExampleComponent.cpp
+    Source/RHI/TextureExampleComponent.h
+    Source/RHI/TextureMapExampleComponent.cpp
+    Source/RHI/TextureMapExampleComponent.h
+    Source/RHI/TriangleExampleComponent.cpp
+    Source/RHI/TriangleExampleComponent.h
+    Source/RHI/TrianglesConstantBufferExampleComponent.h
+    Source/RHI/TrianglesConstantBufferExampleComponent.cpp
+    Source/RHI/RayTracingExampleComponent.cpp
+    Source/RHI/RayTracingExampleComponent.h
+    Source/RHI/MatrixAlignmentTestExampleComponent.cpp
+    Source/RHI/MatrixAlignmentTestExampleComponent.h
+<<<<<<< HEAD
+=======
+    Source/RHI/XRExampleComponent.cpp
+    Source/RHI/XRExampleComponent.h
+    Source/RHI/VariableRateShadingExampleComponent.cpp
+    Source/RHI/VariableRateShadingExampleComponent.h
+>>>>>>> 16ebd78 (VariableRateShading RHI sample)
+    Source/Performance/HighInstanceExampleComponent.cpp
+    Source/Performance/HighInstanceExampleComponent.h
+    Source/Performance/100KDrawable_SingleView_ExampleComponent.cpp
+    Source/Performance/100KDrawable_SingleView_ExampleComponent.h
+    Source/Performance/100KDraw_10KDrawable_MultiView_ExampleComponent.cpp
+    Source/Performance/100KDraw_10KDrawable_MultiView_ExampleComponent.h
+    Source/AreaLightExampleComponent.cpp
+    Source/AreaLightExampleComponent.h
+    Source/AssetLoadTestComponent.cpp
+    Source/AssetLoadTestComponent.h
+    Source/AuxGeomExampleComponent.cpp
+    Source/AuxGeomExampleComponent.h
+    Source/AuxGeomSharedDrawFunctions.cpp
+    Source/AuxGeomSharedDrawFunctions.h
+    Source/BakedShaderVariantExampleComponent.h
+    Source/BakedShaderVariantExampleComponent.cpp
+    Source/SponzaBenchmarkComponent.cpp
+    Source/SponzaBenchmarkComponent.h
+    Source/BloomExampleComponent.cpp
+    Source/BloomExampleComponent.h
+    Source/CheckerboardExampleComponent.h
+    Source/CheckerboardExampleComponent.cpp
+    Source/CommonSampleComponentBase.cpp
+    Source/CommonSampleComponentBase.h
+    Source/CullingAndLodExampleComponent.cpp
+    Source/CullingAndLodExampleComponent.h
+    Source/DecalExampleComponent.cpp
+    Source/DecalExampleComponent.h
+    Source/DecalContainer.cpp
+    Source/DecalContainer.h
+    Source/DepthOfFieldExampleComponent.h
+    Source/DepthOfFieldExampleComponent.cpp
+    Source/DiffuseGIExampleComponent.cpp
+    Source/DiffuseGIExampleComponent.h
+    Source/DynamicDrawExampleComponent.h
+    Source/DynamicDrawExampleComponent.cpp
+    Source/DynamicMaterialTestComponent.cpp
+    Source/DynamicMaterialTestComponent.h
+    Source/EntityLatticeTestComponent.cpp
+    Source/EntityLatticeTestComponent.h
+    Source/EntityUtilityFunctions.cpp
+    Source/EntityUtilityFunctions.h
+    Source/ExposureExampleComponent.cpp
+    Source/ExposureExampleComponent.h
+    Source/EyeMaterialExampleComponent.h
+    Source/EyeMaterialExampleComponent.cpp
+    Source/LightCullingExampleComponent.cpp
+    Source/LightCullingExampleComponent.h
+    Source/MeshExampleComponent.cpp
+    Source/MeshExampleComponent.h
+    Source/MSAA_RPI_ExampleComponent.cpp
+    Source/MSAA_RPI_ExampleComponent.h
+    Source/MultiRenderPipelineExampleComponent.cpp
+    Source/MultiRenderPipelineExampleComponent.h
+    Source/MultiSceneExampleComponent.cpp
+    Source/MultiSceneExampleComponent.h
+    Source/MultiViewSingleSceneAuxGeomExampleComponent.cpp
+    Source/MultiViewSingleSceneAuxGeomExampleComponent.h
+    Source/ParallaxMappingExampleComponent.cpp
+    Source/ParallaxMappingExampleComponent.h
+    Source/Passes/RayTracingAmbientOcclusionPass.cpp
+    Source/Passes/RayTracingAmbientOcclusionPass.h
+    Source/ParallaxMappingExampleComponent.h
+    Source/ProceduralSkinnedMesh.cpp
+    Source/ProceduralSkinnedMesh.h
+    Source/ProceduralSkinnedMeshUtils.cpp
+    Source/ProceduralSkinnedMeshUtils.h
+    Source/ReadbackExampleComponent.cpp
+    Source/ReadbackExampleComponent.h
+    Source/RenderTargetTextureExampleComponent.cpp
+    Source/RenderTargetTextureExampleComponent.h
+    Source/RootConstantsExampleComponent.h
+    Source/RootConstantsExampleComponent.cpp
+    Source/SceneReloadSoakTestComponent.cpp
+    Source/SceneReloadSoakTestComponent.h
+    Source/ShadowExampleComponent.cpp
+    Source/ShadowExampleComponent.h
+    Source/ShadowedSponzaExampleComponent.cpp
+    Source/ShadowedSponzaExampleComponent.h
+    Source/SkinnedMeshContainer.cpp
+    Source/SkinnedMeshContainer.h
+    Source/SkinnedMeshExampleComponent.cpp
+    Source/SkinnedMeshExampleComponent.h
+    Source/SsaoExampleComponent.cpp
+    Source/SsaoExampleComponent.h
+    Source/SSRExampleComponent.cpp
+    Source/SSRExampleComponent.h
+    Source/StreamingImageExampleComponent.cpp
+    Source/StreamingImageExampleComponent.h
+    Source/TonemappingExampleComponent.cpp
+    Source/TonemappingExampleComponent.h
+    Source/TransparencyExampleComponent.cpp
+    Source/TransparencyExampleComponent.h
+    Source/ShaderReloadTestComponent.cpp
+    Source/ShaderReloadTestComponent.h
+    Source/Utils/ImGuiAssetBrowser.cpp
+    Source/Utils/ImGuiAssetBrowser.h
+    Source/Utils/ImGuiHistogramQueue.cpp
+    Source/Utils/ImGuiHistogramQueue.h
+    Source/Utils/ImGuiMessageBox.cpp
+    Source/Utils/ImGuiMessageBox.h
+    Source/Utils/ImGuiSaveFilePath.cpp
+    Source/Utils/ImGuiSaveFilePath.h
+    Source/Utils/ImGuiSidebar.cpp
+    Source/Utils/ImGuiSidebar.h
+    Source/Utils/Utils.cpp
+    Source/Utils/Utils.h
+    Source/Utils/ImGuiProgressList.cpp
+    Source/Utils/ImGuiProgressList.h
+)

+ 53 - 0
Shaders/RHI/VariableRateShading.azsl

@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/Features/SrgSemantics.azsli>
+ShaderResourceGroup InstanceSrg : SRG_PerObject
+{
+    Texture2D m_texture;
+    Sampler m_sampler
+    {
+        MaxAnisotropy = 16;
+        AddressU = Wrap;
+        AddressV = Wrap;
+        AddressW = Wrap;
+    };
+}
+
+struct VSInput
+{
+    float3 m_position : POSITION;
+    float2 m_uv : UV0;
+};
+
+struct VSOutput
+{
+    float4 m_position : SV_POSITION;
+    float2 m_uv : UV0;
+};
+
+VSOutput MainVS(VSInput vsInput)
+{
+    VSOutput OUT;
+    OUT.m_uv = vsInput.m_uv;
+    OUT.m_position = float4(vsInput.m_position.x, vsInput.m_position.y, vsInput.m_position.z, 1);
+    return OUT;
+}
+
+
+struct PSOutput
+{
+    float4 m_color : SV_Target0;
+};
+
+PSOutput MainPS(VSOutput vsOutput)
+{
+    PSOutput OUT;
+    OUT.m_color = InstanceSrg::m_texture.Sample(InstanceSrg::m_sampler, vsOutput.m_uv);
+    return OUT;
+}

+ 24 - 0
Shaders/RHI/VariableRateShading.shader

@@ -0,0 +1,24 @@
+{
+    "Source" : "VariableRateShading.azsl",
+
+    "DepthStencilState" : { 
+        "Depth" : { "Enable" : false, "CompareFunc" : "GreaterEqual" }
+    },
+
+    "DrawList" : "forward",
+
+    "ProgramSettings":
+    {
+      "EntryPoints":
+      [
+        {
+          "name": "MainVS",
+          "type": "Vertex"
+        },
+        {
+          "name": "MainPS",
+          "type": "Fragment"
+        }
+      ]
+    }
+}

+ 52 - 0
Shaders/RHI/VariableRateShadingCompute.azsl

@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/Features/SrgSemantics.azsli>
+ 
+#define NumShadingRates 7
+
+ShaderResourceGroup ComputeSrg : SRG_PerObject
+{
+    float2 m_center;
+    RWTexture2D<uint2> m_shadingRateTexture;
+    struct Pattern
+    {
+        float4 m_distance;
+        uint4 m_rate;
+    };
+
+    Pattern m_pattern[NumShadingRates];
+};
+
+#define ThreadGroupSize 16
+[numthreads(ThreadGroupSize, ThreadGroupSize, 1)]
+void MainCS(uint3 groupID : SV_GroupID, uint3 groupThreadID : SV_GroupThreadID)
+{
+    uint2 samplePos = groupID.xy * ThreadGroupSize + groupThreadID.xy;
+    uint2 dimensions;
+    ComputeSrg::m_shadingRateTexture.GetDimensions(dimensions.x, dimensions.y);
+    if (samplePos.x >= dimensions.x || samplePos.y >= dimensions.y)
+    {
+        return;
+    }
+
+    // Set a default value
+    ComputeSrg::m_shadingRateTexture[samplePos] = uint2(ComputeSrg::m_pattern[NumShadingRates - 1].m_rate.x, ComputeSrg::m_pattern[NumShadingRates - 1].m_rate.y);
+    // Create a circular pattern from the "center" with decreasing shading rates.
+    for (int i = 0; i < NumShadingRates; ++i)
+    {
+        float deltaX = ((float)ComputeSrg::m_center.x - (float)samplePos.x) / dimensions.x * 100.0f;
+        float deltaY = ((float)ComputeSrg::m_center.y - (float)samplePos.y) / dimensions.y * 100.0f;
+        float distance = sqrt(deltaX * deltaX + deltaY * deltaY);
+        if (distance < ComputeSrg::m_pattern[i].m_distance.x)
+        {
+            ComputeSrg::m_shadingRateTexture[samplePos] = uint2(ComputeSrg::m_pattern[i].m_rate.x, ComputeSrg::m_pattern[i].m_rate.y);
+            break;
+        }
+    }
+} 

+ 15 - 0
Shaders/RHI/VariableRateShadingCompute.shader

@@ -0,0 +1,15 @@
+{   
+    "Source": "VariableRateShadingCompute.azsl",
+  
+    "ProgramSettings":
+    {
+      "EntryPoints":
+      [
+        {
+          "name": "MainCS",
+          "type": "Compute"
+        }
+      ]
+    }
+
+}

+ 67 - 0
Shaders/RHI/VariableRateShadingImage.azsl

@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/Features/SrgSemantics.azsli>
+
+#define NumShadingRates 7
+ShaderResourceGroup InstanceSrg : SRG_PerObject
+{
+    Texture2D<uint2> m_texture;
+    struct PatternColor
+    {
+        float4 m_color;
+        uint4 m_rate;
+    };
+    PatternColor m_colors[NumShadingRates];
+}
+
+struct VSInput
+{
+    float3 m_position : POSITION;
+    float2 m_uv : UV0;
+};
+
+struct VSOutput
+{
+    float4 m_position : SV_POSITION;
+    float2 m_uv : UV0;
+};
+
+VSOutput MainVS(VSInput vsInput)
+{
+    VSOutput OUT;
+    OUT.m_uv = vsInput.m_uv;
+    OUT.m_uv.y = 1.0f - OUT.m_uv.y;
+    OUT.m_position = float4(vsInput.m_position.x, vsInput.m_position.y, vsInput.m_position.z, 1);
+    return OUT;
+}
+
+struct PSOutput
+{
+    float4 m_color : SV_Target0;
+};
+
+PSOutput MainPS(VSOutput vsOutput)
+{
+    uint2 dimensions;
+    InstanceSrg::m_texture.GetDimensions(dimensions.x, dimensions.y);
+    uint2 rate = InstanceSrg::m_texture.Load(int3(vsOutput.m_uv.x * (dimensions.x - 1), vsOutput.m_uv.y * (dimensions.y - 1), 0));
+
+    PSOutput OUT;
+    float4 color = float4(0, 0, 0, 1);
+    for (int i = 0; i < NumShadingRates; ++i)
+    {
+        InstanceSrg::PatternColor pattern = InstanceSrg::m_colors[i];
+        if (pattern.m_rate.x == rate.x && pattern.m_rate.y == rate.y)
+        {
+            color = pattern.m_color;
+        }
+    }
+    OUT.m_color = color;
+    return OUT;
+}

+ 24 - 0
Shaders/RHI/VariableRateShadingImage.shader

@@ -0,0 +1,24 @@
+{
+    "Source" : "VariableRateShadingImage.azsl",
+
+    "DepthStencilState" : { 
+        "Depth" : { "Enable" : false, "CompareFunc" : "GreaterEqual" }
+    },
+
+    "DrawList" : "forward",
+
+    "ProgramSettings":
+    {
+      "EntryPoints":
+      [
+        {
+          "name": "MainVS",
+          "type": "Vertex"
+        },
+        {
+          "name": "MainPS",
+          "type": "Fragment"
+        }
+      ]
+    }
+}

+ 3 - 0
Textures/bricks.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f34b38f6e205b6bffec72ddfd1305a66d0146d61bcc2ef901fd146a0ec2708de
+size 1934899