Browse Source

Merged PR 23: Merge user/xiagli/demangel2 to user/texr/rt-merge-rebase

Xiang_Li (XBox) 7 years ago
parent
commit
ab87235284

+ 4 - 2
include/dxc/HLSL/DxilLinker.h

@@ -14,6 +14,7 @@
 #include <unordered_map>
 #include <unordered_set>
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringMap.h"
 #include <memory>
 #include "llvm/Support/ErrorOr.h"
 
@@ -43,8 +44,9 @@ public:
   virtual bool DetachLib(llvm::StringRef name) = 0;
   virtual void DetachAll() = 0;
 
-  virtual std::unique_ptr<llvm::Module> Link(llvm::StringRef entry,
-                                             llvm::StringRef profile) = 0;
+  virtual std::unique_ptr<llvm::Module>
+  Link(llvm::StringRef entry, llvm::StringRef profile,
+       llvm::StringMap<llvm::StringRef> &exportMap) = 0;
 
 protected:
   DxilLinker(llvm::LLVMContext &Ctx, unsigned valMajor, unsigned valMinor) : m_ctx(Ctx), m_valMajor(valMajor), m_valMinor(valMinor) {}

+ 2 - 0
include/dxc/HLSL/DxilUtil.h

@@ -38,6 +38,8 @@ namespace dxilutil {
   bool RemoveUnusedFunctions(llvm::Module &M, llvm::Function *EntryFunc,
                              llvm::Function *PatchConstantFunc, bool IsLib);
   void EmitResMappingError(llvm::Instruction *Res);
+  // Simple demangle just support case "\01?name@" pattern.
+  llvm::StringRef DemangleFunctionName(llvm::StringRef name);
   // Change select/phi on operands into select/phi on operation.
   // phi0 = phi a0, b0, c0
   // phi1 = phi a1, b1, c1

+ 16 - 0
include/dxc/dxcapi.h

@@ -220,6 +220,22 @@ public:
       _COM_Outptr_ IDxcOperationResult *
           *ppResult // Linker output status, buffer, and errors
   ) = 0;
+  // Links the shader with export and produces a shader blob that the Direct3D
+  // runtime can use.
+  virtual HRESULT STDMETHODCALLTYPE LinkWithExports(
+      _In_opt_ LPCWSTR pEntryName, // Entry point name
+      _In_ LPCWSTR pTargetProfile, // shader profile to link
+      _In_count_(libCount)
+          const LPCWSTR *pLibNames, // Array of library names to link
+      UINT32 libCount,              // Number of libraries to link
+      _In_count_(argCount)
+          const LPCWSTR *pArguments, // Array of pointers to arguments
+      _In_ UINT32 argCount,          // Number of arguments
+      _In_count_(exportCount) const DxcDefine *pExports, // Array of exports
+      _In_ UINT32 exportCount,                           // Number of exports
+      _COM_Outptr_ IDxcOperationResult *
+          *ppResult // Linker output status, buffer, and errors
+      ) = 0;
 };
 
 static const UINT32 DxcValidatorFlags_Default = 0;

+ 65 - 18
lib/HLSL/DxilLinker.cpp

@@ -14,6 +14,7 @@
 #include "dxc/HLSL/DxilOperations.h"
 #include "dxc/HLSL/DxilResource.h"
 #include "dxc/HLSL/DxilSampler.h"
+#include "dxc/HLSL/DxilUtil.h"
 #include "dxc/Support/Global.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/DenseSet.h"
@@ -135,8 +136,9 @@ public:
   bool DetachLib(StringRef name) override;
   void DetachAll() override;
 
-  std::unique_ptr<llvm::Module> Link(StringRef entry,
-                                     StringRef profile) override;
+  std::unique_ptr<llvm::Module>
+  Link(StringRef entry, StringRef profile,
+       llvm::StringMap<llvm::StringRef> &exportMap) override;
 
 private:
   bool AttachLib(DxilLib *lib);
