Bladeren bron

Use FrameGraphExecuteContext's device index in passes.

Signed-off-by: Joerg H. Mueller <[email protected]>
Joerg H. Mueller 1 jaar geleden
bovenliggende
commit
a9541b8cf9
28 gewijzigde bestanden met toevoegingen van 62 en 62 verwijderingen
  1. 1 1
      Gems/Atom/Feature/Common/Code/Source/CoreLights/DepthExponentiationPass.cpp
  2. 1 1
      Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp
  3. 1 1
      Gems/Atom/Feature/Common/Code/Source/CoreLights/ShadowmapPass.cpp
  4. 1 1
      Gems/Atom/Feature/Common/Code/Source/DisplayMapper/OutputTransformPass.cpp
  5. 2 2
      Gems/Atom/Feature/Common/Code/Source/ImGui/ImGuiPass.cpp
  6. 1 1
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/BlendColorGradingLutsPass.cpp
  7. 1 1
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/DepthOfFieldBokehBlurPass.cpp
  8. 1 1
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/DepthOfFieldCompositePass.cpp
  9. 1 1
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/DepthOfFieldCopyFocusDepthToCpuPass.cpp
  10. 1 1
      Gems/Atom/Feature/Common/Code/Source/PostProcessing/LookModificationCompositePass.cpp
  11. 5 5
      Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingAccelerationStructurePass.cpp
  12. 8 8
      Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp
  13. 1 1
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/RasterPass.cpp
  14. 1 1
      Gems/Atom/RPI/Code/Source/RPI.Public/Pass/Specific/ImageAttachmentPreviewPass.cpp
  15. 1 1
      Gems/AtomTressFX/Code/Passes/HairSkinningComputePass.cpp
  16. 2 2
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridBlendDistancePass.cpp
  17. 2 2
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridBlendIrradiancePass.cpp
  18. 1 1
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridBorderUpdatePass.cpp
  19. 2 2
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridClassificationPass.cpp
  20. 2 2
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridPreparePass.cpp
  21. 4 4
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridQueryFullscreenPass.cpp
  22. 2 2
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridQueryPass.cpp
  23. 6 6
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridRayTracingPass.cpp
  24. 2 2
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridRelocationPass.cpp
  25. 3 3
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridVisualizationAccelerationStructurePass.cpp
  26. 2 2
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridVisualizationPreparePass.cpp
  27. 6 6
      Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridVisualizationRayTracingPass.cpp
  28. 1 1
      Gems/Meshlets/Code/Source/Meshlets/MeshletsRenderPass.cpp

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/CoreLights/DepthExponentiationPass.cpp

@@ -89,7 +89,7 @@ namespace AZ
         void DepthExponentiationPass::BuildCommandListInternal(const RHI::FrameGraphExecuteContext& context)
         {
             const uint32_t typeIndex = aznumeric_cast<uint32_t>(m_shadowmapType);
-            m_dispatchItem.m_pipelineState = m_shaderVariant[typeIndex].m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+            m_dispatchItem.m_pipelineState = m_shaderVariant[typeIndex].m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
             Base::BuildCommandListInternal(context);
         }

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/CoreLights/LightCullingTilePreparePass.cpp

@@ -54,7 +54,7 @@ namespace AZ
             m_dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = resolution.m_width;
             m_dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = resolution.m_height;
             m_dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;
-            m_dispatchItem.m_pipelineState = m_msaaPipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+            m_dispatchItem.m_pipelineState = m_msaaPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
             commandList->Submit(m_dispatchItem);
         }
 

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/CoreLights/ShadowmapPass.cpp

@@ -201,7 +201,7 @@ namespace AZ
                 if (startIndex == 0)
                 {
                     RHI::CommandList* commandList = context.GetCommandList();
-                    commandList->Submit(m_clearShadowDrawPacket->GetDrawItemProperties(0).m_mdItem->GetDeviceDrawItem(RHI::MultiDevice::DefaultDeviceIndex), 0);
+                    commandList->Submit(m_clearShadowDrawPacket->GetDrawItemProperties(0).m_mdItem->GetDeviceDrawItem(context.GetDeviceIndex()), 0);
                 }
                 else
                 {

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/DisplayMapper/OutputTransformPass.cpp

@@ -83,7 +83,7 @@ namespace AZ
 
             SetSrgsForDraw(commandList);
 
-            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
             commandList->Submit(m_item);
         }

