Explorar o código

Mesh shader output size calculation fix (#2440)

missed align32 validation for mesh outputs
amarpMSFT %!s(int64=6) %!d(string=hai) anos
pai
achega
63cbac780c

+ 9 - 7
lib/HLSL/DxilValidation.cpp

@@ -4974,24 +4974,26 @@ static void ValidateEntrySignatures(ValidationContext &ValCtx,
         { F.getName(), std::to_string(DXIL::kMaxMSTotalSigRows) });
     }
 
-    unsigned maxVertexCount = props.ShaderProps.MS.maxVertexCount;
-    unsigned maxPrimitiveCount = props.ShaderProps.MS.maxPrimitiveCount;
+    const unsigned kScalarSizeForMSAttributes = 4;
+    #define ALIGN32(n) (((n) + 31) & ~31)
+    unsigned maxAlign32VertexCount = ALIGN32(props.ShaderProps.MS.maxVertexCount);
+    unsigned maxAlign32PrimitiveCount = ALIGN32(props.ShaderProps.MS.maxPrimitiveCount);
     unsigned totalOutputScalars = 0;
     for (auto &SE : S.OutputSignature.GetElements()) {
-      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxVertexCount;
+      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxAlign32VertexCount;
     }
     for (auto &SE : S.PatchConstOrPrimSignature.GetElements()) {
-      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxPrimitiveCount;
+      totalOutputScalars += SE->GetRows() * SE->GetCols() * maxAlign32PrimitiveCount;
     }
 
-    if (totalOutputScalars * 4 > DXIL::kMaxMSOutputTotalBytes) {
+    if (totalOutputScalars*kScalarSizeForMSAttributes > DXIL::kMaxMSOutputTotalBytes) {
       ValCtx.EmitFormatError(
         ValidationRule::SmMeshShaderOutputSize,
         { F.getName(), std::to_string(DXIL::kMaxMSOutputTotalBytes) });
     }
 
-    unsigned totalInputOutputScalars = totalOutputScalars + props.ShaderProps.MS.payloadSizeInBytes;
-    if (totalInputOutputScalars * 4 > DXIL::kMaxMSInputOutputTotalBytes) {
+    unsigned totalInputOutputBytes = totalOutputScalars*kScalarSizeForMSAttributes + props.ShaderProps.MS.payloadSizeInBytes;
+    if (totalInputOutputBytes > DXIL::kMaxMSInputOutputTotalBytes) {
       ValCtx.EmitFormatError(
         ValidationRule::SmMeshShaderInOutSize,
         { F.getName(), std::to_string(DXIL::kMaxMSInputOutputTotalBytes) });

+ 7 - 5
tools/clang/test/CodeGenHLSL/mesh-val/msOversizePayloadOutput.hlsl

@@ -2,8 +2,8 @@
 
 // CHECK: payload plus output size is greater than 48128
 
-#define MAX_VERT 32
-#define MAX_PRIM 110
+#define MAX_VERT 230
+#define MAX_PRIM 93
 #define NUM_THREADS 32
 struct MeshPerVertex {
     float4 position : SV_Position;
@@ -12,13 +12,13 @@ struct MeshPerVertex {
 
 struct MeshPerPrimitive {
     float normal : NORMAL;
-    float4 malnor[16] : MALNOR;
-    int layer[4] : LAYER;
+    float4 malnor[14] : MALNOR;
+    int layer[6] : LAYER;
 };
 
 struct MeshPayload {
     float normal;
-    float malnor[4000];
+    float malnor[4091];
     int layer[4];
 };
 
@@ -57,6 +57,8 @@ void main(
       op.layer[1] = mpl.layer[1];
       op.layer[2] = mpl.layer[2];
       op.layer[3] = mpl.layer[3];
+      op.layer[4] = mpl.layer[3];
+      op.layer[5] = mpl.layer[3];
       prims[tig / 3] = op;
     }
     verts[tig] = ov;