2
0
Эх сурвалжийг харах

Close #62: strange error output on recursive function calls (#150)

* Traverse call graph to report recursive calls in the front-end.
* Emit function declarations only as needed.
* Fix extra echo output in hctbuild help.
* Remove unused test case.
Marcelo Lopez Ruiz 8 жил өмнө
parent
commit
843be2528a

+ 3 - 0
lib/HLSL/DxilGenerationPass.cpp

@@ -838,6 +838,9 @@ void DxilGenerationPass::CreateDxilSignatures() {
   if (SM->IsHS()) {
     HLFunctionProps &EntryProps = m_pHLModule->GetHLFunctionProps(EntryFunc);
     Function *patchConstantFunc = EntryProps.ShaderProps.HS.patchConstantFunc;
+    if (patchConstantFunc == nullptr) {
+      EntryFunc->getContext().emitError("Patch constant function is not specified.");
+    }
 
     DxilFunctionAnnotation *patchFuncAnnotation = m_pHLModule->GetFunctionAnnotation(patchConstantFunc);
     DXASSERT(patchFuncAnnotation, "must have function annotation for patch constant function");

+ 2 - 3
lib/HLSL/DxilValidation.cpp

@@ -3878,13 +3878,12 @@ CalculateCallDepth(CallGraphNode *node,
   funcSet.insert(node->getFunction());
   for (auto it = node->begin(), ei = node->end(); it != ei; it++) {
     CallGraphNode *toNode = it->second;
-    if (callStack.count(toNode) > 0) {
-      // Recursive
+    if (callStack.insert(toNode).second == false) {
+      // Recursive.
       return true;
     }
     if (depthMap[toNode] < depth)
       depthMap[toNode] = depth;
-    callStack.insert(toNode);
     if (CalculateCallDepth(toNode, depthMap, callStack, funcSet)) {
       // Recursive
       return true;

+ 2 - 0
tools/clang/include/clang/AST/ASTContext.h

@@ -1111,6 +1111,8 @@ public:
   QualType getFunctionType(QualType ResultTy, ArrayRef<QualType> Args,
                            const FunctionProtoType::ExtProtoInfo &EPI,
                            ArrayRef<hlsl::ParameterModifier> ParamMods) const; // HLSL Change
+  /// \brief Check whether the function declaration can be used as a patch constant function.
+  bool IsPatchConstantFunctionDecl(const FunctionDecl *FD) const; // HLSL Change
 
   /// \brief Return the unique reference to the type for the specified type
   /// declaration.

+ 1 - 0
tools/clang/include/clang/Basic/LangOptions.h

@@ -151,6 +151,7 @@ public:
   // MS Change Starts
   bool HLSL2015;  // Only supported for IntelliSense scenarios.
   bool HLSL2016;
+  std::string HLSLEntryFunction;
   unsigned RootSigMajor;
   unsigned RootSigMinor;
   // MS Change Ends

+ 2 - 0
tools/clang/include/clang/Sema/SemaHLSL.h

@@ -89,6 +89,8 @@ void DiagnoseRegisterType(
   clang::QualType type,
   char registerType);
 
+void DiagnoseTranslationUnit(clang::Sema* self);
+
 void DiagnoseUnusualAnnotationsForHLSL(
   clang::Sema& S,
   std::vector<hlsl::UnusualAnnotation *>& annotations);

+ 4 - 1
tools/clang/lib/AST/ASTContext.cpp

@@ -8461,7 +8461,10 @@ bool ASTContext::DeclMustBeEmitted(const Decl *D) {
     if (Linkage == GVA_Internal || Linkage == GVA_AvailableExternally ||
         Linkage == GVA_DiscardableODR)
       return false;
-    return true;
+    // HLSL Change Starts
+    // Don't just return true because of visibility, unless building a library (which is not currently implemented)
+    return FD->getName() == getLangOpts().HLSLEntryFunction || IsPatchConstantFunctionDecl(FD);
+    // HLSL Change Ends
   }
   
   const VarDecl *VD = cast<VarDecl>(D);

+ 54 - 1
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -21,9 +21,9 @@
 #include "clang/Sema/SemaDiagnostic.h"
 #include "clang/Sema/Sema.h"
 #include "clang/Sema/Overload.h"
-#include <map>
 #include "dxc/Support/Global.h"
 #include "dxc/HLSL/HLOperations.h"
+#include "dxc/HLSL/DxilSemantic.h"
 
 using namespace clang;
 using namespace hlsl;
@@ -928,3 +928,56 @@ UnusualAnnotation* hlsl::UnusualAnnotation::CopyToASTContext(ASTContext& Context
   memcpy(result, this, instanceSize);
   return (UnusualAnnotation*)result;
 }
+
+static bool HasTessFactorSemantic(const ValueDecl *decl) {
+  for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
+    switch (it->getKind()) {
+    case UnusualAnnotation::UA_SemanticDecl: {
+      const SemanticDecl *sd = cast<SemanticDecl>(it);
+      const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
+      if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
+        return true;
+    }
+    }
+  }
+  return false;
+}
+
+static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
+  if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
+    return false;
+
+  if (const RecordType *RT = Ty->getAsStructureType()) {
+    RecordDecl *RD = RT->getDecl();
+    for (FieldDecl *fieldDecl : RD->fields()) {
+      if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
+        return true;
+    }
+    return false;
+  }
+
+  if (const clang::ArrayType *arrayTy = Ty->getAsArrayTypeUnsafe())
+    return HasTessFactorSemantic(decl);
+
+  return false;
+}
+
+bool ASTContext::IsPatchConstantFunctionDecl(const FunctionDecl *FD) const {
+  // This checks whether the function is structurally capable of being a patch
+  // constant function, not whether it is in fact the patch constant function
+  // for the entry point of a compiled hull shader (which may not have been
+  // seen yet). So the answer is conservative.
+  if (!FD->getReturnType()->isVoidType()) {
+    // Try to find TessFactor in return type.
+    if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
+      return true;
+  }
+  // Try to find TessFactor in out param.
+  for (const ParmVarDecl *param : FD->params()) {
+    if (param->hasAttr<HLSLOutAttr>()) {
+      if (HasTessFactorSemanticRecurse(param, param->getType()))
+        return true;
+    }
+  }
+  return false;
+}

+ 1 - 50
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -422,55 +422,6 @@ CGMSHLSLRuntime::SetSemantic(const NamedDecl *decl,
   return SourceLocation();
 }
 
-static bool HasTessFactorSemantic(const ValueDecl *decl) {
-  for (const hlsl::UnusualAnnotation *it : decl->getUnusualAnnotations()) {
-    switch (it->getKind()) {
-    case hlsl::UnusualAnnotation::UA_SemanticDecl: {
-      const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
-      const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
-      if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
-        return true;
-    }
-    }
-  }
-  return false;
-}
-
-static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
-  if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
-    return false;
-
-  if (const RecordType *RT = Ty->getAsStructureType()) {
-    RecordDecl *RD = RT->getDecl();
-    for (FieldDecl *fieldDecl : RD->fields()) {
-      if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
-        return true;
-    }
-    return false;
-  }
-
-  if (const clang::ArrayType *arrayTy = Ty->getAsArrayTypeUnsafe())
-    return HasTessFactorSemantic(decl);
-
-  return false;
-}
-// TODO: get from type annotation.
-static bool IsPatchConstantFunctionDecl(const FunctionDecl *FD) {
-  if (!FD->getReturnType()->isVoidType()) {
-    // Try to find TessFactor in return type.
-    if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
-      return true;
-  }
-  // Try to find TessFactor in out param.
-  for (ParmVarDecl *param : FD->params()) {
-    if (param->hasAttr<HLSLOutAttr>()) {
-      if (HasTessFactorSemanticRecurse(param, param->getType()))
-        return true;
-    }
-  }
-  return false;
-}
-
 static DXIL::TessellatorDomain StringToDomain(StringRef domain) {
   if (domain == "isoline")
     return DXIL::TessellatorDomain::IsoLine;
@@ -1049,7 +1000,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 
   // Save patch constant function to patchConstantFunctionMap.
   bool isPatchConstantFunction = false;
-  if (IsPatchConstantFunctionDecl(FD)) {
+  if (CGM.getContext().IsPatchConstantFunctionDecl(FD)) {
     isPatchConstantFunction = true;
     if (patchConstantFunctionMap.count(FD->getName()) == 0)
       patchConstantFunctionMap[FD->getName()] = F;

+ 7 - 0
tools/clang/lib/Parse/ParseAST.cpp

@@ -22,6 +22,7 @@
 #include "clang/Sema/ExternalSemaSource.h"
 #include "clang/Sema/Sema.h"
 #include "clang/Sema/SemaConsumer.h"
+#include "clang/Sema/SemaHLSL.h" // HLSL Change
 #include "llvm/Support/CrashRecoveryContext.h"
 #include <cstdio>
 #include <memory>
@@ -150,6 +151,12 @@ void clang::ParseAST(Sema &S, bool PrintStats, bool SkipFunctionBodies) {
   for (Decl *D : S.WeakTopLevelDecls())
     Consumer->HandleTopLevelDecl(DeclGroupRef(D));
   
+  // HLSL Change Starts
+  // Provide the opportunity to generate translation-unit level validation
+  // errors in the front-end, without relying on code generation being
+  // available.
+  hlsl::DiagnoseTranslationUnit(&S);
+  // HLSL Change Ends
   Consumer->HandleTranslationUnit(S.getASTContext());
 
   std::swap(OldCollectStats, S.CollectStats);

+ 235 - 1
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -10,6 +10,8 @@
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/DenseMap.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/DeclCXX.h"
@@ -17,6 +19,7 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExternalASTSource.h"
+#include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/AST/HlslTypes.h"
 #include "clang/Sema/Overload.h"
@@ -27,7 +30,6 @@
 #include "clang/Sema/Template.h"
 #include "clang/Sema/TemplateDeduction.h"
 #include "clang/Sema/SemaHLSL.h"
-#include <map>
 #include "dxc/Support/Global.h"
 #include "dxc/Support/WinIncludes.h"
 #include "dxc/dxcapi.internal.h"
@@ -2227,6 +2229,139 @@ static void AddHLSLSubscriptAttr(Decl *D, ASTContext &context, HLSubscriptOpcode
   D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast<unsigned>(opcode)));
 }
 
+//
+// This is similar to clang/Analysis/CallGraph, but the following differences
+// motivate this:
+//
+// - track traversed vs. observed nodes explicitly
+// - fully visit all reachable functions
+// - merge graph visiting with checking for recursion
+// - track global variables and types used (NYI)
+//
+namespace hlsl {
+  struct CallNode {
+    FunctionDecl *CallerFn;
+    ::llvm::SmallPtrSet<FunctionDecl *, 4> CalleeFns;
+  };
+  typedef ::llvm::DenseMap<FunctionDecl*, CallNode> CallNodes;
+  typedef ::llvm::SmallPtrSet<Decl *, 8> FnCallStack;
+  typedef ::llvm::SmallPtrSet<FunctionDecl*, 128> FunctionSet;
+  typedef ::llvm::SmallVector<FunctionDecl*, 32> PendingFunctions;
+
+  // Returns the definition of a function.
+  // This serves two purposes - ignore built-in functions, and pick
+  // a single Decl * to be used in maps and sets.
+  static FunctionDecl *getFunctionWithBody(FunctionDecl *F) {
+    if (!F) return nullptr;
+    if (F->doesThisDeclarationHaveABody()) return F;
+    F = F->getFirstDecl();
+    for (auto &&Candidate : F->redecls()) {
+      if (Candidate->doesThisDeclarationHaveABody()) {
+        return Candidate;
+      }
+    }
+    return nullptr;
+  }
+
+  // AST visitor that maintains visited and pending collections, as well
+  // as recording nodes of caller/callees.
+  class FnReferenceVisitor : public RecursiveASTVisitor<FnReferenceVisitor> {
+  private:
+    CallNodes &m_callNodes;
+    FunctionSet &m_visitedFunctions;
+    PendingFunctions &m_pendingFunctions;
+    FunctionDecl *m_source;
+    CallNodes::iterator m_sourceIt;
+
+  public:
+    FnReferenceVisitor(FunctionSet &visitedFunctions,
+      PendingFunctions &pendingFunctions, CallNodes &callNodes)
+      : m_visitedFunctions(visitedFunctions),
+      m_pendingFunctions(pendingFunctions), m_callNodes(callNodes) {}
+
+    void setSourceFn(FunctionDecl *F) {
+      F = getFunctionWithBody(F);
+      m_source = F;
+      m_sourceIt = m_callNodes.find(F);
+    }
+
+    bool VisitDeclRefExpr(DeclRefExpr *ref) {
+      ValueDecl *valueDecl = ref->getDecl();
+      FunctionDecl *fnDecl = dyn_cast_or_null<FunctionDecl>(valueDecl);
+      fnDecl = getFunctionWithBody(fnDecl);
+      if (fnDecl) {
+        if (m_sourceIt == m_callNodes.end()) {
+          auto result = m_callNodes.insert(
+            std::pair<FunctionDecl *, CallNode>(m_source, CallNode{ m_source }));
+          DXASSERT(result.second == true,
+            "else setSourceFn didn't assign m_sourceIt");
+          m_sourceIt = result.first;
+        }
+        m_sourceIt->second.CalleeFns.insert(fnDecl);
+        if (!m_visitedFunctions.count(fnDecl)) {
+          m_pendingFunctions.push_back(fnDecl);
+        }
+      }
+      return true;
+    }
+  };
+
+  // A call graph that can check for reachability and recursion efficiently.
+  class CallGraphWithRecurseGuard {
+  private:
+    CallNodes m_callNodes;
+    FunctionSet m_visitedFunctions;
+
+    FunctionDecl *CheckRecursion(FnCallStack &CallStack,
+      FunctionDecl *D) const {
+      if (CallStack.insert(D).second == false)
+        return D;
+      auto node = m_callNodes.find(D);
+      if (node != m_callNodes.end()) {
+        for (FunctionDecl *Callee : node->second.CalleeFns) {
+          FunctionDecl *pResult = CheckRecursion(CallStack, Callee);
+          if (pResult)
+            return pResult;
+        }
+      }
+      CallStack.erase(D);
+      return nullptr;
+    }
+
+  public:
+    void BuildForEntry(FunctionDecl *EntryFnDecl) {
+      DXASSERT_NOMSG(EntryFnDecl);
+      EntryFnDecl = getFunctionWithBody(EntryFnDecl);
+      PendingFunctions pendingFunctions;
+      FnReferenceVisitor visitor(m_visitedFunctions, pendingFunctions, m_callNodes);
+      pendingFunctions.push_back(EntryFnDecl);
+      while (!pendingFunctions.empty()) {
+        FunctionDecl *pendingDecl = pendingFunctions.pop_back_val();
+        if (m_visitedFunctions.insert(pendingDecl).second == true) {
+          visitor.setSourceFn(pendingDecl);
+          visitor.TraverseDecl(pendingDecl);
+        }
+      }
+    }
+
+    FunctionDecl *CheckRecursion(FunctionDecl *EntryFnDecl) const {
+      FnCallStack CallStack;
+      EntryFnDecl = getFunctionWithBody(EntryFnDecl);
+      return CheckRecursion(CallStack, EntryFnDecl);
+    }
+
+    void dump() const {
+      OutputDebugStringW(L"Call Nodes:\r\n");
+      for (auto &node : m_callNodes) {
+        OutputDebugFormatA("%s [%p]:\r\n", node.first->getName().str().c_str(), node.first);
+        for (auto callee : node.second.CalleeFns) {
+          OutputDebugFormatA("    %s [%p]\r\n", callee->getName().str().c_str(), callee);
+        }
+      }
+    }
+  };
+}
+
 class HLSLExternalSource : public ExternalSemaSource {
 private:
   // Inner types.
@@ -8568,6 +8703,105 @@ void hlsl::DiagnoseRegisterType(
   }
 }
 
+struct NameLookup {
+  FunctionDecl *Found;
+  FunctionDecl *Other;
+};
+
+static NameLookup GetSingleFunctionDeclByName(clang::Sema *self, StringRef Name, bool checkPatch) {
+  auto DN = DeclarationName(&self->getASTContext().Idents.get(Name));
+  FunctionDecl *pFoundDecl = nullptr;
+  for (auto idIter = self->IdResolver.begin(DN), idEnd = self->IdResolver.end(); idIter != idEnd; ++idIter) {
+    FunctionDecl *pFnDecl = dyn_cast<FunctionDecl>(*idIter);
+    if (!pFnDecl) continue;
+    if (checkPatch && !self->getASTContext().IsPatchConstantFunctionDecl(pFnDecl)) continue;
+    if (pFoundDecl) {
+      return NameLookup{ pFoundDecl, pFnDecl };
+    }
+    pFoundDecl = pFnDecl;
+  }
+  return NameLookup{ pFoundDecl, nullptr };
+}
+
+void hlsl::DiagnoseTranslationUnit(clang::Sema *self) {
+  DXASSERT_NOMSG(self != nullptr);
+
+  // Don't bother with global validation if compilation has already failed.
+  if (self->getDiagnostics().hasErrorOccurred()) {
+    return;
+  }
+
+  // TODO: make these error 'real' errors rather than on-the-fly things
+  // Validate that the entry point is available.
+  ASTContext &Ctx = self->getASTContext();
+  DiagnosticsEngine &Diags = self->getDiagnostics();
+  FunctionDecl *pEntryPointDecl = nullptr;
+  FunctionDecl *pPatchFnDecl = nullptr;
+  const std::string &EntryPointName = self->getLangOpts().HLSLEntryFunction;
+  if (!EntryPointName.empty()) {
+    NameLookup NL = GetSingleFunctionDeclByName(self, EntryPointName, /*checkPatch*/ false);
+    if (NL.Found && NL.Other) {
+      // NOTE: currently we cannot hit this codepath when CodeGen is enabled, because
+      // CodeGenModule::getMangledName will mangle the entry point name into the bare
+      // string, and so ambiguous points will produce an error earlier on.
+      unsigned id = Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
+        "ambiguous entry point function");
+      Diags.Report(NL.Found->getSourceRange().getBegin(), id);
+      Diags.Report(NL.Other->getLocation(), diag::note_previous_definition);
+      return;
+    }
+    pEntryPointDecl = NL.Found;
+    if (!pEntryPointDecl || !pEntryPointDecl->hasBody()) {
+      unsigned id = Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
+        "missing entry point definition");
+      Diags.Report(id);
+      return;
+    }
+  }
+
+  // Validate that there is no recursion; start with the entry function.
+  // NOTE: the information gathered here could be used to bypass code generation
+  // on functions that are unreachable (as an early form of dead code elimination).
+  if (pEntryPointDecl) {
+    if (const HLSLPatchConstantFuncAttr *Attr =
+            pEntryPointDecl->getAttr<HLSLPatchConstantFuncAttr>()) {
+      NameLookup NL = GetSingleFunctionDeclByName(self, Attr->getFunctionName(), /*checkPatch*/ true);
+      if (NL.Found && NL.Other) {
+        unsigned id = Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
+          "ambiguous patch constant function");
+        Diags.Report(NL.Found->getSourceRange().getBegin(), id);
+        Diags.Report(NL.Other->getLocation(), diag::note_previous_definition);
+        return;
+      }
+      if (!NL.Found || !NL.Found->hasBody()) {
+        unsigned id = Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
+          "missing patch function definition");
+        Diags.Report(id);
+        return;
+      }
+      pPatchFnDecl = NL.Found;
+    }
+
+    hlsl::CallGraphWithRecurseGuard CG;
+    CG.BuildForEntry(pEntryPointDecl);
+    Decl *pResult = CG.CheckRecursion(pEntryPointDecl);
+    if (pResult) {
+      unsigned id = Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
+        "recursive functions not allowed");
+      Diags.Report(pResult->getSourceRange().getBegin(), id);
+    }
+    if (pPatchFnDecl) {
+      CG.BuildForEntry(pPatchFnDecl);
+      Decl *pPatchFnDecl = CG.CheckRecursion(pEntryPointDecl);
+      if (pPatchFnDecl) {
+        unsigned id = Diags.getCustomDiagID(clang::DiagnosticsEngine::Level::Error,
+          "recursive functions not allowed (via patch function)");
+        Diags.Report(pPatchFnDecl->getSourceRange().getBegin(), id);
+      }
+    }
+  }
+}
+
 void hlsl::DiagnoseUnusualAnnotationsForHLSL(
   Sema& S,
   std::vector<hlsl::UnusualAnnotation *>& annotations)

