Просмотр исходного кода

Fix bad pointers for DXIL_FUNCTION::Resources

- Add all resources up front
- Add resource map to look up DXIL_RESOURCE ptr from Class&ID
- Add resource pointers looked up to function's resource ptr list
- Use pointer to function's resource ptr list, rather than the original
  array used for global storage.
Tex Riddell 7 лет назад
Родитель
Сommit
817d526250

+ 26 - 3
include/dxc/HLSL/DxilRuntimeReflection.h

@@ -15,6 +15,28 @@
 #include <memory>
 #include "DxilConstants.h"
 
+namespace hlsl { namespace DXIL { namespace RDAT {
+  struct ResourceKey {
+    uint32_t Class, ID;
+    ResourceKey(uint32_t Class, uint32_t ID) : Class(Class), ID(ID) {}
+    bool operator==(const ResourceKey& other) const {
+      return other.Class == Class && other.ID == ID;
+    }
+  };
+} } }
+
+template<>
+struct std::hash<hlsl::DXIL::RDAT::ResourceKey> {
+public:
+  size_t operator()(const hlsl::DXIL::RDAT::ResourceKey& key) const throw() {
+    //static_assert(sizeof(hlsl::DXIL::RDAT::ResourceKey) == sizeof(uint64_t),
+    //              "otherwise, hash function is incorrect");
+    return (std::hash<uint32_t>()(key.Class) * (size_t)16777619U)
+           ^ std::hash<uint32_t>()(key.ID);
+    //return std::hash<uint64_t>()(*reinterpret_cast<const uint64_t*>(&key));
+  }
+};
+
 namespace hlsl {
 namespace DXIL {
 namespace RDAT {
@@ -362,7 +384,7 @@ typedef struct DXIL_FUNCTION {
   LPCWSTR Name;
   LPCWSTR UnmangledName;
   uint32_t NumResources;
-  const DXIL_RESOURCE *Resources;
+  const DXIL_RESOURCE * const*Resources;
   uint32_t NumFunctionDependencies;
   const LPCWSTR *FunctionDependencies;
   uint32_t ShaderKind;
@@ -399,6 +421,7 @@ private:
   StringMap m_StringMap;
   ResourceList m_Resources;
   FunctionList m_Functions;
+  std::unordered_map<ResourceKey, DXIL_RESOURCE *> m_ResourceMap;
   std::unordered_map<DXIL_FUNCTION *, ResourceRefList> m_FuncToResMap;
   std::unordered_map<DXIL_FUNCTION *, WStringList> m_FuncToStringMap;
   bool m_initialized;
@@ -406,8 +429,8 @@ private:
   const wchar_t *GetWideString(const char *ptr);
   void AddString(const char *ptr);
   void InitializeReflection();
-  DXIL_RESOURCE *GetResourcesForFunction(DXIL_FUNCTION &function,
-                                         const FunctionReader &functionReader);
+  const DXIL_RESOURCE * const*GetResourcesForFunction(DXIL_FUNCTION &function,
+                             const FunctionReader &functionReader);
   const wchar_t **GetDependenciesForFunction(DXIL_FUNCTION &function,
                              const FunctionReader &functionReader);
   DXIL_RESOURCE *AddResource(const ResourceReader &resourceReader);

+ 44 - 26
include/dxc/HLSL/DxilRuntimeReflection.inl

@@ -119,12 +119,21 @@ const DXIL_LIBRARY_DESC DxilRuntimeReflection::GetLibraryReflection() {
 void DxilRuntimeReflection::InitializeReflection() {
   // First need to reserve spaces for resources because functions will need to
   // reference them via pointers.
-  m_Resources.reserve(
-      m_RuntimeData.GetResourceTableReader()->GetNumResources());
-  const FunctionTableReader *tableReader =
-      m_RuntimeData.GetFunctionTableReader();
-  for (uint32_t i = 0; i < tableReader->GetNumFunctions(); ++i) {
-    FunctionReader functionReader = tableReader->GetItem(i);
+  const ResourceTableReader *resourceTableReader = m_RuntimeData.GetResourceTableReader();
+  m_Resources.reserve(resourceTableReader->GetNumResources());
+  for (uint32_t i = 0; i < resourceTableReader->GetNumResources(); ++i) {
+    ResourceReader resourceReader = resourceTableReader->GetItem(i);
+    AddString(resourceReader.GetName());
+    DXIL_RESOURCE *pResource = AddResource(resourceReader);
+    if (pResource) {
+      ResourceKey key(pResource->Class, pResource->ID);
+      m_ResourceMap[key] = pResource;
+    }
+  }
+  const FunctionTableReader *functionTableReader = m_RuntimeData.GetFunctionTableReader();
+  m_Functions.reserve(functionTableReader->GetNumFunctions());
+  for (uint32_t i = 0; i < functionTableReader->GetNumFunctions(); ++i) {
+    FunctionReader functionReader = functionTableReader->GetItem(i);
     AddString(functionReader.GetName());
     AddFunction(functionReader);
   }
@@ -132,34 +141,40 @@ void DxilRuntimeReflection::InitializeReflection() {
 
 DXIL_RESOURCE *
 DxilRuntimeReflection::AddResource(const ResourceReader &resourceReader) {
-  if (m_Resources.size() < m_Resources.capacity()) {
-    m_Resources.emplace_back(DXIL_RESOURCE({0}));
-    DXIL_RESOURCE &resource = m_Resources.back();
-    resource.Class = (uint32_t)resourceReader.GetResourceClass();
-    resource.Kind = (uint32_t)resourceReader.GetResourceKind();
-    resource.Space = resourceReader.GetSpace();
-    resource.LowerBound = resourceReader.GetLowerBound();
-    resource.UpperBound = resourceReader.GetUpperBound();
-    resource.ID = resourceReader.GetID();
-    resource.Flags = resourceReader.GetFlags();
-    resource.Name = GetWideString(resourceReader.GetName());
-    return &resource;
-  }
-  // TODO: assert here?
-  return nullptr;
+  assert(m_Resources.size() < m_Resources.capacity() && "Otherwise, number of resources was incorrect");
+  if (!(m_Resources.size() < m_Resources.capacity()))
+    return nullptr;
+  m_Resources.emplace_back(DXIL_RESOURCE({0}));
+  DXIL_RESOURCE &resource = m_Resources.back();
+  resource.Class = (uint32_t)resourceReader.GetResourceClass();
+  resource.Kind = (uint32_t)resourceReader.GetResourceKind();
+  resource.Space = resourceReader.GetSpace();
+  resource.LowerBound = resourceReader.GetLowerBound();
+  resource.UpperBound = resourceReader.GetUpperBound();
+  resource.ID = resourceReader.GetID();
+  resource.Flags = resourceReader.GetFlags();
+  resource.Name = GetWideString(resourceReader.GetName());
+  return &resource;
 }
 
-DXIL_RESOURCE *DxilRuntimeReflection::GetResourcesForFunction(
+const DXIL_RESOURCE * const*DxilRuntimeReflection::GetResourcesForFunction(
     DXIL_FUNCTION &function, const FunctionReader &functionReader) {
   if (m_FuncToResMap.find(&function) == m_FuncToResMap.end())
     m_FuncToResMap.insert(std::pair<DXIL_FUNCTION *, ResourceRefList>(
         &function, ResourceRefList()));
   ResourceRefList &resourceList = m_FuncToResMap.at(&function);
-  for (uint32_t i = 0; i < functionReader.GetNumResources(); ++i) {
-    const ResourceReader resourceReader = functionReader.GetResource(i);
-    resourceList.emplace_back(AddResource(resourceReader));
+  if (resourceList.empty()) {
+    resourceList.reserve(functionReader.GetNumResources());
+    for (uint32_t i = 0; i < functionReader.GetNumResources(); ++i) {
+      const ResourceReader resourceReader = functionReader.GetResource(i);
+      ResourceKey key((uint32_t)resourceReader.GetResourceClass(),
+                      resourceReader.GetID());
+      auto it = m_ResourceMap.find(key);
+      assert(it != m_ResourceMap.end() && it->second && "Otherwise, resource was not in map, or was null");
+      resourceList.emplace_back(it->second);
+    }
   }
-  return resourceList.empty() ? nullptr : *resourceList.data();
+  return resourceList.empty() ? nullptr : resourceList.data();
 }
 
 const wchar_t **DxilRuntimeReflection::GetDependenciesForFunction(
@@ -176,6 +191,9 @@ const wchar_t **DxilRuntimeReflection::GetDependenciesForFunction(
 
 DXIL_FUNCTION *
 DxilRuntimeReflection::AddFunction(const FunctionReader &functionReader) {
+  assert(m_Functions.size() < m_Functions.capacity() && "Otherwise, number of functions was incorrect");
+  if (!(m_Functions.size() < m_Functions.capacity()))
+    return nullptr;
   m_Functions.emplace_back(DXIL_FUNCTION({0}));
   DXIL_FUNCTION &function = m_Functions.back();
   function.Name = GetWideString(functionReader.GetName());

+ 2 - 2
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -767,7 +767,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           VERIFY_IS_TRUE(featureFlag == rawFlag);
           VERIFY_IS_TRUE(function.NumResources == 1);
           VERIFY_IS_TRUE(function.NumFunctionDependencies == 0);
-          const DXIL_RESOURCE resource = function.Resources[0];
+          const DXIL_RESOURCE &resource = *function.Resources[0];
           VERIFY_IS_TRUE(resource.Class == (uint32_t)hlsl::DXIL::ResourceClass::UAV);
           VERIFY_IS_TRUE(resource.Kind == (uint32_t)hlsl::DXIL::ResourceKind::Texture1D);
           std::wstring wName = resource.Name;
@@ -784,7 +784,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
           VERIFY_IS_TRUE(function.NumFunctionDependencies == 0);
           std::unordered_set<std::wstring> stringSet = { L"$Globals", L"b_buf", L"tex2" };
           for (uint32_t j = 0; j < 3; ++j) {
-            const DXIL_RESOURCE resource = function.Resources[j];
+            const DXIL_RESOURCE &resource = *function.Resources[j];
             std::wstring compareName = resource.Name;
             VERIFY_IS_TRUE(stringSet.find(compareName) != stringSet.end());
           }