Преглед изворни кода

Allows Patch Constant Function definition after HS entry point (#748)

* Allows Patch Constant Function definition after HS entry point

* Fixes SimpleHS8.hlsl
John Porto пре 7 година
родитељ
комит
3563ae9c5a

+ 149 - 91
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -114,16 +114,30 @@ private:
   hlsl::DxilResourceBase::Class TypeToClass(clang::QualType Ty);
 
   // Save the entryFunc so don't need to find it with original name.
-  llvm::Function *EntryFunc;
+  struct EntryFunctionInfo {
+    clang::SourceLocation SL = clang::SourceLocation();
+    llvm::Function *Func = nullptr;
+  };
+
+  EntryFunctionInfo Entry;
   
   // Map to save patch constant functions
-  StringMap<Function *> patchConstantFunctionMap;
+  struct PatchConstantInfo {
+    clang::SourceLocation SL = clang::SourceLocation();
+    llvm::Function *Func = nullptr;
+    std::uint32_t NumOverloads = 0;
+  };
+
+  StringMap<PatchConstantInfo> patchConstantFunctionMap;
   std::unordered_map<Function *, std::unique_ptr<DxilFunctionProps>>
       patchConstantFunctionPropsMap;
   bool IsPatchConstantFunction(const Function *F);
 
+  std::unordered_map<Function *, const clang::HLSLPatchConstantFuncAttr *>
+      HSEntryPatchConstantFuncAttr;
+
   // Map to save entry functions.
-  StringMap<Function *> entryFunctionMap;
+  StringMap<EntryFunctionInfo> entryFunctionMap;
 
   // Map to save static global init exp.
   std::unordered_map<Expr *, GlobalVariable *> staticConstGlobalInitMap;
@@ -215,6 +229,10 @@ public:
 
   /// Add resouce to the program
   void addResource(Decl *D) override;
+  void SetPatchConstantFunction(const EntryFunctionInfo &EntryFunc);
+  void SetPatchConstantFunctionWithAttr(
+      const EntryFunctionInfo &EntryFunc,
+      const clang::HLSLPatchConstantFuncAttr *PatchConstantFuncAttr);
   void FinishCodeGen() override;
   bool IsTrivalInitListExpr(CodeGenFunction &CGF, InitListExpr *E) override;
   Value *EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E, Value *DestPtr) override;
