| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- // Copyright (c) 2023 Google LLC.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- #include "extract_source.h"
- #include <cassert>
- #include <string>
- #include <unordered_map>
- #include <vector>
- #include "source/opt/log.h"
- #include "spirv-tools/libspirv.hpp"
- #include "spirv/unified1/spirv.hpp"
- #include "tools/util/cli_consumer.h"
- namespace {
- constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
- // Extract a string literal from a given range.
- // Copies all the characters from `begin` to the first '\0' it encounters, while
- // removing escape patterns.
- // Not finding a '\0' before reaching `end` fails the extraction.
- //
- // Returns `true` if the extraction succeeded.
- // `output` value is undefined if false is returned.
- spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin,
- const char* end, std::string* output) {
- size_t sourceLength = std::distance(begin, end);
- std::string escapedString;
- escapedString.resize(sourceLength);
- size_t writeIndex = 0;
- size_t readIndex = 0;
- for (; readIndex < sourceLength; writeIndex++, readIndex++) {
- const char read = begin[readIndex];
- if (read == '\0') {
- escapedString.resize(writeIndex);
- output->append(escapedString);
- return SPV_SUCCESS;
- }
- if (read == '\\') {
- ++readIndex;
- }
- escapedString[writeIndex] = begin[readIndex];
- }
- spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
- "Missing NULL terminator for literal string.");
- return SPV_ERROR_INVALID_BINARY;
- }
- spv_result_t extractOpString(const spv_position_t& loc,
- const spv_parsed_instruction_t& instruction,
- std::string* output) {
- assert(output != nullptr);
- assert(instruction.opcode == spv::Op::OpString);
- if (instruction.num_operands != 2) {
- spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
- "Missing operands for OpString.");
- return SPV_ERROR_INVALID_BINARY;
- }
- const auto& operand = instruction.operands[1];
- const char* stringBegin =
- reinterpret_cast<const char*>(instruction.words + operand.offset);
- const char* stringEnd = reinterpret_cast<const char*>(
- instruction.words + operand.offset + operand.num_words);
- return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
- }
- spv_result_t extractOpSourceContinued(
- const spv_position_t& loc, const spv_parsed_instruction_t& instruction,
- std::string* output) {
- assert(output != nullptr);
- assert(instruction.opcode == spv::Op::OpSourceContinued);
- if (instruction.num_operands != 1) {
- spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
- "Missing operands for OpSourceContinued.");
- return SPV_ERROR_INVALID_BINARY;
- }
- const auto& operand = instruction.operands[0];
- const char* stringBegin =
- reinterpret_cast<const char*>(instruction.words + operand.offset);
- const char* stringEnd = reinterpret_cast<const char*>(
- instruction.words + operand.offset + operand.num_words);
- return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
- }
- spv_result_t extractOpSource(const spv_position_t& loc,
- const spv_parsed_instruction_t& instruction,
- spv::Id* filename, std::string* code) {
- assert(filename != nullptr && code != nullptr);
- assert(instruction.opcode == spv::Op::OpSource);
- // OpCode [ Source Language | Version | File (optional) | Source (optional) ]
- if (instruction.num_words < 3) {
- spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
- "Missing operands for OpSource.");
- return SPV_ERROR_INVALID_BINARY;
- }
- *filename = 0;
- *code = "";
- if (instruction.num_words < 4) {
- return SPV_SUCCESS;
- }
- *filename = instruction.words[3];
- if (instruction.num_words < 5) {
- return SPV_SUCCESS;
- }
- const char* stringBegin =
- reinterpret_cast<const char*>(instruction.words + 4);
- const char* stringEnd =
- reinterpret_cast<const char*>(instruction.words + instruction.num_words);
- return ExtractStringLiteral(loc, stringBegin, stringEnd, code);
- }
- } // namespace
- bool ExtractSourceFromModule(
- const std::vector<uint32_t>& binary,
- std::unordered_map<std::string, std::string>* output) {
- auto context = spvtools::SpirvTools(kDefaultEnvironment);
- context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
- // There is nothing valuable in the header.
- spvtools::HeaderParser headerParser = [](const spv_endianness_t,
- const spv_parsed_header_t&) {
- return SPV_SUCCESS;
- };
- std::unordered_map<uint32_t, std::string> stringMap;
- std::vector<std::pair<spv::Id, std::string>> sources;
- spv::Op lastOpcode = spv::Op::OpMax;
- size_t instructionIndex = 0;
- spvtools::InstructionParser instructionParser =
- [&stringMap, &sources, &lastOpcode,
- &instructionIndex](const spv_parsed_instruction_t& instruction) {
- const spv_position_t loc = {0, 0, instructionIndex + 1};
- spv_result_t result = SPV_SUCCESS;
- if (instruction.opcode == spv::Op::OpString) {
- std::string content;
- result = extractOpString(loc, instruction, &content);
- if (result == SPV_SUCCESS) {
- stringMap.emplace(instruction.result_id, std::move(content));
- }
- } else if (instruction.opcode == spv::Op::OpSource) {
- spv::Id filenameId;
- std::string code;
- result = extractOpSource(loc, instruction, &filenameId, &code);
- if (result == SPV_SUCCESS) {
- sources.emplace_back(std::make_pair(filenameId, std::move(code)));
- }
- } else if (instruction.opcode == spv::Op::OpSourceContinued) {
- if (lastOpcode != spv::Op::OpSource) {
- spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
- "OpSourceContinued MUST follow an OpSource.");
- return SPV_ERROR_INVALID_BINARY;
- }
- assert(sources.size() > 0);
- result = extractOpSourceContinued(loc, instruction,
- &sources.back().second);
- }
- ++instructionIndex;
- lastOpcode = static_cast<spv::Op>(instruction.opcode);
- return result;
- };
- if (!context.Parse(binary, headerParser, instructionParser)) {
- return false;
- }
- std::string defaultName = "unnamed-";
- size_t unnamedCount = 0;
- for (auto & [ id, code ] : sources) {
- std::string filename;
- const auto it = stringMap.find(id);
- if (it == stringMap.cend() || it->second.empty()) {
- filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl";
- ++unnamedCount;
- } else {
- filename = it->second;
- }
- if (output->count(filename) != 0) {
- spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {},
- "Source file name conflict.");
- return false;
- }
- output->insert({filename, code});
- }
- return true;
- }
|