Xiang Li пре 4 година
родитељ
комит
bd38504d5e

+ 2 - 0
include/dxc/DxilRootSignature/DxilRootSignature.h

@@ -325,6 +325,8 @@ struct DxilVersionedRootSignatureDesc {
   };
   };
 };
 };
 
 
+void printRootSignature(const DxilVersionedRootSignatureDesc &RS, llvm::raw_ostream &os);
+
 // Use this class to represent a root signature that may be in memory or serialized.
 // Use this class to represent a root signature that may be in memory or serialized.
 // There is just enough API surface to help callers not take a dependency on Windows headers.
 // There is just enough API surface to help callers not take a dependency on Windows headers.
 class RootSignatureHandle {
 class RootSignatureHandle {

+ 193 - 0
lib/DxilRootSignature/DxilRootSignature.cpp

@@ -190,4 +190,197 @@ void DeleteRootSignature(const DxilVersionedRootSignatureDesc * pRootSignature)
   delete pRootSignature;
   delete pRootSignature;
 }
 }
 
 
+namespace {
+// Dump root sig.
+
+void printRootSigFlags(DxilRootSignatureFlags Flags, raw_ostream &os) {
+  if (Flags == DxilRootSignatureFlags::None)
+    return;
+  unsigned UFlags = (unsigned)Flags;
+
+  std::pair<unsigned, std::string> FlagTable[] = {
+      {unsigned(DxilRootSignatureFlags::AllowInputAssemblerInputLayout),
+       "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT"},
+      {unsigned(DxilRootSignatureFlags::DenyVertexShaderRootAccess),
+       "DenyVertexShaderRootAccess"},
+      {unsigned(DxilRootSignatureFlags::DenyHullShaderRootAccess),
+       "DenyHullShaderRootAccess"},
+      {unsigned(DxilRootSignatureFlags::DenyDomainShaderRootAccess),
+       "DenyDomainShaderRootAccess"},
+      {unsigned(DxilRootSignatureFlags::DenyGeometryShaderRootAccess),
+       "DenyGeometryShaderRootAccess"},
+      {unsigned(DxilRootSignatureFlags::DenyPixelShaderRootAccess),
+       "DenyPixelShaderRootAccess"},
+      {unsigned(DxilRootSignatureFlags::AllowStreamOutput),
+       "AllowStreamOutput"},
+      {unsigned(DxilRootSignatureFlags::LocalRootSignature),
+       "LocalRootSignature"},
+      {unsigned(DxilRootSignatureFlags::DenyAmplificationShaderRootAccess),
+       "DenyAmplificationShaderRootAccess"},
+      {unsigned(DxilRootSignatureFlags::DenyMeshShaderRootAccess),
+       "DenyMeshShaderRootAccess"},
+      {unsigned(DxilRootSignatureFlags::CBVSRVUAVHeapDirectlyIndexed),
+       "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED"},
+      {unsigned(DxilRootSignatureFlags::SamplerHeapDirectlyIndexed),
+       "SAMPLER_HEAP_DIRECTLY_INDEXED"},
+      {unsigned(DxilRootSignatureFlags::AllowLowTierReservedHwCbLimit),
+       "AllowLowTierReservedHwCbLimit"},
+  };
+  os << "RootFlags(";
+  SmallVector<std::string, 4> FlagStrs;
+  for (auto &f : FlagTable) {
+    if (UFlags & f.first)
+      FlagStrs.emplace_back(f.second);
+  }
+  auto it = FlagStrs.begin();
+  os << *(it++);
+  for (; it != FlagStrs.end(); it++) {
+    os << "|" << *it;
+  }
+
+  os << "),";
+}
+
+void printDesc(unsigned Reg, unsigned Space, unsigned Size, raw_ostream &os) {
+  os << Reg;
+  if (Space)
+    os << ", space=" << Space;
+  if (Size && Size != 1)
+    os << ", numDescriptors =" << Size;
+}
+
+void printDescType(DxilDescriptorRangeType Ty, raw_ostream &os) {
+  switch (Ty) {
+  case DxilDescriptorRangeType::CBV: {
+    os << "CBV(b";
+  } break;
+  case DxilDescriptorRangeType::Sampler: {
+    os << "Sampler(s";
+  } break;
+  case DxilDescriptorRangeType::UAV: {
+    os << "UAV(u";
+  } break;
+  case DxilDescriptorRangeType::SRV: {
+    os << "SRV(t";
+  } break;
+  }
+}
+
+template <typename RangeTy> void printDescRange(RangeTy &R, raw_ostream &os) {
+  printDescType(R.RangeType, os);
+  printDesc(R.BaseShaderRegister, R.RegisterSpace, R.NumDescriptors, os);
+  os << ")";
+}
+
+template <typename TableTy> void printDescTable(TableTy &Tab, raw_ostream &os) {
+  for (unsigned i = 0; i < Tab.NumDescriptorRanges; i++) {
+    auto *pRange = Tab.pDescriptorRanges + i;
+    printDescRange(*pRange, os);
+    os << ",";
+  }
+}
+
+void printVisibility(DxilShaderVisibility v, raw_ostream &os) {
+  switch (v) {
+  default:
+    break;
+  case DxilShaderVisibility::Amplification:
+    os << ",visibility=SHADER_VISIBILITY_AMPLIFICATION";
+    break;
+  case DxilShaderVisibility::Domain:
+    os << ",visibility=SHADER_VISIBILITY_DOMAIN";
+    break;
+  case DxilShaderVisibility::Geometry:
+    os << ",visibility=SHADER_VISIBILITY_GEOMETRY";
+    break;
+  case DxilShaderVisibility::Hull:
+    os << ",visibility=SHADER_VISIBILITY_HULL";
+    break;
+  case DxilShaderVisibility::Mesh:
+    os << ",visibility=SHADER_VISIBILITY_MESH";
+    break;
+  case DxilShaderVisibility::Pixel:
+    os << ",visibility=SHADER_VISIBILITY_PIXEL";
+    break;
+  case DxilShaderVisibility::Vertex:
+    os << ",visibility=SHADER_VISIBILITY_VERTEX";
+    break;
+  }
+}
+
+template <typename ParamTy>
+void printRootParam(ParamTy &Param, raw_ostream &os) {
+  switch (Param.ParameterType) {
+  case DxilRootParameterType::CBV:
+    printDescType(DxilDescriptorRangeType::CBV, os);
+    printDesc(Param.Descriptor.ShaderRegister, Param.Descriptor.RegisterSpace, 0,
+             os);
+
+    break;
+  case DxilRootParameterType::SRV:
+    printDescType(DxilDescriptorRangeType::SRV, os);
+    printDesc(Param.Descriptor.ShaderRegister, Param.Descriptor.RegisterSpace, 0,
+             os);
+    break;
+  case DxilRootParameterType::UAV:
+    printDescType(DxilDescriptorRangeType::UAV, os);
+    printDesc(Param.Descriptor.ShaderRegister, Param.Descriptor.RegisterSpace, 0,
+             os);
+    break;
+  case DxilRootParameterType::Constants32Bit:
+    os << "RootConstants(num32BitConstants=" << Param.Constants.Num32BitValues
+       << "b";
+    printDesc(Param.Constants.ShaderRegister, Param.Constants.RegisterSpace, 0,
+             os);
+    break;
+  case DxilRootParameterType::DescriptorTable:
+    os << "DescriptorTable(";
+    printDescTable(Param.DescriptorTable, os);
+    break;
+  }
+
+  printVisibility(Param.ShaderVisibility, os);
+  os << ")";
+}
+
+void printSampler(DxilStaticSamplerDesc &Sampler, raw_ostream &os) {
+  // StaticSampler(s4, filter=FILTER_MIN_MAG_MIP_LINEAR)
+  os << "StaticSampler(s" << Sampler.ShaderRegister
+     << ", space=" << Sampler.RegisterSpace;
+  // TODO: set the fileds.
+  printVisibility(Sampler.ShaderVisibility, os);
+  os << ")";
+}
+
+template <typename DescTy> void printRootSig(DescTy &RS, raw_ostream &os) {
+  printRootSigFlags(RS.Flags, os);
+  for (unsigned i = 0; i < RS.NumParameters; i++) {
+    auto *pParam = RS.pParameters + i;
+    printRootParam(*pParam, os);
+    os << ",";
+  }
+  for (unsigned i = 0; i < RS.NumStaticSamplers; i++) {
+    auto *pSampler = RS.pStaticSamplers + i;
+    printSampler(*pSampler, os);
+    os << ",";
+  }
+}
+} // namespace
+
+void printRootSignature(const DxilVersionedRootSignatureDesc &RS, raw_ostream &os) {
+  switch (RS.Version) {
+  case DxilRootSignatureVersion::Version_1_0:
+    printRootSig(RS.Desc_1_0, os);
+    break;
+  case DxilRootSignatureVersion::Version_1_1:
+  default:
+    DXASSERT(RS.Version == DxilRootSignatureVersion::Version_1_1,
+             "else version is incorrect");
+    printRootSig(RS.Desc_1_1, os);
+    break;
+  }
+  os.flush();
+}
+
+
 } // namespace hlsl
 } // namespace hlsl

