2
0
Эх сурвалжийг харах

Add cache for preprocessed header in lib cache manager. (#460)

Xiang Li 8 жил өмнө
parent
commit
667d9e03da

+ 108 - 32
tools/clang/tools/dxlib-sample/lib_cache_manager.cpp

@@ -10,6 +10,7 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 #include "dxc/Support/WinIncludes.h"
+#include "dxc/Support/Global.h"
 #include "dxc/dxcapi.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Hashing.h"
@@ -18,10 +19,59 @@
 #include <unordered_map>
 #include "lib_share_helper.h"
 #include "llvm/ADT/STLExtras.h"
+#include "dxc/Support/FileIOHelper.h"
 
 using namespace llvm;
 using namespace libshare;
 
+#include "dxc/Support/dxcapi.use.h"
+#include "dxc/dxctools.h"
+namespace {
+class NoFuncBodyRewriter {
+public:
+  NoFuncBodyRewriter() {
+    if (!m_dllSupport.IsEnabled())
+      m_dllSupport.Initialize();
+    m_dllSupport.CreateInstance(CLSID_DxcRewriter, &m_pRewriter);
+  }
+  HRESULT RewriteToNoFuncBody(LPCWSTR pFilename,
+                              IDxcBlobEncoding *pSource,
+                              std::vector<DxcDefine> &m_defines,
+                              IDxcBlob **ppNoFuncBodySource);
+
+private:
+  CComPtr<IDxcRewriter> m_pRewriter;
+  dxc::DxcDllSupport m_dllSupport;
+};
+
+HRESULT NoFuncBodyRewriter::RewriteToNoFuncBody(
+    LPCWSTR pFilename, IDxcBlobEncoding *pSource,
+    std::vector<DxcDefine> &m_defines, IDxcBlob **ppNoFuncBodySource) {
+  // Create header with no function body.
+  CComPtr<IDxcOperationResult> pRewriteResult;
+  IFT(m_pRewriter->RewriteUnchangedWithInclude(
+      pSource, pFilename, m_defines.data(), m_defines.size(),
+      // Don't need include handler here, include already read in
+      // RewriteIncludesToSnippet
+      nullptr,
+      RewriterOptionMask::SkipFunctionBody | RewriterOptionMask::KeepUserMacro,
+      &pRewriteResult));
+  HRESULT status;
+  if (!SUCCEEDED(pRewriteResult->GetStatus(&status)) || !SUCCEEDED(status)) {
+    CComPtr<IDxcBlobEncoding> pErr;
+    IFT(pRewriteResult->GetErrorBuffer(&pErr));
+    std::string errString =
+        std::string((char *)pErr->GetBufferPointer(), pErr->GetBufferSize());
+    IFTMSG(E_FAIL, errString);
+    return E_FAIL;
+  };
+
+  // Get result.
+  IFT(pRewriteResult->GetResult(ppNoFuncBodySource));
+  return S_OK;
+}
+} // namespace
+
 namespace {
 
 struct KeyHash {
@@ -36,19 +86,26 @@ struct KeyEqual {
 class LibCacheManagerImpl : public libshare::LibCacheManager {
 public:
   ~LibCacheManagerImpl() {}
-  HRESULT AddLibBlob(IDxcBlob *pSource, CompileInput &compiler, size_t &hash,
+  HRESULT AddLibBlob(std::string &processedHeader, const std::string &snippet,
+                     CompileInput &compiler, size_t &hash,
                      IDxcBlob **pResultLib,
-                     std::function<void(void)> compileFn) override;
-  bool GetLibBlob(IDxcBlob *pSource, CompileInput &compiler, size_t &hash,
+                     std::function<void(IDxcBlob *pSource)> compileFn) override;
+  bool GetLibBlob(std::string &processedHeader, const std::string &snippet,
+                  CompileInput &compiler, size_t &hash,
                   IDxcBlob **pResultLib) override;
   void Release() { m_libCache.clear(); }
 
 private:
-  hash_code GetHash(IDxcBlob *pSource, CompileInput &compiler);
+  hash_code GetHash(const std::string &header, const std::string &snippet,
+                    CompileInput &compiler);
   using libCacheType =
       std::unordered_map<hash_code, CComPtr<IDxcBlob>, KeyHash, KeyEqual>;
   libCacheType m_libCache;
+  using headerCacheType =
+      std::unordered_map<hash_code, std::string, KeyHash, KeyEqual>;
+  headerCacheType m_headerCache;
   std::shared_mutex m_mutex;
+  NoFuncBodyRewriter  m_rewriter;
 };
 
 static hash_code CombineWStr(hash_code hash, LPCWSTR Arg) {
@@ -57,9 +114,10 @@ static hash_code CombineWStr(hash_code hash, LPCWSTR Arg) {
   return hash_combine(hash, StringRef(pUtf8Arg.m_psz, length));
 }
 
-hash_code LibCacheManagerImpl::GetHash(IDxcBlob *pSource, CompileInput &compiler) {
-  hash_code libHash = hash_value(
-      StringRef((char *)pSource->GetBufferPointer(), pSource->GetBufferSize()));
+hash_code LibCacheManagerImpl::GetHash(const std::string &header, const std::string &snippet,
+                    CompileInput &compiler) {
+  hash_code libHash = hash_value(header);
+  libHash = hash_combine(libHash, snippet);
   // Combine compile input.
   for (auto &Arg : compiler.arguments) {
     libHash = CombineWStr(libHash, Arg);
@@ -73,31 +131,14 @@ hash_code LibCacheManagerImpl::GetHash(IDxcBlob *pSource, CompileInput &compiler
   return libHash;
 }
 
-bool LibCacheManagerImpl::GetLibBlob(IDxcBlob *pSource, CompileInput &compiler,
-                                 size_t &hash, IDxcBlob **pResultLib) {
-  if (!pSource || !pResultLib) {
-    return false;
-  }
-  // Create hash from source.
-  hash_code libHash = GetHash(pSource, compiler);
-  hash = libHash;
-  // lock
-  std::shared_lock<std::shared_mutex> lk(m_mutex);
-
-  auto it = m_libCache.find(libHash);
-  if (it != m_libCache.end()) {
-    *pResultLib = it->second;
-    return true;
-  } else {
-    return false;
-  }
-}
+using namespace hlsl;
 
-HRESULT
-LibCacheManagerImpl::AddLibBlob(IDxcBlob *pSource, CompileInput &compiler,
-                            size_t &hash, IDxcBlob **pResultLib,
-                            std::function<void(void)> compileFn) {
-  if (!pSource || !pResultLib) {
+HRESULT LibCacheManagerImpl::AddLibBlob(std::string &processedHeader,
+                                        const std::string &snippet,
+                                        CompileInput &compiler, size_t &hash,
+                                        IDxcBlob **pResultLib,
+                                        std::function<void(IDxcBlob *pSource)> compileFn) {
+  if (!pResultLib) {
     return E_FAIL;
   }
 
@@ -106,16 +147,51 @@ LibCacheManagerImpl::AddLibBlob(IDxcBlob *pSource, CompileInput &compiler,
   auto it = m_libCache.find(hash);
   if (it != m_libCache.end()) {
     *pResultLib = it->second;
+    DXASSERT(m_headerCache.count(hash), "else mismatch header and lib");
+    processedHeader = m_headerCache[hash];
     return S_OK;
   }
+  std::string shader = processedHeader + snippet;
+  CComPtr<IDxcBlobEncoding> pSource;
+  IFT(DxcCreateBlobWithEncodingOnMallocCopy(GetGlobalHeapMalloc(), shader.data(), shader.size(), CP_UTF8, &pSource));
 
-  compileFn();
+  compileFn(pSource);
 
   m_libCache[hash] = *pResultLib;
 
+  // Rewrite curHeader to remove function body.
+  CComPtr<IDxcBlob> result;
+  IFT(m_rewriter.RewriteToNoFuncBody(L"input.hlsl", pSource, compiler.defines, &result));
+  processedHeader = std::string((char *)(result)->GetBufferPointer(),
+                                   (result)->GetBufferSize());
+  m_headerCache[hash] = processedHeader;
   return S_OK;
 }
 
+bool LibCacheManagerImpl::GetLibBlob(std::string &processedHeader,
+                                     const std::string &snippet,
+                                     CompileInput &compiler, size_t &hash,
+                                     IDxcBlob **pResultLib) {
+  if (!pResultLib) {
+    return false;
+  }
+  // Create hash from source.
+  hash_code libHash = GetHash(processedHeader, snippet, compiler);
+  hash = libHash;
+  // lock
+  std::shared_lock<std::shared_mutex> lk(m_mutex);
+
+  auto it = m_libCache.find(libHash);
+  if (it != m_libCache.end()) {
+    *pResultLib = it->second;
+    DXASSERT(m_headerCache.count(libHash), "else mismatch header and lib");
+    processedHeader = m_headerCache[libHash];
+    return true;
+  } else {
+    return false;
+  }
+}
+
 LibCacheManager *GetLibCacheManagerPtr(bool bFree) {
   static std::unique_ptr<LibCacheManagerImpl> g_LibCache =
       llvm::make_unique<LibCacheManagerImpl>();

+ 7 - 8
tools/clang/tools/dxlib-sample/lib_share_compile.cpp

@@ -152,26 +152,25 @@ HRESULT CompileFromBlob(IDxcBlobEncoding *pSource, LPCWSTR pSourceName,
     LibCacheManager &libCache = LibCacheManager::GetLibCacheManager();
     IFR(CreateLinker(&linker));
     IDxcIncludeHandler * const kNoIncHandler = nullptr;
-    const auto &headers = preprocessor->GetHeaders();
+    const auto &snippets = preprocessor->GetSnippets();
+
+    std::string processedHeader = "";
     std::vector<std::wstring> hashStrList;
     std::vector<LPCWSTR> hashList;
-    for (const auto &header : headers) {
+    for (const auto &snippet : snippets) {
       CComPtr<IDxcBlob> pOutputBlob;
-      CComPtr<IDxcBlobEncoding> pSource;
-      IFT(DxcCreateBlobWithEncodingOnMallocCopy(GetGlobalHeapMalloc(), header.data(), header.size(), CP_UTF8, &pSource));
       size_t hash;
-      if (!libCache.GetLibBlob(pSource, compilerInput, hash, &pOutputBlob)) {
+      if (!libCache.GetLibBlob(processedHeader, snippet, compilerInput, hash, &pOutputBlob)) {
         // Cannot find existing blob, create from pSource.
         IDxcBlob **ppCode = &pOutputBlob;
 
-        auto compileFn = [&]() {
+        auto compileFn = [&](IDxcBlob *pSource) {
           IFT(CompileToLib(pSource, defines, kNoIncHandler, arguments,
                            ppCode, nullptr));
         };
-        libCache.AddLibBlob(pSource, compilerInput, hash, &pOutputBlob,
+        libCache.AddLibBlob(processedHeader, snippet, compilerInput, hash, &pOutputBlob,
                             compileFn);
       }
-      pSource.Detach(); // Don't keep the ownership.
       hashStrList.emplace_back(std::to_wstring(hash));
       hashList.emplace_back(hashStrList.back().c_str());
       linker->RegisterLibrary(hashList.back(), pOutputBlob);

+ 8 - 6
tools/clang/tools/dxlib-sample/lib_share_helper.h

@@ -35,11 +35,13 @@ struct CompileInput {
 class LibCacheManager {
 public:
   virtual ~LibCacheManager() {}
-  virtual HRESULT AddLibBlob(IDxcBlob *pSource, CompileInput &compiler, size_t &hash,
-                     IDxcBlob **pResultLib,
-                     std::function<void(void)> compileFn) = 0;
-  virtual bool GetLibBlob(IDxcBlob *pSource, CompileInput &compiler, size_t &hash,
-                  IDxcBlob **pResultLib) = 0;
+  virtual HRESULT AddLibBlob(std::string &header, const std::string &snippet,
+                             CompileInput &compiler, size_t &hash,
+                             IDxcBlob **pResultLib,
+                             std::function<void(IDxcBlob *pSource)> compileFn) = 0;
+  virtual bool GetLibBlob(std::string &processedHeader, const std::string &snippet,
+                          CompileInput &compiler, size_t &hash,
+                          IDxcBlob **pResultLib) = 0;
   static LibCacheManager &GetLibCacheManager();
   static void ReleaseLibCacheManager();
 };
@@ -52,7 +54,7 @@ public:
   virtual void AddIncPath(llvm::StringRef path) = 0;
   virtual HRESULT Preprocess(IDxcBlob *pSource, LPCWSTR pFilename) = 0;
 
-  virtual const std::vector<std::string> &GetHeaders() const = 0;
+  virtual const std::vector<std::string> &GetSnippets() const = 0;
   static std::unique_ptr<IncludeToLibPreprocessor>
   CreateIncludeToLibPreprocessor(IDxcIncludeHandler *handler);
 };

+ 4 - 60
tools/clang/tools/dxlib-sample/lib_share_preprocessor.cpp

@@ -12,8 +12,6 @@
 #include "dxc/Support/WinIncludes.h"
 #include "dxc/Support/Global.h"
 #include "dxc/Support/microcom.h"
-#include "dxc/dxctools.h"
-#include "dxc/Support/dxcapi.use.h"
 #include "clang/Rewrite/Frontend/Rewriters.h"
 #include "dxc/Support/FileIOHelper.h"
 
@@ -123,27 +121,22 @@ public:
 
   IncludeToLibPreprocessorImpl(IDxcIncludeHandler *handler)
       : m_pIncludeHandler(handler) {
-    if (!m_dllSupport.IsEnabled())
-      m_dllSupport.Initialize();
   }
 
   void SetupDefines(const DxcDefine *pDefines, unsigned defineCount) override;
   void AddIncPath(StringRef path) override;
   HRESULT Preprocess(IDxcBlob *pSource, LPCWSTR pFilename) override;
 
-  const std::vector<std::string> &GetHeaders() const override { return m_headers; }
+  const std::vector<std::string> &GetSnippets() const override { return m_snippets; }
 
 private:
-  HRESULT RewriteToNoFuncBody(IDxcRewriter *pRewriter, LPCWSTR pFilename,
-    std::string &Source, IDxcBlob **ppNoFuncBodySource);
   IDxcIncludeHandler *m_pIncludeHandler;
-  // Processed header content.
-  std::vector<std::string> m_headers;
+  // Snippets split by #include.
+  std::vector<std::string> m_snippets;
   // Defines.
   std::vector<std::wstring> m_defineStrs;
   std::vector<DxcDefine> m_defines;
   std::vector<std::string> m_includePathList;
-  dxc::DxcDllSupport m_dllSupport;
 };
 
 void IncludeToLibPreprocessorImpl::SetupDefines(const DxcDefine *pDefines,
@@ -170,39 +163,6 @@ void IncludeToLibPreprocessorImpl::AddIncPath(StringRef path) {
   m_includePathList.emplace_back(path);
 }
 
-HRESULT IncludeToLibPreprocessorImpl::RewriteToNoFuncBody(IDxcRewriter *pRewriter, LPCWSTR pFilename,
-    std::string &Source, IDxcBlob **ppNoFuncBodySource) {
-  // Create header with no function body.
-  CComPtr<IDxcBlobEncoding> pEncodingIncludeSource;
-  IFR(DxcCreateBlobWithEncodingOnMalloc(
-      Source.data(), GetGlobalHeapMalloc(), Source.size(),
-      CP_UTF8, &pEncodingIncludeSource));
-
-  CComPtr<IDxcOperationResult> pRewriteResult;
-  IFT(pRewriter->RewriteUnchangedWithInclude(
-      pEncodingIncludeSource, pFilename, m_defines.data(), m_defines.size(),
-      // Don't need include handler here, include already read in
-      // RewriteIncludesToSnippet
-      nullptr,
-      RewriterOptionMask::SkipFunctionBody | RewriterOptionMask::KeepUserMacro,
-      &pRewriteResult));
-  // includeSource ownes the memory.
-  pEncodingIncludeSource.Detach();
-  HRESULT status;
-  if (!SUCCEEDED(pRewriteResult->GetStatus(&status)) || !SUCCEEDED(status)) {
-    CComPtr<IDxcBlobEncoding> pErr;
-    IFT(pRewriteResult->GetErrorBuffer(&pErr));
-    std::string errString =
-        std::string((char *)pErr->GetBufferPointer(), pErr->GetBufferSize());
-    IFTMSG(E_FAIL, errString);
-    return E_FAIL;
-  };
-
-  // Get result.
-  IFT(pRewriteResult->GetResult(ppNoFuncBodySource));
-  return S_OK;
-}
-
 using namespace clang;
 static
 void SetupCompilerForRewrite(CompilerInstance &compiler,
@@ -322,23 +282,7 @@ HRESULT IncludeToLibPreprocessorImpl::Preprocess(IDxcBlob *pSource,
   // AddRef to hold incPathIncludeHandler.
   // If not, DxcArgsFileSystem will kill it.
   incPathIncludeHandler.AddRef();
-  std::vector<std::string> Snippets;
-  RewriteToSnippets(pSource, pFilename, m_defines, &incPathIncludeHandler, Snippets);
-
-  // Combine Snippets.
-  CComPtr<IDxcRewriter> pRewriter;
-  m_dllSupport.CreateInstance(CLSID_DxcRewriter, &pRewriter);
-  std::string curHeader = "";
-  for (std::string &Snippet : Snippets) {
-    curHeader = curHeader + Snippet;
-    m_headers.emplace_back(curHeader);
-    // Rewrite curHeader to remove function body.
-    CComPtr<IDxcBlob> result;
-    IFT(RewriteToNoFuncBody(pRewriter, L"input.hlsl", curHeader, &result));
-    curHeader = std::string((char *)(result)->GetBufferPointer(),
-                                   (result)->GetBufferSize());
-  }
-
+  RewriteToSnippets(pSource, pFilename, m_defines, &incPathIncludeHandler, m_snippets);
   return S_OK;
 }
 } // namespace