Ver Fonte

Implement DXIL Container validation

- Implement and centralize container validation components in DxilValidation
- Strip RootSignature from module metadata before serializing to container
- Use existing DxilModule when serializing rather than constructing new one
- Add DxilModule::TryGetDxilModule for capturing diagnostics on metadata load
- Expose DxilPartWriters/DxilContainerWriter for use elsewhere (such as in validation)
Tex Riddell há 8 anos atrás
pai
commit
7beaa7ba54

+ 5 - 0
docs/DXIL.rst

@@ -2104,6 +2104,11 @@ The set of validation rules that are known to hold for a DXIL program is identif
 Rule Code                             Description
 ===================================== =======================================================================================================================================================================================================================================================================================================
 BITCODE.VALID                         TODO - Module must be bitcode-valid
+CONTAINER.PARTINVALID                 DXIL Container must not contain unknown parts
+CONTAINER.PARTMATCHES                 DXIL Container Parts must match Module
+CONTAINER.PARTMISSING                 DXIL Container requires certain parts, corresponding to module
+CONTAINER.PARTREPEATED                DXIL Container must have only one of each part type
+CONTAINER.ROOTSIGNATUREINCOMPATIBLE   Root Signature in DXIL Container must be compatible with shader
 DECL.DXILFNEXTERN                     External function must be a DXIL function
 DECL.DXILNSRESERVED                   The DXIL reserved prefixes must only be used by built-in functions and types
 DECL.FNFLATTENPARAM                   Function parameters must not use struct types

+ 25 - 2
include/dxc/HLSL/DxilContainer.h

@@ -16,6 +16,7 @@
 
 #include <stdint.h>
 #include <iterator>
+#include <functional>
 #include "dxc/HLSL/DxilConstants.h"
 
 struct IDxcContainerReflection;