+ 1 - 0
tools/clang/tools/dxa/CMakeLists.txt

@@ -10,6 +10,7 @@ set( LLVM_LINK_COMPONENTS
   ${LLVM_TARGETS_TO_BUILD}
   ${LLVM_TARGETS_TO_BUILD}
   DXIL
   DXIL
   DxilContainer
   DxilContainer
+  DxilRootSignature
   HLSL
   HLSL
   dxcsupport
   dxcsupport
   Option     # option library
   Option     # option library

+ 99 - 45
tools/clang/tools/dxa/dxa.cpp

@@ -17,10 +17,12 @@
 #include "dxc/Support/dxcapi.use.h"
 #include "dxc/Support/dxcapi.use.h"
 #include "dxc/Support/HLSLOptions.h"
 #include "dxc/Support/HLSLOptions.h"
 #include "dxc/DxilContainer/DxilContainer.h"
 #include "dxc/DxilContainer/DxilContainer.h"
+#include "dxc/DxilRootSignature/DxilRootSignature.h"
 
 
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support//MSFileSystem.h"
 #include "llvm/Support//MSFileSystem.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/FileSystem.h"
+#include "llvm/Support/raw_ostream.h"
 #include <dia2.h>
 #include <dia2.h>
 #include <intsafe.h>
 #include <intsafe.h>
 
 
@@ -53,12 +55,16 @@ static cl::opt<bool> ListFiles("listfiles",
 static cl::opt<std::string>
 static cl::opt<std::string>
     ExtractFile("extractfile", cl::desc("Extract file from debug information (use '*' for all files)"));
     ExtractFile("extractfile", cl::desc("Extract file from debug information (use '*' for all files)"));
 
 
+static cl::opt<bool> DumpRootSig("dumprs",
+                               cl::desc("Dump root signature"),
+                               cl::init(false));
 
 
 class DxaContext {
 class DxaContext {
 
 
 private:
 private:
   DxcDllSupport &m_dxcSupport;
   DxcDllSupport &m_dxcSupport;
   HRESULT FindModule(hlsl::DxilFourCC fourCC, IDxcBlob *pSource, IDxcLibrary *pLibrary, IDxcBlob **ppTarget);
   HRESULT FindModule(hlsl::DxilFourCC fourCC, IDxcBlob *pSource, IDxcLibrary *pLibrary, IDxcBlob **ppTarget);
+  bool ExtractPart(uint32_t Part, IDxcBlob **ppTargetBlob);
 public:
 public:
   DxaContext(DxcDllSupport &dxcSupport) : m_dxcSupport(dxcSupport) {}
   DxaContext(DxcDllSupport &dxcSupport) : m_dxcSupport(dxcSupport) {}
 
 
@@ -67,6 +73,7 @@ public:
   bool ExtractPart(const char *pName);
   bool ExtractPart(const char *pName);
   void ListFiles();
   void ListFiles();
   void ListParts();
   void ListParts();
+  void DumpRS();
 };
 };
 
 
 void DxaContext::Assemble() {
 void DxaContext::Assemble() {
@@ -189,11 +196,29 @@ bool DxaContext::ExtractFile(const char *pName) {
   return printedAny;
   return printedAny;
 }
 }
 
 
-bool DxaContext::ExtractPart(const char *pName) {
+bool DxaContext::ExtractPart(uint32_t PartKind, IDxcBlob **ppTargetBlob) {
   CComPtr<IDxcContainerReflection> pReflection;
   CComPtr<IDxcContainerReflection> pReflection;
   CComPtr<IDxcBlobEncoding> pSource;
   CComPtr<IDxcBlobEncoding> pSource;
   UINT32 partCount;
   UINT32 partCount;
+  ReadFileIntoBlob(m_dxcSupport, StringRefUtf16(InputFilename), &pSource);
+  IFT(m_dxcSupport.CreateInstance(CLSID_DxcContainerReflection, &pReflection));
+  IFT(pReflection->Load(pSource));
+  IFT(pReflection->GetPartCount(&partCount));
+
+    for (UINT32 i = 0; i < partCount; ++i) {
+    UINT32 curPartKind;
+    IFT(pReflection->GetPartKind(i, &curPartKind));
+    if (curPartKind == PartKind) {
+      CComPtr<IDxcBlob> pContent;
+      IFT(pReflection->GetPartContent(i, ppTargetBlob));
+      return true;
+    }
+  }
+  return false;
 
 
+}
+
+bool DxaContext::ExtractPart(const char *pName) {
   // If the part name is 'module', don't just extract the part,
   // If the part name is 'module', don't just extract the part,
   // but also skip the appropriate header.
   // but also skip the appropriate header.
   bool extractModule = strcmp("module", pName) == 0;
   bool extractModule = strcmp("module", pName) == 0;
@@ -205,56 +230,50 @@ bool DxaContext::ExtractPart(const char *pName) {
     extractModule = true;
     extractModule = true;
   }
   }
 
 
-  ReadFileIntoBlob(m_dxcSupport, StringRefUtf16(InputFilename), &pSource);
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcContainerReflection, &pReflection));
-  IFT(pReflection->Load(pSource));
-  IFT(pReflection->GetPartCount(&partCount));
   IFTARG(strlen(pName) == 4);
   IFTARG(strlen(pName) == 4);
 
 
-  const UINT32 matchName = ((UINT32)pName[0] | ((UINT32)pName[1] << 8) | ((UINT32)pName[2] << 16) | ((UINT32)pName[3] << 24));
-  for (UINT32 i = 0; i < partCount; ++i) {
-    UINT32 partKind;
-    IFT(pReflection->GetPartKind(i, &partKind));
-    if (partKind == matchName) {
-      CComPtr<IDxcBlob> pContent;
-      IFT(pReflection->GetPartContent(i, &pContent));
-      if (OutputFilename.empty()) {
-        if (InputFilename == "-") {
-          OutputFilename = "-";
-        }
-        else {
-          OutputFilename = InputFilename.getValue();
-          OutputFilename += ".";
-          if (extractModule) {
-            OutputFilename += "ll";
-          }
-          else {
-            OutputFilename += pName;
-          }
-        }
-      }
-
+  const UINT32 matchName =
+      ((UINT32)pName[0] | ((UINT32)pName[1] << 8) | ((UINT32)pName[2] << 16) |
+       ((UINT32)pName[3] << 24));
+  CComPtr<IDxcBlob> pContent;
+  if (!ExtractPart(matchName, &pContent))
+    return false;
+
+  if (OutputFilename.empty()) {
+    if (InputFilename == "-") {
+      OutputFilename = "-";
+    } else {
+      OutputFilename = InputFilename.getValue();
+      OutputFilename += ".";
       if (extractModule) {
       if (extractModule) {
-        char *pDxilPart = (char *)pContent->GetBufferPointer();
-        hlsl::DxilProgramHeader *pProgramHdr = (hlsl::DxilProgramHeader *)pDxilPart;
-        const char *pBitcode;
-        uint32_t bitcodeLength;
-        GetDxilProgramBitcode(pProgramHdr, &pBitcode, &bitcodeLength);
-        uint32_t offset = pBitcode - pDxilPart;
-
-        CComPtr<IDxcLibrary> pLib;
-        CComPtr<IDxcBlob> pModuleBlob;
-        IFT(m_dxcSupport.CreateInstance(CLSID_DxcLibrary, &pLib));
-        IFT(pLib->CreateBlobFromBlob(pContent, offset, bitcodeLength, &pModuleBlob));
-        std::swap(pModuleBlob, pContent);
+        OutputFilename += "ll";
+      } else {
+        OutputFilename += pName;
       }
       }
-
-      WriteBlobToFile(pContent, StringRefUtf16(OutputFilename), DXC_CP_UTF8); // TODO: Support DefaultTextCodePage
-      printf("%Iu bytes written to %s\n", pContent->GetBufferSize(), OutputFilename.c_str());
-      return true;
     }
     }
   }
   }
-  return false;
+
+  if (extractModule) {
+    char *pDxilPart = (char *)pContent->GetBufferPointer();
+    hlsl::DxilProgramHeader *pProgramHdr = (hlsl::DxilProgramHeader *)pDxilPart;
+    const char *pBitcode;
+    uint32_t bitcodeLength;
+    GetDxilProgramBitcode(pProgramHdr, &pBitcode, &bitcodeLength);
+    uint32_t offset = pBitcode - pDxilPart;
+
+    CComPtr<IDxcLibrary> pLib;
+    CComPtr<IDxcBlob> pModuleBlob;
+    IFT(m_dxcSupport.CreateInstance(CLSID_DxcLibrary, &pLib));
+    IFT(pLib->CreateBlobFromBlob(pContent, offset, bitcodeLength,
+                                 &pModuleBlob));
+    std::swap(pModuleBlob, pContent);
+  }
+
+  WriteBlobToFile(pContent, StringRefUtf16(OutputFilename),
+                  DXC_CP_UTF8); // TODO: Support DefaultTextCodePage
+  printf("%Iu bytes written to %s\n", pContent->GetBufferSize(),
+         OutputFilename.c_str());
+  return true;
 }
 }
 
 
 void DxaContext::ListParts() {
 void DxaContext::ListParts() {
@@ -283,6 +302,37 @@ void DxaContext::ListParts() {
   }
   }
 }
 }
 
 
