Quellcode durchsuchen

Add validation checks for HS/GS attributes after parsing (#768)

This commit add checking for three HS/GS rules:

* HS entry point should have the patchconstantfunc attribute
* GS entry point should have the maxvertexcount attribute
* GS stream-output objects should be inout parameters
Lei Zhang vor 7 Jahren
Ursprung
Commit
406fe38220

+ 6 - 0
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -7658,6 +7658,12 @@ def err_hlsl_intrinsic_template_arg_requires_2018: Error<
 def err_hlsl_intrinsic_template_arg_scalar_vector_16: Error<
    "Explicit template arguments on intrinsic %0 are limited one to scalar or vector type up to 16 bytes in size.">;
 }
+def err_hlsl_missing_maxvertexcount_attr: Error<
+   "GS entry point must have the maxvertexcount attribute">;
+def err_hlsl_missing_patchconstantfunc_attr: Error<
+   "HS entry point must have the patchconstantfunc attribute">;
+def err_hlsl_missing_inout_attr: Error<
+   "stream-output object must be an inout parameter">;
 // HLSL Change Ends
 
 let CategoryName = "OpenMP Issue" in {

+ 1 - 0
tools/clang/include/clang/Basic/LangOptions.h

@@ -151,6 +151,7 @@ public:
   // MS Change Starts
   unsigned HLSLVersion;  // Only supported for IntelliSense scenarios.
   std::string HLSLEntryFunction;
+  std::string HLSLProfile;
   unsigned RootSigMajor;
   unsigned RootSigMinor;
   bool IsHLSLLibrary;

+ 12 - 8
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -1291,6 +1291,16 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     isHS = true;
     funcProps->shaderKind = DXIL::ShaderKind::Hull;
     HSEntryPatchConstantFuncAttr[F] = Attr;
+  } else {
+    // TODO: This is a duplicate check. We also have this check in
+    // hlsl::DiagnoseTranslationUnit(clang::Sema*).
+    if (isEntry && SM->IsHS()) {
+      unsigned DiagID = Diags.getCustomDiagID(
+          DiagnosticsEngine::Error,
+          "HS entry point must have the patchconstantfunc attribute");
+      Diags.Report(FD->getLocation(), DiagID);
+      return;
+    }
   }
 
   if (const HLSLOutputControlPointsAttr *Attr =
@@ -4065,14 +4075,8 @@ void CGMSHLSLRuntime::SetPatchConstantFunction(const EntryFunctionInfo &EntryFun
 
   auto AttrsIter = HSEntryPatchConstantFuncAttr.find(EntryFunc.Func);
 
-  if (AttrsIter == HSEntryPatchConstantFuncAttr.end()) {
-    DiagnosticsEngine &Diags = CGM.getDiags();
-    unsigned DiagID =
-      Diags.getCustomDiagID(DiagnosticsEngine::Error,
-        "HS entry is missing patchconstantfunc attribute.");
-    Diags.Report(EntryFunc.SL, DiagID);
-    return;
-  }
+  DXASSERT(AttrsIter != HSEntryPatchConstantFuncAttr.end(),
+           "we have checked this in AddHLSLFunctionInfo()");
 
   SetPatchConstantFunctionWithAttr(Entry, AttrsIter->second);
 }

+ 39 - 10
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -36,6 +36,7 @@
 #include "dxc/HlslIntrinsicOp.h"
 #include "gen_intrin_main_tables_15.h"
 #include "dxc/HLSL/HLOperations.h"
+#include "dxc/HLSL/DxilShaderModel.h"
 #include <array>
 
 enum ArBasicKind {
@@ -9084,16 +9085,34 @@ void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
   // NOTE: the information gathered here could be used to bypass code generation
   // on functions that are unreachable (as an early form of dead code elimination).
   if (pEntryPointDecl) {
-    if (const HLSLPatchConstantFuncAttr *Attr =
-            pEntryPointDecl->getAttr<HLSLPatchConstantFuncAttr>()) {
-      NameLookup NL = GetSingleFunctionDeclByName(self, Attr->getFunctionName(), /*checkPatch*/ true);
-      if (!NL.Found || !NL.Found->hasBody()) {
-        unsigned id = Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
-          "missing patch function definition");
-        Diags.Report(id);
+    const auto *shaderModel =
+        hlsl::ShaderModel::GetByName(self->getLangOpts().HLSLProfile.c_str());
+
+    if (shaderModel->IsGS()) {
+      // Validate that GS has the maxvertexcount attribute
+      if (!pEntryPointDecl->hasAttr<HLSLMaxVertexCountAttr>()) {
+        self->Diag(pEntryPointDecl->getLocation(),
+                   diag::err_hlsl_missing_maxvertexcount_attr);
+        return;
+      }
+    } else if (shaderModel->IsHS()) {
+      if (const HLSLPatchConstantFuncAttr *Attr =
+              pEntryPointDecl->getAttr<HLSLPatchConstantFuncAttr>()) {
+        NameLookup NL = GetSingleFunctionDeclByName(
+            self, Attr->getFunctionName(), /*checkPatch*/ true);
+        if (!NL.Found || !NL.Found->hasBody()) {
+          unsigned id =
+              Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
+                                    "missing patch function definition");
+          Diags.Report(id);
+          return;
+        }
+        pPatchFnDecl = NL.Found;
+      } else {
+        self->Diag(pEntryPointDecl->getLocation(),
+                   diag::err_hlsl_missing_patchconstantfunc_attr);
         return;
       }
-      pPatchFnDecl = NL.Found;
     }
 
     hlsl::CallGraphWithRecurseGuard CG;
@@ -10769,10 +10788,11 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC,
     }
   }
 
+  HLSLExternalSource *hlslSource = HLSLExternalSource::FromSema(this);
+  ArBasicKind basicKind = hlslSource->GetTypeElementKind(qt);
+
   if (hasSignSpec) {
-     HLSLExternalSource *hlslSource = HLSLExternalSource::FromSema(this);
      ArTypeObjectKind objKind = hlslSource->GetTypeObjectKind(qt);
-     ArBasicKind basicKind = hlslSource->GetTypeElementKind(qt);
      // vectors or matrices can only have unsigned integer types.
      if (objKind == AR_TOBJ_MATRIX || objKind == AR_TOBJ_VECTOR || objKind == AR_TOBJ_BASIC || objKind == AR_TOBJ_ARRAY) {
          if (!IS_BASIC_UNSIGNABLE(basicKind)) {
@@ -11006,6 +11026,15 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC,
     }
   }
 
+  // Validate that stream-ouput objects are marked as inout
+  if (isParameter && !(usageIn && usageOut) &&
+      (basicKind == ArBasicKind::AR_OBJECT_LINESTREAM ||
+       basicKind == ArBasicKind::AR_OBJECT_POINTSTREAM ||
+       basicKind == ArBasicKind::AR_OBJECT_TRIANGLESTREAM)) {
+    Diag(D.getLocStart(), diag::err_hlsl_missing_inout_attr);
+    result = false;
+  }
+
   // Validate unusual annotations.
   hlsl::DiagnoseUnusualAnnotationsForHLSL(*this, D.UnusualAnnotations);
   auto && unusualIter = D.UnusualAnnotations.begin();

+ 26 - 0
tools/clang/test/CodeGenHLSL/attributes-gs-no-inout-main.hlsl

@@ -0,0 +1,26 @@
+// RUN: %dxc -E main -T gs_6_0 %s | FileCheck %s
+
+// CHECK: 18:11: error: stream-output object must be an inout parameter
+
+struct GsOut {
+    float4 pos : SV_Position;
+};
+
+void foo(inout LineStream<GsOut> param) {
+    GsOut vertex;
+    vertex = (GsOut)0;
+    param.Append(vertex);
+}
+
+// Missing inout on outData
+[maxvertexcount(3)]
+void main(in triangle float4 pos[3] : SV_Position,
+          LineStream<GsOut> outData) {
+    GsOut vertex;
+    vertex.pos = pos[0];
+    outData.Append(vertex);
+
+    foo(outData);
+
+    outData.RestartStrip();
+}

+ 26 - 0
tools/clang/test/CodeGenHLSL/attributes-gs-no-inout-other.hlsl

@@ -0,0 +1,26 @@
+// RUN: %dxc -E main -T gs_6_0 %s | FileCheck %s
+
+// CHECK: 10:10: error: stream-output object must be an inout parameter
+
+struct GsOut {
+    float4 pos : SV_Position;
+};
+
+// Missing inout on param
+void foo(LineStream<GsOut> param) {
+    GsOut vertex;
+    vertex = (GsOut)0;
+    param.Append(vertex);
+}
+
+[maxvertexcount(3)]
+void main(in triangle float4 pos[3] : SV_Position,
+          inout LineStream<GsOut> outData) {
+    GsOut vertex;
+    vertex.pos = pos[0];
+    outData.Append(vertex);
+
+    foo(outData);
+
+    outData.RestartStrip();
+}

+ 25 - 0
tools/clang/test/CodeGenHLSL/attributes-gs-no-maxvertexcount.hlsl

@@ -0,0 +1,25 @@
+// RUN: %dxc -E main -T gs_6_0 %s | FileCheck %s
+
+// CHECK: :16:6: error: GS entry point must have the maxvertexcount attribute
+
+struct GsOut {
+    float4 pos : SV_Position;
+};
+
+void foo(inout LineStream<GsOut> param) {
+    GsOut vertex;
+    vertex = (GsOut)0;
+    param.Append(vertex);
+}
+
+// Missing maxvertexcount attribute
+void main(in triangle float4 pos[3] : SV_Position,
+          inout LineStream<GsOut> outData) {
+    GsOut vertex;
+    vertex.pos = pos[0];
+    outData.Append(vertex);
+
+    foo(outData);
+
+    outData.RestartStrip();
+}

+ 38 - 0
tools/clang/test/CodeGenHLSL/attributes-hs-no-pcf.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc -E main -T hs_6_0 %s | FileCheck %s
+
+// CHECK: :32:9: error: HS entry point must have the patchconstantfunc attribute
+
+#define NumOutPoints 2
+
+struct HsCpIn {
+    float4 pos : SV_Position;
+};
+
+struct HsCpOut {
+    float4 pos : SV_Position;
+};
+
+struct HsPcfOut
+{
+  float tessOuter[4] : SV_TessFactor;
+  float tessInner[2] : SV_InsideTessFactor;
+};
+
+HsPcfOut pcf(InputPatch<HsCpIn, NumOutPoints> patch, uint patchId : SV_PrimitiveID) {
+  HsPcfOut output;
+  output = (HsPcfOut)0;
+  return output;
+}
+
+// Missing patchconstantfunc attribute
+[domain("quad")]
+[partitioning("fractional_odd")]
+[outputtopology("triangle_ccw")]
+[outputcontrolpoints(NumOutPoints)]
+HsCpOut main(InputPatch<HsCpIn, NumOutPoints> patch,
+             uint cpId : SV_OutputControlPointID,
+             uint patchId : SV_PrimitiveID) {
+    HsCpOut output;
+    output = (HsCpOut)0;
+    return output;
+}

+ 2 - 2
tools/clang/test/CodeGenSPIRV/hs.pcf.void.hlsl

@@ -20,9 +20,9 @@ HS_CONSTANT_DATA_OUTPUT PCF() {
   return Output;
 }
 
-[domain("isoline")]
+[domain("quad")]
 [partitioning("fractional_odd")]
-[outputtopology("line")]
+[outputtopology("triangle_cw")]
 [outputcontrolpoints(16)]
 [patchconstantfunc("PCF")]
 BEZIER_CONTROL_POINT main(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POINTS> ip, uint i : SV_OutputControlPointID, uint PatchID : SV_PrimitiveID) {

+ 1 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -391,6 +391,7 @@ public:
 
       compiler.getLangOpts().HLSLEntryFunction =
       compiler.getCodeGenOpts().HLSLEntryFunction = pUtf8EntryPoint.m_psz;
+      compiler.getLangOpts().HLSLProfile =
       compiler.getCodeGenOpts().HLSLProfile = pUtf8TargetProfile.m_psz;
 
       unsigned rootSigMajor = 0;

+ 21 - 1
tools/clang/unittests/HLSL/Objects.cpp

@@ -611,11 +611,21 @@ public:
 
     FormatTypeNameAndPreamble(sod, typeName, &preambleDecl);
 
+    std::string parmType = typeName;
+    // Stream-output objects must be declared as inout.
+    switch (sod.Kind) {
+    case SOK_StreamOutputLine:
+    case SOK_StreamOutputPoint:
+    case SOK_StreamOutputTriangle:
+      parmType = "inout " + parmType;
+      break;
+    }
+
     sprintf_s(result, _countof(result),
               "%s"
               "void f(%s parameter) { }\n"
               "float ps(float4 color : COLOR) { %s localVar; f(localVar); return 0; }",
-              preambleDecl, typeName, typeName);
+              preambleDecl, parmType.c_str(), typeName);
 
     return std::string(result);
   }
@@ -769,6 +779,16 @@ TEST_F(ObjectTest, PassToInoutArgs) {
     std::stringstream programText;
     unsigned uniqueId = 0;
     for (const auto &iop : InOutParameterModifierData) {
+
+      switch (sod.Kind) {
+      case SOK_StreamOutputLine:
+      case SOK_StreamOutputPoint:
+      case SOK_StreamOutputTriangle:
+        // Stream-output objects can only be inout. Skip other cases.
+        if (std::string(iop.Keyword) != "inout")
+          continue;
+      }
+
       char typeName[64];
       const char* preambleDecl;
 

+ 20 - 0
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -226,6 +226,10 @@ public:
   TEST_METHOD(ClipCullMaxComponents)
   TEST_METHOD(ClipCullMaxRows)
   TEST_METHOD(DuplicateSysValue)
+  TEST_METHOD(GSMainMissingAttributeFail)
+  TEST_METHOD(GSOtherMissingAttributeFail)
+  TEST_METHOD(GSMissingMaxVertexCountFail)
+  TEST_METHOD(HSMissingPCFFail)
   TEST_METHOD(GetAttributeAtVertexInVSFail)
   TEST_METHOD(GetAttributeAtVertexIn60Fail)
   TEST_METHOD(GetAttributeAtVertexInterpFail)
@@ -3034,6 +3038,22 @@ float4 main(uint vid : SV_ViewID, float3 In[31] : INPUT) : SV_Target \
     /*bRegex*/true);
 }
 
+TEST_F(ValidationTest, GSMainMissingAttributeFail) {
+  TestCheck(L"..\\CodeGenHLSL\\attributes-gs-no-inout-main.hlsl");
+}
+
+TEST_F(ValidationTest, GSOtherMissingAttributeFail) {
+  TestCheck(L"..\\CodeGenHLSL\\attributes-gs-no-inout-other.hlsl");
+}
+
+TEST_F(ValidationTest, GSMissingMaxVertexCountFail) {
+  TestCheck(L"..\\CodeGenHLSL\\attributes-gs-no-maxvertexcount.hlsl");
+}
+
+TEST_F(ValidationTest, HSMissingPCFFail) {
+  TestCheck(L"..\\CodeGenHLSL\\attributes-hs-no-pcf.hlsl");
+}
+
 TEST_F(ValidationTest, GetAttributeAtVertexInVSFail) {
   if (m_ver.SkipDxilVersion(1,1)) return;
   RewriteAssemblyCheckMsg(