Prechádzať zdrojové kódy

Introduce multi-device indirect types (#16681)

* Introduce multi-device indirect types

Co-authored-by: Martin Winter <[email protected]>
Co-authored-by: Joerg H. Mueller <[email protected]>
Signed-off-by: Joerg H. Mueller <[email protected]>

* Templatized IndirectArguments.

Signed-off-by: Joerg H. Mueller <[email protected]>

* Incorporate PR comments

Signed-off-by: Martin Winter <[email protected]>

* Windows fix

Signed-off-by: Martin Winter <[email protected]>

* Fix UnitTests

Signed-off-by: Martin Winter <[email protected]>

---------

Signed-off-by: Joerg H. Mueller <[email protected]>
Signed-off-by: Martin Winter <[email protected]>
Signed-off-by: Martin Winter <[email protected]>
Co-authored-by: Martin Winter <[email protected]>
Co-authored-by: Martin Winter <[email protected]>
jhmueller-huawei 2 rokov pred
rodič
commit
59c3460b6a

+ 10 - 7
Gems/Atom/RHI/Code/Include/Atom/RHI/IndirectArguments.h

@@ -13,15 +13,16 @@ namespace AZ::RHI
 {
     //! Encapsulates the arguments needed when doing an indirect call
     //! (draw or dispatch) into a command list.
-    struct IndirectArguments
+    template <typename BufferClass, typename IndirectBufferViewClass>
+    struct IndirectArgumentsTemplate
     {
-        IndirectArguments() = default;
+        IndirectArgumentsTemplate() = default;
 
-        IndirectArguments(
+        IndirectArgumentsTemplate(
             uint32_t maxSequenceCount,
             const IndirectBufferView& indirectBuffer,
             uint64_t indirectBufferByteOffset)
-            : IndirectArguments(
+            : IndirectArgumentsTemplate(
                 maxSequenceCount,
                 indirectBuffer,
                 indirectBufferByteOffset,
@@ -29,7 +30,7 @@ namespace AZ::RHI
                 0)
         {}
 
-        IndirectArguments(
+        IndirectArgumentsTemplate(
             uint32_t maxSequenceCount,
             const IndirectBufferView& indirectBuffer,
             uint64_t indirectBufferByteOffset,
@@ -55,9 +56,11 @@ namespace AZ::RHI
         uint64_t m_countBufferByteOffset = 0;
 
         //! View over the Indirect buffer that contains the commands.
-        const IndirectBufferView* m_indirectBufferView = nullptr;
+        const IndirectBufferViewClass* m_indirectBufferView = nullptr;
 
         //! Optional count buffer that contains the number of indirect commands in the indirect buffer.
-        const Buffer* m_countBuffer = nullptr;
+        const BufferClass* m_countBuffer = nullptr;
     };
+
+    using IndirectArguments = IndirectArgumentsTemplate<Buffer, IndirectBufferView>;
 }

+ 17 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/MultiDeviceIndirectArguments.h

@@ -0,0 +1,17 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI/IndirectArguments.h>
+#include <Atom/RHI/MultiDeviceBuffer.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferView.h>
+
+namespace AZ::RHI
+{
+    using MultiDeviceIndirectArguments = IndirectArgumentsTemplate<MultiDeviceBuffer, MultiDeviceIndirectBufferView>;
+} // namespace AZ::RHI

+ 76 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/MultiDeviceIndirectBufferSignature.h

@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI.Reflect/Base.h>
+#include <Atom/RHI.Reflect/IndirectBufferLayout.h>
+#include <Atom/RHI/DeviceObject.h>
+#include <Atom/RHI/IndirectBufferSignature.h>
+
+#include <Atom/RHI/MultiDeviceObject.h>
+
+namespace AZ::RHI
+{
+    class Device;
+    class MultiDevicePipelineState;
+
+    //! A multi-device descriptor for the MultiDeviceIndirectBufferSignature, holding both an IndirectBufferLayout (identical across
+    //! devices) as well as a MultiDevicePipelineState
+    struct MultiDeviceIndirectBufferSignatureDescriptor
+    {
+        //! Returns the device-specific IndirectBufferSignatureDescriptor for the given index
+        IndirectBufferSignatureDescriptor GetDeviceIndirectBufferSignatureDescriptor(int deviceIndex) const;
+
+        const MultiDevicePipelineState* m_pipelineState{ nullptr };
+        IndirectBufferLayout m_layout;
+    };
+
+    //! The MultiDeviceIndirectBufferSignature is an implementation object that represents
+    //! the signature of the commands contained in an Indirect Buffer.
+    //! Indirect Buffers hold the commands that will be used for
+    //! doing Indirect Rendering.
+    //!
+    //! It also exposes implementation dependent offsets for the commands in
+    //! a layout. This information is useful when writing commands into a buffer.
+    class MultiDeviceIndirectBufferSignature : public MultiDeviceObject
+    {
+        using Base = RHI::MultiDeviceObject;
+
+    public:
+        AZ_CLASS_ALLOCATOR(MultiDeviceIndirectBufferSignature, AZ::SystemAllocator, 0);
+        AZ_RTTI(MultiDeviceIndirectBufferSignature, "{3CCFF81D-DC5E-4B12-AC05-DC26D5D0C65C}", Base);
+        AZ_RHI_MULTI_DEVICE_OBJECT_GETTER(IndirectBufferSignature);
+        MultiDeviceIndirectBufferSignature() = default;
+        virtual ~MultiDeviceIndirectBufferSignature() = default;
+
+        //! Initialize an IndirectBufferSignature object.
+        //! @param deviceMask The deviceMask denoting all devices that will contain the signature.
+        //! @param descriptor Descriptor with the necessary information for initializing the signature.
+        //! @return A result code denoting the status of the call. If successful, the MultiDeviceIndirectBufferSignature is considered
+        //! initialized and can be used. If failure, the MultiDeviceIndirectBufferSignature remains uninitialized.
+        ResultCode Init(MultiDevice::DeviceMask deviceMask, const MultiDeviceIndirectBufferSignatureDescriptor& descriptor);
+
+        //! Returns the stride in bytes of the command sequence defined by the provided layout.
+        uint32_t GetByteStride() const;
+
+        //! Returns the offset of the command in the position indicated by the index.
+        //! @param index The location in the layout of the command.
+        uint32_t GetOffset(IndirectCommandIndex index) const;
+
+        const MultiDeviceIndirectBufferSignatureDescriptor& GetDescriptor() const;
+
+        const IndirectBufferLayout& GetLayout() const;
+
+        void Shutdown() final;
+
+    private:
+        MultiDeviceIndirectBufferSignatureDescriptor m_mdDescriptor;
+        static constexpr uint32_t UNINITIALIZED_VALUE{ std::numeric_limits<uint32_t>::max() };
+        uint32_t m_byteStride{ UNINITIALIZED_VALUE };
+    };
+} // namespace AZ::RHI

+ 76 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/MultiDeviceIndirectBufferView.h

@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI.Reflect/Bits.h>
+#include <Atom/RHI/IndirectBufferView.h>
+#include <Atom/RHI/Buffer.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferSignature.h>
+#include <AzCore/Utils/TypeHash.h>
+
+namespace AZ::RHI
+{
+    class MultiDeviceBuffer;
+    class MultiDeviceIndirectBufferSignature;
+
+    //! Provides a view into a multi-device buffer, to be used as an indirect buffer. The content of the view is a contiguous
+    //! list of commands sequences. Its device-specific buffers are provided to the RHI back-end at draw time.
+    class alignas(8) MultiDeviceIndirectBufferView
+    {
+    public:
+        MultiDeviceIndirectBufferView() = default;
+
+        MultiDeviceIndirectBufferView(
+            const MultiDeviceBuffer& buffer,
+            const MultiDeviceIndirectBufferSignature& signature,
+            uint32_t byteOffset,
+            uint32_t byteCount,
+            uint32_t byteStride);
+
+        //! Returns the device-specific IndirectBufferView for the given index
+        IndirectBufferView GetDeviceIndirectBufferView(int deviceIndex) const
+        {
+            AZ_Error("MultiDeviceIndirectBufferView", m_mdSignature, "No MultiDeviceIndirectBufferSignature available\n");
+            AZ_Error("MultiDeviceIndirectBufferView", m_mdBuffer, "No MultiDeviceBuffer available\n");
+
+            return IndirectBufferView(
+                *m_mdBuffer->GetDeviceBuffer(deviceIndex),
+                *m_mdSignature->GetDeviceIndirectBufferSignature(deviceIndex),
+                m_byteOffset,
+                m_byteCount,
+                m_byteStride);
+        }
+
+        //! Returns the hash of the view. This hash is precomputed at creation time.
+        HashValue64 GetHash() const;
+
+        //! Returns the buffer associated with the view.
+        const MultiDeviceBuffer* GetBuffer() const;
+
+        //! Returns the byte offset into the buffer.
+        uint32_t GetByteOffset() const;
+
+        //! Returns the number of bytes in the view.
+        uint32_t GetByteCount() const;
+
+        //! Returns the distance in bytes between consecutive commands sequences.
+        //! This must be larger or equal than the stride specify by the signature.
+        uint32_t GetByteStride() const;
+
+        //! Returns the signature of the indirect buffer that is associated with the view.
+        const MultiDeviceIndirectBufferSignature* GetSignature() const;
+
+    private:
+        HashValue64 m_hash = HashValue64{ 0 };
+        const MultiDeviceIndirectBufferSignature* m_mdSignature = nullptr;
+        const MultiDeviceBuffer* m_mdBuffer = nullptr;
+        uint32_t m_byteOffset = 0;
+        uint32_t m_byteCount = 0;
+        uint32_t m_byteStride = 0;
+    };
+} // namespace AZ::RHI

+ 132 - 0
Gems/Atom/RHI/Code/Include/Atom/RHI/MultiDeviceIndirectBufferWriter.h

@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+#pragma once
+
+#include <Atom/RHI.Reflect/Base.h>
+#include <Atom/RHI.Reflect/IndirectBufferLayout.h>
+#include <Atom/RHI/IndirectBufferWriter.h>
+#include <Atom/RHI/Object.h>
+
+namespace AZ::RHI
+{
+    class MultiDeviceBuffer;
+    class MultiDeviceIndexBufferView;
+    class MultiDeviceStreamBufferView;
+    class MultiDeviceIndirectBufferSignature;
+
+    //! MultiDeviceIndirectBufferWriter is a helper class to write indirect commands
+    //! to a buffer or a memory location in a platform independent way. Different APIs may
+    //! have different layouts for the arguments of an indirect command. This class provides
+    //! a secure and simple way to write the commands without worrying about API differences.
+    //!
+    //! It also provides basic checks, like trying to write more commands than allowed, or
+    //! writing commands that are not specified in the layout.
+    class MultiDeviceIndirectBufferWriter : public Object
+    {
+        using Base = Object;
+
+    public:
+        AZ_CLASS_ALLOCATOR(MultiDeviceIndirectBufferWriter, AZ::SystemAllocator, 0);
+        AZ_RTTI(MultiDeviceIndirectBufferWriter, "{096CBDFF-AB05-4E8D-9EC1-04F12CFCD85D}");
+        virtual ~MultiDeviceIndirectBufferWriter() = default;
+
+        //! Returns the device-specific IndirectBufferWriter for the given index
+        inline Ptr<IndirectBufferWriter> GetDeviceIndirectBufferWriter(int deviceIndex) const
+        {
+            AZ_Error(
+                "MultiDeviceIndirectBufferWriter",
+                m_deviceIndirectBufferWriter.find(deviceIndex) != m_deviceIndirectBufferWriter.end(),
+                "No IndirectBufferWriter found for device index %d\n",
+                deviceIndex);
+            return m_deviceIndirectBufferWriter.at(deviceIndex);
+        }
+
+        //! Initialize the MultiDeviceIndirectBufferWriter to write commands into a buffer.
+        //! @param buffer The buffer where to write the commands. Any previous values for the specified range will be overwritten.
+        //!               The buffer must be big enough to contain the max number of sequences.
+        //! @param byteOffset The offset into the buffer.
+        //! @param byteStride The stride between command sequences. Must be larger than the stride calculated from the signature.
+        //! @param maxCommandSequences The max number of sequences that the MultiDeviceIndirectBufferWriter can write.
+        //! @param signature Signature of the indirect buffer.
+        //! @return A result code denoting the status of the call. If successful, the MultiDeviceIndirectBufferWriter is considered
+        //!      initialized and is able to service write requests. If failure, the MultiDeviceIndirectBufferWriter remains
+        //!      uninitialized.
+        ResultCode Init(
+            MultiDeviceBuffer& buffer,
+            size_t byteOffset,
+            uint32_t byteStride,
+            uint32_t maxCommandSequences,
+            const MultiDeviceIndirectBufferSignature& signature);
+
+        //! Initialize the MultiDeviceIndirectBufferWriter to write commands into a memory location.
+        //! @param memoryPtr The memory location where the commands will be written. Must not be null.
+        //! @param byteStride The stride between command sequences. Must be larger than the stride calculated from the signature.
+        //! @param maxCommandSequences The max number of sequences that the MultiDeviceIndirectBufferWriter can write.
+        //! @param signature Signature of the indirect buffer.
+        //! @return A result code denoting the status of the call. If successful, the MultiDeviceIndirectBufferWriter is considered
+        //!      initialized and is able to service write requests. If failure, the MultiDeviceIndirectBufferWriter remains
+        //!      uninitialized.
+        ResultCode Init(
+            const AZStd::unordered_map<int, void*>& memoryPtrs, uint32_t byteStride, uint32_t maxCommandSequences, const MultiDeviceIndirectBufferSignature& signature);
+
+        //! Writes a vertex buffer view command into the current sequence.
+        //! @param slot The stream buffer slot that the view will set.
+        //! @param view The MultiDeviceStreamBufferView that will be set.
+        //! @return A pointer to the MultiDeviceIndirectBufferWriter object (this).
+        MultiDeviceIndirectBufferWriter* SetVertexView(uint32_t slot, const MultiDeviceStreamBufferView& view);
+
+        //! Writes an index buffer view command into the current sequence.
+        //! @param view The MultiDeviceIndexBufferView that will be set.
+        //! @return A pointer to the MultiDeviceIndirectBufferWriter object (this).
+        MultiDeviceIndirectBufferWriter* SetIndexView(const MultiDeviceIndexBufferView& view);
+
+        //! Writes a draw command into the current sequence.
+        //! @param arguments The draw arguments that will be written.
+        //! @return A pointer to the MultiDeviceIndirectBufferWriter object (this).
+        MultiDeviceIndirectBufferWriter* Draw(const DrawLinear& arguments);
+
+        //! Writes a draw indexed command into the current sequence.
+        //! @param arguments The draw indexed arguments that will be written.
+        //! @return A pointer to the MultiDeviceIndirectBufferWriter object (this).
+        MultiDeviceIndirectBufferWriter* DrawIndexed(const DrawIndexed& arguments);
+
+        //! Writes a dispatch command into the current sequence.
+        //! @param arguments The dispatch arguments that will be written.
+        //! @return A pointer to the MultiDeviceIndirectBufferWriter object (this).
+        MultiDeviceIndirectBufferWriter* Dispatch(const DispatchDirect& arguments);
+
+        //! Writes an inline constants command into the current sequence. This command will set
+        //! the values of all inline constants of the Pipeline.
+        //! @param data A pointer to the data that contains the values that will be written.
+        //! @param byteSize The size of the data that will be written.
+        //! @return A pointer to the MultiDeviceIndirectBufferWriter object (this).
+        MultiDeviceIndirectBufferWriter* SetRootConstants(const uint8_t* data, uint32_t byteSize);
+
+        //! Advance the current sequence index by 1.
+        //! @return True if the sequence index was increased correctly. False otherwise.
+        bool NextSequence();
+
+        //! Move the current sequence index to a specified position.
+        //! @param sequenceIndex The index where the sequence index will be moved. Must be less than maxCommandSequences.
+        //! @return True if the sequence index was updated correctly. False otherwise and the current sequence index is not modified.
+        bool Seek(const uint32_t sequenceIndex);
+
+        //! Flush changes into the destination buffer. Only valid when using a buffer.
+        void Flush();
+
+        bool IsInitialized() const;
+
+        uint32_t GetCurrentSequenceIndex() const;
+
+        void Shutdown() override;
+
+    private:
+        //! A map of all device-specific IndirectBufferWriter, indexed by the device index
+        AZStd::unordered_map<int, Ptr<IndirectBufferWriter>> m_deviceIndirectBufferWriter;
+    };
+} // namespace AZ::RHI

+ 116 - 0
Gems/Atom/RHI/Code/Source/RHI/MultiDeviceIndirectBufferSignature.cpp

@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/RHI/Factory.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferSignature.h>
+#include <Atom/RHI/MultiDevicePipelineState.h>
+#include <Atom/RHI/RHISystemInterface.h>
+
+namespace AZ::RHI
+{
+    IndirectBufferSignatureDescriptor MultiDeviceIndirectBufferSignatureDescriptor::GetDeviceIndirectBufferSignatureDescriptor(
+        int deviceIndex) const
+    {
+        AZ_Assert(m_pipelineState, "No MultiDevicePipelineState available\n");
+
+        IndirectBufferSignatureDescriptor descriptor{ m_layout };
+
+        if (m_pipelineState)
+        {
+            descriptor.m_pipelineState = m_pipelineState->GetDevicePipelineState(deviceIndex).get();
+        }
+
+        return descriptor;
+    }
+
+    ResultCode MultiDeviceIndirectBufferSignature::Init(
+        MultiDevice::DeviceMask deviceMask, const MultiDeviceIndirectBufferSignatureDescriptor& descriptor)
+    {
+        MultiDeviceObject::Init(deviceMask);
+
+        ResultCode resultCode{ ResultCode::Success };
+
+        IterateDevices(
+            [this, &descriptor, &resultCode](int deviceIndex)
+            {
+                auto device = RHISystemInterface::Get()->GetDevice(deviceIndex);
+
+                m_deviceObjects[deviceIndex] = Factory::Get().CreateIndirectBufferSignature();
+                resultCode = GetDeviceIndirectBufferSignature(deviceIndex)->Init(
+                    *device, descriptor.GetDeviceIndirectBufferSignatureDescriptor(deviceIndex));
+
+                if(m_byteStride == UNINITIALIZED_VALUE)
+                {
+                    // Cache byteStride since it is the same for all devices
+                    m_byteStride = GetDeviceIndirectBufferSignature(deviceIndex)->GetByteStride();
+                }
+
+                return resultCode == ResultCode::Success;
+            });
+
+        m_mdDescriptor = descriptor;
+
+        return resultCode;
+    }
+
+    uint32_t MultiDeviceIndirectBufferSignature::GetByteStride() const
+    {
+        AZ_Assert(IsInitialized(), "Signature is not initialized");
+        return m_byteStride;
+    }
+
+    uint32_t MultiDeviceIndirectBufferSignature::GetOffset(IndirectCommandIndex index) const
+    {
+        AZ_Assert(IsInitialized(), "Signature is not initialized");
+        if (Validation::IsEnabled())
+        {
+            if (index.IsNull())
+            {
+                AZ_Assert(false, "Invalid index");
+                return 0;
+            }
+
+            if (index.GetIndex() >= m_mdDescriptor.m_layout.GetCommands().size())
+            {
+                AZ_Assert(false, "Index %d is greater than the number of commands on the layout", index.GetIndex());
+                return 0;
+            }
+        }
+
+        auto offset{ UNINITIALIZED_VALUE };
+
+        IterateObjects<IndirectBufferSignature>([&offset, &index]([[maybe_unused]] auto deviceIndex, auto deviceSignature)
+        {
+            auto deviceOffset{ deviceSignature->GetOffset(index) };
+
+            if (offset == UNINITIALIZED_VALUE)
+            {
+                offset = deviceOffset;
+            }
+
+            AZ_Assert(deviceOffset == offset, "Device Signature offsets do not match");
+        });
+
+        return offset;
+    }
+
+    const MultiDeviceIndirectBufferSignatureDescriptor& MultiDeviceIndirectBufferSignature::GetDescriptor() const
+    {
+        return m_mdDescriptor;
+    }
+
+    const AZ::RHI::IndirectBufferLayout& MultiDeviceIndirectBufferSignature::GetLayout() const
+    {
+        return m_mdDescriptor.m_layout;
+    }
+
+    void MultiDeviceIndirectBufferSignature::Shutdown()
+    {
+        MultiDeviceObject::Shutdown();
+    }
+} // namespace AZ::RHI

+ 65 - 0
Gems/Atom/RHI/Code/Source/RHI/MultiDeviceIndirectBufferView.cpp

@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/RHI/MultiDeviceBuffer.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferView.h>
+#include <AzCore/std/hash.h>
+
+namespace AZ::RHI
+{
+    MultiDeviceIndirectBufferView::MultiDeviceIndirectBufferView(
+        const MultiDeviceBuffer& buffer,
+        const MultiDeviceIndirectBufferSignature& signature,
+        uint32_t byteOffset,
+        uint32_t byteCount,
+        uint32_t byteStride)
+        : m_mdBuffer(&buffer)
+        , m_byteOffset(byteOffset)
+        , m_byteCount(byteCount)
+        , m_byteStride(byteStride)
+        , m_mdSignature(&signature)
+    {
+        size_t seed = 0;
+        AZStd::hash_combine(seed, m_mdBuffer);
+        AZStd::hash_combine(seed, m_byteOffset);
+        AZStd::hash_combine(seed, m_byteCount);
+        AZStd::hash_combine(seed, m_byteStride);
+        AZStd::hash_combine(seed, m_mdSignature);
+        m_hash = static_cast<HashValue64>(seed);
+    }
+
+    HashValue64 MultiDeviceIndirectBufferView::GetHash() const
+    {
+        return m_hash;
+    }
+
+    const MultiDeviceBuffer* MultiDeviceIndirectBufferView::GetBuffer() const
+    {
+        return m_mdBuffer;
+    }
+
+    uint32_t MultiDeviceIndirectBufferView::GetByteOffset() const
+    {
+        return m_byteOffset;
+    }
+
+    uint32_t MultiDeviceIndirectBufferView::GetByteCount() const
+    {
+        return m_byteCount;
+    }
+
+    uint32_t MultiDeviceIndirectBufferView::GetByteStride() const
+    {
+        return m_byteStride;
+    }
+
+    const MultiDeviceIndirectBufferSignature* MultiDeviceIndirectBufferView::GetSignature() const
+    {
+        return m_mdSignature;
+    }
+} // namespace AZ::RHI

+ 291 - 0
Gems/Atom/RHI/Code/Source/RHI/MultiDeviceIndirectBufferWriter.cpp

@@ -0,0 +1,291 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include <Atom/RHI.Reflect/IndirectBufferLayout.h>
+#include <Atom/RHI/Factory.h>
+#include <Atom/RHI/MultiDeviceBuffer.h>
+#include <Atom/RHI/MultiDeviceBufferPool.h>
+#include <Atom/RHI/MultiDeviceStreamBufferView.h>
+#include <Atom/RHI/MultiDeviceIndexBufferView.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferSignature.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferWriter.h>
+#include <Atom/RHI/RHISystemInterface.h>
+
+namespace AZ::RHI
+{
+    ResultCode MultiDeviceIndirectBufferWriter::Init(
+        MultiDeviceBuffer& buffer,
+        size_t byteOffset,
+        uint32_t byteStride,
+        uint32_t maxCommandSequences,
+        const MultiDeviceIndirectBufferSignature& signature)
+    {
+        if (Validation::IsEnabled())
+        {
+            if (IsInitialized())
+            {
+                AZ_Assert(false, "MultiDeviceIndirectBufferWriter cannot be initialized when calling this method.");
+                return ResultCode::InvalidOperation;
+            }
+
+            if ((byteOffset + maxCommandSequences * byteStride) > buffer.GetDescriptor().m_byteCount)
+            {
+                AZ_Assert(false, "MultiDeviceBuffer is too small to contain the required commands");
+                return ResultCode::InvalidArgument;
+            }
+        }
+
+        ResultCode result{ ResultCode::Success };
+        auto deviceMask{ AZStd::to_underlying(buffer.GetDeviceMask()) };
+
+        for (auto deviceIndex{ 0 }; deviceMask && (deviceIndex < RHI::RHISystemInterface::Get()->GetDeviceCount());
+             deviceMask >>= 1, ++deviceIndex)
+        {
+            if (CheckBitsAll(deviceMask, 1u))
+            {
+                m_deviceIndirectBufferWriter[deviceIndex] = Factory::Get().CreateIndirectBufferWriter();
+                auto deviceIndirectBufferSignature{ signature.IsInitialized() ? signature.GetDeviceIndirectBufferSignature(deviceIndex)
+                                                                              : Factory::Get().CreateIndirectBufferSignature() };
+                result = m_deviceIndirectBufferWriter[deviceIndex]->Init(
+                    *buffer.GetDeviceBuffer(deviceIndex).get(),
+                    byteOffset,
+                    byteStride,
+                    maxCommandSequences,
+                    *deviceIndirectBufferSignature);
+
+                if (result != ResultCode::Success)
+                {
+                    break;
+                }
+            }
+        }
+
+        if (result != ResultCode::Success)
+        {
+            // Reset already initialized device-specific IndirectBufferWriters
+            m_deviceIndirectBufferWriter.clear();
+        }
+
+        return result;
+    }
+
+    ResultCode MultiDeviceIndirectBufferWriter::Init(
+        const AZStd::unordered_map<int, void*>& memoryPtrs,
+        uint32_t byteStride,
+        uint32_t maxCommandSequences,
+        const MultiDeviceIndirectBufferSignature& signature)
+    {
+        if (Validation::IsEnabled())
+        {
+            if (memoryPtrs.empty())
+            {
+                AZ_Assert(false, "Null target memory");
+                return ResultCode::InvalidArgument;
+            }
+        }
+
+        ResultCode result{ ResultCode::Success };
+        auto deviceMask{ AZStd::to_underlying(signature.GetDeviceMask()) };
+
+        for (auto deviceIndex{ 0 }; deviceMask && (deviceIndex < RHI::RHISystemInterface::Get()->GetDeviceCount());
+             deviceMask >>= 1, ++deviceIndex)
+        {
+            if (deviceMask & 1)
+            {
+                m_deviceIndirectBufferWriter[deviceIndex] = Factory::Get().CreateIndirectBufferWriter();
+                auto deviceIndirectBufferSignature{ signature.IsInitialized() ? signature.GetDeviceIndirectBufferSignature(deviceIndex)
+                                                                              : Factory::Get().CreateIndirectBufferSignature() };
+                result = m_deviceIndirectBufferWriter[deviceIndex]->Init(
+                    memoryPtrs.at(deviceIndex), byteStride, maxCommandSequences, *deviceIndirectBufferSignature);
+
+                if (result != ResultCode::Success)
+                    break;
+            }
+        }
+
+        if (result != ResultCode::Success)
+        {
+            // Reset already initialized device-specific IndirectBufferWriters
+            m_deviceIndirectBufferWriter.clear();
+        }
+
+        return result;
+    }
+
+    bool MultiDeviceIndirectBufferWriter::NextSequence()
+    {
+        auto result{ false };
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            result = writer->NextSequence();
+            if (!result)
+            {
+                break;
+            }
+        }
+        return result;
+    }
+
+    void MultiDeviceIndirectBufferWriter::Shutdown()
+    {
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->Shutdown();
+        }
+    }
+
+    MultiDeviceIndirectBufferWriter* MultiDeviceIndirectBufferWriter::SetVertexView(uint32_t slot, const MultiDeviceStreamBufferView& view)
+    {
+        if (Validation::IsEnabled() && !IsInitialized())
+        {
+            AZ_Assert(false, "MultiDeviceIndirectBufferWriter must be initialized when calling this method.");
+        }
+
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->SetVertexView(slot, view.GetDeviceStreamBufferView(deviceIndex));
+        }
+
+        return this;
+    }
+
+    MultiDeviceIndirectBufferWriter* MultiDeviceIndirectBufferWriter::SetIndexView(const MultiDeviceIndexBufferView& view)
+    {
+        if (Validation::IsEnabled() && !IsInitialized())
+        {
+            AZ_Assert(false, "MultiDeviceIndirectBufferWriter must be initialized when calling this method.");
+        }
+
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->SetIndexView(view.GetDeviceIndexBufferView(deviceIndex));
+        }
+
+        return this;
+    }
+
+    MultiDeviceIndirectBufferWriter* MultiDeviceIndirectBufferWriter::Draw(const DrawLinear& arguments)
+    {
+        if (Validation::IsEnabled() && !IsInitialized())
+        {
+            AZ_Assert(false, "MultiDeviceIndirectBufferWriter must be initialized when calling this method.");
+        }
+
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->Draw(arguments);
+        }
+
+        return this;
+    }
+
+    MultiDeviceIndirectBufferWriter* MultiDeviceIndirectBufferWriter::DrawIndexed(const RHI::DrawIndexed& arguments)
+    {
+        if (Validation::IsEnabled() && !IsInitialized())
+        {
+            AZ_Assert(false, "MultiDeviceIndirectBufferWriter must be initialized when calling this method.");
+        }
+
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->DrawIndexed(arguments);
+        }
+
+        return this;
+    }
+
+    MultiDeviceIndirectBufferWriter* MultiDeviceIndirectBufferWriter::Dispatch(const DispatchDirect& arguments)
+    {
+        if (Validation::IsEnabled() && !IsInitialized())
+        {
+            AZ_Assert(false, "MultiDeviceIndirectBufferWriter must be initialized when calling this method.");
+        }
+
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->Dispatch(arguments);
+        }
+
+        return this;
+    }
+
+    MultiDeviceIndirectBufferWriter* MultiDeviceIndirectBufferWriter::SetRootConstants(const uint8_t* data, uint32_t byteSize)
+    {
+        if (Validation::IsEnabled() && !IsInitialized())
+        {
+            AZ_Assert(false, "MultiDeviceIndirectBufferWriter must be initialized when calling this method.");
+        }
+
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->SetRootConstants(data, byteSize);
+        }
+
+        return this;
+    }
+
+    bool MultiDeviceIndirectBufferWriter::Seek(const uint32_t sequenceIndex)
+    {
+        auto result{ false };
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            result = writer->Seek(sequenceIndex);
+            if (!result)
+            {
+                break;
+            }
+        }
+        return result;
+    }
+
+    void MultiDeviceIndirectBufferWriter::Flush()
+    {
+        // Unmap the buffer to force a flush changes into the buffer.
+        // The buffer will be remap before writing new commands.
+        // We don't remap here because we can't leave a buffer mapped during the
+        // whole frame execution.
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            writer->Flush();
+        }
+    }
+
+    bool MultiDeviceIndirectBufferWriter::IsInitialized() const
+    {
+        auto result{ false };
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            result = writer->IsInitialized();
+            if (!result)
+            {
+                break;
+            }
+        }
+        return result;
+    }
+
+    uint32_t MultiDeviceIndirectBufferWriter::GetCurrentSequenceIndex() const
+    {
+        static constexpr uint32_t UNINITIALIZED_VALUE{ std::numeric_limits<uint32_t>::max() };
+        uint32_t currentSequenceIndex{ UNINITIALIZED_VALUE };
+
+        for (const auto& [deviceIndex, writer] : m_deviceIndirectBufferWriter)
+        {
+            auto deviceCurrentSequenceIndex{ writer->GetCurrentSequenceIndex() };
+
+            if (currentSequenceIndex == UNINITIALIZED_VALUE)
+            {
+                currentSequenceIndex = deviceCurrentSequenceIndex;
+            }
+
+            AZ_Assert(deviceCurrentSequenceIndex == currentSequenceIndex, "Device IndirectBufferWriter CurrentSequenceIndex do not match");
+        }
+
+        return currentSequenceIndex;
+    }
+} // namespace AZ::RHI

