2
0
Эх сурвалжийг харах

add DXIL tests to verify mesh shader's output size and payload plus output size

czw831024 6 жил өмнө
parent
commit
5fca2b49e1

+ 1 - 1
docs/DXIL.rst

@@ -3142,7 +3142,7 @@ SM.MAXMSSMSIZE                            Total Thread Group Shared Memory stora
 SM.MAXTGSMSIZE                            Total Thread Group Shared Memory storage is %0, exceeded %1
 SM.MAXTHEADGROUP                          Declared Thread Group Count %0 (X*Y*Z) is beyond the valid maximum of %1
 SM.MESHPSIGROWCOUNT                       For shader '%0', primitive output signatures are taking up more than %1 rows
-SM.MESHSHADERINOUTSIZE                    For shader '%0', input plus output size is greater than %1
+SM.MESHSHADERINOUTSIZE                    For shader '%0', payload plus output size is greater than %1
 SM.MESHSHADERMAXPRIMITIVECOUNT            MS max primitive output count must be [0..%0].  %1 specified
 SM.MESHSHADERMAXVERTEXCOUNT               MS max vertex output count must be [0..%0].  %1 specified
 SM.MESHSHADEROUTPUTSIZE                   For shader '%0', vertex plus primitive output size is greater than %1

+ 3 - 3
include/dxc/DXIL/DxilConstants.h