@@ -314,7 +316,10 @@ DxilResourceBase *DxilLib::GetResource(const llvm::Constant *GV) {
 namespace {
 // Create module from link defines.
 struct DxilLinkJob {
-  DxilLinkJob(LLVMContext &Ctx, unsigned valMajor, unsigned valMinor) : m_ctx(Ctx), m_valMajor(valMajor), m_valMinor(valMinor) {}
+  DxilLinkJob(LLVMContext &Ctx, llvm::StringMap<llvm::StringRef> &exportMap,
+              unsigned valMajor, unsigned valMinor)
+      : m_ctx(Ctx), m_exportMap(exportMap), m_valMajor(valMajor),
+        m_valMinor(valMinor) {}
   std::unique_ptr<llvm::Module>
   Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
        const ShaderModel *pSM);
@@ -342,6 +347,7 @@ private:
   llvm::StringMap<std::pair<DxilResourceBase *, llvm::GlobalVariable *>>
       m_resourceMap;
   LLVMContext &m_ctx;
+  llvm::StringMap<llvm::StringRef> &m_exportMap;
   unsigned m_valMajor, m_valMinor;
 };
 } // namespace
@@ -351,6 +357,7 @@ const char kUndefFunction[] = "Cannot find definition of function ";
 const char kRedefineFunction[] = "Definition already exists for function ";
 const char kRedefineGlobal[] = "Definition already exists for global variable ";
 const char kInvalidProfile[] = " is invalid profile to link";
+const char kExportOnlyForLib[] = "export map is only for library";
 const char kShaderKindMismatch[] =
     "Profile mismatch between entry function and target profile:";
 const char kNoEntryProps[] =
@@ -841,6 +848,23 @@ DxilLinkJob::LinkToLib(const ShaderModel *pSM) {
 
   RunPreparePass(*pM);
 
+  if (!m_exportMap.empty()) {
+    DM.ClearDxilMetadata(*pM);
+    for (auto it = pM->begin(); it != pM->end();) {
+      Function *F = it++;
+      if (F->isDeclaration())
+        continue;
+      StringRef name = F->getName();
+      name = dxilutil::DemangleFunctionName(name);
+      // Remove Function not in exportMap.
+      if (m_exportMap.find(name) == m_exportMap.end()) {
+        DM.RemoveFunction(F);
+        F->eraseFromParent();
+      }
+    }
+    DM.EmitDxilMetadata();
+  }
+
   return pM;
 }
 
@@ -1039,8 +1063,9 @@ bool DxilLinkerImpl::AddFunctions(SmallVector<StringRef, 4> &workList,
   return true;
 }
 
