extract_source.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. // Copyright (c) 2023 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "extract_source.h"
  15. #include <cassert>
  16. #include <string>
  17. #include <unordered_map>
  18. #include <vector>
  19. #include "source/opt/log.h"
  20. #include "spirv-tools/libspirv.hpp"
  21. #include "spirv/unified1/spirv.hpp"
  22. #include "tools/util/cli_consumer.h"
  23. namespace {
  24. constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
  25. // Extract a string literal from a given range.
  26. // Copies all the characters from `begin` to the first '\0' it encounters, while
  27. // removing escape patterns.
  28. // Not finding a '\0' before reaching `end` fails the extraction.
  29. //
  30. // Returns `true` if the extraction succeeded.
  31. // `output` value is undefined if false is returned.
  32. spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin,
  33. const char* end, std::string* output) {
  34. size_t sourceLength = std::distance(begin, end);
  35. std::string escapedString;
  36. escapedString.resize(sourceLength);
  37. size_t writeIndex = 0;
  38. size_t readIndex = 0;
  39. for (; readIndex < sourceLength; writeIndex++, readIndex++) {
  40. const char read = begin[readIndex];
  41. if (read == '\0') {
  42. escapedString.resize(writeIndex);
  43. output->append(escapedString);
  44. return SPV_SUCCESS;
  45. }
  46. if (read == '\\') {
  47. ++readIndex;
  48. }
  49. escapedString[writeIndex] = begin[readIndex];
  50. }
  51. spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
  52. "Missing NULL terminator for literal string.");
  53. return SPV_ERROR_INVALID_BINARY;
  54. }
  55. spv_result_t extractOpString(const spv_position_t& loc,
  56. const spv_parsed_instruction_t& instruction,
  57. std::string* output) {
  58. assert(output != nullptr);
  59. assert(instruction.opcode == spv::Op::OpString);
  60. if (instruction.num_operands != 2) {
  61. spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
  62. "Missing operands for OpString.");
  63. return SPV_ERROR_INVALID_BINARY;
  64. }
  65. const auto& operand = instruction.operands[1];
  66. const char* stringBegin =
  67. reinterpret_cast<const char*>(instruction.words + operand.offset);
  68. const char* stringEnd = reinterpret_cast<const char*>(
  69. instruction.words + operand.offset + operand.num_words);
  70. return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
  71. }
  72. spv_result_t extractOpSourceContinued(
  73. const spv_position_t& loc, const spv_parsed_instruction_t& instruction,
  74. std::string* output) {
  75. assert(output != nullptr);
  76. assert(instruction.opcode == spv::Op::OpSourceContinued);
  77. if (instruction.num_operands != 1) {
  78. spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
  79. "Missing operands for OpSourceContinued.");
  80. return SPV_ERROR_INVALID_BINARY;
  81. }
  82. const auto& operand = instruction.operands[0];
  83. const char* stringBegin =
  84. reinterpret_cast<const char*>(instruction.words + operand.offset);
  85. const char* stringEnd = reinterpret_cast<const char*>(
  86. instruction.words + operand.offset + operand.num_words);
  87. return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
  88. }
  89. spv_result_t extractOpSource(const spv_position_t& loc,
  90. const spv_parsed_instruction_t& instruction,
  91. spv::Id* filename, std::string* code) {
  92. assert(filename != nullptr && code != nullptr);
  93. assert(instruction.opcode == spv::Op::OpSource);
  94. // OpCode [ Source Language | Version | File (optional) | Source (optional) ]
  95. if (instruction.num_words < 3) {
  96. spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
  97. "Missing operands for OpSource.");
  98. return SPV_ERROR_INVALID_BINARY;
  99. }
  100. *filename = 0;
  101. *code = "";
  102. if (instruction.num_words < 4) {
  103. return SPV_SUCCESS;
  104. }
  105. *filename = instruction.words[3];
  106. if (instruction.num_words < 5) {
  107. return SPV_SUCCESS;
  108. }
  109. const char* stringBegin =
  110. reinterpret_cast<const char*>(instruction.words + 4);
  111. const char* stringEnd =
  112. reinterpret_cast<const char*>(instruction.words + instruction.num_words);
  113. return ExtractStringLiteral(loc, stringBegin, stringEnd, code);
  114. }
  115. } // namespace
  116. bool ExtractSourceFromModule(
  117. const std::vector<uint32_t>& binary,
  118. std::unordered_map<std::string, std::string>* output) {
  119. auto context = spvtools::SpirvTools(kDefaultEnvironment);
  120. context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
  121. // There is nothing valuable in the header.
  122. spvtools::HeaderParser headerParser = [](const spv_endianness_t,
  123. const spv_parsed_header_t&) {
  124. return SPV_SUCCESS;
  125. };
  126. std::unordered_map<uint32_t, std::string> stringMap;
  127. std::vector<std::pair<spv::Id, std::string>> sources;
  128. spv::Op lastOpcode = spv::Op::OpMax;
  129. size_t instructionIndex = 0;
  130. spvtools::InstructionParser instructionParser =
  131. [&stringMap, &sources, &lastOpcode,
  132. &instructionIndex](const spv_parsed_instruction_t& instruction) {
  133. const spv_position_t loc = {0, 0, instructionIndex + 1};
  134. spv_result_t result = SPV_SUCCESS;
  135. if (instruction.opcode == spv::Op::OpString) {
  136. std::string content;
  137. result = extractOpString(loc, instruction, &content);
  138. if (result == SPV_SUCCESS) {
  139. stringMap.emplace(instruction.result_id, std::move(content));
  140. }
  141. } else if (instruction.opcode == spv::Op::OpSource) {
  142. spv::Id filenameId;
  143. std::string code;
  144. result = extractOpSource(loc, instruction, &filenameId, &code);
  145. if (result == SPV_SUCCESS) {
  146. sources.emplace_back(std::make_pair(filenameId, std::move(code)));
  147. }
  148. } else if (instruction.opcode == spv::Op::OpSourceContinued) {
  149. if (lastOpcode != spv::Op::OpSource) {
  150. spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
  151. "OpSourceContinued MUST follow an OpSource.");
  152. return SPV_ERROR_INVALID_BINARY;
  153. }
  154. assert(sources.size() > 0);
  155. result = extractOpSourceContinued(loc, instruction,
  156. &sources.back().second);
  157. }
  158. ++instructionIndex;
  159. lastOpcode = static_cast<spv::Op>(instruction.opcode);
  160. return result;
  161. };
  162. if (!context.Parse(binary, headerParser, instructionParser)) {
  163. return false;
  164. }
  165. std::string defaultName = "unnamed-";
  166. size_t unnamedCount = 0;
  167. for (auto & [ id, code ] : sources) {
  168. std::string filename;
  169. const auto it = stringMap.find(id);
  170. if (it == stringMap.cend() || it->second.empty()) {
  171. filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl";
  172. ++unnamedCount;
  173. } else {
  174. filename = it->second;
  175. }
  176. if (output->count(filename) != 0) {
  177. spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {},
  178. "Source file name conflict.");
  179. return false;
  180. }
  181. output->insert({filename, code});
  182. }
  183. return true;
  184. }