@@ -23,6 +24,10 @@ namespace llvm { class Module; }
 
 namespace hlsl {
 
+class AbstractMemoryStream;
+class RootSignatureHandle;
+class DxilModule;
+
 #pragma pack(push, 1)
 
 static const size_t DxilContainerHashSize = 16;
@@ -368,8 +373,26 @@ inline uint32_t EncodeVersion(DXIL::ShaderKind shaderType, uint32_t major,
   return ((unsigned)shaderType << 16) | (major << 4) | minor;
 }
 
-class AbstractMemoryStream;
-void SerializeDxilContainerForModule(llvm::Module *pModule,
+class DxilPartWriter {
+public:
+  virtual uint32_t size() const = 0;
+  virtual void write(AbstractMemoryStream *pStream) = 0;
+};
+
+DxilPartWriter *NewProgramSignatureWriter(const DxilModule &M, DXIL::SignatureKind Kind);
+DxilPartWriter *NewRootSignatureWriter(const RootSignatureHandle &S);
+DxilPartWriter *NewFeatureInfoWriter(const DxilModule &M);
+DxilPartWriter *NewPSVWriter(const DxilModule &M);
+
+class DxilContainerWriter : public DxilPartWriter  {
+public:
+  typedef std::function<void(AbstractMemoryStream*)> WriteFn;
+  virtual void AddPart(uint32_t FourCC, uint32_t Size, WriteFn Write) = 0;
+};
+
+DxilContainerWriter *NewDxilContainerWriter();
+
+void SerializeDxilContainerForModule(hlsl::DxilModule *pModule,
                                      AbstractMemoryStream *pModuleBitcode,
                                      AbstractMemoryStream *pStream);
 void CreateDxcContainerReflection(IDxcContainerReflection **ppResult);

+ 7 - 1
include/dxc/HLSL/DxilModule.h

@@ -97,6 +97,9 @@ public:
   const DxilSignature &GetPatchConstantSignature() const;
   const RootSignatureHandle &GetRootSignature() const;
 
+  // Remove Root Signature from module metadata
+  void StripRootSignatureFromMetadata();
+
   // DXIL type system.
   DxilTypeSystem &GetTypeSystem();
 
@@ -121,6 +124,9 @@ public:
 
   void StripDebugRelatedCode();
   llvm::DebugInfoFinder &GetOrCreateDebugInfoFinder();
+
+  static DxilModule *TryGetDxilModule(llvm::Module *pModule);
+
 public:
   // Shader properties.
   class ShaderFlags {
@@ -281,7 +287,7 @@ private:
   // DXIL metadata serialization/deserialization.
   llvm::MDTuple *EmitDxilResources();
   void LoadDxilResources(const llvm::MDOperand &MDO);
-  llvm::MDTuple *EmitDxilShaderProperties();
+  llvm::MDTuple *EmitDxilShaderProperties(bool bStripRootSignature);
   void LoadDxilShaderProperties(const llvm::MDOperand &MDO);
 
   // Helpers.

+ 58 - 3
include/dxc/HLSL/DxilValidation.h

@@ -11,10 +11,14 @@
 
 #pragma once
 
-#include <system_error>
+#include <memory>
+#include "dxc/Support/Global.h"
+#include "dxc/HLSL/DxilConstants.h"
 
 namespace llvm {
 class Module;
+class LLVMContext;
+class raw_ostream;
 }
 
 namespace hlsl {
@@ -26,6 +30,13 @@ enum class ValidationRule : unsigned {
   // Bitcode
   BitcodeValid, // TODO - Module must be bitcode-valid
 
+  // Container
+  ContainerPartInvalid, // DXIL Container must not contain unknown parts
+  ContainerPartMatches, // DXIL Container Parts must match Module
+  ContainerPartMissing, // DXIL Container requires certain parts, corresponding to module
+  ContainerPartRepeated, // DXIL Container must have only one of each part type
+  ContainerRootSignatureIncompatible, // Root Signature in DXIL Container must be compatible with shader
+
   // Declaration
   DeclDxilFnExtern, // External function must be a DXIL function
   DeclDxilNsReserved, // The DXIL reserved prefixes must only be used by built-in functions and types
@@ -218,7 +229,51 @@ enum class ValidationRule : unsigned {
 
 const char *GetValidationRuleText(ValidationRule value);
 void GetValidationVersion(_Out_ unsigned *pMajor, _Out_ unsigned *pMinor);
-std::error_code ValidateDxilModule(_In_ llvm::Module *pModule,
-                                   _In_opt_ llvm::Module *pDebugModule);
+HRESULT ValidateDxilModule(_In_ llvm::Module *pModule,
+                           _In_opt_ llvm::Module *pDebugModule);
+
+// DXIL Container Verification Functions (return false on failure)
+
+bool VerifySignatureMatches(_In_ llvm::Module *pModule,
+                            hlsl::DXIL::SignatureKind SigKind,
+                            _In_reads_bytes_(SigSize) void *pSigData,
+                            _In_ unsigned SigSize);
+
+//bool VerifyRootSignatureMatches(_In_ llvm::Module *pModule,
+//                                  _In_reads_bytes_(RSSize) void *pRSData,
+//                                  _In_ unsigned RSSize);
+
+// PSV = data for Pipeline State Validation
+bool VerifyPSVMatches(_In_ llvm::Module *pModule,
+                      _In_reads_bytes_(PSVSize) void *pPSVData,
+                      _In_ unsigned PSVSize);
+
+bool VerifyFeatureInfoMatches(_In_ llvm::Module *pModule,
+                              _In_reads_bytes_(FeatureInfoSize) const void *pFeatureInfoData,
+                              _In_ unsigned FeatureInfoSize);
+
+// Validate the container parts, assuming supplied module is valid, loaded from the container provided
+struct DxilContainerHeader;
+HRESULT ValidateDxilContainerParts(_In_ llvm::Module *pModule,
+                                   _In_opt_ llvm::Module *pDebugModule,
+                                   _In_reads_bytes_(ContainerSize) const DxilContainerHeader *pContainer,
+                                   _In_ unsigned ContainerSize);
+
+// Loads module, validating load, but not module.
+HRESULT ValidateLoadModule(_In_reads_bytes_(ILLength) const char *pIL,
+                           _In_ uint32_t ILLength,
+                           _In_ std::unique_ptr<llvm::Module> &pModule,
+                           _In_ llvm::LLVMContext &Ctx,
+                           _In_ llvm::raw_ostream &DiagStream);
+
+// Load and validate Dxil module from bitcode.
+HRESULT ValidateDxilBitcode(_In_reads_bytes_(ILLength) const char *pIL,
+                            _In_ uint32_t ILLength,
+                            _In_ llvm::raw_ostream &DiagStream);
+
+// Full container validation, including ValidateDxilModule
+HRESULT ValidateDxilContainer(_In_reads_bytes_(ContainerSize) const void *pContainer,
+                              _In_ unsigned ContainerSize,
+                              _In_ llvm::raw_ostream &DiagStream);
 
 }

+ 7 - 1
include/dxc/Support/ErrorCodes.h

@@ -82,4 +82,10 @@
 #define DXC_E_DUPLICATE_PART                          DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x0011))
 
 // 0x80AA0012 - Error finding part in dxil container.
-#define DXC_E_MISSING_PART                            DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x0012))
+#define DXC_E_MISSING_PART                            DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x0012))
+
+// 0x80AA0013 - Malformed DXIL Container.
+#define DXC_E_MALFORMED_CONTAINER                     DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x0013))
+
+// 0x80AA0014 - Incorrect Root Signature for shader.
+#define DXC_E_INCORRECT_ROOT_SIGNATURE                DXC_MAKE_HRESULT(DXC_SEVERITY_ERROR,FACILITY_DXC,(0x0014))

+ 80 - 51
lib/HLSL/DxilContainerAssembler.cpp

@@ -132,7 +132,7 @@ struct sort_sig {
   }
 };
 
-class DxilProgramSignatureWriter {
+class DxilProgramSignatureWriter : public DxilPartWriter {
 private:
   const DxilSignature &m_signature;
   DXIL::TessellatorDomain m_domain;
@@ -229,11 +229,11 @@ public:
     calcSizes();
   }
 
-  uint32_t size() const {
+  __override uint32_t size() const {
     return m_lastOffset;
   }
 
-  void write(AbstractMemoryStream *pStream) {
+  __override void write(AbstractMemoryStream *pStream) {
     UINT64 startPos = pStream->GetPosition();
     const std::vector<std::unique_ptr<hlsl::DxilSignatureElement>> &elements = m_signature.GetElements();
 
@@ -274,7 +274,22 @@ public:
   }
 };
 
-class DxilProgramRootSignatureWriter {
+DxilPartWriter *hlsl::NewProgramSignatureWriter(const DxilModule &M, DXIL::SignatureKind Kind) {
+  switch (Kind) {
+  case DXIL::SignatureKind::Input:
+    return new DxilProgramSignatureWriter(M.GetInputSignature(),
+      M.GetTessellatorDomain(), true);
+  case DXIL::SignatureKind::Output:
+    return new DxilProgramSignatureWriter(M.GetOutputSignature(),
+      M.GetTessellatorDomain(), false);
+  case DXIL::SignatureKind::PatchConstant:
+    return new DxilProgramSignatureWriter(M.GetPatchConstantSignature(),
+      M.GetTessellatorDomain(), /*IsInput*/ M.GetShaderModel()->IsDS());
+  }
+  return nullptr;
+}
+
+class DxilProgramRootSignatureWriter : public DxilPartWriter {
 private:
   const RootSignatureHandle &m_Sig;
 public:
@@ -288,7 +303,11 @@ public:
   }
 };
 
-class DxilFeatureInfoWriter {
+DxilPartWriter *hlsl::NewRootSignatureWriter(const RootSignatureHandle &S) {
+  return new DxilProgramRootSignatureWriter(S);
+}
+
+class DxilFeatureInfoWriter : public DxilPartWriter  {
 private:
   // Only save the shader properties after create class for it.
   DxilShaderFeatureInfo featureInfo;
@@ -296,24 +315,28 @@ public:
   DxilFeatureInfoWriter(const DxilModule &M) {
     featureInfo.FeatureFlags = M.m_ShaderFlags.GetFeatureInfo();
   }
-  uint32_t size() const {
+  __override uint32_t size() const {
     return sizeof(DxilShaderFeatureInfo);
   }
-  void write(AbstractMemoryStream *pStream) {
+  __override void write(AbstractMemoryStream *pStream) {
     IFT(WriteStreamValue(pStream, featureInfo.FeatureFlags));
   }
 };
 
-class DxilPSVWriter {
+DxilPartWriter *hlsl::NewFeatureInfoWriter(const DxilModule &M) {
+  return new DxilFeatureInfoWriter(M);
+}
+
+class DxilPSVWriter : public DxilPartWriter  {
 private:
-  DxilModule &m_Module;
+  const DxilModule &m_Module;
   UINT m_uTotalResources;
   DxilPipelineStateValidation m_PSV;
   uint32_t m_PSVBufferSize;
   SmallVector<char, 512> m_PSVBuffer;
 
 public:
-  DxilPSVWriter(DxilModule &module) : m_Module(module) {
+  DxilPSVWriter(const DxilModule &module) : m_Module(module) {
     UINT uCBuffers = m_Module.GetCBuffers().size();
     UINT uSamplers = m_Module.GetSamplers().size();
     UINT uSRVs = m_Module.GetSRVs().size();
@@ -321,11 +344,11 @@ public:
     m_uTotalResources = uCBuffers + uSamplers + uSRVs + uUAVs;
     m_PSV.InitNew(m_uTotalResources, nullptr, &m_PSVBufferSize);
   }
-  size_t size() {
+  __override uint32_t size() const {
     return m_PSVBufferSize;
   }
 
-  void write(AbstractMemoryStream *pStream) {
+  __override void write(AbstractMemoryStream *pStream) {
     m_PSVBuffer.resize(m_PSVBufferSize);
     m_PSV.InitNew(m_uTotalResources, m_PSVBuffer.data(), &m_PSVBufferSize);
     DXASSERT_NOMSG(m_PSVBuffer.size() == m_PSVBufferSize);
@@ -339,7 +362,7 @@ public:
     switch (SM->GetKind()) {
       case ShaderModel::Kind::Vertex: {
         pInfo->VS.OutputPositionPresent = 0;
-        DxilSignature &S = m_Module.GetOutputSignature();
+        const DxilSignature &S = m_Module.GetOutputSignature();
         for (auto &&E : S.GetElements()) {
           if (E->GetKind() == Semantic::Kind::Position) {
             // Ideally, we might check never writes mask here,
@@ -360,7 +383,7 @@ public:
       case ShaderModel::Kind::Domain: {
         pInfo->DS.InputControlPointCount = (UINT)m_Module.GetInputControlPointCount();
         pInfo->DS.OutputPositionPresent = 0;
-        DxilSignature &S = m_Module.GetOutputSignature();
+        const DxilSignature &S = m_Module.GetOutputSignature();
         for (auto &&E : S.GetElements()) {
           if (E->GetKind() == Semantic::Kind::Position) {
             // Ideally, we might check never writes mask here,
@@ -382,7 +405,7 @@ public:
           pInfo->GS.OutputStreamMask = 1; // This is what runtime expects.
         }
         pInfo->GS.OutputPositionPresent = 0;
-        DxilSignature &S = m_Module.GetOutputSignature();
+        const DxilSignature &S = m_Module.GetOutputSignature();
         for (auto &&E : S.GetElements()) {
           if (E->GetKind() == Semantic::Kind::Position) {
             // Ideally, we might check never writes mask here,
@@ -397,7 +420,7 @@ public:
         pInfo->PS.DepthOutput = 0;
         pInfo->PS.SampleFrequency = 0;
         {
-          DxilSignature &S = m_Module.GetInputSignature();
+          const DxilSignature &S = m_Module.GetInputSignature();
           for (auto &&E : S.GetElements()) {
             if (E->GetInterpolationMode()->IsAnySample() ||
                 E->GetKind() == Semantic::Kind::SampleIndex) {
@@ -406,7 +429,7 @@ public:
           }
         }
         {
-          DxilSignature &S = m_Module.GetOutputSignature();
+          const DxilSignature &S = m_Module.GetOutputSignature();
           for (auto &&E : S.GetElements()) {
             if (E->IsAnyDepth()) {
               pInfo->PS.DepthOutput = 1;
@@ -483,10 +506,11 @@ public:
   }
 };
 
-class DxilContainerWriter {
-public:
-  typedef std::function<void(AbstractMemoryStream*)> WriteFn;
+DxilPartWriter *hlsl::NewPSVWriter(const DxilModule &M) {
+  return new DxilPSVWriter(M);
+}
 
+class DxilContainerWriter_impl : public DxilContainerWriter  {
 private:
   class DxilPart {
   public:
@@ -501,24 +525,26 @@ private:
   llvm::SmallVector<DxilPart, 8> m_Parts;
 
 public:
-  void AddPart(uint32_t FourCC, uint32_t Size, WriteFn Write) {
+  __override void AddPart(uint32_t FourCC, uint32_t Size, WriteFn Write) {
     m_Parts.emplace_back(FourCC, Size, Write);
   }
 
-  void write(AbstractMemoryStream *pStream) {
-    DxilContainerHeader header;
-    const uint32_t PartCount = (uint32_t)m_Parts.size();
-    const uint32_t OffsetTableSize = sizeof(uint32_t) * PartCount;
-    uint32_t containerSizeInBytes =
-      (uint32_t)sizeof(DxilContainerHeader) + OffsetTableSize +
-      (uint32_t)sizeof(DxilPartHeader) * PartCount;
+  __override uint32_t size() const {
+    uint32_t partSize = 0;
     for (auto &&part : m_Parts) {
-      containerSizeInBytes += part.Header.PartSize;
+      partSize += part.Header.PartSize;
     }
+    return (uint32_t)GetDxilContainerSizeFromParts((uint32_t)m_Parts.size(), partSize);
+  }
+
+  __override void write(AbstractMemoryStream *pStream) {
+    DxilContainerHeader header;
+    const uint32_t PartCount = (uint32_t)m_Parts.size();
+    uint32_t containerSizeInBytes = size();
     InitDxilContainer(&header, PartCount, containerSizeInBytes);
     IFT(pStream->Reserve(header.ContainerSizeInBytes));
     IFT(WriteStreamValue(pStream, header));
-    uint32_t offset = sizeof(header) + OffsetTableSize;
+    uint32_t offset = sizeof(header) + (uint32_t)GetOffsetTableSize(PartCount);
     for (auto &&part : m_Parts) {
       IFT(WriteStreamValue(pStream, offset));
       offset += sizeof(DxilPartHeader) + part.Header.PartSize;
@@ -533,6 +559,10 @@ public:
   }
 };
 
+DxilContainerWriter *hlsl::NewDxilContainerWriter() {
+  return new DxilContainerWriter_impl();
+}
+
 static bool HasDebugInfo(const Module &M) {
   for (Module::const_named_metadata_iterator NMI = M.named_metadata_begin(),
                                              NME = M.named_metadata_end();
@@ -575,7 +605,7 @@ static void WriteProgramPart(const ShaderModel *pModel,
   }
 }
 
-void hlsl::SerializeDxilContainerForModule(Module *pModule,
+void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
                                            AbstractMemoryStream *pModuleBitcode,
                                            AbstractMemoryStream *pFinalStream) {
   // TODO: add a flag to update the module and remove information that is not part
@@ -586,20 +616,18 @@ void hlsl::SerializeDxilContainerForModule(Module *pModule,
   DXASSERT_NOMSG(pFinalStream != nullptr);
 
   CComPtr<AbstractMemoryStream> pProgramStream;
-  DxilModule dxilModule(pModule);
-  dxilModule.LoadDxilMetadata();
 
-  DxilProgramSignatureWriter inputSigWriter(dxilModule.GetInputSignature(),
-                                            dxilModule.GetTessellatorDomain(),
+  DxilProgramSignatureWriter inputSigWriter(pModule->GetInputSignature(),
+                                            pModule->GetTessellatorDomain(),
                                             /*IsInput*/ true);
-  DxilProgramSignatureWriter outputSigWriter(dxilModule.GetOutputSignature(),
-                                             dxilModule.GetTessellatorDomain(),
+  DxilProgramSignatureWriter outputSigWriter(pModule->GetOutputSignature(),
+                                             pModule->GetTessellatorDomain(),
                                              /*IsInput*/ false);
-  DxilPSVWriter PSVWriter(dxilModule);
-  DxilContainerWriter writer;
+  DxilPSVWriter PSVWriter(*pModule);
+  DxilContainerWriter_impl writer;
 
   // Write the feature part.
-  DxilFeatureInfoWriter featureInfoWriter(dxilModule);
+  DxilFeatureInfoWriter featureInfoWriter(*pModule);
   writer.AddPart(DFCC_FeatureInfo, featureInfoWriter.size(), [&](AbstractMemoryStream *pStream) {
     featureInfoWriter.write(pStream);
   });
@@ -613,10 +641,10 @@ void hlsl::SerializeDxilContainerForModule(Module *pModule,
   });
 
   DxilProgramSignatureWriter patchConstantSigWriter(
-      dxilModule.GetPatchConstantSignature(), dxilModule.GetTessellatorDomain(),
-      /*IsInput*/ dxilModule.GetShaderModel()->IsDS());
+      pModule->GetPatchConstantSignature(), pModule->GetTessellatorDomain(),
+      /*IsInput*/ pModule->GetShaderModel()->IsDS());
 
-  if (dxilModule.GetPatchConstantSignature().GetElements().size()) {
+  if (pModule->GetPatchConstantSignature().GetElements().size()) {
     writer.AddPart(DFCC_PatchConstantSignature, patchConstantSigWriter.size(),
                    [&](AbstractMemoryStream *pStream) {
                      patchConstantSigWriter.write(pStream);
@@ -629,32 +657,33 @@ void hlsl::SerializeDxilContainerForModule(Module *pModule,
   });
 
   // Write the root signature (RTS0) part.
-  DxilProgramRootSignatureWriter rootSigWriter(dxilModule.GetRootSignature());
-  if (!dxilModule.GetRootSignature().IsEmpty()) {
+  DxilProgramRootSignatureWriter rootSigWriter(pModule->GetRootSignature());
+  if (!pModule->GetRootSignature().IsEmpty()) {
     writer.AddPart(
         DFCC_RootSignature, rootSigWriter.size(),
         [&](AbstractMemoryStream *pStream) { rootSigWriter.write(pStream); });
+    pModule->StripRootSignatureFromMetadata();
   }
 
   // If we have debug information present, serialize it to a debug part, then use the stripped version as the canonical program version.
   pProgramStream = pModuleBitcode;
-  if (HasDebugInfo(*pModule)) {
+  if (HasDebugInfo(*pModule->GetModule())) {
     uint32_t debugInUInt32, debugPaddingBytes;
     GetPaddedProgramPartSize(pModuleBitcode, debugInUInt32, debugPaddingBytes);
     writer.AddPart(DFCC_ShaderDebugInfoDXIL, debugInUInt32 * sizeof(uint32_t) + sizeof(DxilProgramHeader), [&](AbstractMemoryStream *pStream) {
-      WriteProgramPart(dxilModule.GetShaderModel(), pModuleBitcode, pStream);
+      WriteProgramPart(pModule->GetShaderModel(), pModuleBitcode, pStream);
     });
 
     pProgramStream.Release();
 
-    llvm::StripDebugInfo(*pModule);
-    dxilModule.StripDebugRelatedCode();
+    llvm::StripDebugInfo(*pModule->GetModule());
+    pModule->StripDebugRelatedCode();
 
     CComPtr<IMalloc> pMalloc;
     IFT(CoGetMalloc(1, &pMalloc));
     IFT(CreateMemoryStream(pMalloc, &pProgramStream));
     raw_stream_ostream outStream(pProgramStream.p);
-    WriteBitcodeToFile(pModule, outStream, true);
+    WriteBitcodeToFile(pModule->GetModule(), outStream, true);
   }
 
   // Compute padded bitcode size.
@@ -663,7 +692,7 @@ void hlsl::SerializeDxilContainerForModule(Module *pModule,
 
   // Write the program part.
   writer.AddPart(DFCC_DXIL, programInUInt32 * sizeof(uint32_t) + sizeof(DxilProgramHeader), [&](AbstractMemoryStream *pStream) {
-    WriteProgramPart(dxilModule.GetShaderModel(), pProgramStream, pStream);
+    WriteProgramPart(pModule->GetShaderModel(), pProgramStream, pStream);
   });
 
   writer.write(pFinalStream);

+ 69 - 3
lib/HLSL/DxilModule.cpp

@@ -22,6 +22,8 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/Support/raw_ostream.h"
 #include <unordered_set>
 
@@ -31,6 +33,21 @@ using std::vector;
 using std::unique_ptr;
 
 
+namespace {
+class DxilErrorDiagnosticInfo : public DiagnosticInfo {
+private:
+  const char *m_message;
+public:
+  DxilErrorDiagnosticInfo(const char *str)
+    : DiagnosticInfo(DK_FirstPluginKind, DiagnosticSeverity::DS_Error),
+    m_message(str) { }
+
+  __override void print(DiagnosticPrinter &DP) const {
+    DP << m_message;
+  }
+};
+} // anon namespace
+
 namespace hlsl {
 
 //------------------------------------------------------------------------------
@@ -808,6 +825,27 @@ const RootSignatureHandle &DxilModule::GetRootSignature() const {
   return *m_RootSignature;
 }
 
+void DxilModule::StripRootSignatureFromMetadata() {
+  const llvm::NamedMDNode *pEntries = m_pMDHelper->GetDxilEntryPoints();
+  IFTBOOL(pEntries->getNumOperands() == 1, DXC_E_INCORRECT_DXIL_METADATA);
+
+  Function *pEntryFunc;
+  string EntryName;
+  const llvm::MDOperand *pSignatures, *pResources, *pProperties;
+  m_pMDHelper->GetDxilEntryPoint(pEntries->getOperand(0), pEntryFunc, EntryName, pSignatures, pResources, pProperties);
+
+  MDTuple *pMDSignatures = pSignatures->get() ? dyn_cast<MDTuple>(pSignatures->get()) : nullptr;
+  MDTuple *pMDResources = pResources->get() ? dyn_cast<MDTuple>(pResources->get()) : nullptr;
+
+  MDTuple *pMDProperties = EmitDxilShaderProperties(/*bStripRootSignature*/true);
+  MDTuple *pEntry = m_pMDHelper->EmitDxilEntryPointTuple(pEntryFunc, m_EntryName, pMDSignatures, pMDResources, pMDProperties);
+  vector<MDNode *> Entries;
+  Entries.emplace_back(pEntry);
+  NamedMDNode *pEntryPointsNamedMD = GetModule()->getNamedMetadata(DxilMDHelper::kDxilEntryPointsMDName);
+  GetModule()->eraseNamedMetadata(pEntryPointsNamedMD);
+  m_pMDHelper->EmitDxilEntryPoints(Entries);
+}
+
 void DxilModule::ResetInputSignature(DxilSignature *pValue) {
   m_InputSignature.reset(pValue);
 }
@@ -880,7 +918,7 @@ void DxilModule::EmitDxilMetadata() {
                                                            *m_OutputSignature,
                                                            *m_PatchConstantSignature);
   MDTuple *pMDResources = EmitDxilResources();
-  MDTuple *pMDProperties = EmitDxilShaderProperties();
+  MDTuple *pMDProperties = EmitDxilShaderProperties(/*bStripRootSignature*/false);
   m_pMDHelper->EmitDxilTypeSystem(GetTypeSystem(), m_LLVMUsed);
   EmitLLVMUsed();
   MDTuple *pEntry = m_pMDHelper->EmitDxilEntryPointTuple(GetEntryFunction(), m_EntryName, pMDSignatures, pMDResources, pMDProperties);
@@ -1010,7 +1048,7 @@ void DxilModule::LoadDxilResources(const llvm::MDOperand &MDO) {
   }
 }
 
-MDTuple *DxilModule::EmitDxilShaderProperties() {
+MDTuple *DxilModule::EmitDxilShaderProperties(bool bStripRootSignature) {
   vector<Metadata *> MDVals;
 
   // DXIL shader flags.
@@ -1062,7 +1100,7 @@ MDTuple *DxilModule::EmitDxilShaderProperties() {
     MDVals.emplace_back(pMDTuple);
   }
 
-  if (!m_RootSignature->IsEmpty()) {
+  if (!bStripRootSignature && !m_RootSignature->IsEmpty()) {
     MDVals.emplace_back(m_pMDHelper->Uint32ToConstMD(DxilMDHelper::kDxilRootSignatureTag));
     MDVals.emplace_back(m_pMDHelper->EmitRootSignature(*m_RootSignature.get()));
   }
@@ -1188,6 +1226,34 @@ DebugInfoFinder &DxilModule::GetOrCreateDebugInfoFinder() {
   }
   return *m_pDebugInfoFinder;
 }
+
+hlsl::DxilModule *hlsl::DxilModule::TryGetDxilModule(llvm::Module *pModule) {
+  LLVMContext &Ctx = pModule->getContext();
+  std::string diagStr;
+  raw_string_ostream diagStream(diagStr);
+
+  hlsl::DxilModule *pDxilModule = nullptr;
+  // TODO: add detail error in DxilMDHelper.
+  try {
+    pDxilModule = &pModule->GetOrCreateDxilModule();
+  } catch (const ::hlsl::Exception &hlslException) {
+    diagStream << "load dxil metadata failed -";
+    try {
+      const char *msg = hlslException.what();
+      if (msg == nullptr || *msg == '\0')
+        diagStream << " error code " << hlslException.hr << "\n";
+      else
+        diagStream << msg;
+    } catch (...) {
+      diagStream << " unable to retrieve error message.\n";
+    }
+    Ctx.diagnose(DxilErrorDiagnosticInfo(diagStream.str().c_str()));
+  } catch (...) {
+    Ctx.diagnose(DxilErrorDiagnosticInfo("load dxil metadata failed - unknown error.\n"));
+  }
+  return pDxilModule;
+}
+
 } // namespace hlsl
 
 namespace llvm {

+ 503 - 27
lib/HLSL/DxilValidation.cpp

@@ -19,6 +19,8 @@
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HLSL/DxilInstructions.h"
 #include "dxc/HLSL/ReducibilityAnalysis.h"
+#include "dxc/Support/WinIncludes.h"
+#include "dxc/Support/FileIOHelper.h"
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/CallGraph.h"
@@ -34,12 +36,15 @@
 #include "llvm/ADT/BitVector.h"
 #include <winerror.h>
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Bitcode/ReaderWriter.h"
 #include <unordered_set>
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "dxc/HLSL/DxilSpanAllocator.h"
 #include "dxc/HLSL/DxilSignatureAllocator.h"
+#include "dxc/HLSL/DxilRootSignature.h"
 #include <algorithm>
 
 
@@ -54,6 +59,11 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
   // VALRULE-TEXT:BEGIN
   switch(value) {
     case hlsl::ValidationRule::BitcodeValid: return "Module bitcode is invalid";
+    case hlsl::ValidationRule::ContainerPartMatches: return "Container part '%0' does not match expected for module.";
+    case hlsl::ValidationRule::ContainerPartRepeated: return "More than one container part '%0'.";
+    case hlsl::ValidationRule::ContainerPartMissing: return "Missing part '%0' required by module.";
+    case hlsl::ValidationRule::ContainerPartInvalid: return "Unknown part '%0' found in DXIL container.";
+    case hlsl::ValidationRule::ContainerRootSignatureIncompatible: return "Root Signature in DXIL container is not compatible with shader.";
     case hlsl::ValidationRule::MetaRequired: return "TODO - Required metadata missing";
     case hlsl::ValidationRule::MetaKnown: return "Named metadata '%0' is unknown";
     case hlsl::ValidationRule::MetaUsed: return "All metadata must be used by dxil";
@@ -234,6 +244,57 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
   return "<unknown>";
 }
 
+namespace {
+
+class PrintDiagnosticContext {
+private:
+  DiagnosticPrinter &m_Printer;
+  bool m_errorsFound;
+  bool m_warningsFound;
+public:
+  PrintDiagnosticContext(DiagnosticPrinter &printer)
+      : m_Printer(printer), m_errorsFound(false), m_warningsFound(false) {}
+
+  bool HasErrors() const {
+    return m_errorsFound;
+  }
+  bool HasWarnings() const {
+    return m_warningsFound;
+  }
+  void Handle(const DiagnosticInfo &DI) {
+    DI.print(m_Printer);
+    switch (DI.getSeverity()) {
+    case llvm::DiagnosticSeverity::DS_Error:
+      m_errorsFound = true;
+      break;
+    case llvm::DiagnosticSeverity::DS_Warning:
+      m_warningsFound = true;
+      break;
+    }
+    m_Printer << "\n";
+  }
+};
+
+static void PrintDiagnosticHandler(const DiagnosticInfo &DI, void *Context) {
+  reinterpret_cast<PrintDiagnosticContext *>(Context)->Handle(DI);
+}
+
+// Utility class for setting and restoring the diagnostic context so we may capture errors/warnings
+struct DiagRestore {
+  LLVMContext &Ctx;
+  void *OrigDiagContext;
+  LLVMContext::DiagnosticHandlerTy OrigHandler;
+
+  DiagRestore(llvm::LLVMContext &Ctx, void *DiagContext) : Ctx(Ctx) {
+    OrigHandler = Ctx.getDiagnosticHandler();
+    OrigDiagContext = Ctx.getDiagnosticContext();
+    Ctx.setDiagnosticHandler(PrintDiagnosticHandler, DiagContext);
+  }
+  ~DiagRestore() {
+    Ctx.setDiagnosticHandler(OrigHandler, OrigDiagContext);
+  }
+};
+
 class DxilErrorDiagnosticInfo : public DiagnosticInfo {
 private:
   const char *m_message;
@@ -262,6 +323,8 @@ static inline DiagnosticPrinter &operator<<(DiagnosticPrinter &OS, Type &T) {
   return OS;
 }
 
+} // anon namespace
+
 namespace hlsl {
 
 struct PSExecutionInfo {
@@ -3906,33 +3969,15 @@ void GetValidationVersion(_Out_ unsigned *pMajor, _Out_ unsigned *pMinor) {
   *pMinor = 0;
 }
 
-_Use_decl_annotations_ std::error_code
+_Use_decl_annotations_ HRESULT
 ValidateDxilModule(llvm::Module *pModule, llvm::Module *pDebugModule) {
-  const LLVMContext &Ctx = pModule->getContext();
   std::string diagStr;
   raw_string_ostream diagStream(diagStr);
   DiagnosticPrinterRawOStream DiagPrinter(diagStream);
 
-  DxilModule *pDxilModule;
-  // TODO: add detail error in DxilMDHelper.
-  try {
-    pDxilModule = &pModule->GetOrCreateDxilModule();
-  } catch (const ::hlsl::Exception &hlslException) {
-    DiagPrinter << "load dxil metadata failed -";
-    try {
-      const char *msg = hlslException.what();
-      if (msg == nullptr || *msg == '\0')
-        DiagPrinter << " error code " << hlslException.hr << "\n";
-      else
-        DiagPrinter << msg;
-    } catch (...) {
-      DiagPrinter << " unable to retrieve error message.\n";
-    }
-    emitDxilDiag(Ctx, diagStr.c_str());
-    return std::error_code(ERROR_INVALID_DATA, std::system_category());
-  } catch (...) {
-    emitDxilDiag(Ctx, "load dxil metadata failed - unknown error.\n");
-    return std::error_code(ERROR_INVALID_DATA, std::system_category());
+  DxilModule *pDxilModule = DxilModule::TryGetDxilModule(pModule);
+  if (!pDxilModule) {
+    return DXC_E_IR_VERIFICATION_FAILED;
   }
 
   ValidationContext ValCtx(*pModule, pDebugModule, *pDxilModule, DiagPrinter);
@@ -3978,13 +4023,444 @@ ValidateDxilModule(llvm::Module *pModule, llvm::Module *pDebugModule) {
 
   // Ensure error messages are flushed out on error.
   if (ValCtx.Failed) {
-    diagStream.flush();
-    emitDxilDiag(Ctx, diagStr.c_str());
+    emitDxilDiag(pModule->getContext(), diagStream.str().c_str());
+    return DXC_E_IR_VERIFICATION_FAILED;
+  }
+  return S_OK;
+}
+
+// DXIL Container Verification Functions
+
+static void VerifyBlobPartMatches(_In_ ValidationContext &ValCtx,
+                                  _In_ LPCSTR pName,
+                                  DxilPartWriter *pWriter,
+                                  _In_reads_bytes_opt_(Size) const void *pData,
+                                  _In_ unsigned Size) {
+  if (!pData && pWriter->size()) {
+    // No blob part, but writer says non-zero size is expected.
+    ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, pName);
+    return;
+  }
+
+  // Compare sizes
+  if (pWriter->size() != Size) {
+    ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, pName);
+    return;
+  }
+
+  CComPtr<IMalloc> pMalloc;
+  IFT(CoGetMalloc(1, &pMalloc));
+  CComPtr<AbstractMemoryStream> pOutputStream;
+  IFT(CreateMemoryStream(pMalloc, &pOutputStream));
+  pOutputStream->Reserve(Size);
+
+  pWriter->write(pOutputStream);
+  DXASSERT(pOutputStream->GetPtrSize() == Size, "otherwise, DxilPartWriter misreported size");
+
+  if (memcmp(pData, pOutputStream->GetPtr(), Size)) {
+    ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, pName);
+    return;
+  }
+
+  return;
+}
+
+static void VerifySignatureMatches(_In_ ValidationContext &ValCtx,
+                                   DXIL::SignatureKind SigKind,
+                                   _In_reads_bytes_opt_(SigSize) const void *pSigData,
+                                   _In_ unsigned SigSize) {
+  // Generate corresponding signature from module and memcmp
+
+  const char *pName = nullptr;
+  switch (SigKind)
+  {
+  case hlsl::DXIL::SignatureKind::Input:
+    pName = "Program Input Signature";
+    break;
+  case hlsl::DXIL::SignatureKind::Output:
+    pName = "Program Output Signature";
+    break;
+  case hlsl::DXIL::SignatureKind::PatchConstant:
+    pName = "Program Patch Constant Signature";
+    break;
+  }
+
+  unique_ptr<DxilPartWriter> pWriter(NewProgramSignatureWriter(ValCtx.DxilMod, SigKind));
+  VerifyBlobPartMatches(ValCtx, pName, pWriter.get(), pSigData, SigSize);
+}
+
+_Use_decl_annotations_
+bool VerifySignatureMatches(llvm::Module *pModule,
+                            DXIL::SignatureKind SigKind,
+                            const void *pSigData,
+                            unsigned SigSize) {
+  std::string diagStr;
+  raw_string_ostream diagStream(diagStr);
+  DiagnosticPrinterRawOStream DiagPrinter(diagStream);
+  ValidationContext ValCtx(*pModule, nullptr, pModule->GetOrCreateDxilModule(), DiagPrinter);
+  VerifySignatureMatches(ValCtx, SigKind, pSigData, SigSize);
+  if (ValCtx.Failed) {
+    emitDxilDiag(pModule->getContext(), diagStream.str().c_str());
+  }
+  return !ValCtx.Failed;
+}
+
+//static void VerifyRootSignatureMatches(_In_ ValidationContext &ValCtx,
+//                                             _In_reads_bytes_(RSSize) const void *pRSData,
+//                                             _In_ unsigned RSSize) {
+//  // Write root signature from module and memcmp
+//  unique_ptr<DxilPartWriter> pWriter(NewRootSignatureWriter(ValCtx.DxilMod.GetRootSignature()));
+//  VerifyBlobPartMatches(ValCtx, "Root Signature", pWriter.get(), pRSData, RSSize);
+//}
+//
+//_Use_decl_annotations_
+//bool VerifyRootSignatureMatches(llvm::Module *pModule,
+//                                const void *pRSData,
+//                                unsigned RSSize) {
+//  std::string diagStr;
+//  raw_string_ostream diagStream(diagStr);
+//  DiagnosticPrinterRawOStream DiagPrinter(diagStream);
+//  ValidationContext ValCtx(*pModule, nullptr, pModule->GetOrCreateDxilModule(), DiagPrinter);
+//  VerifyRootSignatureMatches(ValCtx, pRSData, RSSize);
+//  if (ValCtx.Failed) {
+//    emitDxilDiag(pModule->getContext(), diagStream.str().c_str());
+//  }
+//  return !ValCtx.Failed;
+//}
+
+static void VerifyPSVMatches(_In_ ValidationContext &ValCtx,
+                             _In_reads_bytes_(PSVSize) const void *pPSVData,
+                             _In_ unsigned PSVSize) {
+  // generate PSV data from module and memcmp
+  unique_ptr<DxilPartWriter> pWriter(NewPSVWriter(ValCtx.DxilMod));
+  VerifyBlobPartMatches(ValCtx, "Pipeline State Validation", pWriter.get(), pPSVData, PSVSize);
+}
+
+_Use_decl_annotations_
+bool VerifyPSVMatches(llvm::Module *pModule,
+                      const void *pPSVData,
+                      unsigned PSVSize) {
+  std::string diagStr;
+  raw_string_ostream diagStream(diagStr);
+  DiagnosticPrinterRawOStream DiagPrinter(diagStream);
+  ValidationContext ValCtx(*pModule, nullptr, pModule->GetOrCreateDxilModule(), DiagPrinter);
+  VerifyPSVMatches(ValCtx, pPSVData, PSVSize);
+  if (ValCtx.Failed) {
+    emitDxilDiag(pModule->getContext(), diagStream.str().c_str());
+  }
+  return !ValCtx.Failed;
+}
+
+static void VerifyFeatureInfoMatches(_In_ ValidationContext &ValCtx,
+                                     _In_reads_bytes_(FeatureInfoSize) const void *pFeatureInfoData,
+                                     _In_ unsigned FeatureInfoSize) {
+  // generate Feature Info data from module and memcmp
+  unique_ptr<DxilPartWriter> pWriter(NewFeatureInfoWriter(ValCtx.DxilMod));
+  VerifyBlobPartMatches(ValCtx, "Feature Info", pWriter.get(), pFeatureInfoData, FeatureInfoSize);
+}
+
+_Use_decl_annotations_
+bool VerifyFeatureInfoMatches(llvm::Module *pModule,
+                              const void *pFeatureInfoData,
+                              unsigned FeatureInfoSize) {
+  std::string diagStr;
+  raw_string_ostream diagStream(diagStr);
+  DiagnosticPrinterRawOStream DiagPrinter(diagStream);
+  ValidationContext ValCtx(*pModule, nullptr, pModule->GetOrCreateDxilModule(), DiagPrinter);
+  VerifyFeatureInfoMatches(ValCtx, pFeatureInfoData, FeatureInfoSize);
+  if (ValCtx.Failed) {
+    emitDxilDiag(pModule->getContext(), diagStream.str().c_str());
+  }
+  return !ValCtx.Failed;
+}
+
+_Use_decl_annotations_
+HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
+                                   llvm::Module *pDebugModule,
+                                   const DxilContainerHeader *pContainer,
+                                   unsigned ContainerSize) {
+
+  DXASSERT_NOMSG(pModule);
+  if (!pContainer || !IsValidDxilContainer(pContainer, ContainerSize)) {
+    return DXC_E_CONTAINER_INVALID;
+  }
+
+  std::string diagStr;
+  raw_string_ostream DiagStream(diagStr);
+  DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
+
+  DxilModule *pDxilModule = DxilModule::TryGetDxilModule(pModule);
+  if (!pDxilModule) {
+    return DXC_E_IR_VERIFICATION_FAILED;
+  }
+
+  ValidationContext ValCtx(*pModule, pDebugModule, *pDxilModule, DiagPrinter);
+
+  DXIL::ShaderKind ShaderKind = pDxilModule->GetShaderModel()->GetKind();
+  bool bTess = ShaderKind == DXIL::ShaderKind::Hull || ShaderKind == DXIL::ShaderKind::Domain;
+
+  std::unordered_set<uint32_t> FourCCFound;
+  const DxilPartHeader *pRootSignaturePart = nullptr;
+  const DxilPartHeader *pPSVPart = nullptr;
+
+  for (auto it = begin(pContainer), itEnd = end(pContainer); it != itEnd; ++it) {
+    const DxilPartHeader *pPart = *it;
+
+    char szFourCC[5];
+    PartKindToCharArray(pPart->PartFourCC, szFourCC);
+    if (FourCCFound.find(pPart->PartFourCC) != FourCCFound.end()) {
+      // Two parts with same FourCC found
+      ValCtx.EmitFormatError(ValidationRule::ContainerPartRepeated, szFourCC);
+      continue;
+    }
+    FourCCFound.insert(pPart->PartFourCC);
+
+    switch (pPart->PartFourCC)
+    {
+    case DFCC_InputSignature:
+      VerifySignatureMatches(ValCtx, DXIL::SignatureKind::Input, GetDxilPartData(pPart), pPart->PartSize);
+      break;
+    case DFCC_OutputSignature:
+      VerifySignatureMatches(ValCtx, DXIL::SignatureKind::Output, GetDxilPartData(pPart), pPart->PartSize);
+      break;
+    case DFCC_PatchConstantSignature:
+      if (bTess) {
+        VerifySignatureMatches(ValCtx, DXIL::SignatureKind::PatchConstant, GetDxilPartData(pPart), pPart->PartSize);
+      } else {
+        ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, "Program Patch Constant Signature");
+      }
+      break;
+    case DFCC_FeatureInfo:
+      VerifyFeatureInfoMatches(ValCtx, GetDxilPartData(pPart), pPart->PartSize);
+      break;
+    case DFCC_RootSignature:
+      pRootSignaturePart = pPart;
+      break;
+    case DFCC_PipelineStateValidation:
+      pPSVPart = pPart;
+      VerifyPSVMatches(ValCtx, GetDxilPartData(pPart), pPart->PartSize);
+      break;
+
+    // Skip these
+    case DFCC_ResourceDef:
+    case DFCC_ShaderStatistics:
+    case DFCC_PrivateData:
+    case DFCC_DXIL:
+    case DFCC_ShaderDebugInfoDXIL:
+      continue;
+
+    case DFCC_Container:
+    default:
+      ValCtx.EmitFormatError(ValidationRule::ContainerPartInvalid, szFourCC);
+      break;
+    }
+  }
+
+  // Verify required parts found
+  if (FourCCFound.find(DFCC_InputSignature) == FourCCFound.end()) {
+    VerifySignatureMatches(ValCtx, DXIL::SignatureKind::Input, nullptr, 0);
+  }
+  if (FourCCFound.find(DFCC_OutputSignature) == FourCCFound.end()) {
+    VerifySignatureMatches(ValCtx, DXIL::SignatureKind::Output, nullptr, 0);
+  }
+  if (bTess && FourCCFound.find(DFCC_PatchConstantSignature) == FourCCFound.end())
+  {
+    VerifySignatureMatches(ValCtx, DXIL::SignatureKind::PatchConstant, nullptr, 0);
+  }
+  if (FourCCFound.find(DFCC_FeatureInfo) == FourCCFound.end()) {
+    // Could be optional, but RS1 runtime doesn't handle this case properly.
+    ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, "Feature Info");
+  }
+
+  // Validate Root Signature
+  if (pPSVPart) {
+    if (pRootSignaturePart) {
+      DxilVersionedRootSignatureDesc* pDesc = nullptr;
+      try {
+        DeserializeRootSignature(GetDxilPartData(pRootSignaturePart), pRootSignaturePart->PartSize, &pDesc);
+      } catch (...) {
+        pDesc = nullptr;
+      }
+      if (pDesc) {
+        try {
+          VerifyRootSignatureWithShaderPSV(pDesc,
+                                            pDxilModule->GetShaderModel()->GetKind(),
+                                            GetDxilPartData(pPSVPart), pPSVPart->PartSize,
+                                            DiagStream);
+        } catch (...) {
+          DeleteRootSignature(pDesc);
+          ValCtx.EmitError(ValidationRule::ContainerRootSignatureIncompatible);
+        }
+      } else {
+        ValCtx.EmitError(ValidationRule::ContainerRootSignatureIncompatible);
+      }
+    }
+  } else {
+    ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, "Pipeline State Validation");
+  }
+
+  if (ValCtx.Failed) {
+    emitDxilDiag(pModule->getContext(), DiagStream.str().c_str());
+    return DXC_E_MALFORMED_CONTAINER;
+  }
+  return S_OK;
+}
+
+static HRESULT FindDxilPart(_In_reads_bytes_(ContainerSize) const void *pContainerBytes,
+                            _In_ uint32_t ContainerSize,
+                            _In_ DxilFourCC FourCC,
+                            _In_ const DxilPartHeader **ppPart) {
+
+  const DxilContainerHeader *pContainer =
+    IsDxilContainerLike(pContainerBytes, ContainerSize);
+
+  if (!pContainer) {
+    IFR(DXC_E_CONTAINER_INVALID);
+  }
+  if (!IsValidDxilContainer(pContainer, ContainerSize)) {
+    IFR(DXC_E_CONTAINER_INVALID);
+  }
+
+  DxilPartIterator it = std::find_if(begin(pContainer), end(pContainer),
+    DxilPartIsType(FourCC));
+  if (it == end(pContainer)) {
+    IFR(DXC_E_CONTAINER_MISSING_DXIL);
+  }
+
+  const DxilProgramHeader *pProgramHeader =
+    reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(*it));
+  if (!IsValidDxilProgramHeader(pProgramHeader, (*it)->PartSize)) {
+    IFR(DXC_E_CONTAINER_INVALID);
+  }
+
+  *ppPart = *it;
+  return S_OK;
+}
+
+_Use_decl_annotations_
+HRESULT ValidateLoadModule(const char *pIL,
+                           uint32_t ILLength,
+                           unique_ptr<llvm::Module> &pModule,
+                           LLVMContext &Ctx,
+                           llvm::raw_ostream &DiagStream) {
+
+  llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
+  PrintDiagnosticContext DiagContext(DiagPrinter);
+  DiagRestore DR(Ctx, &DiagContext);
+
+  std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf;
+  pBitcodeBuf.reset(llvm::MemoryBuffer::getMemBuffer(
+      llvm::StringRef(pIL, ILLength), "", false).release());
+  ErrorOr<std::unique_ptr<llvm::Module>> loadedModuleResult(llvm::parseBitcodeFile(
+    pBitcodeBuf->getMemBufferRef(), Ctx));
+
+  // DXIL disallows some LLVM bitcode constructs, like unaccounted-for sub-blocks.
+  // These appear as warnings, which the validator should reject.
+  if (DiagContext.HasErrors() || DiagContext.HasWarnings() || loadedModuleResult.getError())
+    return DXC_E_IR_VERIFICATION_FAILED;
+
+  pModule = std::move(loadedModuleResult.get());
+  return S_OK;
+}
+
+HRESULT ValidateDxilBitcode(
+  _In_reads_bytes_(ILLength) const char *pIL,
+  _In_ uint32_t ILLength,
+  _In_ llvm::raw_ostream &DiagStream) {
+
+  LLVMContext Ctx;
+  std::unique_ptr<llvm::Module> pModule;
+
+  llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
+  PrintDiagnosticContext DiagContext(DiagPrinter);
+  Ctx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagContext, true);
+
+  HRESULT hr;
+  if (FAILED(hr = ValidateLoadModule(pIL, ILLength, pModule, Ctx, DiagStream)))
+    return hr;
+
+  if (FAILED(hr = ValidateDxilModule(pModule.get(), nullptr)))
+    return hr;
+
+  DxilModule &dxilModule = pModule->GetDxilModule();
+  if (!dxilModule.GetRootSignature().IsEmpty()) {
+    const RootSignatureHandle &RS = dxilModule.GetRootSignature();
+    unique_ptr<DxilPartWriter> pWriter(NewPSVWriter(dxilModule));
+    DXASSERT_NOMSG(pWriter->size());
+    unique_ptr<unsigned char[]> pPSVData(new unsigned char[pWriter->size()]);
+    DxilVersionedRootSignatureDesc* pDesc = nullptr;
+    if (!RS.GetDesc()) {
+      DeserializeRootSignature(RS.GetSerializedBytes(), RS.GetSerializedSize(), &pDesc);
+      if (!pDesc)
+        return DXC_E_INCORRECT_ROOT_SIGNATURE;
+    }
+    try {
+      VerifyRootSignatureWithShaderPSV(pDesc ? pDesc : RS.GetDesc(),
+                                       dxilModule.GetShaderModel()->GetKind(),
+                                       pPSVData.get(), pWriter->size(),
+                                       DiagStream);
+    } catch (...) {
+      DeleteRootSignature(pDesc);
+      return DXC_E_INCORRECT_ROOT_SIGNATURE;
+    }
+  }
+
+  if (DiagContext.HasErrors() || DiagContext.HasWarnings()) {
+    return DXC_E_IR_VERIFICATION_FAILED;
+  }
+
+  return S_OK;
+}
+
+_Use_decl_annotations_
+HRESULT ValidateDxilContainer(const void *pContainer,
+                              unsigned ContainerSize,
+                              llvm::raw_ostream &DiagStream) {
+
+  LLVMContext Ctx, DbgCtx;
+  std::unique_ptr<llvm::Module> pModule, pDebugModule;
+
+  llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
+  PrintDiagnosticContext DiagContext(DiagPrinter);
+  Ctx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagContext, true);
+  DbgCtx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagContext, true);
+
+  HRESULT hr;
+  const DxilPartHeader *pPart = nullptr;
+  IFR(FindDxilPart(pContainer, ContainerSize, DFCC_DXIL, &pPart));
+
+  const char *pIL = nullptr;
+  uint32_t ILLength = 0;
+  GetDxilProgramBitcode(
+    reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(pPart)),
+    &pIL, &ILLength);
+
+  IFR(ValidateLoadModule(pIL, ILLength, pModule, Ctx, DiagStream));
+
+  const DxilPartHeader *pDbgPart = nullptr;
+  if (FAILED(hr = FindDxilPart(pContainer, ContainerSize, DFCC_ShaderDebugInfoDXIL, &pDbgPart)) &&
+      hr != DXC_E_CONTAINER_MISSING_DXIL) {
+    return hr;
+  }
+
+  if (pDbgPart) {
+    GetDxilProgramBitcode(
+      reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(pPart)),
+      &pIL, &ILLength);
+    if (FAILED(hr = ValidateLoadModule(pIL, ILLength, pDebugModule, DbgCtx, DiagStream))) {
+      return hr;
+    }
+  }
+
+  // Validate DXIL Module
+  IFR(ValidateDxilModule(pModule.get(), pDebugModule.get()));
+
+  if (DiagContext.HasErrors() || DiagContext.HasWarnings()) {
+    return DXC_E_IR_VERIFICATION_FAILED;
   }
 
-  if (ValCtx.Failed)
-    return std::error_code(ERROR_INVALID_DATA, std::system_category());
-  return std::error_code();
+  return ValidateDxilContainerParts(pModule.get(), pDebugModule.get(),
+    IsDxilContainerLike(pContainer, ContainerSize), ContainerSize);
 }
 
 } // namespace hlsl

+ 1 - 1
tools/clang/tools/dxcompiler/dxcassembler.cpp

@@ -120,7 +120,7 @@ HRESULT STDMETHODCALLTYPE DxcAssembler::AssembleToContainer(
     CComPtr<AbstractMemoryStream> pFinalStream;
     IFT(CreateMemoryStream(pMalloc, &pFinalStream));
 
-    SerializeDxilContainerForModule(M.get(), pOutputStream, pFinalStream);
+    SerializeDxilContainerForModule(&M->GetOrCreateDxilModule(), pOutputStream, pFinalStream);
 
     CComPtr<IDxcBlob> pResultBlob;
     IFT(pFinalStream->QueryInterface(&pResultBlob));

+ 1 - 1
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -1858,7 +1858,7 @@ public:
  void WrapModuleInDxilContainer(IMalloc *pMalloc,  AbstractMemoryStream *pModuleBitcode, CComPtr<IDxcBlob> &pDxilContainerBlob) {
     CComPtr<AbstractMemoryStream> pContainerStream;
     IFT(CreateMemoryStream(pMalloc, &pContainerStream));
-    SerializeDxilContainerForModule(m_llvmModule.get(), pModuleBitcode, pContainerStream);
+    SerializeDxilContainerForModule(&m_llvmModule->GetOrCreateDxilModule(), pModuleBitcode, pContainerStream);
 
     pDxilContainerBlob.Release();
     IFT(pContainerStream.QueryInterface(&pDxilContainerBlob));

+ 36 - 86
tools/clang/tools/dxcompiler/dxcvalidator.cpp

@@ -61,6 +61,22 @@ static void PrintDiagnosticHandler(const DiagnosticInfo &DI, void *Context) {
   reinterpret_cast<PrintDiagnosticContext *>(Context)->Handle(DI);
 }
 
+// Utility class for setting and restoring the diagnostic context so we may capture errors/warnings
+struct DiagRestore {
+  LLVMContext &Ctx;
+  void *OrigDiagContext;
+  LLVMContext::DiagnosticHandlerTy OrigHandler;
+
+  DiagRestore(llvm::LLVMContext &Ctx, void *DiagContext) : Ctx(Ctx) {
+    OrigHandler = Ctx.getDiagnosticHandler();
+    OrigDiagContext = Ctx.getDiagnosticContext();
+    Ctx.setDiagnosticHandler(PrintDiagnosticHandler, DiagContext);
+  }
+  ~DiagRestore() {
+    Ctx.setDiagnosticHandler(OrigHandler, OrigDiagContext);
+  }
+};
+
 class DxcValidator : public IDxcValidator, public IDxcVersionInfo {
 private:
   DXC_MICROCOM_REF_FIELD(m_dwRef)
@@ -68,7 +84,7 @@ private:
   HRESULT RunValidation(
     _In_ IDxcBlob *pShader,                       // Shader to validate.
     _In_ llvm::Module *pModule,                   // Module to validate, if available.
-    _In_ llvm::Module *pDiagModule,               // Diag module to validate, if available
+    _In_ llvm::Module *pDebugModule,              // Debug module to validate, if available
     _In_ AbstractMemoryStream *pDiagStream);
 
 public:
@@ -84,7 +100,7 @@ public:
     _In_ IDxcBlob *pShader,                       // Shader to validate.
     _In_ UINT32 Flags,                            // Validation flags.
     _In_ llvm::Module *pModule,                   // Module to validate, if available.
-    _In_ llvm::Module *pDiagModule,               // Diag module to validate, if available
+    _In_ llvm::Module *pDebugModule,              // Debug module to validate, if available
     _COM_Outptr_ IDxcOperationResult **ppResult   // Validation output status, buffer, and errors
   );
 
@@ -115,7 +131,7 @@ HRESULT DxcValidator::ValidateWithOptModules(
   _In_ IDxcBlob *pShader,                       // Shader to validate.
   _In_ UINT32 Flags,                            // Validation flags.
   _In_ llvm::Module *pModule,                   // Module to validate, if available.
-  _In_ llvm::Module *pDiagModule,               // Diag module to validate, if available
+  _In_ llvm::Module *pDebugModule,              // Debug module to validate, if available
   _COM_Outptr_ IDxcOperationResult **ppResult   // Validation output status, buffer, and errors
 ) {
   *ppResult = nullptr;
@@ -130,7 +146,7 @@ HRESULT DxcValidator::ValidateWithOptModules(
 
     // Run validation may throw, but that indicates an inability to validate,
     // not that the validation failed (eg out of memory).
-    validationStatus = RunValidation(pShader, pModule, pDiagModule, pDiagStream);
+    validationStatus = RunValidation(pShader, pModule, pDebugModule, pDiagStream);
 
     // Assemble the result object.
     CComPtr<IDxcBlob> pDiagBlob;
@@ -173,96 +189,30 @@ HRESULT DxcValidator::RunValidation(
   // not that the validation failed (eg out of memory). That is indicated
   // by a failing HRESULT, and possibly error messages in the diagnostics stream.
 
-  llvm::LLVMContext llvmContext;
-  std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf;
-  std::unique_ptr<llvm::Module> pLoadedModule;
-  std::unique_ptr<llvm::MemoryBuffer> pDbgBitcodeBuf;
-  std::unique_ptr<llvm::Module> pLoadedDbgModule;
   raw_stream_ostream DiagStream(pDiagStream);
-  llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
-  PrintDiagnosticContext DiagContext(DiagPrinter);
-  if (pModule == nullptr) {
+
+  if (!pModule) {
     DXASSERT_NOMSG(pDebugModule == nullptr);
-    // Accept a bitcode buffer or a DXIL container.
-    const char *pIL = (const char*)pShader->GetBufferPointer();
-    uint32_t pILLength = pShader->GetBufferSize();
-    const char *pDbgIL = nullptr;
-    uint32_t pDbgILLength = 0;
-    if (const DxilContainerHeader *pContainer =
-      IsDxilContainerLike(pIL, pILLength)) {
-      if (!IsValidDxilContainer(pContainer, pILLength)) {
-        IFR(DXC_E_CONTAINER_INVALID);
-      }
-
-      DxilPartIterator it = std::find_if(begin(pContainer), end(pContainer),
-        DxilPartIsType(DFCC_DXIL));
-      if (it == end(pContainer)) {
-        IFR(DXC_E_CONTAINER_MISSING_DXIL);
-      }
-
-      const DxilProgramHeader *pProgramHeader =
-        reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(*it));
-      if (!IsValidDxilProgramHeader(pProgramHeader, (*it)->PartSize)) {
-        IFR(DXC_E_CONTAINER_INVALID);
-      }
-      GetDxilProgramBitcode(pProgramHeader, &pIL, &pILLength);
-
-      // Look for an optional debug version of the module. If it's there,
-      // it should be valid.
-      DxilPartIterator dbgit = std::find_if(begin(pContainer), end(pContainer),
-        DxilPartIsType(DFCC_ShaderDebugInfoDXIL));
-      if (dbgit != end(pContainer)) {
-        const DxilProgramHeader *pDbgHeader =
-          reinterpret_cast<const DxilProgramHeader *>(GetDxilPartData(*dbgit));
-        if (!IsValidDxilProgramHeader(pDbgHeader, (*dbgit)->PartSize)) {
-          IFR(DXC_E_CONTAINER_INVALID);
-        }
-        GetDxilProgramBitcode(pDbgHeader, &pDbgIL, &pDbgILLength);
-      }
+    if (IsDxilContainerLike(pShader->GetBufferPointer(), pShader->GetBufferSize())) {
+      return ValidateDxilContainer(pShader->GetBufferPointer(), pShader->GetBufferSize(), DiagStream);
+    } else {
+      return ValidateDxilBitcode((const char*)pShader->GetBufferPointer(), (uint32_t)pShader->GetBufferSize(), DiagStream);
     }
+  }
 
-    llvmContext.setDiagnosticHandler(PrintDiagnosticHandler, &DiagContext, true);
-    pBitcodeBuf.reset(llvm::MemoryBuffer::getMemBuffer(
-        llvm::StringRef(pIL, pILLength), "", false).release());
-    ErrorOr<std::unique_ptr<llvm::Module>> loadedModuleResult(llvm::parseBitcodeFile(
-      pBitcodeBuf->getMemBufferRef(), llvmContext));
+  llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
+  PrintDiagnosticContext DiagContext(DiagPrinter);
+  DiagRestore DR(pModule->getContext(), &DiagContext);
 
-    // DXIL disallows some LLVM bitcode constructs, like unaccounted-for sub-blocks.
-    // These appear as warnings, which the validator should reject.
-    if (DiagContext.HasErrors() || DiagContext.HasWarnings()) {
-      IFR(DXC_E_IR_VERIFICATION_FAILED);
-    }
-    if (std::error_code ec = loadedModuleResult.getError()) {
-      IFR(DXC_E_IR_VERIFICATION_FAILED);
-    }
-    pLoadedModule.swap(loadedModuleResult.get());
-
-    if (pDbgIL != nullptr) {
-      pDbgBitcodeBuf.reset(llvm::MemoryBuffer::getMemBuffer(
-          llvm::StringRef(pDbgIL, pDbgILLength), "", false).release());
-      ErrorOr<std::unique_ptr<llvm::Module>> loadedDbgModuleResult(
-          llvm::parseBitcodeFile(pDbgBitcodeBuf->getMemBufferRef(), llvmContext));
-      if (std::error_code ec = loadedDbgModuleResult.getError()) {
-        IFR(DXC_E_IR_VERIFICATION_FAILED);
-      }
-      pLoadedDbgModule.swap(loadedDbgModuleResult.get());
-    }
-    pModule = pLoadedModule.get();
-    pDebugModule = pLoadedDbgModule.get();
-  }
-  else {
-    // Install the diagnostic handler on the
-    pModule->getContext().setDiagnosticHandler(PrintDiagnosticHandler,
-                                               &DiagContext, true);
-  }
+  IFR(hlsl::ValidateDxilModule(pModule, pDebugModule));
+  IFR(ValidateDxilContainerParts(pModule, pDebugModule,
+                    IsDxilContainerLike(pShader->GetBufferPointer(), pShader->GetBufferSize()),
+                    (uint32_t)pShader->GetBufferSize()));
 
-  if (std::error_code ec = hlsl::ValidateDxilModule(pModule, pDebugModule)) {
-    IFR(DXC_E_IR_VERIFICATION_FAILED);
+  if (DiagContext.HasErrors() || DiagContext.HasWarnings()) {
+    return DXC_E_IR_VERIFICATION_FAILED;
   }
 
-  // TODO: run validation that cross-references other parts in the
-  // DxilContainer if available.
-
   return S_OK;
 }
 

+ 7 - 0
utils/hct/hctdb.py

@@ -1443,6 +1443,12 @@ class db_dxil(object):
     def build_valrules(self):
         self.add_valrule_msg("Bitcode.Valid", "TODO - Module must be bitcode-valid", "Module bitcode is invalid")
 
+        self.add_valrule_msg("Container.PartMatches", "DXIL Container Parts must match Module", "Container part '%0' does not match expected for module.")
+        self.add_valrule_msg("Container.PartRepeated", "DXIL Container must have only one of each part type", "More than one container part '%0'.")
+        self.add_valrule_msg("Container.PartMissing", "DXIL Container requires certain parts, corresponding to module", "Missing part '%0' required by module.")
+        self.add_valrule_msg("Container.PartInvalid", "DXIL Container must not contain unknown parts", "Unknown part '%0' found in DXIL container.")
+        self.add_valrule_msg("Container.RootSignatureIncompatible", "Root Signature in DXIL Container must be compatible with shader", "Root Signature in DXIL container is not compatible with shader.")
+
         self.add_valrule("Meta.Required", "TODO - Required metadata missing")
         self.add_valrule_msg("Meta.Known", "Named metadata should be known", "Named metadata '%0' is unknown")
         self.add_valrule("Meta.Used", "All metadata must be used by dxil")
@@ -1649,6 +1655,7 @@ class db_dxil(object):
         
         # Assign sensible category names and build up an enumeration description
         cat_names = {
+            "CONTAINER": "Container",
             "BITCODE": "Bitcode",
             "META": "Metadata",
             "INSTR": "Instruction",