Kaynağa Gözat

add payload size to Amplification Shader metadata to mirror MS metadata (#2359)

amarpMSFT 6 yıl önce
ebeveyn
işleme
37acf90723

+ 2 - 0
include/dxc/DXIL/DxilFunctionProps.h

@@ -79,6 +79,8 @@ struct DxilFunctionProps {
     // Amplification shader.
     struct {
       unsigned numThreads[3];
+      // The following doesn't go into metadata
+      unsigned payloadByteSize;
     } AS;
   } ShaderProps;
   DXIL::ShaderKind shaderKind;

+ 4 - 0
include/dxc/DxilContainer/DxilPipelineStateValidation.h

@@ -56,6 +56,9 @@ struct MSInfo {
   uint16_t MaxOutputVertices;
   uint16_t MaxOutputPrimitives;
 };
+struct ASInfo {
+  uint32_t PayloadSizeInBytes;
+};
 
 // Versioning is additive and based on size
 struct PSVRuntimeInfo0
@@ -67,6 +70,7 @@ struct PSVRuntimeInfo0
     GSInfo GS;
     PSInfo PS;
     MSInfo MS;
+    ASInfo AS;
   };
   uint32_t MinimumExpectedWaveLaneCount;  // minimum lane count required, 0 if unused
   uint32_t MaximumExpectedWaveLaneCount;  // maximum lane count required, 0xffffffff if unused

+ 31 - 10
lib/DXIL/DxilModule.cpp

@@ -683,20 +683,41 @@ void DxilModule::SetMeshOutputTopology(DXIL::MeshOutputTopology MeshOutputTopolo
 }
 
 unsigned DxilModule::GetPayloadByteSize() const {
-  if (!m_pSM->IsMS())
+  if (m_pSM->IsMS())
+  {
+    DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+    DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+    DXASSERT(props.IsMS(), "Must be MS profile");
+    return props.ShaderProps.MS.payloadByteSize;      
+  }
+  else if(m_pSM->IsAS())
+  {
+    DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
+    DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+    DXASSERT(props.IsAS(), "Must be AS profile");
+    return props.ShaderProps.AS.payloadByteSize;
+  }
+  else
+  {
     return 0;
-  DXASSERT(m_DxilEntryPropsMap.size() == 1, "should have one entry prop");
-  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
-  DXASSERT(props.IsMS(), "Must be MS profile");
-  return props.ShaderProps.MS.payloadByteSize;
+  }
 }
 
 void DxilModule::SetPayloadByteSize(unsigned Size) {
-  DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsMS(),
-           "only works for MS profile");
-  DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
-  DXASSERT(props.IsMS(), "Must be MS profile");
-  props.ShaderProps.MS.payloadByteSize = Size;
+  DXASSERT(m_DxilEntryPropsMap.size() == 1 && (m_pSM->IsMS() || m_pSM->IsAS()),
+           "only works for MS or AS profile");
+  if (m_pSM->IsMS())
+  {
+    DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+    DXASSERT(props.IsMS(), "Must be MS profile");
+    props.ShaderProps.MS.payloadByteSize = Size;
+  } 
+  else if (m_pSM->IsAS())
+  {
+    DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
+    DXASSERT(props.IsAS(), "Must be AS profile");
+    props.ShaderProps.AS.payloadByteSize = Size;
+  }
 }
 
 void DxilModule::SetAutoBindingSpace(uint32_t Space) {

+ 38 - 15
lib/DxilContainer/DxilContainerAssembler.cpp

@@ -25,6 +25,7 @@
 #include "dxc/DXIL/DxilUtil.h"
 #include "dxc/DXIL/DxilFunctionProps.h"
 #include "dxc/DXIL/DxilOperations.h"
+#include "dxc/DXIL/DxilInstructions.h"
 #include "dxc/Support/Global.h"
 #include "dxc/Support/Unicode.h"
 #include "dxc/Support/WinIncludes.h"
@@ -601,13 +602,12 @@ public:
         }
         break;
       }
-    case ShaderModel::Kind::Amplification:
     case ShaderModel::Kind::Compute:
     case ShaderModel::Kind::Library:
     case ShaderModel::Kind::Invalid:
-      // Amplification, Compute, Library, and Invalide not relevant to PSVRuntimeInfo0
+      // Compute, Library, and Invalid not relevant to PSVRuntimeInfo0
       break;
