Browse Source

Add mesh shader support to RootSignature parsing/validation and fix PSV (#2363)

* Add mesh shader support to RootSignature parsing/validation
* Fix PSV so MS output topology doesn't overlap SigInputVectors
* fix PSV version code to use latest when validation disabled
Tex Riddell 6 years ago
parent
commit
f6ff322db3

+ 27 - 21
include/dxc/DxilContainer/DxilPipelineStateValidation.h

@@ -59,6 +59,10 @@ struct MSInfo {
 struct ASInfo {
   uint32_t PayloadSizeInBytes;
 };
+struct MSInfo1 {
+  uint8_t SigPrimVectors;     // Primitive output for MS
+  uint8_t MeshOutputTopology;
+};
 
 // Versioning is additive and based on size
 struct PSVRuntimeInfo0
@@ -102,7 +106,8 @@ struct PSVRuntimeInfo1 : public PSVRuntimeInfo0
   uint8_t UsesViewID;
   union {
     uint16_t MaxVertexCount;          // MaxVertexCount for GS only (max 1024)
-    uint8_t SigPatchConstOrPrimVectors;  // Output for HS; Input for DS; Primitive output for MS
+    uint8_t SigPatchConstOrPrimVectors;  // Output for HS; Input for DS; Primitive output for MS (overlaps MS1::SigPrimVectors)
+    struct MSInfo1 MS1;
   };
 
   // PSVSignatureElement counts
@@ -111,10 +116,7 @@ struct PSVRuntimeInfo1 : public PSVRuntimeInfo0
   uint8_t SigPatchConstOrPrimElements;
 
   // Number of packed vectors per signature
-  union {
-    uint8_t SigInputVectors;
-    uint8_t MeshOutputTopology;
-  };
+  uint8_t SigInputVectors;
   uint8_t SigOutputVectors[4];      // Array for GS Stream Out Index
 };
 
@@ -326,6 +328,8 @@ public:
   uint32_t GetDynamicIndexMask() const { return !m_pElement0 ? 0 : (uint32_t)m_pElement0->DynamicMaskAndStream & 0xF; }
 };
 
+#define MAX_PSV_VERSION 1
+
 struct PSVInitInfo
 {
   PSVInitInfo(uint32_t psvVersion)
@@ -564,7 +568,7 @@ public:
 
   bool InitNew(const PSVInitInfo &initInfo, void *pBuffer, uint32_t *pSize) {
     if(!(pSize)) return false;
-    if (initInfo.PSVVersion > 1) return false;
+    if (initInfo.PSVVersion > MAX_PSV_VERSION) return false;
 
     // Versioned structure sizes
     m_uPSVRuntimeInfoSize = sizeof(PSVRuntimeInfo0);
@@ -600,9 +604,9 @@ public:
         if (initInfo.ShaderStage == PSVShaderKind::Hull || initInfo.ShaderStage == PSVShaderKind::Mesh)
           size += sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(initInfo.SigPatchConstOrPrimVectors);
       }
-      if (initInfo.ShaderStage != PSVShaderKind::Mesh && initInfo.ShaderStage != PSVShaderKind::Amplification) {
+      if (initInfo.SigInputVectors > 0) {
         for (unsigned i = 0; i < 4; i++) {
-          if (initInfo.SigOutputVectors[i] > 0 && initInfo.SigInputVectors > 0) {
+          if (initInfo.SigOutputVectors[i] > 0) {
             size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigOutputVectors[i]);
             if (initInfo.ShaderStage != PSVShaderKind::Geometry)
               break;
@@ -611,9 +615,9 @@ public:
         if (initInfo.ShaderStage == PSVShaderKind::Hull && initInfo.SigPatchConstOrPrimVectors > 0 && initInfo.SigInputVectors > 0) {
           size += PSVComputeInputOutputTableSize(initInfo.SigInputVectors, initInfo.SigPatchConstOrPrimVectors);
         }
-        if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstOrPrimVectors > 0) {
-          size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstOrPrimVectors, initInfo.SigOutputVectors[0]);
-        }
+      }
+      if (initInfo.ShaderStage == PSVShaderKind::Domain && initInfo.SigOutputVectors[0] > 0 && initInfo.SigPatchConstOrPrimVectors > 0) {
+        size += PSVComputeInputOutputTableSize(initInfo.SigPatchConstOrPrimVectors, initInfo.SigOutputVectors[0]);
       }
     }
 
@@ -710,17 +714,19 @@ public:
       }
 
       // Input to Output dependencies