@@ -312,7 +330,7 @@ void clang::CompileRootSignature(
 // CGMSHLSLRuntime methods.
 //
 CGMSHLSLRuntime::CGMSHLSLRuntime(CodeGenModule &CGM)
-    : CGHLSLRuntime(CGM), Context(CGM.getLLVMContext()), EntryFunc(nullptr),
+    : CGHLSLRuntime(CGM), Context(CGM.getLLVMContext()), Entry(),
       TheModule(CGM.getModule()),
       dataLayout(CGM.getLangOpts().UseMinPrecision
                        ? hlsl::DXIL::kLegacyLayoutString
@@ -1120,8 +1138,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   // Set entry function
   const std::string &entryName = m_pHLModule->GetEntryFunctionName();
   bool isEntry = FD->getNameAsString() == entryName;
-  if (isEntry)
-    EntryFunc = F;
+  if (isEntry) {
+    Entry.Func = F;
+    Entry.SL = FD->getLocation();
+  }
 
   DiagnosticsEngine &Diags = CGM.getDiags();
 
@@ -1174,22 +1194,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   bool isPatchConstantFunction = false;
   if (!isEntry && CGM.getContext().IsPatchConstantFunctionDecl(FD)) {
     isPatchConstantFunction = true;
-    if (patchConstantFunctionMap.count(FD->getName()) == 0)
-      patchConstantFunctionMap[FD->getName()] = F;
-    else {
-      // TODO: This is not the same as how fxc handles patch constant functions.
-      //  This will fail if more than one function with the same name has a
-      //  SV_TessFactor semantic. Fxc just selects the last function defined
-      //  that has the matching name when referenced by the patchconstantfunc
-      //  attribute from the hull shader currently being compiled.
-      // Report error
-      DiagnosticsEngine &Diags = CGM.getDiags();
-      unsigned DiagID =
-          Diags.getCustomDiagID(DiagnosticsEngine::Error,
-                                "Multiple definitions for patchconstantfunc.");
-      Diags.Report(FD->getLocation(), DiagID);
-      return;
-    }
+    auto &PCI = patchConstantFunctionMap[FD->getName()];
+    PCI.SL = FD->getLocation();
+    PCI.Func = F;
+    ++PCI.NumOverloads;
 
     for (ParmVarDecl *parmDecl : FD->parameters()) {
       QualType Ty = parmDecl->getType();
@@ -1271,37 +1279,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 
     isHS = true;
     funcProps->shaderKind = DXIL::ShaderKind::Hull;
-    StringRef funcName = Attr->getFunctionName();
-
-    if (patchConstantFunctionMap.count(funcName) == 1) {
-      Function *patchConstFunc = patchConstantFunctionMap[funcName];
-      funcProps->ShaderProps.HS.patchConstantFunc = patchConstFunc;
-      DXASSERT_NOMSG(patchConstantFunctionPropsMap.count(patchConstFunc));
-      // Check no inout parameter for patch constant function.
-      DxilFunctionAnnotation *patchConstFuncAnnotation =
-          m_pHLModule->GetFunctionAnnotation(patchConstFunc);
-      for (unsigned i = 0; i < patchConstFuncAnnotation->GetNumParameters();
-           i++) {
-        if (patchConstFuncAnnotation->GetParameterAnnotation(i)
-                .GetParamInputQual() == DxilParamInputQual::Inout) {
-          unsigned DiagID = Diags.getCustomDiagID(
-              DiagnosticsEngine::Error,
-              "Patch Constant function should not have inout param.");
-          Diags.Report(Attr->getLocation(), DiagID);
-          return;
-        }
-      }
-    } else {
-      // TODO: Bring this in line with fxc behavior.  In fxc, patchconstantfunc
-      //  selection is based only on name (last function with matching name),
-      //  not by whether it has SV_TessFactor output.
-      //// Report error
-      // DiagnosticsEngine &Diags = CGM.getDiags();
-      // unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
-      //                                        "Cannot find
-      //                                        patchconstantfunc.");
-      // Diags.Report(Attr->getLocation(), DiagID);
-    }
+    HSEntryPatchConstantFuncAttr[F] = Attr;
   }
 
   if (const HLSLOutputControlPointsAttr *Attr =
@@ -1692,37 +1670,6 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
   }
 
-  if (isHS) {
-    // Check
-    Function *patchConstFunc = funcProps->ShaderProps.HS.patchConstantFunc;
-    if (patchConstantFunctionPropsMap.count(patchConstFunc)) {
-      DxilFunctionProps &patchProps =
-          *patchConstantFunctionPropsMap[patchConstFunc];
-      if (patchProps.ShaderProps.HS.outputControlPoints != 0 &&
-          patchProps.ShaderProps.HS.outputControlPoints !=
-              funcProps->ShaderProps.HS.outputControlPoints) {
-        unsigned DiagID = Diags.getCustomDiagID(
-            DiagnosticsEngine::Error,
-            "Patch constant function's output patch input "
-            "should have %0 elements, but has %1.");
-        Diags.Report(FD->getLocation(), DiagID)
-            << funcProps->ShaderProps.HS.outputControlPoints
-            << patchProps.ShaderProps.HS.outputControlPoints;
-      }
-      if (patchProps.ShaderProps.HS.inputControlPoints != 0 &&
-          patchProps.ShaderProps.HS.inputControlPoints !=
-              funcProps->ShaderProps.HS.inputControlPoints) {
-        unsigned DiagID =
-            Diags.getCustomDiagID(DiagnosticsEngine::Error,
-                                  "Patch constant function's input patch input "
-                                  "should have %0 elements, but has %1.");
-        Diags.Report(FD->getLocation(), DiagID)
-            << funcProps->ShaderProps.HS.inputControlPoints
-            << patchProps.ShaderProps.HS.inputControlPoints;
-      }
-    }
-  }
-
   // Only add functionProps when exist.
   if (profileAttributes || isEntry)
     m_pHLModule->AddDxilFunctionProps(F, funcProps);
@@ -1738,7 +1685,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
           "redefinition of %0");
       Diags.Report(FD->getLocStart(), DiagID) << FD->getName();
     }
-    entryFunctionMap[FD->getNameAsString()] = F;
+    auto &Entry = entryFunctionMap[FD->getNameAsString()];
+    Entry.SL = FD->getLocation();
+    Entry.Func= F;
   }
 }
 
