2
0
Эх сурвалжийг харах

Merged PR 28: Add RDAT blob for library targets and add its reader

Young Kim 7 жил өмнө
parent
commit
0a098d7cbb

+ 1 - 0
include/dxc/HLSL/DxilContainer.h

@@ -82,6 +82,7 @@ enum DxilFourCC {
   DFCC_RootSignature            = DXIL_FOURCC('R', 'T', 'S', '0'),
   DFCC_DXIL                     = DXIL_FOURCC('D', 'X', 'I', 'L'),
   DFCC_PipelineStateValidation  = DXIL_FOURCC('P', 'S', 'V', '0'),
+  DFCC_RuntimeData              = DXIL_FOURCC('R', 'D', 'A', 'T'),
 };
 
 #undef DXIL_FOURCC

+ 1 - 1
include/dxc/HLSL/DxilMetadataHelper.h

@@ -343,7 +343,7 @@ public:
 
   // Function props.
   llvm::MDTuple *EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,
-                                       llvm::Function *F);
+                                       const llvm::Function *F);
   llvm::Function *LoadDxilFunctionProps(llvm::MDTuple *pProps,
                                         hlsl::DxilFunctionProps *props);
 

+ 15 - 18
include/dxc/HLSL/DxilModule.h

@@ -45,6 +45,8 @@ class OP;
 class RootSignatureHandle;
 struct DxilFunctionProps;
 
+typedef std::unordered_map<const llvm::Function *, std::unique_ptr<DxilFunctionProps>> DxilFunctionPropsMap;
+typedef std::unordered_map<llvm::Function *, std::unique_ptr<DxilEntrySignature>> DxilEntrySignatureMap;
 /// Use this class to manipulate DXIL of a shader.
 class DxilModule {
 public:
@@ -130,20 +132,21 @@ public:
   void ReplaceDxilEntrySignature(llvm::Function *F, llvm::Function *NewF);
 
   // DxilFunctionProps.
-  bool HasDxilFunctionProps(llvm::Function *F) const;
-  DxilFunctionProps &GetDxilFunctionProps(llvm::Function *F);
-  void AddDxilFunctionProps(llvm::Function *F, std::unique_ptr<DxilFunctionProps> &info);
+  bool HasDxilFunctionProps(const llvm::Function *F) const;
+  DxilFunctionProps &GetDxilFunctionProps(const llvm::Function *F);
+  const DxilFunctionProps &GetDxilFunctionProps(const llvm::Function *F) const;
+  void AddDxilFunctionProps(const llvm::Function *F, std::unique_ptr<DxilFunctionProps> &info);
 
   // Move DxilFunctionProps of F to NewF.
   void ReplaceDxilFunctionProps(llvm::Function *F, llvm::Function *NewF);
   void SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, llvm::Function *patchConstantFunc);
-  bool IsGraphicsShader(llvm::Function *F); // vs,hs,ds,gs,ps
-  bool IsPatchConstantShader(llvm::Function *F);
-  bool IsComputeShader(llvm::Function *F);
+  bool IsGraphicsShader(const llvm::Function *F) const; // vs,hs,ds,gs,ps
+  bool IsPatchConstantShader(const llvm::Function *F) const;
+  bool IsComputeShader(const llvm::Function *F) const;
 
   // Is an entry function that uses input/output signature conventions?
   // Includes: vs/hs/ds/gs/ps/cs as well as the patch constant function.
-  bool IsEntryThatUsesSignatures(llvm::Function *F);
+  bool IsEntryThatUsesSignatures(const llvm::Function *F) const ;
 
   // Remove Root Signature from module metadata
   void StripRootSignatureFromMetadata();
@@ -179,12 +182,8 @@ public:
   void ResetRootSignature(RootSignatureHandle *pValue);
   void ResetTypeSystem(DxilTypeSystem *pValue);
   void ResetOP(hlsl::OP *hlslOP);
-  void ResetFunctionPropsMap(
-      std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>>
-          &&propsMap);
-  void ResetEntrySignatureMap(
-      std::unordered_map<llvm::Function *, std::unique_ptr<DxilEntrySignature>>
-          &&SigMap);
+  void ResetFunctionPropsMap(DxilFunctionPropsMap &&propsMap);
+  void ResetEntrySignatureMap(DxilEntrySignatureMap &&SigMap);
 
   void StripDebugRelatedCode();
   llvm::DebugInfoFinder &GetOrCreateDebugInfoFinder();
@@ -430,14 +429,12 @@ private:
   std::unique_ptr<DxilTypeSystem> m_pTypeSystem;
 
   // Function properties for shader functions.
-  std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>>
-      m_DxilFunctionPropsMap;
+  DxilFunctionPropsMap m_DxilFunctionPropsMap;
   // EntrySig for shader functions.
-  std::unordered_map<llvm::Function *, std::unique_ptr<DxilEntrySignature>>
-      m_DxilEntrySignatureMap;
+  DxilEntrySignatureMap m_DxilEntrySignatureMap;
 
   // Keeps track of patch constant functions used by hull shaders
-  std::unordered_set<llvm::Function *>  m_PatchConstantFunctions;
+  std::unordered_set<const llvm::Function *>  m_PatchConstantFunctions;
 
   // ViewId state.
   std::unique_ptr<DxilViewIdState> m_pViewIdState;

+ 355 - 15
include/dxc/HLSL/DxilPipelineStateValidation.h

@@ -14,7 +14,13 @@
 
 #include <stdint.h>
 #include <string.h>
