Browse Source

Fix crashing from null elements producing errors (#3487)

Under certain rare circumstances, it is possible for errors to be
reported for function or global variables that are null.

This eliminates the absolute necessity for a non-null function or GV by
passing in the context needed to report any error. It also avoids some
of the cases where trying to get a debug function would give a null
which would replace the non-null function.
Greg Roth 4 years ago
parent
commit
72e46b3709

+ 4 - 4
include/dxc/DXIL/DxilUtil.h

@@ -82,10 +82,10 @@ namespace dxilutil {
 
   void EmitErrorOnInstruction(llvm::Instruction *I, llvm::Twine Msg);
   void EmitWarningOnInstruction(llvm::Instruction *I, llvm::Twine Msg);
-  void EmitErrorOnFunction(llvm::Function *F, llvm::Twine Msg);
-  void EmitWarningOnFunction(llvm::Function *F, llvm::Twine Msg);
-  void EmitErrorOnGlobalVariable(llvm::GlobalVariable *GV, llvm::Twine Msg);
-  void EmitWarningOnGlobalVariable(llvm::GlobalVariable *GV, llvm::Twine Msg);
+  void EmitErrorOnFunction(llvm::LLVMContext &Ctx, llvm::Function *F, llvm::Twine Msg);
+  void EmitWarningOnFunction(llvm::LLVMContext &Ctx, llvm::Function *F, llvm::Twine Msg);
+  void EmitErrorOnGlobalVariable(llvm::LLVMContext &Ctx, llvm::GlobalVariable *GV, llvm::Twine Msg);
+  void EmitWarningOnGlobalVariable(llvm::LLVMContext &Ctx, llvm::GlobalVariable *GV, llvm::Twine Msg);
   void EmitErrorOnContext(llvm::LLVMContext &Ctx, llvm::Twine Msg);
   void EmitWarningOnContext(llvm::LLVMContext &Ctx, llvm::Twine Msg);
   void EmitNoteOnContext(llvm::LLVMContext &Ctx, llvm::Twine Msg);

+ 29 - 28
lib/DXIL/DxilUtil.cpp

@@ -287,56 +287,57 @@ void EmitWarningOnInstruction(Instruction *I, Twine Msg) {
   EmitWarningOrErrorOnInstruction(I, Msg, DiagnosticSeverity::DS_Warning);
 }
 
-static void EmitWarningOrErrorOnFunction(Function *F, Twine Msg,
+static void EmitWarningOrErrorOnFunction(llvm::LLVMContext &Ctx, Function *F, Twine Msg,
                                          DiagnosticSeverity severity) {
-  DISubprogram *DISP = getDISubprogram(F);
   DILocation *DLoc = nullptr;
-  if (DISP) {
+
+  if (DISubprogram *DISP = getDISubprogram(F)) {
     DLoc = DILocation::get(F->getContext(), DISP->getLine(), 0,
                            DISP, nullptr /*InlinedAt*/);
   }
-  F->getContext().diagnose(DiagnosticInfoDxil(F, DLoc, Msg, severity));
+  Ctx.diagnose(DiagnosticInfoDxil(F, DLoc, Msg, severity));
 }
 
-void EmitErrorOnFunction(Function *F, Twine Msg) {
-  EmitWarningOrErrorOnFunction(F, Msg, DiagnosticSeverity::DS_Error);
+void EmitErrorOnFunction(llvm::LLVMContext &Ctx, Function *F, Twine Msg) {
+  EmitWarningOrErrorOnFunction(Ctx, F, Msg, DiagnosticSeverity::DS_Error);
 }
 
-void EmitWarningOnFunction(Function *F, Twine Msg) {
-  EmitWarningOrErrorOnFunction(F, Msg, DiagnosticSeverity::DS_Warning);
+void EmitWarningOnFunction(llvm::LLVMContext &Ctx, Function *F, Twine Msg) {
+  EmitWarningOrErrorOnFunction(Ctx, F, Msg, DiagnosticSeverity::DS_Warning);
 }
 
-static void EmitWarningOrErrorOnGlobalVariable(GlobalVariable *GV,
+static void EmitWarningOrErrorOnGlobalVariable(llvm::LLVMContext &Ctx, GlobalVariable *GV,
                                                Twine Msg, DiagnosticSeverity severity) {
   DIVariable *DIV = nullptr;
-  if (!GV) return;
 
-  Module &M = *GV->getParent();
   DILocation *DLoc = nullptr;
 
-  if (getDebugMetadataVersionFromModule(M) != 0) {
-    DebugInfoFinder FinderObj;
-    DebugInfoFinder &Finder = FinderObj;
-    // Debug modules have no dxil modules. Use it if you got it.
-    if (M.HasDxilModule())
-      Finder = M.GetDxilModule().GetOrCreateDebugInfoFinder();
-    else
-      Finder.processModule(M);
-    DIV = FindGlobalVariableDebugInfo(GV, Finder);
-    if (DIV)
-      DLoc = DILocation::get(GV->getContext(), DIV->getLine(), 0,
-                             DIV->getScope(), nullptr /*InlinedAt*/);
+  if (GV) {
+    Module &M = *GV->getParent();
+    if (getDebugMetadataVersionFromModule(M) != 0) {
+      DebugInfoFinder FinderObj;
+      DebugInfoFinder &Finder = FinderObj;
+      // Debug modules have no dxil modules. Use it if you got it.
+      if (M.HasDxilModule())
+        Finder = M.GetDxilModule().GetOrCreateDebugInfoFinder();
+      else
+        Finder.processModule(M);
+      DIV = FindGlobalVariableDebugInfo(GV, Finder);
+      if (DIV)
+        DLoc = DILocation::get(GV->getContext(), DIV->getLine(), 0,
+                               DIV->getScope(), nullptr /*InlinedAt*/);
+    }
   }
 
-  GV->getContext().diagnose(DiagnosticInfoDxil(nullptr /*Function*/, DLoc, Msg, severity));
+  Ctx.diagnose(DiagnosticInfoDxil(nullptr /*Function*/, DLoc, Msg, severity));
 }
 
-void EmitErrorOnGlobalVariable(GlobalVariable *GV, Twine Msg) {
-  EmitWarningOrErrorOnGlobalVariable(GV, Msg, DiagnosticSeverity::DS_Error);
+void EmitErrorOnGlobalVariable(llvm::LLVMContext &Ctx, GlobalVariable *GV, Twine Msg) {
+  EmitWarningOrErrorOnGlobalVariable(Ctx, GV, Msg, DiagnosticSeverity::DS_Error);
 }
 
-void EmitWarningOnGlobalVariable(GlobalVariable *GV, Twine Msg) {
-  EmitWarningOrErrorOnGlobalVariable(GV, Msg, DiagnosticSeverity::DS_Warning);
+void EmitWarningOnGlobalVariable(llvm::LLVMContext &Ctx, GlobalVariable *GV, Twine Msg) {
+  EmitWarningOrErrorOnGlobalVariable(Ctx, GV, Msg, DiagnosticSeverity::DS_Warning);
 }
 
 const char *kResourceMapErrorMsg =

+ 9 - 9
lib/HLSL/DxilCondenseResources.cpp

@@ -117,7 +117,7 @@ private:
 
   template <typename T>
   static bool
-  AllocateRegisters(const std::vector<std::unique_ptr<T>> &resourceList,
+  AllocateRegisters(LLVMContext &Ctx, const std::vector<std::unique_ptr<T>> &resourceList,
     SpacesAllocator<unsigned, T> &ReservedRegisters,
     unsigned AutoBindingSpace) {
     bool bChanged = false;
@@ -135,7 +135,7 @@ private:
         if (res->IsUnbounded()) {
           const T *unbounded = alloc.GetUnbounded();
           if (unbounded) {
-            dxilutil::EmitErrorOnGlobalVariable(dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
+            dxilutil::EmitErrorOnGlobalVariable(Ctx, dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
                                                 Twine("more than one unbounded resource (") +
                                                 unbounded->GetGlobalName() + (" and ") +
                                                 res->GetGlobalName() + (") in space ") + Twine(space));
@@ -152,7 +152,7 @@ private:
           conflict = alloc.Insert(res.get(), reg, res->GetUpperBound());
         }
         if (conflict) {
-          dxilutil::EmitErrorOnGlobalVariable(dyn_cast<GlobalVariable>(res->GetGlobalSymbol()), 
+          dxilutil::EmitErrorOnGlobalVariable(Ctx, dyn_cast<GlobalVariable>(res->GetGlobalSymbol()), 
                                               ((res->IsUnbounded()) ? Twine("unbounded ") : Twine("")) +
                                               Twine("resource ") + res->GetGlobalName() +
                                               Twine(" at register ") + Twine(reg) +
@@ -184,7 +184,7 @@ private:
       if (res->IsUnbounded()) {
         if (alloc.GetUnbounded() != nullptr) {
           const T *unbounded = alloc.GetUnbounded();
-          dxilutil::EmitErrorOnGlobalVariable(dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
+          dxilutil::EmitErrorOnGlobalVariable(Ctx, dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
                                               Twine("more than one unbounded resource (") +
                                               unbounded->GetGlobalName() + Twine(" and ") +
                                               res->GetGlobalName() + Twine(") in space ") +
@@ -218,7 +218,7 @@ private:
         res->SetSpaceID(space);
         bChanged = true;
       } else {
-        dxilutil::EmitErrorOnGlobalVariable(dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
+        dxilutil::EmitErrorOnGlobalVariable(Ctx, dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
                                             ((res->IsUnbounded()) ? Twine("unbounded ") : Twine("")) +
                                             Twine("resource ") + res->GetGlobalName() +
                                             Twine(" could not be allocated"));
@@ -251,10 +251,10 @@ public:
     }
 
     bool bChanged = false;
-    bChanged |= AllocateRegisters(DM.GetCBuffers(), m_reservedCBufferRegisters, AutoBindingSpace);
-    bChanged |= AllocateRegisters(DM.GetSamplers(), m_reservedSamplerRegisters, AutoBindingSpace);
-    bChanged |= AllocateRegisters(DM.GetUAVs(), m_reservedUAVRegisters, AutoBindingSpace);
-    bChanged |= AllocateRegisters(DM.GetSRVs(), m_reservedSRVRegisters, AutoBindingSpace);
+    bChanged |= AllocateRegisters(DM.GetCtx(), DM.GetCBuffers(), m_reservedCBufferRegisters, AutoBindingSpace);
+    bChanged |= AllocateRegisters(DM.GetCtx(), DM.GetSamplers(), m_reservedSamplerRegisters, AutoBindingSpace);
+    bChanged |= AllocateRegisters(DM.GetCtx(), DM.GetUAVs(), m_reservedUAVRegisters, AutoBindingSpace);
+    bChanged |= AllocateRegisters(DM.GetCtx(), DM.GetSRVs(), m_reservedSRVRegisters, AutoBindingSpace);
     return bChanged;
   }
 };

+ 2 - 2
lib/HLSL/DxilGenerationPass.cpp

@@ -209,7 +209,7 @@ public:
     if (!SM->IsLib()) {
       Function *EntryFn = m_pHLModule->GetEntryFunction();
       if (!m_pHLModule->HasDxilFunctionProps(EntryFn)) {
-        dxilutil::EmitErrorOnFunction(EntryFn, "Entry function don't have property.");
+        dxilutil::EmitErrorOnFunction(M.getContext(), EntryFn, "Entry function don't have property.");
         return false;
       }
       DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(EntryFn);
@@ -261,7 +261,7 @@ public:
           if (F.user_empty()) {
             F.eraseFromParent();
           } else {
-            dxilutil::EmitErrorOnFunction(&F, "Fail to lower createHandle.");
+            dxilutil::EmitErrorOnFunction(M.getContext(), &F, "Fail to lower createHandle.");
           }
         }
       }

+ 5 - 5
lib/HLSL/DxilLinker.cpp

@@ -487,7 +487,7 @@ bool DxilLinkJob::AddResource(DxilResourceBase *res, llvm::GlobalVariable *GV) {
     bool bMatch = IsMatchedType(Ty0, Ty);
     if (!bMatch) {
       // Report error.
-      dxilutil::EmitErrorOnGlobalVariable(dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
+      dxilutil::EmitErrorOnGlobalVariable(m_ctx, dyn_cast<GlobalVariable>(res->GetGlobalSymbol()),
                                           Twine(kRedefineResource) + res->GetResClassName() + " for " +
                                           res->GetGlobalName());
       return false;
@@ -636,7 +636,7 @@ bool DxilLinkJob::AddGlobals(DxilModule &DM, ValueToValueMapTy &vmap) {
           }
 
           // Redefine of global.
-          dxilutil::EmitErrorOnGlobalVariable(GV, Twine(kRedefineGlobal) + GV->getName());
+          dxilutil::EmitErrorOnGlobalVariable(m_ctx, GV, Twine(kRedefineGlobal) + GV->getName());
           bSuccess = false;
         }
         continue;
@@ -724,7 +724,7 @@ DxilLinkJob::Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
   DxilModule &entryDM = entryLinkPair.second->GetDxilModule();
   if (!entryDM.HasDxilFunctionProps(entryFunc)) {
     // Cannot get function props.
-    dxilutil::EmitErrorOnFunction(entryFunc, Twine(kNoEntryProps) + entryFunc->getName());
+    dxilutil::EmitErrorOnFunction(m_ctx, entryFunc, Twine(kNoEntryProps) + entryFunc->getName());
     return nullptr;
   }
 
@@ -732,7 +732,7 @@ DxilLinkJob::Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
 
   if (pSM->GetKind() != props.shaderKind) {
     // Shader kind mismatch.
-    dxilutil::EmitErrorOnFunction(entryFunc, Twine(kShaderKindMismatch) +
+    dxilutil::EmitErrorOnFunction(m_ctx, entryFunc, Twine(kShaderKindMismatch) +
                                   ShaderModel::GetKindName(pSM->GetKind()) + " and " +
                                   ShaderModel::GetKindName(props.shaderKind));
     return nullptr;
@@ -1331,7 +1331,7 @@ bool DxilLinkerImpl::AttachLib(DxilLib *lib) {
     if (m_functionNameMap.count(name)) {
       // Redefine of function.
       const DxilFunctionLinkInfo *DFLI = it->getValue().get();
-      dxilutil::EmitErrorOnFunction(DFLI->func, Twine(kRedefineFunction) + name);
+      dxilutil::EmitErrorOnFunction(m_ctx, DFLI->func, Twine(kRedefineFunction) + name);
       bSuccess = false;
       continue;
     }

+ 7 - 5
lib/HLSL/DxilValidation.cpp

@@ -636,7 +636,7 @@ struct ValidationContext {
     FormatRuleText(ruleText, args);
     if (pDebugModule)
       GV = pDebugModule->getGlobalVariable(GV->getName());
-    dxilutil::EmitErrorOnGlobalVariable(GV, ruleText);
+    dxilutil::EmitErrorOnGlobalVariable(M.getContext(), GV, ruleText);
     Failed = true;
   }
 
@@ -805,8 +805,9 @@ struct ValidationContext {
 
   void EmitFnError(Function *F, ValidationRule rule) {
     if (pDebugModule)
-      F = pDebugModule->getFunction(F->getName());
-    dxilutil::EmitErrorOnFunction(F, GetValidationRuleText(rule));
+      if (Function *dbgF = pDebugModule->getFunction(F->getName()))
+        F = dbgF;
+    dxilutil::EmitErrorOnFunction(M.getContext(), F, GetValidationRuleText(rule));
     Failed = true;
   }
 
@@ -814,8 +815,9 @@ struct ValidationContext {
     std::string ruleText = GetValidationRuleText(rule);
     FormatRuleText(ruleText, args);
     if (pDebugModule)
-      F = pDebugModule->getFunction(F->getName());
-    dxilutil::EmitErrorOnFunction(F, ruleText);
+      if (Function *dbgF = pDebugModule->getFunction(F->getName()))
+        F = dbgF;
+    dxilutil::EmitErrorOnFunction(M.getContext(), F, ruleText);
     Failed = true;
   }
 

+ 14 - 14
lib/HLSL/HLSignatureLower.cpp

@@ -248,7 +248,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
   if (sigPoint->GetKind() == DXIL::SigPointKind::MSPOut) {
     if (interpMode != InterpolationMode::Kind::Undefined &&
         interpMode != InterpolationMode::Kind::Constant) {
-      dxilutil::EmitErrorOnFunction(func,
+      dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), func,
         "Mesh shader's primitive outputs' interpolation mode must be constant or undefined.");
     }
     interpMode = InterpolationMode::Kind::Constant;
@@ -270,7 +270,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
 
   llvm::StringRef semanticStr = paramAnnotation.GetSemanticString();
   if (semanticStr.empty()) {
-    dxilutil::EmitErrorOnFunction(func,
+    dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), func,
         "Semantic must be defined for all parameters of an entry function or "
         "patch constant function");
     return;
@@ -302,7 +302,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
       auto &SemanticIndexSet = SemanticUseMap[(unsigned)pSemantic->GetKind()];
       for (unsigned idx : paramAnnotation.GetSemanticIndexVec()) {
         if (SemanticIndexSet.count(idx) > 0) {
-          dxilutil::EmitErrorOnFunction(func, "Parameter with semantic " + semanticStr +
+          dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), func, "Parameter with semantic " + semanticStr +
             " has overlapping semantic index at " + std::to_string(idx) + ".");
           return;
         }
@@ -319,7 +319,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
                0) ||
           (pSemantic->GetKind() == DXIL::SemanticKind::InnerCoverage &&
            SemanticUseMap.count((unsigned)DXIL::SemanticKind::Coverage) > 0)) {
-        dxilutil::EmitErrorOnFunction(func,
+        dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), func,
             "Pixel shader inputs SV_Coverage and SV_InnerCoverage are mutually "
             "exclusive.");
         return;
@@ -332,7 +332,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
   {
     switch (interpretation) {
     case DXIL::SemanticInterpretationKind::NA: {
-      dxilutil::EmitErrorOnFunction(func, Twine("Semantic ") + semanticStr +
+      dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), func, Twine("Semantic ") + semanticStr +
                                     Twine(" is invalid for shader model: ") +
                                     ShaderModel::GetKindName(props.shaderKind));
 
@@ -393,7 +393,7 @@ void HLSignatureLower::ProcessArgument(Function *func,
       pSE = FindArgInSignature(arg, paramAnnotation.GetSemanticString(),
                                interpMode, sigPoint->GetKind(), *pSig);
       if (!pSE) {
-        dxilutil::EmitErrorOnFunction(func, Twine("Signature element ") + semanticStr +
+        dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), func, Twine("Signature element ") + semanticStr +
                                       Twine(", referred to by patch constant function, is not found in "
                                             "corresponding hull shader ") +
                                       (sigKind == DXIL::SignatureKind::Input ? "input." : "output."));
@@ -457,7 +457,7 @@ void HLSignatureLower::CreateDxilSignatures() {
   }
 
   if (bHasClipPlane) {
-    dxilutil::EmitErrorOnFunction(Entry, "Cannot use clipplanes attribute without "
+    dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), Entry, "Cannot use clipplanes attribute without "
                                   "specifying a 4-component SV_Position "
                                   "output");
   }
@@ -467,7 +467,7 @@ void HLSignatureLower::CreateDxilSignatures() {
   if (props.shaderKind == DXIL::ShaderKind::Hull) {
     Function *patchConstantFunc = props.ShaderProps.HS.patchConstantFunc;
     if (patchConstantFunc == nullptr) {
-      dxilutil::EmitErrorOnFunction(Entry,
+      dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), Entry,
           "Patch constant function is not specified.");
     }
 
@@ -496,14 +496,14 @@ void HLSignatureLower::AllocateDxilInputOutputs() {
 
   hlsl::PackDxilSignature(EntrySig.InputSignature, packing);
   if (!EntrySig.InputSignature.IsFullyAllocated()) {
-    dxilutil::EmitErrorOnFunction(Entry,
+    dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), Entry,
         "Failed to allocate all input signature elements in available space.");
   }
 
   if (props.shaderKind != DXIL::ShaderKind::Amplification) {
     hlsl::PackDxilSignature(EntrySig.OutputSignature, packing);
     if (!EntrySig.OutputSignature.IsFullyAllocated()) {
-      dxilutil::EmitErrorOnFunction(Entry,
+      dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), Entry,
           "Failed to allocate all output signature elements in available space.");
     }
   }
@@ -513,7 +513,7 @@ void HLSignatureLower::AllocateDxilInputOutputs() {
       props.shaderKind == DXIL::ShaderKind::Mesh) {
     hlsl::PackDxilSignature(EntrySig.PatchConstOrPrimSignature, packing);
     if (!EntrySig.PatchConstOrPrimSignature.IsFullyAllocated()) {
-      dxilutil::EmitErrorOnFunction(Entry,
+      dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), Entry,
                              "Failed to allocate all patch constant signature "
                              "elements in available space.");
     }
@@ -1152,7 +1152,7 @@ void HLSignatureLower::GenerateDxilInputsOutputs(DXIL::SignatureKind SK) {
       OSS << "(type for " << SE->GetName() << ")";
       OSS << " cannot be used as shader inputs or outputs.";
       OSS.flush();
-      dxilutil::EmitErrorOnFunction(Entry, O);
+      dxilutil::EmitErrorOnFunction(M.getContext(), Entry, O);
       continue;
     }
     Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty);
@@ -1223,7 +1223,7 @@ void HLSignatureLower::GenerateDxilCSInputs() {
 
     llvm::StringRef semanticStr = paramAnnotation.GetSemanticString();
     if (semanticStr.empty()) {
-      dxilutil::EmitErrorOnFunction(Entry, "Semantic must be defined for all "
+      dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), Entry, "Semantic must be defined for all "
                                     "parameters of an entry function or patch "
                                     "constant function.");
       return;
@@ -1248,7 +1248,7 @@ void HLSignatureLower::GenerateDxilCSInputs() {
     default:
       DXASSERT(semantic->IsInvalid(),
                "else compute shader semantics out-of-date");
-      dxilutil::EmitErrorOnFunction(Entry, "invalid semantic found in CS");
+      dxilutil::EmitErrorOnFunction(HLM.GetModule()->getContext(), Entry, "invalid semantic found in CS");
       return;
     }