-    case ShaderModel::Kind::Mesh:
+    case ShaderModel::Kind::Mesh: {
       pInfo->MS.MaxOutputVertices = (UINT)m_Module.GetMaxOutputVertices();
       pInfo->MS.MaxOutputPrimitives = (UINT)m_Module.GetMaxOutputPrimitives();
       pInfo1->MeshOutputTopology = (UINT)m_Module.GetMeshOutputTopology();
@@ -634,18 +634,11 @@ public:
           // Calls to external functions.
           const CallInst *CI = dyn_cast<CallInst>(&I);
           if (CI) {
-            Function *FCalled = CI->getCalledFunction();
-            if (FCalled->isDeclaration()) {
-              Value *opcodeVal = CI->getOperand(0);
-              ConstantInt *OpcodeConst = dyn_cast<ConstantInt>(opcodeVal);
-              unsigned opcode = OpcodeConst->getLimitedValue();
-              DXIL::OpCode dxilOpcode = (DXIL::OpCode)opcode;
-              if (dxilOpcode == DXIL::OpCode::GetMeshPayload) {
-                PointerType *payloadPTy = cast<PointerType>(CI->getType());
-                Type *payloadTy = payloadPTy->getPointerElementType();
-                payloadByteSize = DL.getTypeAllocSize(payloadTy);
-                break;
-              }
+            if (hlsl::OP::IsDxilOpFuncCallInst(CI,DXIL::OpCode::GetMeshPayload)) {
+              PointerType *payloadPTy = cast<PointerType>(CI->getType());
+              Type *payloadTy = payloadPTy->getPointerElementType();
+              payloadByteSize = DL.getTypeAllocSize(payloadTy);
+              break;
             }
           }
         }
@@ -655,6 +648,36 @@ public:
       pInfo->MS.PayloadSizeInBytes = payloadByteSize;
       break;
     }
+    case ShaderModel::Kind::Amplification: {
+      const Function *entryFunc = m_Module.GetEntryFunction();
+      unsigned payloadByteSize = 0;
+      Module *mod = m_Module.GetModule();
+      const DataLayout &DL = mod->getDataLayout();
+      for (auto b = entryFunc->begin(), bend = entryFunc->end(); b != bend;
+           ++b) {
+        auto i = b->begin(), iend = b->end();
+        for (; i != iend; ++i) {
+          const Instruction &I = *i;
+
+          // Calls to external functions.
+          const CallInst *CI = dyn_cast<CallInst>(&I);
+          if (CI) {
+            if (hlsl::OP::IsDxilOpFuncCallInst(CI,DXIL::OpCode::DispatchMesh)) {
+              DxilInst_DispatchMesh dispatchMeshCall(const_cast<CallInst*>(CI));
+              Value *operandVal = dispatchMeshCall.get_payload();
+              Type *payloadTy = operandVal->getType();
+              payloadByteSize = DL.getTypeAllocSize(payloadTy);
+              break;
+            }
+          }
+        }
+        if (i != iend)
+          break;
+      }
+      pInfo->AS.PayloadSizeInBytes = payloadByteSize;
+      break;
+    }
+    }
 
     // Set resource binding information
     UINT uResIndex = 0;

+ 18 - 0
lib/HLSL/DxilValidation.cpp

@@ -2861,6 +2861,24 @@ static void ValidateAsIntrinsics(Function *F, ValidationContext &ValCtx, CallIns
     DXIL::ShaderKind shaderKind = ValCtx.DxilMod.GetDxilFunctionProps(F).shaderKind;
     if (shaderKind != DXIL::ShaderKind::Amplification)
       return;
+
+    if (dispatchMesh) {
+      DxilInst_DispatchMesh dispatchMeshCall(dispatchMesh);
+      Value *operandVal = dispatchMeshCall.get_payload();
+      Type *payloadTy = operandVal->getType();
+      const DataLayout &DL = F->getParent()->getDataLayout();
+      unsigned payloadSize = DL.getTypeAllocSize(payloadTy);
+
+      if (payloadSize > DXIL::kMaxMSASPayloadSize) {
+        ValCtx.EmitFormatError(
+            ValidationRule::SmAmplificationShaderPayloadSize,
+            {F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize)});
+      }
+
+      DxilFunctionProps &prop = ValCtx.DxilMod.GetDxilFunctionProps(F);
+      prop.ShaderProps.AS.payloadByteSize = payloadSize;
+    }
+
   }
   else {
     return;