Browse Source

Merged PR 79: DxilRuntimeData updates for final data layout

- Revise DxilRuntimeData for versioning support and robustness.
- Rename DxilRuntimeReflection desc structures.
- Add size checking to DxilRuntimeData/Reflection
- CheckedReader/CheckedWriter to catch potential problems
- Backward compatible reader for prerelease RDAT format
- Rewrite IndexArraysPart and avoid duplicate index arrays
- Always start string buffer with null for empty string at offset 0
Tex Riddell 7 years ago
parent
commit
df033577ee

+ 166 - 96
include/dxc/HLSL/DxilRuntimeReflection.h

@@ -15,20 +15,85 @@
 namespace hlsl {
 namespace RDAT {
 
+// Data Layout:
+// -start:
+//  RuntimeDataHeader header;
+//  uint32_t offsets[header.PartCount];
+//  - for each i in header.PartCount:
+//    - at &header + offsets[i]:
+//      RuntimeDataPartHeader part;
+//    - if part.Type is a Table (Function or Resource):
+//      RuntimeDataTableHeader table;
+//      byte TableData[table.RecordCount][table.RecordStride];
+//    - else if part.Type is String:
+//      byte UTF8Data[part.Size];
+//    - else if part.Type is Index:
+//      uint32_t IndexData[part.Size / 4];
+
+enum class RuntimeDataPartType : uint32_t { // TODO: Rename: PartType
+  Invalid = 0,
+  StringBuffer = 1,
+  IndexArrays = 2,
+  ResourceTable = 3,
+  FunctionTable = 4,
+};
+
+enum RuntimeDataVersion {
+  // Cannot be mistaken for part count from prerelease version
+  RDAT_Version_0 = 0x10,
+};
+
+struct RuntimeDataHeader {
+  uint32_t Version;
+  uint32_t PartCount;
+  // Followed by uint32_t array of offsets to parts
+  // offsets are relative to the beginning of this header
+  // offsets must be 4-byte aligned
+  //  uint32_t offsets[];
+};
+struct RuntimeDataPartHeader {
+  RuntimeDataPartType Type;
+  uint32_t Size;  // Not including this header.  Must be 4-byte aligned.
+  // Followed by part data
+  //  byte Data[ALIGN4(Size)];
+};
+
+// For tables of records, such as Function and Resource tables
+// Stride allows for extending records, with forward and backward compatibility
 struct RuntimeDataTableHeader {
-  uint32_t tableType; // RuntimeDataPartType
-  uint32_t size;
-  uint32_t offset;
+  uint32_t RecordCount;
+  uint32_t RecordStride;  // Must be 4-byte aligned.
+  // Followed by recordCount records of recordStride size
+  // byte TableData[RecordCount * RecordStride];
 };
 
-enum class RuntimeDataPartType : uint32_t {
-  Invalid = 0,
-  String,
-  Function,
-  Resource,
-  Index
+// General purpose strided table reader with casting Row() operation that
+// returns nullptr if stride is smaller than type, for record expansion.
+class TableReader {
+  const char *m_table;
+  uint32_t m_count;
+  uint32_t m_stride;
+
+public:
+  TableReader() : TableReader(nullptr, 0, 0) {}
+  TableReader(const char *table, uint32_t count, uint32_t stride)
+    : m_table(table), m_count(count), m_stride(stride) {}
+  void Init(const char *table, uint32_t count, uint32_t stride) {
+    m_table = table; m_count = count; m_stride = stride;
+  }
+  const char *Data() const { return m_table; }
+  uint32_t Count() const { return m_count; }
+  uint32_t Stride() const { return m_stride; }
+
+  template<typename T>
+  const T *Row(uint32_t index) const {
+    if (index < m_count && sizeof(T) <= m_stride)
+      return reinterpret_cast<const T*>(m_table + (m_stride * index));
+    return nullptr;
+  }
 };
 
+
 // Index table is a sequence of rows, where each row has a count as a first
 // element followed by the count number of elements pre computing values
 class IndexTableReader {
@@ -105,8 +170,8 @@ struct RuntimeDataFunctionInfo {
   uint32_t AttributeSizeInBytes; // attribute size for closest hit and any hit
   uint32_t FeatureInfo1;         // first 32 bits of feature flag
   uint32_t FeatureInfo2;         // second 32 bits of feature flag
-  uint32_t ShaderStageFlag;      // valid shader stage flag. Not implemented yet.
-  uint32_t MinShaderTarget;      // minimum shader target. Not implemented yet.
+  uint32_t ShaderStageFlag;      // valid shader stage flag.
+  uint32_t MinShaderTarget;      // minimum shader target.
 };
 
 class ResourceTableReader;
@@ -129,25 +194,27 @@ public:
                  RuntimeDataContext *context)
       : m_ResourceInfo(resInfo), m_Context(context) {}
   hlsl::DXIL::ResourceClass GetResourceClass() const {
-    return (hlsl::DXIL::ResourceClass)m_ResourceInfo->Class;
+    return !m_ResourceInfo ? hlsl::DXIL::ResourceClass::Invalid
+                           : (hlsl::DXIL::ResourceClass)m_ResourceInfo->Class;
   }
-  uint32_t GetSpace() const { return m_ResourceInfo->Space; }
-  uint32_t GetLowerBound() const { return m_ResourceInfo->LowerBound; }
-  uint32_t GetUpperBound() const { return m_ResourceInfo->UpperBound; }
+  uint32_t GetSpace() const { return !m_ResourceInfo ? 0 : m_ResourceInfo->Space; }
+  uint32_t GetLowerBound() const { return !m_ResourceInfo ? 0 : m_ResourceInfo->LowerBound; }
+  uint32_t GetUpperBound() const { return !m_ResourceInfo ? 0 : m_ResourceInfo->UpperBound; }
   hlsl::DXIL::ResourceKind GetResourceKind() const {
-    return (hlsl::DXIL::ResourceKind)m_ResourceInfo->Kind;
+    return !m_ResourceInfo ? hlsl::DXIL::ResourceKind::Invalid
+                           : (hlsl::DXIL::ResourceKind)m_ResourceInfo->Kind;
   }
-  uint32_t GetID() const { return m_ResourceInfo->ID; }
+  uint32_t GetID() const { return !m_ResourceInfo ? 0 : m_ResourceInfo->ID; }
   const char *GetName() const {
-    return m_Context->pStringTableReader->Get(m_ResourceInfo->Name);
+    return !m_ResourceInfo ? ""
+           : m_Context->pStringTableReader->Get(m_ResourceInfo->Name);
   }
-  uint32_t GetFlags() const { return m_ResourceInfo->Flags; }
+  uint32_t GetFlags() const { return !m_ResourceInfo ? 0 : m_ResourceInfo->Flags; }
 };
 
 class ResourceTableReader {
 private:
-  const RuntimeDataResourceInfo
-      *m_ResourceInfo; // pointer to an array of resource bind infos
+  TableReader m_Table;
   RuntimeDataContext *m_Context;
   uint32_t m_CBufferCount;
   uint32_t m_SamplerCount;
@@ -156,18 +223,11 @@ private:
 
 public:
   ResourceTableReader()
-      : m_ResourceInfo(nullptr), m_Context(nullptr), m_CBufferCount(0),
+      : m_Context(nullptr), m_CBufferCount(0),
         m_SamplerCount(0), m_SRVCount(0), m_UAVCount(0){};
-  ResourceTableReader(const RuntimeDataResourceInfo *info1,
-                      RuntimeDataContext *context, uint32_t CBufferCount,
-                      uint32_t SamplerCount, uint32_t SRVCount,
-                      uint32_t UAVCount)
-      : m_ResourceInfo(info1), m_Context(context), m_CBufferCount(CBufferCount),
-        m_SamplerCount(SamplerCount), m_SRVCount(SRVCount),
-        m_UAVCount(UAVCount){};
-
-  void SetResourceInfo(const RuntimeDataResourceInfo *ptr, uint32_t count) {
-    m_ResourceInfo = ptr;
+
+  void SetResourceInfo(const char *ptr, uint32_t count, uint32_t recordStride) {
+    m_Table.Init(ptr, count, recordStride);
     // Assuming that resources are in order of CBuffer, Sampler, SRV, and UAV,
     // count the number for each resource class
     m_CBufferCount = 0;
@@ -176,7 +236,8 @@ public:
     m_UAVCount = 0;
 
     for (uint32_t i = 0; i < count; ++i) {
-      const RuntimeDataResourceInfo *curPtr = &ptr[i];
+      const RuntimeDataResourceInfo *curPtr =
+        m_Table.Row<RuntimeDataResourceInfo>(i);
       if (curPtr->Class == (uint32_t)hlsl::DXIL::ResourceClass::CBuffer)
         m_CBufferCount++;
       else if (curPtr->Class == (uint32_t)hlsl::DXIL::ResourceClass::Sampler)
@@ -195,34 +256,34 @@ public:
   }
   ResourceReader GetItem(uint32_t i) const {
     _Analysis_assume_(i < GetNumResources());
-    return ResourceReader(&m_ResourceInfo[i], m_Context);
+    return ResourceReader(m_Table.Row<RuntimeDataResourceInfo>(i), m_Context);
   }
 
   uint32_t GetNumCBuffers() const { return m_CBufferCount; }
   ResourceReader GetCBuffer(uint32_t i) {
     _Analysis_assume_(i < m_CBufferCount);
-    return ResourceReader(&m_ResourceInfo[i], m_Context);
+    return ResourceReader(m_Table.Row<RuntimeDataResourceInfo>(i), m_Context);
   }
 
   uint32_t GetNumSamplers() const { return m_SamplerCount; }
   ResourceReader GetSampler(uint32_t i) {
     _Analysis_assume_(i < m_SamplerCount);
     uint32_t offset = (m_CBufferCount + i);
-    return ResourceReader(&m_ResourceInfo[offset], m_Context);
+    return ResourceReader(m_Table.Row<RuntimeDataResourceInfo>(offset), m_Context);
   }
 
   uint32_t GetNumSRVs() const { return m_SRVCount; }
   ResourceReader GetSRV(uint32_t i) {
     _Analysis_assume_(i < m_SRVCount);
     uint32_t offset = (m_CBufferCount + m_SamplerCount + i);
-    return ResourceReader(&m_ResourceInfo[offset], m_Context);
+    return ResourceReader(m_Table.Row<RuntimeDataResourceInfo>(offset), m_Context);
   }
 
   uint32_t GetNumUAVs() const { return m_UAVCount; }
   ResourceReader GetUAV(uint32_t i) {
     _Analysis_assume_(i < m_UAVCount);
     uint32_t offset = (m_CBufferCount + m_SamplerCount + m_SRVCount + i);
-    return ResourceReader(&m_ResourceInfo[offset], m_Context);
+    return ResourceReader(m_Table.Row<RuntimeDataResourceInfo>(offset), m_Context);
   }
 };
 
@@ -238,95 +299,102 @@ public:
       : m_RuntimeDataFunctionInfo(functionInfo), m_Context(context) {}
 
   const char *GetName() const {
-    return m_Context->pStringTableReader->Get(m_RuntimeDataFunctionInfo->Name);
+    return !m_RuntimeDataFunctionInfo ? ""
+      : m_Context->pStringTableReader->Get(m_RuntimeDataFunctionInfo->Name);
   }
   const char *GetUnmangledName() const {
-    return m_Context->pStringTableReader->Get(
-        m_RuntimeDataFunctionInfo->UnmangledName);
+    return !m_RuntimeDataFunctionInfo ? ""
+      : m_Context->pStringTableReader->Get(
+          m_RuntimeDataFunctionInfo->UnmangledName);
   }
   uint64_t GetFeatureFlag() const {
-    uint64_t flag =
-        static_cast<uint64_t>(m_RuntimeDataFunctionInfo->FeatureInfo2) << 32;
-    flag |= static_cast<uint64_t>(m_RuntimeDataFunctionInfo->FeatureInfo1);
-    return flag;
+    return (static_cast<uint64_t>(GetFeatureInfo2()) << 32)
+           | static_cast<uint64_t>(GetFeatureInfo1());
   }
   uint32_t GetFeatureInfo1() const {
-    return m_RuntimeDataFunctionInfo->FeatureInfo1;
+    return !m_RuntimeDataFunctionInfo ? 0
+      : m_RuntimeDataFunctionInfo->FeatureInfo1;
   }
   uint32_t GetFeatureInfo2() const {
-    return m_RuntimeDataFunctionInfo->FeatureInfo2;
+    return !m_RuntimeDataFunctionInfo ? 0
+      : m_RuntimeDataFunctionInfo->FeatureInfo2;
   }
 
   uint32_t GetShaderStageFlag() const {
-    return m_RuntimeDataFunctionInfo->ShaderStageFlag;
+    return !m_RuntimeDataFunctionInfo ? 0
+      : m_RuntimeDataFunctionInfo->ShaderStageFlag;
   }
   uint32_t GetMinShaderTarget() const {
-    return m_RuntimeDataFunctionInfo->MinShaderTarget;
+    return !m_RuntimeDataFunctionInfo ? 0
+      : m_RuntimeDataFunctionInfo->MinShaderTarget;
   }
   uint32_t GetNumResources() const {
-    if (m_RuntimeDataFunctionInfo->Resources == UINT_MAX)
+    if (!m_RuntimeDataFunctionInfo ||
+        m_RuntimeDataFunctionInfo->Resources == UINT_MAX)
       return 0;
-    return m_Context->pIndexTableReader
-      ->getRow(m_RuntimeDataFunctionInfo->Resources)
-      .Count();
+    return m_Context->pIndexTableReader->getRow(
+      m_RuntimeDataFunctionInfo->Resources).Count();
   }
   ResourceReader GetResource(uint32_t i) const {
-    uint32_t resIndex = m_Context->pIndexTableReader
-      ->getRow(m_RuntimeDataFunctionInfo->Resources)
-      .At(i);
+    if (!m_RuntimeDataFunctionInfo)
+      return ResourceReader(nullptr, m_Context);
+    uint32_t resIndex = m_Context->pIndexTableReader->getRow(
+      m_RuntimeDataFunctionInfo->Resources).At(i);
     return m_Context->pResourceTableReader->GetItem(resIndex);
   }
   uint32_t GetNumDependencies() const {
-    if (m_RuntimeDataFunctionInfo->FunctionDependencies == UINT_MAX)
+    if (!m_RuntimeDataFunctionInfo ||
+        m_RuntimeDataFunctionInfo->FunctionDependencies == UINT_MAX)
       return 0;
-    return m_Context->pIndexTableReader
-      ->getRow(m_RuntimeDataFunctionInfo->FunctionDependencies)
-      .Count();
+    return m_Context->pIndexTableReader->getRow(
+      m_RuntimeDataFunctionInfo->FunctionDependencies).Count();
   }
   const char *GetDependency(uint32_t i) const {
-    uint32_t resIndex =
-      m_Context->pIndexTableReader
-      ->getRow(m_RuntimeDataFunctionInfo->FunctionDependencies)
-      .At(i);
+    if (!m_RuntimeDataFunctionInfo)
+      return "";
+    uint32_t resIndex = m_Context->pIndexTableReader->getRow(
+      m_RuntimeDataFunctionInfo->FunctionDependencies).At(i);
     return m_Context->pStringTableReader->Get(resIndex);
   }
 
   uint32_t GetPayloadSizeInBytes() const {
-    return m_RuntimeDataFunctionInfo->PayloadSizeInBytes;
+    return !m_RuntimeDataFunctionInfo ? 0
+      : m_RuntimeDataFunctionInfo->PayloadSizeInBytes;
   }
   uint32_t GetAttributeSizeInBytes() const {
-    return m_RuntimeDataFunctionInfo->AttributeSizeInBytes;
+    return !m_RuntimeDataFunctionInfo ? 0
+      : m_RuntimeDataFunctionInfo->AttributeSizeInBytes;
   }
   // payload (hit shaders) and parameters (call shaders) are mutually exclusive
   uint32_t GetParameterSizeInBytes() const {
-    return m_RuntimeDataFunctionInfo->PayloadSizeInBytes;
+    return !m_RuntimeDataFunctionInfo ? 0
+      : m_RuntimeDataFunctionInfo->PayloadSizeInBytes;
   }
   hlsl::DXIL::ShaderKind GetShaderKind() const {
-    return (hlsl::DXIL::ShaderKind)m_RuntimeDataFunctionInfo->ShaderKind;
+    return !m_RuntimeDataFunctionInfo ? hlsl::DXIL::ShaderKind::Invalid
+      : (hlsl::DXIL::ShaderKind)m_RuntimeDataFunctionInfo->ShaderKind;
   }
 };
 
 class FunctionTableReader {
 private:
-  const RuntimeDataFunctionInfo *m_infos;
-  uint32_t m_count;
+  TableReader m_Table;
   RuntimeDataContext *m_context;
 
 public:
-  FunctionTableReader() : m_infos(nullptr), m_count(0), m_context(nullptr) {}
+  FunctionTableReader() : m_context(nullptr) {}
   FunctionTableReader(const RuntimeDataFunctionInfo *functionInfos,
                       uint32_t count, RuntimeDataContext *context)
-      : m_infos(functionInfos), m_count(count), m_context(context) {}
+      : m_context(context) {}
 
   FunctionReader GetItem(uint32_t i) const {
-    return FunctionReader(&m_infos[i], m_context);
+    return FunctionReader(m_Table.Row<RuntimeDataFunctionInfo>(i), m_context);
   }
-  uint32_t GetNumFunctions() const { return m_count; }
+  uint32_t GetNumFunctions() const { return m_Table.Count(); }
 
-  void SetFunctionInfo(const RuntimeDataFunctionInfo *ptr) {
-    m_infos = ptr;
+  void SetFunctionInfo(const char *ptr, uint32_t count, uint32_t recordStride) {
+    m_Table.Init(ptr, count, recordStride);
   }
-  void SetCount(uint32_t count) { m_count = count; }
   void SetContext(RuntimeDataContext *context) { m_context = context; }
 };
 
@@ -341,9 +409,11 @@ private:
 
 public:
   DxilRuntimeData();
-  DxilRuntimeData(const char *ptr);
+  DxilRuntimeData(const char *ptr, size_t size);
   // initializing reader from RDAT. return true if no error has occured.
-  bool InitFromRDAT(const void *pRDAT);
+  bool InitFromRDAT(const void *pRDAT, size_t size);
+  // read prerelease data:
+  bool InitFromRDAT_Prerelease(const void *pRDAT, size_t size);
   FunctionTableReader *GetFunctionTableReader();
   ResourceTableReader *GetResourceTableReader();
 };
@@ -351,7 +421,7 @@ public:
 //////////////////////////////////
 /// structures for library runtime
 
-typedef struct DXIL_RESOURCE {
+typedef struct DxilResourceDesc {
   uint32_t Class; // hlsl::DXIL::ResourceClass
   uint32_t Kind;  // hlsl::DXIL::ResourceKind
   uint32_t ID;    // id per class
@@ -360,13 +430,13 @@ typedef struct DXIL_RESOURCE {
   uint32_t LowerBound;
   LPCWSTR Name;
   uint32_t Flags; // hlsl::RDAT::DxilResourceFlag
-} DXIL_RESOURCE;
+} DxilResourceDesc;
 
-typedef struct DXIL_FUNCTION {
+typedef struct DxilFunctionDesc {
   LPCWSTR Name;
   LPCWSTR UnmangledName;
   uint32_t NumResources;
-  const DXIL_RESOURCE * const*Resources;
+  const DxilResourceDesc * const*Resources;
   uint32_t NumFunctionDependencies;
   const LPCWSTR *FunctionDependencies;
   uint32_t ShaderKind;
@@ -375,29 +445,29 @@ typedef struct DXIL_FUNCTION {
   uint32_t AttributeSizeInBytes; // attribute size for closest hit and any hit
   uint32_t FeatureInfo1;         // first 32 bits of feature flag
   uint32_t FeatureInfo2;         // second 32 bits of feature flag
-  uint32_t ShaderStageFlag;      // valid shader stage flag. Not implemented yet.
-  uint32_t MinShaderTarget;      // minimum shader target. Not implemented yet.
-} DXIL_FUNCITON;
+  uint32_t ShaderStageFlag;      // valid shader stage flag.
+  uint32_t MinShaderTarget;      // minimum shader target.
+} DxilFunctionDesc;
 
-typedef struct DXIL_SUBOBJECT {
-} DXIL_SUBOBJECT;
+typedef struct DxilSubobjectDesc {
+} DxilSubobjectDesc;
 
-typedef struct DXIL_LIBRARY_DESC {
+typedef struct DxilLibraryDesc {
   uint32_t NumFunctions;
-  DXIL_FUNCITON *pFunction;
+  DxilFunctionDesc *pFunction;
   uint32_t NumResources;
-  DXIL_RESOURCE *pResource;
+  DxilResourceDesc *pResource;
   uint32_t NumSubobjects;
-  DXIL_SUBOBJECT *pSubobjects;
-} DXIL_LIBRARY_DESC;
+  DxilSubobjectDesc *pSubobjects;
+} DxilLibraryDesc;
 
 class DxilRuntimeReflection {
 public:
   virtual ~DxilRuntimeReflection() {}
   // This call will allocate memory for GetLibraryReflection call
-  virtual bool InitFromRDAT(const void *pRDAT) = 0;
-  // DxilRuntimeReflection owns the memory pointed to by DXIL_LIBRARY_DESC
-  virtual const DXIL_LIBRARY_DESC GetLibraryReflection() = 0;
+  virtual bool InitFromRDAT(const void *pRDAT, size_t size) = 0;
+  // DxilRuntimeReflection owns the memory pointed to by DxilLibraryDesc
+  virtual const DxilLibraryDesc GetLibraryReflection() = 0;
 };
 
 DxilRuntimeReflection *CreateDxilRuntimeReflection();

+ 197 - 63
include/dxc/HLSL/DxilRuntimeReflection.inl

@@ -26,55 +26,190 @@ struct ResourceKey {
   }
 };
 
-DxilRuntimeData::DxilRuntimeData() : DxilRuntimeData(nullptr) {}
+// Size-checked reader
+//  on overrun: throw buffer_overrun{};
+//  on overlap: throw buffer_overlap{};
+class CheckedReader {
+  const char *Ptr;
+  size_t Size;
+  size_t Offset;
 
-DxilRuntimeData::DxilRuntimeData(const char *ptr)
+public:
+  class exception : public std::exception {};
+  class buffer_overrun : public exception {
+  public:
+    buffer_overrun() noexcept {}
+    virtual const char * what() const noexcept override {
+      return ("buffer_overrun");
+    }
+  };
+  class buffer_overlap : public exception {
+  public:
+    buffer_overlap() noexcept {}
+    virtual const char * what() const noexcept override {
+      return ("buffer_overlap");
+    }
+  };
+
+  CheckedReader(const void *ptr, size_t size) :
+    Ptr(reinterpret_cast<const char*>(ptr)), Size(size), Offset(0) {}
+  void Reset(size_t offset = 0) {
+    if (offset >= Size) throw buffer_overrun{};
+    Offset = offset;
+  }
+  // offset is absolute, ensure offset is >= current offset
+  void Advance(size_t offset = 0) {
+    if (offset < Offset) throw buffer_overlap{};
+    if (offset >= Size) throw buffer_overrun{};
+    Offset = offset;
+  }
+  void CheckBounds(size_t size) const {
+    assert(Offset <= Size && "otherwise, offset larger than size");
+    if (size > Size - Offset)
+      throw buffer_overrun{};
+  }
+  template <typename T>
+  const T *Cast(size_t size = 0) {
+    if (0 == size) size = sizeof(T);
+    CheckBounds(size);
+    return reinterpret_cast<const T*>(Ptr + Offset);
+  }
+  template <typename T>
+  const T &Read() {
+    const size_t size = sizeof(T);
+    const T* p = Cast<T>(size);
+    Offset += size;
+    return *p;
+  }
+  template <typename T>
+  const T *ReadArray(size_t count = 1) {
+    const size_t size = sizeof(T) * count;
+    const T* p = Cast<T>(size);
+    Offset += size;
+    return p;
+  }
+};
+
+DxilRuntimeData::DxilRuntimeData() : DxilRuntimeData(nullptr, 0) {}
+
+DxilRuntimeData::DxilRuntimeData(const char *ptr, size_t size)
     : m_TableCount(0), m_StringReader(), m_ResourceTableReader(),
       m_FunctionTableReader(), m_IndexTableReader(), m_Context() {
   m_Context = {&m_StringReader, &m_IndexTableReader, &m_ResourceTableReader,
                &m_FunctionTableReader};
   m_ResourceTableReader.SetContext(&m_Context);
   m_FunctionTableReader.SetContext(&m_Context);
-  InitFromRDAT(ptr);
+  InitFromRDAT(ptr, size);
 }
 
 // initializing reader from RDAT. return true if no error has occured.
-bool DxilRuntimeData::InitFromRDAT(const void *pRDAT) {
+bool DxilRuntimeData::InitFromRDAT(const void *pRDAT, size_t size) {
   if (pRDAT) {
-    const char *ptr = static_cast<const char *>(pRDAT);
-    uint32_t TableCount = (uint32_t)*ptr;
-    RuntimeDataTableHeader *records = (RuntimeDataTableHeader *)(ptr + 4);
-    for (uint32_t i = 0; i < TableCount; ++i) {
-      RuntimeDataTableHeader *curRecord = &records[i];
-      switch (static_cast<RuntimeDataPartType>(curRecord->tableType)) {
-      case RuntimeDataPartType::Resource: {
-        m_ResourceTableReader.SetResourceInfo(
-            (RuntimeDataResourceInfo *)(ptr + curRecord->offset),
-            curRecord->size / sizeof(RuntimeDataResourceInfo));
-        break;
-      }
-      case RuntimeDataPartType::String: {
-        m_StringReader =
-            StringTableReader(ptr + curRecord->offset, curRecord->size);
-        break;
+    try {
+      CheckedReader Reader(pRDAT, size);
+      RuntimeDataHeader RDATHeader = Reader.Read<RuntimeDataHeader>();
+      if (RDATHeader.Version < RDAT_Version_0) {
+        // Prerelease version, fallback to that Init
+        return InitFromRDAT_Prerelease(pRDAT, size);
       }
-      case RuntimeDataPartType::Function: {
-        m_FunctionTableReader.SetFunctionInfo(
-            (RuntimeDataFunctionInfo *)(ptr + curRecord->offset));
-        m_FunctionTableReader.SetCount(curRecord->size /
-                                       sizeof(RuntimeDataFunctionInfo));
-        break;
+      const uint32_t *offsets = Reader.ReadArray<uint32_t>(RDATHeader.PartCount);
+      for (uint32_t i = 0; i < RDATHeader.PartCount; ++i) {
+        Reader.Advance(offsets[i]);
+        RuntimeDataPartHeader part = Reader.Read<RuntimeDataPartHeader>();
+        CheckedReader PR(Reader.ReadArray<char>(part.Size), part.Size);
+        switch (part.Type) {
+        case RuntimeDataPartType::StringBuffer: {
+          m_StringReader = StringTableReader(
+            PR.ReadArray<char>(part.Size), part.Size);
+          break;
+        }
+        case RuntimeDataPartType::IndexArrays: {
+          uint32_t count = part.Size / sizeof(uint32_t);
+          m_IndexTableReader = IndexTableReader(
+            PR.ReadArray<uint32_t>(count), count);
+          break;
+        }
+        case RuntimeDataPartType::ResourceTable: {
+          RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
+          size_t tableSize = table.RecordCount * table.RecordStride;
+          m_ResourceTableReader.SetResourceInfo(PR.ReadArray<char>(tableSize),
+            table.RecordCount, table.RecordStride);
+          break;
+        }
+        case RuntimeDataPartType::FunctionTable: {
+          RuntimeDataTableHeader table = PR.Read<RuntimeDataTableHeader>();
+          size_t tableSize = table.RecordCount * table.RecordStride;
+          m_FunctionTableReader.SetFunctionInfo(PR.ReadArray<char>(tableSize),
+            table.RecordCount, table.RecordStride);
+          break;
+        }
+        default:
+          continue; // Skip unrecognized parts
+        }
       }
-      case RuntimeDataPartType::Index: {
-        m_IndexTableReader = IndexTableReader(
-            (uint32_t *)(ptr + curRecord->offset), curRecord->size / 4);
-        break;
-      }
-      default:
-        return false;
+      return true;
+    } catch(CheckedReader::exception e) {
+      throw hlsl::Exception(DXC_E_MALFORMED_CONTAINER, e.what());
+    }
+  }
+  return false;
+}
+
+bool DxilRuntimeData::InitFromRDAT_Prerelease(const void *pRDAT, size_t size) {
+  enum class RuntimeDataPartType_Prerelease : uint32_t {
+    Invalid = 0,
+    String,
+    Function,
+    Resource,
+    Index
+  };
+  struct RuntimeDataTableHeader_Prerelease {
+    uint32_t tableType; // RuntimeDataPartType
+    uint32_t size;
+    uint32_t offset;
+  };
+  if (pRDAT) {
+    try {
+      CheckedReader Reader(pRDAT, size);
+      uint32_t partCount = Reader.Read<uint32_t>();
+      const RuntimeDataTableHeader_Prerelease *tableHeaders =
+        Reader.ReadArray<RuntimeDataTableHeader_Prerelease>(partCount);
+      for (uint32_t i = 0; i < partCount; ++i) {
+        uint32_t partSize = tableHeaders[i].size;
+        Reader.Advance(tableHeaders[i].offset);
+        CheckedReader PR(Reader.ReadArray<char>(partSize), partSize);
+        switch ((RuntimeDataPartType_Prerelease)(tableHeaders[i].tableType)) {
+        case RuntimeDataPartType_Prerelease::String: {
+          m_StringReader = StringTableReader(
+            PR.ReadArray<char>(partSize), partSize);
+          break;
+        }
+        case RuntimeDataPartType_Prerelease::Index: {
+          uint32_t count = partSize / sizeof(uint32_t);
+          m_IndexTableReader = IndexTableReader(
+            PR.ReadArray<uint32_t>(count), count);
+          break;
+        }
+        case RuntimeDataPartType_Prerelease::Resource: {
+          uint32_t count = partSize / sizeof(RuntimeDataResourceInfo);
+          m_ResourceTableReader.SetResourceInfo(PR.ReadArray<char>(partSize),
+            count, sizeof(RuntimeDataResourceInfo));
+          break;
+        }
+        case RuntimeDataPartType_Prerelease::Function: {
+          uint32_t count = partSize / sizeof(RuntimeDataFunctionInfo);
+          m_FunctionTableReader.SetFunctionInfo(PR.ReadArray<char>(partSize),
+            count, sizeof(RuntimeDataResourceInfo));
+          break;
+        }
+        default:
+          return false; // There should be no unrecognized parts
+        }
       }
+      return true;
+    } catch(CheckedReader::exception e) {
+      throw hlsl::Exception(DXC_E_MALFORMED_CONTAINER, e.what());
     }
-    return true;
   }
   return false;
 }
@@ -106,41 +241,40 @@ namespace {
 class DxilRuntimeReflection_impl : public DxilRuntimeReflection {
 private:
   typedef std::unordered_map<const char *, std::unique_ptr<wchar_t[]>> StringMap;
-  typedef std::vector<DXIL_RESOURCE> ResourceList;
-  typedef std::vector<DXIL_RESOURCE *> ResourceRefList;
-  typedef std::vector<DXIL_FUNCTION> FunctionList;
+  typedef std::vector<DxilResourceDesc> ResourceList;
+  typedef std::vector<DxilResourceDesc *> ResourceRefList;
+  typedef std::vector<DxilFunctionDesc> FunctionList;
   typedef std::vector<const wchar_t *> WStringList;
 
   DxilRuntimeData m_RuntimeData;
   StringMap m_StringMap;
   ResourceList m_Resources;
   FunctionList m_Functions;
-  std::unordered_map<ResourceKey, DXIL_RESOURCE *> m_ResourceMap;
-  std::unordered_map<DXIL_FUNCTION *, ResourceRefList> m_FuncToResMap;
-  std::unordered_map<DXIL_FUNCTION *, WStringList> m_FuncToStringMap;
+  std::unordered_map<ResourceKey, DxilResourceDesc *> m_ResourceMap;
+  std::unordered_map<DxilFunctionDesc *, ResourceRefList> m_FuncToResMap;
+  std::unordered_map<DxilFunctionDesc *, WStringList> m_FuncToStringMap;
   bool m_initialized;
 
   const wchar_t *GetWideString(const char *ptr);
   void AddString(const char *ptr);
   void InitializeReflection();
-  const DXIL_RESOURCE * const*GetResourcesForFunction(DXIL_FUNCTION &function,
+  const DxilResourceDesc * const*GetResourcesForFunction(DxilFunctionDesc &function,
                              const FunctionReader &functionReader);
-  const wchar_t **GetDependenciesForFunction(DXIL_FUNCTION &function,
+  const wchar_t **GetDependenciesForFunction(DxilFunctionDesc &function,
                              const FunctionReader &functionReader);
-  DXIL_RESOURCE *AddResource(const ResourceReader &resourceReader);
-  DXIL_FUNCTION *AddFunction(const FunctionReader &functionReader);
+  DxilResourceDesc *AddResource(const ResourceReader &resourceReader);
+  DxilFunctionDesc *AddFunction(const FunctionReader &functionReader);
 
 public:
   // TODO: Implement pipeline state validation with runtime data
   // TODO: Update BlobContainer.h to recognize 'RDAT' blob
-  // TODO: Add size and verification to InitFromRDAT and DxilRuntimeData
   DxilRuntimeReflection_impl()
       : m_RuntimeData(), m_StringMap(), m_Resources(), m_Functions(),
         m_FuncToResMap(), m_FuncToStringMap(), m_initialized(false) {}
   virtual ~DxilRuntimeReflection_impl() {}
   // This call will allocate memory for GetLibraryReflection call
-  bool InitFromRDAT(const void *pRDAT) override;
-  const DXIL_LIBRARY_DESC GetLibraryReflection() override;
+  bool InitFromRDAT(const void *pRDAT, size_t size) override;
+  const DxilLibraryDesc GetLibraryReflection() override;
 };
 
 void DxilRuntimeReflection_impl::AddString(const char *ptr) {
@@ -163,15 +297,15 @@ const wchar_t *DxilRuntimeReflection_impl::GetWideString(const char *ptr) {
   return m_StringMap.at(ptr).get();
 }
 
-bool DxilRuntimeReflection_impl::InitFromRDAT(const void *pRDAT) {
-  m_initialized = m_RuntimeData.InitFromRDAT(pRDAT);
+bool DxilRuntimeReflection_impl::InitFromRDAT(const void *pRDAT, size_t size) {
+  m_initialized = m_RuntimeData.InitFromRDAT(pRDAT, size);
   if (m_initialized)
     InitializeReflection();
   return m_initialized;
 }
 
-const DXIL_LIBRARY_DESC DxilRuntimeReflection_impl::GetLibraryReflection() {
-  DXIL_LIBRARY_DESC reflection = {};
+const DxilLibraryDesc DxilRuntimeReflection_impl::GetLibraryReflection() {
+  DxilLibraryDesc reflection = {};
   if (m_initialized) {
     reflection.NumResources =
         m_RuntimeData.GetResourceTableReader()->GetNumResources();
@@ -191,7 +325,7 @@ void DxilRuntimeReflection_impl::InitializeReflection() {
   for (uint32_t i = 0; i < resourceTableReader->GetNumResources(); ++i) {
     ResourceReader resourceReader = resourceTableReader->GetItem(i);
     AddString(resourceReader.GetName());
-    DXIL_RESOURCE *pResource = AddResource(resourceReader);
+    DxilResourceDesc *pResource = AddResource(resourceReader);
     if (pResource) {
       ResourceKey key(pResource->Class, pResource->ID);
       m_ResourceMap[key] = pResource;
@@ -206,13 +340,13 @@ void DxilRuntimeReflection_impl::InitializeReflection() {
   }
 }
 
-DXIL_RESOURCE *
+DxilResourceDesc *
 DxilRuntimeReflection_impl::AddResource(const ResourceReader &resourceReader) {
   assert(m_Resources.size() < m_Resources.capacity() && "Otherwise, number of resources was incorrect");
   if (!(m_Resources.size() < m_Resources.capacity()))
     return nullptr;
-  m_Resources.emplace_back(DXIL_RESOURCE({0}));
-  DXIL_RESOURCE &resource = m_Resources.back();
+  m_Resources.emplace_back(DxilResourceDesc({0}));
+  DxilResourceDesc &resource = m_Resources.back();
   resource.Class = (uint32_t)resourceReader.GetResourceClass();
   resource.Kind = (uint32_t)resourceReader.GetResourceKind();
   resource.Space = resourceReader.GetSpace();
@@ -224,10 +358,10 @@ DxilRuntimeReflection_impl::AddResource(const ResourceReader &resourceReader) {
   return &resource;
 }
 
-const DXIL_RESOURCE * const*DxilRuntimeReflection_impl::GetResourcesForFunction(
-    DXIL_FUNCTION &function, const FunctionReader &functionReader) {
+const DxilResourceDesc * const*DxilRuntimeReflection_impl::GetResourcesForFunction(
+    DxilFunctionDesc &function, const FunctionReader &functionReader) {
   if (m_FuncToResMap.find(&function) == m_FuncToResMap.end())
-    m_FuncToResMap.insert(std::pair<DXIL_FUNCTION *, ResourceRefList>(
+    m_FuncToResMap.insert(std::pair<DxilFunctionDesc *, ResourceRefList>(
         &function, ResourceRefList()));
   ResourceRefList &resourceList = m_FuncToResMap.at(&function);
   if (resourceList.empty()) {
@@ -245,10 +379,10 @@ const DXIL_RESOURCE * const*DxilRuntimeReflection_impl::GetResourcesForFunction(
 }
 
 const wchar_t **DxilRuntimeReflection_impl::GetDependenciesForFunction(
-    DXIL_FUNCTION &function, const FunctionReader &functionReader) {
+    DxilFunctionDesc &function, const FunctionReader &functionReader) {
   if (m_FuncToStringMap.find(&function) == m_FuncToStringMap.end())
     m_FuncToStringMap.insert(
-        std::pair<DXIL_FUNCTION *, WStringList>(&function, WStringList()));
+        std::pair<DxilFunctionDesc *, WStringList>(&function, WStringList()));
   WStringList &wStringList = m_FuncToStringMap.at(&function);
   for (uint32_t i = 0; i < functionReader.GetNumDependencies(); ++i) {
     wStringList.emplace_back(GetWideString(functionReader.GetDependency(i)));
@@ -256,13 +390,13 @@ const wchar_t **DxilRuntimeReflection_impl::GetDependenciesForFunction(
   return wStringList.empty() ? nullptr : wStringList.data();
 }
 
-DXIL_FUNCTION *
+DxilFunctionDesc *
 DxilRuntimeReflection_impl::AddFunction(const FunctionReader &functionReader) {
   assert(m_Functions.size() < m_Functions.capacity() && "Otherwise, number of functions was incorrect");
   if (!(m_Functions.size() < m_Functions.capacity()))
     return nullptr;
-  m_Functions.emplace_back(DXIL_FUNCTION({0}));
-  DXIL_FUNCTION &function = m_Functions.back();
+  m_Functions.emplace_back(DxilFunctionDesc({0}));
+  DxilFunctionDesc &function = m_Functions.back();
   function.Name = GetWideString(functionReader.GetName());
   function.UnmangledName = GetWideString(functionReader.GetUnmangledName());
   function.NumResources = functionReader.GetNumResources();

+ 207 - 91
lib/HLSL/DxilContainerAssembler.cpp

@@ -702,6 +702,86 @@ public:
   }
 };
 
+// Size-checked writer
+//  on overrun: throw buffer_overrun{};
+//  on overlap: throw buffer_overlap{};
+class CheckedWriter {
+  char *Ptr;
+  size_t Size;
+  size_t Offset;
+
+public:
+  class exception : public std::exception {};
+  class buffer_overrun : public exception {
+  public:
+    buffer_overrun() noexcept {}
+    virtual const char * what() const noexcept override {
+      return ("buffer_overrun");
+    }
+  };
+  class buffer_overlap : public exception {
+  public:
+    buffer_overlap() noexcept {}
+    virtual const char * what() const noexcept override {
+      return ("buffer_overlap");
+    }
+  };
+
+  CheckedWriter(void *ptr, size_t size) :
+    Ptr(reinterpret_cast<char*>(ptr)), Size(size), Offset(0) {}
+
+  size_t GetOffset() const { return Offset; }
+  void Reset(size_t offset = 0) {
+    if (offset >= Size) throw buffer_overrun{};
+    Offset = offset;
+  }
+  // offset is absolute, ensure offset is >= current offset
+  void Advance(size_t offset = 0) {
+    if (offset < Offset) throw buffer_overlap{};
+    if (offset >= Size) throw buffer_overrun{};
+    Offset = offset;
+  }
+  void CheckBounds(size_t size) const {
+    assert(Offset <= Size && "otherwise, offset larger than size");
+    if (size > Size - Offset)
+      throw buffer_overrun{};
+  }
+  template <typename T>
+  T *Cast(size_t size = 0) {
+    if (0 == size) size = sizeof(T);
+    CheckBounds(size);
+    return reinterpret_cast<T*>(Ptr + Offset);
+  }
+
+  // Map and Write advance Offset:
+  template <typename T>
+  T &Map() {
+    const size_t size = sizeof(T);
+    T * p = Cast<T>(size);
+    Offset += size;
+    return *p;
+  }
+  template <typename T>
+  T *MapArray(size_t count = 1) {
+    const size_t size = sizeof(T) * count;
+    T *p = Cast<T>(size);
+    Offset += size;
+    return p;
+  }
+  template <typename T>
+  void Write(const T &obj) {
+    const size_t size = sizeof(T);
+    *Cast<T>(size) = obj;
+    Offset += size;
+  }
+  template <typename T>
+  void WriteArray(const T *pArray, size_t count = 1) {
+    const size_t size = sizeof(T) * count;
+    memcpy(Cast<T>(size), pArray, size);
+    Offset += size;
+  }
+};
+
 // Like DXIL container, RDAT itself is a mini container that contains multiple RDAT parts
 class RDATPart {
 public:
@@ -728,92 +808,112 @@ public:
 
   void Write(void *ptr) {
     char *pCur = (char*)ptr;
-    for (auto row : m_rows) {
-      memcpy(pCur, &row, sizeof(T));
-      pCur += sizeof(T);
-    }
+    RuntimeDataTableHeader &header = *reinterpret_cast<RuntimeDataTableHeader*>(pCur);
+    header.RecordCount = m_rows.size();
+    header.RecordStride = sizeof(T);
+    pCur += sizeof(RuntimeDataTableHeader);
+    memcpy(pCur, m_rows.data(), header.RecordCount * header.RecordStride);
   };
 
-  uint32_t GetPartSize() const { return m_rows.size() * sizeof(T); }
+  uint32_t GetPartSize() const {
+    if (m_rows.empty())
+      return 0;
+    return sizeof(RuntimeDataTableHeader) + m_rows.size() * sizeof(T);
+  }
 };
 
 // Resource table will contain a list of RuntimeDataResourceInfo in order of
 // CBuffer, Sampler, SRV, and UAV resource classes.
 class ResourceTable : public RDATTable<RuntimeDataResourceInfo> {
 public:
-  RuntimeDataPartType GetType() const { return RuntimeDataPartType::Resource; }
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::ResourceTable; }
 };
 
 class FunctionTable : public RDATTable<RuntimeDataFunctionInfo> {
 public:
-  RuntimeDataPartType GetType() const { return RuntimeDataPartType::Function; }
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::FunctionTable; }
 };
 
-class StringTable : public RDATPart {
+class StringBufferPart : public RDATPart {
 private:
   StringMap<uint32_t> m_StringMap;
   SmallVector<char, 256> m_StringBuffer;
   uint32_t curIndex;
 public:
-  StringTable() : m_StringMap(), m_StringBuffer(), curIndex(0) {}
+  StringBufferPart() : m_StringMap(), m_StringBuffer(), curIndex(0) {
+    // Always start string table with null so empty/null strings have offset of zero
+    m_StringBuffer.push_back('\0');
+  }
   // returns the offset of the name inserted
   uint32_t Insert(StringRef name) {
+    if (name.empty())
+      return 0;
+
     // Don't add duplicate strings
     auto found = m_StringMap.find(name);
     if (found != m_StringMap.end())
       return found->second;
-    m_StringMap[name] = curIndex;
 
+    uint32_t prevIndex = (uint32_t)m_StringBuffer.size();
+    m_StringMap[name] = prevIndex;
     m_StringBuffer.reserve(m_StringBuffer.size() + name.size() + 1);
     m_StringBuffer.append(name.begin(), name.end());
     m_StringBuffer.push_back('\0');
-
-    uint32_t prevIndex = curIndex;
-    curIndex += (uint32_t)name.size() + 1;
     return prevIndex;
   }
-  RuntimeDataPartType GetType() const { return RuntimeDataPartType::String; }
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::StringBuffer; }
   uint32_t GetPartSize() const { return m_StringBuffer.size(); }
   void Write(void *ptr) { memcpy(ptr, m_StringBuffer.data(), m_StringBuffer.size()); }
 };
 
-struct IndexTable : public RDATPart {
+struct IndexArraysPart : public RDATPart {
 private:
-  typedef llvm::SmallVector<uint32_t, 8> Indices;
-  std::vector<Indices> m_IndicesList;
-  uint32_t m_curOffset;
+  std::vector<uint32_t> m_IndexBuffer;
+
+  // Use m_IndexSet with CmpIndices to avoid duplicate index arrays
+  struct CmpIndices {
+    const IndexArraysPart &Table;
+    CmpIndices(const IndexArraysPart &table) : Table(table) {}
+    bool operator()(uint32_t left, uint32_t right) const {
+      const uint32_t *pLeft = Table.m_IndexBuffer.data() + left;
+      const uint32_t *pRight = Table.m_IndexBuffer.data() + right;
+      if (*pLeft != *pRight)
+        return (*pLeft < *pRight);
+      uint32_t count = *pLeft;
+      for (unsigned i = 0; i < count; i++) {
+        ++pLeft; ++pRight;
+        if (*pLeft != *pRight)
+          return (*pLeft < *pRight);
+      }
+      return false;
+    }
+  };
+  std::set<uint32_t, CmpIndices> m_IndexSet;
 
 public:
-  IndexTable() : m_IndicesList(), m_curOffset(0) {}
+  IndexArraysPart() : m_IndexBuffer(), m_IndexSet(*this) {}
   template <class iterator>
   uint32_t AddIndex(iterator begin, iterator end) {
-    uint32_t prevOffset = m_curOffset;
-    m_IndicesList.emplace_back(Indices());
-    auto &curIndices = m_IndicesList.back();
-    for (iterator it = begin; it != end; ++it) {
-      curIndices.emplace_back(*it);
-    }
-    m_curOffset += curIndices.size() + 1;
-    return prevOffset;
+    uint32_t newOffset = m_IndexBuffer.size();
+    m_IndexBuffer.push_back(0); // Size: update after insertion
+    m_IndexBuffer.insert(m_IndexBuffer.end(), begin, end);
+    m_IndexBuffer[newOffset] = (m_IndexBuffer.size() - newOffset) - 1;
+    // Check for duplicate, return new offset if not duplicate
+    auto insertResult = m_IndexSet.insert(newOffset);
+    if (insertResult.second)
+      return newOffset;
+    // Otherwise it was a duplicate, so chop off the size and return the original
+    m_IndexBuffer.resize(newOffset);
+    return *insertResult.first;
   }
 
-  RuntimeDataPartType GetType() const { return RuntimeDataPartType::Index; }
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::IndexArrays; }
   uint32_t GetPartSize() const {
-    uint32_t size = 0;
-    for (auto Indices : m_IndicesList) {
-      size += Indices.size() + 1;
-    }
-    return sizeof(uint32_t) * size;
+    return sizeof(uint32_t) * m_IndexBuffer.size();
   }
 
   void Write(void *ptr) {
-    uint32_t *cur = (uint32_t*)ptr;
-    for (auto Indices : m_IndicesList) {
-      uint32_t count = Indices.size();
-      memcpy(cur, &count, 4);
-      std::copy(Indices.begin(), Indices.end(), cur + 1);
-      cur += sizeof(uint32_t)/sizeof(4) + Indices.size();
-    }
+    memcpy(ptr, m_IndexBuffer.data(), m_IndexBuffer.size() * sizeof(uint32_t));
   }
 };
 
@@ -824,7 +924,7 @@ private:
   const DxilModule &m_Module;
   SmallVector<char, 1024> m_RDATBuffer;
 
-  std::vector<std::unique_ptr<RDATPart>> m_tables;
+  std::vector<std::unique_ptr<RDATPart>> m_Parts;
   typedef llvm::SmallSetVector<uint32_t, 8> Indices;
   typedef std::unordered_map<llvm::Function *, Indices> FunctionIndexMap;
   FunctionIndexMap m_FuncToResNameOffset; // list of resources used
@@ -864,9 +964,9 @@ private:
   void InsertToResourceTable(DxilResourceBase &resource,
                              ResourceClass resourceClass,
                              ResourceTable &resourceTable,
-                             StringTable &stringTable,
+                             StringBufferPart &stringBufferPart,
                              uint32_t &resourceIndex) {
-    uint32_t stringIndex = stringTable.Insert(resource.GetGlobalName());
+    uint32_t stringIndex = stringBufferPart.Insert(resource.GetGlobalName());
     UpdateFunctionToResourceInfo(&resource, resourceIndex++);
     RuntimeDataResourceInfo info = {};
     info.ID = resource.GetID();
@@ -890,35 +990,35 @@ private:
     resourceTable.Insert(info);
   }
 
-  void UpdateResourceInfo(StringTable &stringTable) {
+  void UpdateResourceInfo(StringBufferPart &stringBufferPart) {
     // Try to allocate string table for resources. String table is a sequence
     // of strings delimited by \0
-    m_tables.emplace_back(std::make_unique<ResourceTable>());
-    ResourceTable &resourceTable = *(ResourceTable*)m_tables.back().get();
+    m_Parts.emplace_back(std::make_unique<ResourceTable>());
+    ResourceTable &resourceTable = *reinterpret_cast<ResourceTable*>(m_Parts.back().get());
     uint32_t resourceIndex = 0;
     for (auto &resource : m_Module.GetCBuffers()) {
-      InsertToResourceTable(*resource.get(), ResourceClass::CBuffer, resourceTable, stringTable,
+      InsertToResourceTable(*resource.get(), ResourceClass::CBuffer, resourceTable, stringBufferPart,
                             resourceIndex);
 
     }
     for (auto &resource : m_Module.GetSamplers()) {
-      InsertToResourceTable(*resource.get(), ResourceClass::Sampler, resourceTable, stringTable,
+      InsertToResourceTable(*resource.get(), ResourceClass::Sampler, resourceTable, stringBufferPart,
                             resourceIndex);
     }
     for (auto &resource : m_Module.GetSRVs()) {
-      InsertToResourceTable(*resource.get(), ResourceClass::SRV, resourceTable, stringTable,
+      InsertToResourceTable(*resource.get(), ResourceClass::SRV, resourceTable, stringBufferPart,
                             resourceIndex);
     }
     for (auto &resource : m_Module.GetUAVs()) {
-      InsertToResourceTable(*resource.get(), ResourceClass::UAV, resourceTable, stringTable,
+      InsertToResourceTable(*resource.get(), ResourceClass::UAV, resourceTable, stringBufferPart,
                             resourceIndex);
     }
   }
 
-  void UpdateFunctionDependency(llvm::Function *F, StringTable &stringTable) {
+  void UpdateFunctionDependency(llvm::Function *F, StringBufferPart &stringBufferPart) {
     for (const auto &user : F->users()) {
       llvm::Function *userFunction = FindUsingFunction(user);
-      uint32_t index = stringTable.Insert(F->getName());
+      uint32_t index = stringBufferPart.Insert(F->getName());
       if (m_FuncToDependencies.find(userFunction) ==
           m_FuncToDependencies.end()) {
         m_FuncToDependencies[userFunction] =
@@ -928,27 +1028,27 @@ private:
     }
   }
 
-  void UpdateFunctionInfo(StringTable &stringTable) {
+  void UpdateFunctionInfo(StringBufferPart &stringBufferPart) {
     // TODO: get a list of valid shader flags
     // TODO: get a minimum shader version
     std::unordered_map<llvm::Function *, std::vector<StringRef>>
         FuncToUnresolvedDependencies;
-    m_tables.emplace_back(std::make_unique<FunctionTable>());
-    FunctionTable &functionTable = *(FunctionTable*)(m_tables.back().get());
-    m_tables.emplace_back(std::make_unique<IndexTable>());
-    IndexTable &indexTable = *(IndexTable*)(m_tables.back().get());
+    m_Parts.emplace_back(std::make_unique<FunctionTable>());
+    FunctionTable &functionTable = *reinterpret_cast<FunctionTable*>(m_Parts.back().get());
+    m_Parts.emplace_back(std::make_unique<IndexArraysPart>());
+    IndexArraysPart &indexArraysPart = *reinterpret_cast<IndexArraysPart*>(m_Parts.back().get());
     for (auto &function : m_Module.GetModule()->getFunctionList()) {
       // If function is a declaration, it is an unresolved dependency in the library
       if (function.isDeclaration() && !OP::IsDxilOpFunc(&function)) {
-        UpdateFunctionDependency(&function, stringTable);
+        UpdateFunctionDependency(&function, stringBufferPart);
       }
     }
     for (auto &function : m_Module.GetModule()->getFunctionList()) {
       if (!function.isDeclaration()) {
         StringRef mangled = function.getName();
         StringRef unmangled = hlsl::dxilutil::DemangleFunctionName(function.getName());
-        uint32_t mangledIndex = stringTable.Insert(mangled);
-        uint32_t unmangledIndex = stringTable.Insert(unmangled);
+        uint32_t mangledIndex = stringBufferPart.Insert(mangled);
+        uint32_t unmangledIndex = stringBufferPart.Insert(unmangled);
         // Update resource Index
         uint32_t resourceIndex = UINT_MAX;
         uint32_t functionDependencies = UINT_MAX;
@@ -958,11 +1058,11 @@ private:
 
         if (m_FuncToResNameOffset.find(&function) != m_FuncToResNameOffset.end())
           resourceIndex =
-              indexTable.AddIndex(m_FuncToResNameOffset[&function].begin(),
+              indexArraysPart.AddIndex(m_FuncToResNameOffset[&function].begin(),
                                   m_FuncToResNameOffset[&function].end());
         if (m_FuncToDependencies.find(&function) != m_FuncToDependencies.end())
           functionDependencies =
-              indexTable.AddIndex(m_FuncToDependencies[&function].begin(),
+              indexArraysPart.AddIndex(m_FuncToDependencies[&function].begin(),
                                   m_FuncToDependencies[&function].end());
         if (m_Module.HasDxilFunctionProps(&function)) {
           auto props = m_Module.GetDxilFunctionProps(&function);
@@ -987,9 +1087,9 @@ private:
         info.FunctionDependencies = functionDependencies;
         info.PayloadSizeInBytes = payloadSizeInBytes;
         info.AttributeSizeInBytes = attrSizeInBytes;
-        uint64_t rawFlags = flags.GetShaderFlagsRaw();
-        info.FeatureInfo1 = rawFlags & 0xffffffff;
-        info.FeatureInfo2 = (rawFlags >> 32) & 0xffffffff;
+        uint64_t featureFlags = flags.GetFeatureInfo();
+        info.FeatureInfo1 = featureFlags & 0xffffffff;
+        info.FeatureInfo2 = (featureFlags >> 32) & 0xffffffff;
         functionTable.Insert(info);
       }
     }
@@ -997,41 +1097,57 @@ private:
 
 public:
   DxilRDATWriter(const DxilModule &module, uint32_t InfoVersion = 0)
-      : m_Module(module), m_RDATBuffer(), m_tables(), m_FuncToResNameOffset() {
+      : m_Module(module), m_RDATBuffer(), m_Parts(), m_FuncToResNameOffset() {
     // It's important to keep the order of this update
-    m_tables.emplace_back(std::make_unique<StringTable>());
-    StringTable &stringTable = *(StringTable*)m_tables.back().get();
-    UpdateResourceInfo(stringTable);
-    UpdateFunctionInfo(stringTable);
+    m_Parts.emplace_back(std::make_unique<StringBufferPart>());
+    StringBufferPart &stringBufferPart = *reinterpret_cast<StringBufferPart*>(m_Parts.back().get());
+    UpdateResourceInfo(stringBufferPart);
+    UpdateFunctionInfo(stringBufferPart);
+
+    // Delete any empty parts:
+    std::vector<std::unique_ptr<RDATPart>>::iterator it = m_Parts.begin();
+    while (it != m_Parts.end()) {
+      if (it->get()->GetPartSize() == 0) {
+        it = m_Parts.erase(it);
+      }
+      else
+        it++;
+    }
   }
 
   __override uint32_t size() const {
-    // one variable to count the number of blobs and two blobs
-    uint32_t total = 4 + m_tables.size() * sizeof(RuntimeDataTableHeader);
-    for (auto &&table : m_tables)
-      total += table->GetPartSize();
+    // header + offset array
+    uint32_t total = sizeof(RuntimeDataHeader) + m_Parts.size() * sizeof(uint32_t);
+    // For each part: part header + part size
+    for (auto &part : m_Parts)
+      total += sizeof(RuntimeDataPartHeader) + PSVALIGN4(part->GetPartSize());
     return total;
   }
 
   __override void write(AbstractMemoryStream *pStream) {
-    m_RDATBuffer.resize(size());
-    char *pCur = m_RDATBuffer.data();
-    // write number of tables
-    uint32_t size = m_tables.size();
-    memcpy(pCur, &size, sizeof(uint32_t));
-    pCur += sizeof(uint32_t);
-    // write records
-    uint32_t curTableOffset = size * sizeof(RuntimeDataTableHeader) + 4;
-    for (auto &&table : m_tables) {
-      RuntimeDataTableHeader record = { static_cast<uint32_t>(table->GetType()), table->GetPartSize(), curTableOffset };
-      memcpy(pCur, &record, sizeof(RuntimeDataTableHeader));
-      pCur += sizeof(RuntimeDataTableHeader);
-      curTableOffset += record.size;
+    try {
+      m_RDATBuffer.resize(size(), 0);
+      CheckedWriter W(m_RDATBuffer.data(), m_RDATBuffer.size());
+      // write RDAT header
+      RuntimeDataHeader &header = W.Map<RuntimeDataHeader>();
+      header.Version = RDAT_Version_0;
+      header.PartCount = m_Parts.size();
+      // map offsets
+      uint32_t *offsets = W.MapArray<uint32_t>(header.PartCount);
+      // write parts
+      unsigned i = 0;
+      for (auto &part : m_Parts) {
+        offsets[i++] = W.GetOffset();
+        RuntimeDataPartHeader &partHeader = W.Map<RuntimeDataPartHeader>();
+        partHeader.Type = part->GetType();
+        partHeader.Size = PSVALIGN4(part->GetPartSize());
+        DXASSERT(partHeader.Size, "otherwise, failed to remove empty part");
+        char *bytes = W.MapArray<char>(partHeader.Size);
+        part->Write(bytes);
+      }
     }
-    // write tables
-    for (auto &&table : m_tables) {
-      table->Write(pCur);
-      pCur += table->GetPartSize();
+    catch (CheckedWriter::exception e) {
+      throw hlsl::Exception(DXC_E_GENERAL_INTERNAL_ERROR, e.what());
     }
 
     ULONG cbWritten;

+ 12 - 12
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -732,7 +732,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
       IFT(containerReflection->GetPartContent(i, &pBlob));
       // Validate using DxilRuntimeData
       DxilRuntimeData context;
-      context.InitFromRDAT((char *)pBlob->GetBufferPointer());
+      context.InitFromRDAT((char *)pBlob->GetBufferPointer(), pBlob->GetBufferSize());
       FunctionTableReader *funcTableReader = context.GetFunctionTableReader();
       ResourceTableReader *resTableReader = context.GetResourceTableReader();
       VERIFY_ARE_EQUAL(funcTableReader->GetNumFunctions(), 4);
@@ -748,7 +748,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           hlsl::ShaderFlags flag;
           flag.SetUAVLoadAdditionalFormats(true);
           flag.SetLowPrecisionPresent(true);
-          uint64_t rawFlag = flag.GetShaderFlagsRaw();
+          uint64_t rawFlag = flag.GetFeatureInfo();
           VERIFY_ARE_EQUAL(funcReader.GetFeatureFlag(), rawFlag);
           ResourceReader resReader = funcReader.GetResource(0);
           VERIFY_ARE_EQUAL(resReader.GetResourceClass(), hlsl::DXIL::ResourceClass::UAV);
@@ -757,7 +757,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
         else if (cur_str.compare("function1") == 0) {
           hlsl::ShaderFlags flag;
           flag.SetLowPrecisionPresent(true);
-          uint64_t rawFlag = flag.GetShaderFlagsRaw();
+          uint64_t rawFlag = flag.GetFeatureInfo();
           VERIFY_ARE_EQUAL(funcReader.GetFeatureFlag(), rawFlag);
           VERIFY_ARE_EQUAL(funcReader.GetNumResources(), 3);
         }
@@ -790,24 +790,24 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
       VERIFY_ARE_EQUAL(resTableReader->GetNumResources(), 8);
       // This is validation test for DxilRuntimeReflection implemented on DxilRuntimeReflection.inl
       unique_ptr<DxilRuntimeReflection> pReflection(CreateDxilRuntimeReflection());
-      VERIFY_IS_TRUE(pReflection->InitFromRDAT(pBlob->GetBufferPointer()));
-      DXIL_LIBRARY_DESC lib_reflection = pReflection->GetLibraryReflection();
+      VERIFY_IS_TRUE(pReflection->InitFromRDAT(pBlob->GetBufferPointer(), pBlob->GetBufferSize()));
+      DxilLibraryDesc lib_reflection = pReflection->GetLibraryReflection();
       VERIFY_ARE_EQUAL(lib_reflection.NumFunctions, 4);
       for (uint32_t j = 0; j < 3; ++j) {
-        DXIL_FUNCTION function = lib_reflection.pFunction[j];
+        DxilFunctionDesc function = lib_reflection.pFunction[j];
         std::string cur_str = str;
         cur_str.push_back('0' + j);
         if (cur_str.compare("function0") == 0) {
           hlsl::ShaderFlags flag;
           flag.SetUAVLoadAdditionalFormats(true);
           flag.SetLowPrecisionPresent(true);
-          uint64_t rawFlag = flag.GetShaderFlagsRaw();
+          uint64_t rawFlag = flag.GetFeatureInfo();
           uint64_t featureFlag = static_cast<uint64_t>(function.FeatureInfo2) << 32;
           featureFlag |= static_cast<uint64_t>(function.FeatureInfo1);
           VERIFY_ARE_EQUAL(featureFlag, rawFlag);
           VERIFY_ARE_EQUAL(function.NumResources, 1);
           VERIFY_ARE_EQUAL(function.NumFunctionDependencies, 0);
-          const DXIL_RESOURCE &resource = *function.Resources[0];
+          const DxilResourceDesc &resource = *function.Resources[0];
           VERIFY_ARE_EQUAL(resource.Class, (uint32_t)hlsl::DXIL::ResourceClass::UAV);
           VERIFY_ARE_EQUAL(resource.Kind, (uint32_t)hlsl::DXIL::ResourceKind::Texture1D);
           std::wstring wName = resource.Name;
@@ -816,7 +816,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
         else if (cur_str.compare("function1") == 0) {
           hlsl::ShaderFlags flag;
           flag.SetLowPrecisionPresent(true);
-          uint64_t rawFlag = flag.GetShaderFlagsRaw();
+          uint64_t rawFlag = flag.GetFeatureInfo();
           uint64_t featureFlag = static_cast<uint64_t>(function.FeatureInfo2) << 32;
           featureFlag |= static_cast<uint64_t>(function.FeatureInfo1);
           VERIFY_ARE_EQUAL(featureFlag, rawFlag);
@@ -824,7 +824,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           VERIFY_ARE_EQUAL(function.NumFunctionDependencies, 0);
           std::unordered_set<std::wstring> stringSet = { L"$Globals", L"b_buf", L"tex2" };
           for (uint32_t j = 0; j < 3; ++j) {
-            const DXIL_RESOURCE &resource = *function.Resources[j];
+            const DxilResourceDesc &resource = *function.Resources[j];
             std::wstring compareName = resource.Name;
             VERIFY_IS_TRUE(stringSet.find(compareName) != stringSet.end());
           }
@@ -843,7 +843,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           VERIFY_ARE_EQUAL(function.NumResources, numResFlagCheck);
           VERIFY_ARE_EQUAL(function.NumFunctionDependencies, 0);
           for (unsigned i = 0; i < function.NumResources; ++i) {
-            const DXIL_RESOURCE *res = function.Resources[i];
+            const DxilResourceDesc *res = function.Resources[i];
             VERIFY_ARE_EQUAL(res->Class, static_cast<uint32_t>(hlsl::DXIL::ResourceClass::UAV));
             unsigned j = 0;
             for (; j < numResFlagCheck; ++j) {
@@ -911,7 +911,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT2) {
       CComPtr<IDxcBlob> pBlob;
       IFT(pReflection->GetPartContent(i, &pBlob));
       DxilRuntimeData context;
-      context.InitFromRDAT((char *)pBlob->GetBufferPointer());
+      context.InitFromRDAT((char *)pBlob->GetBufferPointer(), pBlob->GetBufferSize());
       FunctionTableReader *funcTableReader = context.GetFunctionTableReader();
       ResourceTableReader *resTableReader = context.GetResourceTableReader();
       VERIFY_IS_TRUE(funcTableReader->GetNumFunctions() == 1);