Răsfoiți Sursa

Convert DxilRuntimeReflection to pure virtual interface

- removes stl dependency from header
- change namespace from hlsl::DXIL::RDAT to hlsl::RDAT
- use enum class for RuntimeDataPartType to avoid name collisions
- bit of cleanup
Tex Riddell 7 ani în urmă
părinte
comite
8335fe4251

+ 9 - 63
include/dxc/HLSL/DxilRuntimeReflection.h

@@ -9,36 +9,10 @@
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
-#include <windows.h>
-#include <unordered_map>
-#include <vector>
-#include <memory>
+#pragma once
 #include "DxilConstants.h"
 
-namespace hlsl { namespace DXIL { namespace RDAT {
-  struct ResourceKey {
-    uint32_t Class, ID;
-    ResourceKey(uint32_t Class, uint32_t ID) : Class(Class), ID(ID) {}
-    bool operator==(const ResourceKey& other) const {
-      return other.Class == Class && other.ID == ID;
-    }
-  };
-} } }
-
-template<>
-struct std::hash<hlsl::DXIL::RDAT::ResourceKey> {
-public:
-  size_t operator()(const hlsl::DXIL::RDAT::ResourceKey& key) const throw() {
-    //static_assert(sizeof(hlsl::DXIL::RDAT::ResourceKey) == sizeof(uint64_t),
-    //              "otherwise, hash function is incorrect");
-    return (std::hash<uint32_t>()(key.Class) * (size_t)16777619U)
-           ^ std::hash<uint32_t>()(key.ID);
-    //return std::hash<uint64_t>()(*reinterpret_cast<const uint64_t*>(&key));
-  }
-};
-
 namespace hlsl {
-namespace DXIL {
 namespace RDAT {
 
 struct RuntimeDataTableHeader {
@@ -47,7 +21,7 @@ struct RuntimeDataTableHeader {
   uint32_t offset;
 };
 
-enum RuntimeDataPartType : uint32_t {
+enum class RuntimeDataPartType : uint32_t {
   Invalid = 0,
   String,
   Function,
@@ -410,43 +384,15 @@ typedef struct DXIL_LIBRARY_DESC {
 } DXIL_LIBRARY_DESC;
 
 class DxilRuntimeReflection {
-private:
-  typedef std::unordered_map<const char *, std::unique_ptr<wchar_t[]>> StringMap;
-  typedef std::vector<DXIL_RESOURCE> ResourceList;
-  typedef std::vector<DXIL_RESOURCE *> ResourceRefList;
-  typedef std::vector<DXIL_FUNCTION> FunctionList;
-  typedef std::vector<const wchar_t *> WStringList;
-
-  DxilRuntimeData m_RuntimeData;
-  StringMap m_StringMap;
-  ResourceList m_Resources;
-  FunctionList m_Functions;
-  std::unordered_map<ResourceKey, DXIL_RESOURCE *> m_ResourceMap;
-  std::unordered_map<DXIL_FUNCTION *, ResourceRefList> m_FuncToResMap;
-  std::unordered_map<DXIL_FUNCTION *, WStringList> m_FuncToStringMap;
-  bool m_initialized;
-
-  const wchar_t *GetWideString(const char *ptr);
-  void AddString(const char *ptr);
-  void InitializeReflection();
-  const DXIL_RESOURCE * const*GetResourcesForFunction(DXIL_FUNCTION &function,
-                             const FunctionReader &functionReader);
-  const wchar_t **GetDependenciesForFunction(DXIL_FUNCTION &function,
-                             const FunctionReader &functionReader);
-  DXIL_RESOURCE *AddResource(const ResourceReader &resourceReader);
-  DXIL_FUNCTION *AddFunction(const FunctionReader &functionReader);
-
 public:
-  // TODO: Implement pipeline state validation with runtime data
-  // TODO: Update BlobContainer.h to recognize 'RDAT' blob
-  DxilRuntimeReflection()
-      : m_RuntimeData(), m_StringMap(), m_Resources(), m_Functions(),
-        m_FuncToResMap(), m_FuncToStringMap(), m_initialized(false) {}
+  virtual ~DxilRuntimeReflection() {}
   // This call will allocate memory for GetLibraryReflection call
-  bool InitFromRDAT(const void *pRDAT);
-  const DXIL_LIBRARY_DESC GetLibraryReflection();
+  virtual bool InitFromRDAT(const void *pRDAT) = 0;
+  // DxilRuntimeReflection owns the memory pointed to by DXIL_LIBRARY_DESC
+  virtual const DXIL_LIBRARY_DESC GetLibraryReflection() = 0;
 };
 
-} // namespace LIB
-} // namespace DXIL
+DxilRuntimeReflection *CreateDxilRuntimeReflection();
+
+} // namespace RDAT
 } // namespace hlsl

+ 85 - 13
include/dxc/HLSL/DxilRuntimeReflection.inl

@@ -10,11 +10,22 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 #include "dxc/hlsl/DxilRuntimeReflection.h"
+#include <windows.h>
+#include <unordered_map>
+#include <vector>
+#include <memory>
 
 namespace hlsl {
-namespace DXIL {
 namespace RDAT {
 
+struct ResourceKey {
+  uint32_t Class, ID;
+  ResourceKey(uint32_t Class, uint32_t ID) : Class(Class), ID(ID) {}
+  bool operator==(const ResourceKey& other) const {
+    return other.Class == Class && other.ID == ID;
+  }
+};
+
 DxilRuntimeData::DxilRuntimeData() : DxilRuntimeData(nullptr) {}
 
 DxilRuntimeData::DxilRuntimeData(const char *ptr)
@@ -35,7 +46,7 @@ bool DxilRuntimeData::InitFromRDAT(const void *pRDAT) {
     RuntimeDataTableHeader *records = (RuntimeDataTableHeader *)(ptr + 4);
     for (uint32_t i = 0; i < TableCount; ++i) {
       RuntimeDataTableHeader *curRecord = &records[i];
-      switch (curRecord->tableType) {
+      switch (static_cast<RuntimeDataPartType>(curRecord->tableType)) {
       case RuntimeDataPartType::Resource: {
         m_ResourceTableReader.SetResourceInfo(
             (RuntimeDataResourceInfo *)(ptr + curRecord->offset),
@@ -76,7 +87,63 @@ ResourceTableReader *DxilRuntimeData::GetResourceTableReader() {
   return &m_ResourceTableReader;
 }
 
-void DxilRuntimeReflection::AddString(const char *ptr) {
+}} // hlsl::RDAT
+
+using namespace hlsl;
+using namespace RDAT;
+
+template<>
+struct std::hash<ResourceKey> {
+public:
+  size_t operator()(const ResourceKey& key) const throw() {
+    return (std::hash<uint32_t>()(key.Class) * (size_t)16777619U)
+      ^ std::hash<uint32_t>()(key.ID);
+  }
+};
+
+namespace {
+
+class DxilRuntimeReflection_impl : public DxilRuntimeReflection {
+private:
+  typedef std::unordered_map<const char *, std::unique_ptr<wchar_t[]>> StringMap;
+  typedef std::vector<DXIL_RESOURCE> ResourceList;
+  typedef std::vector<DXIL_RESOURCE *> ResourceRefList;
+  typedef std::vector<DXIL_FUNCTION> FunctionList;
+  typedef std::vector<const wchar_t *> WStringList;
+
+  DxilRuntimeData m_RuntimeData;
+  StringMap m_StringMap;
+  ResourceList m_Resources;
+  FunctionList m_Functions;
+  std::unordered_map<ResourceKey, DXIL_RESOURCE *> m_ResourceMap;
+  std::unordered_map<DXIL_FUNCTION *, ResourceRefList> m_FuncToResMap;
+  std::unordered_map<DXIL_FUNCTION *, WStringList> m_FuncToStringMap;
+  bool m_initialized;
+
+  const wchar_t *GetWideString(const char *ptr);
+  void AddString(const char *ptr);
+  void InitializeReflection();
+  const DXIL_RESOURCE * const*GetResourcesForFunction(DXIL_FUNCTION &function,
+                             const FunctionReader &functionReader);
+  const wchar_t **GetDependenciesForFunction(DXIL_FUNCTION &function,
+                             const FunctionReader &functionReader);
+  DXIL_RESOURCE *AddResource(const ResourceReader &resourceReader);
+  DXIL_FUNCTION *AddFunction(const FunctionReader &functionReader);
+
+public:
+  // TODO: Implement pipeline state validation with runtime data
+  // TODO: Update BlobContainer.h to recognize 'RDAT' blob
+  // TODO: Add size and verification to InitFromRDAT and DxilRuntimeData
+  DxilRuntimeReflection_impl()
+      : m_RuntimeData(), m_StringMap(), m_Resources(), m_Functions(),
+        m_FuncToResMap(), m_FuncToStringMap(), m_initialized(false) {}
+  virtual ~DxilRuntimeReflection_impl() {}
+  // This call will allocate memory for GetLibraryReflection call
+  bool InitFromRDAT(const void *pRDAT) override;
+  const DXIL_LIBRARY_DESC GetLibraryReflection() override;
+};
+
+void DxilRuntimeReflection_impl::AddString(const char *ptr) {
   if (m_StringMap.find(ptr) == m_StringMap.end()) {
     int size = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, ptr, -1,
                                      nullptr, 0);
@@ -89,22 +156,22 @@ void DxilRuntimeReflection::AddString(const char *ptr) {
   }
 }
 
-const wchar_t *DxilRuntimeReflection::GetWideString(const char *ptr) {
+const wchar_t *DxilRuntimeReflection_impl::GetWideString(const char *ptr) {
   if (m_StringMap.find(ptr) == m_StringMap.end()) {
     AddString(ptr);
   }
   return m_StringMap.at(ptr).get();
 }
 
-bool DxilRuntimeReflection::InitFromRDAT(const void *pRDAT) {
+bool DxilRuntimeReflection_impl::InitFromRDAT(const void *pRDAT) {
   m_initialized = m_RuntimeData.InitFromRDAT(pRDAT);
   if (m_initialized)
     InitializeReflection();
   return m_initialized;
 }
 
-const DXIL_LIBRARY_DESC DxilRuntimeReflection::GetLibraryReflection() {
-  DXIL_LIBRARY_DESC reflection;
+const DXIL_LIBRARY_DESC DxilRuntimeReflection_impl::GetLibraryReflection() {
+  DXIL_LIBRARY_DESC reflection = {};
   if (m_initialized) {
     reflection.NumResources =
         m_RuntimeData.GetResourceTableReader()->GetNumResources();
@@ -116,7 +183,7 @@ const DXIL_LIBRARY_DESC DxilRuntimeReflection::GetLibraryReflection() {
   return reflection;
 }
 
-void DxilRuntimeReflection::InitializeReflection() {
+void DxilRuntimeReflection_impl::InitializeReflection() {
   // First need to reserve spaces for resources because functions will need to
   // reference them via pointers.
   const ResourceTableReader *resourceTableReader = m_RuntimeData.GetResourceTableReader();
@@ -140,7 +207,7 @@ void DxilRuntimeReflection::InitializeReflection() {
 }
 
 DXIL_RESOURCE *
-DxilRuntimeReflection::AddResource(const ResourceReader &resourceReader) {
+DxilRuntimeReflection_impl::AddResource(const ResourceReader &resourceReader) {
   assert(m_Resources.size() < m_Resources.capacity() && "Otherwise, number of resources was incorrect");
   if (!(m_Resources.size() < m_Resources.capacity()))
     return nullptr;
@@ -157,7 +224,7 @@ DxilRuntimeReflection::AddResource(const ResourceReader &resourceReader) {
   return &resource;
 }
 
-const DXIL_RESOURCE * const*DxilRuntimeReflection::GetResourcesForFunction(
+const DXIL_RESOURCE * const*DxilRuntimeReflection_impl::GetResourcesForFunction(
     DXIL_FUNCTION &function, const FunctionReader &functionReader) {
   if (m_FuncToResMap.find(&function) == m_FuncToResMap.end())
     m_FuncToResMap.insert(std::pair<DXIL_FUNCTION *, ResourceRefList>(
@@ -177,7 +244,7 @@ const DXIL_RESOURCE * const*DxilRuntimeReflection::GetResourcesForFunction(
   return resourceList.empty() ? nullptr : resourceList.data();
 }
 
-const wchar_t **DxilRuntimeReflection::GetDependenciesForFunction(
+const wchar_t **DxilRuntimeReflection_impl::GetDependenciesForFunction(
     DXIL_FUNCTION &function, const FunctionReader &functionReader) {
   if (m_FuncToStringMap.find(&function) == m_FuncToStringMap.end())
     m_FuncToStringMap.insert(
@@ -190,7 +257,7 @@ const wchar_t **DxilRuntimeReflection::GetDependenciesForFunction(
 }
 
 DXIL_FUNCTION *
-DxilRuntimeReflection::AddFunction(const FunctionReader &functionReader) {
+DxilRuntimeReflection_impl::AddFunction(const FunctionReader &functionReader) {
   assert(m_Functions.size() < m_Functions.capacity() && "Otherwise, number of functions was incorrect");
   if (!(m_Functions.size() < m_Functions.capacity()))
     return nullptr;
@@ -212,4 +279,9 @@ DxilRuntimeReflection::AddFunction(const FunctionReader &functionReader) {
   function.MinShaderTarget = functionReader.GetMinShaderTarget();
   return &function;
 }
-}}}
+
+} // namespace anon
+
+DxilRuntimeReflection *hlsl::RDAT::CreateDxilRuntimeReflection() {
+  return new DxilRuntimeReflection_impl();
+}

+ 2 - 2
lib/HLSL/DxilContainerAssembler.cpp

@@ -34,7 +34,7 @@
 
 using namespace llvm;
 using namespace hlsl;
-using namespace hlsl::DXIL::RDAT;
+using namespace hlsl::RDAT;
 
 static DxilProgramSigSemantic KindToSystemValue(Semantic::Kind kind, DXIL::TessellatorDomain domain) {
   switch (kind) {
@@ -1013,7 +1013,7 @@ public:
     // write records
     uint32_t curTableOffset = size * sizeof(RuntimeDataTableHeader) + 4;
     for (auto &&table : m_tables) {
-      RuntimeDataTableHeader record = { table->GetType(), table->GetPartSize(), curTableOffset };
+      RuntimeDataTableHeader record = { static_cast<uint32_t>(table->GetType()), table->GetPartSize(), curTableOffset };
       memcpy(pCur, &record, sizeof(RuntimeDataTableHeader));
       pCur += sizeof(RuntimeDataTableHeader);
       curTableOffset += record.size;

+ 2 - 1
lib/HLSL/DxilContainerReflection.cpp

@@ -23,7 +23,6 @@
 #include "dxc/Support/microcom.h"
 #include "dxc/Support/FileIOHelper.h"
 #include "dxc/Support/dxcapi.impl.h"
-#include "dxc/HLSL/DxilRuntimeReflection.inl"
 #include "dxc/HLSL/DxilFunctionProps.h"
 
 #include <unordered_set>
@@ -2317,3 +2316,5 @@ ID3D12FunctionReflection *DxilLibraryReflection::GetFunctionByIndex(INT Function
   return ((m_FunctionMap.begin() + FunctionIndex)->second).get();
 }
 
+// DxilRuntimeReflection implementation
+#include "dxc/HLSL/DxilRuntimeReflection.inl"

+ 5 - 5
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -703,7 +703,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
     IFT(containerReflection->GetPartKind(i, &kind));
     if (kind == (uint32_t)hlsl::DxilFourCC::DFCC_RuntimeData) {
       blobFound = true;
-      using namespace hlsl::DXIL::RDAT;
+      using namespace hlsl::RDAT;
       CComPtr<IDxcBlob> pBlob;
       IFT(containerReflection->GetPartContent(i, &pBlob));
       // Validate using DxilRuntimeData
@@ -749,9 +749,9 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
       }
       VERIFY_IS_TRUE(resTableReader->GetNumResources() == 4);
       // This is validation test for DxilRuntimeReflection implemented on DxilRuntimeReflection.inl
-      DxilRuntimeReflection reflection;
-      VERIFY_IS_TRUE(reflection.InitFromRDAT(pBlob->GetBufferPointer()));
-      DXIL_LIBRARY_DESC lib_reflection = reflection.GetLibraryReflection();
+      unique_ptr<DxilRuntimeReflection> pReflection(CreateDxilRuntimeReflection());
+      VERIFY_IS_TRUE(pReflection->InitFromRDAT(pBlob->GetBufferPointer()));
+      DXIL_LIBRARY_DESC lib_reflection = pReflection->GetLibraryReflection();
       VERIFY_IS_TRUE(lib_reflection.NumFunctions == 3);
       for (uint32_t j = 0; j < 3; ++j) {
         DXIL_FUNCTION function = lib_reflection.pFunction[j];
@@ -847,7 +847,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT2) {
     IFT(pReflection->GetPartKind(i, &kind));
     if (kind == (uint32_t)hlsl::DxilFourCC::DFCC_RuntimeData) {
       blobFound = true;
-      using namespace hlsl::DXIL::RDAT;
+      using namespace hlsl::RDAT;
       CComPtr<IDxcBlob> pBlob;
       IFT(pReflection->GetPartContent(i, &pBlob));
       DxilRuntimeData context;