FileTestUtils.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. //===- FileTestUtils.cpp ---- Implementation of FileTestUtils -------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "FileTestUtils.h"
  10. #include <algorithm>
  11. #include <sstream>
  12. #include "dxc/Support/HLSLOptions.h"
  13. #include "SpirvTestOptions.h"
  14. #include "gtest/gtest.h"
  15. namespace clang {
  16. namespace spirv {
  17. namespace utils {
  18. bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
  19. std::string *generatedSpirvAsm,
  20. bool generateHeader) {
  21. spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
  22. spirvTools.SetMessageConsumer(
  23. [](spv_message_level_t, const char *, const spv_position_t &,
  24. const char *message) { fprintf(stdout, "%s\n", message); });
  25. uint32_t options = SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
  26. if (!generateHeader)
  27. options |= SPV_BINARY_TO_TEXT_OPTION_NO_HEADER;
  28. return spirvTools.Disassemble(binary, generatedSpirvAsm, options);
  29. }
  30. bool validateSpirvBinary(spv_target_env env, std::vector<uint32_t> &binary,
  31. bool beforeHlslLegalization, bool glLayout,
  32. bool dxLayout, bool scalarLayout,
  33. std::string *message) {
  34. spvtools::ValidatorOptions options;
  35. options.SetBeforeHlslLegalization(beforeHlslLegalization);
  36. if (dxLayout || scalarLayout) {
  37. options.SetScalarBlockLayout(true);
  38. } else if (glLayout) {
  39. // The default for spirv-val.
  40. } else {
  41. options.SetRelaxBlockLayout(true);
  42. }
  43. spvtools::SpirvTools spirvTools(env);
  44. spirvTools.SetMessageConsumer([message](spv_message_level_t, const char *,
  45. const spv_position_t &,
  46. const char *msg) {
  47. if (message)
  48. *message = msg;
  49. else
  50. fprintf(stdout, "%s\n", msg);
  51. });
  52. return spirvTools.Validate(binary.data(), binary.size(), options);
  53. }
  54. bool processRunCommandArgs(const llvm::StringRef runCommandLine,
  55. std::string *targetProfile, std::string *entryPoint,
  56. std::vector<std::string> *restArgs) {
  57. std::istringstream buf(runCommandLine);
  58. std::istream_iterator<std::string> start(buf), end;
  59. std::vector<std::string> tokens(start, end);
  60. if (tokens.size() < 3 || tokens[1].find("Run") == std::string::npos ||
  61. tokens[2].find("%dxc") == std::string::npos) {
  62. fprintf(stderr, "The only supported format is: \"// Run: %%dxc -T "
  63. "<profile> -E <entry>\"\n");
  64. return false;
  65. }
  66. std::ostringstream rest;
  67. for (uint32_t i = 3; i < tokens.size(); ++i) {
  68. if (tokens[i] == "-T" && (++i) < tokens.size())
  69. *targetProfile = tokens[i];
  70. else if (tokens[i] == "-E" && (++i) < tokens.size())
  71. *entryPoint = tokens[i];
  72. else
  73. restArgs->push_back(tokens[i]);
  74. }
  75. if (targetProfile->empty()) {
  76. fprintf(stderr, "Error: Missing target profile argument (-T).\n");
  77. return false;
  78. }
  79. // lib_6_* profile doesn't need an entryPoint
  80. if (targetProfile->c_str()[0] != 'l' && entryPoint->empty()) {
  81. fprintf(stderr, "Error: Missing entry point argument (-E).\n");
  82. return false;
  83. }
  84. return true;
  85. }
  86. void convertIDxcBlobToUint32(const CComPtr<IDxcBlob> &blob,
  87. std::vector<uint32_t> *binaryWords) {
  88. size_t num32BitWords = (blob->GetBufferSize() + 3) / 4;
  89. std::string binaryStr((char *)blob->GetBufferPointer(),
  90. blob->GetBufferSize());
  91. binaryStr.resize(num32BitWords * 4, 0);
  92. binaryWords->resize(num32BitWords, 0);
  93. memcpy(binaryWords->data(), binaryStr.data(), binaryStr.size());
  94. }
  95. std::string getAbsPathOfInputDataFile(const llvm::StringRef filename) {
  96. std::string path = clang::spirv::testOptions::inputDataDir;
  97. #ifdef _WIN32
  98. const char sep = '\\';
  99. std::replace(path.begin(), path.end(), '/', '\\');
  100. #else
  101. const char sep = '/';
  102. #endif
  103. if (path[path.size() - 1] != sep) {
  104. path = path + sep;
  105. }
  106. path += filename;
  107. return path;
  108. }
  109. bool runCompilerWithSpirvGeneration(const llvm::StringRef inputFilePath,
  110. const llvm::StringRef entryPoint,
  111. const llvm::StringRef targetProfile,
  112. const std::vector<std::string> &restArgs,
  113. std::vector<uint32_t> *generatedBinary,
  114. std::string *errorMessages) {
  115. std::wstring srcFile(inputFilePath.begin(), inputFilePath.end());
  116. std::wstring entry(entryPoint.begin(), entryPoint.end());
  117. std::wstring profile(targetProfile.begin(), targetProfile.end());
  118. std::vector<std::wstring> rest;
  119. for (const auto &arg : restArgs)
  120. rest.emplace_back(arg.begin(), arg.end());
  121. bool success = true;
  122. try {
  123. dxc::DxcDllSupport dllSupport;
  124. IFT(dllSupport.Initialize());
  125. if (hlsl::options::initHlslOptTable())
  126. throw std::bad_alloc();
  127. CComPtr<IDxcLibrary> pLibrary;
  128. CComPtr<IDxcCompiler> pCompiler;
  129. CComPtr<IDxcOperationResult> pResult;
  130. CComPtr<IDxcBlobEncoding> pSource;
  131. CComPtr<IDxcBlobEncoding> pErrorBuffer;
  132. CComPtr<IDxcBlob> pCompiledBlob;
  133. CComPtr<IDxcIncludeHandler> pIncludeHandler;
  134. HRESULT resultStatus;
  135. bool requires_opt = false;
  136. for (const auto &arg : rest)
  137. if (arg == L"-O3" || arg.substr(0, 8) == L"-Oconfig")
  138. requires_opt = true;
  139. std::vector<LPCWSTR> flags;
  140. // lib_6_* profile doesn't need an entryPoint
  141. if (profile.c_str()[0] != 'l') {
  142. flags.push_back(L"-E");
  143. flags.push_back(entry.c_str());
  144. }
  145. flags.push_back(L"-T");
  146. flags.push_back(profile.c_str());
  147. flags.push_back(L"-spirv");
  148. // Disable legalization and optimization for testing, unless the caller
  149. // wants to run a specific optimization recipe (with -Oconfig).
  150. if (!requires_opt)
  151. flags.push_back(L"-fcgl");
  152. // Disable validation. We'll run it manually.
  153. flags.push_back(L"-Vd");
  154. for (const auto &arg : rest)
  155. flags.push_back(arg.c_str());
  156. IFT(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
  157. IFT(pLibrary->CreateBlobFromFile(srcFile.c_str(), nullptr, &pSource));
  158. IFT(pLibrary->CreateIncludeHandler(&pIncludeHandler));
  159. IFT(dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
  160. IFT(pCompiler->Compile(pSource, srcFile.c_str(), entry.c_str(),
  161. profile.c_str(), flags.data(), flags.size(), nullptr,
  162. 0, pIncludeHandler, &pResult));
  163. // Compilation is done. We can clean up the HlslOptTable.
  164. hlsl::options::cleanupHlslOptTable();
  165. // Get compilation results.
  166. IFT(pResult->GetStatus(&resultStatus));
  167. // Get diagnostics string.
  168. IFT(pResult->GetErrorBuffer(&pErrorBuffer));
  169. const std::string diagnostics((char *)pErrorBuffer->GetBufferPointer(),
  170. pErrorBuffer->GetBufferSize());
  171. *errorMessages = diagnostics;
  172. if (SUCCEEDED(resultStatus)) {
  173. CComPtr<IDxcBlobEncoding> pStdErr;
  174. IFT(pResult->GetResult(&pCompiledBlob));
  175. convertIDxcBlobToUint32(pCompiledBlob, generatedBinary);
  176. success = true;
  177. } else {
  178. success = false;
  179. }
  180. } catch (...) {
  181. // An exception has occured while running the compiler with SPIR-V
  182. // Generation
  183. success = false;
  184. }
  185. return success;
  186. }
  187. } // end namespace utils
  188. } // end namespace spirv
  189. } // end namespace clang