markv.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/comp/markv.h"
  15. #include "source/comp/markv_decoder.h"
  16. #include "source/comp/markv_encoder.h"
  17. namespace spvtools {
  18. namespace comp {
  19. namespace {
  20. spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
  21. uint32_t magic, uint32_t version, uint32_t generator,
  22. uint32_t id_bound, uint32_t schema) {
  23. MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
  24. return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
  25. schema);
  26. }
  27. spv_result_t EncodeInstruction(void* user_data,
  28. const spv_parsed_instruction_t* inst) {
  29. MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
  30. return encoder->EncodeInstruction(*inst);
  31. }
  32. } // namespace
  33. spv_result_t SpirvToMarkv(
  34. spv_const_context context, const std::vector<uint32_t>& spirv,
  35. const MarkvCodecOptions& options, const MarkvModel& markv_model,
  36. MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
  37. MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
  38. spv_context_t hijack_context = *context;
  39. SetContextMessageConsumer(&hijack_context, message_consumer);
  40. spv_validator_options validator_options =
  41. MarkvDecoder::GetValidatorOptions(options);
  42. if (validator_options) {
  43. spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
  44. const spv_result_t result = spvValidateWithOptions(
  45. &hijack_context, validator_options, &spirv_binary, nullptr);
  46. if (result != SPV_SUCCESS) return result;
  47. }
  48. MarkvEncoder encoder(&hijack_context, options, &markv_model);
  49. spv_position_t position = {};
  50. if (log_consumer || debug_consumer) {
  51. encoder.CreateLogger(log_consumer, debug_consumer);
  52. spv_text text = nullptr;
  53. if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
  54. SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
  55. nullptr) != SPV_SUCCESS) {
  56. return DiagnosticStream(position, hijack_context.consumer, "",
  57. SPV_ERROR_INVALID_BINARY)
  58. << "Failed to disassemble SPIR-V binary.";
  59. }
  60. assert(text);
  61. encoder.SetDisassembly(std::string(text->str, text->length));
  62. spvTextDestroy(text);
  63. }
  64. if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
  65. EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
  66. return DiagnosticStream(position, hijack_context.consumer, "",
  67. SPV_ERROR_INVALID_BINARY)
  68. << "Unable to encode to MARK-V.";
  69. }
  70. *markv = encoder.GetMarkvBinary();
  71. return SPV_SUCCESS;
  72. }
  73. spv_result_t MarkvToSpirv(
  74. spv_const_context context, const std::vector<uint8_t>& markv,
  75. const MarkvCodecOptions& options, const MarkvModel& markv_model,
  76. MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
  77. MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
  78. spv_position_t position = {};
  79. spv_context_t hijack_context = *context;
  80. SetContextMessageConsumer(&hijack_context, message_consumer);
  81. MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
  82. if (log_consumer || debug_consumer)
  83. decoder.CreateLogger(log_consumer, debug_consumer);
  84. if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
  85. return DiagnosticStream(position, hijack_context.consumer, "",
  86. SPV_ERROR_INVALID_BINARY)
  87. << "Unable to decode MARK-V.";
  88. }
  89. assert(!spirv->empty());
  90. return SPV_SUCCESS;
  91. }
  92. } // namespace comp
  93. } // namespace spvtools