+ 2 - 2
Gems/Atom/Feature/Common/Code/Source/ImGui/ImGuiPass.cpp

@@ -674,13 +674,13 @@ namespace AZ
             AZ_PROFILE_SCOPE(AzRender, "ImGuiPass: BuildCommandListInternal");
 
             context.GetCommandList()->SetViewport(m_viewportState);
-            context.GetCommandList()->SetShaderResourceGroupForDraw(*m_resourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+            context.GetCommandList()->SetShaderResourceGroupForDraw(*m_resourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
 
             for (uint32_t i = context.GetSubmitRange().m_startIndex; i < context.GetSubmitRange().m_endIndex; ++i)
             {
                 RHI::SingleDeviceDrawItem drawItem;
                 drawItem.m_arguments = m_draws.at(i).m_drawIndexed;
-                drawItem.m_pipelineState = m_pipelineState->GetRHIPipelineState()->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                drawItem.m_pipelineState = m_pipelineState->GetRHIPipelineState()->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 drawItem.m_indexBufferView = &m_indexBufferView;
                 drawItem.m_streamBufferViewCount = 2;
                 drawItem.m_streamBufferViews = m_vertexBufferView.data();

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/PostProcessing/BlendColorGradingLutsPass.cpp

@@ -210,7 +210,7 @@ namespace AZ
         {
             if (m_needToUpdateLut && m_blendedLut.m_lutImage && m_currentShaderVariantIndex <= LookModificationSettings::MaxBlendLuts)
             {
-                m_dispatchItem.m_pipelineState = m_shaderVariant[m_currentShaderVariantIndex].m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                m_dispatchItem.m_pipelineState = m_shaderVariant[m_currentShaderVariantIndex].m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
                 ComputePass::BuildCommandListInternal(context);
 

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/PostProcessing/DepthOfFieldBokehBlurPass.cpp

@@ -174,7 +174,7 @@ namespace AZ
 
             SetSrgsForDraw(commandList);
 
-            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
             commandList->Submit(m_item);
         }

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/PostProcessing/DepthOfFieldCompositePass.cpp

@@ -143,7 +143,7 @@ namespace AZ
 
             SetSrgsForDraw(commandList);
 
-            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
             commandList->Submit(m_item);
         }

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/PostProcessing/DepthOfFieldCopyFocusDepthToCpuPass.cpp

@@ -98,7 +98,7 @@ namespace AZ
 
         void DepthOfFieldCopyFocusDepthToCpuPass::BuildCommandList(const RHI::FrameGraphExecuteContext& context)
         {
-            context.GetCommandList()->Submit(m_copyDescriptor.GetDeviceCopyBufferDescriptor(RHI::MultiDevice::DefaultDeviceIndex));
+            context.GetCommandList()->Submit(m_copyDescriptor.GetDeviceCopyBufferDescriptor(context.GetDeviceIndex()));
         }
 
     }   // namespace RPI

+ 1 - 1
Gems/Atom/Feature/Common/Code/Source/PostProcessing/LookModificationCompositePass.cpp

@@ -249,7 +249,7 @@ namespace AZ
 
             SetSrgsForDraw(commandList);
 
-            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+            m_item.m_pipelineState = GetPipelineStateFromShaderVariant()->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
             commandList->Submit(m_item);
         }

+ 5 - 5
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingAccelerationStructurePass.cpp

@@ -265,11 +265,11 @@ namespace AZ
                     for (auto submeshIndex = 0; submeshIndex < blasInstance.second.m_subMeshes.size(); ++submeshIndex)
                     {
                         auto& submeshBlasInstance = blasInstance.second.m_subMeshes[submeshIndex];
-                        changedBlasList.push_back(submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(RHI::MultiDevice::DefaultDeviceIndex).get());
+                        changedBlasList.push_back(submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()).get());
                         if (blasInstance.second.m_blasBuilt == false)
                         {
                             // Always build the BLAS, if it has not previously been built
-                            context.GetCommandList()->BuildBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(RHI::MultiDevice::DefaultDeviceIndex));
+                            context.GetCommandList()->BuildBottomLevelAccelerationStructure(*submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
                             continue;
                         }
 
@@ -282,13 +282,13 @@ namespace AZ
                         {
                             // Skinned mesh that simply needs an update
                             context.GetCommandList()->UpdateBottomLevelAccelerationStructure(
-                                *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(RHI::MultiDevice::DefaultDeviceIndex));
+                                *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
                         }
                         else
                         {
                             // Fall back to building the BLAS in any case
                             context.GetCommandList()->BuildBottomLevelAccelerationStructure(
-                                *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(RHI::MultiDevice::DefaultDeviceIndex));
+                                *submeshBlasInstance.m_blas->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
                         }
                     }
 
