浏览代码

Add dxcutil to share code. (#350)

Xiang Li 8 年之前
父节点
当前提交
e380b8e915

+ 26 - 0
include/dxc/HLSL/DxilValidation.h

@@ -19,6 +19,8 @@ namespace llvm {
 class Module;
 class LLVMContext;
 class raw_ostream;
+class DiagnosticPrinter;
+class DiagnosticInfo;
 }
 
 namespace hlsl {
@@ -269,6 +271,14 @@ HRESULT ValidateLoadModule(_In_reads_bytes_(ILLength) const char *pIL,
                            _In_ llvm::LLVMContext &Ctx,
                            _In_ llvm::raw_ostream &DiagStream);
 
+// Loads module from container, validating load, but not module.
+HRESULT ValidateLoadModuleFromContainer(
+    _In_reads_bytes_(ContainerSize) const void *pContainer,
+    _In_ uint32_t ContainerSize, _In_ std::unique_ptr<llvm::Module> &pModule,
+    _In_ std::unique_ptr<llvm::Module> &pDebugModule,
+    _In_ llvm::LLVMContext &Ctx, llvm::LLVMContext &DbgCtx,
+    _In_ llvm::raw_ostream &DiagStream);
+
 // Load and validate Dxil module from bitcode.
 HRESULT ValidateDxilBitcode(_In_reads_bytes_(ILLength) const char *pIL,
                             _In_ uint32_t ILLength,
@@ -279,4 +289,20 @@ HRESULT ValidateDxilContainer(_In_reads_bytes_(ContainerSize) const void *pConta
                               _In_ uint32_t ContainerSize,
                               _In_ llvm::raw_ostream &DiagStream);
 
+class PrintDiagnosticContext {
+private:
+  llvm::DiagnosticPrinter &m_Printer;
+  bool m_errorsFound;
+  bool m_warningsFound;
+
+public:
+  PrintDiagnosticContext(llvm::DiagnosticPrinter &printer);
+
+  bool HasErrors() const;
+  bool HasWarnings() const;
+  void Handle(const llvm::DiagnosticInfo &DI);
+
+  static void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI,
+                                     void *Context);
+};
 }

+ 64 - 51
lib/HLSL/DxilValidation.cpp

@@ -252,39 +252,6 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
 
 namespace {
 
-class PrintDiagnosticContext {
-private:
-  DiagnosticPrinter &m_Printer;
-  bool m_errorsFound;
-  bool m_warningsFound;
-public:
-  PrintDiagnosticContext(DiagnosticPrinter &printer)
-      : m_Printer(printer), m_errorsFound(false), m_warningsFound(false) {}
-
-  bool HasErrors() const {
-    return m_errorsFound;
-  }
-  bool HasWarnings() const {
-    return m_warningsFound;
-  }
-  void Handle(const DiagnosticInfo &DI) {
-    DI.print(m_Printer);
-    switch (DI.getSeverity()) {
-    case llvm::DiagnosticSeverity::DS_Error:
-      m_errorsFound = true;
-      break;
-    case llvm::DiagnosticSeverity::DS_Warning:
-      m_warningsFound = true;
-      break;
-    }
-    m_Printer << "\n";
-  }
-};
-
-static void PrintDiagnosticHandler(const DiagnosticInfo &DI, void *Context) {
-  reinterpret_cast<PrintDiagnosticContext *>(Context)->Handle(DI);
-}
-
 // Utility class for setting and restoring the diagnostic context so we may capture errors/warnings
 struct DiagRestore {
   LLVMContext &Ctx;
@@ -294,7 +261,8 @@ struct DiagRestore {
   DiagRestore(llvm::LLVMContext &Ctx, void *DiagContext) : Ctx(Ctx) {
     OrigHandler = Ctx.getDiagnosticHandler();
     OrigDiagContext = Ctx.getDiagnosticContext();
-    Ctx.setDiagnosticHandler(PrintDiagnosticHandler, DiagContext);
+    Ctx.setDiagnosticHandler(
+        hlsl::PrintDiagnosticContext::PrintDiagnosticHandler, DiagContext);
   }
   ~DiagRestore() {
     Ctx.setDiagnosticHandler(OrigHandler, OrigDiagContext);
@@ -333,6 +301,29 @@ static inline DiagnosticPrinter &operator<<(DiagnosticPrinter &OS, Type &T) {
 
 namespace hlsl {
 
+// PrintDiagnosticContext methods.
+PrintDiagnosticContext::PrintDiagnosticContext(DiagnosticPrinter &printer)
+    : m_Printer(printer), m_errorsFound(false), m_warningsFound(false) {}
+
+bool PrintDiagnosticContext::HasErrors() const { return m_errorsFound; }
+bool PrintDiagnosticContext::HasWarnings() const { return m_warningsFound; }
+void PrintDiagnosticContext::Handle(const DiagnosticInfo &DI) {
+  DI.print(m_Printer);
+  switch (DI.getSeverity()) {
+  case llvm::DiagnosticSeverity::DS_Error:
+    m_errorsFound = true;
+    break;
+  case llvm::DiagnosticSeverity::DS_Warning:
+    m_warningsFound = true;
+    break;
+  }
+  m_Printer << "\n";
+}
+
+void PrintDiagnosticContext::PrintDiagnosticHandler(const DiagnosticInfo &DI, void *Context) {
+  reinterpret_cast<hlsl::PrintDiagnosticContext *>(Context)->Handle(DI);
+}
+
 struct PSExecutionInfo {
   bool SuperSampling = false;
   DXIL::SemanticKind OutputDepthKind = DXIL::SemanticKind::Invalid;
@@ -4499,7 +4490,8 @@ HRESULT ValidateDxilBitcode(
 
   llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
   PrintDiagnosticContext DiagContext(DiagPrinter);
-  Ctx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagContext, true);
+  Ctx.setDiagnosticHandler(PrintDiagnosticContext::PrintDiagnosticHandler,
+                           &DiagContext, true);
 
   HRESULT hr;
   if (FAILED(hr = ValidateLoadModule(pIL, ILLength, pModule, Ctx, DiagStream)))
@@ -4545,45 +4537,66 @@ HRESULT ValidateDxilBitcode(
 }
 
 _Use_decl_annotations_
-HRESULT ValidateDxilContainer(const void *pContainer,
-                              uint32_t ContainerSize,
-                              llvm::raw_ostream &DiagStream) {
-
-  LLVMContext Ctx, DbgCtx;
-  std::unique_ptr<llvm::Module> pModule, pDebugModule;
-
+HRESULT ValidateLoadModuleFromContainer(
+    _In_reads_bytes_(ILLength) const void *pContainer,
+    _In_ uint32_t ContainerSize, _In_ std::unique_ptr<llvm::Module> &pModule,
+    _In_ std::unique_ptr<llvm::Module> &pDebugModule,
+    _In_ llvm::LLVMContext &Ctx, LLVMContext &DbgCtx,
+    _In_ llvm::raw_ostream &DiagStream) {
   llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
   PrintDiagnosticContext DiagContext(DiagPrinter);
-  Ctx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagContext, true);
-  DbgCtx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagContext, true);
+  DiagRestore DR(Ctx, &DiagContext);
+  DiagRestore DR2(DbgCtx, &DiagContext);
 
-  HRESULT hr;
   const DxilPartHeader *pPart = nullptr;
   IFR(FindDxilPart(pContainer, ContainerSize, DFCC_DXIL, &pPart));
 
   const char *pIL = nullptr;
   uint32_t ILLength = 0;
   GetDxilProgramBitcode(
-    reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(pPart)),
-    &pIL, &ILLength);
+      reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(pPart)), &pIL,
+      &ILLength);
 
   IFR(ValidateLoadModule(pIL, ILLength, pModule, Ctx, DiagStream));
 
+  HRESULT hr;
   const DxilPartHeader *pDbgPart = nullptr;
-  if (FAILED(hr = FindDxilPart(pContainer, ContainerSize, DFCC_ShaderDebugInfoDXIL, &pDbgPart)) &&
+  if (FAILED(hr = FindDxilPart(pContainer, ContainerSize,
+                               DFCC_ShaderDebugInfoDXIL, &pDbgPart)) &&
       hr != DXC_E_CONTAINER_MISSING_DXIL) {
     return hr;
   }
 
   if (pDbgPart) {
     GetDxilProgramBitcode(
-      reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(pDbgPart)),
-      &pIL, &ILLength);
-    if (FAILED(hr = ValidateLoadModule(pIL, ILLength, pDebugModule, DbgCtx, DiagStream))) {
+        reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(pDbgPart)),
+        &pIL, &ILLength);
+    if (FAILED(hr = ValidateLoadModule(pIL, ILLength, pDebugModule, DbgCtx,
+                                       DiagStream))) {
       return hr;
     }
   }
 
+  return S_OK;
+}
+
+_Use_decl_annotations_
+HRESULT ValidateDxilContainer(const void *pContainer,
+                              uint32_t ContainerSize,
+                              llvm::raw_ostream &DiagStream) {
+  LLVMContext Ctx, DbgCtx;
+  std::unique_ptr<llvm::Module> pModule, pDebugModule;
+
+  llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
+  PrintDiagnosticContext DiagContext(DiagPrinter);
+  Ctx.setDiagnosticHandler(PrintDiagnosticContext::PrintDiagnosticHandler,
+                           &DiagContext, true);
+  DbgCtx.setDiagnosticHandler(PrintDiagnosticContext::PrintDiagnosticHandler,
+                              &DiagContext, true);
+
+  IFR(ValidateLoadModuleFromContainer(pContainer, ContainerSize, pModule, pDebugModule,
+      Ctx, DbgCtx, DiagStream));
+
   // Validate DXIL Module
   IFR(ValidateDxilModule(pModule.get(), pDebugModule.get()));
 

+ 1 - 0
tools/clang/tools/dxcompiler/CMakeLists.txt

@@ -49,6 +49,7 @@ set(SOURCES
   DXCompiler.def
   dxillib.cpp
   dxcontainerbuilder.cpp
+  dxcutil.cpp
   )
 
 set(LIBRARIES

+ 16 - 30
tools/clang/tools/dxcompiler/dxcassembler.cpp

@@ -18,6 +18,7 @@
 #include "dxc/HLSL/DxilModule.h"
 #include "dxc/Support/dxcapi.impl.h"
 #include "dxillib.h"
+#include "dxcutil.h"
 
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/IRReader/IRReader.h"
@@ -108,48 +109,33 @@ HRESULT STDMETHODCALLTYPE DxcAssembler::AssembleToContainer(
       return S_OK;
     }
 
-    DxilModule program(M.get());
+    // Upgrade Validator Version if necessary.
     try {
-      program.LoadDxilMetadata();
-    }
-    catch (hlsl::Exception &e) {
-      CComPtr<IDxcBlobEncoding> pErrorBlob;
-      IFT(DxcCreateBlobWithEncodingOnHeapCopy(e.msg.c_str(), e.msg.size(), CP_UTF8, &pErrorBlob));
-      IFT(DxcOperationResult::CreateFromResultErrorStatus(nullptr, pErrorBlob, e.hr, ppResult));
-      return S_OK;
-    }
+      DxilModule &program = M->GetOrCreateDxilModule();
 
-    // Upgrade Validator Version if necessary:
-    {
-      CComPtr<IDxcValidator> pValidator;
-      if (DxilLibIsEnabled()) {
-        DxilLibCreateInstance(CLSID_DxcValidator, &pValidator);
-      }
-      if (pValidator == nullptr) {
-        CreateDxcValidator(IID_PPV_ARGS(&pValidator));
-      }
-      CComPtr<IDxcVersionInfo> pVersionInfo;
-      if (pValidator && SUCCEEDED(pValidator.QueryInterface(&pVersionInfo))) {
+      {
         UINT32 majorVer, minorVer;
-        IFT(pVersionInfo->GetVersion(&majorVer, &minorVer));
+        dxcutil::GetValidatorVersion(&majorVer, &minorVer);
         if (program.UpgradeValidatorVersion(majorVer, minorVer)) {
           program.UpdateValidatorVersionMetadata();
         }
       }
+    } catch (hlsl::Exception &e) {
+      CComPtr<IDxcBlobEncoding> pErrorBlob;
+      IFT(DxcCreateBlobWithEncodingOnHeapCopy(e.msg.c_str(), e.msg.size(),
+                                              CP_UTF8, &pErrorBlob));
+      IFT(DxcOperationResult::CreateFromResultErrorStatus(nullptr, pErrorBlob,
+                                                          e.hr, ppResult));
+      return S_OK;
     }
-
+    // Create bitcode of M.
     WriteBitcodeToFile(M.get(), outStream);
     outStream.flush();
 
-    CComPtr<AbstractMemoryStream> pFinalStream;
-    IFT(CreateMemoryStream(pMalloc, &pFinalStream));
-
-    SerializeDxilContainerForModule(&M->GetOrCreateDxilModule(), pOutputStream,
-                                    pFinalStream,
-                                    SerializeDxilFlags::IncludeDebugNamePart);
-
     CComPtr<IDxcBlob> pResultBlob;
-    IFT(pFinalStream->QueryInterface(&pResultBlob));
+    dxcutil::AssembleToContainer(std::move(M), pResultBlob,
+                                         pMalloc, SerializeDxilFlags::IncludeDebugNamePart,
+                                         pOutputStream);
 
     IFT(DxcOperationResult::CreateFromResultErrorStatus(pResultBlob, nullptr, S_OK, ppResult));
   }

+ 16 - 97
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -40,6 +40,8 @@
 #include "dxc/HLSL/DxilPipelineStateValidation.h"
 #include "dxc/HLSL/HLSLExtensionsCodegenHelper.h"
 #include "dxc/HLSL/DxilRootSignature.h"
+#include "dxcutil.h"
+
 // SPIRV change starts
 #ifdef ENABLE_SPIRV_CODEGEN
 #include "clang/SPIRV/EmitSPIRVAction.h"
@@ -2052,39 +2054,6 @@ public:
   }
 };
 
-// Class to manage lifetime of llvm module and provide some utility
-// functions used for generating compiler output.
-class DxilCompilerLLVMModuleOutput {
-public:
-  DxilCompilerLLVMModuleOutput(std::unique_ptr<llvm::Module> module)
-    : m_llvmModule(std::move(module))
-  { }
-
-  void CloneForDebugInfo() {
-    m_llvmModuleWithDebugInfo.reset(llvm::CloneModule(m_llvmModule.get()));
-  }
-
-  void WrapModuleInDxilContainer(IMalloc *pMalloc,
-                                 AbstractMemoryStream *pModuleBitcode,
-                                 CComPtr<IDxcBlob> &pDxilContainerBlob,
-                                 SerializeDxilFlags Flags) {
-    CComPtr<AbstractMemoryStream> pContainerStream;
-    IFT(CreateMemoryStream(pMalloc, &pContainerStream));
-    SerializeDxilContainerForModule(&m_llvmModule->GetOrCreateDxilModule(),
-                                    pModuleBitcode, pContainerStream, Flags);
-
-    pDxilContainerBlob.Release();
-    IFT(pContainerStream.QueryInterface(&pDxilContainerBlob));
-  }
-
-  llvm::Module *get() { return m_llvmModule.get(); }
-  llvm::Module *getWithDebugInfo() { return m_llvmModuleWithDebugInfo.get(); }
-
-private:
-  std::unique_ptr<llvm::Module> m_llvmModule;
-  std::unique_ptr<llvm::Module> m_llvmModuleWithDebugInfo;
-};
-
 class DxcCompiler : public IDxcCompiler2, public IDxcLangExtensions, public IDxcContainerEvent, public IDxcVersionInfo {
 private:
   DXC_MICROCOM_REF_FIELD(m_dwRef)
@@ -2295,26 +2264,12 @@ public:
       // validator can be used as a fallback.
       bool produceFullContainer = !opts.CodeGenHighLevel && !opts.AstDump && !opts.OptDump && rootSigMajor == 0;
       bool needsValidation = produceFullContainer && !opts.DisableValidation;
-      bool internalValidator = false;
-      CComPtr<IDxcValidator> pValidator;
-      CComPtr<IDxcOperationResult> pValResult;
+
       if (needsValidation) {
-        if (DxilLibIsEnabled()) {
-          if (FAILED(DxilLibCreateInstance(CLSID_DxcValidator, &pValidator))) {
-            w << "Unable to create validator from dxil.dll, fallback to built-in.";
-          }
-        }
-        if (pValidator == nullptr) {
-          IFT(CreateDxcValidator(IID_PPV_ARGS(&pValidator)));
-          internalValidator = true;
-        }
-        CComPtr<IDxcVersionInfo> pVersionInfo;
-        if (SUCCEEDED(pValidator.QueryInterface(&pVersionInfo))) {
-          UINT32 majorVer, minorVer;
-          IFT(pVersionInfo->GetVersion(&majorVer, &minorVer));
-          compiler.getCodeGenOpts().HLSLValidatorMajorVer = majorVer;
-          compiler.getCodeGenOpts().HLSLValidatorMinorVer = minorVer;
-        }
+        UINT32 majorVer, minorVer;
+        dxcutil::GetValidatorVersion(&majorVer, &minorVer);
+        compiler.getCodeGenOpts().HLSLValidatorMajorVer = majorVer;
+        compiler.getCodeGenOpts().HLSLValidatorMinorVer = minorVer;
       }
 
       if (opts.AstDump) {
@@ -2399,54 +2354,18 @@ public:
         }
 
         // Don't do work to put in a container if an error has occurred
-        if (compileOK) {
+        // Do not create a container when there is only a a high-level representation in the module.
+        if (compileOK && !opts.CodeGenHighLevel) {
           HRESULT valHR = S_OK;
 
-          // Take ownership of the module from the action.
-          DxilCompilerLLVMModuleOutput llvmModule(action.takeModule());
-
-          // If using the internal validator, we'll use the modules directly.
-          // In this case, we'll want to make a clone to avoid SerializeDxilContainerForModule
-          // stripping all the debug info. The debug info will be stripped from the orginal
-          // module, but preserved in the cloned module.
-          if (internalValidator && opts.DebugInfo)
-            llvmModule.CloneForDebugInfo();
-
-          // Do not create a container when there is only a a high-level representation in the module.
-          if (!opts.CodeGenHighLevel)
-            llvmModule.WrapModuleInDxilContainer(pMalloc, pOutputStream, pOutputBlob, SerializeFlags);
-
           if (needsValidation) {
-            // Important: in-place edit is required so the blob is reused and thus
-            // dxil.dll can be released.
-            if (internalValidator) {
-              IFT(RunInternalValidator(
-                pValidator, llvmModule.get(), llvmModule.getWithDebugInfo(), pOutputBlob,
-                DxcValidatorFlags_InPlaceEdit, &pValResult));
-            }
-            else {
-              IFT(pValidator->Validate(
-                pOutputBlob, DxcValidatorFlags_InPlaceEdit, &pValResult));
-            }
-            IFT(pValResult->GetStatus(&valHR));
-            if (FAILED(valHR)) {
-              CComPtr<IDxcBlobEncoding> pErrors;
-              CComPtr<IDxcBlobEncoding> pErrorsUtf8;
-              IFT(pValResult->GetErrorBuffer(&pErrors));
-              IFT(hlsl::DxcGetBlobAsUtf8(pErrors, &pErrorsUtf8));
-              StringRef errRef((const char *)pErrorsUtf8->GetBufferPointer(),
-                pErrorsUtf8->GetBufferSize());
-              DiagnosticsEngine &D = compiler.getDiagnostics();
-              unsigned DiagID = D.getCustomDiagID(DiagnosticsEngine::Error,
-                "validation errors\r\n%0");
-              D.Report(DiagID) << errRef;
-            }
-            CComPtr<IDxcBlob> pValidatedBlob;
-            IFT(pValResult->GetResult(&pValidatedBlob));
-            if (pValidatedBlob != nullptr) {
-              std::swap(pOutputBlob, pValidatedBlob);
-            }
-            pValidator.Release();
+            valHR = dxcutil::ValidateAndAssembleToContainer(
+                action.takeModule(), pOutputBlob, pMalloc, SerializeFlags,
+                pOutputStream, opts.DebugInfo, compiler.getDiagnostics());
+          } else {
+            dxcutil::AssembleToContainer(action.takeModule(),
+                                                 pOutputBlob, pMalloc,
+                                                 SerializeFlags, pOutputStream);
           }
 
           // Callback after valid DXIL is produced

+ 175 - 0
tools/clang/tools/dxcompiler/dxcutil.cpp

@@ -0,0 +1,175 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// dxcutil.cpp                                                               //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Provides helper class for dxcompiler.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+
+#include "dxc/Support/WinIncludes.h"
+#include "dxc/HLSL/DxilContainer.h"
+#include "dxc/Support/FileIOHelper.h"
+#include "dxc/Support/Global.h"
+#include "dxc/dxcapi.h"
+#include "dxcutil.h"
+#include "dxillib.h"
+#include "clang/Basic/Diagnostic.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+// This declaration is used for the locally-linked validator.
+HRESULT CreateDxcValidator(_In_ REFIID riid, _Out_ LPVOID *ppv);
+// This internal call allows the validator to avoid having to re-deserialize
+// the module. It trusts that the caller didn't make any changes and is
+// kept internal because the layout of the module class may change based
+// on changes across modules, or picking a different compiler version or CRT.
+HRESULT RunInternalValidator(_In_ IDxcValidator *pValidator,
+                             _In_ llvm::Module *pModule,
+                             _In_ llvm::Module *pDebugModule,
+                             _In_ IDxcBlob *pShader, UINT32 Flags,
+                             _In_ IDxcOperationResult **ppResult);
+
+namespace {
+
+bool CreateValidator(CComPtr<IDxcValidator> &pValidator) {
+  if (DxilLibIsEnabled()) {
+    DxilLibCreateInstance(CLSID_DxcValidator, &pValidator);
+  }
+  bool bInternalValidator = false;
+  if (pValidator == nullptr) {
+    IFT(CreateDxcValidator(IID_PPV_ARGS(&pValidator)));
+    bInternalValidator = true;
+  }
+  return bInternalValidator;
+}
+
+// Class to manage lifetime of llvm module and provide some utility
+// functions used for generating compiler output.
+class DxilCompilerLLVMModuleOutput {
+public:
+  DxilCompilerLLVMModuleOutput(std::unique_ptr<llvm::Module> module)
+      : m_llvmModule(std::move(module)) {}
+
+  void CloneForDebugInfo() {
+    m_llvmModuleWithDebugInfo.reset(llvm::CloneModule(m_llvmModule.get()));
+  }
+
+  void WrapModuleInDxilContainer(IMalloc *pMalloc,
+                                 AbstractMemoryStream *pModuleBitcode,
+                                 CComPtr<IDxcBlob> &pDxilContainerBlob,
+                                 SerializeDxilFlags Flags) {
+    CComPtr<AbstractMemoryStream> pContainerStream;
+    IFT(CreateMemoryStream(pMalloc, &pContainerStream));
+    SerializeDxilContainerForModule(&m_llvmModule->GetOrCreateDxilModule(),
+                                    pModuleBitcode, pContainerStream, Flags);
+
+    pDxilContainerBlob.Release();
+    IFT(pContainerStream.QueryInterface(&pDxilContainerBlob));
+  }
+
+  llvm::Module *get() { return m_llvmModule.get(); }
+  llvm::Module *getWithDebugInfo() { return m_llvmModuleWithDebugInfo.get(); }
+
+private:
+  std::unique_ptr<llvm::Module> m_llvmModule;
+  std::unique_ptr<llvm::Module> m_llvmModuleWithDebugInfo;
+};
+
+} // namespace
+
+namespace dxcutil {
+void GetValidatorVersion(unsigned *pMajor, unsigned *pMinor) {
+  if (pMajor == nullptr || pMinor == nullptr)
+    return;
+
+  CComPtr<IDxcValidator> pValidator;
+  CreateValidator(pValidator);
+
+  CComPtr<IDxcVersionInfo> pVersionInfo;
+  if (SUCCEEDED(pValidator.QueryInterface(&pVersionInfo))) {
+    IFT(pVersionInfo->GetVersion(pMajor, pMinor));
+  } else {
+    // Default to 1.0
+    *pMajor = 1;
+    *pMinor = 0;
+  }
+}
+
+void AssembleToContainer(std::unique_ptr<llvm::Module> pM,
+                         CComPtr<IDxcBlob> &pOutputBlob,
+                         CComPtr<IMalloc> &pMalloc,
+                         SerializeDxilFlags SerializeFlags,
+                         CComPtr<AbstractMemoryStream> &pOutputStream) {
+  // Take ownership of the module from the action.
+  DxilCompilerLLVMModuleOutput llvmModule(std::move(pM));
+
+  llvmModule.WrapModuleInDxilContainer(pMalloc, pOutputStream, pOutputBlob,
+                                       SerializeFlags);
+}
+
+HRESULT ValidateAndAssembleToContainer(
+    std::unique_ptr<llvm::Module> pM, CComPtr<IDxcBlob> &pOutputBlob,
+    CComPtr<IMalloc> &pMalloc, SerializeDxilFlags SerializeFlags,
+    CComPtr<AbstractMemoryStream> &pOutputStream, bool bDebugInfo,
+    clang::DiagnosticsEngine &Diag) {
+  HRESULT valHR = S_OK;
+
+  // Take ownership of the module from the action.
+  DxilCompilerLLVMModuleOutput llvmModule(std::move(pM));
+
+  CComPtr<IDxcValidator> pValidator;
+  bool bInternalValidator = CreateValidator(pValidator);
+  // If using the internal validator, we'll use the modules directly.
+  // In this case, we'll want to make a clone to avoid
+  // SerializeDxilContainerForModule stripping all the debug info. The debug
+  // info will be stripped from the orginal module, but preserved in the cloned
+  // module.
+  if (bInternalValidator && bDebugInfo)
+    llvmModule.CloneForDebugInfo();
+
+  llvmModule.WrapModuleInDxilContainer(pMalloc, pOutputStream, pOutputBlob,
+                                       SerializeFlags);
+
+  CComPtr<IDxcOperationResult> pValResult;
+  // Important: in-place edit is required so the blob is reused and thus
+  // dxil.dll can be released.
+  if (bInternalValidator) {
+    IFT(RunInternalValidator(pValidator, llvmModule.get(),
+                             llvmModule.getWithDebugInfo(), pOutputBlob,
+                             DxcValidatorFlags_InPlaceEdit, &pValResult));
+  } else {
+    IFT(pValidator->Validate(pOutputBlob, DxcValidatorFlags_InPlaceEdit,
+                             &pValResult));
+  }
+  IFT(pValResult->GetStatus(&valHR));
+  if (FAILED(valHR)) {
+    CComPtr<IDxcBlobEncoding> pErrors;
+    CComPtr<IDxcBlobEncoding> pErrorsUtf8;
+    IFT(pValResult->GetErrorBuffer(&pErrors));
+    IFT(hlsl::DxcGetBlobAsUtf8(pErrors, &pErrorsUtf8));
+    StringRef errRef((const char *)pErrorsUtf8->GetBufferPointer(),
+                     pErrorsUtf8->GetBufferSize());
+    unsigned DiagID = Diag.getCustomDiagID(clang::DiagnosticsEngine::Error,
+                                           "validation errors\r\n%0");
+    Diag.Report(DiagID) << errRef;
+  }
+  CComPtr<IDxcBlob> pValidatedBlob;
+  IFT(pValResult->GetResult(&pValidatedBlob));
+  if (pValidatedBlob != nullptr) {
+    std::swap(pOutputBlob, pValidatedBlob);
+  }
+  pValidator.Release();
+
+  return valHR;
+}
+
+} // namespace dxcutil

+ 45 - 0
tools/clang/tools/dxcompiler/dxcutil.h

@@ -0,0 +1,45 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// dxcutil.h                                                                 //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Provides helper class for dxcompiler.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+
+
+#include "dxc/dxcapi.h"
+#include "dxc/Support/microcom.h"
+#include <memory>
+
+namespace clang {
+class DiagnosticsEngine;
+}
+
+namespace llvm {
+class Module;
+class raw_string_ostream;
+}
+
+namespace hlsl {
+enum class SerializeDxilFlags : uint32_t;
+class AbstractMemoryStream;
+}
+
+namespace dxcutil {
+HRESULT ValidateAndAssembleToContainer(
+    std::unique_ptr<llvm::Module> pM, CComPtr<IDxcBlob> &pOutputContainerBlob,
+    CComPtr<IMalloc> &pMalloc, hlsl::SerializeDxilFlags SerializeFlags,
+    CComPtr<hlsl::AbstractMemoryStream> &pModuleBitcode, bool bDebugInfo,
+    clang::DiagnosticsEngine &Diag);
+void GetValidatorVersion(unsigned *pMajor, unsigned *pMinor);
+void AssembleToContainer(std::unique_ptr<llvm::Module> pM,
+                         CComPtr<IDxcBlob> &pOutputContainerBlob,
+                         CComPtr<IMalloc> &pMalloc,
+                         hlsl::SerializeDxilFlags SerializeFlags,
+                         CComPtr<hlsl::AbstractMemoryStream> &pModuleBitcode);
+} // namespace dxcutil

+ 2 - 34
tools/clang/tools/dxcompiler/dxcvalidator.cpp

@@ -29,39 +29,6 @@
 using namespace llvm;
 using namespace hlsl;
 
-class PrintDiagnosticContext {
-private:
-  DiagnosticPrinter &m_Printer;
-  bool m_errorsFound;
-  bool m_warningsFound;
-public:
-  PrintDiagnosticContext(DiagnosticPrinter &printer)
-      : m_Printer(printer), m_errorsFound(false), m_warningsFound(false) {}
-
-  bool HasErrors() const {
-    return m_errorsFound;
-  }
-  bool HasWarnings() const {
-    return m_warningsFound;
-  }
-  void Handle(const DiagnosticInfo &DI) {
-    DI.print(m_Printer);
-    switch (DI.getSeverity()) {
-    case llvm::DiagnosticSeverity::DS_Error:
-      m_errorsFound = true;
-      break;
-    case llvm::DiagnosticSeverity::DS_Warning:
-      m_warningsFound = true;
-      break;
-    }
-    m_Printer << "\n";
-  }
-};
-
-static void PrintDiagnosticHandler(const DiagnosticInfo &DI, void *Context) {
-  reinterpret_cast<PrintDiagnosticContext *>(Context)->Handle(DI);
-}
-
 // Utility class for setting and restoring the diagnostic context so we may capture errors/warnings
 struct DiagRestore {
   LLVMContext &Ctx;
@@ -71,7 +38,8 @@ struct DiagRestore {
   DiagRestore(llvm::LLVMContext &Ctx, void *DiagContext) : Ctx(Ctx) {
     OrigHandler = Ctx.getDiagnosticHandler();
     OrigDiagContext = Ctx.getDiagnosticContext();
-    Ctx.setDiagnosticHandler(PrintDiagnosticHandler, DiagContext);
+    Ctx.setDiagnosticHandler(PrintDiagnosticContext::PrintDiagnosticHandler,
+                             DiagContext);
   }
   ~DiagRestore() {
     Ctx.setDiagnosticHandler(OrigHandler, OrigDiagContext);