Procházet zdrojové kódy

Merged PR 38: Cleanup RDAT Table, fixing resource dependency, and adding function dependency

Young Kim před 7 roky
rodič
revize
799331737d

+ 70 - 30
include/dxc/HLSL/DxilPipelineStateValidation.h

@@ -158,15 +158,23 @@ struct PSVResourceBindInfo0
   uint32_t UpperBound;
 };
 
-struct RuntimeDataResourceInfo : public PSVResourceBindInfo0
+struct RuntimeDataResourceInfo
 {
-  uint32_t Kind; // PSVResourceKind
-  uint32_t Name; // offset for string table
+  uint32_t ResType;      // PSVResourceType
+  uint32_t Space;
+  uint32_t LowerBound;
+  uint32_t UpperBound;
+  uint32_t Kind;         // PSVResourceKind
+  uint32_t Name;         // offset for string table
+  uint32_t ID;           // id per resource class
+  uint32_t flags;        // flag for resource.
 };
+
 struct RuntimeDataFunctionInfo {
   uint32_t Name;                 // offset for string table
   uint32_t UnmangledName;        // offset for string table
   uint32_t Resources;            // index to an index table
+  uint32_t FunctionDependencies; // index to a list of functions that function depends on
   uint32_t ShaderKind;           // shader kind
   uint32_t PayloadSizeInBytes;   // payload count for miss, closest hit, any hit
                                  // shader, or parameter size for call shader
@@ -219,7 +227,7 @@ struct RuntimeDataTableHeader {
   uint32_t offset;
 };
 
-enum RuntimeDataTableType : uint32_t {
+enum RuntimeDataPartType : uint32_t {
   Invalid = 0,
   String,
   Function,
@@ -436,6 +444,9 @@ public:
     return m_ResourceInfo->UpperBound;
   }
   PSVResourceKind GetResourceKind() { return (PSVResourceKind)m_ResourceInfo->Kind; }
+  uint32_t GetID() {
+    return m_ResourceInfo->ID;
+  }
   const char *GetName() {
     return m_Context->pStringTableReader->Get(m_ResourceInfo->Name);
   }
@@ -462,12 +473,34 @@ public:
         m_CBufferCount(CBufferCount), m_SamplerCount(SamplerCount),
         m_SRVCount(SRVCount), m_UAVCount(UAVCount){};
 
-  void SetResourceInfo(const RuntimeDataResourceInfo *ptr) { m_ResourceInfo = ptr; }
+  void SetResourceInfo(const RuntimeDataResourceInfo *ptr, uint32_t count) { 
+    m_ResourceInfo = ptr;
+    // Assuming that resources are in order of CBuffer, Sampler, SRV, and UAV,
+    // count the number for each resource class
+    m_CBufferCount = 0;
+    m_SamplerCount = 0;
+    m_SRVCount = 0;
+    m_UAVCount = 0;
+
+    for (uint32_t i = 0; i < count; ++i) {
+      const RuntimeDataResourceInfo *curPtr = &ptr[i];
+      if (curPtr->ResType == (uint32_t)PSVResourceType::CBV)
+        m_CBufferCount++;
+      else if (curPtr->ResType == (uint32_t)PSVResourceType::Sampler)
+        m_SamplerCount++;
+      else if (curPtr->ResType == (uint32_t)PSVResourceType::SRVRaw ||
+          curPtr->ResType == (uint32_t)PSVResourceType::SRVStructured ||
+          curPtr->ResType == (uint32_t)PSVResourceType::SRVTyped)
+        m_SRVCount++;
+      else if (curPtr->ResType == (uint32_t)PSVResourceType::UAVRaw ||
+          curPtr->ResType == (uint32_t)PSVResourceType::UAVStructured ||
+          curPtr->ResType == (uint32_t)PSVResourceType::UAVStructuredWithCounter ||
+          curPtr->ResType == (uint32_t)PSVResourceType::UAVTyped)
+        m_UAVCount++;
+    }
+  }
+
   void SetContext(RuntimeDataContext *context) { m_Context = context; }
-  void SetCBufferCount(uint32_t count) { m_CBufferCount = count; }
-  void SetSamplerCount(uint32_t count) { m_SamplerCount = count; }
-  void SetSRVCount(uint32_t count) { m_SRVCount = count; }
-  void SetUAVCount(uint32_t count) { m_UAVCount = count; }
 
   uint32_t GetNumResources() {
     return m_CBufferCount + m_SamplerCount + m_SRVCount + m_UAVCount;
@@ -477,7 +510,6 @@ public:
     return ResourceReader(&m_ResourceInfo[i], m_Context);
   }
 
-
   uint32_t GetNumCBuffers() { return m_CBufferCount; }
   ResourceReader GetCBuffer(uint32_t i) {
     _Analysis_assume_(i < m_CBufferCount);
@@ -528,7 +560,7 @@ public:
   }
   uint32_t GetShaderStageFlag() { return m_RuntimeDataFunctionInfo->ShaderStageFlag; }
   uint32_t GetMinShaderTarget() { return m_RuntimeDataFunctionInfo->MinShaderTarget; }
-  uint32_t FunctionReader::GetNumResources() {
+  uint32_t GetNumResources() {
     if (m_RuntimeDataFunctionInfo->Resources == UINT_MAX)
       return 0;
     return m_Context->pIndexTableReader->getRow(m_RuntimeDataFunctionInfo->Resources).Count();
@@ -538,6 +570,20 @@ public:
     return m_Context->pResourceTableReader->GetItem(resIndex);
   }
 
+  uint32_t GetNumDependencies() {
+    if (m_RuntimeDataFunctionInfo->FunctionDependencies == UINT_MAX)
+      return 0;
+    return m_Context->pIndexTableReader
+        ->getRow(m_RuntimeDataFunctionInfo->FunctionDependencies).Count();
+  }
+
+  const char *GetDependency(uint32_t i) {
+    uint32_t resIndex =
+        m_Context->pIndexTableReader
+            ->getRow(m_RuntimeDataFunctionInfo->FunctionDependencies).At(i);
+    return m_Context->pStringTableReader->Get(resIndex);
+  }
+
   uint32_t GetPayloadSizeInBytes() { return m_RuntimeDataFunctionInfo->PayloadSizeInBytes; }
   uint32_t GetAttributeSizeInBytes() { return m_RuntimeDataFunctionInfo->AttributeSizeInBytes; }
   // payload (hit shaders) and parameters (call shaders) are mutually exclusive
@@ -580,6 +626,8 @@ public:
         m_FunctionTableReader(), m_IndexTableReader(), m_Context() {
     m_Context = {&m_StringReader, &m_IndexTableReader, &m_ResourceTableReader,
                  &m_FunctionTableReader};
+    m_ResourceTableReader.SetContext(&m_Context);
+    m_FunctionTableReader.SetContext(&m_Context);
   }
   DxilRuntimeData(const char *ptr) {
     InitFromRDAT(ptr);
@@ -592,32 +640,24 @@ public:
       for (uint32_t i = 0; i < TableCount; ++i) {
         RuntimeDataTableHeader *curRecord = &records[i];
         switch (curRecord->tableType) {
-          case RuntimeDataTableType::Resource: {
-            uint32_t cBufferCount = *(uint32_t*)(ptr + curRecord->offset);
-            uint32_t samplerCount = *(uint32_t*)(ptr + curRecord->offset + 4);
-            uint32_t srvCount = *(uint32_t*)(ptr + curRecord->offset + 8);
-            uint32_t uavCount = *(uint32_t*)(ptr + curRecord->offset + 12);
-            m_ResourceTableReader.SetResourceInfo((RuntimeDataResourceInfo*)(ptr + curRecord->offset + 16));
-            m_ResourceTableReader.SetCBufferCount(cBufferCount);
-            m_ResourceTableReader.SetSamplerCount(samplerCount);
-            m_ResourceTableReader.SetSRVCount(srvCount);
-            m_ResourceTableReader.SetUAVCount(uavCount);
-            m_ResourceTableReader.SetContext(&m_Context);
+          case RuntimeDataPartType::Resource: {
+            m_ResourceTableReader.SetResourceInfo(
+                (RuntimeDataResourceInfo *)(ptr + curRecord->offset),
+                curRecord->size / sizeof(RuntimeDataResourceInfo));
             break;
           }
-          case RuntimeDataTableType::String: {
+          case RuntimeDataPartType::String: {
             m_StringReader = PSVStringTable(ptr + curRecord->offset, curRecord->size);
             break;
           }
-          case RuntimeDataTableType::Function: {
-            RuntimeDataFunctionInfo *funcInfo =
-                (RuntimeDataFunctionInfo *)(ptr + curRecord->offset);
-            m_FunctionTableReader.SetFunctionInfo(funcInfo);
-            m_FunctionTableReader.SetCount(curRecord->size / sizeof(RuntimeDataFunctionInfo));
-            m_FunctionTableReader.SetContext(&m_Context);
+          case RuntimeDataPartType::Function: {
+            m_FunctionTableReader.SetFunctionInfo(
+                (RuntimeDataFunctionInfo *)(ptr + curRecord->offset));
+            m_FunctionTableReader.SetCount(curRecord->size /
+                                           sizeof(RuntimeDataFunctionInfo));
             break;
           }
-          case RuntimeDataTableType::Index: {
+          case RuntimeDataPartType::Index: {
             m_IndexTableReader = IndexTableReader(
                 (uint32_t *)(ptr + curRecord->offset), curRecord->size / 4);
             break;

+ 3 - 0
include/dxc/HLSL/DxilUtil.h

@@ -20,6 +20,9 @@ class Module;
 class MemoryBuffer;
 class LLVMContext;
 class DiagnosticInfo;
+class Value;
+class Instruction;
+class StringRef;
 }
 
 namespace hlsl {

+ 165 - 185
lib/HLSL/DxilContainerAssembler.cpp

@@ -20,6 +20,7 @@
 #include "dxc/HLSL/DxilRootSignature.h"
 #include "dxc/HLSL/DxilUtil.h"
 #include "dxc/HLSL/DxilFunctionProps.h"
+#include "dxc/HLSL/DxilOperations.h"
 #include "dxc/Support/Global.h"
 #include "dxc/Support/Unicode.h"
 #include "dxc/Support/WinIncludes.h"
@@ -699,154 +700,54 @@ public:
   }
 };
 
-class RDATTable {
+// Like DXIL container, RDAT itself is a mini container that contains multiple RDAT parts
+class RDATPart {
 public:
-  virtual uint32_t GetBlobSize() const { return 0; }
-  virtual void write(void *ptr) {}
-  virtual RuntimeDataTableType GetType() const { return RuntimeDataTableType::Invalid; }
-  virtual ~RDATTable() {}
+  virtual uint32_t GetPartSize() const { return 0; }
+  virtual void Write(void *ptr) {}
+  virtual RuntimeDataPartType GetType() const { return RuntimeDataPartType::Invalid; }
+  virtual ~RDATPart() {}
 };
 
-class ResourceTable : public RDATTable {
-private:
-  uint32_t m_Version;
-  std::vector<std::pair<const DxilCBuffer*, uint32_t>> CBufferToOffset;
-  std::vector<std::pair<const DxilSampler*, uint32_t>> SamplerToOffset;
-  std::vector<std::pair<const DxilResource*, uint32_t>> SRVToOffset;
-  std::vector<std::pair<const DxilResource*, uint32_t>> UAVToOffset;
-
-  void UpdateResourceInfo(const DxilResourceBase *res, uint32_t offset,
-                          RuntimeDataResourceInfo *info, char **pCur) {
-    info->Kind = static_cast<uint32_t>(res->GetKind());
-    info->Space = res->GetSpaceID();
-    info->LowerBound = res->GetLowerBound();
-    info->UpperBound = res->GetUpperBound();
-    info->Name = offset;
-    memcpy(*pCur, info, sizeof(RuntimeDataResourceInfo));
-    *pCur += sizeof(RuntimeDataResourceInfo);
-  }
-
+// Most RDAT parts are tables each containing a list of structures of same type.
+// Exceptions are string table and index table because each string or list of
+// indicies can be of different sizes.
+template <class T>
+class RDATTable : public RDATPart {
+protected:
+  std::vector<T> m_rows;
 public:
-  ResourceTable(uint32_t version) : m_Version(version), CBufferToOffset(), SamplerToOffset(), SRVToOffset(), UAVToOffset() {}
-  void AddCBuffer(const DxilCBuffer *resource, uint32_t offset) {
-    CBufferToOffset.emplace_back(
-        std::pair<const DxilCBuffer *, uint32_t>(resource, offset));
-  }
-  void AddSampler(const DxilSampler *resource, uint32_t offset) {
-    SamplerToOffset.emplace_back(
-        std::pair<const DxilSampler *, uint32_t>(resource, offset));
-  }
-  void AddSRV(const DxilResource *resource, uint32_t offset) {
-    SRVToOffset.emplace_back(
-        std::pair<const DxilResource *, uint32_t>(resource, offset));
-  }
-  void AddUAV(const DxilResource *resource, uint32_t offset) {
-    UAVToOffset.emplace_back(
-        std::pair<const DxilResource *, uint32_t>(resource, offset));
-  }
-  uint32_t NumResources() const {
-    return CBufferToOffset.size() + SamplerToOffset.size() +
-           SRVToOffset.size() + UAVToOffset.size();
-  }
-  RuntimeDataTableType GetType() const { return RuntimeDataTableType::Resource; }
-  uint32_t GetBlobSize() const {
-    return NumResources() * sizeof(RuntimeDataResourceInfo) +
-           4 * sizeof(uint32_t);
+  virtual void Insert(T *data) {}
+  virtual ~RDATTable() {}
+
+  void Insert(const T &data) {
+    m_rows.push_back(data);
   }
-  void write(void *ptr) {
-    // Only impelemented for RDAT for now
-    if (m_Version == 0) {
-      char *pCur = (char*)ptr;
-      // count for each resource class
-      uint32_t cBufferCount = CBufferToOffset.size();
-      uint32_t samplerCount = SamplerToOffset.size();
-      uint32_t srvCount = SRVToOffset.size();
-      uint32_t uavCount = UAVToOffset.size();
-      memcpy(pCur, &cBufferCount, sizeof(uint32_t));
-      pCur += sizeof(uint32_t);
-      memcpy(pCur, &samplerCount, sizeof(uint32_t));
-      pCur += sizeof(uint32_t);
-      memcpy(pCur, &srvCount, sizeof(uint32_t));
-      pCur += sizeof(uint32_t);
-      memcpy(pCur, &uavCount, sizeof(uint32_t));
-      pCur += sizeof(uint32_t);
-
-      for (auto pair : CBufferToOffset) {
-        RuntimeDataResourceInfo info = {};
-        info.ResType = static_cast<uint32_t>(PSVResourceType::CBV);
-        UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
-      }
-      for (auto pair : SamplerToOffset) {
-        RuntimeDataResourceInfo info = {};
-        info.ResType = static_cast<uint32_t>(PSVResourceType::Sampler);
-        UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
-      }
-      for (auto pair : SRVToOffset) {
-        RuntimeDataResourceInfo info = {};
-        auto res = pair.first;
-        if (res->IsStructuredBuffer()) {
-          info.ResType = (UINT)PSVResourceType::SRVStructured;
-        } else if (res->IsRawBuffer()) {
-          info.ResType = (UINT)PSVResourceType::SRVRaw;
-        } else {
-          info.ResType = (UINT)PSVResourceType::SRVTyped;
-        }
-        UpdateResourceInfo(pair.first, pair.second, &info, &pCur);
-      }
-      for (auto pair : UAVToOffset) {
-        RuntimeDataResourceInfo info = {};
-        auto res = pair.first;
-        if (res->IsStructuredBuffer()) {
-          if (res->HasCounter())
-            info.ResType = (UINT)PSVResourceType::UAVStructuredWithCounter;
-          else
-            info.ResType = (UINT)PSVResourceType::UAVStructured;
-        } else if (res->IsRawBuffer()) {
-          info.ResType = (UINT)PSVResourceType::UAVRaw;
-        } else {
-          info.ResType = (UINT)PSVResourceType::UAVTyped;
-        }
-        UpdateResourceInfo(res, pair.second, &info, &pCur);
-      }
+
+  void Write(void *ptr) {
+    char *pCur = (char*)ptr;
+    for (auto row : m_rows) {
+      memcpy(pCur, &row, sizeof(T));
+      pCur += sizeof(T);
     }
-  }
+  };
+
+  uint32_t GetPartSize() const { return m_rows.size() * sizeof(T); }
 };
 
-class FunctionTable : public RDATTable {
-private:
-  std::vector<std::pair<const llvm::Function *, RuntimeDataFunctionInfo>> FuncToInfo;
+// Resource table will contain a list of RuntimeDataResourceInfo in order of
+// CBuffer, Sampler, SRV, and UAV resource classes.
+class ResourceTable : public RDATTable<RuntimeDataResourceInfo> {
 public:
-  FunctionTable(): FuncToInfo() {}
-  uint32_t NumFunctions() const { return FuncToInfo.size(); }
-  void AddFunction(const llvm::Function *func, uint32_t mangledOfffset,
-                   uint32_t unmangledOffset, uint32_t shaderKind, uint32_t resourceIndex,
-                   uint32_t payloadSizeInBytes, uint32_t attrSizeInBytes, ShaderFlags flags) {
-    RuntimeDataFunctionInfo info = {};
-    info.Name = mangledOfffset;
-    info.UnmangledName = unmangledOffset;
-    info.ShaderKind = shaderKind;
-    info.Resources = resourceIndex;
-    info.PayloadSizeInBytes = payloadSizeInBytes;
-    info.AttributeSizeInBytes = attrSizeInBytes;
-    uint64_t rawFlags = flags.GetShaderFlagsRaw();
-    info.FeatureInfo1 = rawFlags & 0xffffffff;
-    info.FeatureInfo2 = (rawFlags >> 32) & 0xffffffff;
-    FuncToInfo.push_back({ func, info });
-  }
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::Resource; }
+};
 
-  uint32_t GetBlobSize() const { return NumFunctions() * sizeof(RuntimeDataFunctionInfo); }
-  RuntimeDataTableType GetType() const { return RuntimeDataTableType::Function; }
-  void write(void *ptr) {
-    char *cur = (char *)ptr;
-    for (auto &&pair : FuncToInfo) {
-      auto offset = pair.second;
-      memcpy(cur, &offset, sizeof(RuntimeDataFunctionInfo));
-      cur += sizeof(RuntimeDataFunctionInfo);
-    }
-  }
+class FunctionTable : public RDATTable<RuntimeDataFunctionInfo> {
+public:
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::Function; }
 };
 
-class StringTable : public RDATTable {
+class StringTable : public RDATPart {
 private:
   SmallVector<char, 256> m_StringBuffer;
   uint32_t curIndex;
@@ -863,42 +764,41 @@ public:
     curIndex += name.size() + 1;
     return prevIndex;
   }
-  RuntimeDataTableType GetType() const { return RuntimeDataTableType::String; }
-  uint32_t GetBlobSize() const { return m_StringBuffer.size(); }
-  void write(void *ptr) { memcpy(ptr, m_StringBuffer.data(), m_StringBuffer.size()); }
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::String; }
+  uint32_t GetPartSize() const { return m_StringBuffer.size(); }
+  void Write(void *ptr) { memcpy(ptr, m_StringBuffer.data(), m_StringBuffer.size()); }
 };
 
-template <class T>
-struct IndexTable : public RDATTable {
+struct IndexTable : public RDATPart {
 private:
-  std::vector<std::vector<T>> m_IndicesList;
+  std::vector<std::vector<uint32_t>> m_IndicesList;
   uint32_t m_curOffset;
 
 public:
   IndexTable() : m_IndicesList(), m_curOffset(0) {}
-  uint32_t AddIndex(const std::vector<T> &Indices) {
+  uint32_t AddIndex(const std::vector<uint32_t> &Indices) {
     uint32_t prevOffset = m_curOffset;
     m_curOffset += Indices.size() + 1;
     m_IndicesList.emplace_back(std::move(Indices));
     return prevOffset;
   }
 
-  RuntimeDataTableType GetType() const { return RuntimeDataTableType::Index; }
-  uint32_t GetBlobSize() const {
+  RuntimeDataPartType GetType() const { return RuntimeDataPartType::Index; }
+  uint32_t GetPartSize() const {
     uint32_t size = 0;
     for (auto Indices : m_IndicesList) {
       size += Indices.size() + 1;
     }
-    return sizeof(T) * size;
+    return sizeof(uint32_t) * size;
   }
 
-  void write(void *ptr) {
-    T *cur = (T*)ptr;
+  void Write(void *ptr) {
+    uint32_t *cur = (uint32_t*)ptr;
     for (auto Indices : m_IndicesList) {
       uint32_t count = Indices.size();
       memcpy(cur, &count, 4);
       std::copy(Indices.data(), Indices.data() + Indices.size(), cur + 1);
-      cur += sizeof(T)/sizeof(4) + Indices.size();
+      cur += sizeof(uint32_t)/sizeof(4) + Indices.size();
     }
   }
 };
@@ -908,62 +808,130 @@ private:
   const DxilModule &m_Module;
   SmallVector<char, 1024> m_RDATBuffer;
 
-  std::vector<std::unique_ptr<RDATTable>> m_tables;
-  std::map<llvm::Function *, std::vector<uint32_t>> m_FuncToResNameOffset;
+  std::vector<std::unique_ptr<RDATPart>> m_tables;
+  typedef std::unordered_map<llvm::Function *, std::vector<uint32_t>> FunctionIndexMap;
+  FunctionIndexMap m_FuncToResNameOffset; // list of resources used
+  FunctionIndexMap m_FuncToDependencies;  // list of unresolved functions used
 
-  void UpdateFunctionToResourceInfo(const DxilResourceBase *resource, uint32_t offset) {
+  llvm::Function *FindUsingFunction(llvm::Value *User) {
+    if (llvm::Instruction *I = dyn_cast<llvm::Instruction>(User)) {
+      // Instruction should be inside a basic block, which is in a function
+      return cast<llvm::Function>(I->getParent()->getParent());
+    }
+    // User can be either instruction, constant, or operator. But User is an
+    // operator only if constant is a scalar value, not resource pointer.
+    llvm::Constant *CU = cast<llvm::Constant>(User);
+    return FindUsingFunction(*CU->user_begin());
+  }
+
+  void UpdateFunctionToResourceInfo(const DxilResourceBase *resource,
+                                    uint32_t offset) {
     Constant *var = resource->GetGlobalSymbol();
     if (var) {
       for (auto user : var->users()) {
-        if (llvm::Instruction *I = dyn_cast<llvm::Instruction>(user)) {
-          if (llvm::Function *F = dyn_cast<llvm::Function>(I->getParent()->getParent())) {
-            if (m_FuncToResNameOffset.find(F) != m_FuncToResNameOffset.end()) {
-              m_FuncToResNameOffset[F].emplace_back(offset);
-            }
-            else {
-              m_FuncToResNameOffset[F] = std::vector<uint32_t>({offset});
-            }
-          }
+        // Find the function.
+        llvm::Function *F = FindUsingFunction(user);
+        if (m_FuncToResNameOffset.find(F) != m_FuncToResNameOffset.end()) {
+          m_FuncToResNameOffset[F].emplace_back(offset);
+        }
+        else {
+          m_FuncToResNameOffset[F] = std::vector<uint32_t>({offset});
         }
       }
     }
   }
+
+  void InsertToResourceTable(DxilResourceBase &resource,
+                             PSVResourceType resType,
+                             ResourceTable &resourceTable,
+                             StringTable &stringTable,
+                             uint32_t &resourceIndex) {
+    uint32_t stringIndex = stringTable.Insert(resource.GetGlobalName());
+    UpdateFunctionToResourceInfo(&resource, resourceIndex++);
+    RuntimeDataResourceInfo info = {};
+    info.Kind = static_cast<uint32_t>(resource.GetKind());
+    info.ResType = (uint32_t)resType,
+    info.Space = resource.GetSpaceID();
+    info.LowerBound = resource.GetLowerBound();
+    info.UpperBound = resource.GetUpperBound();
+    info.Name = stringIndex;
+    info.ID = resource.GetID();
+    resourceTable.Insert(info);
+  }
+
   void UpdateResourceInfo(StringTable &stringTable) {
     // Try to allocate string table for resources. String table is a sequence
     // of strings delimited by \0
-    m_tables.emplace_back(std::make_unique<ResourceTable>(0));
+    m_tables.emplace_back(std::make_unique<ResourceTable>());
     ResourceTable &resourceTable = *(ResourceTable*)m_tables.back().get();
-    uint32_t stringIndex;
     uint32_t resourceIndex = 0;
     for (auto &resource : m_Module.GetCBuffers()) {
-      stringIndex = stringTable.Insert(resource->GetGlobalName());
-      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
-      resourceTable.AddCBuffer(resource.get(), stringIndex);
+      InsertToResourceTable(*resource.get(), PSVResourceType::CBV, resourceTable, stringTable,
+                            resourceIndex);
+
     }
     for (auto &resource : m_Module.GetSamplers()) {
-      stringIndex = stringTable.Insert(resource->GetGlobalName());
-      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
-      resourceTable.AddSampler(resource.get(), stringIndex);
+      InsertToResourceTable(*resource.get(), PSVResourceType::Sampler, resourceTable, stringTable,
+                            resourceIndex);
     }
     for (auto &resource : m_Module.GetSRVs()) {
-      stringIndex = stringTable.Insert(resource->GetGlobalName());
-      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
-      resourceTable.AddSRV(resource.get(), stringIndex);
+      PSVResourceType resType = PSVResourceType::Invalid;
+      if (resource->IsStructuredBuffer()) {
+        resType = PSVResourceType::SRVStructured;
+      } else if (resource->IsRawBuffer()) {
+        resType = PSVResourceType::SRVRaw;
+      } else {
+        resType = PSVResourceType::SRVTyped;
+      }
+      InsertToResourceTable(*resource.get(), resType, resourceTable, stringTable,
+                            resourceIndex);
     }
     for (auto &resource : m_Module.GetUAVs()) {
-      stringIndex = stringTable.Insert(resource->GetGlobalName());
-      UpdateFunctionToResourceInfo(resource.get(), resourceIndex++);
-      resourceTable.AddUAV(resource.get(), stringIndex);
+      PSVResourceType resType = PSVResourceType::Invalid;
+      if (resource->IsStructuredBuffer()) {
+        if (resource->HasCounter())
+          resType = PSVResourceType::UAVStructuredWithCounter;
+        else
+          resType = PSVResourceType::UAVStructured;
+      } else if (resource->IsRawBuffer()) {
+        resType = PSVResourceType::UAVRaw;
+      } else {
+        resType = PSVResourceType::UAVTyped;
+      }
+      InsertToResourceTable(*resource.get(), resType, resourceTable, stringTable,
+                            resourceIndex);
     }
   }
+
+  void UpdateFunctionDependency(llvm::Function *F, StringTable &stringTable) {
+    for (const auto &user : F->users()) {
+      llvm::Function *userFunction = FindUsingFunction(user);
+      uint32_t index = stringTable.Insert(F->getName());
+      if (m_FuncToDependencies.find(userFunction) ==
+          m_FuncToDependencies.end()) {
+        m_FuncToDependencies[userFunction] =
+            std::vector<uint32_t>({index});
+      } else {
+        m_FuncToDependencies[userFunction].push_back(index);
+      }
+    }
+  }
+
   void UpdateFunctionInfo(StringTable &stringTable) {
-    // TODO: get a list of required features
     // TODO: get a list of valid shader flags
     // TODO: get a minimum shader version
+    std::unordered_map<llvm::Function *, std::vector<StringRef>>
+        FuncToUnresolvedDependencies;
     m_tables.emplace_back(std::make_unique<FunctionTable>());
     FunctionTable &functionTable = *(FunctionTable*)(m_tables.back().get());
-    m_tables.emplace_back(std::make_unique<IndexTable<uint32_t>>());
-    IndexTable<uint32_t> &indexTable = *(IndexTable<uint32_t>*)(m_tables.back().get());
+    m_tables.emplace_back(std::make_unique<IndexTable>());
+    IndexTable &indexTable = *(IndexTable*)(m_tables.back().get());
+    for (auto &function : m_Module.GetModule()->getFunctionList()) {
+      // If function is a declaration, it is an unresolved dependency in the library
+      if (function.isDeclaration() && !OP::IsDxilOpFunc(&function)) {
+        UpdateFunctionDependency(&function, stringTable);
+      }
+    }
     for (auto &function : m_Module.GetModule()->getFunctionList()) {
       if (!function.isDeclaration()) {
         StringRef mangled = function.getName();
@@ -972,11 +940,14 @@ private:
         uint32_t unmangledIndex = stringTable.Insert(unmangled);
         // Update resource Index
         uint32_t resourceIndex = UINT_MAX;
+        uint32_t functionDependencies = UINT_MAX;
         uint32_t payloadSizeInBytes = 0;
         uint32_t attrSizeInBytes = 0;
         uint32_t shaderKind = (uint32_t)PSVShaderKind::Library;
         if (m_FuncToResNameOffset.find(&function) != m_FuncToResNameOffset.end())
           resourceIndex = indexTable.AddIndex(m_FuncToResNameOffset[&function]);
+        if (m_FuncToDependencies.find(&function) != m_FuncToDependencies.end())
+          functionDependencies = indexTable.AddIndex(m_FuncToDependencies[&function]);
         if (m_Module.HasDxilFunctionProps(&function)) {
           auto props = m_Module.GetDxilFunctionProps(&function);
           if (props.IsClosestHit() || props.IsAnyHit()) {
@@ -992,9 +963,18 @@ private:
           shaderKind = (uint32_t)props.shaderKind;
         }
         ShaderFlags flags = ShaderFlags::CollectShaderFlags(&function, &m_Module);
-        functionTable.AddFunction(&function, mangledIndex, unmangledIndex,
-                                    shaderKind, resourceIndex,
-                                    payloadSizeInBytes, attrSizeInBytes, flags);
+        RuntimeDataFunctionInfo info = {};
+        info.Name = mangledIndex;
+        info.UnmangledName = unmangledIndex;
+        info.ShaderKind = shaderKind;
+        info.Resources = resourceIndex;
+        info.FunctionDependencies = functionDependencies;
+        info.PayloadSizeInBytes = payloadSizeInBytes;
+        info.AttributeSizeInBytes = attrSizeInBytes;
+        uint64_t rawFlags = flags.GetShaderFlagsRaw();
+        info.FeatureInfo1 = rawFlags & 0xffffffff;
+        info.FeatureInfo2 = (rawFlags >> 32) & 0xffffffff;
+        functionTable.Insert(info);
       }
     }
   }
@@ -1013,7 +993,7 @@ public:
     // one variable to count the number of blobs and two blobs
     uint32_t total = 4 + m_tables.size() * sizeof(RuntimeDataTableHeader);
     for (auto &&table : m_tables)
-      total += table->GetBlobSize();
+      total += table->GetPartSize();
     return total;
   }
 
@@ -1027,15 +1007,15 @@ public:
     // write records
     uint32_t curTableOffset = size * sizeof(RuntimeDataTableHeader) + 4;
     for (auto &&table : m_tables) {
-      RuntimeDataTableHeader record = { table->GetType(), table->GetBlobSize(), curTableOffset };
+      RuntimeDataTableHeader record = { table->GetType(), table->GetPartSize(), curTableOffset };
       memcpy(pCur, &record, sizeof(RuntimeDataTableHeader));
       pCur += sizeof(RuntimeDataTableHeader);
       curTableOffset += record.size;
     }
     // write tables
     for (auto &&table : m_tables) {
-      table->write(pCur);
-      pCur += table->GetBlobSize();
+      table->Write(pCur);
+      pCur += table->GetPartSize();
     }
 
     ULONG cbWritten;

+ 64 - 0
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -35,6 +35,7 @@
 #include "dxc/HLSL/DxilContainer.h"
 #include "dxc/HLSL/DxilPipelineStateValidation.h"
 #include "dxc/HLSL/DxilShaderFlags.h"
+#include "dxc/HLSL/DxilUtil.h"
 
 #include <fstream>
 #include <filesystem>
@@ -72,6 +73,7 @@ public:
 
   TEST_METHOD(CompileWhenDebugSourceThenSourceMatters)
   TEST_METHOD(CompileWhenOkThenCheckRDAT)
+  TEST_METHOD(CompileWhenOkThenCheckRDAT2)
   TEST_METHOD(CompileWhenOKThenIncludesFeatureInfo)
   TEST_METHOD(CompileWhenOKThenIncludesSignatures)
   TEST_METHOD(CompileWhenSigSquareThenIncludeSplit)
@@ -745,6 +747,68 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
   IFTBOOLMSG(blobFound, E_FAIL, "failed to find RDAT blob after compiling");
 }
 
+TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT2) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  // This is a case when the user of resource is a constant, not instruction.
+  // Compiler generates the following load instruction for texture.
+  // load %class.Texture2D, %class.Texture2D* getelementptr inbounds ([3 x
+  // %class.Texture2D], [3 x %class.Texture2D]*
+  // @"\01?ThreeTextures@@3PAV?$Texture2D@M@@A", i32 0, i32 0), align 4
+  const char *shader =
+      "SamplerState Sampler : register(s0); RWBuffer<float> Uav : "
+      "register(u0); Texture2D<float> ThreeTextures[3] : register(t0); "
+      "float function1();"
+      "[shader(\"raygeneration\")] void RayGenMain() { Uav[0] = "
+      "ThreeTextures[0].Sample(Sampler, float2(0, 0)) + function1(); }";
+  CComPtr<IDxcCompiler> pCompiler;
+  CComPtr<IDxcBlobEncoding> pSource;
+  CComPtr<IDxcBlob> pProgram;
+  CComPtr<IDxcBlobEncoding> pDisassembly;
+  CComPtr<IDxcOperationResult> pResult;
+  HRESULT status;
+
+  VERIFY_SUCCEEDED(CreateCompiler(&pCompiler));
+  CreateBlobFromText(shader, &pSource);
+  VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"hlsl.hlsl", L"main",
+                                      L"lib_6_3", nullptr, 0, nullptr, 0,
+                                      nullptr, &pResult));
+  VERIFY_SUCCEEDED(pResult->GetResult(&pProgram));
+  VERIFY_SUCCEEDED(pResult->GetStatus(&status));
+  VERIFY_SUCCEEDED(status);
+  CComPtr<IDxcContainerReflection> pReflection;
+  uint32_t partCount;
+  IFT(m_dllSupport.CreateInstance(CLSID_DxcContainerReflection, &pReflection));
+  IFT(pReflection->Load(pProgram));
+  IFT(pReflection->GetPartCount(&partCount));
+  bool blobFound = false;
+  for (uint32_t i = 0; i < partCount; ++i) {
+    uint32_t kind;
+    IFT(pReflection->GetPartKind(i, &kind));
+    if (kind == (uint32_t)hlsl::DxilFourCC::DFCC_RuntimeData) {
+      blobFound = true;
+      using namespace hlsl::DXIL::PSV;
+      CComPtr<IDxcBlob> pBlob;
+      IFT(pReflection->GetPartContent(i, &pBlob));
+      DxilRuntimeData context;
+      context.InitFromRDAT((char *)pBlob->GetBufferPointer());
+      FunctionTableReader *funcTableReader = context.GetFunctionTableReader();
+      ResourceTableReader *resTableReader = context.GetResourceTableReader();
+      VERIFY_IS_TRUE(funcTableReader->GetNumFunctions() == 1);
+      VERIFY_IS_TRUE(resTableReader->GetNumResources() == 3);
+      FunctionReader funcReader = funcTableReader->GetItem(0);
+      llvm::StringRef name(funcReader.GetUnmangledName());
+      VERIFY_IS_TRUE(name.compare("RayGenMain") == 0);
+      VERIFY_IS_TRUE(funcReader.GetShaderKind() == PSVShaderKind::RayGeneration);
+      VERIFY_IS_TRUE(funcReader.GetNumResources() == 3);
+      VERIFY_IS_TRUE(funcReader.GetNumDependencies() == 1);
+      llvm::StringRef dependencyName =
+          hlsl::dxilutil::DemangleFunctionName(funcReader.GetDependency(0));
+      VERIFY_IS_TRUE(dependencyName.compare("function1") == 0);
+    }
+  }
+  IFTBOOLMSG(blobFound, E_FAIL, "failed to find RDAT blob after compiling");
+}
+
 TEST_F(DxilContainerTest, CompileWhenOKThenIncludesFeatureInfo) {
   CComPtr<IDxcCompiler> pCompiler;
   CComPtr<IDxcBlobEncoding> pSource;