Ver código fonte

Optimized bitcode loading. Added function to only materialize named MD. (#2854)

Adam Yang 5 anos atrás
pai
commit
5c64108bcc

+ 2 - 0
include/dxc/DXIL/DxilUtil.h

@@ -118,6 +118,8 @@ namespace dxilutil {
     llvm::LLVMContext &Ctx, std::string &DiagStr);
   std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
     llvm::LLVMContext &Ctx, std::string &DiagStr);
+  std::unique_ptr<llvm::Module> LoadModuleFromBitcodeLazy(std::unique_ptr<llvm::MemoryBuffer> &&MB,
+    llvm::LLVMContext &Ctx, std::string &DiagStr);
   void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context);
   bool IsIntegerOrFloatingPointType(llvm::Type *Ty);
   // Returns true if type contains HLSL Object type (resource)

+ 21 - 2
include/llvm/Bitcode/BitstreamReader.h

@@ -55,7 +55,7 @@ public:
   struct ScopeTrack {
     BitstreamCursor *BC;
     uint64_t begin;
-    ~ScopeTrack();
+    inline ~ScopeTrack();
   };
   static ScopeTrack scope_track(BitstreamCursor *BC);
   static void track(BitstreamUseTracker *BT, uint64_t begin, uint64_t end);
@@ -238,6 +238,7 @@ class BitstreamCursor {
   /// This tracks the codesize of parent blocks.
   SmallVector<Block, 8> BlockScope;
 
+  template<typename T> inline void AddRecordElements(BitCodeAbbrevOp::Encoding enc, uint64_t encData, unsigned NumElts, SmallVectorImpl<T> &Vals); // HLSL Change
 
 public:
   static const size_t MaxChunkSize = sizeof(word_t) * 8;
@@ -478,6 +479,15 @@ public:
     return Read(CurCodeSize);
   }
 
+  // HLSL Change - begin
+  inline unsigned PeekCode() {
+    auto BitPos = GetCurrentBitNo();
+    unsigned result = Read(CurCodeSize);
+    JumpToBit(BitPos);
+    return result;
+  }
+  // HLSL Change - end
+
 
   // Block header:
   //    [ENTER_SUBBLOCK, blockid, newcodelen, <align4bytes>, blocklen]
@@ -547,7 +557,9 @@ public:
   void skipRecord(unsigned AbbrevID);
 
   unsigned readRecord(unsigned AbbrevID, SmallVectorImpl<uint64_t> &Vals,
-                      StringRef *Blob = nullptr);
+                      StringRef *Blob = nullptr,
+                      SmallVectorImpl<uint8_t> *Uint8Vals = nullptr); // HLSL Change
+  unsigned peekRecord(unsigned AbbrevID); // HLSL Change
 
   //===--------------------------------------------------------------------===//
   // Abbrev Processing
@@ -557,6 +569,13 @@ public:
   bool ReadBlockInfoBlock(unsigned *pCount = nullptr);
 };
 
+// HLSL Change - Begin
+BitstreamUseTracker::ScopeTrack::~ScopeTrack() {
+  if (auto *Tracker = BC->getBitStreamReader()->Tracker)
+    Tracker->insert(begin, BC->GetCurrentBitNo());
+}
+// HLSL Change - End
+
 } // End llvm namespace
 
 #endif

+ 4 - 2
include/llvm/Bitcode/ReaderWriter.h

@@ -37,7 +37,8 @@ namespace llvm {
   getLazyBitcodeModule(std::unique_ptr<MemoryBuffer> &&Buffer,
                        LLVMContext &Context,
                        DiagnosticHandlerFunction DiagnosticHandler = nullptr,
-                       bool ShouldLazyLoadMetadata = false);
+                       bool ShouldLazyLoadMetadata = false,
+                       bool ShouldTrackBitstreamUsage = false);
 
   /// Read the header of the specified stream and prepare for lazy
   /// deserialization and streaming of function bodies.
@@ -56,7 +57,8 @@ namespace llvm {
   /// Read the specified bitcode file, returning the module.
   ErrorOr<std::unique_ptr<Module>>
   parseBitcodeFile(MemoryBufferRef Buffer, LLVMContext &Context,
-                   DiagnosticHandlerFunction DiagnosticHandler = nullptr);
+                   DiagnosticHandlerFunction DiagnosticHandler = nullptr,
+                   bool ShouldTrackBitstreamUsage = false); // HLSL Change
 
   /// \brief Write the specified module to the specified raw output stream.
   ///