+namespace hlsl {
+namespace DXIL {
+namespace PSV {
 
+#ifndef UINT_MAX
+#define UINT_MAX 0xffffffff
+#endif
 // How many dwords are required for mask with one bit per component, 4 components per vector
 inline uint32_t PSVComputeMaskDwordsFromVectors(uint32_t Vectors) { return (Vectors + 7) >> 3; }
 inline uint32_t PSVComputeInputOutputTableSize(uint32_t InputVectors, uint32_t OutputVectors) {
@@ -64,6 +70,13 @@ enum class PSVShaderKind : uint8_t    // DXIL::ShaderKind
   Hull,
   Domain,
   Compute,
+  Library,
+  RayGeneration,
+  Intersection,
+  AnyHit,
+  ClosestHit,
+  Miss,
+  Callable,
   Invalid,
 };
 
@@ -99,10 +112,43 @@ enum class PSVResourceType
   UAVRaw,
   UAVStructured,
   UAVStructuredWithCounter,
+  NumEntries
+};
 
+enum class PSVResourceKind
+{
+  Invalid = 0,
+  Texture1D,
+  Texture2D,
+  Texture2DMS,
+  Texture3D,
+  TextureCube,
+  Texture1DArray,
+  Texture2DArray,
+  Texture2DMSArray,
+  TextureCubeArray,
+  TypedBuffer,
+  RawBuffer,
+  StructuredBuffer,
+  CBuffer,
+  Sampler,
+  TBuffer,
+  RTAccelerationStructure,
   NumEntries
 };
 
+// Table of null-terminated strings, overall size aligned to dword boundary, last byte must be null
+struct PSVStringTable {
+  const char *Table;
+  uint32_t Size;
+  PSVStringTable() : Table(nullptr), Size(0) {}
+  PSVStringTable(const char *table, uint32_t size) : Table(table), Size(size) {}
+  const char *Get(uint32_t offset) const {
+    _Analysis_assume_(offset < Size && Table && Table[Size-1] == '\0');
+    return Table + offset;
+  }
+};
+
 // Versioning is additive and based on size
 struct PSVResourceBindInfo0
 {
@@ -111,7 +157,75 @@ struct PSVResourceBindInfo0
   uint32_t LowerBound;
   uint32_t UpperBound;
 };
-// PSVResourceBindInfo1 would derive and extend
+
+struct RuntimeDataResourceInfo : public PSVResourceBindInfo0
+{
+  uint32_t Kind; // PSVResourceKind
+  uint32_t Name; // offset for string table
+};
+struct RuntimeDataFunctionInfo {
+  uint32_t Name;                 // offset for string table
+  uint32_t UnmangledName;        // offset for string table
+  uint32_t Resources;            // index to an index table
+  uint32_t ShaderKind;           // shader kind
+  uint32_t PayloadSizeInBytes;   // payload count for miss, closest hit, any hit
+                                 // shader, or parameter size for call shader
+  uint32_t AttributeSizeInBytes; // attribute size for closest hit and any hit
+  uint32_t FeatureInfo1;         // required feature info
+  uint32_t FeatureInfo2;         // required feature info
+  uint32_t ShaderStageFlag;      // valid shader stage flag
+  uint32_t MinShaderTarget;      // minimum shader target
+};
+
+// Index table is a sequence of rows, where each row has a count as a first
+// element followed by the count number of elements pre computing values
+struct IndexTableReader {
+private:
+  const uint32_t *m_table;
+  uint32_t m_size;
+
+public:
+  class IndexRow {
+  private:
+    const uint32_t *m_values;
+    const uint32_t m_count;
+  public:
+    IndexRow(const uint32_t *values, uint32_t count)
+      : m_values(values), m_count(count) {}
+    uint32_t Count() { return m_count; }
+    uint32_t At(uint32_t i) { return m_values[i]; }
+  };
+
+  IndexTableReader() : m_table(nullptr), m_size(0) {}
+  IndexTableReader(const uint32_t *table, uint32_t size)
+    : m_table(table), m_size(size) {}
+
+  void SetTable(const uint32_t *table) {
+    m_table = table;
+  }
+
+  void SetSize(uint32_t size) {
+    m_size = size;
+  }
+
+  IndexRow getRow(uint32_t i) {
+    return IndexRow(&m_table[i] + 1, m_table[i]);
+  }
+};
+
+struct RuntimeDataTableHeader {
+  uint32_t tableType; // DataTableType
+  uint32_t size;
+  uint32_t offset;
+};
+
+enum RuntimeDataTableType : uint32_t {
+  Invalid = 0,
+  String,
+  Function,
+  Resource,
+  Index
+};
 
 // Helpers for output dependencies (ViewID and Input-Output tables)
 struct PSVComponentMask {
@@ -165,17 +279,6 @@ struct PSVDependencyTable {
   bool IsValid() { return Table != nullptr; }
 };
 
-// Table of null-terminated strings, overall size aligned to dword boundary, last byte must be null
-struct PSVStringTable {
-  const char *Table;
-  uint32_t Size;
-  PSVStringTable() : Table(nullptr), Size(0) {}
-  PSVStringTable(const char *table, uint32_t size) : Table(table), Size(size) {}
-  const char *Get(uint32_t offset) const {
-    _Analysis_assume_(offset < Size && Table && Table[Size-1] == '\0');
-    return Table + offset;
-  }
-};
 struct PSVString {
   uint32_t Offset;
   PSVString() : Offset(0) {}
@@ -237,7 +340,7 @@ enum class PSVSemanticKind : uint8_t    // DXIL::SemanticKind
 
 struct PSVSignatureElement0
 {
-  uint32_t SemanticName;          // Offset into PSVStringTable
+  uint32_t SemanticName;          // Offset into StringTable
   uint32_t SemanticIndexes;       // Offset into PSVSemanticIndexTable, count == Rows
   uint8_t Rows;                   // Number of rows this element occupies
   uint8_t StartRow;               // Starting row of packing location if allocated
@@ -302,6 +405,240 @@ struct PSVInitInfo
   uint8_t SigOutputVectors[4] = {0, 0, 0, 0};
 };
 
+struct ResourceReader {
+private:
+  const RuntimeDataResourceInfo *m_ResourceInfo;
+  PSVStringTable *m_StringReader;
+public:
+  ResourceReader() : m_ResourceInfo(nullptr), m_StringReader(nullptr) {}
+  ResourceReader(const RuntimeDataResourceInfo *resInfo,
+                 PSVStringTable *stringReader)
+      : m_ResourceInfo(resInfo), m_StringReader(stringReader) {}
+  PSVResourceType GetResourceType() {
+    return (PSVResourceType)m_ResourceInfo->ResType;
+  }
+  uint32_t GetSpace() {
+    return m_ResourceInfo->Space;
+  }
+  uint32_t GetLowerBound() {
+    return m_ResourceInfo->LowerBound;
+  }
+  uint32_t GetUpperBound() {
+    return m_ResourceInfo->UpperBound;
+  }
+  PSVResourceKind GetResourceKind() { return (PSVResourceKind)m_ResourceInfo->Kind; }
+  const char* GetName() { return m_StringReader->Get(m_ResourceInfo->Name); }
+};
+
+struct ResourceTableReader {
+private:
+  const RuntimeDataResourceInfo *m_ResourceInfo; // pointer to an array of resource bind infos
+  PSVStringTable *m_StringReader;
+  uint32_t m_CBufferCount;
+  uint32_t m_SamplerCount;
+  uint32_t m_SRVCount;
+  uint32_t m_UAVCount;
+
+public:
+  ResourceTableReader()
+      : m_ResourceInfo(nullptr), m_StringReader(nullptr), m_CBufferCount(0),
+        m_SamplerCount(0), m_SRVCount(0), m_UAVCount(0){};
+  ResourceTableReader(const RuntimeDataResourceInfo *info1,
+                      PSVStringTable *stringTable, uint32_t CBufferCount,
+                      uint32_t SamplerCount, uint32_t SRVCount,
+                      uint32_t UAVCount)
+      : m_ResourceInfo(info1), m_StringReader(stringTable),
+        m_CBufferCount(CBufferCount), m_SamplerCount(SamplerCount),
+        m_SRVCount(SRVCount), m_UAVCount(UAVCount){};
+
+  void SetResourceInfo(const RuntimeDataResourceInfo *ptr) { m_ResourceInfo = ptr; }
+  void SetStringReader(PSVStringTable *ptr) { m_StringReader = ptr; }
+  void SetCBufferCount(uint32_t count) { m_CBufferCount = count; }
+  void SetSamplerCount(uint32_t count) { m_SamplerCount = count; }
+  void SetSRVCount(uint32_t count) { m_SRVCount = count; }
+  void SetUAVCount(uint32_t count) { m_UAVCount = count; }
+
+  uint32_t GetNumResources() {
+    return m_CBufferCount + m_SamplerCount + m_SRVCount + m_UAVCount;
+  }
+  ResourceReader GetItem(uint32_t i) {
+    _Analysis_assume_(i < GetNumResources());
+    return ResourceReader(&m_ResourceInfo[i], m_StringReader);
+  }
+
+
+  uint32_t GetNumCBuffers() { return m_CBufferCount; }
+  ResourceReader GetCBuffer(uint32_t i) {
+    _Analysis_assume_(i < m_CBufferCount);
+    return ResourceReader(&m_ResourceInfo[i], m_StringReader);
+  }
+
+  uint32_t GetNumSamplers() { return m_SamplerCount; }
+  ResourceReader GetSampler(uint32_t i) {
+    _Analysis_assume_(i < m_SamplerCount);
+    uint32_t offset = (m_CBufferCount + i);
+    return ResourceReader(&m_ResourceInfo[offset], m_StringReader);
+  }
+
+  uint32_t GetNumSRVs() { return m_SRVCount; }
+  ResourceReader GetSRV(uint32_t i) {
+    _Analysis_assume_(i < m_SRVCount);
+    uint32_t offset = (m_CBufferCount + m_SamplerCount + i);
+    return ResourceReader(&m_ResourceInfo[offset], m_StringReader);
+  }
+
+  uint32_t GetNumUAVs() { return m_UAVCount; }
+  ResourceReader GetUAV(uint32_t i) {
+    _Analysis_assume_(i < m_UAVCount);
+    uint32_t offset = (m_CBufferCount + m_SamplerCount + m_SRVCount + i);
+    return ResourceReader(&m_ResourceInfo[offset], m_StringReader);
+  }
+};
+
+struct FunctionReader {
+private:
+  const RuntimeDataFunctionInfo *m_RuntimeDataFunctionInfo;
+  PSVStringTable *m_StringReader;
+  IndexTableReader *m_IndexTableReader;
+  ResourceTableReader *m_ResourceTableReader;
+public:
+  FunctionReader()
+      : m_RuntimeDataFunctionInfo(nullptr), m_StringReader(nullptr),
+        m_IndexTableReader(nullptr), m_ResourceTableReader(nullptr) {}
+  FunctionReader(const RuntimeDataFunctionInfo *functionInfo,
+                 PSVStringTable *stringReader,
+                 IndexTableReader *indexTableReader,
+                 ResourceTableReader *resourceTableReader)
+      : m_RuntimeDataFunctionInfo(functionInfo), m_StringReader(stringReader),
+        m_IndexTableReader(indexTableReader),
+        m_ResourceTableReader(resourceTableReader) {}
+
+  const char *GetName() { return m_StringReader->Get(m_RuntimeDataFunctionInfo->Name); }
+  const char *GetUnmangledName() { return m_StringReader->Get(m_RuntimeDataFunctionInfo->UnmangledName); }
+  uint64_t GetFeatureFlag() {
+    uint64_t flag = static_cast<uint64_t>(m_RuntimeDataFunctionInfo->FeatureInfo2) << 32;
+    flag |= static_cast<uint64_t>(m_RuntimeDataFunctionInfo->FeatureInfo1);
+    return flag;
+  }
+  uint32_t GetShaderStageFlag() { return m_RuntimeDataFunctionInfo->ShaderStageFlag; }
+  uint32_t GetMinShaderTarget() { return m_RuntimeDataFunctionInfo->MinShaderTarget; }
+  uint32_t FunctionReader::GetNumResources() {
+    if (m_RuntimeDataFunctionInfo->Resources == UINT_MAX)
+      return 0;
+    return m_IndexTableReader->getRow(m_RuntimeDataFunctionInfo->Resources).Count();
+  }
+  ResourceReader GetResource(uint32_t i) {
+    uint32_t resIndex = m_IndexTableReader->getRow(m_RuntimeDataFunctionInfo->Resources).At(i);
+    return m_ResourceTableReader->GetItem(resIndex);
+  }
+
+  uint32_t GetPayloadSizeInBytes() { return m_RuntimeDataFunctionInfo->PayloadSizeInBytes; }
+  uint32_t GetAttributeSizeInBytes() { return m_RuntimeDataFunctionInfo->AttributeSizeInBytes; }
+  // payload (hit shaders) and parameters (call shaders) are mutually exclusive
+  uint32_t GetParameterSizeInBytes() { return m_RuntimeDataFunctionInfo->PayloadSizeInBytes; }
+  PSVShaderKind GetShaderKind() { return (PSVShaderKind) m_RuntimeDataFunctionInfo->ShaderKind; }
+};
+
+struct FunctionTableReader {
+private:
+  const RuntimeDataFunctionInfo *m_infos;
+  uint32_t m_count;
+  PSVStringTable *m_StringReader;
+  IndexTableReader *m_IndexTableReader;
+  ResourceTableReader *m_ResourceTableReader;
+public:
+  FunctionTableReader()
+      : m_infos(nullptr), m_count(0), m_StringReader(nullptr),
+        m_IndexTableReader(nullptr), m_ResourceTableReader(nullptr) {}
+  FunctionTableReader(const RuntimeDataFunctionInfo *functionInfos,
+                      uint32_t count, PSVStringTable *stringReader = nullptr,
+                      IndexTableReader *indexTableReader = nullptr,
+                      ResourceTableReader *resourceTableReader = nullptr)
+      : m_infos(functionInfos), m_count(count), m_StringReader(stringReader),
+        m_IndexTableReader(indexTableReader),
+        m_ResourceTableReader(resourceTableReader) {}
+
+  FunctionReader GetItem(uint32_t i) {
+    return FunctionReader(&m_infos[i], m_StringReader, m_IndexTableReader,
+                          m_ResourceTableReader);
+  }
+  uint32_t GetNumFunctions() { return m_count; }
+
+  void SetStringReader(PSVStringTable *ptr) { m_StringReader = ptr; }
+  void SetIndexTableReader(IndexTableReader *ptr) { m_IndexTableReader = ptr; }
+  void SetResourceTableReader(ResourceTableReader *ptr) { m_ResourceTableReader = ptr; }
+  void SetFunctionInfo(const RuntimeDataFunctionInfo *ptr) { m_infos = ptr; }
+  void SetCount(uint32_t count) { m_count = count; }
+};
+
+class DxilRuntimeData {
+private:
+  uint32_t m_TableCount;
+  PSVStringTable m_StringReader;
+  IndexTableReader m_IndexTableReader;
+  ResourceTableReader m_ResourceTableReader;
+  FunctionTableReader m_FunctionTableReader;
+  friend struct FunctionReader;
+  friend struct ResourceReader;
+public:
+  DxilRuntimeData()
+      : m_TableCount(0), m_StringReader(), m_ResourceTableReader(),
+        m_FunctionTableReader(), m_IndexTableReader() {}
+  DxilRuntimeData(const char *ptr) {
+    InitFromRDAT(ptr);
+  }
+  // initializing reader from RDAT. return true if no error has occured.
+  bool InitFromRDAT(const char *ptr) {
+    if (ptr) {
+      uint32_t TableCount = (uint32_t)*ptr;
+      RuntimeDataTableHeader *records = (RuntimeDataTableHeader *)(ptr + 4);
+      for (uint32_t i = 0; i < TableCount; ++i) {
+        RuntimeDataTableHeader *curRecord = &records[i];
+        switch (curRecord->tableType) {
+          case RuntimeDataTableType::Resource: {
+            uint32_t cBufferCount = *(uint32_t*)(ptr + curRecord->offset);
+            uint32_t samplerCount = *(uint32_t*)(ptr + curRecord->offset + 4);
+            uint32_t srvCount = *(uint32_t*)(ptr + curRecord->offset + 8);
+            uint32_t uavCount = *(uint32_t*)(ptr + curRecord->offset + 12);
+            m_ResourceTableReader.SetResourceInfo((RuntimeDataResourceInfo*)(ptr + curRecord->offset + 16));
+            m_ResourceTableReader.SetCBufferCount(cBufferCount);
+            m_ResourceTableReader.SetSamplerCount(samplerCount);
+            m_ResourceTableReader.SetSRVCount(srvCount);
+            m_ResourceTableReader.SetUAVCount(uavCount);
+            m_FunctionTableReader.SetResourceTableReader(&m_ResourceTableReader);
+            break;
+          }
+          case RuntimeDataTableType::String: {
+            m_StringReader = PSVStringTable(ptr + curRecord->offset, curRecord->size);
+            m_ResourceTableReader.SetStringReader(&m_StringReader);
+            m_FunctionTableReader.SetStringReader(&m_StringReader);
+            break;
+          }
+          case RuntimeDataTableType::Function: {
+            RuntimeDataFunctionInfo *funcInfo =
+                (RuntimeDataFunctionInfo *)(ptr + curRecord->offset);
+            m_FunctionTableReader.SetFunctionInfo(funcInfo);
+            m_FunctionTableReader.SetCount(curRecord->size / sizeof(RuntimeDataFunctionInfo));
+            break;
+          }
+          case RuntimeDataTableType::Index: {
+            m_IndexTableReader = IndexTableReader(
+                (uint32_t *)(ptr + curRecord->offset), curRecord->size / 4);
+            m_FunctionTableReader.SetIndexTableReader(&m_IndexTableReader);
+            break;
+          }
+          default:
+            return false;
+        }
+      }
+      return true;
+    }
+    return false;
+  }
+  FunctionTableReader *GetFunctionTableReader() { return &m_FunctionTableReader; }
+  ResourceTableReader *GetResourceTableReader() { return &m_ResourceTableReader; }
+};
+
 class DxilPipelineStateValidation
 {
   uint32_t m_uPSVRuntimeInfoSize;
@@ -323,7 +660,7 @@ class DxilPipelineStateValidation
   uint32_t* m_pPCInputToOutputTable;
 
 public:
-  DxilPipelineStateValidation() : 
+  DxilPipelineStateValidation() :
     m_uPSVRuntimeInfoSize(0),
     m_pPSVRuntimeInfo0(nullptr),
     m_pPSVRuntimeInfo1(nullptr),
@@ -792,6 +1129,9 @@ public:
     return PSVDependencyTable();
   }
 };
+} // namespace PSV
+} // namespace DXIL
+} // namespace hlsl
 
 namespace hlsl {
 
@@ -810,7 +1150,7 @@ namespace hlsl {
       InvalidPSV,
     };
     virtual ~ViewIDValidator() {}
-    virtual Result ValidateStage(const DxilPipelineStateValidation &PSV,
+    virtual Result ValidateStage(const DXIL::PSV::DxilPipelineStateValidation &PSV,
                                  bool bFinalStage,
                                  bool bExpandInputOnly,
                                  unsigned &mismatchElementId) = 0;

+ 4 - 3
include/dxc/HLSL/HLModule.h

@@ -64,6 +64,8 @@ struct HLOptions {
   unsigned unused                  : 24;
 };
 
+typedef std::unordered_map<const llvm::Function *, std::unique_ptr<DxilFunctionProps>> DxilFunctionPropsMap;
+
 /// Use this class to manipulate HLDXIR of a shader.
 class HLModule {
 public:
@@ -213,8 +215,7 @@ public:
   DxilTypeSystem *ReleaseTypeSystem();
   OP *ReleaseOP();
   RootSignatureHandle *ReleaseRootSignature();
-  std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>> &&
-  ReleaseFunctionPropsMap();
+  DxilFunctionPropsMap &&ReleaseFunctionPropsMap();
 
   llvm::DebugInfoFinder &GetOrCreateDebugInfoFinder();
   static llvm::DIGlobalVariable *
@@ -246,7 +247,7 @@ private:
   std::vector<llvm::GlobalVariable*>  m_TGSMVariables;
 
   // High level function info.
-  std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>>  m_DxilFunctionPropsMap;
+  std::unordered_map<const llvm::Function *, std::unique_ptr<DxilFunctionProps>>  m_DxilFunctionPropsMap;
   std::unordered_set<llvm::Function *>  m_PatchConstantFunctions;
 
   // Resource type annotation.

+ 28 - 28
include/dxc/HLSL/ViewIDPipelineValidation.inl

@@ -16,7 +16,7 @@ namespace {
 
 typedef std::vector<DxilSignatureAllocator::DummyElement> ElementVec;
 
-struct ComponentMask : public PSVComponentMask {
+struct ComponentMask : public DXIL::PSV::PSVComponentMask {
   uint32_t Data[4];
   ComponentMask() : PSVComponentMask(Data, 0) {
     memset(Data, 0, sizeof(Data));
@@ -30,7 +30,7 @@ struct ComponentMask : public PSVComponentMask {
   ComponentMask &operator=(const PSVComponentMask &other) {
     NumVectors = other.NumVectors;
     if (other.Mask && NumVectors) {
-      memcpy(Data, other.Mask, sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors(NumVectors));
+      memcpy(Data, other.Mask, sizeof(uint32_t) * DXIL::PSV::PSVComputeMaskDwordsFromVectors(NumVectors));
     }
     else {
       memset(Data, 0, sizeof(Data));
@@ -49,7 +49,7 @@ struct ComponentMask : public PSVComponentMask {
 };
 
 static void InitElement(DxilSignatureAllocator::DummyElement &eOut,
-                        const PSVSignatureElement &eIn,
+                        const DXIL::PSV::PSVSignatureElement &eIn,
                         DXIL::SigPointKind sigPoint) {
   eOut.rows = eIn.GetRows();
   eOut.cols = eIn.GetCols();
@@ -69,7 +69,7 @@ static void CopyElements( ElementVec &outElements,
                           DXIL::SigPointKind sigPoint,
                           unsigned numElements,
                           unsigned streamIndex,
-                          std::function<PSVSignatureElement(unsigned)> getElement) {
+                          std::function<DXIL::PSV::PSVSignatureElement(unsigned)> getElement) {
   outElements.clear();
   outElements.reserve(numElements);
   for (unsigned i = 0; i < numElements; i++) {
@@ -84,7 +84,7 @@ static void CopyElements( ElementVec &outElements,
 
 static void AddViewIDElements(ElementVec &outElements,
                               ElementVec &inElements,
-                              PSVComponentMask &mask,
+                              DXIL::PSV::PSVComponentMask &mask,
                               unsigned viewIDCount) {
   // Compute needed elements
   for (unsigned adding = 0; adding < 2; adding++) {
@@ -185,7 +185,7 @@ static bool MergeElements(const ElementVec &priorElements,
 static void PropagateMask(const ComponentMask &priorMask,
                           ElementVec &inputElements,
                           ComponentMask &outMask,
-                          std::function<PSVComponentMask(unsigned)> getMask) {
+                          std::function<DXIL::PSV::PSVComponentMask(unsigned)> getMask) {
   // Iterate elements
   for (auto &E : inputElements) {
     for (unsigned row = 0; row < E.GetRows(); row++) {
@@ -238,7 +238,7 @@ public:
       m_GSRastStreamIndex(gsRastStreamIndex)
   {}
   virtual ~ViewIDValidator_impl() {}
-  __override Result ValidateStage(const DxilPipelineStateValidation &PSV,
+  __override Result ValidateStage(const DXIL::PSV::DxilPipelineStateValidation &PSV,
                                   bool bFinalStage,
                                   bool bExpandInputOnly,
                                   unsigned &mismatchElementId) {
@@ -248,7 +248,7 @@ public:
       return Result::InvalidPSVVersion;
 
     switch (PSV.GetShaderKind()) {
-    case PSVShaderKind::Vertex: {
+    case DXIL::PSV::PSVShaderKind::Vertex: {
       if (bExpandInputOnly)
         return Result::InvalidUsage;
 
@@ -258,7 +258,7 @@ public:
       // capture output signature
       ElementVec outSig;
       CopyElements( outSig, DXIL::SigPointKind::VSOut, PSV.GetSigOutputElements(), 0,
-                    [&](unsigned i) -> PSVSignatureElement {
+                    [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetOutputElement0(i));
                     });
 
@@ -270,7 +270,7 @@ public:
 
       break;
     }
-    case PSVShaderKind::Hull: {
+    case DXIL::PSV::PSVShaderKind::Hull: {
       if (bFinalStage)
         return Result::InvalidUsage;
 
@@ -281,7 +281,7 @@ public:
       // capture signatures
       ElementVec inSig, outSig, pcSig;
       CopyElements( inSig, DXIL::SigPointKind::HSCPIn, PSV.GetSigInputElements(), 0,
-                    [&](unsigned i) -> PSVSignatureElement {
+                    [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetInputElement0(i));
                     });
 
@@ -303,22 +303,22 @@ public:
       }
 
       CopyElements(outSig, DXIL::SigPointKind::HSCPOut, PSV.GetSigOutputElements(), 0,
-        [&](unsigned i) -> PSVSignatureElement {
+        [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
         return PSV.GetSignatureElement(PSV.GetOutputElement0(i));
       });
       CopyElements(pcSig, DXIL::SigPointKind::PCOut, PSV.GetSigPatchConstantElements(), 0,
-        [&](unsigned i) -> PSVSignatureElement {
+        [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
         return PSV.GetSignatureElement(PSV.GetPatchConstantElement0(i));
       });
 
       // Propagate prior mask through input-output dependencies
       if (PSV.GetInputToOutputTable(0).IsValid()) {
         PropagateMask(m_PriorOutputMask, inSig, outputMask,
-                      [&](unsigned i) -> PSVComponentMask { return PSV.GetInputToOutputTable(0).GetMaskForInput(i); });
+                      [&](unsigned i) -> DXIL::PSV::PSVComponentMask { return PSV.GetInputToOutputTable(0).GetMaskForInput(i); });
       }
       if (PSV.GetInputToPCOutputTable().IsValid()) {
         PropagateMask(m_PriorOutputMask, inSig, pcMask,
-                      [&](unsigned i) -> PSVComponentMask { return PSV.GetInputToPCOutputTable().GetMaskForInput(i); });
+                      [&](unsigned i) -> DXIL::PSV::PSVComponentMask { return PSV.GetInputToPCOutputTable().GetMaskForInput(i); });
       }
 
       // Copy mask to prior mask
@@ -335,18 +335,18 @@ public:
 
       break;
     }
-    case PSVShaderKind::Domain: {
+    case DXIL::PSV::PSVShaderKind::Domain: {
       // Initialize mask with direct ViewID dependent outputs
       ComponentMask mask(PSV.GetViewIDOutputMask(0));
 
       // capture signatures
       ElementVec inSig, pcSig, outSig;
       CopyElements( inSig, DXIL::SigPointKind::DSCPIn, PSV.GetSigInputElements(), 0,
-                    [&](unsigned i) -> PSVSignatureElement {
+                    [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetInputElement0(i));
                     });
       CopyElements( pcSig, DXIL::SigPointKind::DSIn, PSV.GetSigPatchConstantElements(), 0,
-                    [&](unsigned i) -> PSVSignatureElement {
+                    [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetPatchConstantElement0(i));
                     });
 
@@ -382,18 +382,18 @@ public:
       }
 
       CopyElements(outSig, DXIL::SigPointKind::DSOut, PSV.GetSigOutputElements(), 0,
-        [&](unsigned i) -> PSVSignatureElement {
+        [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
         return PSV.GetSignatureElement(PSV.GetOutputElement0(i));
       });
 
       // Propagate prior mask through input-output dependencies
       if (PSV.GetInputToOutputTable(0).IsValid()) {
         PropagateMask(m_PriorOutputMask, inSig, mask,
-                      [&](unsigned i) -> PSVComponentMask { return PSV.GetInputToOutputTable(0).GetMaskForInput(i); });
+                      [&](unsigned i) -> DXIL::PSV::PSVComponentMask { return PSV.GetInputToOutputTable(0).GetMaskForInput(i); });
       }
       if (PSV.GetPCInputToOutputTable().IsValid()) {
         PropagateMask(m_PriorPCMask, pcSig, mask,
-                      [&](unsigned i) -> PSVComponentMask { return PSV.GetPCInputToOutputTable().GetMaskForInput(i); });
+                      [&](unsigned i) -> DXIL::PSV::PSVComponentMask { return PSV.GetPCInputToOutputTable().GetMaskForInput(i); });
       }
 
       // Copy mask to prior mask
@@ -406,11 +406,11 @@ public:
 
       break;
     }
-    case PSVShaderKind::Geometry: {
+    case DXIL::PSV::PSVShaderKind::Geometry: {
       // capture signatures
       ElementVec inSig, outSig[4];
       CopyElements( inSig, DXIL::SigPointKind::GSVIn, PSV.GetSigInputElements(), 0,
-                    [&](unsigned i) -> PSVSignatureElement {
+                    [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetInputElement0(i));
                     });
 
@@ -436,7 +436,7 @@ public:
         ComponentMask mask(PSV.GetViewIDOutputMask(streamIndex));
 
         CopyElements( outSig[streamIndex], DXIL::SigPointKind::GSOut, PSV.GetSigOutputElements(), streamIndex,
-                      [&](unsigned i) -> PSVSignatureElement {
+                      [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
                         return PSV.GetSignatureElement(PSV.GetOutputElement0(i));
                       });
 
@@ -444,7 +444,7 @@ public:
           // Propagate prior mask through input-output dependencies
           if (PSV.GetInputToOutputTable(streamIndex).IsValid()) {
             PropagateMask(m_PriorOutputMask, inSig, mask,
-              [&](unsigned i) -> PSVComponentMask { return PSV.GetInputToOutputTable(streamIndex).GetMaskForInput(i); });
+              [&](unsigned i) -> DXIL::PSV::PSVComponentMask { return PSV.GetInputToOutputTable(streamIndex).GetMaskForInput(i); });
           }
 
           // Create new version with ViewID elements from prior signature
@@ -473,11 +473,11 @@ public:
 
       return Result::Success;
     }
-    case PSVShaderKind::Pixel: {
+    case DXIL::PSV::PSVShaderKind::Pixel: {
       // capture signatures
       ElementVec inSig;
       CopyElements( inSig, DXIL::SigPointKind::PSIn, PSV.GetSigInputElements(), 0,
-                    [&](unsigned i) -> PSVSignatureElement {
+                    [&](unsigned i) -> DXIL::PSV::PSVSignatureElement {
                       return PSV.GetSignatureElement(PSV.GetInputElement0(i));
                     });
 
@@ -500,7 +500,7 @@ public:
       // PS has to be the last stage, so return.
       return Result::Success;
     }
-    case PSVShaderKind::Compute:
+    case DXIL::PSV::PSVShaderKind::Compute:
     default:
       return Result::InvalidUsage;
     }

+ 357 - 5
lib/HLSL/DxilContainerAssembler.cpp

@@ -18,6 +18,8 @@
 #include "dxc/HLSL/DxilModule.h"
 #include "dxc/HLSL/DxilShaderModel.h"
 #include "dxc/HLSL/DxilRootSignature.h"
+#include "dxc/HLSL/DxilUtil.h"
+#include "dxc/HLSL/DxilFunctionProps.h"
 #include "dxc/Support/Global.h"
 #include "dxc/Support/Unicode.h"
 #include "dxc/Support/WinIncludes.h"
@@ -29,6 +31,7 @@
 
 using namespace llvm;
 using namespace hlsl;
+using namespace hlsl::DXIL::PSV;
 
 static DxilProgramSigSemantic KindToSystemValue(Semantic::Kind kind, DXIL::TessellatorDomain domain) {
   switch (kind) {
@@ -436,6 +439,7 @@ public:
     UINT uSRVs = m_Module.GetSRVs().size();
     UINT uUAVs = m_Module.GetUAVs().size();
     m_PSVInitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs;
+    // TODO: for >= 6.2 version, create more efficient structure
     if (m_PSVInitInfo.PSVVersion > 0) {
       m_PSVInitInfo.ShaderStage = (PSVShaderKind)SM->GetKind();
       // Copy Dxil Signatures
@@ -695,6 +699,347 @@ public:
   }
 };
 
+class RDATTable {
+public:
+  virtual uint32_t GetBlobSize() const { return 0; }
+  virtual void write(void *ptr) {}
+  virtual RuntimeDataTableType GetType() const { return RuntimeDataTableType::Invalid; }
+  virtual ~RDATTable() {}
+};
+
+class ResourceTable : public RDATTable {
+private:
+  uint32_t m_Version;
+  std::vector<std::pair<const DxilCBuffer*, uint32_t>> CBufferToOffset;
+  std::vector<std::pair<const DxilSampler*, uint32_t>> SamplerToOffset;
+  std::vector<std::pair<const DxilResource*, uint32_t>> SRVToOffset;
+  std::vector<std::pair<const DxilResource*, uint32_t>> UAVToOffset;
+
+  void UpdateResourceInfo(const DxilResourceBase *res, uint32_t offset,
+                          RuntimeDataResourceInfo *info, char **pCur) {
+    info->Kind = static_cast<uint32_t>(res->GetKind());
+    info->Space = res->GetSpaceID();
+    info->LowerBound = res->GetLowerBound();
+    info->UpperBound = res->GetUpperBound();
+    info->Name = offset;
+    memcpy(*pCur, info, sizeof(RuntimeDataResourceInfo));
+    *pCur += sizeof(RuntimeDataResourceInfo);
+  }
+
+public:
+  ResourceTable(uint32_t version) : m_Version(version) {}
+  void AddCBuffer(const DxilCBuffer *resource, uint32_t offset) {
+    CBufferToOffset.emplace_back(
+        std::pair<const DxilCBuffer *, uint32_t>(resource, offset));
+  }
+  void AddSampler(const DxilSampler *resource, uint32_t offset) {
+    SamplerToOffset.emplace_back(
+        std::pair<const DxilSampler *, uint32_t>(resource, offset));
+  }
+  void AddSRV(const DxilResource *resource, uint32_t offset) {
+    SRVToOffset.emplace_back(
+        std::pair<const DxilResource *, uint32_t>(resource, offset));
+  }
+  void AddUAV(const DxilResource *resource, uint32_t offset) {
+    UAVToOffset.emplace_back(
+        std::pair<const DxilResource *, uint32_t>(resource, offset));
+  }
+  uint32_t NumResources() const {
+    return CBufferToOffset.size() + SamplerToOffset.size() +
+           SRVToOffset.size() + UAVToOffset.size();
+  }
+  RuntimeDataTableType GetType() const { return RuntimeDataTableType::Resource; }
+  uint32_t GetBlobSize() const {
+    return NumResources() * sizeof(RuntimeDataResourceInfo) +
+           4 * sizeof(uint32_t);
+  }
+  void write(void *ptr) {
+    // Only impelemented for RDAT for now
+    if (m_Version == 0) {
+      char *pCur = (char*)ptr;
+      // count for each resource class
+      uint32_t cBufferCount = CBufferToOffset.size();
+      uint32_t samplerCount = SamplerToOffset.size();
+      uint32_t srvCount = SRVToOffset.size();
+      uint32_t uavCount = UAVToOffset.size();
+      memcpy(pCur, &cBufferCount, sizeof(uint32_t));
+      pCur += sizeof(uint32_t);
+      memcpy(pCur, &samplerCount, sizeof(uint32_t));
+      pCur += sizeof(uint32_t);
+      memcpy(pCur, &srvCount, sizeof(uint32_t));
+      pCur += sizeof(uint32_t);
+      memcpy(pCur, &uavCount, sizeof(uint32_t));
+      pCur += sizeof(uint32_t);
+
+      for (auto pair : CBufferToOffset) {
+        RuntimeDataResourceInfo info = {};
+        info.ResType = static_cast<uint32_t>(PSVResourceType::CBV);
+        UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
+      }
+      for (auto pair : SamplerToOffset) {
+        RuntimeDataResourceInfo info = {};
+        info.ResType = static_cast<uint32_t>(PSVResourceType::Sampler);
+        UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
+      }
+      for (auto pair : SRVToOffset) {
+        RuntimeDataResourceInfo info = {};
+        auto res = pair.first;
+        if (res->IsStructuredBuffer()) {
+          info.ResType = (UINT)PSVResourceType::SRVStructured;
+        } else if (res->IsRawBuffer()) {
+          info.ResType = (UINT)PSVResourceType::SRVRaw;
+        } else {
+          info.ResType = (UINT)PSVResourceType::SRVTyped;
+        }
+        UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
+      }
+      for (auto pair : UAVToOffset) {
+        RuntimeDataResourceInfo info = {};
+        auto res = pair.first;
+        if (res->IsStructuredBuffer()) {
+          if (res->HasCounter())
+            info.ResType = (UINT)PSVResourceType::UAVStructuredWithCounter;
+          else
+            info.ResType = (UINT)PSVResourceType::UAVStructured;
+        } else if (res->IsRawBuffer()) {
+          info.ResType = (UINT)PSVResourceType::UAVRaw;
+        } else {
+          info.ResType = (UINT)PSVResourceType::UAVTyped;
+        }
+        UpdateResourceInfo(res, pair.second, &info, &pCur);
+      }
+    }
+  }
+};
+
+class FunctionTable : public RDATTable {
+private:
+  std::unordered_map<const llvm::Function*, RuntimeDataFunctionInfo> FuncToInfo;
+public:
+  FunctionTable(): FuncToInfo() {}
+  uint32_t NumFunctions() const { return FuncToInfo.size(); }
+  void AddFunction(const llvm::Function *func, uint32_t mangledOfffset,
+                   uint32_t unmangledOffset, uint32_t shaderKind, uint32_t resourceIndex,
+                   uint32_t payloadSizeInBytes, uint32_t attrSizeInBytes) {
+    RuntimeDataFunctionInfo info = {};
+    info.Name = mangledOfffset;
+    info.UnmangledName = unmangledOffset;
+    info.ShaderKind = shaderKind;
+    info.Resources = resourceIndex;
+    info.PayloadSizeInBytes = payloadSizeInBytes;
+    info.AttributeSizeInBytes = attrSizeInBytes;
+    FuncToInfo.insert({func, info});
+  }
+
+  uint32_t GetBlobSize() const { return NumFunctions() * sizeof(RuntimeDataFunctionInfo); }
+  RuntimeDataTableType GetType() const { return RuntimeDataTableType::Function; }
+  void write(void *ptr) {
+    char *cur = (char *)ptr;
+    for (auto &&pair : FuncToInfo) {
+      auto offset = pair.second;
+      memcpy(cur, &offset, sizeof(RuntimeDataFunctionInfo));
+      cur += sizeof(RuntimeDataFunctionInfo);
+    }
+  }
+};
+
+class StringTable : public RDATTable {
+private:
+  SmallVector<char, 256> m_StringBuffer;
+  uint32_t curIndex;
+public:
+  StringTable() : m_StringBuffer(), curIndex(0) {}
+  // returns the offset of the name inserted
+  uint32_t Insert(StringRef name) {
+    for (auto iter = name.begin(), End = name.end(); iter != End; ++iter) {
+        m_StringBuffer.push_back(*iter);
+    }
+    m_StringBuffer.push_back('\0');
+
+    uint32_t prevIndex = curIndex;
+    curIndex += name.size() + 1;
+    return prevIndex;
+  }
+  RuntimeDataTableType GetType() const { return RuntimeDataTableType::String; }
+  uint32_t GetBlobSize() const { return m_StringBuffer.size(); }
+  void write(void *ptr) { memcpy(ptr, m_StringBuffer.data(), m_StringBuffer.size()); }
+};
+
+template <class T>
+struct IndexTable : public RDATTable {
+private:
+  std::vector<std::vector<T>> m_IndicesList;
+  uint32_t m_curOffset;
+
+public:
+  IndexTable() : m_IndicesList(), m_curOffset(0) {}
+  uint32_t AddIndex(const std::vector<T> &Indices) {
+    uint32_t prevOffset = m_curOffset;
+    m_curOffset += Indices.size() + 1;
+    m_IndicesList.emplace_back(std::move(Indices));
+    return prevOffset;
+  }
+
+  RuntimeDataTableType GetType() const { return RuntimeDataTableType::Index; }
+  uint32_t GetBlobSize() const {
+    uint32_t size = 0;
+    for (auto Indices : m_IndicesList) {
+      size += Indices.size() + 1;
+    }
+    return sizeof(T) * size;
+  }
+
+  void write(void *ptr) {
+    T *cur = (T*)ptr;
+    for (auto Indices : m_IndicesList) {
+      uint32_t count = Indices.size();
+      memcpy(cur, &count, 4);
+      std::copy(Indices.data(), Indices.data() + Indices.size(), cur + 1);
+      cur += sizeof(T)/sizeof(4) + Indices.size();
+    }
+  }
+};
+
+class DxilRDATWriter : public DxilPartWriter {
+private:
+  const DxilModule &m_Module;
+  SmallVector<char, 1024> m_RDATBuffer;
+
+  std::vector<std::unique_ptr<RDATTable>> tables;
+  std::map<llvm::Function *, std::vector<uint32_t>> m_FuncToResNameOffset;
+
+  void UpdateFunctionToResourceInfo(const DxilResourceBase *resource, uint32_t offset) {
+    Constant *var = resource->GetGlobalSymbol();
+    if (var) {
+      for (auto user : var->users()) {
+        if (llvm::Instruction *I = dyn_cast<llvm::Instruction>(user)) {
+          if (llvm::Function *F = dyn_cast<llvm::Function>(I->getParent()->getParent())) {
+            if (m_FuncToResNameOffset.find(F) != m_FuncToResNameOffset.end()) {
+              m_FuncToResNameOffset[F].emplace_back(offset);
+            }
+            else {
+              m_FuncToResNameOffset[F] = std::vector<uint32_t>({offset});
+            }
+          }
+        }
+      }
+    }
+  }
+  void UpdateResourceInfo(StringTable &stringTable) {
+    // Try to allocate string table for resources. String table is a sequence
+    // of strings delimited by \0
+    tables.emplace_back(std::make_unique<ResourceTable>(0));
+    ResourceTable &resourceTable = *(ResourceTable*)tables.back().get();
+    uint32_t stringIndex;
+    uint32_t resourceIndex = 0;
+    for (auto &resource : m_Module.GetCBuffers()) {
+      stringIndex = stringTable.Insert(resource->GetGlobalName());
+      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
+      resourceTable.AddCBuffer(resource.get(), stringIndex);
+    }
+    for (auto &resource : m_Module.GetSamplers()) {
+      stringIndex = stringTable.Insert(resource->GetGlobalName());
+      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
+      resourceTable.AddSampler(resource.get(), stringIndex);
+    }
+    for (auto &resource : m_Module.GetSRVs()) {
+      stringIndex = stringTable.Insert(resource->GetGlobalName());
+      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
+      resourceTable.AddSRV(resource.get(), stringIndex);
+    }
+    for (auto &resource : m_Module.GetUAVs()) {
+      stringIndex = stringTable.Insert(resource->GetGlobalName());
+      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
+      resourceTable.AddUAV(resource.get(), stringIndex);
+    }
+  }
+  void UpdateFunctionInfo(StringTable &stringTable) {
+    // TODO: get a list of required features
+    // TODO: get a list of valid shader flags
+    // TODO: get a minimum shader version
+    tables.emplace_back(std::make_unique<FunctionTable>());
+    FunctionTable &functionTable = *(FunctionTable*)(tables.back().get());
+    tables.emplace_back(std::make_unique<IndexTable<uint32_t>>());
+    IndexTable<uint32_t> &indexTable = *(IndexTable<uint32_t>*)(tables.back().get());
+    for (auto &function : m_Module.GetModule()->getFunctionList()) {
+      if (!function.isDeclaration()) {
+        StringRef mangled = function.getName();
+        StringRef unmangled = hlsl::dxilutil::DemangleFunctionName(function.getName());
+        uint32_t mangledIndex = stringTable.Insert(mangled);
+        uint32_t unmangledIndex = stringTable.Insert(unmangled);
+        // Update resource Index
+        uint32_t resourceIndex = UINT_MAX;
+        uint32_t payloadSizeInBytes = 0;
+        uint32_t attrSizeInBytes = 0;
+        uint32_t shaderKind = (uint32_t)PSVShaderKind::Library;
+        if (m_FuncToResNameOffset.find(&function) != m_FuncToResNameOffset.end())
+          resourceIndex = indexTable.AddIndex(m_FuncToResNameOffset[&function]);
+        if (m_Module.HasDxilFunctionProps(&function)) {
+          auto props = m_Module.GetDxilFunctionProps(&function);
+          if (props.IsClosestHit() || props.IsAnyHit()) {
+            payloadSizeInBytes = props.ShaderProps.Ray.payloadSizeInBytes;
+            attrSizeInBytes = props.ShaderProps.Ray.attributeSizeInBytes;
+          }
+          else if (props.IsMiss()) {
+            payloadSizeInBytes = props.ShaderProps.Ray.payloadSizeInBytes;
+          }
+          else if (props.IsCallable()) {
+            payloadSizeInBytes = props.ShaderProps.Ray.paramSizeInBytes;
+          }
+          shaderKind = (uint32_t)props.shaderKind;
+        }
+        functionTable.AddFunction(&function, mangledIndex, unmangledIndex,
+                                    shaderKind, resourceIndex,
+                                    payloadSizeInBytes, attrSizeInBytes);
+      }
+    }
+  }
+
+public:
+  DxilRDATWriter(const DxilModule &module, uint32_t InfoVersion = 0)
+      : m_Module(module), m_RDATBuffer() {
+    // It's important to keep the order of this update
+    tables.emplace_back(std::make_unique<StringTable>());
+    StringTable &stringTable = *(StringTable*)tables.back().get();
+    UpdateResourceInfo(stringTable);
+    UpdateFunctionInfo(stringTable);
+  }
+
+  __override uint32_t size() const {
+    // one variable to count the number of blobs and two blobs
+    uint32_t total = 4 + tables.size() * sizeof(RuntimeDataTableHeader);
+    for (auto &&table : tables)
+      total += table->GetBlobSize();
+    return total;
+  }
+
+  __override void write(AbstractMemoryStream *pStream) {
+    m_RDATBuffer.resize(size());
+    char *pCur = m_RDATBuffer.data();
+    // write number of tables
+    uint32_t size = tables.size();
+    memcpy(pCur, &size, sizeof(uint32_t));
+    pCur += sizeof(uint32_t);
+    // write records
+    uint32_t curTableOffset = size * sizeof(RuntimeDataTableHeader) + 4;
+    for (auto &&table : tables) {
+      RuntimeDataTableHeader record = { table->GetType(), table->GetBlobSize(), curTableOffset };
+      memcpy(pCur, &record, sizeof(RuntimeDataTableHeader));
+      pCur += sizeof(RuntimeDataTableHeader);
+      curTableOffset += record.size;
+    }
+    // write tables
+    for (auto &&table : tables) {
+      table->write(pCur);
+      pCur += table->GetBlobSize();
+    }
+
+    ULONG cbWritten;
+    IFT(pStream->Write(m_RDATBuffer.data(), m_RDATBuffer.size(), &cbWritten));
+    DXASSERT_NOMSG(cbWritten == m_RDATBuffer.size());
+  }
+};
+
 DxilPartWriter *hlsl::NewPSVWriter(const DxilModule &M, uint32_t PSVVersion) {
   return new DxilPSVWriter(M, PSVVersion);
 }
@@ -821,7 +1166,6 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
       pModule->GetOutputSignature(), pModule->GetTessellatorDomain(),
       /*IsInput*/ false,
       /*UseMinPrecision*/ !pModule->m_ShaderFlags.GetUseNativeLowPrecision());
-  DxilPSVWriter PSVWriter(*pModule);
   DxilContainerWriter_impl writer;
 
   // Write the feature part.
@@ -850,10 +1194,18 @@ void hlsl::SerializeDxilContainerForModule(DxilModule *pModule,
   }
 
   // Write the DxilPipelineStateValidation (PSV0) part.
-  writer.AddPart(DFCC_PipelineStateValidation, PSVWriter.size(), [&](AbstractMemoryStream *pStream) {
-    PSVWriter.write(pStream);
-  });
-
+  DxilRDATWriter RDATWriter(*pModule);
+  DxilPSVWriter PSVWriter(*pModule);
+  if (pModule->GetShaderModel()->IsLib()) {
+    writer.AddPart(DFCC_RuntimeData, RDATWriter.size(), [&](AbstractMemoryStream *pStream) {
+        RDATWriter.write(pStream);
+    });
+  }
+  else {
+    writer.AddPart(DFCC_PipelineStateValidation, PSVWriter.size(), [&](AbstractMemoryStream *pStream) {
+        PSVWriter.write(pStream);
+    });
+  }
   // Write the root signature (RTS0) part.
   DxilProgramRootSignatureWriter rootSigWriter(pModule->GetRootSignature());
   CComPtr<AbstractMemoryStream> pInputProgramStream = pModuleBitcode;

+ 2 - 2
lib/HLSL/DxilMetadataHelper.cpp

@@ -1028,12 +1028,12 @@ Function *DxilMDHelper::LoadDxilFunctionProps(MDTuple *pProps,
 
 MDTuple *
 DxilMDHelper::EmitDxilFunctionProps(const hlsl::DxilFunctionProps *props,
-                                    Function *F) {
+                                   const Function *F) {
   bool bRayAttributes = false;
   Metadata *MDVals[30];
   std::fill(MDVals, MDVals + _countof(MDVals), nullptr);
   unsigned valIdx = 0;
-  MDVals[valIdx++] = ValueAsMetadata::get(F);
+  MDVals[valIdx++] = ValueAsMetadata::get(const_cast<Function*>(F));
   MDVals[valIdx++] = Uint32ToConstMD(static_cast<unsigned>(props->shaderKind));
   switch (props->shaderKind) {
   case DXIL::ShaderKind::Compute:

+ 17 - 14
lib/HLSL/DxilModule.cpp

@@ -1067,15 +1067,22 @@ void DxilModule::ReplaceDxilEntrySignature(llvm::Function *F,
   m_DxilEntrySignatureMap[NewF] = std::move(Sig);
 }
 
-bool DxilModule::HasDxilFunctionProps(llvm::Function *F) const {
+bool DxilModule::HasDxilFunctionProps(const llvm::Function *F) const {
   return m_DxilFunctionPropsMap.find(F) != m_DxilFunctionPropsMap.end();
 }
-DxilFunctionProps &DxilModule::GetDxilFunctionProps(llvm::Function *F) {
+DxilFunctionProps &DxilModule::GetDxilFunctionProps(const llvm::Function *F) {
+  return const_cast<DxilFunctionProps &>(
+      static_cast<const DxilModule *>(this)->GetDxilFunctionProps(F));
+}
+
+const DxilFunctionProps &
+DxilModule::GetDxilFunctionProps(const llvm::Function *F) const {
   DXASSERT(m_DxilFunctionPropsMap.count(F) != 0, "cannot find F in map");
-  return *m_DxilFunctionPropsMap[F];
+  return *(m_DxilFunctionPropsMap.find(F))->second.get();
 }
+
 void DxilModule::AddDxilFunctionProps(
-    llvm::Function *F, std::unique_ptr<DxilFunctionProps> &info) {
+    const llvm::Function *F, std::unique_ptr<DxilFunctionProps> &info) {
   DXASSERT(m_DxilFunctionPropsMap.count(F) == 0,
            "F already in map, info will be overwritten");
   DXASSERT_NOMSG(info->shaderKind != DXIL::ShaderKind::Invalid);
@@ -1100,16 +1107,16 @@ void DxilModule::SetPatchConstantFunctionForHS(llvm::Function *hullShaderFunc, l
   if (patchConstantFunc)
     m_PatchConstantFunctions.insert(patchConstantFunc);
 }
-bool DxilModule::IsGraphicsShader(llvm::Function *F) {
+bool DxilModule::IsGraphicsShader(const llvm::Function *F) const {
   return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsGraphics();
 }
-bool DxilModule::IsPatchConstantShader(llvm::Function *F) {
+bool DxilModule::IsPatchConstantShader(const llvm::Function *F) const {
   return m_PatchConstantFunctions.count(F) != 0;
 }
-bool DxilModule::IsComputeShader(llvm::Function *F) {
+bool DxilModule::IsComputeShader(const llvm::Function *F) const {
   return HasDxilFunctionProps(F) && GetDxilFunctionProps(F).IsCS();
 }
-bool DxilModule::IsEntryThatUsesSignatures(llvm::Function *F) {
+bool DxilModule::IsEntryThatUsesSignatures(const llvm::Function *F) const {
   auto propIter = m_DxilFunctionPropsMap.find(F);
   if (propIter != m_DxilFunctionPropsMap.end()) {
     DxilFunctionProps &props = *(propIter->second);
@@ -1155,15 +1162,11 @@ void DxilModule::ResetTypeSystem(DxilTypeSystem *pValue) {
 
 void DxilModule::ResetOP(hlsl::OP *hlslOP) { m_pOP.reset(hlslOP); }
 
-void DxilModule::ResetFunctionPropsMap(
-    std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>>
-        &&propsMap) {
+void DxilModule::ResetFunctionPropsMap(DxilFunctionPropsMap &&propsMap) {
   m_DxilFunctionPropsMap = std::move(propsMap);
 }
 
-void DxilModule::ResetEntrySignatureMap(
-    std::unordered_map<llvm::Function *, std::unique_ptr<DxilEntrySignature>>
-        &&SigMap) {
+void DxilModule::ResetEntrySignatureMap(DxilEntrySignatureMap &&SigMap) {
   m_DxilEntrySignatureMap = std::move(SigMap);
 }
 

+ 2 - 1
lib/HLSL/DxilRootSignature.cpp

@@ -769,7 +769,7 @@ void RootSignatureVerifier::VerifyShader(DxilShaderVisibility VisType,
                                          const void *pPSVData,
                                          uint32_t PSVSize,
                                          DiagnosticPrinter &DiagPrinter) {
-  DxilPipelineStateValidation PSV;
+  DXIL::PSV::DxilPipelineStateValidation PSV;
   IFTBOOL(PSV.InitFromPSV0(pPSVData, PSVSize), E_INVALIDARG);
 
   bool bShaderDeniedByRootSig = false;
@@ -806,6 +806,7 @@ void RootSignatureVerifier::VerifyShader(DxilShaderVisibility VisType,
   bool bShaderHasRootBindings = false;
 
   for (unsigned iResource = 0; iResource < PSV.GetBindCount(); iResource++) {
+    using namespace DXIL::PSV;
     const PSVResourceBindInfo0 *pBindInfo0 = PSV.GetPSVResourceBindInfo0(iResource);
     DXASSERT_NOMSG(pBindInfo0);
 

+ 1 - 2
lib/HLSL/HLModule.cpp

@@ -309,8 +309,7 @@ RootSignatureHandle *HLModule::ReleaseRootSignature() {
   return m_RootSignature.release();
 }
 
-std::unordered_map<llvm::Function *, std::unique_ptr<DxilFunctionProps>> &&
-HLModule::ReleaseFunctionPropsMap() {
+DxilFunctionPropsMap &&HLModule::ReleaseFunctionPropsMap() {
   return std::move(m_DxilFunctionPropsMap);
 }
 

+ 1 - 1
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -1032,7 +1032,7 @@ void PrintPipelineStateValidationRuntimeInfo(const char *pBuffer,
      << comment << "\n";
 
   const unsigned offset = sizeof(unsigned);
-  const PSVRuntimeInfo0 *pInfo = (PSVRuntimeInfo0 *)(pBuffer + offset);
+  const hlsl::DXIL::PSV::PSVRuntimeInfo0 *pInfo = (hlsl::DXIL::PSV::PSVRuntimeInfo0 *)(pBuffer + offset);
 
   switch (shaderKind) {
   case DXIL::ShaderKind::Hull: {