+ 2 - 2
tools/clang/test/CodeGenHLSL/parameter_types.hlsl

@@ -1,7 +1,6 @@
 // RUN: %dxc -E main -T cs_6_0 -fcgl %s  | FileCheck %s
 
 // CHECK: float %a, <4 x float> %b, %struct.T* %t, %class.matrix.float.2.3 %m, [3 x <2 x float>]* %n
-
 // CHECK: float* dereferenceable(4) %a, <4 x float>* dereferenceable(16) %b, %struct.T* %t, %class.matrix.float.2.3* dereferenceable(24) %m, [3 x <2 x float>]* %n
 
 struct T{
@@ -30,5 +29,6 @@ void main() {
   test(a, b, t, m, n);
   // TODO: report error on use float as out float4 in front-end.
   // FXC error message is "cannot convert output parameter from 'float4' to 'float'"
-  //test2(a, b, t, m, n);
+  float4 out_b = b;
+  test2(a, out_b, t, m, n);
 }

+ 0 - 19
tools/clang/test/CodeGenHLSL/recursive.hlsl

@@ -1,19 +0,0 @@
-// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
-
-// CHECK: Recursion is not permitted
-// CHECK: with parameter is not permitted
-
-void test_inout(inout float4 m, float4 a) 
-{
-    if (a.x > 1)
-      test_inout(m, a-1);
-    m = abs(m+a*a.yxxx);
-}
-
-float4 main(float4 a : A, float4 b:B) : SV_TARGET
-{
-  float4 x = b;
-  test_inout(x, a);
-  return x;
-}
-

+ 206 - 0
tools/clang/test/CodeGenHLSL/recursive.ll

@@ -0,0 +1,206 @@
+; RUN: %dxv %s | FileCheck %s
+
+; CHECK: Recursion is not permitted
+; CHECK: with parameter is not permitted
+
+; This test originally covered two validator error messages:
+; 1. recursion not allowed
+; 2. functions with parameters are not allowed
+;
+; The recursion error is now handled earlier in the pipeline, and so there
+; is no coverage for error #2. But we do need the validator to check for this
+; in case someone decides to remove this. So instead, we validate the assembly
+; we would have had, which we can generate with this command:
+;
+; dxc -Vd -T ps_6_0 recursive.hlsl
+;
+
+; void test_inout(inout float4 m, float4 a) {
+;    if (a.x > 1)
+;     test_inout(m, a-1);
+;   m = abs(m+a*a.yxxx);
+; }
+;
+; float4 main(float4 a : A, float4 b:B) : SV_TARGET {
+;  float4 x = b;
+;  test_inout(x, a);
+;  return x;
+; }
+
+;
+; Input signature:
+;
+; Name                 Index   Mask Register SysValue  Format   Used
+; -------------------- ----- ------ -------- -------- ------- ------
+; A                        0   xyzw        0     NONE   float
+; B                        0   xyzw        1     NONE   float
+;
+;
+; Output signature:
+;
+; Name                 Index   Mask Register SysValue  Format   Used
+; -------------------- ----- ------ -------- -------- ------- ------
+; SV_Target                0   xyzw        0   TARGET   float   xyzw
+;
+;
+; Pipeline Runtime Information:
+;
+; Pixel Shader
+; DepthOutput=0
+; SampleFrequency=0
+;
+;
+; Input signature:
+;
+; Name                 Index             InterpMode
+; -------------------- ----- ----------------------
+; A                        0                 linear
+; B                        0                 linear
+;
+; Output signature:
+;
+; Name                 Index             InterpMode
+; -------------------- ----- ----------------------
+; SV_Target                0
+;
+; Buffer Definitions:
+;
+;
+; Resource Bindings:
+;
+; Name                                 Type  Format         Dim      ID      HLSL Bind  Count
+; ------------------------------ ---------- ------- ----------- ------- -------------- ------
+;
+target datalayout = "e-m:e-p:32:32-i64:64-f80:32-n8:16:32-a:0:32-S32"
+target triple = "dxil-ms-dx"
+
+; Function Attrs: alwaysinline nounwind
+define internal fastcc void @"\01?test_inout@@YAXAAV?$vector@M$03@@V1@@Z"(<4 x float>* nocapture dereferenceable(16) %m, <4 x float> %a) #0 {
+entry:
+  %a.i0 = extractelement <4 x float> %a, i32 0
+  %a.i1 = extractelement <4 x float> %a, i32 1
+  %a.i2 = extractelement <4 x float> %a, i32 2
+  %a.i3 = extractelement <4 x float> %a, i32 3
+  %0 = load <4 x float>, <4 x float>* %m, align 4
+  %.i05 = extractelement <4 x float> %0, i32 0
+  %.i16 = extractelement <4 x float> %0, i32 1
+  %.i27 = extractelement <4 x float> %0, i32 2
+  %.i38 = extractelement <4 x float> %0, i32 3
+  %1 = alloca <4 x float>, align 4
+  %cmp = fcmp ogt float %a.i0, 1.000000e+00
+  br i1 %cmp, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  store <4 x float> %0, <4 x float>* %1, align 4
+  %sub.i0 = fadd float %a.i0, -1.000000e+00
+  %sub.i1 = fadd float %a.i1, -1.000000e+00
+  %sub.i2 = fadd float %a.i2, -1.000000e+00
+  %sub.i3 = fadd float %a.i3, -1.000000e+00
+  %sub.upto0 = insertelement <4 x float> undef, float %sub.i0, i32 0
+  %sub.upto1 = insertelement <4 x float> %sub.upto0, float %sub.i1, i32 1
+  %sub.upto2 = insertelement <4 x float> %sub.upto1, float %sub.i2, i32 2
+  %sub = insertelement <4 x float> %sub.upto2, float %sub.i3, i32 3
+  call fastcc void @"\01?test_inout@@YAXAAV?$vector@M$03@@V1@@Z"(<4 x float>* nonnull dereferenceable(16) %1, <4 x float> %sub)
+  %2 = load <4 x float>, <4 x float>* %1, align 4
+  %.i0 = extractelement <4 x float> %2, i32 0
+  %.i1 = extractelement <4 x float> %2, i32 1
+  %.i2 = extractelement <4 x float> %2, i32 2
+  %.i3 = extractelement <4 x float> %2, i32 3
+  br label %if.end
+
+if.end:                                           ; preds = %if.then, %entry
+  %.0.i0 = phi float [ %.i0, %if.then ], [ %.i05, %entry ]
+  %.0.i1 = phi float [ %.i1, %if.then ], [ %.i16, %entry ]
+  %.0.i2 = phi float [ %.i2, %if.then ], [ %.i27, %entry ]
+  %.0.i3 = phi float [ %.i3, %if.then ], [ %.i38, %entry ]
+  %mul.i0 = fmul float %a.i0, %a.i1
+  %mul.i2 = fmul float %a.i2, %a.i0
+  %mul.i3 = fmul float %a.i3, %a.i0
+  %add.i0 = fadd float %mul.i0, %.0.i0
+  %add.i1 = fadd float %mul.i0, %.0.i1
+  %add.i2 = fadd float %mul.i2, %.0.i2
+  %add.i3 = fadd float %mul.i3, %.0.i3
+  %FAbs = call float @dx.op.unary.f32(i32 6, float %add.i0)  ; FAbs(value)
+  %3 = insertelement <4 x float> undef, float %FAbs, i64 0
+  %FAbs2 = call float @dx.op.unary.f32(i32 6, float %add.i1)  ; FAbs(value)
+  %4 = insertelement <4 x float> %3, float %FAbs2, i64 1
+  %FAbs3 = call float @dx.op.unary.f32(i32 6, float %add.i2)  ; FAbs(value)
+  %5 = insertelement <4 x float> %4, float %FAbs3, i64 2
+  %FAbs4 = call float @dx.op.unary.f32(i32 6, float %add.i3)  ; FAbs(value)
+  %6 = insertelement <4 x float> %5, float %FAbs4, i64 3
+  store <4 x float> %6, <4 x float>* %m, align 4
+  ret void
+}
+
+define void @main() {
+entry:
+  %0 = call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 0, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %1 = insertelement <4 x float> undef, float %0, i64 0
+  %2 = call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 1, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %3 = insertelement <4 x float> %1, float %2, i64 1
+  %4 = call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 2, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %5 = insertelement <4 x float> %3, float %4, i64 2
+  %6 = call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 3, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %7 = insertelement <4 x float> %5, float %6, i64 3
+  %8 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %9 = insertelement <4 x float> undef, float %8, i64 0
+  %10 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %11 = insertelement <4 x float> %9, float %10, i64 1
+  %12 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 2, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %13 = insertelement <4 x float> %11, float %12, i64 2
+  %14 = call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 3, i32 undef)  ; LoadInput(inputSigId,rowIndex,colIndex,gsVertexAxis)
+  %15 = insertelement <4 x float> %13, float %14, i64 3
+  %16 = alloca <4 x float>, align 4
+  store <4 x float> %7, <4 x float>* %16, align 4
+  call fastcc void @"\01?test_inout@@YAXAAV?$vector@M$03@@V1@@Z"(<4 x float>* nonnull dereferenceable(16) %16, <4 x float> %15)
+  %17 = load <4 x float>, <4 x float>* %16, align 4
+  %18 = extractelement <4 x float> %17, i64 0
+  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %18)  ; StoreOutput(outputtSigId,rowIndex,colIndex,value)
+  %19 = extractelement <4 x float> %17, i64 1
+  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %19)  ; StoreOutput(outputtSigId,rowIndex,colIndex,value)
+  %20 = extractelement <4 x float> %17, i64 2
+  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %20)  ; StoreOutput(outputtSigId,rowIndex,colIndex,value)
+  %21 = extractelement <4 x float> %17, i64 3
+  call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %21)  ; StoreOutput(outputtSigId,rowIndex,colIndex,value)
+  ret void
+}
+
+; Function Attrs: nounwind readnone
+declare float @dx.op.loadInput.f32(i32, i32, i32, i8, i32) #1
+
+; Function Attrs: nounwind
+declare void @dx.op.storeOutput.f32(i32, i32, i32, i8, float) #2
+
+; Function Attrs: nounwind readnone
+declare float @dx.op.unary.f32(i32, float) #1
+
+attributes #0 = { alwaysinline nounwind "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="0" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #1 = { nounwind readnone }
+attributes #2 = { nounwind }
+
+!llvm.ident = !{!0}
+!dx.version = !{!1}
+!dx.shaderModel = !{!2}
+!dx.typeAnnotations = !{!3}
+!dx.entryPoints = !{!12}
+
+!0 = !{!"clang version 3.7 (tags/RELEASE_370/final)"}
+!1 = !{i32 1, i32 0}
+!2 = !{!"ps", i32 6, i32 0}
+!3 = !{i32 1, void (<4 x float>*, <4 x float>)* @"\01?test_inout@@YAXAAV?$vector@M$03@@V1@@Z", !4, void ()* @main, !10}
+!4 = !{!5, !7, !9}
+!5 = !{i32 1, !6, !6}
+!6 = !{}
+!7 = !{i32 2, !8, !6}
+!8 = !{i32 7, i32 9}
+!9 = !{i32 0, !8, !6}
+!10 = !{!11}
+!11 = !{i32 0, !6, !6}
+!12 = !{void ()* @main, !"main", !13, null, null}
+!13 = !{!14, !18, null}
+!14 = !{!15, !17}
+!15 = !{i32 0, !"A", i8 9, i8 0, !16, i8 2, i32 1, i8 4, i32 0, i8 0, null}
+!16 = !{i32 0}
+!17 = !{i32 1, !"B", i8 9, i8 0, !16, i8 2, i32 1, i8 4, i32 1, i8 0, null}
+!18 = !{!19}
+!19 = !{i32 0, !"SV_Target", i8 9, i8 16, !16, i8 0, i32 1, i8 4, i32 0, i8 0, null}