+ 633 - 0
Gems/Atom/RHI/Code/Tests/MultiDeviceIndirectBufferTests.cpp

@@ -0,0 +1,633 @@
+/*
+ * Copyright (c) Contributors to the Open 3D Engine Project.
+ * For complete copyright and license terms please see the LICENSE at the root of this distribution.
+ *
+ * SPDX-License-Identifier: Apache-2.0 OR MIT
+ *
+ */
+
+#include "RHITestFixture.h"
+#include <Atom/RHI/FrameEventBus.h>
+#include <Atom/RHI/MultiDeviceBufferPool.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferSignature.h>
+#include <Atom/RHI/MultiDeviceIndirectBufferWriter.h>
+#include <Atom/RHI/MultiDeviceStreamBufferView.h>
+#include <Atom/RHI/MultiDeviceIndexBufferView.h>
+#include <Tests/Buffer.h>
+#include <Tests/Device.h>
+#include <Tests/IndirectBuffer.h>
+
+#include <Atom/RHI.Reflect/ReflectSystemComponent.h>
+
+#include <AzCore/Serialization/ObjectStream.h>
+#include <AzCore/Serialization/Utils.h>
+
+namespace UnitTest
+{
+    using namespace AZ;
+
+    class MultiDeviceIndirectBufferTests : public MultiDeviceRHITestFixture
+    {
+    public:
+        MultiDeviceIndirectBufferTests()
+            : MultiDeviceRHITestFixture()
+        {
+        }
+
+        ~MultiDeviceIndirectBufferTests()
+        {
+        }
+
+    private:
+        void SetUp() override
+        {
+            MultiDeviceRHITestFixture::SetUp();
+
+            m_serializeContext = AZStd::make_unique<SerializeContext>();
+            RHI::ReflectSystemComponent::Reflect(m_serializeContext.get());
+            AZ::Name::Reflect(m_serializeContext.get());
+
+            m_commands.clear();
+            m_commands.push_back(RHI::IndirectCommandType::RootConstants);
+            m_commands.push_back(RHI::IndirectBufferViewArguments{ s_vertexSlotIndex });
+            m_commands.push_back(RHI::IndirectCommandType::IndexBufferView);
+            m_commands.push_back(RHI::IndirectCommandType::DrawIndexed);
+
+            m_bufferPool = aznew AZ::RHI::MultiDeviceBufferPool;
+            RHI::BufferPoolDescriptor poolDesc;
+            poolDesc.m_bindFlags = RHI::BufferBindFlags::ShaderReadWrite;
+            m_bufferPool->Init(DeviceMask, poolDesc);
+
+            m_buffer = aznew AZ::RHI::MultiDeviceBuffer;
+            RHI::MultiDeviceBufferInitRequest initRequest;
+            initRequest.m_buffer = m_buffer.get();
+            initRequest.m_descriptor.m_byteCount = m_writerCommandStride * m_writerNumCommands;
+            initRequest.m_descriptor.m_bindFlags = poolDesc.m_bindFlags;
+            m_bufferPool->InitBuffer(initRequest);
+
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            m_writerSignature = CreateInitializedSignature();
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_CALL(
+                    *static_cast<IndirectBufferSignature*>(m_writerSignature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                    GetByteStrideInternal())
+                    .WillRepeatedly(testing::Return(m_writerCommandStride));
+            }
+        }
+
+        void TearDown() override
+        {
+            m_buffer.reset();
+            m_bufferPool.reset();
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_CALL(
+                    *static_cast<IndirectBufferSignature*>(m_writerSignature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                    ShutdownInternal())
+                    .Times(1);
+            }
+
+            m_writerSignature.reset();
+
+            m_serializeContext.reset();
+            MultiDeviceRHITestFixture::TearDown();
+        }
+
+    protected:
+        RHI::IndirectBufferLayout CreateUnfinalizedLayout()
+        {
+            RHI::IndirectBufferLayout layout;
+            for (const auto& descriptor : m_commands)
+            {
+                EXPECT_TRUE(layout.AddIndirectCommand(descriptor));
+            }
+            return layout;
+        }
+
+        RHI::IndirectBufferLayout CreateFinalizedLayout()
+        {
+            auto layout = CreateUnfinalizedLayout();
+            EXPECT_TRUE(layout.Finalize());
+            return layout;
+        }
+
+        RHI::IndirectBufferLayout CreateSerializedLayout(const RHI::IndirectBufferLayout& layout)
+        {
+            AZStd::vector<char, AZ::OSStdAllocator> buffer;
+            AZ::IO::ByteContainerStream<AZStd::vector<char, AZ::OSStdAllocator>> outStream(&buffer);
+
+            {
+                AZ::ObjectStream* objStream = AZ::ObjectStream::Create(&outStream, *m_serializeContext.get(), AZ::ObjectStream::ST_BINARY);
+
+                bool writeOK = objStream->WriteClass(&layout);
+                EXPECT_TRUE(writeOK);
+
+                bool finalizeOK = objStream->Finalize();
+                EXPECT_TRUE(finalizeOK);
+            }
+
+            outStream.Seek(0, IO::GenericStream::ST_SEEK_BEGIN);
+
+            AZ::ObjectStream::FilterDescriptor filterDesc;
+            RHI::IndirectBufferLayout deserializedLayout;
+            bool deserializedOK = AZ::Utils::LoadObjectFromStreamInPlace<RHI::IndirectBufferLayout>(
+                outStream, deserializedLayout, m_serializeContext.get(), filterDesc);
+            EXPECT_TRUE(deserializedOK);
+            return deserializedLayout;
+        }
+
+        void ValidateLayout(const RHI::IndirectBufferLayout& layout)
+        {
+            EXPECT_TRUE(layout.IsFinalized());
+            auto layoutCommands = layout.GetCommands();
+            EXPECT_EQ(m_commands.size(), layoutCommands.size());
+            for (uint32_t i = 0; i < m_commands.size(); ++i)
+            {
+                EXPECT_EQ(m_commands[i], layoutCommands[i]);
+                EXPECT_EQ(layout.FindCommandIndex(m_commands[i]), RHI::IndirectCommandIndex(i));
+            }
+        }
+
+        RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferSignature> CreateInitializedSignature()
+        {
+            using namespace ::testing;
+            auto signature = aznew AZ::RHI::MultiDeviceIndirectBufferSignature;
+            m_signatureDescriptor.m_layout = CreateFinalizedLayout();
+            EXPECT_EQ(signature->Init(DeviceMask, m_signatureDescriptor), RHI::ResultCode::Success);
+
+            return signature;
+        }
+
+        RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferSignature> CreateUnInitializedSignature()
+        {
+            auto signature = aznew AZ::RHI::MultiDeviceIndirectBufferSignature;
+            return signature;
+        }
+
+        RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> CreateInitializedWriter()
+        {
+            auto writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            EXPECT_EQ(
+                writer->Init(*m_buffer, m_writerOffset, m_writerCommandStride, m_writerNumCommands, *m_writerSignature),
+                RHI::ResultCode::Success);
+            return writer;
+        }
+
+        void ValidateSignature(const RHI::MultiDeviceIndirectBufferSignature& signature)
+        {
+            ValidateLayout(signature.GetLayout());
+            EXPECT_TRUE(signature.IsInitialized());
+        }
+
+        void ValidateWriter(const AZ::RHI::MultiDeviceIndirectBufferWriter& writer)
+        {
+            auto currentSequenceIndex{ writer.GetCurrentSequenceIndex() };
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_EQ(
+                    static_cast<IndirectBufferWriter*>(writer.GetDeviceIndirectBufferWriter(deviceIndex).get())->GetData(),
+                    static_cast<const uint8_t*>(static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->GetData().data()));
+
+                EXPECT_EQ(currentSequenceIndex, 0);
+                EXPECT_TRUE(static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->IsMapped());
+            }
+        }
+
+        static const uint32_t s_vertexSlotIndex = 3;
+        AZStd::vector<RHI::IndirectCommandDescriptor> m_commands;
+
+        AZStd::unique_ptr<SerializeContext> m_serializeContext;
+        RHI::MultiDeviceIndirectBufferSignatureDescriptor m_signatureDescriptor;
+
+        RHI::Ptr<AZ::RHI::MultiDeviceBufferPool> m_bufferPool;
+        RHI::Ptr<AZ::RHI::MultiDeviceBuffer> m_buffer;
+
+        size_t m_writerOffset = 0;
+        uint32_t m_writerCommandStride = 2;
+        uint32_t m_writerNumCommands = 1024;
+
+        RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferSignature> m_writerSignature;
+    };
+
+    TEST_F(MultiDeviceIndirectBufferTests, TestSignature)
+    {
+        // Normal initialization
+        {
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            auto signature = CreateInitializedSignature();
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+            EXPECT_TRUE(signature != nullptr);
+            ValidateSignature(*signature);
+        }
+
+        // ! Cannot tests this as we do not have access to device signatures here and cannot setup mock-call
+        // // Failure initializing.
+        // {
+        //     auto signature = CreateUnInitializedSignature();
+        //     RHI::MultiDeviceIndirectBufferSignatureDescriptor descriptor;
+        //     EXPECT_TRUE(signature->Init(DeviceMask, descriptor) == RHI::ResultCode::InvalidOperation);
+        //     EXPECT_FALSE(signature->IsInitialized());
+        // }
+
+        // ! Cannot tests this as we do not have access to device signatures here and cannot setup mock-call
+        // // GetByteStride()
+        // {
+        //     AZ_TEST_START_TRACE_SUPPRESSION;
+        //     auto signature = CreateInitializedSignature();
+        //     AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+        //     uint32_t byteStride = 1337;
+        //     for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+        //     {
+        //         EXPECT_CALL(
+        //             *static_cast<IndirectBufferSignature*>(signature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+        //             GetByteStrideInternal())
+        //             .Times(1)
+        //             .WillOnce(testing::Return(byteStride));
+        //     }
+        //     EXPECT_EQ(signature->GetByteStride(), byteStride);
+        // }
+
+        // GetByteStride() on uninitialized signature.
+        {
+            auto signature = CreateUnInitializedSignature();
+            // ! Do not have access to members, cannot setup mock-call
+            // EXPECT_CALL(*signature, GetByteStrideInternal()).Times(1).WillOnce(testing::Return(0));
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            signature->GetByteStride();
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // GetOffset()
+        {
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            auto signature = CreateInitializedSignature();
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+            uint32_t offset = 1337;
+            RHI::IndirectCommandIndex index(m_commands.size() - 1);
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_CALL(
+                    *static_cast<IndirectBufferSignature*>(signature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                    GetOffsetInternal(index))
+                    .Times(1)
+                    .WillOnce(testing::Return(offset));
+            }
+            EXPECT_EQ(signature->GetOffset(index), offset);
+        }
+
+        // GetOffset with null index
+        {
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            auto signature = CreateInitializedSignature();
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+            RHI::IndirectCommandIndex index = RHI::IndirectCommandIndex::Null;
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            signature->GetOffset(index);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // GetOffset with invalid index
+        {
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            auto signature = CreateInitializedSignature();
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+            RHI::IndirectCommandIndex index(m_commands.size());
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            signature->GetOffset(index);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Shutdown
+        {
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            auto signature = CreateInitializedSignature();
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_CALL(
+                    *static_cast<IndirectBufferSignature*>(signature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                    ShutdownInternal())
+                    .Times(1);
+            }
+        }
+    }
+
+    TEST_F(MultiDeviceIndirectBufferTests, TestWriter)
+    {
+        // Normal Initialization
+        {
+            auto writer = CreateInitializedWriter();
+            EXPECT_TRUE(writer != nullptr);
+            ValidateWriter(*writer);
+        }
+
+        // Initialization with invalid size
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            EXPECT_EQ(
+                writer->Init(*m_buffer, 1, m_writerCommandStride, m_writerNumCommands, *m_writerSignature),
+                RHI::ResultCode::InvalidArgument);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Initialization with invalid stride
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            EXPECT_EQ(
+                writer->Init(*m_buffer, m_writerOffset, 0, m_writerNumCommands, *m_writerSignature), RHI::ResultCode::InvalidArgument);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Initialization with invalid max num sequences
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            EXPECT_EQ(
+                writer->Init(*m_buffer, m_writerOffset, m_writerCommandStride, 0, *m_writerSignature), RHI::ResultCode::InvalidArgument);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Initialization with small invalid stride
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            EXPECT_EQ(
+                writer->Init(*m_buffer, m_writerOffset, m_writerCommandStride - 1, m_writerNumCommands, *m_writerSignature),
+                RHI::ResultCode::InvalidArgument);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Initialization with invalid signature
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            auto signature = CreateUnInitializedSignature();
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            EXPECT_EQ(
+                writer->Init(*m_buffer, m_writerOffset, m_writerCommandStride, m_writerNumCommands, *signature),
+                RHI::ResultCode::InvalidArgument);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Initialization with offset
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            size_t offset = 16;
+            EXPECT_EQ(writer->Init(*m_buffer, offset, m_writerCommandStride, 5, *m_writerSignature), RHI::ResultCode::Success);
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_EQ(
+                    static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get())->GetData(),
+                    static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->GetData().data() + offset);
+            }
+        }
+
+        // Initialization with memory pointer
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            AZStd::unordered_map<int, void*> memoryPtrs;
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                memoryPtrs[deviceIndex] = const_cast<uint8_t*>(
+                    static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->GetData().data());
+            }
+            
+            EXPECT_EQ(writer->Init(memoryPtrs, m_writerCommandStride, m_writerNumCommands, *m_writerSignature), RHI::ResultCode::Success);
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_EQ(
+                    static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get())->GetData(),
+                    static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->GetData().data());
+            }
+        }
+
+        // Double Init
+        {
+            auto writer = CreateInitializedWriter();
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            EXPECT_EQ(
+                writer->Init(*m_buffer, m_writerOffset, m_writerCommandStride, m_writerNumCommands, *m_writerSignature),
+                RHI::ResultCode::InvalidOperation);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Valid Seek
+        {
+            auto writer = CreateInitializedWriter();
+            uint32_t seekPos = 2;
+            EXPECT_TRUE(writer->Seek(seekPos));
+            {
+                auto currentSequenceIndex{ writer->GetCurrentSequenceIndex() };
+                EXPECT_EQ(currentSequenceIndex, seekPos);
+            }
+
+            seekPos += 6;
+            EXPECT_TRUE(writer->Seek(seekPos));
+            {
+                auto currentSequenceIndex{ writer->GetCurrentSequenceIndex() };
+                EXPECT_EQ(currentSequenceIndex, seekPos);
+            }
+        }
+
+        // Invalid Seek
+        {
+            auto writer = CreateInitializedWriter();
+            EXPECT_FALSE(writer->Seek(m_writerNumCommands + 1));
+            {
+                auto currentSequenceIndex{ writer->GetCurrentSequenceIndex() };
+                EXPECT_EQ(currentSequenceIndex, 0);
+            }
+        }
+
+        // Valid NextSequence
+        {
+            auto writer = CreateInitializedWriter();
+            EXPECT_TRUE(writer->NextSequence());
+            {
+                auto currentSequenceIndex{ writer->GetCurrentSequenceIndex() };
+                EXPECT_EQ(currentSequenceIndex, 1);
+            }
+        }
+
+        // Invalid NextSequence
+        {
+            auto writer = CreateInitializedWriter();
+            EXPECT_TRUE(writer->Seek(m_writerNumCommands - 1));
+            EXPECT_FALSE(writer->NextSequence());
+            {
+                auto currentSequenceIndex{ writer->GetCurrentSequenceIndex() };
+                EXPECT_EQ(currentSequenceIndex, m_writerNumCommands - 1);
+            }
+        }
+
+        // Valid Command
+        {
+            auto writer = CreateInitializedWriter();
+            for (const auto& command : m_commands)
+            {
+                switch (command.m_type)
+                {
+                case RHI::IndirectCommandType::VertexBufferView:
+                    {
+                        auto index = m_signatureDescriptor.m_layout.FindCommandIndex(RHI::IndirectBufferViewArguments{ s_vertexSlotIndex });
+                        EXPECT_FALSE(index.IsNull());
+                        AZ::RHI::MultiDeviceStreamBufferView bufferView(*m_buffer, 0, 12, 10);
+                        for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                            EXPECT_CALL(
+                                *static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get()),
+                                SetVertexViewInternal(index, testing::_))
+                                .Times(1);
+                        writer->SetVertexView(s_vertexSlotIndex, bufferView);
+                        break;
+                    }
+                case RHI::IndirectCommandType::IndexBufferView:
+                    {
+                        auto index = m_signatureDescriptor.m_layout.FindCommandIndex(command.m_type);
+                        EXPECT_FALSE(index.IsNull());
+                        AZ::RHI::MultiDeviceIndexBufferView indexView(*m_buffer, 0, 12, RHI::IndexFormat::Uint16);
+                        for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                            EXPECT_CALL(
+                                *static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get()),
+                                SetIndexViewInternal(index, testing::_))
+                                .Times(1);
+                        writer->SetIndexView(indexView);
+                        break;
+                    }
+                case RHI::IndirectCommandType::DrawIndexed:
+                    {
+                        auto index = m_signatureDescriptor.m_layout.FindCommandIndex(command.m_type);
+                        EXPECT_FALSE(index.IsNull());
+                        AZ::RHI::DrawIndexed arguments(1, 2, 3, 4, 5);
+                        for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                            EXPECT_CALL(
+                                *static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get()),
+                                DrawIndexedInternal(index, testing::_))
+                                .Times(1);
+                        writer->DrawIndexed(arguments);
+                        break;
+                    }
+                case RHI::IndirectCommandType::RootConstants:
+                    {
+                        auto index = m_signatureDescriptor.m_layout.FindCommandIndex(command.m_type);
+                        EXPECT_FALSE(index.IsNull());
+                        size_t rootConstant;
+                        for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                            EXPECT_CALL(
+                                *static_cast<IndirectBufferSignature*>(
+                                    m_writerSignature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                                GetOffsetInternal(index))
+                                .Times(1)
+                                .WillOnce(testing::Return(0));
+
+                        auto nextIndex = RHI::IndirectCommandIndex(index.GetIndex() + 1);
+                        for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                            EXPECT_CALL(
+                                *static_cast<IndirectBufferSignature*>(
+                                    m_writerSignature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                                GetOffsetInternal(nextIndex))
+                                .Times(1)
+                                .WillOnce(testing::Return(static_cast<uint32_t>(sizeof(rootConstant))));
+
+                        for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                            EXPECT_CALL(
+                                *static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get()),
+                                SetRootConstantsInternal(index, reinterpret_cast<uint8_t*>(&rootConstant), sizeof(rootConstant)))
+                                .Times(1);
+                        writer->SetRootConstants(reinterpret_cast<uint8_t*>(&rootConstant), sizeof(rootConstant));
+                        break;
+                    }
+                default:
+                    break;
+                }
+            }
+        }
+
+        // Invalid command
+        {
+            auto writer = CreateInitializedWriter();
+            RHI::DispatchDirect args;
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            writer->Dispatch(args);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+        }
+
+        // Write command on uninitialized writer
+        {
+            RHI::Ptr<AZ::RHI::MultiDeviceIndirectBufferWriter> writer = aznew AZ::RHI::MultiDeviceIndirectBufferWriter;
+            ;
+            AZ::RHI::DrawIndexed arguments(1, 2, 3, 4, 5);
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            writer->DrawIndexed(arguments);
+            AZ_TEST_STOP_TRACE_SUPPRESSION(1);
+        }
+
+        // Flush
+        {
+            auto writer = CreateInitializedWriter();
+            writer->Flush();
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                EXPECT_FALSE(static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->IsMapped());
+            AZ::RHI::MultiDeviceIndexBufferView indexView(*m_buffer, 0, 12, RHI::IndexFormat::Uint16);
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                EXPECT_CALL(
+                    *static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get()),
+                    SetIndexViewInternal(testing::_, testing::_))
+                    .Times(1);
+            writer->SetIndexView(indexView);
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                EXPECT_TRUE(static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->IsMapped());
+        }
+
+        // Inline Constants Command with incorrect size
+        {
+            auto writer = CreateInitializedWriter();
+            auto findIt = AZStd::find_if(
+                m_commands.begin(),
+                m_commands.end(),
+                [](const auto& element)
+                {
+                    return element.m_type == RHI::IndirectCommandType::RootConstants;
+                });
+            EXPECT_NE(findIt, m_commands.end());
+            auto commandIndex = m_writerSignature->GetLayout().FindCommandIndex(*findIt);
+            EXPECT_FALSE(commandIndex.IsNull());
+            auto nextCommandIndex = RHI::IndirectCommandIndex(commandIndex.GetIndex() + 1);
+            uint32_t commandOffsett = 12;
+            uint32_t nextCommandOffset = 16;
+
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+            {
+                EXPECT_CALL(
+                    *static_cast<IndirectBufferSignature*>(m_writerSignature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                    GetOffsetInternal(commandIndex))
+                    .Times(1)
+                    .WillOnce(testing::Return(commandOffsett));
+
+                EXPECT_CALL(
+                    *static_cast<IndirectBufferSignature*>(m_writerSignature->GetDeviceIndirectBufferSignature(deviceIndex).get()),
+                    GetOffsetInternal(nextCommandIndex))
+                    .Times(1)
+                    .WillOnce(testing::Return(nextCommandOffset));
+            }
+
+            AZ_TEST_START_TRACE_SUPPRESSION;
+            uint64_t data;
+            writer->SetRootConstants(reinterpret_cast<uint8_t*>(&data), sizeof(data));
+            AZ_TEST_STOP_TRACE_SUPPRESSION(DeviceCount);
+        }
+
+        // Shutdown
+        {
+            auto writer = CreateInitializedWriter();
+            writer->Shutdown();
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                EXPECT_FALSE(static_cast<Buffer*>(m_buffer->GetDeviceBuffer(deviceIndex).get())->IsMapped());
+            for (auto deviceIndex{ 0 }; deviceIndex < DeviceCount; ++deviceIndex)
+                EXPECT_TRUE(
+                    static_cast<IndirectBufferWriter*>(writer->GetDeviceIndirectBufferWriter(deviceIndex).get())->GetData() == nullptr);
+        }
+    }
+} // namespace UnitTest

