markv_codec.cpp 99 KB


  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Contains
  15. // - SPIR-V to MARK-V encoder
  16. // - MARK-V to SPIR-V decoder
  17. //
  18. // MARK-V is a compression format for SPIR-V binaries. It strips away
  19. // non-essential information (such as result ids which can be regenerated) and
  20. // uses various bit reduction techiniques to reduce the size of the binary.
  21. #include <algorithm>
  22. #include <cassert>
  23. #include <cstring>
  24. #include <functional>
  25. #include <iostream>
  26. #include <iterator>
  27. #include <list>
  28. #include <memory>
  29. #include <numeric>
  30. #include <string>
  31. #include <unordered_map>
  32. #include <unordered_set>
  33. #include <vector>
  34. #include "latest_version_glsl_std_450_header.h"
  35. #include "latest_version_opencl_std_header.h"
  36. #include "latest_version_spirv_header.h"
  37. #include "binary.h"
  38. #include "diagnostic.h"
  39. #include "enum_string_mapping.h"
  40. #include "ext_inst.h"
  41. #include "extensions.h"
  42. #include "id_descriptor.h"
  43. #include "instruction.h"
  44. #include "markv.h"
  45. #include "markv_model.h"
  46. #include "opcode.h"
  47. #include "operand.h"
  48. #include "spirv-tools/libspirv.h"
  49. #include "spirv_endian.h"
  50. #include "spirv_validator_options.h"
  51. #include "util/bit_stream.h"
  52. #include "util/huffman_codec.h"
  53. #include "util/move_to_front.h"
  54. #include "util/parse_number.h"
  55. #include "val/instruction.h"
  56. #include "val/validation_state.h"
  57. #include "validate.h"
  58. using libspirv::DiagnosticStream;
  59. using libspirv::IdDescriptorCollection;
  60. using libspirv::Instruction;
  61. using libspirv::ValidationState_t;
  62. using spvutils::BitReaderWord64;
  63. using spvutils::BitWriterWord64;
  64. using spvutils::HuffmanCodec;
  65. using MoveToFront = spvutils::MoveToFront<uint32_t>;
  66. using MultiMoveToFront = spvutils::MultiMoveToFront<uint32_t>;
  67. namespace spvtools {
  68. namespace {
  69. const uint32_t kSpirvMagicNumber = SpvMagicNumber;
  70. const uint32_t kMarkvMagicNumber = 0x07230303;
  71. // Handles for move-to-front sequences. Enums which end with "Begin" define
  72. // handle spaces which start at that value and span 16 or 32 bit wide.
  73. enum : uint64_t {
  74. kMtfNone = 0,
  75. // All ids.
  76. kMtfAll,
  77. // All forward declared ids.
  78. kMtfForwardDeclared,
  79. // All type ids except for generated by OpTypeFunction.
  80. kMtfTypeNonFunction,
  81. // All labels.
  82. kMtfLabel,
  83. // All ids created by instructions which had type_id.
  84. kMtfObject,
  85. // All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
  86. kMtfTypeScalar,
  87. // All composite types.
  88. kMtfTypeComposite,
  89. // Boolean type or any vector type of it.
  90. kMtfTypeBoolScalarOrVector,
  91. // All float types or any vector floats type.
  92. kMtfTypeFloatScalarOrVector,
  93. // All int types or any vector int type.
  94. kMtfTypeIntScalarOrVector,
  95. // All types declared as return types in OpTypeFunction.
  96. kMtfTypeReturnedByFunction,
  97. // All composite objects.
  98. kMtfComposite,
  99. // All bool objects or vectors of bools.
  100. kMtfBoolScalarOrVector,
  101. // All float objects or vectors of float.
  102. kMtfFloatScalarOrVector,
  103. // All int objects or vectors of int.
  104. kMtfIntScalarOrVector,
  105. // All pointer types which point to composited.
  106. kMtfTypePointerToComposite,
  107. // Used by EncodeMtfRankHuffman.
  108. kMtfGenericNonZeroRank,
  109. // Handle space for ids of specific type.
  110. kMtfIdOfTypeBegin = 0x10000,
  111. // Handle space for ids generated by specific opcode.
  112. kMtfIdGeneratedByOpcode = 0x20000,
  113. // Handle space for ids of objects with type generated by specific opcode.
  114. kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
  115. // All vectors of specific component type.
  116. kMtfVectorOfComponentTypeBegin = 0x40000,
  117. // All vector types of specific size.
  118. kMtfTypeVectorOfSizeBegin = 0x50000,
  119. // All pointer types to specific type.
  120. kMtfPointerToTypeBegin = 0x60000,
  121. // All function types which return specific type.
  122. kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
  123. // All function objects which return specific type.
  124. kMtfFunctionWithReturnTypeBegin = 0x80000,
  125. // Short id descriptor space (max 16-bit).
  126. kMtfShortIdDescriptorSpaceBegin = 0x90000,
  127. // Long id descriptor space (32-bit).
  128. kMtfLongIdDescriptorSpaceBegin = 0x100000000,
  129. };
  130. // Signals that the value is not in the coding scheme and a fallback method
  131. // needs to be used.
  132. const uint64_t kMarkvNoneOfTheAbove = MarkvModel::GetMarkvNoneOfTheAbove();
  133. // Mtf ranks smaller than this are encoded with Huffman coding.
  134. const uint32_t kMtfSmallestRankEncodedByValue = 10;
  135. // Signals that the mtf rank is too large to be encoded with Huffman.
  136. const uint32_t kMtfRankEncodedByValueSignal =
  137. std::numeric_limits<uint32_t>::max();
  138. const size_t kCommentNumWhitespaces = 2;
  139. const size_t kByteBreakAfterInstIfLessThanUntilNextByte = 8;
  140. const uint32_t kShortDescriptorNumBits = 8;
  141. // Custom hash function used to produce short descriptors.
  142. uint32_t ShortHashU32Array(const std::vector<uint32_t>& words) {
  143. // The hash function is a sum of hashes of each word seeded by word index.
  144. // Knuth's multiplicative hash is used to hash the words.
  145. const uint32_t kKnuthMulHash = 2654435761;
  146. uint32_t val = 0;
  147. for (uint32_t i = 0; i < words.size(); ++i) {
  148. val += (words[i] + i + 123) * kKnuthMulHash;
  149. }
  150. return 1 + val % ((1 << kShortDescriptorNumBits) - 1);
  151. }
  152. // Returns a set of mtf rank codecs based on a plausible hand-coded
  153. // distribution.
  154. std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
  155. GetMtfHuffmanCodecs() {
  156. std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
  157. std::unique_ptr<HuffmanCodec<uint32_t>> codec;
  158. codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
  159. {0, 5},
  160. {1, 40},
  161. {2, 10},
  162. {3, 5},
  163. {4, 5},
  164. {5, 5},
  165. {6, 3},
  166. {7, 3},
  167. {8, 3},
  168. {9, 3},
  169. {kMtfRankEncodedByValueSignal, 10},
  170. })));
  171. codecs.emplace(kMtfAll, std::move(codec));
  172. codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
  173. {1, 50},
  174. {2, 20},
  175. {3, 5},
  176. {4, 5},
  177. {5, 2},
  178. {6, 1},
  179. {7, 1},
  180. {8, 1},
  181. {9, 1},
  182. {kMtfRankEncodedByValueSignal, 10},
  183. })));
  184. codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
  185. return codecs;
  186. }
  187. // Returns true if the opcode has a fixed number of operands. May return a
  188. // false negative.
  189. bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
  190. switch (opcode) {
  191. // TODO([email protected]) This is not a complete list.
  192. case SpvOpNop:
  193. case SpvOpName:
  194. case SpvOpUndef:
  195. case SpvOpSizeOf:
  196. case SpvOpLine:
  197. case SpvOpNoLine:
  198. case SpvOpDecorationGroup:
  199. case SpvOpExtension:
  200. case SpvOpExtInstImport:
  201. case SpvOpMemoryModel:
  202. case SpvOpCapability:
  203. case SpvOpTypeVoid:
  204. case SpvOpTypeBool:
  205. case SpvOpTypeInt:
  206. case SpvOpTypeFloat:
  207. case SpvOpTypeVector:
  208. case SpvOpTypeMatrix:
  209. case SpvOpTypeSampler:
  210. case SpvOpTypeSampledImage:
  211. case SpvOpTypeArray:
  212. case SpvOpTypePointer:
  213. case SpvOpConstantTrue:
  214. case SpvOpConstantFalse:
  215. case SpvOpLabel:
  216. case SpvOpBranch:
  217. case SpvOpFunction:
  218. case SpvOpFunctionParameter:
  219. case SpvOpFunctionEnd:
  220. case SpvOpBitcast:
  221. case SpvOpCopyObject:
  222. case SpvOpTranspose:
  223. case SpvOpSNegate:
  224. case SpvOpFNegate:
  225. case SpvOpIAdd:
  226. case SpvOpFAdd:
  227. case SpvOpISub:
  228. case SpvOpFSub:
  229. case SpvOpIMul:
  230. case SpvOpFMul:
  231. case SpvOpUDiv:
  232. case SpvOpSDiv:
  233. case SpvOpFDiv:
  234. case SpvOpUMod:
  235. case SpvOpSRem:
  236. case SpvOpSMod:
  237. case SpvOpFRem:
  238. case SpvOpFMod:
  239. case SpvOpVectorTimesScalar:
  240. case SpvOpMatrixTimesScalar:
  241. case SpvOpVectorTimesMatrix:
  242. case SpvOpMatrixTimesVector:
  243. case SpvOpMatrixTimesMatrix:
  244. case SpvOpOuterProduct:
  245. case SpvOpDot:
  246. return true;
  247. default:
  248. break;
  249. }
  250. return false;
  251. }
  252. size_t GetNumBitsToNextByte(size_t bit_pos) { return (8 - (bit_pos % 8)) % 8; }
  253. // Defines and returns current MARK-V version.
  254. uint32_t GetMarkvVersion() {
  255. const uint32_t kVersionMajor = 1;
  256. const uint32_t kVersionMinor = 4;
  257. return kVersionMinor | (kVersionMajor << 16);
  258. }
  259. class MarkvLogger {
  260. public:
  261. MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer)
  262. : log_consumer_(log_consumer), debug_consumer_(debug_consumer) {}
  263. void AppendText(const std::string& str) {
  264. Append(str);
  265. use_delimiter_ = false;
  266. }
  267. void AppendTextNewLine(const std::string& str) {
  268. Append(str);
  269. Append("\n");
  270. use_delimiter_ = false;
  271. }
  272. void AppendBitSequence(const std::string& str) {
  273. if (debug_consumer_) instruction_bits_ << str;
  274. if (use_delimiter_) Append("-");
  275. Append(str);
  276. use_delimiter_ = true;
  277. }
  278. void AppendWhitespaces(size_t num) {
  279. Append(std::string(num, ' '));
  280. use_delimiter_ = false;
  281. }
  282. void NewLine() {
  283. Append("\n");
  284. use_delimiter_ = false;
  285. }
  286. bool DebugInstruction(const spv_parsed_instruction_t& inst) {
  287. bool result = true;
  288. if (debug_consumer_) {
  289. result = debug_consumer_(
  290. std::vector<uint32_t>(inst.words, inst.words + inst.num_words),
  291. instruction_bits_.str(), instruction_comment_.str());
  292. instruction_bits_.str(std::string());
  293. instruction_comment_.str(std::string());
  294. }
  295. return result;
  296. }
  297. private:
  298. MarkvLogger(const MarkvLogger&) = delete;
  299. MarkvLogger(MarkvLogger&&) = delete;
  300. MarkvLogger& operator=(const MarkvLogger&) = delete;
  301. MarkvLogger& operator=(MarkvLogger&&) = delete;
  302. void Append(const std::string& str) {
  303. if (log_consumer_) log_consumer_(str);
  304. if (debug_consumer_) instruction_comment_ << str;
  305. }
  306. MarkvLogConsumer log_consumer_;
  307. MarkvDebugConsumer debug_consumer_;
  308. std::stringstream instruction_bits_;
  309. std::stringstream instruction_comment_;
  310. // If true a delimiter will be appended before the next bit sequence.
  311. // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
  312. bool use_delimiter_ = false;
  313. };
  314. // Base class for MARK-V encoder and decoder. Contains common functionality
  315. // such as:
  316. // - Validator connection and validation state.
  317. // - SPIR-V grammar and helper functions.
  318. class MarkvCodecBase {
  319. public:
  320. virtual ~MarkvCodecBase() { spvValidatorOptionsDestroy(validator_options_); }
  321. MarkvCodecBase() = delete;
  322. protected:
  323. struct MarkvHeader {
  324. MarkvHeader() {
  325. magic_number = kMarkvMagicNumber;
  326. markv_version = GetMarkvVersion();
  327. markv_model = 0;
  328. markv_length_in_bits = 0;
  329. spirv_version = 0;
  330. spirv_generator = 0;
  331. }
  332. uint32_t magic_number;
  333. uint32_t markv_version;
  334. // Magic number to identify or verify MarkvModel used for encoding.
  335. uint32_t markv_model;
  336. uint32_t markv_length_in_bits;
  337. uint32_t spirv_version;
  338. uint32_t spirv_generator;
  339. };
  340. // |model| is owned by the caller, must be not null and valid during the
  341. // lifetime of the codec.
  342. explicit MarkvCodecBase(spv_const_context context,
  343. spv_validator_options validator_options,
  344. const MarkvModel* model)
  345. : validator_options_(validator_options),
  346. grammar_(context),
  347. model_(model),
  348. short_id_descriptors_(ShortHashU32Array),
  349. mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
  350. context_(context),
  351. vstate_(validator_options
  352. ? new ValidationState_t(context, validator_options_)
  353. : nullptr) {}
  354. // Validates a single instruction and updates validation state of the module.
  355. // Does nothing and returns SPV_SUCCESS if validator was not created.
  356. spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
  357. if (!vstate_) return SPV_SUCCESS;
  358. return ValidateInstructionAndUpdateValidationState(vstate_.get(), &inst);
  359. }
  360. // Returns instruction which created |id| or nullptr if such instruction was
  361. // not registered.
  362. const Instruction* FindDef(uint32_t id) const {
  363. const auto it = id_to_def_instruction_.find(id);
  364. if (it == id_to_def_instruction_.end()) return nullptr;
  365. return it->second;
  366. }
  367. // Returns type id of vector type component.
  368. uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
  369. const Instruction* type_inst = FindDef(vector_type_id);
  370. assert(type_inst);
  371. assert(type_inst->opcode() == SpvOpTypeVector);
  372. const uint32_t component_type =
  373. type_inst->word(type_inst->operands()[1].offset);
  374. return component_type;
  375. }
  376. // Returns mtf handle for ids of given type.
  377. uint64_t GetMtfIdOfType(uint32_t type_id) const {
  378. return kMtfIdOfTypeBegin + type_id;
  379. }
  380. // Returns mtf handle for ids generated by given opcode.
  381. uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
  382. return kMtfIdGeneratedByOpcode + opcode;
  383. }
  384. // Returns mtf handle for ids of type generated by given opcode.
  385. uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
  386. return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
  387. }
  388. // Returns mtf handle for vectors of specific component type.
  389. uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
  390. return kMtfVectorOfComponentTypeBegin + type_id;
  391. }
  392. // Returns mtf handle for vector type of specific size.
  393. uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
  394. return kMtfTypeVectorOfSizeBegin + size;
  395. }
  396. // Returns mtf handle for pointers to specific size.
  397. uint64_t GetMtfPointerToType(uint32_t type_id) const {
  398. return kMtfPointerToTypeBegin + type_id;
  399. }
  400. // Returns mtf handle for function types with given return type.
  401. uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
  402. return kMtfFunctionTypeWithReturnTypeBegin + type_id;
  403. }
  404. // Returns mtf handle for functions with given return type.
  405. uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
  406. return kMtfFunctionWithReturnTypeBegin + type_id;
  407. }
  408. // Returns mtf handle for the given long id descriptor.
  409. uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const {
  410. return kMtfLongIdDescriptorSpaceBegin + descriptor;
  411. }
  412. // Returns mtf handle for the given short id descriptor.
  413. uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const {
  414. return kMtfShortIdDescriptorSpaceBegin + descriptor;
  415. }
  416. // Process data from the current instruction. This would update MTFs and
  417. // other data containers.
  418. void ProcessCurInstruction();
  419. // Returns move-to-front handle to be used for the current operand slot.
  420. // Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
  421. uint64_t GetRuleBasedMtf();
  422. // Returns words of the current instruction. Decoder has a different
  423. // implementation and the array is valid only until the previously decoded
  424. // word.
  425. virtual const uint32_t* GetInstWords() const { return inst_.words; }
  426. // Returns the opcode of the previous instruction.
  427. SpvOp GetPrevOpcode() const {
  428. if (instructions_.empty()) return SpvOpNop;
  429. return instructions_.back()->opcode();
  430. }
  431. // Returns diagnostic stream, position index is set to instruction number.
  432. DiagnosticStream Diag(spv_result_t error_code) const {
  433. return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer,
  434. error_code);
  435. }
  436. // Returns current id bound.
  437. uint32_t GetIdBound() const { return id_bound_; }
  438. // Sets current id bound, expected to be no lower than the previous one.
  439. void SetIdBound(uint32_t id_bound) {
  440. assert(id_bound >= id_bound_);
  441. id_bound_ = id_bound;
  442. if (vstate_) vstate_->setIdBound(id_bound);
  443. }
  444. // Returns Huffman codec for ranks of the mtf with given |handle|.
  445. // Different mtfs can use different rank distributions.
  446. // May return nullptr if the codec doesn't exist.
  447. const spvutils::HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(
  448. uint64_t handle) const {
  449. const auto it = mtf_huffman_codecs_.find(handle);
  450. if (it == mtf_huffman_codecs_.end()) return nullptr;
  451. return it->second.get();
  452. }
  453. // Promotes id in all move-to-front sequences if ids can be shared by multiple
  454. // sequences.
  455. void PromoteIfNeeded(uint32_t id) {
  456. if (!model_->AnyDescriptorHasCodingScheme() &&
  457. model_->id_fallback_strategy() ==
  458. MarkvModel::IdFallbackStrategy::kShortDescriptor) {
  459. // Move-to-front sequences do not share ids. Nothing to do.
  460. return;
  461. }
  462. multi_mtf_.Promote(id);
  463. }
  464. spv_validator_options validator_options_ = nullptr;
  465. const libspirv::AssemblyGrammar grammar_;
  466. MarkvHeader header_;
  467. // MARK-V model, not owned.
  468. const MarkvModel* model_ = nullptr;
  469. // Current instruction, current operand and current operand index.
  470. spv_parsed_instruction_t inst_;
  471. spv_parsed_operand_t operand_;
  472. uint32_t operand_index_;
  473. // Maps a result ID to its type ID. By convention:
  474. // - a result ID that is a type definition maps to itself.
  475. // - a result ID without a type maps to 0. (E.g. for OpLabel)
  476. std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
  477. // Container for all move-to-front sequences.
  478. MultiMoveToFront multi_mtf_;
  479. // Id of the current function or zero if outside of function.
  480. uint32_t cur_function_id_ = 0;
  481. // Return type of the current function.
  482. uint32_t cur_function_return_type_ = 0;
  483. // Remaining function parameter types. This container is filled on OpFunction,
  484. // and drained on OpFunctionParameter.
  485. std::list<uint32_t> remaining_function_parameter_types_;
  486. // List of ids local to the current function.
  487. std::vector<uint32_t> ids_local_to_cur_function_;
  488. // List of instructions in the order they are given in the module.
  489. std::vector<std::unique_ptr<const Instruction>> instructions_;
  490. // Container/computer for long (32-bit) id descriptors.
  491. IdDescriptorCollection long_id_descriptors_;
  492. // Container/computer for short id descriptors.
  493. // Short descriptors are stored in uint32_t, but their actual bit width is
  494. // defined with kShortDescriptorNumBits.
  495. // It doesn't seem logical to have a different computer for short id
  496. // descriptors, since one could actually map/truncate long descriptors.
  497. // But as short descriptors have collisions, the efficiency of
  498. // compression depends on the collision pattern, and short descriptors
  499. // produced by function ShortHashU32Array have been empirically proven to
  500. // produce better results.
  501. IdDescriptorCollection short_id_descriptors_;
  502. // Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
  503. // need to contain a different codec for every handle as most use one and the
  504. // same.
  505. std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
  506. mtf_huffman_codecs_;
  507. // If not nullptr, codec will log comments on the compression process.
  508. std::unique_ptr<MarkvLogger> logger_;
  509. private:
  510. spv_const_context context_ = nullptr;
  511. std::unique_ptr<ValidationState_t> vstate_;
  512. // Maps result id to the instruction which defined it.
  513. std::unordered_map<uint32_t, const Instruction*> id_to_def_instruction_;
  514. uint32_t id_bound_ = 1;
  515. };
  516. // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
  517. // EncodeInstruction which can be used as callback by spvBinaryParse.
  518. // Encoded binary is written to an internally maintained bitstream.
  519. // After the last instruction is encoded, the resulting MARK-V binary can be
  520. // acquired by calling GetMarkvBinary().
  521. // The encoder uses SPIR-V validator to keep internal state, therefore
  522. // SPIR-V binary needs to be able to pass validator checks.
  523. // CreateCommentsLogger() can be used to enable the encoder to write comments
  524. // on how encoding was done, which can later be accessed with GetComments().
  525. class MarkvEncoder : public MarkvCodecBase {
  526. public:
  527. // |model| is owned by the caller, must be not null and valid during the
  528. // lifetime of MarkvEncoder.
  529. MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options,
  530. const MarkvModel* model)
  531. : MarkvCodecBase(context, GetValidatorOptions(options), model),
  532. options_(options) {
  533. (void)options_;
  534. }
  535. // Writes data from SPIR-V header to MARK-V header.
  536. spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */,
  537. uint32_t version, uint32_t generator,
  538. uint32_t id_bound, uint32_t /* schema */) {
  539. SetIdBound(id_bound);
  540. header_.spirv_version = version;
  541. header_.spirv_generator = generator;
  542. return SPV_SUCCESS;
  543. }
  544. // Creates an internal logger which writes comments on the encoding process.
  545. void CreateLogger(MarkvLogConsumer log_consumer,
  546. MarkvDebugConsumer debug_consumer) {
  547. logger_.reset(new MarkvLogger(log_consumer, debug_consumer));
  548. writer_.SetCallback(
  549. [this](const std::string& str) { logger_->AppendBitSequence(str); });
  550. }
  551. // Encodes SPIR-V instruction to MARK-V and writes to bit stream.
  552. // Operation can fail if the instruction fails to pass the validator or if
  553. // the encoder stubmles on something unexpected.
  554. spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
  555. // Concatenates MARK-V header and the bit stream with encoded instructions
  556. // into a single buffer and returns it as spv_markv_binary. The returned
  557. // value is owned by the caller and needs to be destroyed with
  558. // spvMarkvBinaryDestroy().
  559. std::vector<uint8_t> GetMarkvBinary() {
  560. header_.markv_length_in_bits =
  561. static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
  562. header_.markv_model =
  563. (model_->model_type() << 16) | model_->model_version();
  564. const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
  565. std::vector<uint8_t> markv(num_bytes);
  566. assert(writer_.GetData());
  567. std::memcpy(markv.data(), &header_, sizeof(header_));
  568. std::memcpy(markv.data() + sizeof(header_), writer_.GetData(),
  569. writer_.GetDataSizeBytes());
  570. return markv;
  571. }
  572. // Optionally adds disassembly to the comments.
  573. // Disassembly should contain all instructions in the module separated by
  574. // \n, and no header.
  575. void SetDisassembly(std::string&& disassembly) {
  576. disassembly_.reset(new std::stringstream(std::move(disassembly)));
  577. }
  578. // Extracts the next instruction line from the disassembly and logs it.
  579. void LogDisassemblyInstruction() {
  580. if (logger_ && disassembly_) {
  581. std::string line;
  582. std::getline(*disassembly_, line, '\n');
  583. logger_->AppendTextNewLine(line);
  584. }
  585. }
  586. private:
  587. // Creates and returns validator options. Returned value owned by the caller.
  588. static spv_validator_options GetValidatorOptions(
  589. const MarkvCodecOptions& options) {
  590. return options.validate_spirv_binary ? spvValidatorOptionsCreate()
  591. : nullptr;
  592. }
  593. // Writes a single word to bit stream. operand_.type determines if the word is
  594. // encoded and how.
  595. spv_result_t EncodeNonIdWord(uint32_t word);
  596. // Writes both opcode and num_operands as a single code.
  597. // Returns SPV_UNSUPPORTED iff no suitable codec was found.
  598. spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode,
  599. uint32_t num_operands);
  600. // Writes mtf rank to bit stream. |mtf| is used to determine the codec
  601. // scheme. |fallback_method| is used if no codec defined for |mtf|.
  602. spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
  603. uint64_t fallback_method);
  604. // Writes id using coding based on mtf associated with the id descriptor.
  605. // Returns SPV_UNSUPPORTED iff fallback method needs to be used.
  606. spv_result_t EncodeIdWithDescriptor(uint32_t id);
  607. // Writes id using coding based on the given |mtf|, which is expected to
  608. // contain the given |id|.
  609. spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id);
  610. // Writes type id of the current instruction if can't be inferred.
  611. spv_result_t EncodeTypeId();
  612. // Writes result id of the current instruction if can't be inferred.
  613. spv_result_t EncodeResultId();
  614. // Writes ids which are neither type nor result ids.
  615. spv_result_t EncodeRefId(uint32_t id);
  616. // Writes bits to the stream until the beginning of the next byte if the
  617. // number of bits until the next byte is less than |byte_break_if_less_than|.
  618. void AddByteBreak(size_t byte_break_if_less_than);
  619. // Encodes a literal number operand and writes it to the bit stream.
  620. spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand);
  621. MarkvCodecOptions options_;
  622. // Bit stream where encoded instructions are written.
  623. BitWriterWord64 writer_;
  624. // If not nullptr, disassembled instruction lines will be written to comments.
  625. // Format: \n separated instruction lines, no header.
  626. std::unique_ptr<std::stringstream> disassembly_;
  627. };
  628. // Decodes MARK-V buffers written by MarkvEncoder.
  629. class MarkvDecoder : public MarkvCodecBase {
  630. public:
  631. // |model| is owned by the caller, must be not null and valid during the
  632. // lifetime of MarkvEncoder.
  633. MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
  634. const MarkvCodecOptions& options, const MarkvModel* model)
  635. : MarkvCodecBase(context, GetValidatorOptions(options), model),
  636. options_(options),
  637. reader_(markv) {
  638. (void)options_;
  639. SetIdBound(1);
  640. parsed_operands_.reserve(25);
  641. inst_words_.reserve(25);
  642. }
  643. // Creates an internal logger which writes comments on the decoding process.
  644. void CreateLogger(MarkvLogConsumer log_consumer,
  645. MarkvDebugConsumer debug_consumer) {
  646. logger_.reset(new MarkvLogger(log_consumer, debug_consumer));
  647. }
  648. // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
  649. // Can be called only once. Fails if data of wrong format or ends prematurely,
  650. // of if validation fails.
  651. spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
  652. private:
  653. // Describes the format of a typed literal number.
  654. struct NumberType {
  655. spv_number_kind_t type;
  656. uint32_t bit_width;
  657. };
  658. // Creates and returns validator options. Returned value owned by the caller.
  659. static spv_validator_options GetValidatorOptions(
  660. const MarkvCodecOptions& options) {
  661. return options.validate_spirv_binary ? spvValidatorOptionsCreate()
  662. : nullptr;
  663. }
  664. // Reads a single bit from reader_. The read bit is stored in |bit|.
  665. // Returns false iff reader_ fails.
  666. bool ReadBit(bool* bit) {
  667. uint64_t bits = 0;
  668. const bool result = reader_.ReadBits(&bits, 1);
  669. if (result) *bit = bits ? true : false;
  670. return result;
  671. };
  672. // Returns ReadBit bound to the class object.
  673. std::function<bool(bool*)> GetReadBitCallback() {
  674. return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
  675. }
  676. // Reads a single non-id word from bit stream. operand_.type determines if
  677. // the word needs to be decoded and how.
  678. spv_result_t DecodeNonIdWord(uint32_t* word);
  679. // Reads and decodes both opcode and num_operands as a single code.
  680. // Returns SPV_UNSUPPORTED iff no suitable codec was found.
  681. spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
  682. uint32_t* num_operands);
  683. // Reads mtf rank from bit stream. |mtf| is used to determine the codec
  684. // scheme. |fallback_method| is used if no codec defined for |mtf|.
  685. spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
  686. uint32_t* rank);
  687. // Reads id using coding based on mtf associated with the id descriptor.
  688. // Returns SPV_UNSUPPORTED iff fallback method needs to be used.
  689. spv_result_t DecodeIdWithDescriptor(uint32_t* id);
  690. // Reads id using coding based on the given |mtf|, which is expected to
  691. // contain the needed |id|.
  692. spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
  693. // Reads type id of the current instruction if can't be inferred.
  694. spv_result_t DecodeTypeId();
  695. // Reads result id of the current instruction if can't be inferred.
  696. spv_result_t DecodeResultId();
  697. // Reads id which is neither type nor result id.
  698. spv_result_t DecodeRefId(uint32_t* id);
  699. // Reads and discards bits until the beginning of the next byte if the
  700. // number of bits until the next byte is less than |byte_break_if_less_than|.
  701. bool ReadToByteBreak(size_t byte_break_if_less_than);
  702. // Returns instruction words decoded up to this point.
  703. const uint32_t* GetInstWords() const override { return inst_words_.data(); }
  704. // Reads a literal number as it is described in |operand| from the bit stream,
  705. // decodes and writes it to spirv_.
  706. spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
  707. // Reads instruction from bit stream, decodes and validates it.
  708. // Decoded instruction is valid until the next call of DecodeInstruction().
  709. spv_result_t DecodeInstruction();
  710. // Read operand from the stream decodes and validates it.
  711. spv_result_t DecodeOperand(size_t operand_offset,
  712. const spv_operand_type_t type,
  713. spv_operand_pattern_t* expected_operands);
  714. // Records the numeric type for an operand according to the type information
  715. // associated with the given non-zero type Id. This can fail if the type Id
  716. // is not a type Id, or if the type Id does not reference a scalar numeric
  717. // type. On success, return SPV_SUCCESS and populates the num_words,
  718. // number_kind, and number_bit_width fields of parsed_operand.
  719. spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
  720. uint32_t type_id);
  721. // Records the number type for the current instruction, if it generates a
  722. // type. For types that aren't scalar numbers, record something with number
  723. // kind SPV_NUMBER_NONE.
  724. void RecordNumberType();
  725. MarkvCodecOptions options_;
  726. // Temporary sink where decoded SPIR-V words are written. Once it contains the
  727. // entire module, the container is moved and returned.
  728. std::vector<uint32_t> spirv_;
  729. // Bit stream containing encoded data.
  730. BitReaderWord64 reader_;
  731. // Temporary storage for operands of the currently parsed instruction.
  732. // Valid until next DecodeInstruction call.
  733. std::vector<spv_parsed_operand_t> parsed_operands_;
  734. // Temporary storage for current instruction words.
  735. // Valid until next DecodeInstruction call.
  736. std::vector<uint32_t> inst_words_;
  737. // Maps a type ID to its number type description.
  738. std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
  739. // Maps an ExtInstImport id to the extended instruction type.
  740. std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
  741. };
  742. void MarkvCodecBase::ProcessCurInstruction() {
  743. instructions_.emplace_back(new Instruction(&inst_));
  744. const SpvOp opcode = SpvOp(inst_.opcode);
  745. if (inst_.result_id) {
  746. id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
  747. // Collect ids local to the current function.
  748. if (cur_function_id_) {
  749. ids_local_to_cur_function_.push_back(inst_.result_id);
  750. }
  751. // Starting new function.
  752. if (opcode == SpvOpFunction) {
  753. cur_function_id_ = inst_.result_id;
  754. cur_function_return_type_ = inst_.type_id;
  755. if (model_->id_fallback_strategy() ==
  756. MarkvModel::IdFallbackStrategy::kRuleBased) {
  757. multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
  758. inst_.result_id);
  759. }
  760. // Store function parameter types in a queue, so that we know which types
  761. // to expect in the following OpFunctionParameter instructions.
  762. const Instruction* def_inst = FindDef(inst_.words[4]);
  763. assert(def_inst);
  764. assert(def_inst->opcode() == SpvOpTypeFunction);
  765. for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
  766. remaining_function_parameter_types_.push_back(def_inst->word(i));
  767. }
  768. }
  769. }
  770. // Remove local ids from MTFs if function end.
  771. if (opcode == SpvOpFunctionEnd) {
  772. cur_function_id_ = 0;
  773. for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id);
  774. ids_local_to_cur_function_.clear();
  775. assert(remaining_function_parameter_types_.empty());
  776. }
  777. if (!inst_.result_id) return;
  778. {
  779. // Save the result ID to type ID mapping.
  780. // In the grammar, type ID always appears before result ID.
  781. // A regular value maps to its type. Some instructions (e.g. OpLabel)
  782. // have no type Id, and will map to 0. The result Id for a
  783. // type-generating instruction (e.g. OpTypeInt) maps to itself.
  784. auto insertion_result = id_to_type_id_.emplace(
  785. inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode))
  786. ? inst_.result_id
  787. : inst_.type_id);
  788. (void)insertion_result;
  789. assert(insertion_result.second);
  790. }
  791. // Add result_id to MTFs.
  792. if (model_->id_fallback_strategy() ==
  793. MarkvModel::IdFallbackStrategy::kRuleBased) {
  794. switch (opcode) {
  795. case SpvOpTypeFloat:
  796. case SpvOpTypeInt:
  797. case SpvOpTypeBool:
  798. case SpvOpTypeVector:
  799. case SpvOpTypePointer:
  800. case SpvOpExtInstImport:
  801. case SpvOpTypeSampledImage:
  802. case SpvOpTypeImage:
  803. case SpvOpTypeSampler:
  804. multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
  805. break;
  806. default:
  807. break;
  808. }
  809. if (spvOpcodeIsComposite(opcode)) {
  810. multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
  811. }
  812. if (opcode == SpvOpLabel) {
  813. multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
  814. }
  815. if (opcode == SpvOpTypeInt) {
  816. multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
  817. multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
  818. }
  819. if (opcode == SpvOpTypeFloat) {
  820. multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
  821. multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
  822. }
  823. if (opcode == SpvOpTypeBool) {
  824. multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
  825. multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
  826. }
  827. if (opcode == SpvOpTypeVector) {
  828. const uint32_t component_type_id = inst_.words[2];
  829. const uint32_t size = inst_.words[3];
  830. if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
  831. component_type_id)) {
  832. multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
  833. } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
  834. component_type_id)) {
  835. multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
  836. } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
  837. component_type_id)) {
  838. multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
  839. }
  840. multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
  841. }
  842. if (inst_.opcode == SpvOpTypeFunction) {
  843. const uint32_t return_type = inst_.words[2];
  844. multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
  845. multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
  846. inst_.result_id);
  847. }
  848. if (inst_.type_id) {
  849. const Instruction* type_inst = FindDef(inst_.type_id);
  850. assert(type_inst);
  851. multi_mtf_.Insert(kMtfObject, inst_.result_id);
  852. multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
  853. if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
  854. multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
  855. }
  856. if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
  857. multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
  858. if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
  859. multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
  860. if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
  861. multi_mtf_.Insert(kMtfComposite, inst_.result_id);
  862. switch (type_inst->opcode()) {
  863. case SpvOpTypeInt:
  864. case SpvOpTypeBool:
  865. case SpvOpTypePointer:
  866. case SpvOpTypeVector:
  867. case SpvOpTypeImage:
  868. case SpvOpTypeSampledImage:
  869. case SpvOpTypeSampler:
  870. multi_mtf_.Insert(
  871. GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()),
  872. inst_.result_id);
  873. break;
  874. default:
  875. break;
  876. }
  877. if (type_inst->opcode() == SpvOpTypeVector) {
  878. const uint32_t component_type = type_inst->word(2);
  879. multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
  880. inst_.result_id);
  881. }
  882. if (type_inst->opcode() == SpvOpTypePointer) {
  883. assert(type_inst->operands().size() > 2);
  884. assert(type_inst->words().size() > type_inst->operands()[2].offset);
  885. const uint32_t data_type =
  886. type_inst->word(type_inst->operands()[2].offset);
  887. multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
  888. if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
  889. multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
  890. }
  891. }
  892. if (spvOpcodeGeneratesType(opcode)) {
  893. if (opcode != SpvOpTypeFunction) {
  894. multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
  895. }
  896. }
  897. }
  898. if (model_->AnyDescriptorHasCodingScheme()) {
  899. const uint32_t long_descriptor =
  900. long_id_descriptors_.ProcessInstruction(inst_);
  901. if (model_->DescriptorHasCodingScheme(long_descriptor))
  902. multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor),
  903. inst_.result_id);
  904. }
  905. if (model_->id_fallback_strategy() ==
  906. MarkvModel::IdFallbackStrategy::kShortDescriptor) {
  907. const uint32_t short_descriptor =
  908. short_id_descriptors_.ProcessInstruction(inst_);
  909. multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor),
  910. inst_.result_id);
  911. }
  912. }
  913. uint64_t MarkvCodecBase::GetRuleBasedMtf() {
  914. // This function is only called for id operands (but not result ids).
  915. assert(spvIsIdType(operand_.type) ||
  916. operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
  917. assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
  918. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  919. // All operand slots which expect label id.
  920. if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
  921. (inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
  922. (inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
  923. (inst_.opcode == SpvOpBranchConditional &&
  924. (operand_index_ == 1 || operand_index_ == 2)) ||
  925. (inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
  926. operand_index_ % 2 == 1) ||
  927. (inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
  928. return kMtfLabel;
  929. }
  930. switch (opcode) {
  931. case SpvOpFAdd:
  932. case SpvOpFSub:
  933. case SpvOpFMul:
  934. case SpvOpFDiv:
  935. case SpvOpFRem:
  936. case SpvOpFMod:
  937. case SpvOpFNegate: {
  938. if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector;
  939. return GetMtfIdOfType(inst_.type_id);
  940. }
  941. case SpvOpISub:
  942. case SpvOpIAdd:
  943. case SpvOpIMul:
  944. case SpvOpSDiv:
  945. case SpvOpUDiv:
  946. case SpvOpSMod:
  947. case SpvOpUMod:
  948. case SpvOpSRem:
  949. case SpvOpSNegate: {
  950. if (operand_index_ == 0) return kMtfTypeIntScalarOrVector;
  951. return kMtfIntScalarOrVector;
  952. }
  953. // TODO([email protected]) Add OpConvertFToU and other opcodes.
  954. case SpvOpFOrdEqual:
  955. case SpvOpFUnordEqual:
  956. case SpvOpFOrdNotEqual:
  957. case SpvOpFUnordNotEqual:
  958. case SpvOpFOrdLessThan:
  959. case SpvOpFUnordLessThan:
  960. case SpvOpFOrdGreaterThan:
  961. case SpvOpFUnordGreaterThan:
  962. case SpvOpFOrdLessThanEqual:
  963. case SpvOpFUnordLessThanEqual:
  964. case SpvOpFOrdGreaterThanEqual:
  965. case SpvOpFUnordGreaterThanEqual: {
  966. if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector;
  967. if (operand_index_ == 2) return kMtfFloatScalarOrVector;
  968. if (operand_index_ == 3) {
  969. const uint32_t first_operand_id = GetInstWords()[3];
  970. const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id);
  971. return GetMtfIdOfType(first_operand_type);
  972. }
  973. break;
  974. }
  975. case SpvOpVectorShuffle: {
  976. if (operand_index_ == 0) {
  977. assert(inst_.num_operands > 4);
  978. return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
  979. }
  980. assert(inst_.type_id);
  981. if (operand_index_ == 2 || operand_index_ == 3)
  982. return GetMtfVectorOfComponentType(
  983. GetVectorComponentType(inst_.type_id));
  984. break;
  985. }
  986. case SpvOpVectorTimesScalar: {
  987. if (operand_index_ == 0) {
  988. // TODO([email protected]) Could be narrowed to vector of floats.
  989. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
  990. }
  991. assert(inst_.type_id);
  992. if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id);
  993. if (operand_index_ == 3)
  994. return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
  995. break;
  996. }
  997. case SpvOpDot: {
  998. if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
  999. assert(inst_.type_id);
  1000. if (operand_index_ == 2)
  1001. return GetMtfVectorOfComponentType(inst_.type_id);
  1002. if (operand_index_ == 3) {
  1003. const uint32_t vector_id = GetInstWords()[3];
  1004. const uint32_t vector_type = id_to_type_id_.at(vector_id);
  1005. return GetMtfIdOfType(vector_type);
  1006. }
  1007. break;
  1008. }
  1009. case SpvOpTypeVector: {
  1010. if (operand_index_ == 1) {
  1011. return kMtfTypeScalar;
  1012. }
  1013. break;
  1014. }
  1015. case SpvOpTypeMatrix: {
  1016. if (operand_index_ == 1) {
  1017. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
  1018. }
  1019. break;
  1020. }
  1021. case SpvOpTypePointer: {
  1022. if (operand_index_ == 2) {
  1023. return kMtfTypeNonFunction;
  1024. }
  1025. break;
  1026. }
  1027. case SpvOpTypeStruct: {
  1028. if (operand_index_ >= 1) {
  1029. return kMtfTypeNonFunction;
  1030. }
  1031. break;
  1032. }
  1033. case SpvOpTypeFunction: {
  1034. if (operand_index_ == 1) {
  1035. return kMtfTypeNonFunction;
  1036. }
  1037. if (operand_index_ >= 2) {
  1038. return kMtfTypeNonFunction;
  1039. }
  1040. break;
  1041. }
  1042. case SpvOpLoad: {
  1043. if (operand_index_ == 0) return kMtfTypeNonFunction;
  1044. if (operand_index_ == 2) {
  1045. assert(inst_.type_id);
  1046. return GetMtfPointerToType(inst_.type_id);
  1047. }
  1048. break;
  1049. }
  1050. case SpvOpStore: {
  1051. if (operand_index_ == 0)
  1052. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
  1053. if (operand_index_ == 1) {
  1054. const uint32_t pointer_id = GetInstWords()[1];
  1055. const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
  1056. const Instruction* pointer_inst = FindDef(pointer_type);
  1057. assert(pointer_inst);
  1058. assert(pointer_inst->opcode() == SpvOpTypePointer);
  1059. const uint32_t data_type =
  1060. pointer_inst->word(pointer_inst->operands()[2].offset);
  1061. return GetMtfIdOfType(data_type);
  1062. }
  1063. break;
  1064. }
  1065. case SpvOpVariable: {
  1066. if (operand_index_ == 0)
  1067. return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
  1068. break;
  1069. }
  1070. case SpvOpAccessChain: {
  1071. if (operand_index_ == 0)
  1072. return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
  1073. if (operand_index_ == 2) return kMtfTypePointerToComposite;
  1074. if (operand_index_ >= 3)
  1075. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
  1076. break;
  1077. }
  1078. case SpvOpCompositeConstruct: {
  1079. if (operand_index_ == 0) return kMtfTypeComposite;
  1080. if (operand_index_ >= 2) {
  1081. const uint32_t composite_type = GetInstWords()[1];
  1082. if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
  1083. return kMtfFloatScalarOrVector;
  1084. if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
  1085. return kMtfIntScalarOrVector;
  1086. if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
  1087. return kMtfBoolScalarOrVector;
  1088. }
  1089. break;
  1090. }
  1091. case SpvOpCompositeExtract: {
  1092. if (operand_index_ == 2) return kMtfComposite;
  1093. break;
  1094. }
  1095. case SpvOpConstantComposite: {
  1096. if (operand_index_ == 0) return kMtfTypeComposite;
  1097. if (operand_index_ >= 2) {
  1098. const Instruction* composite_type_inst = FindDef(inst_.type_id);
  1099. assert(composite_type_inst);
  1100. if (composite_type_inst->opcode() == SpvOpTypeVector) {
  1101. return GetMtfIdOfType(composite_type_inst->word(2));
  1102. }
  1103. }
  1104. break;
  1105. }
  1106. case SpvOpExtInst: {
  1107. if (operand_index_ == 2)
  1108. return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
  1109. if (operand_index_ >= 4) {
  1110. const uint32_t return_type = GetInstWords()[1];
  1111. const uint32_t ext_inst_type = inst_.ext_inst_type;
  1112. const uint32_t ext_inst_index = GetInstWords()[4];
  1113. // TODO([email protected]) The list of extended instructions is
  1114. // incomplete. Only common instructions and low-hanging fruits listed.
  1115. if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
  1116. switch (ext_inst_index) {
  1117. case GLSLstd450FAbs:
  1118. case GLSLstd450FClamp:
  1119. case GLSLstd450FMax:
  1120. case GLSLstd450FMin:
  1121. case GLSLstd450FMix:
  1122. case GLSLstd450Step:
  1123. case GLSLstd450SmoothStep:
  1124. case GLSLstd450Fma:
  1125. case GLSLstd450Pow:
  1126. case GLSLstd450Exp:
  1127. case GLSLstd450Exp2:
  1128. case GLSLstd450Log:
  1129. case GLSLstd450Log2:
  1130. case GLSLstd450Sqrt:
  1131. case GLSLstd450InverseSqrt:
  1132. case GLSLstd450Fract:
  1133. case GLSLstd450Floor:
  1134. case GLSLstd450Ceil:
  1135. case GLSLstd450Radians:
  1136. case GLSLstd450Degrees:
  1137. case GLSLstd450Sin:
  1138. case GLSLstd450Cos:
  1139. case GLSLstd450Tan:
  1140. case GLSLstd450Sinh:
  1141. case GLSLstd450Cosh:
  1142. case GLSLstd450Tanh:
  1143. case GLSLstd450Asin:
  1144. case GLSLstd450Acos:
  1145. case GLSLstd450Atan:
  1146. case GLSLstd450Atan2:
  1147. case GLSLstd450Asinh:
  1148. case GLSLstd450Acosh:
  1149. case GLSLstd450Atanh:
  1150. case GLSLstd450MatrixInverse:
  1151. case GLSLstd450Cross:
  1152. case GLSLstd450Normalize:
  1153. case GLSLstd450Reflect:
  1154. case GLSLstd450FaceForward:
  1155. return GetMtfIdOfType(return_type);
  1156. case GLSLstd450Length:
  1157. case GLSLstd450Distance:
  1158. case GLSLstd450Refract:
  1159. return kMtfFloatScalarOrVector;
  1160. default:
  1161. break;
  1162. }
  1163. } else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
  1164. switch (ext_inst_index) {
  1165. case OpenCLLIB::Fabs:
  1166. case OpenCLLIB::FClamp:
  1167. case OpenCLLIB::Fmax:
  1168. case OpenCLLIB::Fmin:
  1169. case OpenCLLIB::Step:
  1170. case OpenCLLIB::Smoothstep:
  1171. case OpenCLLIB::Fma:
  1172. case OpenCLLIB::Pow:
  1173. case OpenCLLIB::Exp:
  1174. case OpenCLLIB::Exp2:
  1175. case OpenCLLIB::Log:
  1176. case OpenCLLIB::Log2:
  1177. case OpenCLLIB::Sqrt:
  1178. case OpenCLLIB::Rsqrt:
  1179. case OpenCLLIB::Fract:
  1180. case OpenCLLIB::Floor:
  1181. case OpenCLLIB::Ceil:
  1182. case OpenCLLIB::Radians:
  1183. case OpenCLLIB::Degrees:
  1184. case OpenCLLIB::Sin:
  1185. case OpenCLLIB::Cos:
  1186. case OpenCLLIB::Tan:
  1187. case OpenCLLIB::Sinh:
  1188. case OpenCLLIB::Cosh:
  1189. case OpenCLLIB::Tanh:
  1190. case OpenCLLIB::Asin:
  1191. case OpenCLLIB::Acos:
  1192. case OpenCLLIB::Atan:
  1193. case OpenCLLIB::Atan2:
  1194. case OpenCLLIB::Asinh:
  1195. case OpenCLLIB::Acosh:
  1196. case OpenCLLIB::Atanh:
  1197. case OpenCLLIB::Cross:
  1198. case OpenCLLIB::Normalize:
  1199. return GetMtfIdOfType(return_type);
  1200. case OpenCLLIB::Length:
  1201. case OpenCLLIB::Distance:
  1202. return kMtfFloatScalarOrVector;
  1203. default:
  1204. break;
  1205. }
  1206. }
  1207. }
  1208. break;
  1209. }
  1210. case SpvOpFunction: {
  1211. if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
  1212. if (operand_index_ == 3) {
  1213. const uint32_t return_type = GetInstWords()[1];
  1214. return GetMtfFunctionTypeWithReturnType(return_type);
  1215. }
  1216. break;
  1217. }
  1218. case SpvOpFunctionCall: {
  1219. if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
  1220. if (operand_index_ == 2) {
  1221. const uint32_t return_type = GetInstWords()[1];
  1222. return GetMtfFunctionWithReturnType(return_type);
  1223. }
  1224. if (operand_index_ >= 3) {
  1225. const uint32_t function_id = GetInstWords()[3];
  1226. const Instruction* function_inst = FindDef(function_id);
  1227. if (!function_inst) return kMtfObject;
  1228. assert(function_inst->opcode() == SpvOpFunction);
  1229. const uint32_t function_type_id = function_inst->word(4);
  1230. const Instruction* function_type_inst = FindDef(function_type_id);
  1231. assert(function_type_inst);
  1232. assert(function_type_inst->opcode() == SpvOpTypeFunction);
  1233. const uint32_t argument_type = function_type_inst->word(operand_index_);
  1234. return GetMtfIdOfType(argument_type);
  1235. }
  1236. break;
  1237. }
  1238. case SpvOpReturnValue: {
  1239. if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_);
  1240. break;
  1241. }
  1242. case SpvOpBranchConditional: {
  1243. if (operand_index_ == 0)
  1244. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
  1245. break;
  1246. }
  1247. case SpvOpSampledImage: {
  1248. if (operand_index_ == 0)
  1249. return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
  1250. if (operand_index_ == 2)
  1251. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
  1252. if (operand_index_ == 3)
  1253. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
  1254. break;
  1255. }
  1256. case SpvOpImageSampleImplicitLod: {
  1257. if (operand_index_ == 0)
  1258. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
  1259. if (operand_index_ == 2)
  1260. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
  1261. if (operand_index_ == 3)
  1262. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
  1263. break;
  1264. }
  1265. default:
  1266. break;
  1267. }
  1268. return kMtfNone;
  1269. }
  1270. spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
  1271. auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
  1272. if (codec) {
  1273. uint64_t bits = 0;
  1274. size_t num_bits = 0;
  1275. if (codec->Encode(word, &bits, &num_bits)) {
  1276. // Encoding successful.
  1277. writer_.WriteBits(bits, num_bits);
  1278. return SPV_SUCCESS;
  1279. } else {
  1280. // Encoding failed, write kMarkvNoneOfTheAbove flag.
  1281. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
  1282. return Diag(SPV_ERROR_INTERNAL)
  1283. << "Non-id word Huffman table for "
  1284. << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
  1285. << operand_index_ << " is missing kMarkvNoneOfTheAbove";
  1286. writer_.WriteBits(bits, num_bits);
  1287. }
  1288. }
  1289. // Fallback encoding.
  1290. const size_t chunk_length =
  1291. model_->GetOperandVariableWidthChunkLength(operand_.type);
  1292. if (chunk_length) {
  1293. writer_.WriteVariableWidthU32(word, chunk_length);
  1294. } else {
  1295. writer_.WriteUnencoded(word);
  1296. }
  1297. return SPV_SUCCESS;
  1298. }
  1299. spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
  1300. auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
  1301. if (codec) {
  1302. uint64_t decoded_value = 0;
  1303. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  1304. return Diag(SPV_ERROR_INVALID_BINARY)
  1305. << "Failed to decode non-id word with Huffman";
  1306. if (decoded_value != kMarkvNoneOfTheAbove) {
  1307. // The word decoded successfully.
  1308. *word = uint32_t(decoded_value);
  1309. assert(*word == decoded_value);
  1310. return SPV_SUCCESS;
  1311. }
  1312. // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
  1313. }
  1314. const size_t chunk_length =
  1315. model_->GetOperandVariableWidthChunkLength(operand_.type);
  1316. if (chunk_length) {
  1317. if (!reader_.ReadVariableWidthU32(word, chunk_length))
  1318. return Diag(SPV_ERROR_INVALID_BINARY)
  1319. << "Failed to decode non-id word with varint";
  1320. } else {
  1321. if (!reader_.ReadUnencoded(word))
  1322. return Diag(SPV_ERROR_INVALID_BINARY)
  1323. << "Failed to read unencoded non-id word";
  1324. }
  1325. return SPV_SUCCESS;
  1326. }
  1327. spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode,
  1328. uint32_t num_operands) {
  1329. uint64_t bits = 0;
  1330. size_t num_bits = 0;
  1331. const uint32_t word = opcode | (num_operands << 16);
  1332. // First try to use the Markov chain codec.
  1333. auto* codec =
  1334. model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
  1335. if (codec) {
  1336. if (codec->Encode(word, &bits, &num_bits)) {
  1337. // The word was successfully encoded into bits/num_bits.
  1338. writer_.WriteBits(bits, num_bits);
  1339. return SPV_SUCCESS;
  1340. } else {
  1341. // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
  1342. // and use fallback encoding.
  1343. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
  1344. return Diag(SPV_ERROR_INTERNAL)
  1345. << "opcode_and_num_operands Huffman table for "
  1346. << spvOpcodeString(GetPrevOpcode())
  1347. << "is missing kMarkvNoneOfTheAbove";
  1348. writer_.WriteBits(bits, num_bits);
  1349. }
  1350. }
  1351. // Fallback to base-rate codec.
  1352. codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
  1353. assert(codec);
  1354. if (codec->Encode(word, &bits, &num_bits)) {
  1355. // The word was successfully encoded into bits/num_bits.
  1356. writer_.WriteBits(bits, num_bits);
  1357. return SPV_SUCCESS;
  1358. } else {
  1359. // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
  1360. // and return false.
  1361. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
  1362. return Diag(SPV_ERROR_INTERNAL)
  1363. << "Global opcode_and_num_operands Huffman table is missing "
  1364. << "kMarkvNoneOfTheAbove";
  1365. writer_.WriteBits(bits, num_bits);
  1366. return SPV_UNSUPPORTED;
  1367. }
  1368. }
  1369. spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
  1370. uint32_t* opcode, uint32_t* num_operands) {
  1371. // First try to use the Markov chain codec.
  1372. auto* codec =
  1373. model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
  1374. if (codec) {
  1375. uint64_t decoded_value = 0;
  1376. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  1377. return Diag(SPV_ERROR_INTERNAL)
  1378. << "Failed to decode opcode_and_num_operands, previous opcode is "
  1379. << spvOpcodeString(GetPrevOpcode());
  1380. if (decoded_value != kMarkvNoneOfTheAbove) {
  1381. // The word was successfully decoded.
  1382. *opcode = uint32_t(decoded_value & 0xFFFF);
  1383. *num_operands = uint32_t(decoded_value >> 16);
  1384. return SPV_SUCCESS;
  1385. }
  1386. // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
  1387. }
  1388. // Fallback to base-rate codec.
  1389. codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
  1390. assert(codec);
  1391. uint64_t decoded_value = 0;
  1392. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  1393. return Diag(SPV_ERROR_INTERNAL)
  1394. << "Failed to decode opcode_and_num_operands with global codec";
  1395. if (decoded_value == kMarkvNoneOfTheAbove) {
  1396. // Received kMarkvNoneOfTheAbove signal, fallback further.
  1397. return SPV_UNSUPPORTED;
  1398. }
  1399. *opcode = uint32_t(decoded_value & 0xFFFF);
  1400. *num_operands = uint32_t(decoded_value >> 16);
  1401. return SPV_SUCCESS;
  1402. }
  1403. spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
  1404. uint64_t fallback_method) {
  1405. const auto* codec = GetMtfHuffmanCodec(mtf);
  1406. if (!codec) {
  1407. assert(fallback_method != kMtfNone);
  1408. codec = GetMtfHuffmanCodec(fallback_method);
  1409. }
  1410. if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
  1411. uint64_t bits = 0;
  1412. size_t num_bits = 0;
  1413. if (rank < kMtfSmallestRankEncodedByValue) {
  1414. // Encode using Huffman coding.
  1415. if (!codec->Encode(rank, &bits, &num_bits))
  1416. return Diag(SPV_ERROR_INTERNAL)
  1417. << "Failed to encode MTF rank with Huffman";
  1418. writer_.WriteBits(bits, num_bits);
  1419. } else {
  1420. // Encode by value.
  1421. if (!codec->Encode(kMtfRankEncodedByValueSignal, &bits, &num_bits))
  1422. return Diag(SPV_ERROR_INTERNAL)
  1423. << "Failed to encode kMtfRankEncodedByValueSignal";
  1424. writer_.WriteBits(bits, num_bits);
  1425. writer_.WriteVariableWidthU32(rank - kMtfSmallestRankEncodedByValue,
  1426. model_->mtf_rank_chunk_length());
  1427. }
  1428. return SPV_SUCCESS;
  1429. }
  1430. spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
  1431. uint32_t fallback_method,
  1432. uint32_t* rank) {
  1433. const auto* codec = GetMtfHuffmanCodec(mtf);
  1434. if (!codec) {
  1435. assert(fallback_method != kMtfNone);
  1436. codec = GetMtfHuffmanCodec(fallback_method);
  1437. }
  1438. if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
  1439. uint32_t decoded_value = 0;
  1440. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  1441. return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
  1442. if (decoded_value == kMtfRankEncodedByValueSignal) {
  1443. // Decode by value.
  1444. if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
  1445. return Diag(SPV_ERROR_INTERNAL)
  1446. << "Failed to decode MTF rank with varint";
  1447. *rank += kMtfSmallestRankEncodedByValue;
  1448. } else {
  1449. // Decode using Huffman coding.
  1450. assert(decoded_value < kMtfSmallestRankEncodedByValue);
  1451. *rank = decoded_value;
  1452. }
  1453. return SPV_SUCCESS;
  1454. }
  1455. spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
  1456. // Get the descriptor for id.
  1457. const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id);
  1458. auto* codec =
  1459. model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
  1460. uint64_t bits = 0;
  1461. size_t num_bits = 0;
  1462. uint64_t mtf = kMtfNone;
  1463. if (long_descriptor && codec &&
  1464. codec->Encode(long_descriptor, &bits, &num_bits)) {
  1465. // If the descriptor exists and is in the table, write the descriptor and
  1466. // proceed to encoding the rank.
  1467. writer_.WriteBits(bits, num_bits);
  1468. mtf = GetMtfLongIdDescriptor(long_descriptor);
  1469. } else {
  1470. if (codec) {
  1471. // The descriptor doesn't exist or we have no coding for it. Write
  1472. // kMarkvNoneOfTheAbove and go to fallback method.
  1473. if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
  1474. return Diag(SPV_ERROR_INTERNAL)
  1475. << "Descriptor Huffman table for "
  1476. << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
  1477. << operand_index_ << " is missing kMarkvNoneOfTheAbove";
  1478. writer_.WriteBits(bits, num_bits);
  1479. }
  1480. if (model_->id_fallback_strategy() !=
  1481. MarkvModel::IdFallbackStrategy::kShortDescriptor) {
  1482. return SPV_UNSUPPORTED;
  1483. }
  1484. const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id);
  1485. writer_.WriteBits(short_descriptor, kShortDescriptorNumBits);
  1486. if (short_descriptor == 0) {
  1487. // Forward declared id.
  1488. return SPV_UNSUPPORTED;
  1489. }
  1490. mtf = GetMtfShortIdDescriptor(short_descriptor);
  1491. }
  1492. // Descriptor has been encoded. Now encode the rank of the id in the
  1493. // associated mtf sequence.
  1494. return EncodeExistingId(mtf, id);
  1495. }
  1496. spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
  1497. auto* codec =
  1498. model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
  1499. uint64_t mtf = kMtfNone;
  1500. if (codec) {
  1501. uint64_t decoded_value = 0;
  1502. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  1503. return Diag(SPV_ERROR_INTERNAL)
  1504. << "Failed to decode descriptor with Huffman";
  1505. if (decoded_value != kMarkvNoneOfTheAbove) {
  1506. const uint32_t long_descriptor = uint32_t(decoded_value);
  1507. mtf = GetMtfLongIdDescriptor(long_descriptor);
  1508. }
  1509. }
  1510. if (mtf == kMtfNone) {
  1511. if (model_->id_fallback_strategy() !=
  1512. MarkvModel::IdFallbackStrategy::kShortDescriptor) {
  1513. return SPV_UNSUPPORTED;
  1514. }
  1515. uint64_t decoded_value = 0;
  1516. if (!reader_.ReadBits(&decoded_value, kShortDescriptorNumBits))
  1517. return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
  1518. const uint32_t short_descriptor = uint32_t(decoded_value);
  1519. if (short_descriptor == 0) {
  1520. // Forward declared id.
  1521. return SPV_UNSUPPORTED;
  1522. }
  1523. mtf = GetMtfShortIdDescriptor(short_descriptor);
  1524. }
  1525. return DecodeExistingId(mtf, id);
  1526. }
  1527. spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
  1528. assert(multi_mtf_.GetSize(mtf) > 0);
  1529. if (multi_mtf_.GetSize(mtf) == 1) {
  1530. // If the sequence has only one element no need to write rank, the decoder
  1531. // would make the same decision.
  1532. return SPV_SUCCESS;
  1533. }
  1534. uint32_t rank = 0;
  1535. if (!multi_mtf_.RankFromValue(mtf, id, &rank))
  1536. return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence";
  1537. return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
  1538. }
  1539. spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
  1540. assert(multi_mtf_.GetSize(mtf) > 0);
  1541. *id = 0;
  1542. uint32_t rank = 0;
  1543. if (multi_mtf_.GetSize(mtf) == 1) {
  1544. rank = 1;
  1545. } else {
  1546. const spv_result_t result =
  1547. DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
  1548. if (result != SPV_SUCCESS) return result;
  1549. }
  1550. assert(rank);
  1551. if (!multi_mtf_.ValueFromRank(mtf, rank, id))
  1552. return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
  1553. return SPV_SUCCESS;
  1554. }
  1555. spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
  1556. {
  1557. // Try to encode using id descriptor mtfs.
  1558. const spv_result_t result = EncodeIdWithDescriptor(id);
  1559. if (result != SPV_UNSUPPORTED) return result;
  1560. // If can't be done continue with other methods.
  1561. }
  1562. const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
  1563. SpvOp(inst_.opcode))(operand_index_);
  1564. uint32_t rank = 0;
  1565. if (model_->id_fallback_strategy() ==
  1566. MarkvModel::IdFallbackStrategy::kRuleBased) {
  1567. // Encode using rule-based mtf.
  1568. uint64_t mtf = GetRuleBasedMtf();
  1569. if (mtf != kMtfNone && !can_forward_declare) {
  1570. assert(multi_mtf_.HasValue(kMtfAll, id));
  1571. return EncodeExistingId(mtf, id);
  1572. }
  1573. if (mtf == kMtfNone) mtf = kMtfAll;
  1574. if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
  1575. // This is the first occurrence of a forward declared id.
  1576. multi_mtf_.Insert(kMtfAll, id);
  1577. multi_mtf_.Insert(kMtfForwardDeclared, id);
  1578. if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id);
  1579. rank = 0;
  1580. }
  1581. return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
  1582. } else {
  1583. assert(can_forward_declare);
  1584. if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) {
  1585. // This is the first occurrence of a forward declared id.
  1586. multi_mtf_.Insert(kMtfForwardDeclared, id);
  1587. rank = 0;
  1588. }
  1589. writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
  1590. return SPV_SUCCESS;
  1591. }
  1592. }
  1593. spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
  1594. {
  1595. const spv_result_t result = DecodeIdWithDescriptor(id);
  1596. if (result != SPV_UNSUPPORTED) return result;
  1597. }
  1598. const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
  1599. SpvOp(inst_.opcode))(operand_index_);
  1600. uint32_t rank = 0;
  1601. *id = 0;
  1602. if (model_->id_fallback_strategy() ==
  1603. MarkvModel::IdFallbackStrategy::kRuleBased) {
  1604. uint64_t mtf = GetRuleBasedMtf();
  1605. if (mtf != kMtfNone && !can_forward_declare) {
  1606. return DecodeExistingId(mtf, id);
  1607. }
  1608. if (mtf == kMtfNone) mtf = kMtfAll;
  1609. {
  1610. const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
  1611. if (result != SPV_SUCCESS) return result;
  1612. }
  1613. if (rank == 0) {
  1614. // This is the first occurrence of a forward declared id.
  1615. *id = GetIdBound();
  1616. SetIdBound(*id + 1);
  1617. multi_mtf_.Insert(kMtfAll, *id);
  1618. multi_mtf_.Insert(kMtfForwardDeclared, *id);
  1619. if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
  1620. } else {
  1621. if (!multi_mtf_.ValueFromRank(mtf, rank, id))
  1622. return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
  1623. }
  1624. } else {
  1625. assert(can_forward_declare);
  1626. if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
  1627. return Diag(SPV_ERROR_INTERNAL)
  1628. << "Failed to decode MTF rank with varint";
  1629. if (rank == 0) {
  1630. // This is the first occurrence of a forward declared id.
  1631. *id = GetIdBound();
  1632. SetIdBound(*id + 1);
  1633. multi_mtf_.Insert(kMtfForwardDeclared, *id);
  1634. } else {
  1635. if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
  1636. return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
  1637. }
  1638. }
  1639. assert(*id);
  1640. return SPV_SUCCESS;
  1641. }
  1642. spv_result_t MarkvEncoder::EncodeTypeId() {
  1643. if (inst_.opcode == SpvOpFunctionParameter) {
  1644. assert(!remaining_function_parameter_types_.empty());
  1645. assert(inst_.type_id == remaining_function_parameter_types_.front());
  1646. remaining_function_parameter_types_.pop_front();
  1647. return SPV_SUCCESS;
  1648. }
  1649. {
  1650. // Try to encode using id descriptor mtfs.
  1651. const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
  1652. if (result != SPV_UNSUPPORTED) return result;
  1653. // If can't be done continue with other methods.
  1654. }
  1655. assert(model_->id_fallback_strategy() ==
  1656. MarkvModel::IdFallbackStrategy::kRuleBased);
  1657. uint64_t mtf = GetRuleBasedMtf();
  1658. assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
  1659. operand_index_));
  1660. if (mtf == kMtfNone) {
  1661. mtf = kMtfTypeNonFunction;
  1662. // Function types should have been handled by GetRuleBasedMtf.
  1663. assert(inst_.opcode != SpvOpFunction);
  1664. }
  1665. return EncodeExistingId(mtf, inst_.type_id);
  1666. }
  1667. spv_result_t MarkvDecoder::DecodeTypeId() {
  1668. if (inst_.opcode == SpvOpFunctionParameter) {
  1669. assert(!remaining_function_parameter_types_.empty());
  1670. inst_.type_id = remaining_function_parameter_types_.front();
  1671. remaining_function_parameter_types_.pop_front();
  1672. return SPV_SUCCESS;
  1673. }
  1674. {
  1675. const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
  1676. if (result != SPV_UNSUPPORTED) return result;
  1677. }
  1678. assert(model_->id_fallback_strategy() ==
  1679. MarkvModel::IdFallbackStrategy::kRuleBased);
  1680. uint64_t mtf = GetRuleBasedMtf();
  1681. assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
  1682. operand_index_));
  1683. if (mtf == kMtfNone) {
  1684. mtf = kMtfTypeNonFunction;
  1685. // Function types should have been handled by GetRuleBasedMtf.
  1686. assert(inst_.opcode != SpvOpFunction);
  1687. }
  1688. return DecodeExistingId(mtf, &inst_.type_id);
  1689. }
  1690. spv_result_t MarkvEncoder::EncodeResultId() {
  1691. uint32_t rank = 0;
  1692. const uint64_t num_still_forward_declared =
  1693. multi_mtf_.GetSize(kMtfForwardDeclared);
  1694. if (num_still_forward_declared) {
  1695. // We write the rank only if kMtfForwardDeclared is not empty. If it is
  1696. // empty the decoder knows that there are no forward declared ids to expect.
  1697. if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) {
  1698. // This is a definition of a forward declared id. We can remove the id
  1699. // from kMtfForwardDeclared.
  1700. if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
  1701. return Diag(SPV_ERROR_INTERNAL)
  1702. << "Failed to remove id from kMtfForwardDeclared";
  1703. writer_.WriteBits(1, 1);
  1704. writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
  1705. } else {
  1706. rank = 0;
  1707. writer_.WriteBits(0, 1);
  1708. }
  1709. }
  1710. if (model_->id_fallback_strategy() ==
  1711. MarkvModel::IdFallbackStrategy::kRuleBased) {
  1712. if (!rank) {
  1713. multi_mtf_.Insert(kMtfAll, inst_.result_id);
  1714. }
  1715. }
  1716. return SPV_SUCCESS;
  1717. }
  1718. spv_result_t MarkvDecoder::DecodeResultId() {
  1719. uint32_t rank = 0;
  1720. const uint64_t num_still_forward_declared =
  1721. multi_mtf_.GetSize(kMtfForwardDeclared);
  1722. if (num_still_forward_declared) {
  1723. // Some ids were forward declared. Check if this id is one of them.
  1724. uint64_t id_was_forward_declared;
  1725. if (!reader_.ReadBits(&id_was_forward_declared, 1))
  1726. return Diag(SPV_ERROR_INVALID_BINARY)
  1727. << "Failed to read id_was_forward_declared flag";
  1728. if (id_was_forward_declared) {
  1729. if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
  1730. return Diag(SPV_ERROR_INVALID_BINARY)
  1731. << "Failed to read MTF rank of forward declared id";
  1732. if (rank) {
  1733. // The id was forward declared, recover it from kMtfForwardDeclared.
  1734. if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
  1735. &inst_.result_id))
  1736. return Diag(SPV_ERROR_INTERNAL)
  1737. << "Forward declared MTF rank is out of bounds";
  1738. // We can now remove the id from kMtfForwardDeclared.
  1739. if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
  1740. return Diag(SPV_ERROR_INTERNAL)
  1741. << "Failed to remove id from kMtfForwardDeclared";
  1742. }
  1743. }
  1744. }
  1745. if (inst_.result_id == 0) {
  1746. // The id was not forward declared, issue a new id.
  1747. inst_.result_id = GetIdBound();
  1748. SetIdBound(inst_.result_id + 1);
  1749. }
  1750. if (model_->id_fallback_strategy() ==
  1751. MarkvModel::IdFallbackStrategy::kRuleBased) {
  1752. if (!rank) {
  1753. multi_mtf_.Insert(kMtfAll, inst_.result_id);
  1754. }
  1755. }
  1756. return SPV_SUCCESS;
  1757. }
  1758. spv_result_t MarkvEncoder::EncodeLiteralNumber(
  1759. const spv_parsed_operand_t& operand) {
  1760. if (operand.number_bit_width <= 32) {
  1761. const uint32_t word = inst_.words[operand.offset];
  1762. return EncodeNonIdWord(word);
  1763. } else {
  1764. assert(operand.number_bit_width <= 64);
  1765. const uint64_t word = uint64_t(inst_.words[operand.offset]) |
  1766. (uint64_t(inst_.words[operand.offset + 1]) << 32);
  1767. if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
  1768. writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
  1769. } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
  1770. int64_t val = 0;
  1771. std::memcpy(&val, &word, 8);
  1772. writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
  1773. model_->s64_block_exponent());
  1774. } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
  1775. writer_.WriteUnencoded(word);
  1776. } else {
  1777. return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
  1778. }
  1779. }
  1780. return SPV_SUCCESS;
  1781. }
  1782. spv_result_t MarkvDecoder::DecodeLiteralNumber(
  1783. const spv_parsed_operand_t& operand) {
  1784. if (operand.number_bit_width <= 32) {
  1785. uint32_t word = 0;
  1786. const spv_result_t result = DecodeNonIdWord(&word);
  1787. if (result != SPV_SUCCESS) return result;
  1788. inst_words_.push_back(word);
  1789. } else {
  1790. assert(operand.number_bit_width <= 64);
  1791. uint64_t word = 0;
  1792. if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
  1793. if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
  1794. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
  1795. } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
  1796. int64_t val = 0;
  1797. if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
  1798. model_->s64_block_exponent()))
  1799. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
  1800. std::memcpy(&word, &val, 8);
  1801. } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
  1802. if (!reader_.ReadUnencoded(&word))
  1803. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
  1804. } else {
  1805. return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
  1806. }
  1807. inst_words_.push_back(static_cast<uint32_t>(word));
  1808. inst_words_.push_back(static_cast<uint32_t>(word >> 32));
  1809. }
  1810. return SPV_SUCCESS;
  1811. }
  1812. void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
  1813. const size_t num_bits_to_next_byte =
  1814. GetNumBitsToNextByte(writer_.GetNumBits());
  1815. if (num_bits_to_next_byte == 0 ||
  1816. num_bits_to_next_byte > byte_break_if_less_than)
  1817. return;
  1818. if (logger_) {
  1819. logger_->AppendWhitespaces(kCommentNumWhitespaces);
  1820. logger_->AppendText("<byte break>");
  1821. }
  1822. writer_.WriteBits(0, num_bits_to_next_byte);
  1823. }
  1824. bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
  1825. const size_t num_bits_to_next_byte =
  1826. GetNumBitsToNextByte(reader_.GetNumReadBits());
  1827. if (num_bits_to_next_byte == 0 ||
  1828. num_bits_to_next_byte > byte_break_if_less_than)
  1829. return true;
  1830. uint64_t bits = 0;
  1831. if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
  1832. assert(bits == 0);
  1833. if (bits != 0) return false;
  1834. return true;
  1835. }
  1836. spv_result_t MarkvEncoder::EncodeInstruction(
  1837. const spv_parsed_instruction_t& inst) {
  1838. SpvOp opcode = SpvOp(inst.opcode);
  1839. inst_ = inst;
  1840. const spv_result_t validation_result = UpdateValidationState(inst);
  1841. if (validation_result != SPV_SUCCESS) return validation_result;
  1842. LogDisassemblyInstruction();
  1843. const spv_result_t opcode_encodig_result =
  1844. EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
  1845. if (opcode_encodig_result < 0) return opcode_encodig_result;
  1846. if (opcode_encodig_result != SPV_SUCCESS) {
  1847. // Fallback encoding for opcode and num_operands.
  1848. writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
  1849. if (!OpcodeHasFixedNumberOfOperands(opcode)) {
  1850. // If the opcode has a variable number of operands, encode the number of
  1851. // operands with the instruction.
  1852. if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces);
  1853. writer_.WriteVariableWidthU16(inst.num_operands,
  1854. model_->num_operands_chunk_length());
  1855. }
  1856. }
  1857. // Write operands.
  1858. const uint32_t num_operands = inst_.num_operands;
  1859. for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
  1860. operand_ = inst_.operands[operand_index_];
  1861. if (logger_) {
  1862. logger_->AppendWhitespaces(kCommentNumWhitespaces);
  1863. logger_->AppendText("<");
  1864. logger_->AppendText(spvOperandTypeStr(operand_.type));
  1865. logger_->AppendText(">");
  1866. }
  1867. switch (operand_.type) {
  1868. case SPV_OPERAND_TYPE_RESULT_ID:
  1869. case SPV_OPERAND_TYPE_TYPE_ID:
  1870. case SPV_OPERAND_TYPE_ID:
  1871. case SPV_OPERAND_TYPE_OPTIONAL_ID:
  1872. case SPV_OPERAND_TYPE_SCOPE_ID:
  1873. case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
  1874. const uint32_t id = inst_.words[operand_.offset];
  1875. if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
  1876. const spv_result_t result = EncodeTypeId();
  1877. if (result != SPV_SUCCESS) return result;
  1878. } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
  1879. const spv_result_t result = EncodeResultId();
  1880. if (result != SPV_SUCCESS) return result;
  1881. } else {
  1882. const spv_result_t result = EncodeRefId(id);
  1883. if (result != SPV_SUCCESS) return result;
  1884. }
  1885. PromoteIfNeeded(id);
  1886. break;
  1887. }
  1888. case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
  1889. const spv_result_t result =
  1890. EncodeNonIdWord(inst_.words[operand_.offset]);
  1891. if (result != SPV_SUCCESS) return result;
  1892. break;
  1893. }
  1894. case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
  1895. const spv_result_t result = EncodeLiteralNumber(operand_);
  1896. if (result != SPV_SUCCESS) return result;
  1897. break;
  1898. }
  1899. case SPV_OPERAND_TYPE_LITERAL_STRING: {
  1900. const char* src =
  1901. reinterpret_cast<const char*>(&inst_.words[operand_.offset]);
  1902. auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
  1903. if (codec) {
  1904. uint64_t bits = 0;
  1905. size_t num_bits = 0;
  1906. const std::string str = src;
  1907. if (codec->Encode(str, &bits, &num_bits)) {
  1908. writer_.WriteBits(bits, num_bits);
  1909. break;
  1910. } else {
  1911. bool result =
  1912. codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits);
  1913. (void)result;
  1914. assert(result);
  1915. writer_.WriteBits(bits, num_bits);
  1916. }
  1917. }
  1918. const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
  1919. if (length == operand_.num_words * 4)
  1920. return Diag(SPV_ERROR_INVALID_BINARY)
  1921. << "Failed to find terminal character of literal string";
  1922. for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]);
  1923. break;
  1924. }
  1925. default: {
  1926. for (int i = 0; i < operand_.num_words; ++i) {
  1927. const uint32_t word = inst_.words[operand_.offset + i];
  1928. const spv_result_t result = EncodeNonIdWord(word);
  1929. if (result != SPV_SUCCESS) return result;
  1930. }
  1931. break;
  1932. }
  1933. }
  1934. }
  1935. AddByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte);
  1936. if (logger_) {
  1937. logger_->NewLine();
  1938. logger_->NewLine();
  1939. if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
  1940. }
  1941. ProcessCurInstruction();
  1942. return SPV_SUCCESS;
  1943. }
  1944. spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
  1945. const bool header_read_success =
  1946. reader_.ReadUnencoded(&header_.magic_number) &&
  1947. reader_.ReadUnencoded(&header_.markv_version) &&
  1948. reader_.ReadUnencoded(&header_.markv_model) &&
  1949. reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
  1950. reader_.ReadUnencoded(&header_.spirv_version) &&
  1951. reader_.ReadUnencoded(&header_.spirv_generator);
  1952. if (!header_read_success)
  1953. return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
  1954. if (header_.markv_length_in_bits == 0)
  1955. return Diag(SPV_ERROR_INVALID_BINARY)
  1956. << "Header markv_length_in_bits field is zero";
  1957. if (header_.magic_number != kMarkvMagicNumber)
  1958. return Diag(SPV_ERROR_INVALID_BINARY)
  1959. << "MARK-V binary has incorrect magic number";
  1960. // TODO([email protected]): Print version strings.
  1961. if (header_.markv_version != GetMarkvVersion())
  1962. return Diag(SPV_ERROR_INVALID_BINARY)
  1963. << "MARK-V binary and the codec have different versions";
  1964. const uint32_t model_type = header_.markv_model >> 16;
  1965. const uint32_t model_version = header_.markv_model & 0xFFFF;
  1966. if (model_type != model_->model_type())
  1967. return Diag(SPV_ERROR_INVALID_BINARY)
  1968. << "MARK-V binary and the codec use different MARK-V models";
  1969. if (model_version != model_->model_version())
  1970. return Diag(SPV_ERROR_INVALID_BINARY)
  1971. << "MARK-V binary and the codec use different versions if the same "
  1972. << "MARK-V model";
  1973. spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
  1974. spirv_.resize(5, 0);
  1975. spirv_[0] = kSpirvMagicNumber;
  1976. spirv_[1] = header_.spirv_version;
  1977. spirv_[2] = header_.spirv_generator;
  1978. if (logger_) {
  1979. reader_.SetCallback(
  1980. [this](const std::string& str) { logger_->AppendBitSequence(str); });
  1981. }
  1982. while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
  1983. inst_ = {};
  1984. const spv_result_t decode_result = DecodeInstruction();
  1985. if (decode_result != SPV_SUCCESS) return decode_result;
  1986. const spv_result_t validation_result = UpdateValidationState(inst_);
  1987. if (validation_result != SPV_SUCCESS) return validation_result;
  1988. }
  1989. if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
  1990. !reader_.OnlyZeroesLeft()) {
  1991. return Diag(SPV_ERROR_INVALID_BINARY)
  1992. << "MARK-V binary has wrong stated bit length "
  1993. << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
  1994. }
  1995. // Decoding of the module is finished, validation state should have correct
  1996. // id bound.
  1997. spirv_[3] = GetIdBound();
  1998. *spirv_binary = std::move(spirv_);
  1999. return SPV_SUCCESS;
  2000. }
  2001. // TODO([email protected]): The implementation borrows heavily from
  2002. // Parser::parseOperand.
  2003. // Consider coupling them together in some way once MARK-V codec is more mature.
  2004. // For now it's better to keep the code independent for experimentation
  2005. // purposes.
  2006. spv_result_t MarkvDecoder::DecodeOperand(
  2007. size_t operand_offset, const spv_operand_type_t type,
  2008. spv_operand_pattern_t* expected_operands) {
  2009. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  2010. memset(&operand_, 0, sizeof(operand_));
  2011. assert((operand_offset >> 16) == 0);
  2012. operand_.offset = static_cast<uint16_t>(operand_offset);
  2013. operand_.type = type;
  2014. // Set default values, may be updated later.
  2015. operand_.number_kind = SPV_NUMBER_NONE;
  2016. operand_.number_bit_width = 0;
  2017. const size_t first_word_index = inst_words_.size();
  2018. switch (type) {
  2019. case SPV_OPERAND_TYPE_RESULT_ID: {
  2020. const spv_result_t result = DecodeResultId();
  2021. if (result != SPV_SUCCESS) return result;
  2022. inst_words_.push_back(inst_.result_id);
  2023. SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
  2024. PromoteIfNeeded(inst_.result_id);
  2025. break;
  2026. }
  2027. case SPV_OPERAND_TYPE_TYPE_ID: {
  2028. const spv_result_t result = DecodeTypeId();
  2029. if (result != SPV_SUCCESS) return result;
  2030. inst_words_.push_back(inst_.type_id);
  2031. SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
  2032. PromoteIfNeeded(inst_.type_id);
  2033. break;
  2034. }
  2035. case SPV_OPERAND_TYPE_ID:
  2036. case SPV_OPERAND_TYPE_OPTIONAL_ID:
  2037. case SPV_OPERAND_TYPE_SCOPE_ID:
  2038. case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
  2039. uint32_t id = 0;
  2040. const spv_result_t result = DecodeRefId(&id);
  2041. if (result != SPV_SUCCESS) return result;
  2042. if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
  2043. if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
  2044. operand_.type = SPV_OPERAND_TYPE_ID;
  2045. if (opcode == SpvOpExtInst && operand_.offset == 3) {
  2046. // The current word is the extended instruction set id.
  2047. // Set the extended instruction set type for the current
  2048. // instruction.
  2049. auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
  2050. if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
  2051. return Diag(SPV_ERROR_INVALID_ID)
  2052. << "OpExtInst set id " << id
  2053. << " does not reference an OpExtInstImport result Id";
  2054. }
  2055. inst_.ext_inst_type = ext_inst_type_iter->second;
  2056. }
  2057. }
  2058. inst_words_.push_back(id);
  2059. SetIdBound(std::max(GetIdBound(), id + 1));
  2060. PromoteIfNeeded(id);
  2061. break;
  2062. }
  2063. case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
  2064. uint32_t word = 0;
  2065. const spv_result_t result = DecodeNonIdWord(&word);
  2066. if (result != SPV_SUCCESS) return result;
  2067. inst_words_.push_back(word);
  2068. assert(SpvOpExtInst == opcode);
  2069. assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
  2070. spv_ext_inst_desc ext_inst;
  2071. if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
  2072. return Diag(SPV_ERROR_INVALID_BINARY)
  2073. << "Invalid extended instruction number: " << word;
  2074. spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
  2075. break;
  2076. }
  2077. case SPV_OPERAND_TYPE_LITERAL_INTEGER:
  2078. case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
  2079. // These are regular single-word literal integer operands.
  2080. // Post-parsing validation should check the range of the parsed value.
  2081. operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
  2082. // It turns out they are always unsigned integers!
  2083. operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
  2084. operand_.number_bit_width = 32;
  2085. uint32_t word = 0;
  2086. const spv_result_t result = DecodeNonIdWord(&word);
  2087. if (result != SPV_SUCCESS) return result;
  2088. inst_words_.push_back(word);
  2089. break;
  2090. }
  2091. case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
  2092. case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
  2093. operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
  2094. if (opcode == SpvOpSwitch) {
  2095. // The literal operands have the same type as the value
  2096. // referenced by the selector Id.
  2097. const uint32_t selector_id = inst_words_.at(1);
  2098. const auto type_id_iter = id_to_type_id_.find(selector_id);
  2099. if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
  2100. return Diag(SPV_ERROR_INVALID_BINARY)
  2101. << "Invalid OpSwitch: selector id " << selector_id
  2102. << " has no type";
  2103. }
  2104. uint32_t type_id = type_id_iter->second;
  2105. if (selector_id == type_id) {
  2106. // Recall that by convention, a result ID that is a type definition
  2107. // maps to itself.
  2108. return Diag(SPV_ERROR_INVALID_BINARY)
  2109. << "Invalid OpSwitch: selector id " << selector_id
  2110. << " is a type, not a value";
  2111. }
  2112. if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
  2113. return error;
  2114. if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
  2115. operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
  2116. return Diag(SPV_ERROR_INVALID_BINARY)
  2117. << "Invalid OpSwitch: selector id " << selector_id
  2118. << " is not a scalar integer";
  2119. }
  2120. } else {
  2121. assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
  2122. // The literal number type is determined by the type Id for the
  2123. // constant.
  2124. assert(inst_.type_id);
  2125. if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
  2126. return error;
  2127. }
  2128. if (auto error = DecodeLiteralNumber(operand_)) return error;
  2129. break;
  2130. }
  2131. case SPV_OPERAND_TYPE_LITERAL_STRING:
  2132. case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
  2133. operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
  2134. std::vector<char> str;
  2135. auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
  2136. if (codec) {
  2137. std::string decoded_string;
  2138. const bool huffman_result =
  2139. codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
  2140. assert(huffman_result);
  2141. if (!huffman_result)
  2142. return Diag(SPV_ERROR_INVALID_BINARY)
  2143. << "Failed to read literal string";
  2144. if (decoded_string != "kMarkvNoneOfTheAbove") {
  2145. std::copy(decoded_string.begin(), decoded_string.end(),
  2146. std::back_inserter(str));
  2147. str.push_back('\0');
  2148. }
  2149. }
  2150. // The loop is expected to terminate once we encounter '\0' or exhaust
  2151. // the bit stream.
  2152. if (str.empty()) {
  2153. while (true) {
  2154. char ch = 0;
  2155. if (!reader_.ReadUnencoded(&ch))
  2156. return Diag(SPV_ERROR_INVALID_BINARY)
  2157. << "Failed to read literal string";
  2158. str.push_back(ch);
  2159. if (ch == '\0') break;
  2160. }
  2161. }
  2162. while (str.size() % 4 != 0) str.push_back('\0');
  2163. inst_words_.resize(inst_words_.size() + str.size() / 4);
  2164. std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
  2165. if (SpvOpExtInstImport == opcode) {
  2166. // Record the extended instruction type for the ID for this import.
  2167. // There is only one string literal argument to OpExtInstImport,
  2168. // so it's sufficient to guard this just on the opcode.
  2169. const spv_ext_inst_type_t ext_inst_type =
  2170. spvExtInstImportTypeGet(str.data());
  2171. if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
  2172. return Diag(SPV_ERROR_INVALID_BINARY)
  2173. << "Invalid extended instruction import '" << str.data()
  2174. << "'";
  2175. }
  2176. // We must have parsed a valid result ID. It's a condition
  2177. // of the grammar, and we only accept non-zero result Ids.
  2178. assert(inst_.result_id);
  2179. const bool inserted =
  2180. import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
  2181. .second;
  2182. (void)inserted;
  2183. assert(inserted);
  2184. }
  2185. break;
  2186. }
  2187. case SPV_OPERAND_TYPE_CAPABILITY:
  2188. case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
  2189. case SPV_OPERAND_TYPE_EXECUTION_MODEL:
  2190. case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
  2191. case SPV_OPERAND_TYPE_MEMORY_MODEL:
  2192. case SPV_OPERAND_TYPE_EXECUTION_MODE:
  2193. case SPV_OPERAND_TYPE_STORAGE_CLASS:
  2194. case SPV_OPERAND_TYPE_DIMENSIONALITY:
  2195. case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
  2196. case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
  2197. case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
  2198. case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
  2199. case SPV_OPERAND_TYPE_LINKAGE_TYPE:
  2200. case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
  2201. case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
  2202. case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
  2203. case SPV_OPERAND_TYPE_DECORATION:
  2204. case SPV_OPERAND_TYPE_BUILT_IN:
  2205. case SPV_OPERAND_TYPE_GROUP_OPERATION:
  2206. case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
  2207. case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
  2208. // A single word that is a plain enum value.
  2209. uint32_t word = 0;
  2210. const spv_result_t result = DecodeNonIdWord(&word);
  2211. if (result != SPV_SUCCESS) return result;
  2212. inst_words_.push_back(word);
  2213. // Map an optional operand type to its corresponding concrete type.
  2214. if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
  2215. operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
  2216. spv_operand_desc entry;
  2217. if (grammar_.lookupOperand(type, word, &entry)) {
  2218. return Diag(SPV_ERROR_INVALID_BINARY)
  2219. << "Invalid " << spvOperandTypeStr(operand_.type)
  2220. << " operand: " << word;
  2221. }
  2222. // Prepare to accept operands to this operand, if needed.
  2223. spvPushOperandTypes(entry->operandTypes, expected_operands);
  2224. break;
  2225. }
  2226. case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
  2227. case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
  2228. case SPV_OPERAND_TYPE_LOOP_CONTROL:
  2229. case SPV_OPERAND_TYPE_IMAGE:
  2230. case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
  2231. case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
  2232. case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
  2233. // This operand is a mask.
  2234. uint32_t word = 0;
  2235. const spv_result_t result = DecodeNonIdWord(&word);
  2236. if (result != SPV_SUCCESS) return result;
  2237. inst_words_.push_back(word);
  2238. // Map an optional operand type to its corresponding concrete type.
  2239. if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
  2240. operand_.type = SPV_OPERAND_TYPE_IMAGE;
  2241. else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
  2242. operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
  2243. // Check validity of set mask bits. Also prepare for operands for those
  2244. // masks if they have any. To get operand order correct, scan from
  2245. // MSB to LSB since we can only prepend operands to a pattern.
  2246. // The only case in the grammar where you have more than one mask bit
  2247. // having an operand is for image operands. See SPIR-V 3.14 Image
  2248. // Operands.
  2249. uint32_t remaining_word = word;
  2250. for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
  2251. if (remaining_word & mask) {
  2252. spv_operand_desc entry;
  2253. if (grammar_.lookupOperand(type, mask, &entry)) {
  2254. return Diag(SPV_ERROR_INVALID_BINARY)
  2255. << "Invalid " << spvOperandTypeStr(operand_.type)
  2256. << " operand: " << word << " has invalid mask component "
  2257. << mask;
  2258. }
  2259. remaining_word ^= mask;
  2260. spvPushOperandTypes(entry->operandTypes, expected_operands);
  2261. }
  2262. }
  2263. if (word == 0) {
  2264. // An all-zeroes mask *might* also be valid.
  2265. spv_operand_desc entry;
  2266. if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
  2267. // Prepare for its operands, if any.
  2268. spvPushOperandTypes(entry->operandTypes, expected_operands);
  2269. }
  2270. }
  2271. break;
  2272. }
  2273. default:
  2274. return Diag(SPV_ERROR_INVALID_BINARY)
  2275. << "Internal error: Unhandled operand type: " << type;
  2276. }
  2277. operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
  2278. assert(spvOperandIsConcrete(operand_.type));
  2279. parsed_operands_.push_back(operand_);
  2280. return SPV_SUCCESS;
  2281. }
  2282. spv_result_t MarkvDecoder::DecodeInstruction() {
  2283. parsed_operands_.clear();
  2284. inst_words_.clear();
  2285. // Opcode/num_words placeholder, the word will be filled in later.
  2286. inst_words_.push_back(0);
  2287. bool num_operands_still_unknown = true;
  2288. {
  2289. uint32_t opcode = 0;
  2290. uint32_t num_operands = 0;
  2291. const spv_result_t opcode_decoding_result =
  2292. DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
  2293. if (opcode_decoding_result < 0) return opcode_decoding_result;
  2294. if (opcode_decoding_result == SPV_SUCCESS) {
  2295. inst_.num_operands = static_cast<uint16_t>(num_operands);
  2296. num_operands_still_unknown = false;
  2297. } else {
  2298. if (!reader_.ReadVariableWidthU32(&opcode,
  2299. model_->opcode_chunk_length())) {
  2300. return Diag(SPV_ERROR_INVALID_BINARY)
  2301. << "Failed to read opcode of instruction";
  2302. }
  2303. }
  2304. inst_.opcode = static_cast<uint16_t>(opcode);
  2305. }
  2306. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  2307. spv_opcode_desc opcode_desc;
  2308. if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
  2309. return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
  2310. }
  2311. spv_operand_pattern_t expected_operands;
  2312. expected_operands.reserve(opcode_desc->numTypes);
  2313. for (auto i = 0; i < opcode_desc->numTypes; i++) {
  2314. expected_operands.push_back(
  2315. opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
  2316. }
  2317. if (num_operands_still_unknown) {
  2318. if (!OpcodeHasFixedNumberOfOperands(opcode)) {
  2319. if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
  2320. model_->num_operands_chunk_length()))
  2321. return Diag(SPV_ERROR_INVALID_BINARY)
  2322. << "Failed to read num_operands of instruction";
  2323. } else {
  2324. inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
  2325. }
  2326. }
  2327. for (operand_index_ = 0;
  2328. operand_index_ < static_cast<size_t>(inst_.num_operands);
  2329. ++operand_index_) {
  2330. assert(!expected_operands.empty());
  2331. const spv_operand_type_t type =
  2332. spvTakeFirstMatchableOperand(&expected_operands);
  2333. const size_t operand_offset = inst_words_.size();
  2334. const spv_result_t decode_result =
  2335. DecodeOperand(operand_offset, type, &expected_operands);
  2336. if (decode_result != SPV_SUCCESS) return decode_result;
  2337. }
  2338. assert(inst_.num_operands == parsed_operands_.size());
  2339. // Only valid while inst_words_ and parsed_operands_ remain unchanged (until
  2340. // next DecodeInstruction call).
  2341. inst_.words = inst_words_.data();
  2342. inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
  2343. inst_.num_words = static_cast<uint16_t>(inst_words_.size());
  2344. inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
  2345. std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
  2346. assert(inst_.num_words ==
  2347. std::accumulate(
  2348. parsed_operands_.begin(), parsed_operands_.end(), 1,
  2349. [](int num_words, const spv_parsed_operand_t& operand) {
  2350. return num_words += operand.num_words;
  2351. }) &&
  2352. "num_words in instruction doesn't correspond to the sum of num_words"
  2353. "in the operands");
  2354. RecordNumberType();
  2355. ProcessCurInstruction();
  2356. if (!ReadToByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte))
  2357. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
  2358. if (logger_) {
  2359. logger_->NewLine();
  2360. std::stringstream ss;
  2361. ss << spvOpcodeString(opcode) << " ";
  2362. for (size_t index = 1; index < inst_words_.size(); ++index)
  2363. ss << inst_words_[index] << " ";
  2364. logger_->AppendText(ss.str());
  2365. logger_->NewLine();
  2366. logger_->NewLine();
  2367. if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
  2368. }
  2369. return SPV_SUCCESS;
  2370. }
  2371. spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
  2372. spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
  2373. assert(type_id != 0);
  2374. auto type_info_iter = type_id_to_number_type_info_.find(type_id);
  2375. if (type_info_iter == type_id_to_number_type_info_.end()) {
  2376. return Diag(SPV_ERROR_INVALID_BINARY)
  2377. << "Type Id " << type_id << " is not a type";
  2378. }
  2379. const NumberType& info = type_info_iter->second;
  2380. if (info.type == SPV_NUMBER_NONE) {
  2381. // This is a valid type, but for something other than a scalar number.
  2382. return Diag(SPV_ERROR_INVALID_BINARY)
  2383. << "Type Id " << type_id << " is not a scalar numeric type";
  2384. }
  2385. parsed_operand->number_kind = info.type;
  2386. parsed_operand->number_bit_width = info.bit_width;
  2387. // Round up the word count.
  2388. parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
  2389. return SPV_SUCCESS;
  2390. }
  2391. void MarkvDecoder::RecordNumberType() {
  2392. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  2393. if (spvOpcodeGeneratesType(opcode)) {
  2394. NumberType info = {SPV_NUMBER_NONE, 0};
  2395. if (SpvOpTypeInt == opcode) {
  2396. info.bit_width = inst_.words[inst_.operands[1].offset];
  2397. info.type = inst_.words[inst_.operands[2].offset]
  2398. ? SPV_NUMBER_SIGNED_INT
  2399. : SPV_NUMBER_UNSIGNED_INT;
  2400. } else if (SpvOpTypeFloat == opcode) {
  2401. info.bit_width = inst_.words[inst_.operands[1].offset];
  2402. info.type = SPV_NUMBER_FLOATING;
  2403. }
  2404. // The *result* Id of a type generating instruction is the type Id.
  2405. type_id_to_number_type_info_[inst_.result_id] = info;
  2406. }
  2407. }
  2408. spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
  2409. uint32_t magic, uint32_t version, uint32_t generator,
  2410. uint32_t id_bound, uint32_t schema) {
  2411. MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
  2412. return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
  2413. schema);
  2414. }
  2415. spv_result_t EncodeInstruction(void* user_data,
  2416. const spv_parsed_instruction_t* inst) {
  2417. MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
  2418. return encoder->EncodeInstruction(*inst);
  2419. }
  2420. } // namespace
  2421. spv_result_t SpirvToMarkv(
  2422. spv_const_context context, const std::vector<uint32_t>& spirv,
  2423. const MarkvCodecOptions& options, const MarkvModel& markv_model,
  2424. MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
  2425. MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
  2426. spv_context_t hijack_context = *context;
  2427. libspirv::SetContextMessageConsumer(&hijack_context, message_consumer);
  2428. spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
  2429. spv_endianness_t endian;
  2430. spv_position_t position = {};
  2431. if (spvBinaryEndianness(&spirv_binary, &endian)) {
  2432. return DiagnosticStream(position, hijack_context.consumer,
  2433. SPV_ERROR_INVALID_BINARY)
  2434. << "Invalid SPIR-V magic number.";
  2435. }
  2436. spv_header_t header;
  2437. if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
  2438. return DiagnosticStream(position, hijack_context.consumer,
  2439. SPV_ERROR_INVALID_BINARY)
  2440. << "Invalid SPIR-V header.";
  2441. }
  2442. MarkvEncoder encoder(&hijack_context, options, &markv_model);
  2443. if (log_consumer || debug_consumer) {
  2444. encoder.CreateLogger(log_consumer, debug_consumer);
  2445. spv_text text = nullptr;
  2446. if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
  2447. SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
  2448. nullptr) != SPV_SUCCESS) {
  2449. return DiagnosticStream(position, hijack_context.consumer,
  2450. SPV_ERROR_INVALID_BINARY)
  2451. << "Failed to disassemble SPIR-V binary.";
  2452. }
  2453. assert(text);
  2454. encoder.SetDisassembly(std::string(text->str, text->length));
  2455. spvTextDestroy(text);
  2456. }
  2457. if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
  2458. EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
  2459. return DiagnosticStream(position, hijack_context.consumer,
  2460. SPV_ERROR_INVALID_BINARY)
  2461. << "Unable to encode to MARK-V.";
  2462. }
  2463. *markv = encoder.GetMarkvBinary();
  2464. return SPV_SUCCESS;
  2465. }
  2466. spv_result_t MarkvToSpirv(
  2467. spv_const_context context, const std::vector<uint8_t>& markv,
  2468. const MarkvCodecOptions& options, const MarkvModel& markv_model,
  2469. MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
  2470. MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
  2471. spv_position_t position = {};
  2472. spv_context_t hijack_context = *context;
  2473. libspirv::SetContextMessageConsumer(&hijack_context, message_consumer);
  2474. MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
  2475. if (log_consumer || debug_consumer)
  2476. decoder.CreateLogger(log_consumer, debug_consumer);
  2477. if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
  2478. return DiagnosticStream(position, hijack_context.consumer,
  2479. SPV_ERROR_INVALID_BINARY)
  2480. << "Unable to decode MARK-V.";
  2481. }
  2482. assert(!spirv->empty());
  2483. return SPV_SUCCESS;
  2484. }
  2485. } // namespace spvtools