+ 1 - 1
tools/clang/test/CodeGenHLSL/recursive2.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
 
-// CHECK: Recursion is not permitted
+// CHECK: error: recursive functions not allowed
 
 struct M {
   float m;

+ 1 - 1
tools/clang/test/CodeGenHLSL/recursive3.hlsl

@@ -1,6 +1,6 @@
 // RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
 
-// CHECK: Recursion is not permitted
+// CHECK: error: recursive functions not allowed
 
 float test_ret()
 {

+ 1 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -2056,6 +2056,7 @@ public:
       compiler.WriteDefaultOutputDirectly = true;
       compiler.setOutStream(&outStream);
 
+      compiler.getLangOpts().HLSLEntryFunction =
       compiler.getCodeGenOpts().HLSLEntryFunction = pUtf8EntryPoint.m_psz;
       compiler.getCodeGenOpts().HLSLProfile = pUtf8TargetProfile.m_psz;
 

+ 77 - 14
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -294,6 +294,8 @@ public:
   TEST_METHOD(CompileWhenShaderModelMismatchAttributeThenFail)
   TEST_METHOD(CompileBadHlslThenFail)
   TEST_METHOD(CompileLegacyShaderModelThenFail)
+  TEST_METHOD(CompileWhenRecursiveAlbeitStaticTermThenFail)
+  TEST_METHOD(CompileWhenRecursiveThenFail)
 
   TEST_METHOD(CodeGenAbs1)
   TEST_METHOD(CodeGenAbs2)
@@ -1122,6 +1124,32 @@ public:
     }
   }
 
