Browse Source

Added way for caller to replace args in PDB utils (#3595)

Adam Yang 4 years ago
parent
commit
640c9af748

+ 7 - 0
include/dxc/dxcapi.h

@@ -582,6 +582,11 @@ struct IDxcVersionInfo3 : public IUnknown {
   ) = 0;
 };
 
+struct DxcArgPair {
+  const WCHAR *pName;
+  const WCHAR *pValue;
+};
+
 CROSS_PLATFORM_UUIDOF(IDxcPdbUtils, "E6C9647E-9D6A-4C3B-B94C-524B5A6C343D")
 struct IDxcPdbUtils : public IUnknown {
   virtual HRESULT STDMETHODCALLTYPE Load(_In_ IDxcBlob *pPdbOrDxil) = 0;
@@ -616,6 +621,8 @@ struct IDxcPdbUtils : public IUnknown {
 
   virtual HRESULT STDMETHODCALLTYPE SetCompiler(_In_ IDxcCompiler3 *pCompiler) = 0;
   virtual HRESULT STDMETHODCALLTYPE CompileForFullPDB(_COM_Outptr_ IDxcResult **ppResult) = 0;
+  virtual HRESULT STDMETHODCALLTYPE OverrideArgs(_In_ DxcArgPair *pArgPairs, UINT32 uNumArgPairs) = 0;
+  virtual HRESULT STDMETHODCALLTYPE OverrideRootSignature(_In_ const WCHAR *pRootSignature) = 0;
 };
 
 // Note: __declspec(selectany) requires 'extern'

+ 126 - 46
tools/clang/tools/dxcompiler/dxcpdbutils.cpp

@@ -267,9 +267,7 @@ private:
   CComPtr<IDxcBlob> m_pDebugProgramBlob;
   CComPtr<IDxcBlob> m_ContainerBlob;
   std::vector<Source_File> m_SourceFiles;
-  std::vector<std::wstring> m_Defines;
-  std::vector<std::wstring> m_Args;
-  std::vector<std::wstring> m_Flags;
+
   std::wstring m_EntryPoint;
   std::wstring m_TargetProfile;
   std::wstring m_Name;
@@ -290,17 +288,24 @@ private:
     std::wstring Value;
   };
   std::vector<ArgPair> m_ArgPairs;
+  std::vector<std::wstring> m_Defines;
+  std::vector<std::wstring> m_Args;
+  std::vector<std::wstring> m_Flags;
 
-  void Reset() {
-    m_pDebugProgramBlob = nullptr;
-    m_InputBlob = nullptr;
-    m_ContainerBlob = nullptr;
-    m_SourceFiles.clear();
+  void ResetAllArgs() {
+    m_ArgPairs.clear();
     m_Defines.clear();
     m_Args.clear();
     m_Flags.clear();
     m_EntryPoint.clear();
     m_TargetProfile.clear();
+  }
+
+  void Reset() {
+    m_pDebugProgramBlob = nullptr;
+    m_InputBlob = nullptr;
+    m_ContainerBlob = nullptr;
+    m_SourceFiles.clear();
     m_Name.clear();
     m_MainFileName.clear();
     m_HashBlob = nullptr;
@@ -308,8 +313,8 @@ private:
     m_VersionInfo = {};
     m_VersionCommitSha.clear();
     m_VersionString.clear();
-    m_ArgPairs.clear();
     m_pCachedRecompileResult = nullptr;
+    ResetAllArgs();
   }
 
   bool HasSources() const {
@@ -503,38 +508,7 @@ private:
             newPair.Name = ToWstring(pair.Name);
             newPair.Value = ToWstring(pair.Value);
           }
-
-          bool excludeFromFlags = false;
-          if (newPair.Name == L"E") {
-            m_EntryPoint = newPair.Value;
-            excludeFromFlags = true;
-          }
-          else if (newPair.Name == L"T") {
-            m_TargetProfile = newPair.Value;
-            excludeFromFlags = true;
-          }
-          else if (newPair.Name == L"D") {
-            m_Defines.push_back(newPair.Value);
-            excludeFromFlags = true;
-          }
-
-          std::wstring nameWithDash;
-          if (newPair.Name.size())
-            nameWithDash = std::wstring(L"-") + newPair.Name;
-
-          if (!excludeFromFlags) {
-            if (nameWithDash.size())
-              m_Flags.push_back(nameWithDash);
-            if (newPair.Value.size())
-              m_Flags.push_back(newPair.Value);
-          }
-
-          if (nameWithDash.size())
-            m_Args.push_back(nameWithDash);
-          if (newPair.Value.size())
-            m_Args.push_back(newPair.Value);
-
-          m_ArgPairs.push_back( std::move(newPair) );
+          AddArgPair(std::move(newPair));
         }
 
         // Entry point might have been omitted. Set it to main by default.
