dxcapi.use.cpp 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // dxcapi.use.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides support for DXC API users. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/Support/WinIncludes.h"
  12. #include "dxc/Support/dxcapi.use.h"
  13. #include "dxc/Support/Global.h"
  14. #include "dxc/Support/Unicode.h"
  15. #include "dxc/Support/FileIOHelper.h"
  16. #include "dxc/Support/WinFunctions.h"
  17. namespace dxc {
  18. #ifdef _WIN32
  19. static void TrimEOL(_Inout_z_ char *pMsg) {
  20. char *pEnd = pMsg + strlen(pMsg);
  21. --pEnd;
  22. while (pEnd > pMsg && (*pEnd == '\r' || *pEnd == '\n')) {
  23. --pEnd;
  24. }
  25. pEnd[1] = '\0';
  26. }
  27. static std::string GetWin32ErrorMessage(DWORD err) {
  28. char formattedMsg[200];
  29. DWORD formattedMsgLen =
  30. FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
  31. nullptr, err, 0, formattedMsg, _countof(formattedMsg), 0);
  32. if (formattedMsg > 0 && formattedMsgLen < _countof(formattedMsg)) {
  33. TrimEOL(formattedMsg);
  34. return std::string(formattedMsg);
  35. }
  36. return std::string();
  37. }
  38. #else
  39. static std::string GetWin32ErrorMessage(DWORD err) {
  40. // Since we use errno for handling messages, we use strerror to get the error
  41. // message.
  42. return std::string(std::strerror(err));
  43. }
  44. #endif // _WIN32
  45. void IFT_Data(HRESULT hr, LPCWSTR data) {
  46. if (SUCCEEDED(hr)) return;
  47. CW2A pData(data, CP_UTF8);
  48. std::string errMsg;
  49. if (HRESULT_IS_WIN32ERR(hr)) {
  50. DWORD err = HRESULT_AS_WIN32ERR(hr);
  51. errMsg.append(GetWin32ErrorMessage(err));
  52. if (data != nullptr) {
  53. errMsg.append(" ", 1);
  54. }
  55. }
  56. if (data != nullptr) {
  57. errMsg.append(pData);
  58. }
  59. throw ::hlsl::Exception(hr, errMsg);
  60. }
  61. void EnsureEnabled(DxcDllSupport &dxcSupport) {
  62. if (!dxcSupport.IsEnabled()) {
  63. IFT(dxcSupport.Initialize());
  64. }
  65. }
  66. void ReadFileIntoBlob(DxcDllSupport &dxcSupport, _In_ LPCWSTR pFileName,
  67. _COM_Outptr_ IDxcBlobEncoding **ppBlobEncoding) {
  68. CComPtr<IDxcLibrary> library;
  69. IFT(dxcSupport.CreateInstance(CLSID_DxcLibrary, &library));
  70. IFT_Data(library->CreateBlobFromFile(pFileName, nullptr, ppBlobEncoding),
  71. pFileName);
  72. }
  73. void WriteOperationErrorsToConsole(_In_ IDxcOperationResult *pResult,
  74. bool outputWarnings) {
  75. HRESULT status;
  76. IFT(pResult->GetStatus(&status));
  77. if (FAILED(status) || outputWarnings) {
  78. CComPtr<IDxcBlobEncoding> pErrors;
  79. IFT(pResult->GetErrorBuffer(&pErrors));
  80. if (pErrors.p != nullptr) {
  81. WriteBlobToConsole(pErrors, STD_ERROR_HANDLE);
  82. }
  83. }
  84. }
  85. void WriteOperationResultToConsole(_In_ IDxcOperationResult *pRewriteResult,
  86. bool outputWarnings) {
  87. WriteOperationErrorsToConsole(pRewriteResult, outputWarnings);
  88. CComPtr<IDxcBlob> pBlob;
  89. IFT(pRewriteResult->GetResult(&pBlob));
  90. WriteBlobToConsole(pBlob, STD_OUTPUT_HANDLE);
  91. }
  92. static void WriteUtf16NullTermToConsole(_In_opt_count_(charCount) const wchar_t *pText,
  93. DWORD streamType) {
  94. if (pText == nullptr) {
  95. return;
  96. }
  97. bool lossy; // Note: even if there was loss, print anyway
  98. std::string consoleMessage;
  99. Unicode::UTF16ToConsoleString(pText, &consoleMessage, &lossy);
  100. if (streamType == STD_OUTPUT_HANDLE) {
  101. fprintf(stdout, "%s\n", consoleMessage.c_str());
  102. }
  103. else if (streamType == STD_ERROR_HANDLE) {
  104. fprintf(stderr, "%s\n", consoleMessage.c_str());
  105. }
  106. else {
  107. throw hlsl::Exception(E_INVALIDARG);
  108. }
  109. }
  110. static HRESULT BlobToUtf8IfText(_In_opt_ IDxcBlob *pBlob, IDxcBlobUtf8 **ppBlobUtf8) {
  111. CComPtr<IDxcBlobEncoding> pBlobEncoding;
  112. if (SUCCEEDED(pBlob->QueryInterface(&pBlobEncoding))) {
  113. BOOL known;
  114. UINT32 cp = 0;
  115. IFT(pBlobEncoding->GetEncoding(&known, &cp));
  116. if (known) {
  117. return hlsl::DxcGetBlobAsUtf8(pBlob, nullptr, ppBlobUtf8);
  118. }
  119. }
  120. return S_OK;
  121. }
  122. static HRESULT BlobToUtf16IfText(_In_opt_ IDxcBlob *pBlob, IDxcBlobUtf16 **ppBlobUtf16) {
  123. CComPtr<IDxcBlobEncoding> pBlobEncoding;
  124. if (SUCCEEDED(pBlob->QueryInterface(&pBlobEncoding))) {
  125. BOOL known;
  126. UINT32 cp = 0;
  127. IFT(pBlobEncoding->GetEncoding(&known, &cp));
  128. if (known) {
  129. return hlsl::DxcGetBlobAsUtf16(pBlob, nullptr, ppBlobUtf16);
  130. }
  131. }
  132. return S_OK;
  133. }
  134. void WriteBlobToConsole(_In_opt_ IDxcBlob *pBlob, DWORD streamType) {
  135. if (pBlob == nullptr) {
  136. return;
  137. }
  138. // Try to get as UTF-16 or UTF-8
  139. BOOL known;
  140. UINT32 cp = 0;
  141. CComPtr<IDxcBlobEncoding> pBlobEncoding;
  142. IFT(pBlob->QueryInterface(&pBlobEncoding));
  143. IFT(pBlobEncoding->GetEncoding(&known, &cp));
  144. if (cp == DXC_CP_UTF16) {
  145. CComPtr<IDxcBlobUtf16> pUtf16;
  146. IFT(hlsl::DxcGetBlobAsUtf16(pBlob, nullptr, &pUtf16));
  147. WriteUtf16NullTermToConsole(pUtf16->GetStringPointer(), streamType);
  148. } else if (cp == CP_UTF8) {
  149. CComPtr<IDxcBlobUtf8> pUtf8;
  150. IFT(hlsl::DxcGetBlobAsUtf8(pBlob, nullptr, &pUtf8));
  151. WriteUtf8ToConsoleSizeT(pUtf8->GetStringPointer(), pUtf8->GetStringLength(), streamType);
  152. }
  153. }
  154. void WriteBlobToFile(_In_opt_ IDxcBlob *pBlob, _In_ LPCWSTR pFileName, _In_ UINT32 textCodePage) {
  155. if (pBlob == nullptr) {
  156. return;
  157. }
  158. CHandle file(CreateFileW(pFileName, GENERIC_WRITE, FILE_SHARE_READ, nullptr,
  159. CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr));
  160. if (file == INVALID_HANDLE_VALUE) {
  161. IFT_Data(HRESULT_FROM_WIN32(GetLastError()), pFileName);
  162. }
  163. WriteBlobToHandle(pBlob, file, pFileName, textCodePage);
  164. }
  165. void WriteBlobToHandle(_In_opt_ IDxcBlob *pBlob, _In_ HANDLE hFile, _In_opt_ LPCWSTR pFileName, _In_ UINT32 textCodePage) {
  166. if (pBlob == nullptr) {
  167. return;
  168. }
  169. LPCVOID pPtr = pBlob->GetBufferPointer();
  170. SIZE_T size = pBlob->GetBufferSize();
  171. std::string BOM;
  172. CComPtr<IDxcBlobUtf8> pBlobUtf8;
  173. CComPtr<IDxcBlobUtf16> pBlobUtf16;
  174. if (textCodePage == DXC_CP_UTF8) {
  175. IFT_Data(BlobToUtf8IfText(pBlob, &pBlobUtf8), pFileName);
  176. if (pBlobUtf8) {
  177. pPtr = pBlobUtf8->GetStringPointer();
  178. size = pBlobUtf8->GetStringLength();
  179. // TBD: Should we write UTF-8 BOM?
  180. //BOM = "\xef\xbb\xbf"; // UTF-8
  181. }
  182. } else if (textCodePage == DXC_CP_UTF16) {
  183. IFT_Data(BlobToUtf16IfText(pBlob, &pBlobUtf16), pFileName);
  184. if (pBlobUtf16) {
  185. pPtr = pBlobUtf16->GetStringPointer();
  186. size = pBlobUtf16->GetStringLength() * sizeof(wchar_t);
  187. BOM = "\xff\xfe"; // UTF-16 LE
  188. }
  189. }
  190. IFT_Data(size > (SIZE_T)UINT32_MAX ? E_OUTOFMEMORY : S_OK , pFileName);
  191. DWORD written;
  192. if (!BOM.empty()) {
  193. if (FALSE == WriteFile(hFile, BOM.data(), BOM.length(), &written, nullptr)) {
  194. IFT_Data(HRESULT_FROM_WIN32(GetLastError()), pFileName);
  195. }
  196. }
  197. if (FALSE == WriteFile(hFile, pPtr, (DWORD)size, &written, nullptr)) {
  198. IFT_Data(HRESULT_FROM_WIN32(GetLastError()), pFileName);
  199. }
  200. }
  201. void WriteUtf8ToConsole(_In_opt_count_(charCount) const char *pText,
  202. int charCount, DWORD streamType) {
  203. if (charCount == 0 || pText == nullptr) {
  204. return;
  205. }
  206. std::string resultToPrint;
  207. wchar_t *utf16Message = nullptr;
  208. size_t utf16MessageLen;
  209. Unicode::UTF8BufferToUTF16Buffer(pText, charCount, &utf16Message,
  210. &utf16MessageLen);
  211. WriteUtf16NullTermToConsole(utf16Message, streamType);
  212. delete[] utf16Message;
  213. }
  214. void WriteUtf8ToConsoleSizeT(_In_opt_count_(charCount) const char *pText,
  215. size_t charCount, DWORD streamType) {
  216. if (charCount == 0) {
  217. return;
  218. }
  219. int charCountInt = 0;
  220. IFT(SizeTToInt(charCount, &charCountInt));
  221. WriteUtf8ToConsole(pText, charCountInt, streamType);
  222. }
  223. } // namespace dxc