@@ -2600,13 +2549,13 @@ HLCBuffer &CGMSHLSLRuntime::GetOrCreateCBuffer(HLSLBufferDecl *D) {
 bool CGMSHLSLRuntime::IsPatchConstantFunction(const Function *F) {
   DXASSERT_NOMSG(F != nullptr);
   for (auto && p : patchConstantFunctionMap) {
-    if (p.second == F) return true;
+    if (p.second.Func == F) return true;
   }
   return false;
 }
 
 void CGMSHLSLRuntime::SetEntryFunction() {
-  if (EntryFunc == nullptr) {
+  if (Entry.Func == nullptr) {
     DiagnosticsEngine &Diags = CGM.getDiags();
     unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
                                             "cannot find entry function %0");
@@ -2614,7 +2563,7 @@ void CGMSHLSLRuntime::SetEntryFunction() {
     return;
   }
 
-  m_pHLModule->SetEntryFunction(EntryFunc);
+  m_pHLModule->SetEntryFunction(Entry.Func);
 }
 
 // Here the size is CB size. So don't need check type.
@@ -4096,6 +4045,106 @@ void ProcessCtorFunctions(llvm::Module &M, StringRef globalName,
   }
 }
 
+void CGMSHLSLRuntime::SetPatchConstantFunction(const EntryFunctionInfo &EntryFunc) {
+
+  auto AttrsIter = HSEntryPatchConstantFuncAttr.find(EntryFunc.Func);
+
+  if (AttrsIter == HSEntryPatchConstantFuncAttr.end()) {
+    DiagnosticsEngine &Diags = CGM.getDiags();
+    unsigned DiagID =
+      Diags.getCustomDiagID(DiagnosticsEngine::Error,
+        "HS entry is missing patchconstantfunc attribute.");
+    Diags.Report(EntryFunc.SL, DiagID);
+    return;
+  }
+
+  SetPatchConstantFunctionWithAttr(Entry, AttrsIter->second);
+}
+
+void CGMSHLSLRuntime::SetPatchConstantFunctionWithAttr(
+    const EntryFunctionInfo &EntryFunc,
+    const clang::HLSLPatchConstantFuncAttr *PatchConstantFuncAttr) {
+  StringRef funcName = PatchConstantFuncAttr->getFunctionName();
+
+  auto Entry = patchConstantFunctionMap.find(funcName);
+  if (Entry == patchConstantFunctionMap.end()) {
+    DiagnosticsEngine &Diags = CGM.getDiags();
+    unsigned DiagID =
+      Diags.getCustomDiagID(DiagnosticsEngine::Error,
+        "Cannot find patchconstantfunc %0.");
+    Diags.Report(PatchConstantFuncAttr->getLocation(), DiagID)
+      << funcName;
+    return;
+  }
+
+  if (Entry->second.NumOverloads != 1) {
+    DXASSERT(false,
+        "hlsl::DiagnoseTranslationUnit used to check for this condition.");
+    DiagnosticsEngine &Diags = CGM.getDiags();
+    unsigned DiagID =
+      Diags.getCustomDiagID(DiagnosticsEngine::Warning,
+        "Multiple functions match patchconstantfunc %0.");
+    unsigned NoteID =
+      Diags.getCustomDiagID(DiagnosticsEngine::Note,
+        "This overload was selected.");
+    Diags.Report(PatchConstantFuncAttr->getLocation(), DiagID)
+      << funcName;
+    Diags.Report(Entry->second.SL, NoteID);
+  }
+
+  Function *patchConstFunc = Entry->second.Func;
+  DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func);
+  DXASSERT(HSProps != nullptr,
+    " else AddHLSLFunctionInfo did not save the dxil function props for the "
+    "HS entry.");
+  HSProps->ShaderProps.HS.patchConstantFunc = patchConstFunc;
+  DXASSERT_NOMSG(patchConstantFunctionPropsMap.count(patchConstFunc));
+  // Check no inout parameter for patch constant function.
+  DxilFunctionAnnotation *patchConstFuncAnnotation =
+    m_pHLModule->GetFunctionAnnotation(patchConstFunc);
+  for (unsigned i = 0; i < patchConstFuncAnnotation->GetNumParameters(); i++) {
+    if (patchConstFuncAnnotation->GetParameterAnnotation(i)
+      .GetParamInputQual() == DxilParamInputQual::Inout) {
+      DiagnosticsEngine &Diags = CGM.getDiags();
+      unsigned DiagID = Diags.getCustomDiagID(
+        DiagnosticsEngine::Error,
+        "Patch Constant function %0 should not have inout param.");
+      Diags.Report(Entry->second.SL, DiagID) << funcName;
+    }
+  }
+  
+  // Input/Output control point validation.
+  if (patchConstantFunctionPropsMap.count(patchConstFunc)) {
+    const DxilFunctionProps &patchProps =
+      *patchConstantFunctionPropsMap[patchConstFunc];
+    if (patchProps.ShaderProps.HS.inputControlPoints != 0 &&
+      patchProps.ShaderProps.HS.inputControlPoints !=
+      HSProps->ShaderProps.HS.inputControlPoints) {
+      DiagnosticsEngine &Diags = CGM.getDiags();
+      unsigned DiagID =
+        Diags.getCustomDiagID(DiagnosticsEngine::Error,
+          "Patch constant function's input patch input "
+          "should have %0 elements, but has %1.");
+      Diags.Report(Entry->second.SL, DiagID)
+        << HSProps->ShaderProps.HS.inputControlPoints
+        << patchProps.ShaderProps.HS.inputControlPoints;
+    }
+    if (patchProps.ShaderProps.HS.outputControlPoints != 0 &&
+      patchProps.ShaderProps.HS.outputControlPoints !=
+      HSProps->ShaderProps.HS.outputControlPoints) {
+      DiagnosticsEngine &Diags = CGM.getDiags();
+      unsigned DiagID = Diags.getCustomDiagID(
+        DiagnosticsEngine::Error,
+        "Patch constant function's output patch input "
+        "should have %0 elements, but has %1.");
+      Diags.Report(Entry->second.SL, DiagID)
+        << HSProps->ShaderProps.HS.outputControlPoints
+        << patchProps.ShaderProps.HS.outputControlPoints;
+    }
+  }
+  
+}
+
 void CGMSHLSLRuntime::FinishCodeGen() {
   // Library don't have entry.
   if (!m_bIsLib) {
@@ -4107,9 +4156,18 @@ void CGMSHLSLRuntime::FinishCodeGen() {
              "else SetEntryFunction should have reported this condition");
       return;
     }
+
+    if (m_pHLModule->GetShaderModel()->IsHS()) {
+      SetPatchConstantFunction(Entry);
+    }
   } else {
     for (auto &it : entryFunctionMap) {
-      CloneShaderEntry(it.second, it.getKey(), *m_pHLModule);
+      CloneShaderEntry(it.second.Func, it.getKey(), *m_pHLModule);
+
+      auto AttrIter = HSEntryPatchConstantFuncAttr.find(it.second.Func);
+      if (AttrIter != HSEntryPatchConstantFuncAttr.end()) {
+        SetPatchConstantFunctionWithAttr(it.second, AttrIter->second);
+      }
     }
   }
 