@@ -593,6 +567,40 @@ private:
     return S_OK;
   }
 
+  void AddArgPair(ArgPair &&newPair) {
+    bool excludeFromFlags = false;
+    if (newPair.Name == L"E") {
+      m_EntryPoint = newPair.Value;
+      excludeFromFlags = true;
+    }
+    else if (newPair.Name == L"T") {
+      m_TargetProfile = newPair.Value;
+      excludeFromFlags = true;
+    }
+    else if (newPair.Name == L"D") {
+      m_Defines.push_back(newPair.Value);
+      excludeFromFlags = true;
+    }
+
+    std::wstring nameWithDash;
+    if (newPair.Name.size())
+      nameWithDash = std::wstring(L"-") + newPair.Name;
+
+    if (!excludeFromFlags) {
+      if (nameWithDash.size())
+        m_Flags.push_back(nameWithDash);
+      if (newPair.Value.size())
+        m_Flags.push_back(newPair.Value);
+    }
+
+    if (nameWithDash.size())
+      m_Args.push_back(nameWithDash);
+    if (newPair.Value.size())
+      m_Args.push_back(newPair.Value);
+
+    m_ArgPairs.push_back( std::move(newPair) );
+  }
+
 public:
   DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
   DXC_MICROCOM_TM_ALLOC(DxcPdbUtils)
@@ -737,6 +745,63 @@ public:
     return m_pDebugProgramBlob != nullptr;
   }
 