@@ -86,11 +86,11 @@ namespace DXIL {
   const unsigned kMinMSASThreadGroupX = 1;
   const unsigned kMinMSASThreadGroupY = 1;
   const unsigned kMinMSASThreadGroupZ = 1;
-  const unsigned kMaxMSASPayloadSize = 16384;
+  const unsigned kMaxMSASPayloadBytes = 1024 * 16;
   const unsigned kMaxMSOutputPrimitiveCount = 256;
   const unsigned kMaxMSOutputVertexCount = 256;
-  const unsigned kMaxMSOutputTotalScalars = 32768;
-  const unsigned kMaxMSInputOutputTotalScalars = 41984;
+  const unsigned kMaxMSOutputTotalBytes = 1024 * 32;
+  const unsigned kMaxMSInputOutputTotalBytes = 1024 * 47;
   const unsigned kMaxMSVSigRows = 32;
   const unsigned kMaxMSPSigRows = 32;
   const unsigned kMaxMSTotalSigRows = 32;

+ 1 - 1
include/dxc/HLSL/DxilValidation.h

@@ -227,7 +227,7 @@ enum class ValidationRule : unsigned {
   SmMaxTGSMSize, // Total Thread Group Shared Memory storage is %0, exceeded %1
   SmMaxTheadGroup, // Declared Thread Group Count %0 (X*Y*Z) is beyond the valid maximum of %1
   SmMeshPSigRowCount, // For shader '%0', primitive output signatures are taking up more than %1 rows
-  SmMeshShaderInOutSize, // For shader '%0', input plus output size is greater than %1
+  SmMeshShaderInOutSize, // For shader '%0', payload plus output size is greater than %1
   SmMeshShaderMaxPrimitiveCount, // MS max primitive output count must be [0..%0].  %1 specified
   SmMeshShaderMaxVertexCount, // MS max vertex output count must be [0..%0].  %1 specified
   SmMeshShaderOutputSize, // For shader '%0', vertex plus primitive output size is greater than %1

+ 13 - 13
lib/HLSL/DxilValidation.cpp

@@ -266,7 +266,7 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::SmMeshShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
     case hlsl::ValidationRule::SmMeshShaderPayloadSizeDeclared: return "For shader '%0', payload size %1 is greater than declared size of %2 bytes";
     case hlsl::ValidationRule::SmMeshShaderOutputSize: return "For shader '%0', vertex plus primitive output size is greater than %1";
-    case hlsl::ValidationRule::SmMeshShaderInOutSize: return "For shader '%0', input plus output size is greater than %1";
+    case hlsl::ValidationRule::SmMeshShaderInOutSize: return "For shader '%0', payload plus output size is greater than %1";
     case hlsl::ValidationRule::SmMeshVSigRowCount: return "For shader '%0', vertex output signatures are taking up more than %1 rows";
     case hlsl::ValidationRule::SmMeshPSigRowCount: return "For shader '%0', primitive output signatures are taking up more than %1 rows";
     case hlsl::ValidationRule::SmMeshTotalSigRowCount: return "For shader '%0', vertex and primitive output signatures are taking up more than %1 rows";
@@ -2890,10 +2890,10 @@ static void ValidateMsIntrinsics(Function *F,
 
     DxilFunctionProps &prop = ValCtx.DxilMod.GetDxilFunctionProps(F);
 
-    if (payloadSize > DXIL::kMaxMSASPayloadSize ||
-        prop.ShaderProps.MS.payloadSizeInBytes > DXIL::kMaxMSASPayloadSize) {
+    if (payloadSize > DXIL::kMaxMSASPayloadBytes ||
+        prop.ShaderProps.MS.payloadSizeInBytes > DXIL::kMaxMSASPayloadBytes) {
       ValCtx.EmitFormatError(ValidationRule::SmMeshShaderPayloadSize,
-        { F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize) });
+        { F->getName(), std::to_string(DXIL::kMaxMSASPayloadBytes) });
     }
 
     if (prop.ShaderProps.MS.payloadSizeInBytes < payloadSize) {
@@ -2919,11 +2919,11 @@ static void ValidateAsIntrinsics(Function *F, ValidationContext &ValCtx, CallIns
 
       DxilFunctionProps &prop = ValCtx.DxilMod.GetDxilFunctionProps(F);
 
-      if (payloadSize > DXIL::kMaxMSASPayloadSize ||
-          prop.ShaderProps.AS.payloadSizeInBytes > DXIL::kMaxMSASPayloadSize) {
+      if (payloadSize > DXIL::kMaxMSASPayloadBytes ||
+          prop.ShaderProps.AS.payloadSizeInBytes > DXIL::kMaxMSASPayloadBytes) {
         ValCtx.EmitFormatError(
             ValidationRule::SmAmplificationShaderPayloadSize,
-            {F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize)});
+            {F->getName(), std::to_string(DXIL::kMaxMSASPayloadBytes)});
       }
 
       if (prop.ShaderProps.AS.payloadSizeInBytes < payloadSize) {
@@ -2957,9 +2957,9 @@ static void ValidateAsIntrinsics(Function *F, ValidationContext &ValCtx, CallIns
   const DataLayout &DL = F->getParent()->getDataLayout();
   unsigned payloadSize = DL.getTypeAllocSize(payloadTy);
 
-  if (payloadSize > DXIL::kMaxMSASPayloadSize) {
+  if (payloadSize > DXIL::kMaxMSASPayloadBytes) {
     ValCtx.EmitFormatError(ValidationRule::SmAmplificationShaderPayloadSize,
-      { F->getName(), std::to_string(DXIL::kMaxMSASPayloadSize) });
+      { F->getName(), std::to_string(DXIL::kMaxMSASPayloadBytes) });
   }
 }
 
@@ -4985,17 +4985,17 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
       totalOutputScalars += SE->GetRows() * SE->GetCols() * maxPrimitiveCount;
     }
 
-    if (totalOutputScalars > DXIL::kMaxMSOutputTotalScalars) {
+    if (totalOutputScalars * 4 > DXIL::kMaxMSOutputTotalBytes) {
       ValCtx.EmitFormatError(
         ValidationRule::SmMeshShaderOutputSize,
-        { F.getName(), std::to_string(DXIL::kMaxMSOutputTotalScalars) });
+        { F.getName(), std::to_string(DXIL::kMaxMSOutputTotalBytes) });
     }
 
     unsigned totalInputOutputScalars = totalOutputScalars + props.ShaderProps.MS.payloadSizeInBytes;
-    if (totalInputOutputScalars > DXIL::kMaxMSInputOutputTotalScalars) {
+    if (totalInputOutputScalars * 4 > DXIL::kMaxMSInputOutputTotalBytes) {
       ValCtx.EmitFormatError(
         ValidationRule::SmMeshShaderInOutSize,
-        { F.getName(), std::to_string(DXIL::kMaxMSInputOutputTotalScalars) });
+        { F.getName(), std::to_string(DXIL::kMaxMSInputOutputTotalBytes) });
     }
   }
 }

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh-val/msOversizeOutput.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: vertex plus primitive output size is greater than 32768
+
+#define MAX_VERT 32
+#define MAX_PRIM 128
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float4 malnor[16] : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor;
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor[0].x = mpl.malnor;
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 63 - 0
tools/clang/test/CodeGenHLSL/mesh-val/msOversizePayloadOutput.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: payload plus output size is greater than 48128
+
+#define MAX_VERT 32
+#define MAX_PRIM 110
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+    float4 malnor[16] : MALNOR;
+    int layer[4] : LAYER;
+};
+
+struct MeshPayload {
+    float normal;
+    float malnor[4000];
+    int layer[4];
+};
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+      primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+      MeshPerPrimitive op;
+      op.normal = mpl.normal;
+      op.malnor[0].x = mpl.malnor[0];
+      op.layer[0] = mpl.layer[0];
+      op.layer[1] = mpl.layer[1];
+      op.layer[2] = mpl.layer[2];
+      op.layer[3] = mpl.layer[3];
+      prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 10 - 0
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -262,6 +262,8 @@ public:
   TEST_METHOD(MeshMissingSetMeshOutputCounts)
   TEST_METHOD(MeshNonDominatingSetMeshOutputCounts)
   TEST_METHOD(MeshOversizePayload)