@@ -297,7 +297,7 @@ namespace AZ
             }
 
             // build the TLAS object
-            context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas()->GetDeviceRayTracingTlas(RHI::MultiDevice::DefaultDeviceIndex), changedBlasList);
+            context.GetCommandList()->BuildTopLevelAccelerationStructure(*rayTracingFeatureProcessor->GetTlas()->GetDeviceRayTracingTlas(context.GetDeviceIndex()), changedBlasList);
 
             ++m_frameCount;
 

+ 8 - 8
Gems/Atom/Feature/Common/Code/Source/RayTracing/RayTracingPass.cpp

@@ -417,11 +417,11 @@ namespace AZ
 
             // bind RayTracingGlobal, RayTracingScene, and View Srgs
             // [GFX TODO][ATOM-15610] Add RenderPass::SetSrgsForRayTracingDispatch
-            AZStd::vector<RHI::SingleDeviceShaderResourceGroup*> shaderResourceGroups = { m_shaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get() };
+            AZStd::vector<RHI::SingleDeviceShaderResourceGroup*> shaderResourceGroups = { m_shaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get() };
 
             if (m_requiresRayTracingSceneSrg)
             {
-                shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get());
+                shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
             }
 
             if (m_requiresViewSrg)
@@ -429,26 +429,26 @@ namespace AZ
                 RPI::ViewPtr view = m_pipeline->GetFirstView(GetPipelineViewTag());
                 if (view)
                 {
-                    shaderResourceGroups.push_back(view->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get());
+                    shaderResourceGroups.push_back(view->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
                 }
             }
 
             if (m_requiresSceneSrg)
             {
-                shaderResourceGroups.push_back(scene->GetShaderResourceGroup()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get());
+                shaderResourceGroups.push_back(scene->GetShaderResourceGroup()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
             }
 
             if (m_requiresRayTracingMaterialSrg)
             {
-                shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get());
+                shaderResourceGroups.push_back(rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
             }
 
             dispatchRaysItem.m_shaderResourceGroupCount = aznumeric_cast<uint32_t>(shaderResourceGroups.size());
             dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups.data();
             dispatchRaysItem.m_rayTracingPipelineState =
-                m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
-            dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(RHI::MultiDevice::DefaultDeviceIndex).get();
-            dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
+            dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
+            dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
             // submit the DispatchRays item
             context.GetCommandList()->Submit(dispatchRaysItem);

+ 1 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/RasterPass.cpp

@@ -264,7 +264,7 @@ namespace AZ
                 const RHI::MultiDeviceDrawItemProperties& drawItemProperties = m_drawListView[index];
                 if (drawItemProperties.m_drawFilterMask & m_pipeline->GetDrawFilterMask())
                 {
-                    commandList->Submit(drawItemProperties.m_mdItem->GetDeviceDrawItem(RHI::MultiDevice::DefaultDeviceIndex), index + indexOffset);
+                    commandList->Submit(drawItemProperties.m_mdItem->GetDeviceDrawItem(context.GetDeviceIndex()), index + indexOffset);
                 }
             }
         }

+ 1 - 1
Gems/Atom/RPI/Code/Source/RPI.Public/Pass/Specific/ImageAttachmentPreviewPass.cpp

@@ -544,7 +544,7 @@ namespace AZ
             commandList->SetScissor(m_scissor);
 
             // submit srg