+  virtual HRESULT STDMETHODCALLTYPE OverrideArgs(_In_ DxcArgPair *pArgPairs, UINT32 uNumArgPairs) override {
+    try {
+      DxcThreadMalloc TM(m_pMalloc);
+
+      ResetAllArgs();
+
+      for (UINT32 i = 0; i < uNumArgPairs; i++) {
+        ArgPair newPair;
+        newPair.Name  = pArgPairs[i].pName ? pArgPairs[i].pName : L"";
+        newPair.Value = pArgPairs[i].pValue ? pArgPairs[i].pValue : L"";
+        AddArgPair(std::move(newPair));
+      }
+
+      // Clear the cached compile result
+      m_pCachedRecompileResult = nullptr;
+    }
+    CATCH_CPP_RETURN_HRESULT()
+
+    return S_OK;
+  }
+
+  virtual HRESULT STDMETHODCALLTYPE OverrideRootSignature(_In_ const WCHAR *pRootSignature) override {
+    try {
+      DxcThreadMalloc TM(m_pMalloc);
+
+      std::vector<ArgPair> newArgPairs;
+      for (ArgPair &pair : m_ArgPairs) {
+        if (pair.Name == L"rootsig-define") {
+          continue;
+        }
+        newArgPairs.push_back(pair);
+      }
+
+      ResetAllArgs();
+
+      for (ArgPair &newArg : newArgPairs) {
+        AddArgPair(std::move(newArg));
+      }
+
+      ArgPair rsPair;
+      rsPair.Name = L"rootsig-define";
+      rsPair.Value = L"__DXC_RS_DEFINE";
+      AddArgPair(std::move(rsPair));
+
+      ArgPair defPair;
+      defPair.Name = L"D";
+      defPair.Value = std::wstring(L"__DXC_RS_DEFINE=") + pRootSignature;
+      AddArgPair(std::move(defPair));
+
+      // Clear the cached compile result
+      m_pCachedRecompileResult = nullptr;
+    }
+    CATCH_CPP_RETURN_HRESULT()
+
+    return S_OK;
+  }
+
   virtual HRESULT STDMETHODCALLTYPE CompileForFullPDB(_COM_Outptr_ IDxcResult **ppResult) {
     if (!ppResult) return E_POINTER;
     *ppResult = nullptr;
@@ -752,13 +817,28 @@ public:
 
     DxcThreadMalloc TM(m_pMalloc);
 
+    std::vector<std::wstring> new_args_storage;
+    for (unsigned i = 0; i < m_ArgPairs.size(); i++) {
+      std::wstring name  = m_ArgPairs[i].Name;
+      std::wstring value = m_ArgPairs[i].Value;
+
+      if (name == L"Zs") continue;
+      if (name == L"Zi") continue;
+
+      if (name.size()) {
+        name.insert(name.begin(), L'-');
+        new_args_storage.push_back(std::move(name));
+      }
+      if (value.size()) {
+        new_args_storage.push_back(std::move(value));
+      }
+    }
+    new_args_storage.push_back(L"-Zi");
+
     std::vector<const WCHAR *> new_args;
-    for (unsigned i = 0; i < m_Args.size(); i++) {
-      if (m_Args[i] == L"/Zs" || m_Args[i] == L"-Zs")
-        continue;
-      new_args.push_back(m_Args[i].c_str());
+    for (std::wstring &arg : new_args_storage) {
+      new_args.push_back(arg.c_str());
     }
-    new_args.push_back(L"-Zi");
 
     assert(m_MainFileName.size());
     if (m_MainFileName.size())

+ 48 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -1288,6 +1288,54 @@ static void VerifyPdbUtil(dxc::DxcDllSupport &dllSupport,
     CComPtr<IDxcBlob> pFullPdb;
     VERIFY_SUCCEEDED(pPdbUtils->GetFullPDB(&pFullPdb));
 
+    // Save a copy of the arg pairs
+    std::vector<std::pair< std::wstring, std::wstring> > pairsStorage;
+    UINT32 uNumArgsPairs = 0;
+    VERIFY_SUCCEEDED(pPdbUtils->GetArgPairCount(&uNumArgsPairs));
+    for (UINT32 i = 0; i < uNumArgsPairs; i++) {
+      CComBSTR pName, pValue;
+      VERIFY_SUCCEEDED(pPdbUtils->GetArgPair(i, &pName, &pValue));
+      std::pair< std::wstring, std::wstring> pairStorage;
+      pairStorage.first  = pName  ? pName  : L"";
+      pairStorage.second = pValue ? pValue : L"";
+      pairsStorage.push_back(pairStorage);
+    }
+
+    // Set an obviously wrong RS and verify compilation fails
+    {
+      VERIFY_SUCCEEDED(pPdbUtils->OverrideRootSignature(L""));
+      CComPtr<IDxcResult> pResult;
+      VERIFY_SUCCEEDED(pPdbUtils->CompileForFullPDB(&pResult));
+
+      HRESULT result = S_OK;
+      VERIFY_SUCCEEDED(pResult->GetStatus(&result));
+      VERIFY_FAILED(result);
+
+      CComPtr<IDxcBlobEncoding> pErr;
+      VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&pErr));
+    }
+
+    // Set an obviously wrong set of args and verify compilation fails
+    {
+
+      std::vector<DxcArgPair> pairs;
+      for (auto &p : pairsStorage) {
+        DxcArgPair pair = {};
+        pair.pName = p.first.c_str();
+        pair.pValue = p.second.c_str();
+        pairs.push_back(pair);
+      }
+
+      VERIFY_SUCCEEDED(pPdbUtils->OverrideArgs(pairs.data(), pairs.size()));
+
+      CComPtr<IDxcResult> pResult;
+      VERIFY_SUCCEEDED(pPdbUtils->CompileForFullPDB(&pResult));
+
+      HRESULT result = S_OK;
+      VERIFY_SUCCEEDED(pResult->GetStatus(&result));
+      VERIFY_SUCCEEDED(result);
+    }
+
     auto ReplaceDebugFlagPair = [](const std::vector<std::pair<const WCHAR *, const WCHAR *> > &List) -> std::vector<std::pair<const WCHAR *, const WCHAR *> > {
       std::vector<std::pair<const WCHAR *, const WCHAR *> > ret;
       for (unsigned i = 0; i < List.size(); i++) {