+  TEST_METHOD(MeshOversizeOutput)
+  TEST_METHOD(MeshOversizePayloadOutput)
   TEST_METHOD(MeshMultipleGetMeshPayload)
   TEST_METHOD(MeshOutofRangeMaxVertexCount)
   TEST_METHOD(MeshOutofRangeMaxPrimitiveCount)
@@ -3563,6 +3565,14 @@ TEST_F(ValidationTest, MeshOversizePayload) {
   TestCheck(L"..\\CodeGenHLSL\\mesh-val\\msOversizePayload.hlsl");
 }
 
+TEST_F(ValidationTest, MeshOversizeOutput) {
+  TestCheck(L"..\\CodeGenHLSL\\mesh-val\\msOversizeOutput.hlsl");
+}
+
+TEST_F(ValidationTest, MeshOversizePayloadOutput) {
+  TestCheck(L"..\\CodeGenHLSL\\mesh-val\\msOversizePayloadOutput.hlsl");
+}
+
 TEST_F(ValidationTest, MeshMultipleGetMeshPayload) {
   RewriteAssemblyCheckMsg(L"..\\CodeGenHLSL\\mesh-val\\mesh.hlsl", "ms_6_5",
                           "%([0-9]+) = call %struct.MeshPayload\\* @dx.op.getMeshPayload.struct.MeshPayload\\(i32 170\\)  ; GetMeshPayload\\(\\)",

+ 1 - 1
utils/hct/hctdb.py

@@ -2456,7 +2456,7 @@ class db_dxil(object):
         self.add_valrule("Sm.MeshShaderPayloadSize", "For shader '%0', payload size is greater than %1")
         self.add_valrule("Sm.MeshShaderPayloadSizeDeclared", "For shader '%0', payload size %1 is greater than declared size of %2 bytes")
         self.add_valrule("Sm.MeshShaderOutputSize", "For shader '%0', vertex plus primitive output size is greater than %1")
-        self.add_valrule("Sm.MeshShaderInOutSize", "For shader '%0', input plus output size is greater than %1")
+        self.add_valrule("Sm.MeshShaderInOutSize", "For shader '%0', payload plus output size is greater than %1")
         self.add_valrule("Sm.MeshVSigRowCount", "For shader '%0', vertex output signatures are taking up more than %1 rows")
         self.add_valrule("Sm.MeshPSigRowCount", "For shader '%0', primitive output signatures are taking up more than %1 rows")
         self.add_valrule("Sm.MeshTotalSigRowCount", "For shader '%0', vertex and primitive output signatures are taking up more than %1 rows")