-      for (unsigned i = 0; i < 4; i++) {
-        if (m_pPSVRuntimeInfo1->SigOutputVectors[i] > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
-          m_pInputToOutputTable = (uint32_t*)pCurBits;
-          pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
+      if (m_pPSVRuntimeInfo1->SigInputVectors > 0) {
+        for (unsigned i = 0; i < 4; i++) {
+          if (m_pPSVRuntimeInfo1->SigOutputVectors[i] > 0) {
+            m_pInputToOutputTable = (uint32_t*)pCurBits;
+            pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigOutputVectors[i]);
+          }
+          if (!IsGS())
+            break;
+        }
+        if (IsHS() && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
+          m_pInputToPCOutputTable = (uint32_t*)pCurBits;
+          pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
         }
-        if (!IsGS())
-          break;
-      }
-      if ((IsHS() || IsMS()) && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0 && m_pPSVRuntimeInfo1->SigInputVectors > 0) {
-        m_pInputToPCOutputTable = (uint32_t*)pCurBits;
-        pCurBits += PSVComputeInputOutputTableSize(m_pPSVRuntimeInfo1->SigInputVectors, m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors);
       }
       if (IsDS() && m_pPSVRuntimeInfo1->SigOutputVectors[0] > 0 && m_pPSVRuntimeInfo1->SigPatchConstOrPrimVectors > 0) {
         m_pPCInputToOutputTable = (uint32_t*)pCurBits;

+ 11 - 4
include/dxc/DxilRootSignature/DxilRootSignature.h

@@ -77,7 +77,8 @@ enum class DxilDescriptorRangeType : unsigned {
   SRV = 0,
   UAV = 1,
   CBV = 2,
-  Sampler = 3
+  Sampler = 3,
+  MaxValue = 3
 };
 enum class DxilRootDescriptorFlags : unsigned {
   None = 0,
@@ -106,15 +107,18 @@ enum class DxilRootSignatureFlags : uint32_t {
   DenyPixelShaderRootAccess = 0x20,
   AllowStreamOutput = 0x40,
   LocalRootSignature = 0x80,
+  DenyAmplificationShaderRootAccess = 0x100,
+  DenyMeshShaderRootAccess = 0x200,
   AllowLowTierReservedHwCbLimit = 0x80000000,
-  ValidFlags = 0x800000ff
+  ValidFlags = 0x800003ff
 };
 enum class DxilRootParameterType {
   DescriptorTable = 0,
   Constants32Bit = 1,
   CBV = 2,
   SRV = 3,
-  UAV = 4
+  UAV = 4,
+  MaxValue = 4
 };
 enum class DxilFilter {
   // TODO: make these names consistent with code convention
@@ -161,7 +165,10 @@ enum class DxilShaderVisibility {
   Hull = 2,
   Domain = 3,
   Geometry = 4,
-  Pixel = 5
+  Pixel = 5,
+  Amplification = 6,
+  Mesh = 7,
+  MaxValue = 7
 };
 enum class DxilStaticBorderColor {
   TransparentBlack = 0,

+ 6 - 2
lib/DxilContainer/DxilContainerAssembler.cpp

@@ -452,8 +452,12 @@ public:
     unsigned ValMajor, ValMinor;
     m_Module.GetValidatorVersion(ValMajor, ValMinor);
     // Allow PSVVersion to be upgraded
-    if (m_PSVInitInfo.PSVVersion < 1 && (ValMajor > 1 || (ValMajor == 1 && ValMinor >= 1)))
+    if (ValMajor == 0 && ValMinor == 0) {
+      // Validation disabled upgrades to maximum PSVVersion
+      m_PSVInitInfo.PSVVersion = MAX_PSV_VERSION;
+    } else if (m_PSVInitInfo.PSVVersion < 1 && (ValMajor > 1 || (ValMajor == 1 && ValMinor >= 1))) {
       m_PSVInitInfo.PSVVersion = 1;
+    }
 
     const ShaderModel *SM = m_Module.GetShaderModel();
     UINT uCBuffers = m_Module.GetCBuffers().size();
@@ -610,7 +614,7 @@ public:
     case ShaderModel::Kind::Mesh: {
       pInfo->MS.MaxOutputVertices = (UINT)m_Module.GetMaxOutputVertices();
       pInfo->MS.MaxOutputPrimitives = (UINT)m_Module.GetMaxOutputPrimitives();
-      pInfo1->MeshOutputTopology = (UINT)m_Module.GetMeshOutputTopology();
+      pInfo1->MS1.MeshOutputTopology = (UINT)m_Module.GetMeshOutputTopology();
       Module *mod = m_Module.GetModule();
       const DataLayout &DL = mod->getDataLayout();
       unsigned totalByteSize = 0;

+ 36 - 13
lib/DxilRootSignature/DxilRootSignatureValidator.cpp

@@ -98,9 +98,9 @@ public:
 
 private:
   static const unsigned kMinVisType = (unsigned)DxilShaderVisibility::All;
-  static const unsigned kMaxVisType = (unsigned)DxilShaderVisibility::Pixel;
+  static const unsigned kMaxVisType = (unsigned)DxilShaderVisibility::MaxValue;
   static const unsigned kMinDescType = (unsigned)DxilDescriptorRangeType::SRV;
-  static const unsigned kMaxDescType = (unsigned)DxilDescriptorRangeType::Sampler;
+  static const unsigned kMaxDescType = (unsigned)DxilDescriptorRangeType::MaxValue;
 
   struct RegisterRange {
     NODE_TYPE nt;
@@ -171,6 +171,8 @@ void DescriptorTableVerifier::Verify(const DxilDescriptorRange1 *pRanges,
       bHasSamplers = true;
       break;
     default:
+      static_assert(DxilDescriptorRangeType::Sampler == DxilDescriptorRangeType::MaxValue,
+                    "otherwise, need to update cases here");
       EAT(DiagPrinter << "Unsupported RangeType value " << (uint32_t)pRange->RangeType
                       << " (descriptor table slot [" << iDTS << "], root parameter [" << iRP << "]).\n");
     }
@@ -240,19 +242,24 @@ void RootSignatureVerifier::AllowReservedRegisterSpace(bool bAllow) {
 const char* RangeTypeString(DxilDescriptorRangeType rt)
 {
   static const char *RangeType[] = {"SRV", "UAV", "CBV", "SAMPLER"};
-  return (rt <= DxilDescriptorRangeType::Sampler) ? RangeType[(unsigned)rt]
-                                                  : "unknown";
+  static_assert(_countof(RangeType) == ((unsigned)DxilDescriptorRangeType::MaxValue + 1),
+                "otherwise, need to update name array");
+  return (rt <= DxilDescriptorRangeType::MaxValue) ? RangeType[(unsigned)rt]
+                                                   : "unknown";
 }
 
 const char *VisTypeString(DxilShaderVisibility vis) {
   static const char *Vis[] = {"ALL",    "VERTEX",   "HULL",
-                              "DOMAIN", "GEOMETRY", "PIXEL"};
+                              "DOMAIN", "GEOMETRY", "PIXEL",
+                              "AMPLIFICATION", "MESH"};
+  static_assert(_countof(Vis) == ((unsigned)DxilShaderVisibility::MaxValue + 1),
+                "otherwise, need to update name array");
   unsigned idx = (unsigned)vis;
-  return vis <= DxilShaderVisibility::Pixel ? Vis[idx] : "unknown";
+  return vis <= DxilShaderVisibility::MaxValue ? Vis[idx] : "unknown";
 }
 
 static bool IsDxilShaderVisibility(DxilShaderVisibility v) {
-  return v <= DxilShaderVisibility::Pixel;
+  return v <= DxilShaderVisibility::MaxValue;
 }
 
 void RootSignatureVerifier::AddRegisterRange(unsigned iRP,
@@ -390,6 +397,8 @@ static DxilDescriptorRangeType GetRangeType(DxilRootParameterType RPT) {
   case DxilRootParameterType::SRV: return DxilDescriptorRangeType::SRV;
   case DxilRootParameterType::UAV: return DxilDescriptorRangeType::UAV;
   default:
+    static_assert(DxilRootParameterType::UAV == DxilRootParameterType::MaxValue,
+                  "otherwise, need to add cases here.");
     break;
   }
 
@@ -535,6 +544,8 @@ void RootSignatureVerifier::VerifyRootSignature(
     }
 
     default:
+      static_assert(DxilRootParameterType::UAV == DxilRootParameterType::MaxValue,
+                    "otherwise, need to add cases here.");
       EAT(DiagPrinter << "Unsupported ParameterType value " << (uint32_t)ParameterType
                       << " (root parameter " << iRP << ")\n");
     }
@@ -591,6 +602,16 @@ void RootSignatureVerifier::VerifyShader(DxilShaderVisibility VisType,
       bShaderDeniedByRootSig = true;
     }
     break;
+  case DxilShaderVisibility::Amplification:
+    if ((m_RootSignatureFlags & DxilRootSignatureFlags::DenyAmplificationShaderRootAccess) != DxilRootSignatureFlags::None) {
+      bShaderDeniedByRootSig = true;
+    }
+    break;
+  case DxilShaderVisibility::Mesh:
+    if ((m_RootSignatureFlags & DxilRootSignatureFlags::DenyMeshShaderRootAccess) != DxilRootSignatureFlags::None) {
+      bShaderDeniedByRootSig = true;
+    }
+    break;
   default:
     break;
   }
@@ -800,12 +821,14 @@ void StaticSamplerVerifier::Verify(const DxilStaticSamplerDesc* pDesc,
 
 static DxilShaderVisibility GetVisibilityType(DXIL::ShaderKind ShaderKind) {
   switch(ShaderKind) {
-  case DXIL::ShaderKind::Pixel:       return DxilShaderVisibility::Pixel;
-  case DXIL::ShaderKind::Vertex:      return DxilShaderVisibility::Vertex;
-  case DXIL::ShaderKind::Geometry:    return DxilShaderVisibility::Geometry;
-  case DXIL::ShaderKind::Hull:        return DxilShaderVisibility::Hull;
-  case DXIL::ShaderKind::Domain:      return DxilShaderVisibility::Domain;
-  default:                            return DxilShaderVisibility::All;
+  case DXIL::ShaderKind::Pixel:         return DxilShaderVisibility::Pixel;
+  case DXIL::ShaderKind::Vertex:        return DxilShaderVisibility::Vertex;
+  case DXIL::ShaderKind::Geometry:      return DxilShaderVisibility::Geometry;
+  case DXIL::ShaderKind::Hull:          return DxilShaderVisibility::Hull;
+  case DXIL::ShaderKind::Domain:        return DxilShaderVisibility::Domain;
+  case DXIL::ShaderKind::Amplification: return DxilShaderVisibility::Amplification;
+  case DXIL::ShaderKind::Mesh:          return DxilShaderVisibility::Mesh;
+  default:                              return DxilShaderVisibility::All;
   }
 }
 

+ 13 - 0
tools/clang/lib/Parse/HLSLRootSignature.cpp

@@ -287,6 +287,8 @@ void RootSignatureTokenizer::ReadNextToken(uint32_t BufferIdx)
               KW(DENY_DOMAIN_SHADER_ROOT_ACCESS) || 
               KW(DENY_GEOMETRY_SHADER_ROOT_ACCESS) || 
               KW(DENY_PIXEL_SHADER_ROOT_ACCESS) ||
+              KW(DENY_AMPLIFICATION_SHADER_ROOT_ACCESS) ||
+              KW(DENY_MESH_SHADER_ROOT_ACCESS) ||
               KW(DESCRIPTORS_VOLATILE) ||
               KW(DATA_VOLATILE) ||
               KW(DATA_STATIC) ||
@@ -360,6 +362,7 @@ void RootSignatureTokenizer::ReadNextToken(uint32_t BufferIdx)
               KW(SHADER_VISIBILITY_ALL)      ||  KW(SHADER_VISIBILITY_VERTEX) || 
               KW(SHADER_VISIBILITY_HULL)     || KW(SHADER_VISIBILITY_DOMAIN)  ||
               KW(SHADER_VISIBILITY_GEOMETRY) || KW(SHADER_VISIBILITY_PIXEL) ||
+              KW(SHADER_VISIBILITY_AMPLIFICATION) || KW(SHADER_VISIBILITY_MESH) ||
               KW(STATIC_BORDER_COLOR_TRANSPARENT_BLACK) ||
               KW(STATIC_BORDER_COLOR_OPAQUE_BLACK) ||
               KW(STATIC_BORDER_COLOR_OPAQUE_WHITE);
@@ -678,6 +681,8 @@ HRESULT RootSignatureParser::ParseRootSignatureFlags(DxilRootSignatureFlags & Fl
     //  DENY_DOMAIN_SHADER_ROOT_ACCESS
     //  DENY_GEOMETRY_SHADER_ROOT_ACCESS
     //  DENY_PIXEL_SHADER_ROOT_ACCESS
+    //  DENY_AMPLIFICATION_SHADER_ROOT_ACCESS
+    //  DENY_MESH_SHADER_ROOT_ACCESS
     //  ALLOW_STREAM_OUTPUT
     //  LOCAL_ROOT_SIGNATURE
 
@@ -724,6 +729,12 @@ HRESULT RootSignatureParser::ParseRootSignatureFlags(DxilRootSignatureFlags & Fl
             case TokenType::DENY_PIXEL_SHADER_ROOT_ACCESS:
                 Flags |= DxilRootSignatureFlags::DenyPixelShaderRootAccess;
                 break;
+            case TokenType::DENY_AMPLIFICATION_SHADER_ROOT_ACCESS:
+                Flags |= DxilRootSignatureFlags::DenyAmplificationShaderRootAccess;
+                break;
+            case TokenType::DENY_MESH_SHADER_ROOT_ACCESS:
+                Flags |= DxilRootSignatureFlags::DenyMeshShaderRootAccess;
+                break;
             case TokenType::ALLOW_STREAM_OUTPUT:
                 Flags |= DxilRootSignatureFlags::AllowStreamOutput;
                 break;
@@ -1290,6 +1301,8 @@ HRESULT RootSignatureParser::ParseVisibility(DxilShaderVisibility & Vis)
     case TokenType::SHADER_VISIBILITY_DOMAIN:   Vis = DxilShaderVisibility::Domain;   break;
     case TokenType::SHADER_VISIBILITY_GEOMETRY: Vis = DxilShaderVisibility::Geometry; break;
     case TokenType::SHADER_VISIBILITY_PIXEL:    Vis = DxilShaderVisibility::Pixel;    break;
+    case TokenType::SHADER_VISIBILITY_AMPLIFICATION:  Vis = DxilShaderVisibility::Amplification;  break;
+    case TokenType::SHADER_VISIBILITY_MESH:     Vis = DxilShaderVisibility::Mesh;     break;
     default:
         IFC(Error(ERR_RS_UNEXPECTED_TOKEN, 
                  "Unexpected visibility value: '%s'.", Token.GetStr()));

+ 4 - 0
tools/clang/lib/Parse/HLSLRootSignature.h

@@ -69,6 +69,8 @@ public:
             SHADER_VISIBILITY_DOMAIN,
             SHADER_VISIBILITY_GEOMETRY,
             SHADER_VISIBILITY_PIXEL,
+            SHADER_VISIBILITY_AMPLIFICATION,
+            SHADER_VISIBILITY_MESH,
 
             // Root signature flags
             RootFlags,
@@ -78,6 +80,8 @@ public:
             DENY_DOMAIN_SHADER_ROOT_ACCESS,
             DENY_GEOMETRY_SHADER_ROOT_ACCESS,
             DENY_PIXEL_SHADER_ROOT_ACCESS,
+            DENY_AMPLIFICATION_SHADER_ROOT_ACCESS,
+            DENY_MESH_SHADER_ROOT_ACCESS,
             ALLOW_STREAM_OUTPUT,
             LOCAL_ROOT_SIGNATURE,
 

+ 78 - 0
tools/clang/test/CodeGenHLSL/mesh/mesh-rootsig.hlsl

@@ -0,0 +1,78 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: dx.op.getMeshPayload.struct.MeshPayload
+// CHECK: dx.op.setMeshOutputCounts(i32 168, i32 30, i32 10)
+// CHECK: dx.op.emitIndices
+// CHECK: dx.op.storeVertexOutput
+// CHECK: dx.op.storePrimitiveOutput
+
+#define MAX_VERT 30
+#define MAX_PRIM 10
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float4 color : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float malnor : MALNOR;
+    float alnorm : ALNORM;
+    float ormaln : ORMALN;
+    int4 layer0 : LAYER0;
+    int2 layer1 : LAYER1;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    float alnorm;
+    float ormaln;
+    int layer[6];
+};
+
+groupshared float gsMem[MAX_PRIM];
+
+cbuffer CB1 : register(b1, space2)
+{
+  uint idx;
+}
+
+[RootSignature("CBV(b1, space=2, visibility=SHADER_VISIBILITY_MESH)")]
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    if (tig < MAX_PRIM) {
+      uint3 indices = (tig * 3) + uint3(0, 1, 2);
+      primIndices[tig] = indices;
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      gsMem[tig] = op.normal;
+      op.malnor = gsMem[(tig + idx) % MAX_PRIM];
+      op.alnorm = mpl.alnorm;
+      op.ormaln = mpl.ormaln;
+      op.layer0 = int4(mpl.layer[0], mpl.layer[1], mpl.layer[2], mpl.layer[3]);
+      op.layer1 = int2(mpl.layer[4], mpl.layer[5]);
+      prims[tig] = op;
+    }
+    if (tig < MAX_VERT) {
+      MeshPerVertex ov;
+      if (vid % 2) {
+          ov.position = float4(0.0, 1.0, 2.0, 3.0);
+          ov.color = float4(4.0, 5.0, 6.0, 7.0);
+      } else {
+          ov.position = float4(10.0, 11.0, 12.0, 13.0);
+          ov.color = float4(14.0, 15.0, 16.0, 17.0);
+      }
+      verts[tig] = ov;
+    }
+}