-std::unique_ptr<llvm::Module> DxilLinkerImpl::Link(StringRef entry,
-                                               StringRef profile) {
+std::unique_ptr<llvm::Module>
+DxilLinkerImpl::Link(StringRef entry, StringRef profile,
+                     llvm::StringMap<llvm::StringRef> &exportMap) {
   const ShaderModel *pSM = ShaderModel::GetByName(profile.data());
   DXIL::ShaderKind kind = pSM->GetKind();
   if (kind == DXIL::ShaderKind::Invalid ||
@@ -1051,6 +1076,11 @@ std::unique_ptr<llvm::Module> DxilLinkerImpl::Link(StringRef entry,
     return nullptr;
   }
 
+  if (!exportMap.empty() && kind != DXIL::ShaderKind::Library) {
+    m_ctx.emitError(Twine(kExportOnlyForLib));
+    return nullptr;
+  }
+
   // Skip validation for lib target until implemented.
   if (!pSM->IsLib()) {
     // Verifying validator version supports the requested profile
@@ -1063,7 +1093,7 @@ std::unique_ptr<llvm::Module> DxilLinkerImpl::Link(StringRef entry,
     }
   }
 
-  DxilLinkJob linkJob(m_ctx, m_valMajor, m_valMinor);
+  DxilLinkJob linkJob(m_ctx, exportMap, m_valMajor, m_valMinor);
 
   DenseSet<DxilLib *> libSet;
   StringSet<> addedFunctionSet;
@@ -1078,21 +1108,38 @@ std::unique_ptr<llvm::Module> DxilLinkerImpl::Link(StringRef entry,
       return nullptr;
 
   } else {
-    // Add every function for lib profile.
-    for (auto &it : m_functionNameMap) {
-      StringRef name = it.getKey();
-      std::pair<DxilFunctionLinkInfo *, DxilLib *> &linkPair = it.second;
-      DxilFunctionLinkInfo *linkInfo = linkPair.first;
-      DxilLib *pLib = linkPair.second;
-
-      Function *F = linkInfo->func;
-      pLib->LazyLoadFunction(F);
+    if (exportMap.empty()) {
+      // Add every function for lib profile.
+      for (auto &it : m_functionNameMap) {
+        StringRef name = it.getKey();
+        std::pair<DxilFunctionLinkInfo *, DxilLib *> &linkPair = it.second;
+        DxilFunctionLinkInfo *linkInfo = linkPair.first;
+        DxilLib *pLib = linkPair.second;
+
+        Function *F = linkInfo->func;
+        pLib->LazyLoadFunction(F);
+
+        linkJob.AddFunction(linkPair);
 
-      linkJob.AddFunction(linkPair);
+        libSet.insert(pLib);
 
-      libSet.insert(pLib);
+        addedFunctionSet.insert(name);
+      }
+    } else {
+      SmallVector<StringRef, 4> workList;
+
+      // Only add exported functions.
+      for (auto &it : m_functionNameMap) {
+        StringRef name = it.getKey();
+        StringRef demangledName = dxilutil::DemangleFunctionName(name);
+        // Only add names exist in exportMap.
+        if (exportMap.find(demangledName) != exportMap.end())
+          workList.emplace_back(name);
+      }
 
-      addedFunctionSet.insert(name);
+      if (!AddFunctions(workList, libSet, addedFunctionSet, linkJob,
+                        /*bLazyLoadDone*/ false))
+        return nullptr;
     }
     // Add every dxil functions and llvm intrinsic.
     for (auto *pLib : libSet) {

+ 13 - 0
lib/HLSL/DxilUtil.cpp

@@ -24,6 +24,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/IRBuilder.h"
+#include "dxc/Support/Global.h"
 
 using namespace llvm;
 using namespace hlsl;
@@ -127,6 +128,18 @@ void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) {
   DI.print(*printer);
 }
 
+StringRef DemangleFunctionName(StringRef name) {
+  if (!name.startswith("\01?")) {
+    // Name don't mangled.
+    return name;
+  }
+
+  size_t nameEnd = name.find_first_of("@");
+  DXASSERT(nameEnd != StringRef::npos, "else Name don't mangled but has \01?");
+
+  return name.substr(2, nameEnd - 2);
+}
+
 std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
   llvm::LLVMContext &Ctx,
   std::string &DiagStr) {

+ 50 - 4
tools/clang/tools/dxcompiler/dxclinker.cpp

@@ -69,6 +69,23 @@ public:
           *ppResult // Linker output status, buffer, and errors
   );
 
+  // Links the shader with export and produces a shader blob that the Direct3D
+  // runtime can use.
+  __override HRESULT STDMETHODCALLTYPE LinkWithExports(
+      _In_opt_ LPCWSTR pEntryName, // Entry point name
+      _In_ LPCWSTR pTargetProfile, // shader profile to link
+      _In_count_(libCount)
+          const LPCWSTR *pLibNames, // Array of library names to link
+      UINT32 libCount,              // Number of libraries to link
+      _In_count_(argCount)
+          const LPCWSTR *pArguments, // Array of pointers to arguments
+      _In_ UINT32 argCount,          // Number of arguments
+      _In_count_(exportCount) const DxcDefine *pExports, // Array of exports
+      _In_ UINT32 exportCount,                           // Number of exports
+      _COM_Outptr_ IDxcOperationResult *
+          *ppResult // Linker output status, buffer, and errors
+      );
+
   __override HRESULT STDMETHODCALLTYPE RegisterDxilContainerEventHandler(
       IDxcContainerEventsHandler *pHandler, UINT64 *pCookie) {
     DxcThreadMalloc TM(m_pMalloc);
@@ -149,8 +166,6 @@ DxcLinker::RegisterLibrary(_In_opt_ LPCWSTR pLibName, // Name of the library.
   }
 }
 
-// Links the shader and produces a shader blob that the Direct3D runtime can
-// use.
 HRESULT STDMETHODCALLTYPE DxcLinker::Link(
     _In_opt_ LPCWSTR pEntryName, // Entry point name
     _In_ LPCWSTR pTargetProfile, // shader profile to link
@@ -162,6 +177,27 @@ HRESULT STDMETHODCALLTYPE DxcLinker::Link(
     _In_ UINT32 argCount,          // Number of arguments
     _COM_Outptr_ IDxcOperationResult *
         *ppResult // Linker output status, buffer, and errors
+) {
+  return LinkWithExports(pEntryName, pTargetProfile, pLibNames, libCount,
+                         pArguments, argCount, /*pExorts*/ nullptr,
+                         /*exportCount*/ 0, ppResult);
+}
+
+// Links the shader with export and produces a shader blob that the Direct3D
+// runtime can use.
+__override HRESULT STDMETHODCALLTYPE DxcLinker::LinkWithExports(
+    _In_opt_ LPCWSTR pEntryName, // Entry point name
+    _In_ LPCWSTR pTargetProfile, // shader profile to link
+    _In_count_(libCount)
+        const LPCWSTR *pLibNames, // Array of library names to link
+    UINT32 libCount,              // Number of libraries to link
+    _In_count_(argCount)
+        const LPCWSTR *pArguments, // Array of pointers to arguments
+    _In_ UINT32 argCount,          // Number of arguments
+    _In_count_(exportCount) const DxcDefine *pExports, // Array of exports
+    _In_ UINT32 exportCount,                           // Number of exports
+    _COM_Outptr_ IDxcOperationResult *
+        *ppResult // Linker output status, buffer, and errors
 ) {
   DxcThreadMalloc TM(m_pMalloc);
   // Prepare UTF8-encoded versions of API values.
@@ -194,6 +230,8 @@ HRESULT STDMETHODCALLTYPE DxcLinker::Link(
     bool finished;
     dxcutil::ReadOptsAndValidate(mainArgs, opts, pOutputStream, ppResult,
                                  finished);
+    if (pEntryName)
+      opts.EntryPoint = pUtf8EntryPoint.m_psz;
     if (finished) {
       return S_OK;
     }
@@ -216,8 +254,16 @@ HRESULT STDMETHODCALLTYPE DxcLinker::Link(
 
     bool hasErrorOccurred = !bSuccess;
     if (bSuccess) {
-      std::unique_ptr<Module> pM =
-          m_pLinker->Link(pUtf8EntryPoint.m_psz, pUtf8TargetProfile.m_psz);
+      StringMap<StringRef> exportMap;
+      std::vector<std::string> names(exportCount);
+      for (unsigned i=0;i<exportCount;i++) {
+        const DxcDefine &pExport = pExports[i];
+        names[i] = CW2A(pExport.Name);
+        exportMap[names[i]] = "";
+      }
+
+      std::unique_ptr<Module> pM = m_pLinker->Link(
+          opts.EntryPoint, pUtf8TargetProfile.m_psz, exportMap);
       if (pM) {
         const IntrusiveRefCntPtr<clang::DiagnosticIDs> Diags(
             new clang::DiagnosticIDs);

+ 46 - 0
tools/clang/unittests/HLSL/LinkerTest.cpp

@@ -46,6 +46,7 @@ public:
   TEST_METHOD(RunLinkNoAlloca);
   TEST_METHOD(RunLinkResRet);
   TEST_METHOD(RunLinkToLib);
+  TEST_METHOD(RunLinkToLibExport);
   TEST_METHOD(RunLinkFailReDefineGlobal);
   TEST_METHOD(RunLinkFailProfileMismatch);
   TEST_METHOD(RunLinkFailEntryNoProps);
@@ -114,6 +115,30 @@ public:
     }
   }
 
+  void LinkWithExports(IDxcLinker *pLinker, ArrayRef<LPCWSTR> libNames,
+                       ArrayRef<DxcDefine> exportNames,
+                       llvm::ArrayRef<LPCSTR> pCheckMsgs,
+                       llvm::ArrayRef<LPCSTR> pCheckNotMsgs) {
+    CComPtr<IDxcOperationResult> pResult;
+    VERIFY_SUCCEEDED(pLinker->LinkWithExports(
+        /*pEntryName*/ nullptr, /*pShaderModel*/ L"lib_6_2", libNames.data(),
+        libNames.size(), nullptr, 0, exportNames.data(), exportNames.size(),
+        &pResult));
+    CComPtr<IDxcBlob> pProgram;
+    CheckOperationSucceeded(pResult, &pProgram);
+
+    CComPtr<IDxcCompiler> pCompiler;
+    CComPtr<IDxcBlobEncoding> pDisassembly;
+
+    VERIFY_SUCCEEDED(
+        m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
+    VERIFY_SUCCEEDED(pCompiler->Disassemble(pProgram, &pDisassembly));
+    std::string IR = BlobToUtf8(pDisassembly);
+    CheckMsgs(IR.c_str(), IR.size(), pCheckMsgs.data(), pCheckMsgs.size(), false);
+    for (auto notMsg : pCheckNotMsgs) {
+      VERIFY_IS_TRUE(IR.find(notMsg) == std::string::npos);
+    }
+  }
   void LinkCheckMsg(LPCWSTR pEntryName, LPCWSTR pShaderModel, IDxcLinker *pLinker,
             ArrayRef<LPCWSTR> libNames, llvm::ArrayRef<LPCSTR> pErrorMsgs) {
     CComPtr<IDxcOperationResult> pResult;
@@ -328,6 +353,27 @@ TEST_F(LinkerTest, RunLinkToLib) {
   Link(L"", L"lib_6_2", pLinker, {libName, libName2}, {"!llvm.dbg.cu"}, {});
 }
 
+TEST_F(LinkerTest, RunLinkToLibExport) {
+  CComPtr<IDxcBlob> pEntryLib;
+  CompileLib(L"..\\CodeGenHLSL\\shader-compat-suite\\lib_out_param_res.hlsl",
+             &pEntryLib);
+  CComPtr<IDxcBlob> pLib;
+  CompileLib(
+      L"..\\CodeGenHLSL\\shader-compat-suite\\lib_out_param_res_imp.hlsl",
+      &pLib);
+
+  CComPtr<IDxcLinker> pLinker;
+  CreateLinker(&pLinker);
+
+  LPCWSTR libName = L"ps_main";
+  RegisterDxcModule(libName, pEntryLib, pLinker);
+
+  LPCWSTR libName2 = L"test";
+  RegisterDxcModule(libName2, pLib, pLinker);
+  DxcDefine exports[] = { {L"test", L""} };
+  LinkWithExports(pLinker, {libName, libName2}, exports, {"@\"\\01?test@@","@test"}, {"@\"\\01?GetBuf"});
+}
+
 TEST_F(LinkerTest, RunLinkFailSelectRes) {
   if (m_ver.SkipDxilVersion(1, 3)) return;
   CComPtr<IDxcBlob> pEntryLib;