+ 3 - 0
include/llvm/IR/GVMaterializer.h

@@ -20,6 +20,8 @@
 
 #include <system_error>
 #include <vector>
+#include "llvm/ADT/ArrayRef.h" // HLSL Change
+#include "llvm/ADT/StringRef.h" // HLSL Change
 
 namespace llvm {
 class Function;
@@ -54,6 +56,7 @@ public:
   virtual std::error_code materializeModule(Module *M) = 0;
 
   virtual std::error_code materializeMetadata() = 0;
+  virtual std::error_code materializeSelectNamedMetadata(llvm::ArrayRef<llvm::StringRef>) = 0; // HLSL Change
   virtual void setStripDebugInfo() = 0;
 
   virtual std::vector<StructType *> getIdentifiedStructTypes() const = 0;

+ 1 - 0
include/llvm/IR/Module.h

@@ -515,6 +515,7 @@ public:
   std::error_code materializeAllPermanently();
 
   std::error_code materializeMetadata();
+  std::error_code materializeSelectNamedMetadata(ArrayRef<StringRef> NamedMetadata); // HLSL Change
 
 /// @}
 /// @name Direct access to the globals list, functions list, and symbol table

+ 315 - 7
lib/Bitcode/Reader/BitcodeReader.cpp

@@ -34,6 +34,7 @@
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/raw_ostream.h"
 #include <deque>
+#include <unordered_set> // HLSL Change
 #include "dxc/DXIL/DxilOperations.h"   // HLSL Change
 using namespace llvm;
 
@@ -234,6 +235,7 @@ public:
 
   void releaseBuffer();
 
+  bool ShouldTrackBitstreamUsage = false; // HLSL Change
   BitstreamUseTracker Tracker; // HLSL Change
 
   bool isDematerializable(const GlobalValue *GV) const override;
@@ -369,6 +371,8 @@ private:
   std::error_code globalCleanup();
   std::error_code resolveGlobalAndAliasInits();
   std::error_code parseMetadata();
+  std::error_code parseSelectNamedMetadata(ArrayRef<StringRef> NamedMetadata); // HLSL Change
+  std::error_code materializeSelectNamedMetadata(ArrayRef<StringRef> NamedMetadata); // HLSL Change
   std::error_code parseMetadataAttachment(Function &F);
   ErrorOr<std::string> parseModuleTriple();
   std::error_code parseUseLists();
@@ -1630,6 +1634,278 @@ std::error_code BitcodeReader::parseValueSymbolTable() {
 
 static int64_t unrotateSign(uint64_t U) { return U & 1 ? ~(U >> 1) : U >> 1; }
 
+// HLSL Change - Begin
+// This function takes a list of strings that corresponds to the list of named
+// metadata that we want to materialize, and materialize them efficiently.
+//
+// Note: This function will only materialize metadata that are the following
+// types:
+//
+//    MDString        e.g. !"my metadata string"
+//    MDNode          e.g. !10 = !{ !"my node", !32, !48 }
+//    distinct MDNode e.g. !10 = distinct !{ !"my node", !32, !48 }
+//    ValueAsMetadata e.g. true, 0, 10
+//
+// Everything else will appear as !<temporary>
+//
+// We first skip through the whole METADATA_BLOCK_ID block. As we do, we take
+// note of the named metadata we want, and push their operands into a queue.
+// We also record the bit offsets where all String, Node, and Value metadata.
+//
+// Next, we go through the queue, and skip to their bit offsets and load their
+// data (but only if they're the types listed above). If the metadata has their
+// own operands, we insert them into the queue as well.
+//
+//
+std::error_code BitcodeReader::parseSelectNamedMetadata(ArrayRef<StringRef> NamedMetadata) {
+  // Remember our bit position right at the start, because we're going to
+  // jump back to it later.
+  uint64_t OriginalBitPos = Stream.GetCurrentBitNo();
+
+  // Buffer used to read record operands.
+  SmallVector<uint64_t, 64> Record;
+  SmallVector<uint8_t, 32> Uint8Record;
+
+  // A map that we use to remember where we saw each value number
+  struct Info {
+    uint64_t BitPos;
+    uint64_t ID;
+    bool IsString;
+  };
+  std::vector<Info> NodePositions;
+
+  unsigned NextMDValueNo = MDValueList.size();
+
+  if (Stream.EnterSubBlock(bitc::METADATA_BLOCK_ID))
+    return error("Invalid record");
+
+  SmallVector<uint64_t, 1> AbbrevDefines;
+  std::vector<uint64_t> NodeQueue;
+  std::unordered_set<uint64_t> NodeQueueSet;
+
+  auto add_to_queue = [&NodeQueueSet, &NodeQueue](uint64_t Val) {
+    if (NodeQueueSet.insert(Val).second)
+      NodeQueue.push_back(Val);
+  };
+
+  // Read all the records.
+  while (1) {
+    // If we encounter a DEFINE_ABBREV record, record where it is.
+    // We need to go back to them to recover the abbreviation list.
+    // There shouldn't be more than one... but just in case.
+    if (Stream.PeekCode() == bitc::DEFINE_ABBREV) {
+      AbbrevDefines.push_back(Stream.GetCurrentBitNo());
+    }
+
+    // HLSL Change Starts - count skipped blocks
+    unsigned skipCount = 0;
+    BitstreamEntry Entry = Stream.advanceSkippingSubblocks(0, &skipCount);
+    if (skipCount) ReportWarning(DiagnosticHandler, "Unrecognized subblock");
+    // HLSL Change End
+
+    bool Stop = false;
+    switch (Entry.Kind) {
+    case BitstreamEntry::SubBlock: // Handled for us already.
+    case BitstreamEntry::Error:
+      return error("Malformed block");
+    case BitstreamEntry::EndBlock:
+      MDValueList.tryToResolveCycles();
+      Stop = true;
+      break;
+    case BitstreamEntry::Record:
+      // The interesting case.
+      break;
+    }
+    if (Stop)
+      break;
+
+    // Peek the record type without changing bit-position and loading
+    // the record's operands.
+    unsigned PeekCode = Stream.peekRecord(Entry.ID);
+
+    // For the first pass, we're only interested in named metadata.
+    if (PeekCode != bitc::METADATA_NAME) {
+      // If it's one of these types of things, remember where we actually
+      // found it.
+      switch (PeekCode) {
+      case bitc::METADATA_DISTINCT_NODE:
+      case bitc::METADATA_NODE:
+      case bitc::METADATA_STRING:
+      case bitc::METADATA_VALUE:
+        auto old_size = NodePositions.size();
+        NodePositions.resize(NextMDValueNo + 1);
+        memset(NodePositions.data() + old_size, 0, sizeof(NodePositions[0]) * (NextMDValueNo + 1 - old_size));
+        NodePositions[NextMDValueNo] = { Stream.GetCurrentBitNo(), Entry.ID, PeekCode == bitc::METADATA_STRING };
+        break;
+      }
+      // Skip the record without loading anything.
+      Stream.skipRecord(Entry.ID);
+      NextMDValueNo++;
+      continue;
+    }
+
+    // Read a record.
+    Record.clear();
+    unsigned Code = Stream.readRecord(Entry.ID, Record);
+
+    // Read name of the named metadata.
+    SmallString<8> Name(Record.begin(), Record.end());
+    Record.clear();
+    Code = Stream.ReadCode();
+
+    // Figure out if it's one of the named metadata that we actually want.
+    bool found = false;
+    for (unsigned i = 0; i < NamedMetadata.size(); i++) {
+      if (Name == NamedMetadata[i]) {
+        found = true;
+        break;
+      }
+    }
+
+    // If it's not interesting to us, then just skip.
+    if (!found) {
+      Stream.skipRecord(Code);
+    }
+    else {
+      unsigned NextBitCode = Stream.readRecord(Code, Record);
+      if (NextBitCode != bitc::METADATA_NAMED_NODE)
+        return error("METADATA_NAME not followed by METADATA_NAMED_NODE");
+
+      // Read named metadata elements.
+      unsigned Size = Record.size();
+      NamedMDNode *NMD = TheModule->getOrInsertNamedMetadata(Name);
+      for (unsigned i = 0; i != Size; ++i) {
+        MDNode *MD = dyn_cast_or_null<MDNode>(MDValueList.getValueFwdRef(Record[i]));
+        if (!MD)
+          return error("Invalid record");
+
+        add_to_queue(Record[i]); // Add this MD number to our queue, so we know to try to read it later.
+        NMD->addOperand(MD);
+      }
+    }
+  }
+
+  // Now that we have gathered all the metadata operands that the named
+  // metadata need...
+
+  // Go back to the beginning.
+  Stream.JumpToBit(OriginalBitPos);
+
+  // Re-enter the metadata block.
+  if (Stream.EnterSubBlock(bitc::METADATA_BLOCK_ID))
+    return error("Invalid record");
+
+  // Load all the abbreviations's again, since exiting the block and re-entering
+  // the block has wiped them clean.
+  for (unsigned i = 0; i < AbbrevDefines.size(); i++) {
+    Stream.JumpToBit(AbbrevDefines[i]);
+    while (1) {
+      unsigned Code = Stream.ReadCode();
+      if (Code == bitc::DEFINE_ABBREV) {
+        Stream.ReadAbbrevRecord();
+        continue;
+      }
+      else {
+        break;
+      }
+    }
+  }
+
+  std::string String; // String buffer used to read MD string
+
+  // Go through the queue and read all the metadata we want.
+  for (unsigned i = 0; i < NodeQueue.size(); i++) {
+    uint64_t MDNumber = NodeQueue[i];
+
+    // If we never memorized the location for this MD No, it means
+    // it wasn't one of the MD types that we care about.
+    if (MDNumber >= NodePositions.size())
+      continue;
+    Info I = NodePositions[MDNumber];
+    if (I.BitPos == 0)
+      continue;
+
+    // Go back to the bit where we read the record
+    Stream.JumpToBit(I.BitPos);
+
+    // Read a record
+    Record.clear();
+    unsigned Code = 0;
+
+    // If it's a string, use our special Uint8Buffer to speed up the reading.
+    if (I.IsString) {
+      Uint8Record.clear();
+      Code = Stream.readRecord(I.ID, Record, nullptr, &Uint8Record);
+      assert(!Uint8Record.empty() || (Record.empty() && Uint8Record.empty()));
+    }
+    else {
+      Code = Stream.readRecord(I.ID, Record);
+    }
+
+    // Read the actual data. This code is largely copied from parseMetadata
+    bool IsDistinct = false;
+    switch (Code) {
+    default:
+      llvm_unreachable("Can't actually be anything else.");
+      break;
+    case bitc::METADATA_VALUE: {
+      if (Record.size() != 2)
+        return error("Invalid record");
+
+      Type *Ty = getTypeByID(Record[0]);
+      if (Ty->isMetadataTy() || Ty->isVoidTy())
+        return error("Invalid record");
+
+      MDValueList.assignValue(
+          ValueAsMetadata::get(ValueList.getValueFwdRef(Record[1], Ty)),
+          MDNumber);
+      break;
+    }
+    case bitc::METADATA_DISTINCT_NODE:
+      IsDistinct = true;
+      // fallthrough...
+    case bitc::METADATA_NODE: {
+      SmallVector<Metadata *, 8> Elts;
+      Elts.reserve(Record.size());
+      for (unsigned ID : Record) {
+        Elts.push_back(ID ? MDValueList.getValueFwdRef(ID - 1) : nullptr);
+        // If this ID is not a null MD, add to queue.
+        if (ID)
+          add_to_queue(ID - 1);
+      }
+      MDValueList.assignValue(IsDistinct ? MDNode::getDistinct(Context, Elts)
+                                         : MDNode::get(Context, Elts),
+                              MDNumber);
+      break;
+    }
+    case bitc::METADATA_STRING: {
+      String.clear();
+      String.resize(Uint8Record.size());
+      memcpy(&String[0], Uint8Record.data(), Uint8Record.size());
+      llvm::UpgradeMDStringConstant(String);
+      Metadata *MD = MDString::get(Context, String);
+      MDValueList.assignValue(MD, MDNumber);
+      break;
+    }
+    }
+  }
+
+  return std::error_code();
+}
+
+std::error_code BitcodeReader::materializeSelectNamedMetadata(ArrayRef<StringRef> NamedMetadata) {
+  for (uint64_t BitPos : DeferredMetadataInfo) {
+    // Move the bit stream to the saved position.
+    Stream.JumpToBit(BitPos);
+    if (std::error_code EC = parseSelectNamedMetadata(NamedMetadata))
+      return EC;
+  }
+  DeferredMetadataInfo.clear();
+  return std::error_code();
+}
+
+// HLSL Change - end
+
 std::error_code BitcodeReader::parseMetadata() {
   IsMetadataMaterialized = true;
   unsigned NextMDValueNo = MDValueList.size();
@@ -1638,6 +1914,7 @@ std::error_code BitcodeReader::parseMetadata() {
     return error("Invalid record");
 
   SmallVector<uint64_t, 64> Record;
+  SmallVector<uint8_t, 64> Uint8Record; // HLSL Change
 
   auto getMD =
       [&](unsigned ID) -> Metadata *{ return MDValueList.getValueFwdRef(ID); };
@@ -1675,9 +1952,27 @@ std::error_code BitcodeReader::parseMetadata() {
       break;
     }
 
+#if 1 // HLSL Change
+    // If it's a string metadata, use our special Uint8Record to speed
+    // up reading.
+    unsigned PeekCode = Stream.peekRecord(Entry.ID);
+    unsigned Code = 0;
+    Record.clear();
+    if (PeekCode == bitc::METADATA_STRING) {
+      Uint8Record.clear();
+      Code = Stream.readRecord(Entry.ID, Record, nullptr, &Uint8Record);
+      assert(!Uint8Record.empty() || (Record.empty() && Uint8Record.empty()));
+    }
+    else {
+      Code = Stream.readRecord(Entry.ID, Record);
+    }
+#else // HLSL Change
     // Read a record.
     Record.clear();
     unsigned Code = Stream.readRecord(Entry.ID, Record);
+#endif // HLSL Change
+
+    std::string String; // HLSL Change - Reuse buffer for loading string.
     bool IsDistinct = false;
     switch (Code) {
     default:  // Default behavior: ignore.
@@ -2064,7 +2359,12 @@ std::error_code BitcodeReader::parseMetadata() {
       break;
     }
     case bitc::METADATA_STRING: {
+#if 0
       std::string String(Record.begin(), Record.end());
+#else
+      String.resize(Uint8Record.size());
+      memcpy(&String[0], Uint8Record.data(), Uint8Record.size());
+#endif
       llvm::UpgradeMDStringConstant(String);
       Metadata *MD = MDString::get(Context, String);
       MDValueList.assignValue(MD, NextMDValueNo++);
@@ -4637,7 +4937,7 @@ std::error_code BitcodeReader::materializeModule(Module *M) {
   UpgradeDebugInfo(*M);
 
   // HLSL Change Starts
-  if (!Tracker.isDense((uint64_t)(Buffer->getBufferSize()) * 8)) {
+  if (ShouldTrackBitstreamUsage && !Tracker.isDense((uint64_t)(Buffer->getBufferSize()) * 8)) {
     ReportWarning(DiagnosticHandler, "Unused bits in buffer.");
   }
   // HLSL Change Ends
@@ -4670,7 +4970,7 @@ std::error_code BitcodeReader::initStreamFromBuffer() {
       return error("Invalid bitcode wrapper header");
 
   StreamFile.reset(new BitstreamReader(BufPtr, BufEnd));
-  StreamFile->Tracker = &Tracker; // HLSL Change
+  if (ShouldTrackBitstreamUsage) StreamFile->Tracker = &Tracker; // HLSL Change
   Stream.init(&*StreamFile);
 
   return std::error_code();
@@ -4770,7 +5070,9 @@ static ErrorOr<std::unique_ptr<Module>>
 getLazyBitcodeModuleImpl(std::unique_ptr<MemoryBuffer> &&Buffer,
                          LLVMContext &Context, bool MaterializeAll,
                          DiagnosticHandlerFunction DiagnosticHandler,
-                         bool ShouldLazyLoadMetadata = false) {
+                         bool ShouldLazyLoadMetadata = false,
+                         bool ShouldTrackBitstreamUsage = false) // HLSL Change
+{
   // HLSL Change Begin: Proper memory management with unique_ptr
   // Get the buffer identifier before we transfer the ownership to the bitcode reader,
   // this is ugly but safe as long as it keeps the buffer, and hence identifier string, alive.
@@ -4778,6 +5080,7 @@ getLazyBitcodeModuleImpl(std::unique_ptr<MemoryBuffer> &&Buffer,
   std::unique_ptr<BitcodeReader> R = llvm::make_unique<BitcodeReader>(
     std::move(Buffer), Context, DiagnosticHandler);
 
+  if (R) R->ShouldTrackBitstreamUsage = ShouldTrackBitstreamUsage; // HLSL Change
   ErrorOr<std::unique_ptr<Module>> Ret =
       getBitcodeModuleImpl(nullptr, BufferIdentifier, std::move(R), Context,
                            MaterializeAll, ShouldLazyLoadMetadata);
@@ -4790,9 +5093,11 @@ getLazyBitcodeModuleImpl(std::unique_ptr<MemoryBuffer> &&Buffer,
 
 ErrorOr<std::unique_ptr<Module>> llvm::getLazyBitcodeModule(
     std::unique_ptr<MemoryBuffer> &&Buffer, LLVMContext &Context,
-    DiagnosticHandlerFunction DiagnosticHandler, bool ShouldLazyLoadMetadata) {
+    DiagnosticHandlerFunction DiagnosticHandler, bool ShouldLazyLoadMetadata,
+    bool ShouldTrackBitstreamUsage) {
   return getLazyBitcodeModuleImpl(std::move(Buffer), Context, false,
-                                  DiagnosticHandler, ShouldLazyLoadMetadata);
+                                  DiagnosticHandler, ShouldLazyLoadMetadata,
+                                  ShouldTrackBitstreamUsage); // HLSL Change
 }
 
 ErrorOr<std::unique_ptr<Module>> llvm::getStreamedBitcodeModule(
@@ -4823,7 +5128,9 @@ void report_fatal_error_handler(void *user_datam, const std::string &reason,
 
 ErrorOr<std::unique_ptr<Module>>
 llvm::parseBitcodeFile(MemoryBufferRef Buffer, LLVMContext &Context,
-                       DiagnosticHandlerFunction DiagnosticHandler) {
+                       DiagnosticHandlerFunction DiagnosticHandler,
+                       bool ShouldTrackBitstreamUsage) // HLSL Change
+{
   // HLSL Change Starts - introduce a ScopedFatalErrorHandler to handle
   // report_fatal_error from readers.
   report_fatal_error_data data(DiagnosticHandler);
@@ -4831,7 +5138,8 @@ llvm::parseBitcodeFile(MemoryBufferRef Buffer, LLVMContext &Context,
   // HLSL Change Ends
   std::unique_ptr<MemoryBuffer> Buf = MemoryBuffer::getMemBuffer(Buffer, false);
   return getLazyBitcodeModuleImpl(std::move(Buf), Context, true,
-                                  DiagnosticHandler);
+                                  DiagnosticHandler,
+                                  false, ShouldTrackBitstreamUsage); // HLSL Change
   // TODO: Restore the use-lists to the in-memory state when the bitcode was
   // written.  We must defer until the Module has been fully materialized.
 }

+ 131 - 9
lib/Bitcode/Reader/BitstreamReader.cpp

@@ -130,9 +130,35 @@ void BitstreamCursor::skipRecord(unsigned AbbrevID) {
       assert(i+2 == e && "array op not second to last?");
       const BitCodeAbbrevOp &EltEnc = Abbv->getOperandInfo(++i);
 
+#if 1 // HLSL Change - Make skipping go brrrrrrrrrrr
+      {
+        const auto &Op = EltEnc;
+        auto &Cursor = *this;
+        auto CurBit = Cursor.GetCurrentBitNo();
+        // Decode the value as we are commanded.
+        switch (EltEnc.getEncoding()) {
+        case BitCodeAbbrevOp::Array:
+        case BitCodeAbbrevOp::Blob:
+          llvm_unreachable("Should not reach here");
+        case BitCodeAbbrevOp::Fixed:
+          assert((unsigned)Op.getEncodingData() <= Cursor.MaxChunkSize);
+          Cursor.JumpToBit(CurBit + NumElts * Op.getEncodingData());
+          break;
+        case BitCodeAbbrevOp::VBR:
+          assert((unsigned)Op.getEncodingData() <= Cursor.MaxChunkSize);
+          for (; NumElts; --NumElts)
+            Cursor.ReadVBR64((unsigned)Op.getEncodingData());
+          break;
+        case BitCodeAbbrevOp::Char6:
+          Cursor.JumpToBit(CurBit + NumElts * 6);
+          break;
+        }
+      }
+#else
       // Read all the elements.
       for (; NumElts; --NumElts)
         skipAbbreviatedField(*this, EltEnc);
+#endif
       continue;
     }
 
@@ -156,14 +182,70 @@ void BitstreamCursor::skipRecord(unsigned AbbrevID) {
   }
 }
 
+// HLSL Change - Begin
+unsigned BitstreamCursor::peekRecord(unsigned AbbrevID) {
+  auto last_bit_pos = GetCurrentBitNo();
+  if (AbbrevID == bitc::UNABBREV_RECORD) {
+    unsigned Code = ReadVBR(6);
+    this->JumpToBit(last_bit_pos);
+    return Code;
+  }
+
+  const BitCodeAbbrev *Abbv = getAbbrev(AbbrevID);
+
+  // Read the record code first.
+  assert(Abbv->getNumOperandInfos() != 0 && "no record code in abbreviation?");
+  const BitCodeAbbrevOp &CodeOp = Abbv->getOperandInfo(0);
+  unsigned Code;
+  if (CodeOp.isLiteral())
+    Code = CodeOp.getLiteralValue();
+  else {
+    if (CodeOp.getEncoding() == BitCodeAbbrevOp::Array ||
+        CodeOp.getEncoding() == BitCodeAbbrevOp::Blob)
+      report_fatal_error("Abbreviation starts with an Array or a Blob");
+    Code = readAbbreviatedField(*this, CodeOp);
+  }
+  this->JumpToBit(last_bit_pos);
+  return Code;
+}
+
+template<typename T>
+void BitstreamCursor::AddRecordElements(BitCodeAbbrevOp::Encoding enc, uint64_t encData, unsigned NumElts, SmallVectorImpl<T> &Vals) {
+  const unsigned size = (unsigned)encData;
+  if (enc == BitCodeAbbrevOp::VBR) {
+    assert((unsigned)encData <= MaxChunkSize);
+    for (; NumElts; --NumElts) {
+      Vals.push_back((T)ReadVBR64(size));
+    }
+  }
+  else if (enc == BitCodeAbbrevOp::Char6) {
+    assert((unsigned)encData <= MaxChunkSize);
+    for (; NumElts; --NumElts) {
+      Vals.push_back(BitCodeAbbrevOp::DecodeChar6(Read(6)));
+    }
+  }
+  else {
+    llvm_unreachable("Unknown kind of thing");
+  }
+}
+// HLSL Change - End
+
 unsigned BitstreamCursor::readRecord(unsigned AbbrevID,
                                      SmallVectorImpl<uint64_t> &Vals,
-                                     StringRef *Blob) {
+                                     StringRef *Blob,
+                                     SmallVectorImpl<uint8_t> *Uint8Vals // HLSL Change
+  ) {
   if (AbbrevID == bitc::UNABBREV_RECORD) {
     unsigned Code = ReadVBR(6);
     unsigned NumElts = ReadVBR(6);
-    for (unsigned i = 0; i != NumElts; ++i)
-      Vals.push_back(ReadVBR64(6));
+    if (Uint8Vals) {
+      for (unsigned i = 0; i != NumElts; ++i)
+        Uint8Vals->push_back((uint8_t)ReadVBR64(6));
+    }
+    else {
+      for (unsigned i = 0; i != NumElts; ++i)
+        Vals.push_back(ReadVBR64(6));
+    }
     return Code;
   }
 
@@ -210,9 +292,54 @@ unsigned BitstreamCursor::readRecord(unsigned AbbrevID,
           EltEnc.getEncoding() == BitCodeAbbrevOp::Blob)
         report_fatal_error("Array element type can't be an Array or a Blob");
 
+#if 1 // HLSL Change
+      // Read all the elements a little faster.
+      {
+        BitCodeAbbrevOp::Encoding enc = EltEnc.getEncoding();
+        uint64_t encData = 0;
+        if (EltEnc.hasEncodingData())
+          encData = EltEnc.getEncodingData();
+        unsigned size = (unsigned)encData;
+        if (Uint8Vals) {
+          if (enc == BitCodeAbbrevOp::Fixed) {
+            assert((unsigned)encData <= MaxChunkSize);
+            assert((unsigned)encData == 8);
+            // Special optimization for fixed elements that are 8 bits
+            Uint8Vals->resize(NumElts);
+            uint8_t *ptr = Uint8Vals->data();
+            unsigned i = 0;
+            constexpr unsigned BytesInWord = sizeof(size_t);
+            // First, read word by word instead of byte by byte
+            for (; NumElts >= BytesInWord; NumElts -= BytesInWord) {
+              const size_t e = Read(BytesInWord * 8);
+              memcpy(ptr + i, &e, sizeof(e));
+              i += BytesInWord;
+            }
+            for (; NumElts; --NumElts)
+              Uint8Vals->operator[](i++) = (uint8_t)Read(8);
+          }
+          else {
+            AddRecordElements(enc, encData, NumElts, *Uint8Vals);
+          }
+        }
+        else {
+          if (enc == BitCodeAbbrevOp::Fixed) {
+            assert((unsigned)encData <= MaxChunkSize);
+            Vals.reserve(Vals.size() + NumElts);
+            for (; NumElts; --NumElts)
+              Vals.push_back(Read(size));
+          }
+          else {
+            AddRecordElements(enc, encData, NumElts, Vals);
+          }
+        }
+      }
+#else // HLSL Change
       // Read all the elements.
       for (; NumElts; --NumElts)
         Vals.push_back(readAbbreviatedField(*this, EltEnc));
+
+#endif // HLSL Change
       continue;
     }
 
@@ -404,7 +531,7 @@ bool BitstreamUseTracker::considerMergeRight(size_t idx) {
 
 void BitstreamUseTracker::insert(uint64_t begin, uint64_t end) {
   UseRange IR(begin, end);
-  for (size_t i = 0; i < Ranges.size(); ++i) {
+  for (size_t i = 0, E = Ranges.size(); i < E; ++i) {
     ExtendResult ER = extendRange(Ranges[i], IR);
     switch (ER) {
     case Included:
@@ -437,11 +564,6 @@ void BitstreamUseTracker::insert(uint64_t begin, uint64_t end) {
   Ranges.push_back(IR);
 }
 
-BitstreamUseTracker::ScopeTrack::~ScopeTrack() {
-  if (BC->getBitStreamReader()->Tracker != nullptr)
-    BC->getBitStreamReader()->Tracker->insert(begin, BC->GetCurrentBitNo());
-}
-
 BitstreamUseTracker::ScopeTrack
 BitstreamUseTracker::scope_track(BitstreamCursor *BC) {
   ScopeTrack Result;

+ 11 - 0
lib/DXIL/DxilUtil.cpp

@@ -202,6 +202,17 @@ std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
   return std::unique_ptr<llvm::Module>(pModule.get().release());
 }
 
+std::unique_ptr<llvm::Module> LoadModuleFromBitcodeLazy(std::unique_ptr<llvm::MemoryBuffer> &&MB,
+  llvm::LLVMContext &Ctx, std::string &DiagStr)
+{
+  // Note: the DiagStr is not used.
+  auto pModule = llvm::getLazyBitcodeModule(std::move(MB), Ctx, nullptr, true);
+  if (!pModule) {
+    return nullptr;
+  }
+  return std::unique_ptr<llvm::Module>(pModule.get().release());
+}
+
 std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
   llvm::LLVMContext &Ctx,
   std::string &DiagStr) {

+ 2 - 2
lib/HLSL/DxilValidation.cpp

@@ -6052,8 +6052,8 @@ HRESULT ValidateLoadModule(const char *pIL,
 
   ErrorOr<std::unique_ptr<Module>> loadedModuleResult =
       bLazyLoad == 0?
-      llvm::parseBitcodeFile(pBitcodeBuf->getMemBufferRef(), Ctx) :
-      llvm::getLazyBitcodeModule(std::move(pBitcodeBuf), Ctx);
+      llvm::parseBitcodeFile(pBitcodeBuf->getMemBufferRef(), Ctx, nullptr, true /*Track Bitstream*/) :
+      llvm::getLazyBitcodeModule(std::move(pBitcodeBuf), Ctx, nullptr, false, true /*Track Bitstream*/);
 
   // DXIL disallows some LLVM bitcode constructs, like unaccounted-for sub-blocks.
   // These appear as warnings, which the validator should reject.

+ 6 - 0
lib/IR/Module.cpp

@@ -426,6 +426,12 @@ std::error_code Module::materializeMetadata() {
   return Materializer->materializeMetadata();
 }
 
+std::error_code Module::materializeSelectNamedMetadata(ArrayRef<StringRef> NamedMetadata) {
+  if (!Materializer)
+    return std::error_code();
+  return Materializer->materializeSelectNamedMetadata(NamedMetadata);
+}
+
 //===----------------------------------------------------------------------===//
 // Other module related stuff.
 //