@@ -4150,7 +4208,7 @@ void CGMSHLSLRuntime::FinishCodeGen() {
   if (!m_bIsLib) {
     // need this for "llvm.global_dtors"?
     ProcessCtorFunctions(TheModule ,"llvm.global_ctors",
-                  EntryFunc->getEntryBlock().getFirstInsertionPt());
+                  Entry.Func->getEntryBlock().getFirstInsertionPt());
   }
   // translate opcode into parameter for intrinsic functions
   AddOpcodeParamForIntrinsics(*m_pHLModule, m_IntrinsicMap, resMetadataMap);
@@ -5963,7 +6021,7 @@ void CGMSHLSLRuntime::EmitHLSLRootSignature(CodeGenFunction &CGF,
                                             HLSLRootSignatureAttr *RSA,
                                             Function *Fn) {
   // Only parse root signature for entry function.
-  if (Fn != EntryFunc)
+  if (Fn != Entry.Func)
     return;
 
   StringRef StrRef = RSA->getSignatureName();

+ 1 - 2
tools/clang/test/CodeGenHLSL/SimpleHs6.hlsl

@@ -1,7 +1,6 @@
 // RUN: %dxc -E main -T hs_6_0 -Zi %s | FileCheck %s
 
 // CHECK: may only have one InputPatch parameter
-// CHECK: Patch constant function's output patch input should have 3 elements, but has 5.
 
 //--------------------------------------------------------------------------------------
 // SimpleTessellation.hlsl
@@ -64,7 +63,7 @@ HSPerPatchData HSPerPatchFunc( const InputPatch< PSSceneIn, 3 > points,  OutputP
 [patchconstantfunc("HSPerPatchFunc")]
 [outputcontrolpoints(3)]
 HSPerVertexData main( const uint id : SV_OutputControlPointID,
-                               const InputPatch< PSSceneIn, 3 > points, const InputPatch< PSSceneIn, 3 > points2 )
+                      const InputPatch< PSSceneIn, 3 > points, const InputPatch< PSSceneIn, 3 > points2 )
 {
     HSPerVertexData v;
 

+ 5 - 3
tools/clang/test/CodeGenHLSL/SimpleHs8.hlsl

@@ -1,6 +1,8 @@
 // RUN: %dxc -E main -T hs_6_0  %s | FileCheck %s
 
-// CHECK: Patch Constant function should not have inout param
+// CHECK-DAG: Patch Constant function HSPerPatchFunc should not have inout param
+// CHECK-DAG: Patch constant function's input patch input should have 3 elements, but has 12.
+// CHECK-DAG: Patch constant function's output patch input should have 3 elements, but has 5.
 
 //--------------------------------------------------------------------------------------
 // SimpleTessellation.hlsl
@@ -44,7 +46,7 @@ float4 HSPerPatchFunc()
     return 1.8;
 }
 
-HSPerPatchData HSPerPatchFunc( const InputPatch< PSSceneIn, 3 > points, OutputPatch<HSPerVertexData, 5> outp, inout float x)
+HSPerPatchData HSPerPatchFunc( const InputPatch< PSSceneIn, 12 > points, OutputPatch<HSPerVertexData, 5> outp, inout float x)
 {
     HSPerPatchData d;
 
@@ -63,7 +65,7 @@ HSPerPatchData HSPerPatchFunc( const InputPatch< PSSceneIn, 3 > points, OutputPa
 [patchconstantfunc("HSPerPatchFunc")]
 [outputcontrolpoints(3)]
 HSPerVertexData main( const uint id : SV_OutputControlPointID,
-                               const InputPatch< PSSceneIn, 3 > points, const InputPatch< PSSceneIn, 3 > points2 )
+                      const InputPatch< PSSceneIn, 3 > points )
 {
     HSPerVertexData v;
 

+ 82 - 0
tools/clang/test/CodeGenHLSL/SimpleHs9.hlsl

@@ -0,0 +1,82 @@
+// RUN: %dxc -E main -T hs_6_0  %s | FileCheck %s
+
+// CHECK: SV_TessFactor 0
+// CHECK: SV_InsideTessFactor 0
+
+// CHECK: define void @main
+
+// CHECK: define void {{.*}}HSPerPatchFunc
+// CHECK: dx.op.storePatchConstant.f32{{.*}}float 1.0
+// CHECK: dx.op.storePatchConstant.f32{{.*}}float 2.0
+// CHECK: dx.op.storePatchConstant.f32{{.*}}float 3.0
+// CHECK: dx.op.storePatchConstant.f32{{.*}}float 4.0
+
+//--------------------------------------------------------------------------------------
+// SimpleTessellation.hlsl
+//
+// Advanced Technology Group (ATG)
+// Copyright (C) Microsoft Corporation. All rights reserved.
+//--------------------------------------------------------------------------------------
+
+
+struct PSSceneIn
+{
+    float4 pos     : SV_Position;
+    float2 tex     : TEXCOORD0;
+    float3 norm    : NORMAL;
+    uint   RTIndex : SV_RenderTargetArrayIndex;
+};
+
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// Simple forwarding Tessellation shaders
+
+struct HSPerVertexData
+{
+    // This is just the original vertex verbatim. In many real life cases this would be a
+    // control point instead
+    PSSceneIn v;
+};
+
+struct HSPerPatchData
+{
+    // We at least have to specify tess factors per patch
+    // As we're tesselating triangles, there will be 4 tess factors
+    // In real life case this might contain face normal, for example
+	float	edges[3] : SV_TessFactor;
+	float	inside   : SV_InsideTessFactor;
+};
+
+float4 HSPerPatchFunc()
+{
+    return 1.8;
+}
+
+// hull per-control point shader
+[domain("tri")]
+[partitioning("fractional_odd")]
+[outputtopology("triangle_cw")]
+[patchconstantfunc("HSPerPatchFunc")]
+[outputcontrolpoints(3)]
+HSPerVertexData main( const uint id : SV_OutputControlPointID,
+                      const InputPatch< PSSceneIn, 3 > points )
+{
+    HSPerVertexData v;
+
+    // Just forward the vertex
+    v.v = points[ id ];
+
+	return v;
+}
+
+HSPerPatchData HSPerPatchFunc(const InputPatch< PSSceneIn, 3 > points)
+{
+  HSPerPatchData d;
+
+  d.edges[0] = 1;
+  d.edges[1] = 2;
+  d.edges[2] = 3;
+  d.inside = 4;
+
+  return d;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -761,6 +761,7 @@ public:
   TEST_METHOD(CodeGenSimpleHS6)
   TEST_METHOD(CodeGenSimpleHS7)
   TEST_METHOD(CodeGenSimpleHS8)
+  TEST_METHOD(CodeGenSimpleHS9)
   TEST_METHOD(CodeGenSMFail)
   TEST_METHOD(CodeGenSrv_Ms_Load1)
   TEST_METHOD(CodeGenSrv_Ms_Load2)
@@ -4190,6 +4191,10 @@ TEST_F(CompilerTest, CodeGenSimpleHS8) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\SimpleHS8.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenSimpleHS9) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\SimpleHS9.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenSMFail) {
   CodeGenTestCheck(L"sm-fail.hlsl");
 }