-            commandList->SetShaderResourceGroupForDraw(*m_passSrg->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+            commandList->SetShaderResourceGroupForDraw(*m_passSrg->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
 
             // submit draw call
             for (uint32_t index = context.GetSubmitRange().m_startIndex; index < context.GetSubmitRange().m_endIndex; ++index)

+ 1 - 1
Gems/AtomTressFX/Code/Passes/HairSkinningComputePass.cpp

@@ -237,7 +237,7 @@ namespace AZ
 
                 for (uint32_t index = context.GetSubmitRange().m_startIndex; index < context.GetSubmitRange().m_endIndex; ++index, ++it)
                 {
-                    commandList->Submit((*it)->GetDeviceDispatchItem(RHI::MultiDevice::DefaultDeviceIndex), index);
+                    commandList->Submit((*it)->GetDeviceDispatchItem(context.GetDeviceIndex()), index);
                 }
 
                 // Clear the dispatch items. They will need to be re-populated next frame

+ 2 - 2
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridBlendDistancePass.cpp

@@ -187,7 +187,7 @@ namespace AZ
                 DiffuseProbeGridShader& shader = m_shaders[diffuseProbeGrid->GetNumRaysPerProbe().m_index];
 
                 const RHI::MultiDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetBlendDistanceSrg()->GetRHIShaderResourceGroup();
-                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
 
                 uint32_t probeCountX;
                 uint32_t probeCountY;
@@ -197,7 +197,7 @@ namespace AZ
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
                 dispatchItem.m_arguments = shader.m_dispatchArgs;
-                dispatchItem.m_pipelineState = shader.m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = shader.m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = probeCountX * dispatchItem.m_arguments.m_direct.m_threadsPerGroupX;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = probeCountY * dispatchItem.m_arguments.m_direct.m_threadsPerGroupY;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;

+ 2 - 2
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridBlendIrradiancePass.cpp

@@ -177,7 +177,7 @@ namespace AZ
                 DiffuseProbeGridShader& shader = m_shaders[diffuseProbeGrid->GetNumRaysPerProbe().m_index];
 
                 const RHI::MultiDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetBlendIrradianceSrg()->GetRHIShaderResourceGroup();
-                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get());
+                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get());
 
                 uint32_t probeCountX;
                 uint32_t probeCountY;
@@ -187,7 +187,7 @@ namespace AZ
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
                 dispatchItem.m_arguments = shader.m_dispatchArgs;
-                dispatchItem.m_pipelineState = shader.m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = shader.m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = probeCountX * dispatchItem.m_arguments.m_direct.m_threadsPerGroupX;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = probeCountY * dispatchItem.m_arguments.m_direct.m_threadsPerGroupY;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;

+ 1 - 1
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridBorderUpdatePass.cpp