+  std::string VerifyCompileFailed(LPCSTR pText, LPWSTR pTargetProfile, LPCSTR pErrorMsg) {
+    return VerifyCompileFailed(pText, pTargetProfile, pErrorMsg, L"main");
+  }
+
+  std::string VerifyCompileFailed(LPCSTR pText, LPWSTR pTargetProfile, LPCSTR pErrorMsg, LPCWSTR pEntryPoint) {
+    CComPtr<IDxcCompiler> pCompiler;
+    CComPtr<IDxcOperationResult> pResult;
+    CComPtr<IDxcBlobEncoding> pSource;
+    CComPtr<IDxcBlobEncoding> pErrors;
+
+    VERIFY_SUCCEEDED(CreateCompiler(&pCompiler));
+    CreateBlobFromText(pText, &pSource);
+
+    VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"source.hlsl", pEntryPoint,
+      pTargetProfile, nullptr, 0, nullptr, 0, nullptr, &pResult));
+
+    HRESULT status;
+    VERIFY_SUCCEEDED(pResult->GetStatus(&status));
+    VERIFY_FAILED(status);
+    VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&pErrors));
+    if (pErrorMsg && *pErrorMsg) {
+      CheckOperationResultMsgs(pResult, &pErrorMsg, 1, false, false);
+    }
+    return BlobToUtf8(pErrors);
+  }
+
   void VerifyOperationSucceeded(IDxcOperationResult *pResult) {
     HRESULT result;
     VERIFY_SUCCEEDED(pResult->GetStatus(&result));
@@ -1889,20 +1917,55 @@ TEST_F(CompilerTest, CompileBadHlslThenFail) {
 }
 
 TEST_F(CompilerTest, CompileLegacyShaderModelThenFail) {
-  CComPtr<IDxcCompiler> pCompiler;
-  CComPtr<IDxcOperationResult> pResult;
-  CComPtr<IDxcBlobEncoding> pSource;
-
-  VERIFY_SUCCEEDED(CreateCompiler(&pCompiler));
-  CreateBlobFromText(
-    "float4 main(float4 pos : SV_Position) : SV_Target { return pos; }", &pSource);
-
-  VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"source.hlsl", L"main",
-    L"ps_5_1", nullptr, 0, nullptr, 0, nullptr, &pResult));
-
-  HRESULT status;
-  VERIFY_SUCCEEDED(pResult->GetStatus(&status));
-  VERIFY_FAILED(status);
+  VerifyCompileFailed(
+    "float4 main(float4 pos : SV_Position) : SV_Target { return pos; }", L"ps_5_1", nullptr);
+}
+
+TEST_F(CompilerTest, CompileWhenRecursiveAlbeitStaticTermThenFail) {
+  // This shader will compile under fxc because if execution is
+  // simulated statically, it does terminate. dxc changes this behavior
+  // to avoid imposing the requirement on the compiler.
+  const char ShaderText[] =
+    "static int i = 10;\r\n"
+    "float4 f(); // Forward declaration\r\n"
+    "float4 g() { if (i > 10) { i--; return f(); } else return 0; } // Recursive call to 'f'\r\n"
+    "float4 f() { return g(); } // First call to 'g'\r\n"
+    "float4 VS() : SV_Position{\r\n"
+    "  return f(); // First call to 'f'\r\n"
+    "}\r\n";
+  VerifyCompileFailed(ShaderText, L"vs_6_0", "recursive functions not allowed", L"VS");
+}
+
+TEST_F(CompilerTest, CompileWhenRecursiveThenFail) {
+  const char ShaderTextSimple[] =
+    "float4 f(); // Forward declaration\r\n"
+    "float4 g() { return f(); } // Recursive call to 'f'\r\n"
+    "float4 f() { return g(); } // First call to 'g'\r\n"
+    "float4 main() : SV_Position{\r\n"
+    "  return f(); // First call to 'f'\r\n"
+    "}\r\n";
+  VerifyCompileFailed(ShaderTextSimple, L"vs_6_0", "recursive functions not allowed");
+
+  const char ShaderTextIndirect[] =
+    "float4 f(); // Forward declaration\r\n"
+    "float4 g() { return f(); } // Recursive call to 'f'\r\n"
+    "float4 f() { return g(); } // First call to 'g'\r\n"
+    "float4 main() : SV_Position{\r\n"
+    "  return f(); // First call to 'f'\r\n"
+    "}\r\n";
+  VerifyCompileFailed(ShaderTextIndirect, L"vs_6_0", "recursive functions not allowed");
+
+  const char ShaderTextSelf[] =
+    "float4 main() : SV_Position{\r\n"
+    "  return main();\r\n"
+    "}\r\n";
+  VerifyCompileFailed(ShaderTextSelf, L"vs_6_0", "recursive functions not allowed");
+
+  const char ShaderTextMissing[] =
+    "float4 mainz() : SV_Position{\r\n"
+    "  return 1;\r\n"
+    "}\r\n";
+  VerifyCompileFailed(ShaderTextMissing, L"vs_6_0", "missing entry point definition");
 }
 
 