+void DxaContext::DumpRS() {
+  const char *pName = "RTS0";
+  const UINT32 matchName =
+      ((UINT32)pName[0] | ((UINT32)pName[1] << 8) | ((UINT32)pName[2] << 16) |
+       ((UINT32)pName[3] << 24));
+  CComPtr<IDxcBlob> pContent;
+  if (!ExtractPart(matchName, &pContent)) {
+    printf("cannot find root signature part");
+    return;
+  }
+
+  const void *serializedData = pContent->GetBufferPointer();
+  uint32_t serializedSize = pContent->GetBufferSize();
+  hlsl::RootSignatureHandle rootsig;
+  rootsig.LoadSerialized(static_cast<const uint8_t *>(serializedData),
+                         serializedSize);
+  try {
+    rootsig.Deserialize();
+  } catch (const hlsl::Exception &e) {
+    printf("fail to deserialize root sig %s", e.msg.c_str());
+    return;
+  }
+
+  if (const hlsl::DxilVersionedRootSignatureDesc *pRS = rootsig.GetDesc()) {
+    std::string str;
+    llvm::raw_string_ostream os(str);
+    hlsl::printRootSignature(*pRS, os);
+    printf("%s", str.c_str());
+  }
+}
+
 using namespace hlsl::options;
 using namespace hlsl::options;
 
 
 int __cdecl main(int argc, _In_reads_z_(argc) char **argv) {
 int __cdecl main(int argc, _In_reads_z_(argc) char **argv) {
@@ -311,6 +361,7 @@ int __cdecl main(int argc, _In_reads_z_(argc) char **argv) {
       return 2;
       return 2;
     }
     }
 
 
+
     DxcDllSupport dxcSupport;
     DxcDllSupport dxcSupport;
     dxc::EnsureEnabled(dxcSupport);
     dxc::EnsureEnabled(dxcSupport);
     DxaContext context(dxcSupport);
     DxaContext context(dxcSupport);
@@ -333,6 +384,9 @@ int __cdecl main(int argc, _In_reads_z_(argc) char **argv) {
       if (!context.ExtractFile(ExtractFile.c_str())) {
       if (!context.ExtractFile(ExtractFile.c_str())) {
         return 1;
         return 1;
       }
       }
+    } else if (DumpRootSig) {
+      pStage = "Dump root sig";
+      context.DumpRS();
     }
     }
     else {
     else {
       pStage = "Assembling";
       pStage = "Assembling";