@@ -226,7 +226,7 @@ namespace AZ
             {
                 SubmitItem& submitItem = m_submitItems[index];
 
-                commandList->SetShaderResourceGroupForDispatch(*submitItem.m_shaderResourceGroup->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+                commandList->SetShaderResourceGroupForDispatch(*submitItem.m_shaderResourceGroup->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
                 commandList->Submit(submitItem.m_dispatchItem, index++);
             }
         }

+ 2 - 2
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridClassificationPass.cpp

@@ -179,11 +179,11 @@ namespace AZ
                 DiffuseProbeGridShader& shader = m_shaders[diffuseProbeGrid->GetNumRaysPerProbe().m_index];
 
                 const RHI::MultiDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetClassificationSrg()->GetRHIShaderResourceGroup();
-                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
                 dispatchItem.m_arguments = shader.m_dispatchArgs;
-                dispatchItem.m_pipelineState = shader.m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = shader.m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = AZ::DivideAndRoundUp(diffuseProbeGrid->GetTotalProbeCount(), diffuseProbeGrid->GetFrameUpdateCount());
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = 1;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;

+ 2 - 2
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridPreparePass.cpp

@@ -148,11 +148,11 @@ namespace AZ
                 AZStd::shared_ptr<DiffuseProbeGrid> diffuseProbeGrid = diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids()[index];
 
                 const RHI::MultiDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetPrepareSrg()->GetRHIShaderResourceGroup();
-                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
                 dispatchItem.m_arguments = m_dispatchArgs;
-                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = 1;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = 1;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;

+ 4 - 4
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridQueryFullscreenPass.cpp

@@ -220,13 +220,13 @@ namespace AZ
                 const uint8_t srgCount = 3;
                 AZStd::array<const RHI::SingleDeviceShaderResourceGroup*, 8> shaderResourceGroups =
                 {
-                    diffuseProbeGrid->GetQuerySrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get(),
-                    m_shaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get(),
-                    views[0]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get()
+                    diffuseProbeGrid->GetQuerySrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
+                    m_shaderResourceGroup->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
+                    views[0]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
                 };
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
-                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments = m_dispatchArgs;
                 dispatchItem.m_shaderResourceGroupCount = srgCount;
                 dispatchItem.m_shaderResourceGroups = shaderResourceGroups;

+ 2 - 2
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridQueryPass.cpp

@@ -251,11 +251,11 @@ namespace AZ
                 AZStd::shared_ptr<DiffuseProbeGrid> diffuseProbeGrid = diffuseProbeGridFeatureProcessor->GetVisibleProbeGrids()[index];
 
                 const RHI::MultiDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetQuerySrg()->GetRHIShaderResourceGroup();
-                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
                 dispatchItem.m_arguments = m_dispatchArgs;
-                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = diffuseProbeGridFeatureProcessor->GetIrradianceQueryCount();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = 1;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;

+ 6 - 6
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridRayTracingPass.cpp

@@ -358,20 +358,20 @@ namespace AZ
                     AZStd::shared_ptr<DiffuseProbeGrid> diffuseProbeGrid = diffuseProbeGridFeatureProcessor->GetVisibleRealTimeProbeGrids()[index];
 
                     const RHI::SingleDeviceShaderResourceGroup* shaderResourceGroups[] = {
-                        diffuseProbeGrid->GetRayTraceSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get(),
-                        rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get(),
-                        rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get()
+                        diffuseProbeGrid->GetRayTraceSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
+                        rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
+                        rayTracingFeatureProcessor->GetRayTracingMaterialSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get()
                     };
 
                     RHI::SingleDeviceDispatchRaysItem dispatchRaysItem;
                     dispatchRaysItem.m_arguments.m_direct.m_width = diffuseProbeGrid->GetNumRaysPerProbe().m_rayCount;
                     dispatchRaysItem.m_arguments.m_direct.m_height = AZ::DivideAndRoundUp(diffuseProbeGrid->GetTotalProbeCount(), diffuseProbeGrid->GetFrameUpdateCount());
                     dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
-                    dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
-                    dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(RHI::MultiDevice::DefaultDeviceIndex).get();
+                    dispatchRaysItem.m_rayTracingPipelineState = m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
+                    dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
                     dispatchRaysItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
                     dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
-                    dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                    dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
                     // submit the DispatchRays item
                     context.GetCommandList()->Submit(dispatchRaysItem, index);

+ 2 - 2
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridRelocationPass.cpp

@@ -210,11 +210,11 @@ namespace AZ
                 AZStd::shared_ptr<DiffuseProbeGrid> diffuseProbeGrid = diffuseProbeGridFeatureProcessor->GetVisibleRealTimeProbeGrids()[index];
 
                 const RHI::MultiDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetRelocationSrg()->GetRHIShaderResourceGroup();
-                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex));
+                commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup->GetDeviceShaderResourceGroup(context.GetDeviceIndex()));
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
                 dispatchItem.m_arguments = m_dispatchArgs;
-                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = AZ::DivideAndRoundUp(diffuseProbeGrid->GetTotalProbeCount(), diffuseProbeGrid->GetFrameUpdateCount());
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = 1;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;

+ 3 - 3
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridVisualizationAccelerationStructurePass.cpp

@@ -154,10 +154,10 @@ namespace AZ
             AZStd::vector<const RHI::SingleDeviceRayTracingBlas*> changedBlasList;
             if (m_visualizationBlasBuilt == false)
             {
-                context.GetCommandList()->BuildBottomLevelAccelerationStructure(*diffuseProbeGridFeatureProcessor->GetVisualizationBlas()->GetDeviceRayTracingBlas(RHI::MultiDevice::DefaultDeviceIndex));
+                context.GetCommandList()->BuildBottomLevelAccelerationStructure(*diffuseProbeGridFeatureProcessor->GetVisualizationBlas()->GetDeviceRayTracingBlas(context.GetDeviceIndex()));
                 m_visualizationBlasBuilt = true;
                 changedBlasList.push_back(diffuseProbeGridFeatureProcessor->GetVisualizationBlas()
-                                              ->GetDeviceRayTracingBlas(RHI::MultiDevice::DefaultDeviceIndex)
+                                              ->GetDeviceRayTracingBlas(context.GetDeviceIndex())
                                               .get());
             }
 