+ 3 - 0
tools/clang/unittests/HLSL/DxcTestUtils.h

@@ -100,6 +100,9 @@ inline std::string BlobToUtf8(_In_ IDxcBlob *pBlob) {
 
 std::wstring BlobToUtf16(_In_ IDxcBlob *pBlob);
 void CheckOperationSucceeded(IDxcOperationResult *pResult, IDxcBlob **ppBlob);
+bool CheckOperationResultMsgs(IDxcOperationResult *pResult,
+                              LPCSTR *pErrorMsgs, size_t errorMsgCount,
+                              bool maySucceedAnyway, bool bRegex);
 std::string DisassembleProgram(dxc::DxcDllSupport &dllSupport, IDxcBlob *pProgram);
 void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val, _Outptr_ IDxcBlob **ppBlob);
 void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val, _Outptr_ IDxcBlobEncoding **ppBlob);

+ 54 - 47
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -35,6 +35,57 @@ void CheckOperationSucceeded(IDxcOperationResult *pResult, IDxcBlob **ppBlob) {
   VERIFY_SUCCEEDED(pResult->GetResult(ppBlob));
 }
 
+static
+bool CheckOperationResultMsgs(IDxcOperationResult *pResult,
+                              llvm::ArrayRef<LPCSTR> pErrorMsgs,
+                              bool maySucceedAnyway, bool bRegex) {
+  HRESULT status;
+  CComPtr<IDxcBlobEncoding> text;
+  VERIFY_SUCCEEDED(pResult->GetStatus(&status));
+  VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&text));
+  const char *pStart = text ? (const char *)text->GetBufferPointer() : nullptr;
+  const char *pEnd = text ? pStart + text->GetBufferSize() : nullptr;
+  if (pErrorMsgs.empty() || (pErrorMsgs.size() == 1 && !pErrorMsgs[0])) {
+    if (FAILED(status) && pStart) {
+      WEX::Logging::Log::Comment(WEX::Common::String().Format(
+          L"Expected success but found errors\r\n%.*S", (pEnd - pStart),
+          pStart));
+    }
+    VERIFY_SUCCEEDED(status);
+  }
+  else {
+    if (SUCCEEDED(status) && maySucceedAnyway) {
+      return false;
+    }
+    for (auto pErrorMsg : pErrorMsgs) {
+      if (bRegex) {
+        llvm::Regex RE(pErrorMsg);
+        std::string reErrors;
+        VERIFY_IS_TRUE(RE.isValid(reErrors));
+        VERIFY_IS_TRUE(RE.match(llvm::StringRef((const char *)text->GetBufferPointer(), text->GetBufferSize())));
+      }
+      else {
+        const char *pMatch = std::search(pStart, pEnd, pErrorMsg, pErrorMsg + strlen(pErrorMsg));
+        if (pEnd == pMatch) {
+          WEX::Logging::Log::Comment(WEX::Common::String().Format(
+            L"Unable to find '%S' in text:\r\n%.*S", pErrorMsg, (pEnd - pStart),
+            pStart));
+        }
+        VERIFY_ARE_NOT_EQUAL(pEnd, pMatch);
+      }
+    }
+  }
+  return true;
+}
+
+bool CheckOperationResultMsgs(IDxcOperationResult *pResult, LPCSTR *pErrorMsgs,
+                              size_t errorMsgCount, bool maySucceedAnyway,
+                              bool bRegex) {
+  return CheckOperationResultMsgs(
+      pResult, llvm::ArrayRef<LPCSTR>(pErrorMsgs, errorMsgCount),
+      maySucceedAnyway, bRegex);
+}
+
 std::string DisassembleProgram(dxc::DxcDllSupport &dllSupport,
                                IDxcBlob *pProgram) {
   CComPtr<IDxcCompiler> pCompiler;
@@ -82,7 +133,6 @@ public:
   TEST_METHOD(Recursive)
   TEST_METHOD(Recursive2)
   TEST_METHOD(Recursive3)
-  TEST_METHOD(UserDefineFunction)
   TEST_METHOD(ResourceRangeOverlap0)
   TEST_METHOD(ResourceRangeOverlap1)
   TEST_METHOD(ResourceRangeOverlap2)
@@ -249,44 +299,6 @@ public:
     }
   }
 
