FileTestUtils.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. //===- FileTestUtils.cpp ---- Implementation of FileTestUtils -------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "FileTestUtils.h"
  10. #include <algorithm>
  11. #include <sstream>
  12. #include "dxc/Support/HLSLOptions.h"
  13. #include "SPIRVTestOptions.h"
  14. #include "gtest/gtest.h"
  15. namespace clang {
  16. namespace spirv {
  17. namespace utils {
  18. bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
  19. std::string *generatedSpirvAsm,
  20. bool generateHeader) {
  21. spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
  22. spirvTools.SetMessageConsumer(
  23. [](spv_message_level_t, const char *, const spv_position_t &,
  24. const char *message) { fprintf(stdout, "%s\n", message); });
  25. uint32_t options = SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
  26. if (!generateHeader)
  27. options |= SPV_BINARY_TO_TEXT_OPTION_NO_HEADER;
  28. return spirvTools.Disassemble(binary, generatedSpirvAsm, options);
  29. }
  30. bool validateSpirvBinary(spv_target_env env, std::vector<uint32_t> &binary,
  31. bool relaxLogicalPointer, bool glLayout, bool dxLayout,
  32. std::string *message) {
  33. spvtools::ValidatorOptions options;
  34. options.SetRelaxLogicalPointer(relaxLogicalPointer);
  35. options.SetRelaxBlockLayout(!glLayout && !dxLayout);
  36. options.SetSkipBlockLayout(dxLayout);
  37. spvtools::SpirvTools spirvTools(env);
  38. spirvTools.SetMessageConsumer([message](spv_message_level_t, const char *,
  39. const spv_position_t &,
  40. const char *msg) {
  41. if (message)
  42. *message = msg;
  43. else
  44. fprintf(stdout, "%s\n", msg);
  45. });
  46. return spirvTools.Validate(binary.data(), binary.size(), options);
  47. }
  48. bool processRunCommandArgs(const llvm::StringRef runCommandLine,
  49. std::string *targetProfile, std::string *entryPoint,
  50. std::vector<std::string> *restArgs) {
  51. std::istringstream buf(runCommandLine);
  52. std::istream_iterator<std::string> start(buf), end;
  53. std::vector<std::string> tokens(start, end);
  54. if (tokens.size() < 3 || tokens[1].find("Run") == std::string::npos ||
  55. tokens[2].find("%dxc") == std::string::npos) {
  56. fprintf(stderr, "The only supported format is: \"// Run: %%dxc -T "
  57. "<profile> -E <entry>\"\n");
  58. return false;
  59. }
  60. std::ostringstream rest;
  61. for (uint32_t i = 3; i < tokens.size(); ++i) {
  62. if (tokens[i] == "-T" && (++i) < tokens.size())
  63. *targetProfile = tokens[i];
  64. else if (tokens[i] == "-E" && (++i) < tokens.size())
  65. *entryPoint = tokens[i];
  66. else
  67. restArgs->push_back(tokens[i]);
  68. }
  69. if (targetProfile->empty()) {
  70. fprintf(stderr, "Error: Missing target profile argument (-T).\n");
  71. return false;
  72. }
  73. if (entryPoint->empty()) {
  74. fprintf(stderr, "Error: Missing entry point argument (-E).\n");
  75. return false;
  76. }
  77. return true;
  78. }
  79. void convertIDxcBlobToUint32(const CComPtr<IDxcBlob> &blob,
  80. std::vector<uint32_t> *binaryWords) {
  81. size_t num32BitWords = (blob->GetBufferSize() + 3) / 4;
  82. std::string binaryStr((char *)blob->GetBufferPointer(),
  83. blob->GetBufferSize());
  84. binaryStr.resize(num32BitWords * 4, 0);
  85. binaryWords->resize(num32BitWords, 0);
  86. memcpy(binaryWords->data(), binaryStr.data(), binaryStr.size());
  87. }
  88. std::string getAbsPathOfInputDataFile(const llvm::StringRef filename) {
  89. std::string path = clang::spirv::testOptions::inputDataDir;
  90. #ifdef _WIN32
  91. const char sep = '\\';
  92. std::replace(path.begin(), path.end(), '/', '\\');
  93. #else
  94. const char sep = '/';
  95. #endif
  96. if (path[path.size() - 1] != sep) {
  97. path = path + sep;
  98. }
  99. path += filename;
  100. return path;
  101. }
  102. bool runCompilerWithSpirvGeneration(const llvm::StringRef inputFilePath,
  103. const llvm::StringRef entryPoint,
  104. const llvm::StringRef targetProfile,
  105. const std::vector<std::string> &restArgs,
  106. std::vector<uint32_t> *generatedBinary,
  107. std::string *errorMessages) {
  108. std::wstring srcFile(inputFilePath.begin(), inputFilePath.end());
  109. std::wstring entry(entryPoint.begin(), entryPoint.end());
  110. std::wstring profile(targetProfile.begin(), targetProfile.end());
  111. std::vector<std::wstring> rest;
  112. for (const auto &arg : restArgs)
  113. rest.emplace_back(arg.begin(), arg.end());
  114. bool success = true;
  115. try {
  116. dxc::DxcDllSupport dllSupport;
  117. IFT(dllSupport.Initialize());
  118. DxcInitThreadMalloc();
  119. if (hlsl::options::initHlslOptTable())
  120. throw std::bad_alloc();
  121. CComPtr<IDxcLibrary> pLibrary;
  122. CComPtr<IDxcCompiler> pCompiler;
  123. CComPtr<IDxcOperationResult> pResult;
  124. CComPtr<IDxcBlobEncoding> pSource;
  125. CComPtr<IDxcBlobEncoding> pErrorBuffer;
  126. CComPtr<IDxcBlob> pCompiledBlob;
  127. CComPtr<IDxcIncludeHandler> pIncludeHandler;
  128. HRESULT resultStatus;
  129. bool requires_opt = false;
  130. for (const auto &arg : rest)
  131. if (arg == L"-O3" || arg.substr(0, 8) == L"-Oconfig")
  132. requires_opt = true;
  133. std::vector<LPCWSTR> flags;
  134. flags.push_back(L"-E");
  135. flags.push_back(entry.c_str());
  136. flags.push_back(L"-T");
  137. flags.push_back(profile.c_str());
  138. flags.push_back(L"-spirv");
  139. // Disable legalization and optimization for testing, unless the caller
  140. // wants to run a specific optimization recipe (with -Oconfig).
  141. if (!requires_opt)
  142. flags.push_back(L"-fcgl");
  143. // Disable validation. We'll run it manually.
  144. flags.push_back(L"-Vd");
  145. for (const auto &arg : rest)
  146. flags.push_back(arg.c_str());
  147. IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
  148. IFT(pLibrary->CreateBlobFromFile(srcFile.c_str(), nullptr, &pSource));
  149. IFT(pLibrary->CreateIncludeHandler(&pIncludeHandler));
  150. IFT(dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
  151. IFT(pCompiler->Compile(pSource, srcFile.c_str(), entry.c_str(),
  152. profile.c_str(), flags.data(), flags.size(), nullptr,
  153. 0, pIncludeHandler, &pResult));
  154. // Compilation is done. We can clean up the HlslOptTable.
  155. hlsl::options::cleanupHlslOptTable();
  156. // Get compilation results.
  157. IFT(pResult->GetStatus(&resultStatus));
  158. // Get diagnostics string.
  159. IFT(pResult->GetErrorBuffer(&pErrorBuffer));
  160. const std::string diagnostics((char *)pErrorBuffer->GetBufferPointer(),
  161. pErrorBuffer->GetBufferSize());
  162. *errorMessages = diagnostics;
  163. if (SUCCEEDED(resultStatus)) {
  164. CComPtr<IDxcBlobEncoding> pStdErr;
  165. IFT(pResult->GetResult(&pCompiledBlob));
  166. convertIDxcBlobToUint32(pCompiledBlob, generatedBinary);
  167. success = true;
  168. } else {
  169. success = false;
  170. }
  171. } catch (...) {
  172. // An exception has occured while running the compiler with SPIR-V
  173. // Generation
  174. success = false;
  175. }
  176. DxcCleanupThreadMalloc();
  177. return success;
  178. }
  179. } // end namespace utils
  180. } // end namespace spirv
  181. } // end namespace clang