Browse Source

[dxil2spv] Add additional error checking (#4440)

Add some missing error checking for possible nullptrs, and use result of
diagnostic client for setting a succes/failure return value.
Natalie Chouinard 3 years ago
parent
commit
3c1918e21e

+ 7 - 4
tools/clang/tools/dxil2spv/dxil2spvmain.cpp

@@ -50,12 +50,12 @@ int main(int argc, const char **argv_) {
 #endif // _WIN32
   // Configure filesystem for llvm stdout and stderr handling.
   if (llvm::sys::fs::SetupPerThreadFileSystem())
-    return DXC_E_GENERAL_INTERNAL_ERROR;
+    return EXIT_FAILURE;
   llvm::sys::fs::AutoCleanupPerThreadFileSystem auto_cleanup_fs;
   llvm::sys::fs::MSFileSystem *msfPtr;
   HRESULT hr;
   if (!SUCCEEDED(hr = CreateMSFileSystemForDisk(&msfPtr)))
-    return DXC_E_GENERAL_INTERNAL_ERROR;
+    return EXIT_FAILURE;
   std::unique_ptr<llvm::sys::fs::MSFileSystem> msf(msfPtr);
   llvm::sys::fs::AutoPerThreadSystem pts(msf.get());
   llvm::STDStreamCloser stdStreamCloser;
@@ -63,7 +63,7 @@ int main(int argc, const char **argv_) {
   // Check input arguments.
   if (argc < 2) {
     llvm::errs() << "Required input file argument is missing\n";
-    return DXC_E_GENERAL_INTERNAL_ERROR;
+    return EXIT_FAILURE;
   }
 
   // Setup a compiler instance with diagnostics.
@@ -82,5 +82,8 @@ int main(int argc, const char **argv_) {
 
   // Run translator.
   clang::dxil2spv::Translator translator(instance);
-  return translator.Run();
+  translator.Run();
+
+  return instance.getDiagnosticClient().getNumErrors() > 0 ? EXIT_FAILURE
+                                                           : EXIT_SUCCESS;
 }

+ 58 - 29
tools/clang/tools/dxil2spv/lib/dxil2spv.cpp

@@ -45,14 +45,14 @@ Translator::Translator(CompilerInstance &instance)
       featureManager(diagnosticsEngine, spirvOptions),
       spvBuilder(spvContext, spirvOptions, featureManager) {}
 
-int Translator::Run() {
+void Translator::Run() {
   // Read input file to memory buffer.
   std::string filename = ci.getCodeGenOpts().MainFileName;
   auto errorOrInputFile = llvm::MemoryBuffer::getFileOrSTDIN(filename);
   if (!errorOrInputFile) {
     emitError("Error reading %0: %1")
         << filename << errorOrInputFile.getError().message();
-    return DXC_E_GENERAL_INTERNAL_ERROR;
+    return;
   }
   std::unique_ptr<llvm::MemoryBuffer> memoryBuffer =
       std::move(errorOrInputFile.get());
@@ -85,7 +85,7 @@ int Translator::Run() {
 
     if (module == nullptr) {
       emitError("Could not parse DXIL module from bitcode");
-      return DXC_E_GENERAL_INTERNAL_ERROR;
+      return;
     }
   }
   // Parse LLVM module from IR.
@@ -95,7 +95,7 @@ int Translator::Run() {
 
     if (module == nullptr) {
       emitError("Could not parse DXIL module from IR: %0") << err.getMessage();
-      return DXC_E_GENERAL_INTERNAL_ERROR;
+      return;
     }
   }
 
@@ -140,6 +140,12 @@ int Translator::Run() {
                                 {});
   }
 
+  // Don't attempt to emit SPIR-V module if errors were encountered in
+  // translation.
+  if (diagnosticsEngine.getClient()->getNumErrors() > 0) {
+    return;
+  }
+
   // Contsruct the SPIR-V module.
   std::vector<uint32_t> m = spvBuilder.takeModuleForDxilToSpv();
 
@@ -147,7 +153,6 @@ int Translator::Run() {
   std::string messages;
   if (!spirvToolsValidate(&m, &messages)) {
     emitError("Generated SPIR-V is invalid: %0") << messages;
-    // return DXC_E_GENERAL_INTERNAL_ERROR;
   }
 
   // Disassemble SPIR-V for output.
@@ -158,12 +163,10 @@ int Translator::Run() {
 
   if (!spirvTools.Disassemble(m, &assembly, spirvDisOpts)) {
     emitError("SPIR-V disassembly failed");
-    return DXC_E_GENERAL_INTERNAL_ERROR;
+    return;
   }
 
   *ci.getOutStream() << assembly;
-
-  return 0;
 }
 
 void Translator::createStageIOVariables(
@@ -188,6 +191,11 @@ void Translator::createStageIOVariables(
 
 void Translator::createStageIOVariable(hlsl::DxilSignatureElement *elem) {
   const spirv::SpirvType *spirvType = toSpirvType(elem);
+  if (!spirvType) {
+    emitError("Failed to translate DXIL signature element to SPIR-V: %0")
+        << elem->GetName();
+    return;
+  }
   spv::StorageClass storageClass =
       elem->IsInput() ? spv::StorageClass::Input : spv::StorageClass::Output;
   const unsigned id = elem->GetID();
@@ -223,8 +231,14 @@ void Translator::createModuleVariables(
     assert(hlslType->isPointerTy());
     llvm::Type *pointeeType =
         cast<llvm::PointerType>(hlslType)->getPointerElementType();
-    spirv::SpirvVariable *moduleVar = spvBuilder.addModuleVar(
-        toSpirvType(pointeeType), spv::StorageClass::Uniform, false);
+    const spirv::SpirvType *spirvType = toSpirvType(pointeeType);
+    if (!spirvType) {
+      emitError("Failed to translate DXIL resource to SPIR-V: %0")
+          << resource->GetID();
+      return;
+    }
+    spirv::SpirvVariable *moduleVar =
+        spvBuilder.addModuleVar(spirvType, spv::StorageClass::Uniform, false);
     spvBuilder.decorateDSetBinding(moduleVar, nextDescriptorSet,
                                    nextBindingNo++);
     resourceMap[{static_cast<unsigned>(resource->GetClass()),
@@ -336,11 +350,21 @@ void Translator::createLoadInputInstruction(llvm::CallInst &instruction) {
                                   hlsl::DXIL::OperandIndex::kLoadInputIDOpIdx))
           ->getLimitedValue();
   spirv::SpirvVariable *inputVar = inputSignatureElementMap[inputID];
+  if (!inputVar) {
+    emitError(
+        "No matching SPIR-V input variable found for load instruction: %0",
+        instruction);
+    return;
+  }
   const spirv::SpirvType *inputVarType = inputVar->getResultType();
 
   // TODO: Handle other input signature types. Only vector for initial
   // passthrough shader support.
-  assert(isa<spirv::VectorType>(inputVarType));
+  if (!isa<spirv::VectorType>(inputVarType)) {
+    emitError("Input signature type not yet supported for load instruction: %0",
+              instruction);
+    return;
+  }
   const spirv::SpirvType *elemType =
       cast<spirv::VectorType>(inputVarType)->getElementType();
 
@@ -366,11 +390,22 @@ void Translator::createStoreOutputInstruction(llvm::CallInst &instruction) {
                               hlsl::DXIL::OperandIndex::kStoreOutputIDOpIdx))
                           ->getLimitedValue();
   spirv::SpirvVariable *outputVar = outputSignatureElementMap[outputID];
+  if (!outputVar) {
+    emitError(
+        "No matching SPIR-V output variable found for store output ID: %0")
+        << outputID;
+    return;
+  }
   const spirv::SpirvType *outputVarType = outputVar->getResultType();
 
   // TODO: Handle other output signature types. Only vector for initial
   // passthrough shader support.
-  assert(isa<spirv::VectorType>(outputVarType));
+  if (!isa<spirv::VectorType>(outputVarType)) {
+    emitError(
+        "Output signature type not yet supported for store instruction: %0",
+        instruction);
+    return;
+  }
   const spirv::SpirvType *elemType =
       cast<spirv::VectorType>(outputVarType)->getElementType();
 
@@ -640,6 +675,11 @@ void Translator::createExtractValueInstruction(
 
   // Create access chain and save mapping.
   const spirv::SpirvType *returnType = toSpirvType(instruction.getType());
+  if (!returnType) {
+    emitError("Failed to translate return type to SPIR-V for instruction: %0",
+              instruction);
+    return;
+  }
   spirv::SpirvAccessChain *accessChain =
       spvBuilder.createAccessChain(returnType, spvInstruction, indices, {});
 
@@ -729,7 +769,12 @@ const spirv::SpirvType *Translator::toSpirvType(llvm::StructType *structType) {
   std::vector<spirv::StructType::FieldInfo> fields;
   fields.reserve(structType->getNumElements());
   for (llvm::Type *elemType : structType->elements()) {
-    fields.emplace_back(toSpirvType(elemType));
+    const spirv::SpirvType *spirvType = toSpirvType(elemType);
+    if (!spirvType) {
+      emitError("Failed to translate struct field to SPIR-V: %0", *elemType);
+      return nullptr;
+    }
+    fields.emplace_back(spirvType);
   }
   return spvContext.getStructType(fields, name);
 }
@@ -762,21 +807,5 @@ Translator::createSpirvConstant(llvm::Constant *instruction) {
   return nullptr;
 }
 
-template <unsigned N>
-DiagnosticBuilder Translator::emitError(const char (&message)[N]) {
-  const auto diagId =
-      diagnosticsEngine.getCustomDiagID(DiagnosticsEngine::Error, message);
-  return diagnosticsEngine.Report({}, diagId);
-}
-
-template <unsigned N>
-DiagnosticBuilder Translator::emitError(const char (&message)[N],
-                                        llvm::Value &value) {
-  std::string str;
-  llvm::raw_string_ostream os(str);
-  value.print(os);
-  return emitError(message) << os.str();
-}
-
 } // namespace dxil2spv
 } // namespace clang

+ 24 - 5
tools/clang/tools/dxil2spv/lib/dxil2spv.h

@@ -29,7 +29,7 @@ namespace dxil2spv {
 class Translator {
 public:
   Translator(CompilerInstance &instance);
-  int Run();
+  void Run();
 
 private:
   CompilerInstance &ci;
@@ -110,11 +110,30 @@ private:
   unsigned nextDescriptorSet = 0;
   unsigned nextBindingNo = 0;
 
-  // Helper diagnostic functions for emitting error messages.
-  template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]);
+  // Helper diagnostic functions for emitting error messages. message should be
+  // a fixed diagnostic format string using the syntax expected by the
+  // DiagnosticIDs interface.
+  template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
+    const auto diagId =
+        diagnosticsEngine.getCustomDiagID(DiagnosticsEngine::Error, message);
+    return diagnosticsEngine.Report({}, diagId);
+  }
+
+  template <unsigned N>
+  DiagnosticBuilder emitError(const char (&message)[N], llvm::Value &value) {
+    std::string str;
+    llvm::raw_string_ostream os(str);
+    value.print(os);
+    return emitError(message) << os.str();
+  }
+
   template <unsigned N>
-  DiagnosticBuilder emitError(const char (&message)[N],
-                              llvm::Value &instruction);
+  DiagnosticBuilder emitError(const char (&message)[N], llvm::Type &type) {
+    std::string str;
+    llvm::raw_string_ostream os(str);
+    type.print(os);
+    return emitError(message) << os.str();
+  }
 };
 
 } // namespace dxil2spv