-  bool CheckOperationResultMsgs(IDxcOperationResult *pResult,
-                               llvm::ArrayRef<LPCSTR> pErrorMsgs, bool maySucceedAnyway,
-                               bool bRegex) {
-    HRESULT status;
-    VERIFY_SUCCEEDED(pResult->GetStatus(&status));
-    if (pErrorMsgs.empty() || 
-        (pErrorMsgs.size() == 1 && !pErrorMsgs[0])) {
-      VERIFY_SUCCEEDED(status);
-    }
-    else {
-      if (SUCCEEDED(status) && maySucceedAnyway) {
-        return false;
-      }
-      //VERIFY_FAILED(status);
-      CComPtr<IDxcBlobEncoding> text;
-      VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&text));
-      for (auto pErrorMsg : pErrorMsgs) {
-        if (bRegex) {
-          llvm::Regex RE(pErrorMsg);
-          std::string reErrors;
-          VERIFY_IS_TRUE(RE.isValid(reErrors));
-          VERIFY_IS_TRUE(RE.match(llvm::StringRef((const char *)text->GetBufferPointer(), text->GetBufferSize())));
-        } else {
-          const char *pStart = (const char *)text->GetBufferPointer();
-          const char *pEnd = pStart + text->GetBufferSize();
-          const char *pMatch = std::search(pStart, pEnd, pErrorMsg, pErrorMsg + strlen(pErrorMsg));
-          if (pEnd == pMatch) {
-            WEX::Logging::Log::Comment(WEX::Common::String().Format(
-                L"Unable to find '%S' in text:\r\n%.*S", pErrorMsg, (pEnd - pStart),
-                pStart));
-          }
-          VERIFY_ARE_NOT_EQUAL(pEnd, pMatch);
-        }
-      }
-    }
-    return true;
-  }
-
   void CheckValidationMsgs(IDxcBlob *pBlob, llvm::ArrayRef<LPCSTR> pErrorMsgs, bool bRegex = false) {
     CComPtr<IDxcValidator> pValidator;
     CComPtr<IDxcOperationResult> pResult;
@@ -328,9 +340,7 @@ public:
     VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"hlsl.hlsl", L"main",
                                         shWide, nullptr, 0, nullptr, 0, nullptr,
                                         &pResult));