@@ -176,7 +176,7 @@ namespace AZ
                 }
 
                 // build the TLAS object
-                context.GetCommandList()->BuildTopLevelAccelerationStructure(*diffuseProbeGrid->GetVisualizationTlas()->GetDeviceRayTracingTlas(RHI::MultiDevice::DefaultDeviceIndex), changedBlasList);
+                context.GetCommandList()->BuildTopLevelAccelerationStructure(*diffuseProbeGrid->GetVisualizationTlas()->GetDeviceRayTracingTlas(context.GetDeviceIndex()), changedBlasList);
             }
         }
 

+ 2 - 2
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridVisualizationPreparePass.cpp

@@ -268,12 +268,12 @@ namespace AZ
                     continue;
                 }
 
-                const RHI::SingleDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetVisualizationPrepareSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get();
+                const RHI::SingleDeviceShaderResourceGroup* shaderResourceGroup = diffuseProbeGrid->GetVisualizationPrepareSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get();
                 commandList->SetShaderResourceGroupForDispatch(*shaderResourceGroup);
 
                 RHI::SingleDeviceDispatchItem dispatchItem;
                 dispatchItem.m_arguments = m_dispatchArgs;
-                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchItem.m_pipelineState = m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsX = diffuseProbeGrid->GetTotalProbeCount();
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsY = 1;
                 dispatchItem.m_arguments.m_direct.m_totalNumberOfThreadsZ = 1;

+ 6 - 6
Gems/DiffuseProbeGrid/Code/Source/Render/DiffuseProbeGridVisualizationRayTracingPass.cpp

@@ -296,9 +296,9 @@ namespace AZ
                 }
 
                 const RHI::SingleDeviceShaderResourceGroup* shaderResourceGroups[] = {
-                    diffuseProbeGrid->GetVisualizationRayTraceSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get(),
-                    rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get(),
-                    views[0]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(RHI::MultiDevice::DefaultDeviceIndex).get(),
+                    diffuseProbeGrid->GetVisualizationRayTraceSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
+                    rayTracingFeatureProcessor->GetRayTracingSceneSrg()->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
+                    views[0]->GetRHIShaderResourceGroup()->GetDeviceShaderResourceGroup(context.GetDeviceIndex()).get(),
                 };
 
                 RHI::SingleDeviceDispatchRaysItem dispatchRaysItem;
@@ -306,11 +306,11 @@ namespace AZ
                 dispatchRaysItem.m_arguments.m_direct.m_height = m_outputAttachmentSize.m_height;
                 dispatchRaysItem.m_arguments.m_direct.m_depth = 1;
                 dispatchRaysItem.m_rayTracingPipelineState =
-                    m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
-                dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(RHI::MultiDevice::DefaultDeviceIndex).get();
+                    m_rayTracingPipelineState->GetDeviceRayTracingPipelineState(context.GetDeviceIndex()).get();
+                dispatchRaysItem.m_rayTracingShaderTable = m_rayTracingShaderTable->GetDeviceRayTracingShaderTable(context.GetDeviceIndex()).get();
                 dispatchRaysItem.m_shaderResourceGroupCount = RHI::ArraySize(shaderResourceGroups);
                 dispatchRaysItem.m_shaderResourceGroups = shaderResourceGroups;
-                dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+                dispatchRaysItem.m_globalPipelineState = m_globalPipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
                 // submit the DispatchRays item
                 context.GetCommandList()->Submit(dispatchRaysItem, index);

+ 1 - 1
Gems/Meshlets/Code/Source/Meshlets/MeshletsRenderPass.cpp

@@ -180,7 +180,7 @@ namespace AZ
             }
 
             drawRequest.m_listTag = m_drawListTag;
-            drawRequest.m_pipelineState = m_pipelineState->GetDevicePipelineState(RHI::MultiDevice::DefaultDeviceIndex).get();
+            drawRequest.m_pipelineState = m_pipelineState->GetDevicePipelineState(context.GetDeviceIndex()).get();
 
             return true;
         }