+ 7 - 0
Gems/Atom/RHI/Code/atom_rhi_public_files.cmake

@@ -51,6 +51,7 @@ set(FILES
     Include/Atom/RHI/DrawPacket.h
     Include/Atom/RHI/DrawPacketBuilder.h
     Include/Atom/RHI/IndirectArguments.h
+    Include/Atom/RHI/MultiDeviceIndirectArguments.h
     Source/RHI/CommandList.cpp
     Source/RHI/CommandListValidator.cpp
     Source/RHI/ConstantsData.cpp
@@ -125,11 +126,17 @@ set(FILES
     Source/RHI/StreamingImagePool.cpp
     Source/RHI/MultiDeviceStreamingImagePool.cpp
     Include/Atom/RHI/IndirectBufferSignature.h
+    Include/Atom/RHI/MultiDeviceIndirectBufferSignature.h
     Include/Atom/RHI/IndirectBufferView.h
+    Include/Atom/RHI/MultiDeviceIndirectBufferView.h
     Include/Atom/RHI/IndirectBufferWriter.h
+    Include/Atom/RHI/MultiDeviceIndirectBufferWriter.h
     Source/RHI/IndirectBufferSignature.cpp
+    Source/RHI/MultiDeviceIndirectBufferSignature.cpp
     Source/RHI/IndirectBufferView.cpp
+    Source/RHI/MultiDeviceIndirectBufferView.cpp
     Source/RHI/IndirectBufferWriter.cpp
+    Source/RHI/MultiDeviceIndirectBufferWriter.cpp
     Include/Atom/RHI/Object.h
     Include/Atom/RHI/ObjectCache.h
     Include/Atom/RHI/ObjectCollector.h

+ 1 - 0
Gems/Atom/RHI/Code/atom_rhi_tests_files.cmake

@@ -56,6 +56,7 @@ set(FILES
     Tests/MultiDeviceQueryTests.cpp
     Tests/MultiDeviceBufferTests.cpp
     Tests/MultiDeviceImageTests.cpp
+    Tests/MultiDeviceIndirectBufferTests.cpp
     Tests/MultiDeviceShaderResourceGroupTests.cpp
 )
 

+ 1 - 5
Gems/Atom/RHI/DX12/Code/Source/RHI/CommandList.h

@@ -18,6 +18,7 @@
 #include <Atom/RHI/CommandList.h>
 #include <Atom/RHI/CommandListValidator.h>
 #include <Atom/RHI/CommandListStates.h>
+#include <Atom/RHI/IndirectArguments.h>
 #include <Atom/RHI/ObjectPool.h>
 #include <AzCore/std/containers/span.h>
 #include <AzCore/Memory/SystemAllocator.h>
@@ -31,11 +32,6 @@
 
 namespace AZ
 {
-    namespace RHI
-    {
-        struct IndirectArguments;
-    }
-
     namespace DX12
     {
         class CommandQueue;