-    HRESULT hr;
-    VERIFY_SUCCEEDED(pResult->GetStatus(&hr));
-    VERIFY_SUCCEEDED(hr);
+    CheckOperationResultMsgs(pResult, nullptr, false, false);
     VERIFY_SUCCEEDED(pResult->GetResult(pResultBlob));
   }
 
@@ -1037,7 +1047,8 @@ TEST_F(ValidationTest, TypedUAVStoreFullMask1) {
 }
 
 TEST_F(ValidationTest, Recursive) {
-    TestCheck(L"..\\CodeGenHLSL\\recursive.hlsl");
+  // Includes coverage for user-defined functions.
+  TestCheck(L"..\\CodeGenHLSL\\recursive.ll");
 }
 
 TEST_F(ValidationTest, Recursive2) {
@@ -1048,10 +1059,6 @@ TEST_F(ValidationTest, Recursive3) {
     TestCheck(L"..\\CodeGenHLSL\\recursive3.hlsl");
 }
 
-TEST_F(ValidationTest, UserDefineFunction) {
-    TestCheck(L"..\\CodeGenHLSL\\recursive2.hlsl");
-}
-
 TEST_F(ValidationTest, ResourceRangeOverlap0) {
     RewriteAssemblyCheckMsg(
       L"..\\CodeGenHLSL\\resource_overlap.hlsl", "ps_6_0",

+ 2 - 2
utils/hct/hctbuild.cmd

@@ -165,8 +165,8 @@ echo current BUILD_ARCH=%BUILD_ARCH%.  Override with:
 echo   -x86 targets an x86 build (aka. Win32)
 echo   -x64 targets an x64 build (aka. Win64)
 echo   -arm targets an ARM build
-echo
-echo   AppVeyor Support
+echo.
+echo AppVeyor Support
 echo   -Release builds release
 echo   -Debug builds debug
 echo   -vs2017 uses Visual Studio 2017 to build