dxclinker.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // dxclinker.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Implements the Dxil Linker. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/Support/WinIncludes.h"
  12. #include "dxc/DxilContainer/DxilContainer.h"
  13. #include "dxc/Support/ErrorCodes.h"
  14. #include "dxc/Support/Global.h"
  15. #include "dxc/Support/FileIOHelper.h"
  16. #include "dxc/Support/dxcapi.impl.h"
  17. #include "dxc/Support/microcom.h"
  18. #include "dxc/dxcapi.h"
  19. #include "dxillib.h"
  20. #include "llvm/ADT/SmallVector.h"
  21. #include <algorithm>
  22. #include "dxc/HLSL/DxilLinker.h"
  23. #include "dxc/HLSL/DxilValidation.h"
  24. #include "dxc/Support/Unicode.h"
  25. #include "dxc/Support/microcom.h"
  26. #include "dxc/dxcapi.internal.h"
  27. #include "dxcutil.h"
  28. #include "clang/Basic/Diagnostic.h"
  29. #include "llvm/Bitcode/ReaderWriter.h"
  30. #include "llvm/IR/DiagnosticPrinter.h"
  31. #include "llvm/IR/LLVMContext.h"
  32. #include "llvm/IR/Module.h"
  33. #include "llvm/Support/raw_ostream.h"
  34. #include "clang/Frontend/TextDiagnosticPrinter.h"
  35. #include "dxc/Support/HLSLOptions.h"
  36. using namespace hlsl;
  37. using namespace llvm;
  38. // This declaration is used for the locally-linked validator.
  39. HRESULT CreateDxcValidator(_In_ REFIID riid, _Out_ LPVOID *ppv);
  40. class DxcLinker : public IDxcLinker, public IDxcContainerEvent {
  41. public:
  42. DXC_MICROCOM_TM_ADDREF_RELEASE_IMPL()
  43. DXC_MICROCOM_TM_CTOR(DxcLinker)
  44. // Register a library with name to ref it later.
  45. HRESULT RegisterLibrary(
  46. _In_opt_ LPCWSTR pLibName, // Name of the library.
  47. _In_ IDxcBlob *pLib // Library to add.
  48. ) override;
  49. // Links the shader and produces a shader blob that the Direct3D runtime can
  50. // use.
  51. HRESULT STDMETHODCALLTYPE Link(
  52. _In_opt_ LPCWSTR pEntryName, // Entry point name
  53. _In_ LPCWSTR pTargetProfile, // shader profile to link
  54. _In_count_(libCount)
  55. const LPCWSTR *pLibNames, // Array of library names to link
  56. UINT32 libCount, // Number of libraries to link
  57. _In_count_(argCount)
  58. const LPCWSTR *pArguments, // Array of pointers to arguments
  59. _In_ UINT32 argCount, // Number of arguments
  60. _COM_Outptr_ IDxcOperationResult *
  61. *ppResult // Linker output status, buffer, and errors
  62. ) override;
  63. HRESULT STDMETHODCALLTYPE RegisterDxilContainerEventHandler(
  64. IDxcContainerEventsHandler *pHandler, UINT64 *pCookie) override {
  65. DxcThreadMalloc TM(m_pMalloc);
  66. DXASSERT(m_pDxcContainerEventsHandler == nullptr,
  67. "else events handler is already registered");
  68. *pCookie = 1; // Only one EventsHandler supported
  69. m_pDxcContainerEventsHandler = pHandler;
  70. return S_OK;
  71. };
  72. HRESULT STDMETHODCALLTYPE
  73. UnRegisterDxilContainerEventHandler(UINT64 cookie) override {
  74. DxcThreadMalloc TM(m_pMalloc);
  75. DXASSERT(m_pDxcContainerEventsHandler != nullptr,
  76. "else unregister should not have been called");
  77. m_pDxcContainerEventsHandler.Release();
  78. return S_OK;
  79. }
  80. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void **ppvObject) {
  81. return DoBasicQueryInterface<IDxcLinker>(this, riid, ppvObject);
  82. }
  83. void Initialize() {
  84. UINT32 valMajor, valMinor;
  85. dxcutil::GetValidatorVersion(&valMajor, &valMinor);
  86. m_pLinker.reset(DxilLinker::CreateLinker(m_Ctx, valMajor, valMinor));
  87. }
  88. ~DxcLinker() {
  89. // Make sure DxilLinker is released before LLVMContext.
  90. m_pLinker.reset();
  91. }
  92. private:
  93. DXC_MICROCOM_TM_REF_FIELDS()
  94. LLVMContext m_Ctx;
  95. std::unique_ptr<DxilLinker> m_pLinker;
  96. CComPtr<IDxcContainerEventsHandler> m_pDxcContainerEventsHandler;
  97. std::vector<CComPtr<IDxcBlob>> m_blobs; // Keep blobs live for lazy load.
  98. };
  99. HRESULT
  100. DxcLinker::RegisterLibrary(_In_opt_ LPCWSTR pLibName, // Name of the library.
  101. _In_ IDxcBlob *pBlob // Library to add.
  102. ) {
  103. if (!pLibName || !pBlob)
  104. return E_INVALIDARG;
  105. DXASSERT(m_pLinker.get(), "else Initialize() not called or failed silently");
  106. DxcThreadMalloc TM(m_pMalloc);
  107. // Prepare UTF8-encoded versions of API values.
  108. CW2A pUtf8LibName(pLibName, CP_UTF8);
  109. // Already exist lib with same name.
  110. if (m_pLinker->HasLibNameRegistered(pUtf8LibName.m_psz))
  111. return E_INVALIDARG;
  112. try {
  113. std::unique_ptr<llvm::Module> pModule, pDebugModule;
  114. CComPtr<IMalloc> pMalloc;
  115. CComPtr<AbstractMemoryStream> pDiagStream;
  116. IFT(CoGetMalloc(1, &pMalloc));
  117. IFT(CreateMemoryStream(pMalloc, &pDiagStream));
  118. raw_stream_ostream DiagStream(pDiagStream);
  119. IFR(ValidateLoadModuleFromContainerLazy(
  120. pBlob->GetBufferPointer(), pBlob->GetBufferSize(), pModule,
  121. pDebugModule, m_Ctx, m_Ctx, DiagStream));
  122. if (m_pLinker->RegisterLib(pUtf8LibName.m_psz, std::move(pModule),
  123. std::move(pDebugModule))) {
  124. m_blobs.emplace_back(pBlob);
  125. return S_OK;
  126. } else {
  127. return E_INVALIDARG;
  128. }
  129. } catch (hlsl::Exception &) {
  130. return E_INVALIDARG;
  131. }
  132. }
  133. // Links the shader and produces a shader blob that the Direct3D runtime can
  134. // use.
  135. HRESULT STDMETHODCALLTYPE DxcLinker::Link(
  136. _In_opt_ LPCWSTR pEntryName, // Entry point name
  137. _In_ LPCWSTR pTargetProfile, // shader profile to link
  138. _In_count_(libCount)
  139. const LPCWSTR *pLibNames, // Array of library names to link
  140. UINT32 libCount, // Number of libraries to link
  141. _In_count_(argCount)
  142. const LPCWSTR *pArguments, // Array of pointers to arguments
  143. _In_ UINT32 argCount, // Number of arguments
  144. _COM_Outptr_ IDxcOperationResult *
  145. *ppResult // Linker output status, buffer, and errors
  146. ) {
  147. if (!pTargetProfile || !pLibNames || libCount == 0 || !ppResult)
  148. return E_INVALIDARG;
  149. DxcThreadMalloc TM(m_pMalloc);
  150. // Prepare UTF8-encoded versions of API values.
  151. CW2A pUtf8TargetProfile(pTargetProfile, CP_UTF8);
  152. CW2A pUtf8EntryPoint(pEntryName, CP_UTF8);
  153. CComPtr<AbstractMemoryStream> pOutputStream;
  154. // Detach previous libraries.
  155. m_pLinker->DetachAll();
  156. HRESULT hr = S_OK;
  157. try {
  158. CComPtr<IMalloc> pMalloc;
  159. CComPtr<IDxcBlob> pOutputBlob;
  160. CComPtr<AbstractMemoryStream> pDiagStream;
  161. IFT(CoGetMalloc(1, &pMalloc));
  162. IFT(CreateMemoryStream(pMalloc, &pOutputStream));
  163. // Read and validate options.
  164. int argCountInt;
  165. IFT(UIntToInt(argCount, &argCountInt));
  166. hlsl::options::MainArgs mainArgs(argCountInt,
  167. const_cast<LPCWSTR *>(pArguments), 0);
  168. hlsl::options::DxcOpts opts;
  169. CW2A pUtf8TargetProfile(pTargetProfile, CP_UTF8);
  170. // Set target profile before reading options and validate
  171. opts.TargetProfile = pUtf8TargetProfile.m_psz;
  172. bool finished;
  173. dxcutil::ReadOptsAndValidate(mainArgs, opts, pOutputStream, ppResult,
  174. finished);
  175. if (pEntryName)
  176. opts.EntryPoint = pUtf8EntryPoint.m_psz;
  177. if (finished) {
  178. return S_OK;
  179. }
  180. std::string warnings;
  181. //llvm::raw_string_ostream w(warnings);
  182. IFT(CreateMemoryStream(pMalloc, &pDiagStream));
  183. raw_stream_ostream DiagStream(pDiagStream);
  184. llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
  185. PrintDiagnosticContext DiagContext(DiagPrinter);
  186. m_Ctx.setDiagnosticHandler(PrintDiagnosticContext::PrintDiagnosticHandler,
  187. &DiagContext, true);
  188. if (opts.ValVerMajor != UINT32_MAX) {
  189. m_pLinker->SetValidatorVersion(opts.ValVerMajor, opts.ValVerMinor);
  190. }
  191. // Attach libraries.
  192. bool bSuccess = true;
  193. for (unsigned i = 0; i < libCount; i++) {
  194. CW2A pUtf8LibName(pLibNames[i], CP_UTF8);
  195. bSuccess &= m_pLinker->AttachLib(pUtf8LibName.m_psz);
  196. }
  197. dxilutil::ExportMap exportMap;
  198. bSuccess = exportMap.ParseExports(opts.Exports, DiagStream);
  199. bool hasErrorOccurred = !bSuccess;
  200. if (bSuccess) {
  201. std::unique_ptr<Module> pM = m_pLinker->Link(
  202. opts.EntryPoint, pUtf8TargetProfile.m_psz, exportMap);
  203. if (pM) {
  204. const IntrusiveRefCntPtr<clang::DiagnosticIDs> Diags(
  205. new clang::DiagnosticIDs);
  206. IntrusiveRefCntPtr<clang::DiagnosticOptions> DiagOpts =
  207. new clang::DiagnosticOptions();
  208. // Construct our diagnostic client.
  209. clang::TextDiagnosticPrinter *DiagClient =
  210. new clang::TextDiagnosticPrinter(DiagStream, &*DiagOpts);
  211. clang::DiagnosticsEngine Diag(Diags, &*DiagOpts, DiagClient);
  212. raw_stream_ostream outStream(pOutputStream.p);
  213. // Create bitcode of M.
  214. WriteBitcodeToFile(pM.get(), outStream);
  215. outStream.flush();
  216. // Always save debug info. If lib has debug info, the link result will
  217. // have debug info.
  218. SerializeDxilFlags SerializeFlags =
  219. SerializeDxilFlags::IncludeDebugNamePart;
  220. // Unless we want to strip it right away, include it in the container.
  221. if (!opts.StripDebug) {
  222. SerializeFlags |= SerializeDxilFlags::IncludeDebugInfoPart;
  223. }
  224. if (opts.DebugNameForSource) {
  225. SerializeFlags |= SerializeDxilFlags::DebugNameDependOnSource;
  226. }
  227. // Validation.
  228. HRESULT valHR = S_OK;
  229. // Skip validation on lib for now.
  230. if (!opts.TargetProfile.startswith("lib_")) {
  231. valHR = dxcutil::ValidateAndAssembleToContainer(
  232. std::move(pM), pOutputBlob, pMalloc, SerializeFlags,
  233. pOutputStream,
  234. /*bDebugInfo*/ false, llvm::StringRef(), Diag);
  235. } else {
  236. dxcutil::AssembleToContainer(std::move(pM), pOutputBlob, m_pMalloc,
  237. SerializeFlags, pOutputStream);
  238. }
  239. // Callback after valid DXIL is produced
  240. if (SUCCEEDED(valHR)) {
  241. CComPtr<IDxcBlob> pTargetBlob;
  242. if (m_pDxcContainerEventsHandler != nullptr) {
  243. HRESULT hr = m_pDxcContainerEventsHandler->OnDxilContainerBuilt(
  244. pOutputBlob, &pTargetBlob);
  245. if (SUCCEEDED(hr) && pTargetBlob != nullptr) {
  246. std::swap(pOutputBlob, pTargetBlob);
  247. }
  248. }
  249. // TODO: DFCC_ShaderDebugName
  250. }
  251. hasErrorOccurred = Diag.hasErrorOccurred();
  252. } else {
  253. hasErrorOccurred = true;
  254. }
  255. }
  256. DiagStream.flush();
  257. CComPtr<IStream> pStream = pDiagStream;
  258. dxcutil::CreateOperationResultFromOutputs(pOutputBlob, pStream, warnings,
  259. hasErrorOccurred, ppResult);
  260. }
  261. CATCH_CPP_ASSIGN_HRESULT();
  262. return hr;
  263. }
  264. HRESULT CreateDxcLinker(_In_ REFIID riid, _Out_ LPVOID *ppv) {
  265. *ppv = nullptr;
  266. try {
  267. CComPtr<DxcLinker> result(DxcLinker::Alloc(DxcGetThreadMallocNoRef()));
  268. IFROOM(result.p);
  269. result->Initialize();
  270. return result.p->QueryInterface(riid, ppv);
  271. }
  272. CATCH_CPP_RETURN_HRESULT();
  273. }