DxcTestUtils.cpp 17 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxcTestUtils.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Utility function implementations for testing dxc APIs //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/Test/CompilationResult.h"
  12. #include "dxc/Test/DxcTestUtils.h"
  13. #include "dxc/Test/HlslTestUtils.h"
  14. #include "dxc/Support/HLSLOptions.h"
  15. #include "dxc/Support/Global.h"
  16. #include "llvm/ADT/StringRef.h"
  17. #include "llvm/ADT/APInt.h"
  18. #include "llvm/Support/ManagedStatic.h"
  19. #include "llvm/Support/Regex.h"
  20. #include "llvm/Support/FileSystem.h"
  21. using namespace std;
  22. using namespace hlsl_test;
  23. MODULE_SETUP(TestModuleSetup);
  24. MODULE_CLEANUP(TestModuleCleanup);
  25. bool TestModuleSetup() {
  26. // Use this module-level function to set up LLVM dependencies.
  27. if (llvm::sys::fs::SetupPerThreadFileSystem())
  28. return false;
  29. if (FAILED(DxcInitThreadMalloc()))
  30. return false;
  31. DxcSetThreadMallocToDefault();
  32. if (hlsl::options::initHlslOptTable()) {
  33. return false;
  34. }
  35. return true;
  36. }
  37. bool TestModuleCleanup() {
  38. // Use this module-level function to set up LLVM dependencies.
  39. // In particular, clean up managed static allocations used by
  40. // parsing options with the LLVM library.
  41. ::hlsl::options::cleanupHlslOptTable();
  42. ::llvm::llvm_shutdown();
  43. DxcClearThreadMalloc();
  44. DxcCleanupThreadMalloc();
  45. llvm::sys::fs::CleanupPerThreadFileSystem();
  46. return true;
  47. }
  48. std::shared_ptr<HlslIntellisenseSupport> CompilationResult::DefaultHlslSupport;
  49. void CheckOperationSucceeded(IDxcOperationResult *pResult, IDxcBlob **ppBlob) {
  50. HRESULT status;
  51. VERIFY_SUCCEEDED(pResult->GetStatus(&status));
  52. VERIFY_SUCCEEDED(status);
  53. VERIFY_SUCCEEDED(pResult->GetResult(ppBlob));
  54. }
  55. static bool CheckMsgs(llvm::StringRef text, llvm::ArrayRef<LPCSTR> pMsgs,
  56. bool bRegex) {
  57. const char *pStart = !text.empty() ? text.begin() : nullptr;
  58. const char *pEnd = !text.empty() ? text.end() : nullptr;
  59. for (auto pMsg : pMsgs) {
  60. if (bRegex) {
  61. llvm::Regex RE(pMsg);
  62. std::string reErrors;
  63. VERIFY_IS_TRUE(RE.isValid(reErrors));
  64. if (!RE.match(text)) {
  65. WEX::Logging::Log::Comment(WEX::Common::String().Format(
  66. L"Unable to find regex '%S' in text:\r\n%.*S", pMsg, (pEnd - pStart),
  67. pStart));
  68. VERIFY_IS_TRUE(false);
  69. }
  70. } else {
  71. const char *pMatch = std::search(pStart, pEnd, pMsg, pMsg + strlen(pMsg));
  72. if (pEnd == pMatch) {
  73. WEX::Logging::Log::Comment(WEX::Common::String().Format(
  74. L"Unable to find '%S' in text:\r\n%.*S", pMsg, (pEnd - pStart),
  75. pStart));
  76. }
  77. VERIFY_IS_FALSE(pEnd == pMatch);
  78. }
  79. }
  80. return true;
  81. }
  82. bool CheckMsgs(const LPCSTR pText, size_t TextCount, const LPCSTR *pErrorMsgs,
  83. size_t errorMsgCount, bool bRegex) {
  84. return CheckMsgs(llvm::StringRef(pText, TextCount),
  85. llvm::ArrayRef<LPCSTR>(pErrorMsgs, errorMsgCount), bRegex);
  86. }
  87. static bool CheckNotMsgs(llvm::StringRef text, llvm::ArrayRef<LPCSTR> pMsgs,
  88. bool bRegex) {
  89. const char *pStart = !text.empty() ? text.begin() : nullptr;
  90. const char *pEnd = !text.empty() ? text.end() : nullptr;
  91. for (auto pMsg : pMsgs) {
  92. if (bRegex) {
  93. llvm::Regex RE(pMsg);
  94. std::string reErrors;
  95. VERIFY_IS_TRUE(RE.isValid(reErrors));
  96. if (RE.match(text)) {
  97. WEX::Logging::Log::Comment(WEX::Common::String().Format(
  98. L"Unexpectedly found regex '%S' in text:\r\n%.*S", pMsg, (pEnd - pStart),
  99. pStart));
  100. VERIFY_IS_TRUE(false);
  101. }
  102. }
  103. else {
  104. const char *pMatch = std::search(pStart, pEnd, pMsg, pMsg + strlen(pMsg));
  105. if (pEnd != pMatch) {
  106. WEX::Logging::Log::Comment(WEX::Common::String().Format(
  107. L"Unexpectedly found '%S' in text:\r\n%.*S", pMsg, (pEnd - pStart),
  108. pStart));
  109. }
  110. VERIFY_IS_TRUE(pEnd == pMatch);
  111. }
  112. }
  113. return true;
  114. }
  115. bool CheckNotMsgs(const LPCSTR pText, size_t TextCount, const LPCSTR *pErrorMsgs,
  116. size_t errorMsgCount, bool bRegex) {
  117. return CheckNotMsgs(llvm::StringRef(pText, TextCount),
  118. llvm::ArrayRef<LPCSTR>(pErrorMsgs, errorMsgCount), bRegex);
  119. }
  120. bool CheckOperationResultMsgs(IDxcOperationResult *pResult,
  121. llvm::ArrayRef<LPCSTR> pErrorMsgs,
  122. bool maySucceedAnyway, bool bRegex) {
  123. HRESULT status;
  124. CComPtr<IDxcBlobEncoding> textBlob;
  125. if (!pResult)
  126. return true;
  127. VERIFY_SUCCEEDED(pResult->GetStatus(&status));
  128. VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&textBlob));
  129. std::string textUtf8 = BlobToUtf8(textBlob);
  130. const char *pStart = !textUtf8.empty() ? textUtf8.c_str() : nullptr;
  131. const char *pEnd = !textUtf8.empty() ? pStart + textUtf8.length() : nullptr;
  132. if (pErrorMsgs.empty() || (pErrorMsgs.size() == 1 && !pErrorMsgs[0])) {
  133. if (FAILED(status) && pStart) {
  134. WEX::Logging::Log::Comment(WEX::Common::String().Format(
  135. L"Expected success but found errors\r\n%.*S", (pEnd - pStart),
  136. pStart));
  137. }
  138. VERIFY_SUCCEEDED(status);
  139. } else {
  140. if (SUCCEEDED(status) && maySucceedAnyway) {
  141. return false;
  142. }
  143. CheckMsgs(textUtf8, pErrorMsgs, bRegex);
  144. }
  145. return true;
  146. }
  147. bool CheckOperationResultMsgs(IDxcOperationResult *pResult,
  148. const LPCSTR *pErrorMsgs, size_t errorMsgCount,
  149. bool maySucceedAnyway, bool bRegex) {
  150. return CheckOperationResultMsgs(
  151. pResult, llvm::ArrayRef<LPCSTR>(pErrorMsgs, errorMsgCount),
  152. maySucceedAnyway, bRegex);
  153. }
  154. std::string DisassembleProgram(dxc::DxcDllSupport &dllSupport,
  155. IDxcBlob *pProgram) {
  156. CComPtr<IDxcCompiler> pCompiler;
  157. CComPtr<IDxcBlobEncoding> pDisassembly;
  158. if (!dllSupport.IsEnabled()) {
  159. VERIFY_SUCCEEDED(dllSupport.Initialize());
  160. }
  161. VERIFY_SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
  162. VERIFY_SUCCEEDED(pCompiler->Disassemble(pProgram, &pDisassembly));
  163. return BlobToUtf8(pDisassembly);
  164. }
  165. void AssembleToContainer(dxc::DxcDllSupport &dllSupport, IDxcBlob *pModule,
  166. IDxcBlob **pContainer) {
  167. CComPtr<IDxcAssembler> pAssembler;
  168. CComPtr<IDxcOperationResult> pResult;
  169. VERIFY_SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcAssembler, &pAssembler));
  170. VERIFY_SUCCEEDED(pAssembler->AssembleToContainer(pModule, &pResult));
  171. CheckOperationSucceeded(pResult, pContainer);
  172. }
  173. ///////////////////////////////////////////////////////////////////////////////
  174. // Helper functions to deal with passes.
  175. void SplitPassList(LPWSTR pPassesBuffer, std::vector<LPCWSTR> &passes) {
  176. while (*pPassesBuffer) {
  177. // Skip comment lines.
  178. if (*pPassesBuffer == L'#') {
  179. while (*pPassesBuffer && *pPassesBuffer != '\n' &&
  180. *pPassesBuffer != '\r') {
  181. ++pPassesBuffer;
  182. }
  183. while (*pPassesBuffer == '\n' || *pPassesBuffer == '\r') {
  184. ++pPassesBuffer;
  185. }
  186. continue;
  187. }
  188. // Every other line is an option. Find the end of the line/buffer and
  189. // terminate it.
  190. passes.push_back(pPassesBuffer);
  191. while (*pPassesBuffer && *pPassesBuffer != '\n' && *pPassesBuffer != '\r') {
  192. ++pPassesBuffer;
  193. }
  194. while (*pPassesBuffer == '\n' || *pPassesBuffer == '\r') {
  195. *pPassesBuffer = L'\0';
  196. ++pPassesBuffer;
  197. }
  198. }
  199. }
  200. std::string BlobToUtf8(_In_ IDxcBlob *pBlob) {
  201. if (!pBlob)
  202. return std::string();
  203. const UINT CP_UTF16 = 1200;
  204. CComPtr<IDxcBlobUtf8> pBlobUtf8;
  205. if (SUCCEEDED(pBlob->QueryInterface(&pBlobUtf8)))
  206. return std::string(pBlobUtf8->GetStringPointer(), pBlobUtf8->GetStringLength());
  207. CComPtr<IDxcBlobEncoding> pBlobEncoding;
  208. IFT(pBlob->QueryInterface(&pBlobEncoding));
  209. //if (FAILED(pBlob->QueryInterface(&pBlobEncoding))) {
  210. // // Assume it is already UTF-8
  211. // return std::string((const char*)pBlob->GetBufferPointer(),
  212. // pBlob->GetBufferSize());
  213. //}
  214. BOOL known;
  215. UINT32 codePage;
  216. IFT(pBlobEncoding->GetEncoding(&known, &codePage));
  217. if (!known) {
  218. throw std::runtime_error("unknown codepage for blob.");
  219. }
  220. std::string result;
  221. if (codePage == CP_UTF16) {
  222. const wchar_t* text = (const wchar_t *)pBlob->GetBufferPointer();
  223. size_t length = pBlob->GetBufferSize() / 2;
  224. if (length >= 1 && text[length-1] == L'\0')
  225. length -= 1; // Exclude null-terminator
  226. Unicode::UTF16ToUTF8String(text, length, &result);
  227. return result;
  228. } else if (codePage == CP_UTF8) {
  229. const char* text = (const char *)pBlob->GetBufferPointer();
  230. size_t length = pBlob->GetBufferSize();
  231. if (length >= 1 && text[length-1] == '\0')
  232. length -= 1; // Exclude null-terminator
  233. result.resize(length);
  234. memcpy((void *)result.data(), text, length);
  235. return result;
  236. } else {
  237. throw std::runtime_error("Unsupported codepage.");
  238. }
  239. }
  240. std::wstring BlobToUtf16(_In_ IDxcBlob *pBlob) {
  241. if (!pBlob)
  242. return std::wstring();
  243. const UINT CP_UTF16 = 1200;
  244. CComPtr<IDxcBlobUtf16> pBlobUtf16;
  245. if (SUCCEEDED(pBlob->QueryInterface(&pBlobUtf16)))
  246. return std::wstring(pBlobUtf16->GetStringPointer(), pBlobUtf16->GetStringLength());
  247. CComPtr<IDxcBlobEncoding> pBlobEncoding;
  248. IFT(pBlob->QueryInterface(&pBlobEncoding));
  249. BOOL known;
  250. UINT32 codePage;
  251. IFT(pBlobEncoding->GetEncoding(&known, &codePage));
  252. if (!known) {
  253. throw std::runtime_error("unknown codepage for blob.");
  254. }
  255. std::wstring result;
  256. if (codePage == CP_UTF16) {
  257. const wchar_t* text = (const wchar_t *)pBlob->GetBufferPointer();
  258. size_t length = pBlob->GetBufferSize() / 2;
  259. if (length >= 1 && text[length-1] == L'\0')
  260. length -= 1; // Exclude null-terminator
  261. result.resize(length);
  262. memcpy((void *)result.data(), text, length);
  263. return result;
  264. } else if (codePage == CP_UTF8) {
  265. const char* text = (const char *)pBlob->GetBufferPointer();
  266. size_t length = pBlob->GetBufferSize();
  267. if (length >= 1 && text[length-1] == '\0')
  268. length -= 1; // Exclude null-terminator
  269. Unicode::UTF8ToUTF16String(text, length, &result);
  270. return result;
  271. } else {
  272. throw std::runtime_error("Unsupported codepage.");
  273. }
  274. }
  275. void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const char *pVal,
  276. _Outptr_ IDxcBlobEncoding **ppBlob) {
  277. CComPtr<IDxcLibrary> library;
  278. IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &library));
  279. IFT(library->CreateBlobWithEncodingOnHeapCopy(pVal, strlen(pVal), CP_UTF8,
  280. ppBlob));
  281. }
  282. void MultiByteStringToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val,
  283. UINT32 codePage, _Outptr_ IDxcBlobEncoding **ppBlob) {
  284. CComPtr<IDxcLibrary> library;
  285. IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &library));
  286. IFT(library->CreateBlobWithEncodingOnHeapCopy(val.data(), val.size(),
  287. codePage, ppBlob));
  288. }
  289. void MultiByteStringToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val,
  290. UINT32 codePage, _Outptr_ IDxcBlob **ppBlob) {
  291. MultiByteStringToBlob(dllSupport, val, codePage, (IDxcBlobEncoding **)ppBlob);
  292. }
  293. void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val,
  294. _Outptr_ IDxcBlobEncoding **ppBlob) {
  295. MultiByteStringToBlob(dllSupport, val, CP_UTF8, ppBlob);
  296. }
  297. void Utf8ToBlob(dxc::DxcDllSupport &dllSupport, const std::string &val,
  298. _Outptr_ IDxcBlob **ppBlob) {
  299. Utf8ToBlob(dllSupport, val, (IDxcBlobEncoding **)ppBlob);
  300. }
  301. void Utf16ToBlob(dxc::DxcDllSupport &dllSupport, const std::wstring &val,
  302. _Outptr_ IDxcBlobEncoding **ppBlob) {
  303. const UINT32 CP_UTF16 = 1200;
  304. CComPtr<IDxcLibrary> library;
  305. IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &library));
  306. IFT(library->CreateBlobWithEncodingOnHeapCopy(
  307. val.data(), val.size() * sizeof(wchar_t), CP_UTF16, ppBlob));
  308. }
  309. void Utf16ToBlob(dxc::DxcDllSupport &dllSupport, const std::wstring &val,
  310. _Outptr_ IDxcBlob **ppBlob) {
  311. Utf16ToBlob(dllSupport, val, (IDxcBlobEncoding **)ppBlob);
  312. }
  313. void VerifyCompileOK(dxc::DxcDllSupport &dllSupport, LPCSTR pText,
  314. LPWSTR pTargetProfile, LPCWSTR pArgs,
  315. _Outptr_ IDxcBlob **ppResult) {
  316. std::vector<std::wstring> argsW;
  317. std::vector<LPCWSTR> args;
  318. if (pArgs) {
  319. wistringstream argsS(pArgs);
  320. copy(istream_iterator<wstring, wchar_t>(argsS),
  321. istream_iterator<wstring, wchar_t>(), back_inserter(argsW));
  322. transform(argsW.begin(), argsW.end(), back_inserter(args),
  323. [](const wstring &w) { return w.data(); });
  324. }
  325. VerifyCompileOK(dllSupport, pText, pTargetProfile, args, ppResult);
  326. }
  327. void VerifyCompileOK(dxc::DxcDllSupport &dllSupport, LPCSTR pText,
  328. LPWSTR pTargetProfile, std::vector<LPCWSTR> &args,
  329. _Outptr_ IDxcBlob **ppResult) {
  330. CComPtr<IDxcCompiler> pCompiler;
  331. CComPtr<IDxcBlobEncoding> pSource;
  332. CComPtr<IDxcOperationResult> pResult;
  333. HRESULT hrCompile;
  334. *ppResult = nullptr;
  335. VERIFY_SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
  336. Utf8ToBlob(dllSupport, pText, &pSource);
  337. VERIFY_SUCCEEDED(pCompiler->Compile(pSource, L"source.hlsl", L"main",
  338. pTargetProfile, args.data(), args.size(),
  339. nullptr, 0, nullptr, &pResult));
  340. VERIFY_SUCCEEDED(pResult->GetStatus(&hrCompile));
  341. VERIFY_SUCCEEDED(hrCompile);
  342. VERIFY_SUCCEEDED(pResult->GetResult(ppResult));
  343. }
  344. HRESULT GetVersion(dxc::DxcDllSupport& DllSupport, REFCLSID clsid, unsigned &Major, unsigned &Minor) {
  345. CComPtr<IUnknown> pUnk;
  346. if (SUCCEEDED(DllSupport.CreateInstance(clsid, &pUnk))) {
  347. CComPtr<IDxcVersionInfo> pVersionInfo;
  348. IFR(pUnk.QueryInterface(&pVersionInfo));
  349. IFR(pVersionInfo->GetVersion(&Major, &Minor));
  350. }
  351. return S_OK;
  352. }
  353. bool ParseTargetProfile(llvm::StringRef targetProfile, llvm::StringRef &outStage, unsigned &outMajor, unsigned &outMinor) {
  354. auto stage_model = targetProfile.split("_");
  355. auto major_minor = stage_model.second.split("_");
  356. llvm::APInt major;
  357. if (major_minor.first.getAsInteger(16, major))
  358. return false;
  359. if (major_minor.second.compare("x") == 0) {
  360. outMinor = 0xF; // indicates offline target
  361. } else {
  362. llvm::APInt minor;
  363. if (major_minor.second.getAsInteger(16, minor))
  364. return false;
  365. outMinor = (unsigned)minor.getLimitedValue();
  366. }
  367. outStage = stage_model.first;
  368. outMajor = (unsigned)major.getLimitedValue();
  369. return true;
  370. }
  371. // VersionSupportInfo Implementation
  372. VersionSupportInfo::VersionSupportInfo()
  373. : m_CompilerIsDebugBuild(false), m_InternalValidator(false), m_DxilMajor(0),
  374. m_DxilMinor(0), m_ValMajor(0), m_ValMinor(0) {}
  375. void VersionSupportInfo::Initialize(dxc::DxcDllSupport &dllSupport) {
  376. VERIFY_IS_TRUE(dllSupport.IsEnabled());
  377. // Default to Dxil 1.0 and internal Val 1.0
  378. m_DxilMajor = m_ValMajor = 1;
  379. m_DxilMinor = m_ValMinor = 0;
  380. m_InternalValidator = true;
  381. CComPtr<IDxcVersionInfo> pVersionInfo;
  382. UINT32 VersionFlags = 0;
  383. // If the following fails, we have Dxil 1.0 compiler
  384. if (SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcCompiler, &pVersionInfo))) {
  385. VERIFY_SUCCEEDED(pVersionInfo->GetVersion(&m_DxilMajor, &m_DxilMinor));
  386. VERIFY_SUCCEEDED(pVersionInfo->GetFlags(&VersionFlags));
  387. m_CompilerIsDebugBuild =
  388. (VersionFlags & DxcVersionInfoFlags_Debug) ? true : false;
  389. pVersionInfo.Release();
  390. }
  391. if (SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcValidator, &pVersionInfo))) {
  392. VERIFY_SUCCEEDED(pVersionInfo->GetVersion(&m_ValMajor, &m_ValMinor));
  393. VERIFY_SUCCEEDED(pVersionInfo->GetFlags(&VersionFlags));
  394. if (m_ValMinor > 0) {
  395. // flag only exists on newer validator, assume internal otherwise.
  396. m_InternalValidator =
  397. (VersionFlags & DxcVersionInfoFlags_Internal) ? true : false;
  398. } else {
  399. // With old compiler, validator is the only way to get this
  400. m_CompilerIsDebugBuild =
  401. (VersionFlags & DxcVersionInfoFlags_Debug) ? true : false;
  402. }
  403. } else {
  404. // If create instance of IDxcVersionInfo on validator failed, we have an old
  405. // validator from dxil.dll
  406. m_InternalValidator = false;
  407. }
  408. }
  409. bool VersionSupportInfo::SkipIRSensitiveTest() {
  410. // Only debug builds preserve BB names.
  411. if (!m_CompilerIsDebugBuild) {
  412. WEX::Logging::Log::Comment(
  413. L"Test skipped due to name preservation requirement.");
  414. return true;
  415. }
  416. return false;
  417. }
  418. bool VersionSupportInfo::SkipDxilVersion(unsigned major, unsigned minor) {
  419. if (m_DxilMajor < major || (m_DxilMajor == major && m_DxilMinor < minor) ||
  420. m_ValMajor < major || (m_ValMajor == major && m_ValMinor < minor)) {
  421. WEX::Logging::Log::Comment(WEX::Common::String().Format(
  422. L"Test skipped because it requires Dxil %u.%u and Validator %u.%u.",
  423. major, minor, major, minor));
  424. return true;
  425. }
  426. return false;
  427. }
  428. bool VersionSupportInfo::SkipOutOfMemoryTest() { return false; }