| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920 |
- // Copyright (c) 2017 Google Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // Contains
- // - SPIR-V to MARK-V encoder
- // - MARK-V to SPIR-V decoder
- //
- // MARK-V is a compression format for SPIR-V binaries. It strips away
- // non-essential information (such as result ids which can be regenerated) and
- // uses various bit reduction techiniques to reduce the size of the binary.
- #include <algorithm>
- #include <cassert>
- #include <cstring>
- #include <functional>
- #include <iostream>
- #include <iterator>
- #include <list>
- #include <memory>
- #include <numeric>
- #include <string>
- #include <unordered_map>
- #include <unordered_set>
- #include <vector>
- #include "latest_version_glsl_std_450_header.h"
- #include "latest_version_opencl_std_header.h"
- #include "latest_version_spirv_header.h"
- #include "binary.h"
- #include "diagnostic.h"
- #include "enum_string_mapping.h"
- #include "ext_inst.h"
- #include "extensions.h"
- #include "id_descriptor.h"
- #include "instruction.h"
- #include "markv.h"
- #include "markv_model.h"
- #include "opcode.h"
- #include "operand.h"
- #include "spirv-tools/libspirv.h"
- #include "spirv_endian.h"
- #include "spirv_validator_options.h"
- #include "util/bit_stream.h"
- #include "util/huffman_codec.h"
- #include "util/move_to_front.h"
- #include "util/parse_number.h"
- #include "val/instruction.h"
- #include "val/validation_state.h"
- #include "validate.h"
- using libspirv::DiagnosticStream;
- using libspirv::IdDescriptorCollection;
- using libspirv::Instruction;
- using libspirv::ValidationState_t;
- using spvutils::BitReaderWord64;
- using spvutils::BitWriterWord64;
- using spvutils::HuffmanCodec;
- using MoveToFront = spvutils::MoveToFront<uint32_t>;
- using MultiMoveToFront = spvutils::MultiMoveToFront<uint32_t>;
- namespace spvtools {
- namespace {
- const uint32_t kSpirvMagicNumber = SpvMagicNumber;
- const uint32_t kMarkvMagicNumber = 0x07230303;
- // Handles for move-to-front sequences. Enums which end with "Begin" define
- // handle spaces which start at that value and span 16 or 32 bit wide.
- enum : uint64_t {
- kMtfNone = 0,
- // All ids.
- kMtfAll,
- // All forward declared ids.
- kMtfForwardDeclared,
- // All type ids except for generated by OpTypeFunction.
- kMtfTypeNonFunction,
- // All labels.
- kMtfLabel,
- // All ids created by instructions which had type_id.
- kMtfObject,
- // All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
- kMtfTypeScalar,
- // All composite types.
- kMtfTypeComposite,
- // Boolean type or any vector type of it.
- kMtfTypeBoolScalarOrVector,
- // All float types or any vector floats type.
- kMtfTypeFloatScalarOrVector,
- // All int types or any vector int type.
- kMtfTypeIntScalarOrVector,
- // All types declared as return types in OpTypeFunction.
- kMtfTypeReturnedByFunction,
- // All composite objects.
- kMtfComposite,
- // All bool objects or vectors of bools.
- kMtfBoolScalarOrVector,
- // All float objects or vectors of float.
- kMtfFloatScalarOrVector,
- // All int objects or vectors of int.
- kMtfIntScalarOrVector,
- // All pointer types which point to composited.
- kMtfTypePointerToComposite,
- // Used by EncodeMtfRankHuffman.
- kMtfGenericNonZeroRank,
- // Handle space for ids of specific type.
- kMtfIdOfTypeBegin = 0x10000,
- // Handle space for ids generated by specific opcode.
- kMtfIdGeneratedByOpcode = 0x20000,
- // Handle space for ids of objects with type generated by specific opcode.
- kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
- // All vectors of specific component type.
- kMtfVectorOfComponentTypeBegin = 0x40000,
- // All vector types of specific size.
- kMtfTypeVectorOfSizeBegin = 0x50000,
- // All pointer types to specific type.
- kMtfPointerToTypeBegin = 0x60000,
- // All function types which return specific type.
- kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
- // All function objects which return specific type.
- kMtfFunctionWithReturnTypeBegin = 0x80000,
- // Short id descriptor space (max 16-bit).
- kMtfShortIdDescriptorSpaceBegin = 0x90000,
- // Long id descriptor space (32-bit).
- kMtfLongIdDescriptorSpaceBegin = 0x100000000,
- };
- // Signals that the value is not in the coding scheme and a fallback method
- // needs to be used.
- const uint64_t kMarkvNoneOfTheAbove = MarkvModel::GetMarkvNoneOfTheAbove();
- // Mtf ranks smaller than this are encoded with Huffman coding.
- const uint32_t kMtfSmallestRankEncodedByValue = 10;
- // Signals that the mtf rank is too large to be encoded with Huffman.
- const uint32_t kMtfRankEncodedByValueSignal =
- std::numeric_limits<uint32_t>::max();
- const size_t kCommentNumWhitespaces = 2;
- const size_t kByteBreakAfterInstIfLessThanUntilNextByte = 8;
- const uint32_t kShortDescriptorNumBits = 8;
- // Custom hash function used to produce short descriptors.
- uint32_t ShortHashU32Array(const std::vector<uint32_t>& words) {
- // The hash function is a sum of hashes of each word seeded by word index.
- // Knuth's multiplicative hash is used to hash the words.
- const uint32_t kKnuthMulHash = 2654435761;
- uint32_t val = 0;
- for (uint32_t i = 0; i < words.size(); ++i) {
- val += (words[i] + i + 123) * kKnuthMulHash;
- }
- return 1 + val % ((1 << kShortDescriptorNumBits) - 1);
- }
- // Returns a set of mtf rank codecs based on a plausible hand-coded
- // distribution.
- std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
- GetMtfHuffmanCodecs() {
- std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
- std::unique_ptr<HuffmanCodec<uint32_t>> codec;
- codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
- {0, 5},
- {1, 40},
- {2, 10},
- {3, 5},
- {4, 5},
- {5, 5},
- {6, 3},
- {7, 3},
- {8, 3},
- {9, 3},
- {kMtfRankEncodedByValueSignal, 10},
- })));
- codecs.emplace(kMtfAll, std::move(codec));
- codec.reset(new HuffmanCodec<uint32_t>(std::map<uint32_t, uint32_t>({
- {1, 50},
- {2, 20},
- {3, 5},
- {4, 5},
- {5, 2},
- {6, 1},
- {7, 1},
- {8, 1},
- {9, 1},
- {kMtfRankEncodedByValueSignal, 10},
- })));
- codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
- return codecs;
- }
- // Returns true if the opcode has a fixed number of operands. May return a
- // false negative.
- bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
- switch (opcode) {
- // TODO([email protected]) This is not a complete list.
- case SpvOpNop:
- case SpvOpName:
- case SpvOpUndef:
- case SpvOpSizeOf:
- case SpvOpLine:
- case SpvOpNoLine:
- case SpvOpDecorationGroup:
- case SpvOpExtension:
- case SpvOpExtInstImport:
- case SpvOpMemoryModel:
- case SpvOpCapability:
- case SpvOpTypeVoid:
- case SpvOpTypeBool:
- case SpvOpTypeInt:
- case SpvOpTypeFloat:
- case SpvOpTypeVector:
- case SpvOpTypeMatrix:
- case SpvOpTypeSampler:
- case SpvOpTypeSampledImage:
- case SpvOpTypeArray:
- case SpvOpTypePointer:
- case SpvOpConstantTrue:
- case SpvOpConstantFalse:
- case SpvOpLabel:
- case SpvOpBranch:
- case SpvOpFunction:
- case SpvOpFunctionParameter:
- case SpvOpFunctionEnd:
- case SpvOpBitcast:
- case SpvOpCopyObject:
- case SpvOpTranspose:
- case SpvOpSNegate:
- case SpvOpFNegate:
- case SpvOpIAdd:
- case SpvOpFAdd:
- case SpvOpISub:
- case SpvOpFSub:
- case SpvOpIMul:
- case SpvOpFMul:
- case SpvOpUDiv:
- case SpvOpSDiv:
- case SpvOpFDiv:
- case SpvOpUMod:
- case SpvOpSRem:
- case SpvOpSMod:
- case SpvOpFRem:
- case SpvOpFMod:
- case SpvOpVectorTimesScalar:
- case SpvOpMatrixTimesScalar:
- case SpvOpVectorTimesMatrix:
- case SpvOpMatrixTimesVector:
- case SpvOpMatrixTimesMatrix:
- case SpvOpOuterProduct:
- case SpvOpDot:
- return true;
- default:
- break;
- }
- return false;
- }
- size_t GetNumBitsToNextByte(size_t bit_pos) { return (8 - (bit_pos % 8)) % 8; }
- // Defines and returns current MARK-V version.
- uint32_t GetMarkvVersion() {
- const uint32_t kVersionMajor = 1;
- const uint32_t kVersionMinor = 4;
- return kVersionMinor | (kVersionMajor << 16);
- }
- class MarkvLogger {
- public:
- MarkvLogger(MarkvLogConsumer log_consumer, MarkvDebugConsumer debug_consumer)
- : log_consumer_(log_consumer), debug_consumer_(debug_consumer) {}
- void AppendText(const std::string& str) {
- Append(str);
- use_delimiter_ = false;
- }
- void AppendTextNewLine(const std::string& str) {
- Append(str);
- Append("\n");
- use_delimiter_ = false;
- }
- void AppendBitSequence(const std::string& str) {
- if (debug_consumer_) instruction_bits_ << str;
- if (use_delimiter_) Append("-");
- Append(str);
- use_delimiter_ = true;
- }
- void AppendWhitespaces(size_t num) {
- Append(std::string(num, ' '));
- use_delimiter_ = false;
- }
- void NewLine() {
- Append("\n");
- use_delimiter_ = false;
- }
- bool DebugInstruction(const spv_parsed_instruction_t& inst) {
- bool result = true;
- if (debug_consumer_) {
- result = debug_consumer_(
- std::vector<uint32_t>(inst.words, inst.words + inst.num_words),
- instruction_bits_.str(), instruction_comment_.str());
- instruction_bits_.str(std::string());
- instruction_comment_.str(std::string());
- }
- return result;
- }
- private:
- MarkvLogger(const MarkvLogger&) = delete;
- MarkvLogger(MarkvLogger&&) = delete;
- MarkvLogger& operator=(const MarkvLogger&) = delete;
- MarkvLogger& operator=(MarkvLogger&&) = delete;
- void Append(const std::string& str) {
- if (log_consumer_) log_consumer_(str);
- if (debug_consumer_) instruction_comment_ << str;
- }
- MarkvLogConsumer log_consumer_;
- MarkvDebugConsumer debug_consumer_;
- std::stringstream instruction_bits_;
- std::stringstream instruction_comment_;
- // If true a delimiter will be appended before the next bit sequence.
- // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
- bool use_delimiter_ = false;
- };
- // Base class for MARK-V encoder and decoder. Contains common functionality
- // such as:
- // - Validator connection and validation state.
- // - SPIR-V grammar and helper functions.
- class MarkvCodecBase {
- public:
- virtual ~MarkvCodecBase() { spvValidatorOptionsDestroy(validator_options_); }
- MarkvCodecBase() = delete;
- protected:
- struct MarkvHeader {
- MarkvHeader() {
- magic_number = kMarkvMagicNumber;
- markv_version = GetMarkvVersion();
- markv_model = 0;
- markv_length_in_bits = 0;
- spirv_version = 0;
- spirv_generator = 0;
- }
- uint32_t magic_number;
- uint32_t markv_version;
- // Magic number to identify or verify MarkvModel used for encoding.
- uint32_t markv_model;
- uint32_t markv_length_in_bits;
- uint32_t spirv_version;
- uint32_t spirv_generator;
- };
- // |model| is owned by the caller, must be not null and valid during the
- // lifetime of the codec.
- explicit MarkvCodecBase(spv_const_context context,
- spv_validator_options validator_options,
- const MarkvModel* model)
- : validator_options_(validator_options),
- grammar_(context),
- model_(model),
- short_id_descriptors_(ShortHashU32Array),
- mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
- context_(context),
- vstate_(validator_options
- ? new ValidationState_t(context, validator_options_)
- : nullptr) {}
- // Validates a single instruction and updates validation state of the module.
- // Does nothing and returns SPV_SUCCESS if validator was not created.
- spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
- if (!vstate_) return SPV_SUCCESS;
- return ValidateInstructionAndUpdateValidationState(vstate_.get(), &inst);
- }
- // Returns instruction which created |id| or nullptr if such instruction was
- // not registered.
- const Instruction* FindDef(uint32_t id) const {
- const auto it = id_to_def_instruction_.find(id);
- if (it == id_to_def_instruction_.end()) return nullptr;
- return it->second;
- }
- // Returns type id of vector type component.
- uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
- const Instruction* type_inst = FindDef(vector_type_id);
- assert(type_inst);
- assert(type_inst->opcode() == SpvOpTypeVector);
- const uint32_t component_type =
- type_inst->word(type_inst->operands()[1].offset);
- return component_type;
- }
- // Returns mtf handle for ids of given type.
- uint64_t GetMtfIdOfType(uint32_t type_id) const {
- return kMtfIdOfTypeBegin + type_id;
- }
- // Returns mtf handle for ids generated by given opcode.
- uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
- return kMtfIdGeneratedByOpcode + opcode;
- }
- // Returns mtf handle for ids of type generated by given opcode.
- uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
- return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
- }
- // Returns mtf handle for vectors of specific component type.
- uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
- return kMtfVectorOfComponentTypeBegin + type_id;
- }
- // Returns mtf handle for vector type of specific size.
- uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
- return kMtfTypeVectorOfSizeBegin + size;
- }
- // Returns mtf handle for pointers to specific size.
- uint64_t GetMtfPointerToType(uint32_t type_id) const {
- return kMtfPointerToTypeBegin + type_id;
- }
- // Returns mtf handle for function types with given return type.
- uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
- return kMtfFunctionTypeWithReturnTypeBegin + type_id;
- }
- // Returns mtf handle for functions with given return type.
- uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
- return kMtfFunctionWithReturnTypeBegin + type_id;
- }
- // Returns mtf handle for the given long id descriptor.
- uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const {
- return kMtfLongIdDescriptorSpaceBegin + descriptor;
- }
- // Returns mtf handle for the given short id descriptor.
- uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const {
- return kMtfShortIdDescriptorSpaceBegin + descriptor;
- }
- // Process data from the current instruction. This would update MTFs and
- // other data containers.
- void ProcessCurInstruction();
- // Returns move-to-front handle to be used for the current operand slot.
- // Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
- uint64_t GetRuleBasedMtf();
- // Returns words of the current instruction. Decoder has a different
- // implementation and the array is valid only until the previously decoded
- // word.
- virtual const uint32_t* GetInstWords() const { return inst_.words; }
- // Returns the opcode of the previous instruction.
- SpvOp GetPrevOpcode() const {
- if (instructions_.empty()) return SpvOpNop;
- return instructions_.back()->opcode();
- }
- // Returns diagnostic stream, position index is set to instruction number.
- DiagnosticStream Diag(spv_result_t error_code) const {
- return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer,
- error_code);
- }
- // Returns current id bound.
- uint32_t GetIdBound() const { return id_bound_; }
- // Sets current id bound, expected to be no lower than the previous one.
- void SetIdBound(uint32_t id_bound) {
- assert(id_bound >= id_bound_);
- id_bound_ = id_bound;
- if (vstate_) vstate_->setIdBound(id_bound);
- }
- // Returns Huffman codec for ranks of the mtf with given |handle|.
- // Different mtfs can use different rank distributions.
- // May return nullptr if the codec doesn't exist.
- const spvutils::HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(
- uint64_t handle) const {
- const auto it = mtf_huffman_codecs_.find(handle);
- if (it == mtf_huffman_codecs_.end()) return nullptr;
- return it->second.get();
- }
- // Promotes id in all move-to-front sequences if ids can be shared by multiple
- // sequences.
- void PromoteIfNeeded(uint32_t id) {
- if (!model_->AnyDescriptorHasCodingScheme() &&
- model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kShortDescriptor) {
- // Move-to-front sequences do not share ids. Nothing to do.
- return;
- }
- multi_mtf_.Promote(id);
- }
- spv_validator_options validator_options_ = nullptr;
- const libspirv::AssemblyGrammar grammar_;
- MarkvHeader header_;
- // MARK-V model, not owned.
- const MarkvModel* model_ = nullptr;
- // Current instruction, current operand and current operand index.
- spv_parsed_instruction_t inst_;
- spv_parsed_operand_t operand_;
- uint32_t operand_index_;
- // Maps a result ID to its type ID. By convention:
- // - a result ID that is a type definition maps to itself.
- // - a result ID without a type maps to 0. (E.g. for OpLabel)
- std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
- // Container for all move-to-front sequences.
- MultiMoveToFront multi_mtf_;
- // Id of the current function or zero if outside of function.
- uint32_t cur_function_id_ = 0;
- // Return type of the current function.
- uint32_t cur_function_return_type_ = 0;
- // Remaining function parameter types. This container is filled on OpFunction,
- // and drained on OpFunctionParameter.
- std::list<uint32_t> remaining_function_parameter_types_;
- // List of ids local to the current function.
- std::vector<uint32_t> ids_local_to_cur_function_;
- // List of instructions in the order they are given in the module.
- std::vector<std::unique_ptr<const Instruction>> instructions_;
- // Container/computer for long (32-bit) id descriptors.
- IdDescriptorCollection long_id_descriptors_;
- // Container/computer for short id descriptors.
- // Short descriptors are stored in uint32_t, but their actual bit width is
- // defined with kShortDescriptorNumBits.
- // It doesn't seem logical to have a different computer for short id
- // descriptors, since one could actually map/truncate long descriptors.
- // But as short descriptors have collisions, the efficiency of
- // compression depends on the collision pattern, and short descriptors
- // produced by function ShortHashU32Array have been empirically proven to
- // produce better results.
- IdDescriptorCollection short_id_descriptors_;
- // Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
- // need to contain a different codec for every handle as most use one and the
- // same.
- std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
- mtf_huffman_codecs_;
- // If not nullptr, codec will log comments on the compression process.
- std::unique_ptr<MarkvLogger> logger_;
- private:
- spv_const_context context_ = nullptr;
- std::unique_ptr<ValidationState_t> vstate_;
- // Maps result id to the instruction which defined it.
- std::unordered_map<uint32_t, const Instruction*> id_to_def_instruction_;
- uint32_t id_bound_ = 1;
- };
- // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
- // EncodeInstruction which can be used as callback by spvBinaryParse.
- // Encoded binary is written to an internally maintained bitstream.
- // After the last instruction is encoded, the resulting MARK-V binary can be
- // acquired by calling GetMarkvBinary().
- // The encoder uses SPIR-V validator to keep internal state, therefore
- // SPIR-V binary needs to be able to pass validator checks.
- // CreateCommentsLogger() can be used to enable the encoder to write comments
- // on how encoding was done, which can later be accessed with GetComments().
- class MarkvEncoder : public MarkvCodecBase {
- public:
- // |model| is owned by the caller, must be not null and valid during the
- // lifetime of MarkvEncoder.
- MarkvEncoder(spv_const_context context, const MarkvCodecOptions& options,
- const MarkvModel* model)
- : MarkvCodecBase(context, GetValidatorOptions(options), model),
- options_(options) {
- (void)options_;
- }
- // Writes data from SPIR-V header to MARK-V header.
- spv_result_t EncodeHeader(spv_endianness_t /* endian */, uint32_t /* magic */,
- uint32_t version, uint32_t generator,
- uint32_t id_bound, uint32_t /* schema */) {
- SetIdBound(id_bound);
- header_.spirv_version = version;
- header_.spirv_generator = generator;
- return SPV_SUCCESS;
- }
- // Creates an internal logger which writes comments on the encoding process.
- void CreateLogger(MarkvLogConsumer log_consumer,
- MarkvDebugConsumer debug_consumer) {
- logger_.reset(new MarkvLogger(log_consumer, debug_consumer));
- writer_.SetCallback(
- [this](const std::string& str) { logger_->AppendBitSequence(str); });
- }
- // Encodes SPIR-V instruction to MARK-V and writes to bit stream.
- // Operation can fail if the instruction fails to pass the validator or if
- // the encoder stubmles on something unexpected.
- spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
- // Concatenates MARK-V header and the bit stream with encoded instructions
- // into a single buffer and returns it as spv_markv_binary. The returned
- // value is owned by the caller and needs to be destroyed with
- // spvMarkvBinaryDestroy().
- std::vector<uint8_t> GetMarkvBinary() {
- header_.markv_length_in_bits =
- static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
- header_.markv_model =
- (model_->model_type() << 16) | model_->model_version();
- const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
- std::vector<uint8_t> markv(num_bytes);
- assert(writer_.GetData());
- std::memcpy(markv.data(), &header_, sizeof(header_));
- std::memcpy(markv.data() + sizeof(header_), writer_.GetData(),
- writer_.GetDataSizeBytes());
- return markv;
- }
- // Optionally adds disassembly to the comments.
- // Disassembly should contain all instructions in the module separated by
- // \n, and no header.
- void SetDisassembly(std::string&& disassembly) {
- disassembly_.reset(new std::stringstream(std::move(disassembly)));
- }
- // Extracts the next instruction line from the disassembly and logs it.
- void LogDisassemblyInstruction() {
- if (logger_ && disassembly_) {
- std::string line;
- std::getline(*disassembly_, line, '\n');
- logger_->AppendTextNewLine(line);
- }
- }
- private:
- // Creates and returns validator options. Returned value owned by the caller.
- static spv_validator_options GetValidatorOptions(
- const MarkvCodecOptions& options) {
- return options.validate_spirv_binary ? spvValidatorOptionsCreate()
- : nullptr;
- }
- // Writes a single word to bit stream. operand_.type determines if the word is
- // encoded and how.
- spv_result_t EncodeNonIdWord(uint32_t word);
- // Writes both opcode and num_operands as a single code.
- // Returns SPV_UNSUPPORTED iff no suitable codec was found.
- spv_result_t EncodeOpcodeAndNumOperands(uint32_t opcode,
- uint32_t num_operands);
- // Writes mtf rank to bit stream. |mtf| is used to determine the codec
- // scheme. |fallback_method| is used if no codec defined for |mtf|.
- spv_result_t EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
- uint64_t fallback_method);
- // Writes id using coding based on mtf associated with the id descriptor.
- // Returns SPV_UNSUPPORTED iff fallback method needs to be used.
- spv_result_t EncodeIdWithDescriptor(uint32_t id);
- // Writes id using coding based on the given |mtf|, which is expected to
- // contain the given |id|.
- spv_result_t EncodeExistingId(uint64_t mtf, uint32_t id);
- // Writes type id of the current instruction if can't be inferred.
- spv_result_t EncodeTypeId();
- // Writes result id of the current instruction if can't be inferred.
- spv_result_t EncodeResultId();
- // Writes ids which are neither type nor result ids.
- spv_result_t EncodeRefId(uint32_t id);
- // Writes bits to the stream until the beginning of the next byte if the
- // number of bits until the next byte is less than |byte_break_if_less_than|.
- void AddByteBreak(size_t byte_break_if_less_than);
- // Encodes a literal number operand and writes it to the bit stream.
- spv_result_t EncodeLiteralNumber(const spv_parsed_operand_t& operand);
- MarkvCodecOptions options_;
- // Bit stream where encoded instructions are written.
- BitWriterWord64 writer_;
- // If not nullptr, disassembled instruction lines will be written to comments.
- // Format: \n separated instruction lines, no header.
- std::unique_ptr<std::stringstream> disassembly_;
- };
- // Decodes MARK-V buffers written by MarkvEncoder.
- class MarkvDecoder : public MarkvCodecBase {
- public:
- // |model| is owned by the caller, must be not null and valid during the
- // lifetime of MarkvEncoder.
- MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
- const MarkvCodecOptions& options, const MarkvModel* model)
- : MarkvCodecBase(context, GetValidatorOptions(options), model),
- options_(options),
- reader_(markv) {
- (void)options_;
- SetIdBound(1);
- parsed_operands_.reserve(25);
- inst_words_.reserve(25);
- }
- // Creates an internal logger which writes comments on the decoding process.
- void CreateLogger(MarkvLogConsumer log_consumer,
- MarkvDebugConsumer debug_consumer) {
- logger_.reset(new MarkvLogger(log_consumer, debug_consumer));
- }
- // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
- // Can be called only once. Fails if data of wrong format or ends prematurely,
- // of if validation fails.
- spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
- private:
- // Describes the format of a typed literal number.
- struct NumberType {
- spv_number_kind_t type;
- uint32_t bit_width;
- };
- // Creates and returns validator options. Returned value owned by the caller.
- static spv_validator_options GetValidatorOptions(
- const MarkvCodecOptions& options) {
- return options.validate_spirv_binary ? spvValidatorOptionsCreate()
- : nullptr;
- }
- // Reads a single bit from reader_. The read bit is stored in |bit|.
- // Returns false iff reader_ fails.
- bool ReadBit(bool* bit) {
- uint64_t bits = 0;
- const bool result = reader_.ReadBits(&bits, 1);
- if (result) *bit = bits ? true : false;
- return result;
- };
- // Returns ReadBit bound to the class object.
- std::function<bool(bool*)> GetReadBitCallback() {
- return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
- }
- // Reads a single non-id word from bit stream. operand_.type determines if
- // the word needs to be decoded and how.
- spv_result_t DecodeNonIdWord(uint32_t* word);
- // Reads and decodes both opcode and num_operands as a single code.
- // Returns SPV_UNSUPPORTED iff no suitable codec was found.
- spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
- uint32_t* num_operands);
- // Reads mtf rank from bit stream. |mtf| is used to determine the codec
- // scheme. |fallback_method| is used if no codec defined for |mtf|.
- spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
- uint32_t* rank);
- // Reads id using coding based on mtf associated with the id descriptor.
- // Returns SPV_UNSUPPORTED iff fallback method needs to be used.
- spv_result_t DecodeIdWithDescriptor(uint32_t* id);
- // Reads id using coding based on the given |mtf|, which is expected to
- // contain the needed |id|.
- spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
- // Reads type id of the current instruction if can't be inferred.
- spv_result_t DecodeTypeId();
- // Reads result id of the current instruction if can't be inferred.
- spv_result_t DecodeResultId();
- // Reads id which is neither type nor result id.
- spv_result_t DecodeRefId(uint32_t* id);
- // Reads and discards bits until the beginning of the next byte if the
- // number of bits until the next byte is less than |byte_break_if_less_than|.
- bool ReadToByteBreak(size_t byte_break_if_less_than);
- // Returns instruction words decoded up to this point.
- const uint32_t* GetInstWords() const override { return inst_words_.data(); }
- // Reads a literal number as it is described in |operand| from the bit stream,
- // decodes and writes it to spirv_.
- spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
- // Reads instruction from bit stream, decodes and validates it.
- // Decoded instruction is valid until the next call of DecodeInstruction().
- spv_result_t DecodeInstruction();
- // Read operand from the stream decodes and validates it.
- spv_result_t DecodeOperand(size_t operand_offset,
- const spv_operand_type_t type,
- spv_operand_pattern_t* expected_operands);
- // Records the numeric type for an operand according to the type information
- // associated with the given non-zero type Id. This can fail if the type Id
- // is not a type Id, or if the type Id does not reference a scalar numeric
- // type. On success, return SPV_SUCCESS and populates the num_words,
- // number_kind, and number_bit_width fields of parsed_operand.
- spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
- uint32_t type_id);
- // Records the number type for the current instruction, if it generates a
- // type. For types that aren't scalar numbers, record something with number
- // kind SPV_NUMBER_NONE.
- void RecordNumberType();
- MarkvCodecOptions options_;
- // Temporary sink where decoded SPIR-V words are written. Once it contains the
- // entire module, the container is moved and returned.
- std::vector<uint32_t> spirv_;
- // Bit stream containing encoded data.
- BitReaderWord64 reader_;
- // Temporary storage for operands of the currently parsed instruction.
- // Valid until next DecodeInstruction call.
- std::vector<spv_parsed_operand_t> parsed_operands_;
- // Temporary storage for current instruction words.
- // Valid until next DecodeInstruction call.
- std::vector<uint32_t> inst_words_;
- // Maps a type ID to its number type description.
- std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
- // Maps an ExtInstImport id to the extended instruction type.
- std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
- };
- void MarkvCodecBase::ProcessCurInstruction() {
- instructions_.emplace_back(new Instruction(&inst_));
- const SpvOp opcode = SpvOp(inst_.opcode);
- if (inst_.result_id) {
- id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
- // Collect ids local to the current function.
- if (cur_function_id_) {
- ids_local_to_cur_function_.push_back(inst_.result_id);
- }
- // Starting new function.
- if (opcode == SpvOpFunction) {
- cur_function_id_ = inst_.result_id;
- cur_function_return_type_ = inst_.type_id;
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
- inst_.result_id);
- }
- // Store function parameter types in a queue, so that we know which types
- // to expect in the following OpFunctionParameter instructions.
- const Instruction* def_inst = FindDef(inst_.words[4]);
- assert(def_inst);
- assert(def_inst->opcode() == SpvOpTypeFunction);
- for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
- remaining_function_parameter_types_.push_back(def_inst->word(i));
- }
- }
- }
- // Remove local ids from MTFs if function end.
- if (opcode == SpvOpFunctionEnd) {
- cur_function_id_ = 0;
- for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id);
- ids_local_to_cur_function_.clear();
- assert(remaining_function_parameter_types_.empty());
- }
- if (!inst_.result_id) return;
- {
- // Save the result ID to type ID mapping.
- // In the grammar, type ID always appears before result ID.
- // A regular value maps to its type. Some instructions (e.g. OpLabel)
- // have no type Id, and will map to 0. The result Id for a
- // type-generating instruction (e.g. OpTypeInt) maps to itself.
- auto insertion_result = id_to_type_id_.emplace(
- inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode))
- ? inst_.result_id
- : inst_.type_id);
- (void)insertion_result;
- assert(insertion_result.second);
- }
- // Add result_id to MTFs.
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- switch (opcode) {
- case SpvOpTypeFloat:
- case SpvOpTypeInt:
- case SpvOpTypeBool:
- case SpvOpTypeVector:
- case SpvOpTypePointer:
- case SpvOpExtInstImport:
- case SpvOpTypeSampledImage:
- case SpvOpTypeImage:
- case SpvOpTypeSampler:
- multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
- break;
- default:
- break;
- }
- if (spvOpcodeIsComposite(opcode)) {
- multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
- }
- if (opcode == SpvOpLabel) {
- multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
- }
- if (opcode == SpvOpTypeInt) {
- multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
- multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
- }
- if (opcode == SpvOpTypeFloat) {
- multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
- multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
- }
- if (opcode == SpvOpTypeBool) {
- multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
- multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
- }
- if (opcode == SpvOpTypeVector) {
- const uint32_t component_type_id = inst_.words[2];
- const uint32_t size = inst_.words[3];
- if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
- component_type_id)) {
- multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
- } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
- component_type_id)) {
- multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
- } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
- component_type_id)) {
- multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
- }
- multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
- }
- if (inst_.opcode == SpvOpTypeFunction) {
- const uint32_t return_type = inst_.words[2];
- multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
- multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
- inst_.result_id);
- }
- if (inst_.type_id) {
- const Instruction* type_inst = FindDef(inst_.type_id);
- assert(type_inst);
- multi_mtf_.Insert(kMtfObject, inst_.result_id);
- multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
- if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
- multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
- }
- if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
- multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
- if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
- multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
- if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
- multi_mtf_.Insert(kMtfComposite, inst_.result_id);
- switch (type_inst->opcode()) {
- case SpvOpTypeInt:
- case SpvOpTypeBool:
- case SpvOpTypePointer:
- case SpvOpTypeVector:
- case SpvOpTypeImage:
- case SpvOpTypeSampledImage:
- case SpvOpTypeSampler:
- multi_mtf_.Insert(
- GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()),
- inst_.result_id);
- break;
- default:
- break;
- }
- if (type_inst->opcode() == SpvOpTypeVector) {
- const uint32_t component_type = type_inst->word(2);
- multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
- inst_.result_id);
- }
- if (type_inst->opcode() == SpvOpTypePointer) {
- assert(type_inst->operands().size() > 2);
- assert(type_inst->words().size() > type_inst->operands()[2].offset);
- const uint32_t data_type =
- type_inst->word(type_inst->operands()[2].offset);
- multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
- if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
- multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
- }
- }
- if (spvOpcodeGeneratesType(opcode)) {
- if (opcode != SpvOpTypeFunction) {
- multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
- }
- }
- }
- if (model_->AnyDescriptorHasCodingScheme()) {
- const uint32_t long_descriptor =
- long_id_descriptors_.ProcessInstruction(inst_);
- if (model_->DescriptorHasCodingScheme(long_descriptor))
- multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor),
- inst_.result_id);
- }
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kShortDescriptor) {
- const uint32_t short_descriptor =
- short_id_descriptors_.ProcessInstruction(inst_);
- multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor),
- inst_.result_id);
- }
- }
- uint64_t MarkvCodecBase::GetRuleBasedMtf() {
- // This function is only called for id operands (but not result ids).
- assert(spvIsIdType(operand_.type) ||
- operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
- assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
- const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
- // All operand slots which expect label id.
- if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
- (inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
- (inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
- (inst_.opcode == SpvOpBranchConditional &&
- (operand_index_ == 1 || operand_index_ == 2)) ||
- (inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
- operand_index_ % 2 == 1) ||
- (inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
- return kMtfLabel;
- }
- switch (opcode) {
- case SpvOpFAdd:
- case SpvOpFSub:
- case SpvOpFMul:
- case SpvOpFDiv:
- case SpvOpFRem:
- case SpvOpFMod:
- case SpvOpFNegate: {
- if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector;
- return GetMtfIdOfType(inst_.type_id);
- }
- case SpvOpISub:
- case SpvOpIAdd:
- case SpvOpIMul:
- case SpvOpSDiv:
- case SpvOpUDiv:
- case SpvOpSMod:
- case SpvOpUMod:
- case SpvOpSRem:
- case SpvOpSNegate: {
- if (operand_index_ == 0) return kMtfTypeIntScalarOrVector;
- return kMtfIntScalarOrVector;
- }
- // TODO([email protected]) Add OpConvertFToU and other opcodes.
- case SpvOpFOrdEqual:
- case SpvOpFUnordEqual:
- case SpvOpFOrdNotEqual:
- case SpvOpFUnordNotEqual:
- case SpvOpFOrdLessThan:
- case SpvOpFUnordLessThan:
- case SpvOpFOrdGreaterThan:
- case SpvOpFUnordGreaterThan:
- case SpvOpFOrdLessThanEqual:
- case SpvOpFUnordLessThanEqual:
- case SpvOpFOrdGreaterThanEqual:
- case SpvOpFUnordGreaterThanEqual: {
- if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector;
- if (operand_index_ == 2) return kMtfFloatScalarOrVector;
- if (operand_index_ == 3) {
- const uint32_t first_operand_id = GetInstWords()[3];
- const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id);
- return GetMtfIdOfType(first_operand_type);
- }
- break;
- }
- case SpvOpVectorShuffle: {
- if (operand_index_ == 0) {
- assert(inst_.num_operands > 4);
- return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
- }
- assert(inst_.type_id);
- if (operand_index_ == 2 || operand_index_ == 3)
- return GetMtfVectorOfComponentType(
- GetVectorComponentType(inst_.type_id));
- break;
- }
- case SpvOpVectorTimesScalar: {
- if (operand_index_ == 0) {
- // TODO([email protected]) Could be narrowed to vector of floats.
- return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
- }
- assert(inst_.type_id);
- if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id);
- if (operand_index_ == 3)
- return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
- break;
- }
- case SpvOpDot: {
- if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
- assert(inst_.type_id);
- if (operand_index_ == 2)
- return GetMtfVectorOfComponentType(inst_.type_id);
- if (operand_index_ == 3) {
- const uint32_t vector_id = GetInstWords()[3];
- const uint32_t vector_type = id_to_type_id_.at(vector_id);
- return GetMtfIdOfType(vector_type);
- }
- break;
- }
- case SpvOpTypeVector: {
- if (operand_index_ == 1) {
- return kMtfTypeScalar;
- }
- break;
- }
- case SpvOpTypeMatrix: {
- if (operand_index_ == 1) {
- return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
- }
- break;
- }
- case SpvOpTypePointer: {
- if (operand_index_ == 2) {
- return kMtfTypeNonFunction;
- }
- break;
- }
- case SpvOpTypeStruct: {
- if (operand_index_ >= 1) {
- return kMtfTypeNonFunction;
- }
- break;
- }
- case SpvOpTypeFunction: {
- if (operand_index_ == 1) {
- return kMtfTypeNonFunction;
- }
- if (operand_index_ >= 2) {
- return kMtfTypeNonFunction;
- }
- break;
- }
- case SpvOpLoad: {
- if (operand_index_ == 0) return kMtfTypeNonFunction;
- if (operand_index_ == 2) {
- assert(inst_.type_id);
- return GetMtfPointerToType(inst_.type_id);
- }
- break;
- }
- case SpvOpStore: {
- if (operand_index_ == 0)
- return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
- if (operand_index_ == 1) {
- const uint32_t pointer_id = GetInstWords()[1];
- const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
- const Instruction* pointer_inst = FindDef(pointer_type);
- assert(pointer_inst);
- assert(pointer_inst->opcode() == SpvOpTypePointer);
- const uint32_t data_type =
- pointer_inst->word(pointer_inst->operands()[2].offset);
- return GetMtfIdOfType(data_type);
- }
- break;
- }
- case SpvOpVariable: {
- if (operand_index_ == 0)
- return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
- break;
- }
- case SpvOpAccessChain: {
- if (operand_index_ == 0)
- return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
- if (operand_index_ == 2) return kMtfTypePointerToComposite;
- if (operand_index_ >= 3)
- return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
- break;
- }
- case SpvOpCompositeConstruct: {
- if (operand_index_ == 0) return kMtfTypeComposite;
- if (operand_index_ >= 2) {
- const uint32_t composite_type = GetInstWords()[1];
- if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
- return kMtfFloatScalarOrVector;
- if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
- return kMtfIntScalarOrVector;
- if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
- return kMtfBoolScalarOrVector;
- }
- break;
- }
- case SpvOpCompositeExtract: {
- if (operand_index_ == 2) return kMtfComposite;
- break;
- }
- case SpvOpConstantComposite: {
- if (operand_index_ == 0) return kMtfTypeComposite;
- if (operand_index_ >= 2) {
- const Instruction* composite_type_inst = FindDef(inst_.type_id);
- assert(composite_type_inst);
- if (composite_type_inst->opcode() == SpvOpTypeVector) {
- return GetMtfIdOfType(composite_type_inst->word(2));
- }
- }
- break;
- }
- case SpvOpExtInst: {
- if (operand_index_ == 2)
- return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
- if (operand_index_ >= 4) {
- const uint32_t return_type = GetInstWords()[1];
- const uint32_t ext_inst_type = inst_.ext_inst_type;
- const uint32_t ext_inst_index = GetInstWords()[4];
- // TODO([email protected]) The list of extended instructions is
- // incomplete. Only common instructions and low-hanging fruits listed.
- if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
- switch (ext_inst_index) {
- case GLSLstd450FAbs:
- case GLSLstd450FClamp:
- case GLSLstd450FMax:
- case GLSLstd450FMin:
- case GLSLstd450FMix:
- case GLSLstd450Step:
- case GLSLstd450SmoothStep:
- case GLSLstd450Fma:
- case GLSLstd450Pow:
- case GLSLstd450Exp:
- case GLSLstd450Exp2:
- case GLSLstd450Log:
- case GLSLstd450Log2:
- case GLSLstd450Sqrt:
- case GLSLstd450InverseSqrt:
- case GLSLstd450Fract:
- case GLSLstd450Floor:
- case GLSLstd450Ceil:
- case GLSLstd450Radians:
- case GLSLstd450Degrees:
- case GLSLstd450Sin:
- case GLSLstd450Cos:
- case GLSLstd450Tan:
- case GLSLstd450Sinh:
- case GLSLstd450Cosh:
- case GLSLstd450Tanh:
- case GLSLstd450Asin:
- case GLSLstd450Acos:
- case GLSLstd450Atan:
- case GLSLstd450Atan2:
- case GLSLstd450Asinh:
- case GLSLstd450Acosh:
- case GLSLstd450Atanh:
- case GLSLstd450MatrixInverse:
- case GLSLstd450Cross:
- case GLSLstd450Normalize:
- case GLSLstd450Reflect:
- case GLSLstd450FaceForward:
- return GetMtfIdOfType(return_type);
- case GLSLstd450Length:
- case GLSLstd450Distance:
- case GLSLstd450Refract:
- return kMtfFloatScalarOrVector;
- default:
- break;
- }
- } else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
- switch (ext_inst_index) {
- case OpenCLLIB::Fabs:
- case OpenCLLIB::FClamp:
- case OpenCLLIB::Fmax:
- case OpenCLLIB::Fmin:
- case OpenCLLIB::Step:
- case OpenCLLIB::Smoothstep:
- case OpenCLLIB::Fma:
- case OpenCLLIB::Pow:
- case OpenCLLIB::Exp:
- case OpenCLLIB::Exp2:
- case OpenCLLIB::Log:
- case OpenCLLIB::Log2:
- case OpenCLLIB::Sqrt:
- case OpenCLLIB::Rsqrt:
- case OpenCLLIB::Fract:
- case OpenCLLIB::Floor:
- case OpenCLLIB::Ceil:
- case OpenCLLIB::Radians:
- case OpenCLLIB::Degrees:
- case OpenCLLIB::Sin:
- case OpenCLLIB::Cos:
- case OpenCLLIB::Tan:
- case OpenCLLIB::Sinh:
- case OpenCLLIB::Cosh:
- case OpenCLLIB::Tanh:
- case OpenCLLIB::Asin:
- case OpenCLLIB::Acos:
- case OpenCLLIB::Atan:
- case OpenCLLIB::Atan2:
- case OpenCLLIB::Asinh:
- case OpenCLLIB::Acosh:
- case OpenCLLIB::Atanh:
- case OpenCLLIB::Cross:
- case OpenCLLIB::Normalize:
- return GetMtfIdOfType(return_type);
- case OpenCLLIB::Length:
- case OpenCLLIB::Distance:
- return kMtfFloatScalarOrVector;
- default:
- break;
- }
- }
- }
- break;
- }
- case SpvOpFunction: {
- if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
- if (operand_index_ == 3) {
- const uint32_t return_type = GetInstWords()[1];
- return GetMtfFunctionTypeWithReturnType(return_type);
- }
- break;
- }
- case SpvOpFunctionCall: {
- if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
- if (operand_index_ == 2) {
- const uint32_t return_type = GetInstWords()[1];
- return GetMtfFunctionWithReturnType(return_type);
- }
- if (operand_index_ >= 3) {
- const uint32_t function_id = GetInstWords()[3];
- const Instruction* function_inst = FindDef(function_id);
- if (!function_inst) return kMtfObject;
- assert(function_inst->opcode() == SpvOpFunction);
- const uint32_t function_type_id = function_inst->word(4);
- const Instruction* function_type_inst = FindDef(function_type_id);
- assert(function_type_inst);
- assert(function_type_inst->opcode() == SpvOpTypeFunction);
- const uint32_t argument_type = function_type_inst->word(operand_index_);
- return GetMtfIdOfType(argument_type);
- }
- break;
- }
- case SpvOpReturnValue: {
- if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_);
- break;
- }
- case SpvOpBranchConditional: {
- if (operand_index_ == 0)
- return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
- break;
- }
- case SpvOpSampledImage: {
- if (operand_index_ == 0)
- return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
- if (operand_index_ == 2)
- return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
- if (operand_index_ == 3)
- return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
- break;
- }
- case SpvOpImageSampleImplicitLod: {
- if (operand_index_ == 0)
- return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
- if (operand_index_ == 2)
- return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
- if (operand_index_ == 3)
- return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
- break;
- }
- default:
- break;
- }
- return kMtfNone;
- }
- spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
- auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
- if (codec) {
- uint64_t bits = 0;
- size_t num_bits = 0;
- if (codec->Encode(word, &bits, &num_bits)) {
- // Encoding successful.
- writer_.WriteBits(bits, num_bits);
- return SPV_SUCCESS;
- } else {
- // Encoding failed, write kMarkvNoneOfTheAbove flag.
- if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
- return Diag(SPV_ERROR_INTERNAL)
- << "Non-id word Huffman table for "
- << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
- << operand_index_ << " is missing kMarkvNoneOfTheAbove";
- writer_.WriteBits(bits, num_bits);
- }
- }
- // Fallback encoding.
- const size_t chunk_length =
- model_->GetOperandVariableWidthChunkLength(operand_.type);
- if (chunk_length) {
- writer_.WriteVariableWidthU32(word, chunk_length);
- } else {
- writer_.WriteUnencoded(word);
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
- auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
- if (codec) {
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to decode non-id word with Huffman";
- if (decoded_value != kMarkvNoneOfTheAbove) {
- // The word decoded successfully.
- *word = uint32_t(decoded_value);
- assert(*word == decoded_value);
- return SPV_SUCCESS;
- }
- // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
- }
- const size_t chunk_length =
- model_->GetOperandVariableWidthChunkLength(operand_.type);
- if (chunk_length) {
- if (!reader_.ReadVariableWidthU32(word, chunk_length))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to decode non-id word with varint";
- } else {
- if (!reader_.ReadUnencoded(word))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read unencoded non-id word";
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode,
- uint32_t num_operands) {
- uint64_t bits = 0;
- size_t num_bits = 0;
- const uint32_t word = opcode | (num_operands << 16);
- // First try to use the Markov chain codec.
- auto* codec =
- model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
- if (codec) {
- if (codec->Encode(word, &bits, &num_bits)) {
- // The word was successfully encoded into bits/num_bits.
- writer_.WriteBits(bits, num_bits);
- return SPV_SUCCESS;
- } else {
- // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
- // and use fallback encoding.
- if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
- return Diag(SPV_ERROR_INTERNAL)
- << "opcode_and_num_operands Huffman table for "
- << spvOpcodeString(GetPrevOpcode())
- << "is missing kMarkvNoneOfTheAbove";
- writer_.WriteBits(bits, num_bits);
- }
- }
- // Fallback to base-rate codec.
- codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
- assert(codec);
- if (codec->Encode(word, &bits, &num_bits)) {
- // The word was successfully encoded into bits/num_bits.
- writer_.WriteBits(bits, num_bits);
- return SPV_SUCCESS;
- } else {
- // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
- // and return false.
- if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
- return Diag(SPV_ERROR_INTERNAL)
- << "Global opcode_and_num_operands Huffman table is missing "
- << "kMarkvNoneOfTheAbove";
- writer_.WriteBits(bits, num_bits);
- return SPV_UNSUPPORTED;
- }
- }
- spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
- uint32_t* opcode, uint32_t* num_operands) {
- // First try to use the Markov chain codec.
- auto* codec =
- model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
- if (codec) {
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode opcode_and_num_operands, previous opcode is "
- << spvOpcodeString(GetPrevOpcode());
- if (decoded_value != kMarkvNoneOfTheAbove) {
- // The word was successfully decoded.
- *opcode = uint32_t(decoded_value & 0xFFFF);
- *num_operands = uint32_t(decoded_value >> 16);
- return SPV_SUCCESS;
- }
- // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
- }
- // Fallback to base-rate codec.
- codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
- assert(codec);
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode opcode_and_num_operands with global codec";
- if (decoded_value == kMarkvNoneOfTheAbove) {
- // Received kMarkvNoneOfTheAbove signal, fallback further.
- return SPV_UNSUPPORTED;
- }
- *opcode = uint32_t(decoded_value & 0xFFFF);
- *num_operands = uint32_t(decoded_value >> 16);
- return SPV_SUCCESS;
- }
- spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
- uint64_t fallback_method) {
- const auto* codec = GetMtfHuffmanCodec(mtf);
- if (!codec) {
- assert(fallback_method != kMtfNone);
- codec = GetMtfHuffmanCodec(fallback_method);
- }
- if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
- uint64_t bits = 0;
- size_t num_bits = 0;
- if (rank < kMtfSmallestRankEncodedByValue) {
- // Encode using Huffman coding.
- if (!codec->Encode(rank, &bits, &num_bits))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to encode MTF rank with Huffman";
- writer_.WriteBits(bits, num_bits);
- } else {
- // Encode by value.
- if (!codec->Encode(kMtfRankEncodedByValueSignal, &bits, &num_bits))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to encode kMtfRankEncodedByValueSignal";
- writer_.WriteBits(bits, num_bits);
- writer_.WriteVariableWidthU32(rank - kMtfSmallestRankEncodedByValue,
- model_->mtf_rank_chunk_length());
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
- uint32_t fallback_method,
- uint32_t* rank) {
- const auto* codec = GetMtfHuffmanCodec(mtf);
- if (!codec) {
- assert(fallback_method != kMtfNone);
- codec = GetMtfHuffmanCodec(fallback_method);
- }
- if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
- uint32_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
- if (decoded_value == kMtfRankEncodedByValueSignal) {
- // Decode by value.
- if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode MTF rank with varint";
- *rank += kMtfSmallestRankEncodedByValue;
- } else {
- // Decode using Huffman coding.
- assert(decoded_value < kMtfSmallestRankEncodedByValue);
- *rank = decoded_value;
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
- // Get the descriptor for id.
- const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id);
- auto* codec =
- model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
- uint64_t bits = 0;
- size_t num_bits = 0;
- uint64_t mtf = kMtfNone;
- if (long_descriptor && codec &&
- codec->Encode(long_descriptor, &bits, &num_bits)) {
- // If the descriptor exists and is in the table, write the descriptor and
- // proceed to encoding the rank.
- writer_.WriteBits(bits, num_bits);
- mtf = GetMtfLongIdDescriptor(long_descriptor);
- } else {
- if (codec) {
- // The descriptor doesn't exist or we have no coding for it. Write
- // kMarkvNoneOfTheAbove and go to fallback method.
- if (!codec->Encode(kMarkvNoneOfTheAbove, &bits, &num_bits))
- return Diag(SPV_ERROR_INTERNAL)
- << "Descriptor Huffman table for "
- << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
- << operand_index_ << " is missing kMarkvNoneOfTheAbove";
- writer_.WriteBits(bits, num_bits);
- }
- if (model_->id_fallback_strategy() !=
- MarkvModel::IdFallbackStrategy::kShortDescriptor) {
- return SPV_UNSUPPORTED;
- }
- const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id);
- writer_.WriteBits(short_descriptor, kShortDescriptorNumBits);
- if (short_descriptor == 0) {
- // Forward declared id.
- return SPV_UNSUPPORTED;
- }
- mtf = GetMtfShortIdDescriptor(short_descriptor);
- }
- // Descriptor has been encoded. Now encode the rank of the id in the
- // associated mtf sequence.
- return EncodeExistingId(mtf, id);
- }
- spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
- auto* codec =
- model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
- uint64_t mtf = kMtfNone;
- if (codec) {
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode descriptor with Huffman";
- if (decoded_value != kMarkvNoneOfTheAbove) {
- const uint32_t long_descriptor = uint32_t(decoded_value);
- mtf = GetMtfLongIdDescriptor(long_descriptor);
- }
- }
- if (mtf == kMtfNone) {
- if (model_->id_fallback_strategy() !=
- MarkvModel::IdFallbackStrategy::kShortDescriptor) {
- return SPV_UNSUPPORTED;
- }
- uint64_t decoded_value = 0;
- if (!reader_.ReadBits(&decoded_value, kShortDescriptorNumBits))
- return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
- const uint32_t short_descriptor = uint32_t(decoded_value);
- if (short_descriptor == 0) {
- // Forward declared id.
- return SPV_UNSUPPORTED;
- }
- mtf = GetMtfShortIdDescriptor(short_descriptor);
- }
- return DecodeExistingId(mtf, id);
- }
- spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
- assert(multi_mtf_.GetSize(mtf) > 0);
- if (multi_mtf_.GetSize(mtf) == 1) {
- // If the sequence has only one element no need to write rank, the decoder
- // would make the same decision.
- return SPV_SUCCESS;
- }
- uint32_t rank = 0;
- if (!multi_mtf_.RankFromValue(mtf, id, &rank))
- return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence";
- return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
- }
- spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
- assert(multi_mtf_.GetSize(mtf) > 0);
- *id = 0;
- uint32_t rank = 0;
- if (multi_mtf_.GetSize(mtf) == 1) {
- rank = 1;
- } else {
- const spv_result_t result =
- DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
- if (result != SPV_SUCCESS) return result;
- }
- assert(rank);
- if (!multi_mtf_.ValueFromRank(mtf, rank, id))
- return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
- return SPV_SUCCESS;
- }
- spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
- {
- // Try to encode using id descriptor mtfs.
- const spv_result_t result = EncodeIdWithDescriptor(id);
- if (result != SPV_UNSUPPORTED) return result;
- // If can't be done continue with other methods.
- }
- const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
- SpvOp(inst_.opcode))(operand_index_);
- uint32_t rank = 0;
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- // Encode using rule-based mtf.
- uint64_t mtf = GetRuleBasedMtf();
- if (mtf != kMtfNone && !can_forward_declare) {
- assert(multi_mtf_.HasValue(kMtfAll, id));
- return EncodeExistingId(mtf, id);
- }
- if (mtf == kMtfNone) mtf = kMtfAll;
- if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
- // This is the first occurrence of a forward declared id.
- multi_mtf_.Insert(kMtfAll, id);
- multi_mtf_.Insert(kMtfForwardDeclared, id);
- if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id);
- rank = 0;
- }
- return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
- } else {
- assert(can_forward_declare);
- if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) {
- // This is the first occurrence of a forward declared id.
- multi_mtf_.Insert(kMtfForwardDeclared, id);
- rank = 0;
- }
- writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
- return SPV_SUCCESS;
- }
- }
- spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
- {
- const spv_result_t result = DecodeIdWithDescriptor(id);
- if (result != SPV_UNSUPPORTED) return result;
- }
- const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
- SpvOp(inst_.opcode))(operand_index_);
- uint32_t rank = 0;
- *id = 0;
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- uint64_t mtf = GetRuleBasedMtf();
- if (mtf != kMtfNone && !can_forward_declare) {
- return DecodeExistingId(mtf, id);
- }
- if (mtf == kMtfNone) mtf = kMtfAll;
- {
- const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
- if (result != SPV_SUCCESS) return result;
- }
- if (rank == 0) {
- // This is the first occurrence of a forward declared id.
- *id = GetIdBound();
- SetIdBound(*id + 1);
- multi_mtf_.Insert(kMtfAll, *id);
- multi_mtf_.Insert(kMtfForwardDeclared, *id);
- if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
- } else {
- if (!multi_mtf_.ValueFromRank(mtf, rank, id))
- return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
- }
- } else {
- assert(can_forward_declare);
- if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode MTF rank with varint";
- if (rank == 0) {
- // This is the first occurrence of a forward declared id.
- *id = GetIdBound();
- SetIdBound(*id + 1);
- multi_mtf_.Insert(kMtfForwardDeclared, *id);
- } else {
- if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
- return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
- }
- }
- assert(*id);
- return SPV_SUCCESS;
- }
- spv_result_t MarkvEncoder::EncodeTypeId() {
- if (inst_.opcode == SpvOpFunctionParameter) {
- assert(!remaining_function_parameter_types_.empty());
- assert(inst_.type_id == remaining_function_parameter_types_.front());
- remaining_function_parameter_types_.pop_front();
- return SPV_SUCCESS;
- }
- {
- // Try to encode using id descriptor mtfs.
- const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
- if (result != SPV_UNSUPPORTED) return result;
- // If can't be done continue with other methods.
- }
- assert(model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased);
- uint64_t mtf = GetRuleBasedMtf();
- assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
- operand_index_));
- if (mtf == kMtfNone) {
- mtf = kMtfTypeNonFunction;
- // Function types should have been handled by GetRuleBasedMtf.
- assert(inst_.opcode != SpvOpFunction);
- }
- return EncodeExistingId(mtf, inst_.type_id);
- }
- spv_result_t MarkvDecoder::DecodeTypeId() {
- if (inst_.opcode == SpvOpFunctionParameter) {
- assert(!remaining_function_parameter_types_.empty());
- inst_.type_id = remaining_function_parameter_types_.front();
- remaining_function_parameter_types_.pop_front();
- return SPV_SUCCESS;
- }
- {
- const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
- if (result != SPV_UNSUPPORTED) return result;
- }
- assert(model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased);
- uint64_t mtf = GetRuleBasedMtf();
- assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
- operand_index_));
- if (mtf == kMtfNone) {
- mtf = kMtfTypeNonFunction;
- // Function types should have been handled by GetRuleBasedMtf.
- assert(inst_.opcode != SpvOpFunction);
- }
- return DecodeExistingId(mtf, &inst_.type_id);
- }
- spv_result_t MarkvEncoder::EncodeResultId() {
- uint32_t rank = 0;
- const uint64_t num_still_forward_declared =
- multi_mtf_.GetSize(kMtfForwardDeclared);
- if (num_still_forward_declared) {
- // We write the rank only if kMtfForwardDeclared is not empty. If it is
- // empty the decoder knows that there are no forward declared ids to expect.
- if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) {
- // This is a definition of a forward declared id. We can remove the id
- // from kMtfForwardDeclared.
- if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to remove id from kMtfForwardDeclared";
- writer_.WriteBits(1, 1);
- writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
- } else {
- rank = 0;
- writer_.WriteBits(0, 1);
- }
- }
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- if (!rank) {
- multi_mtf_.Insert(kMtfAll, inst_.result_id);
- }
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeResultId() {
- uint32_t rank = 0;
- const uint64_t num_still_forward_declared =
- multi_mtf_.GetSize(kMtfForwardDeclared);
- if (num_still_forward_declared) {
- // Some ids were forward declared. Check if this id is one of them.
- uint64_t id_was_forward_declared;
- if (!reader_.ReadBits(&id_was_forward_declared, 1))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read id_was_forward_declared flag";
- if (id_was_forward_declared) {
- if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read MTF rank of forward declared id";
- if (rank) {
- // The id was forward declared, recover it from kMtfForwardDeclared.
- if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
- &inst_.result_id))
- return Diag(SPV_ERROR_INTERNAL)
- << "Forward declared MTF rank is out of bounds";
- // We can now remove the id from kMtfForwardDeclared.
- if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to remove id from kMtfForwardDeclared";
- }
- }
- }
- if (inst_.result_id == 0) {
- // The id was not forward declared, issue a new id.
- inst_.result_id = GetIdBound();
- SetIdBound(inst_.result_id + 1);
- }
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- if (!rank) {
- multi_mtf_.Insert(kMtfAll, inst_.result_id);
- }
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvEncoder::EncodeLiteralNumber(
- const spv_parsed_operand_t& operand) {
- if (operand.number_bit_width <= 32) {
- const uint32_t word = inst_.words[operand.offset];
- return EncodeNonIdWord(word);
- } else {
- assert(operand.number_bit_width <= 64);
- const uint64_t word = uint64_t(inst_.words[operand.offset]) |
- (uint64_t(inst_.words[operand.offset + 1]) << 32);
- if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
- writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
- } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
- int64_t val = 0;
- std::memcpy(&val, &word, 8);
- writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
- model_->s64_block_exponent());
- } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
- writer_.WriteUnencoded(word);
- } else {
- return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
- }
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeLiteralNumber(
- const spv_parsed_operand_t& operand) {
- if (operand.number_bit_width <= 32) {
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- } else {
- assert(operand.number_bit_width <= 64);
- uint64_t word = 0;
- if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
- if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
- } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
- int64_t val = 0;
- if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
- model_->s64_block_exponent()))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
- std::memcpy(&word, &val, 8);
- } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
- if (!reader_.ReadUnencoded(&word))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
- } else {
- return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
- }
- inst_words_.push_back(static_cast<uint32_t>(word));
- inst_words_.push_back(static_cast<uint32_t>(word >> 32));
- }
- return SPV_SUCCESS;
- }
- void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
- const size_t num_bits_to_next_byte =
- GetNumBitsToNextByte(writer_.GetNumBits());
- if (num_bits_to_next_byte == 0 ||
- num_bits_to_next_byte > byte_break_if_less_than)
- return;
- if (logger_) {
- logger_->AppendWhitespaces(kCommentNumWhitespaces);
- logger_->AppendText("<byte break>");
- }
- writer_.WriteBits(0, num_bits_to_next_byte);
- }
- bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
- const size_t num_bits_to_next_byte =
- GetNumBitsToNextByte(reader_.GetNumReadBits());
- if (num_bits_to_next_byte == 0 ||
- num_bits_to_next_byte > byte_break_if_less_than)
- return true;
- uint64_t bits = 0;
- if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
- assert(bits == 0);
- if (bits != 0) return false;
- return true;
- }
- spv_result_t MarkvEncoder::EncodeInstruction(
- const spv_parsed_instruction_t& inst) {
- SpvOp opcode = SpvOp(inst.opcode);
- inst_ = inst;
- const spv_result_t validation_result = UpdateValidationState(inst);
- if (validation_result != SPV_SUCCESS) return validation_result;
- LogDisassemblyInstruction();
- const spv_result_t opcode_encodig_result =
- EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
- if (opcode_encodig_result < 0) return opcode_encodig_result;
- if (opcode_encodig_result != SPV_SUCCESS) {
- // Fallback encoding for opcode and num_operands.
- writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
- if (!OpcodeHasFixedNumberOfOperands(opcode)) {
- // If the opcode has a variable number of operands, encode the number of
- // operands with the instruction.
- if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces);
- writer_.WriteVariableWidthU16(inst.num_operands,
- model_->num_operands_chunk_length());
- }
- }
- // Write operands.
- const uint32_t num_operands = inst_.num_operands;
- for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
- operand_ = inst_.operands[operand_index_];
- if (logger_) {
- logger_->AppendWhitespaces(kCommentNumWhitespaces);
- logger_->AppendText("<");
- logger_->AppendText(spvOperandTypeStr(operand_.type));
- logger_->AppendText(">");
- }
- switch (operand_.type) {
- case SPV_OPERAND_TYPE_RESULT_ID:
- case SPV_OPERAND_TYPE_TYPE_ID:
- case SPV_OPERAND_TYPE_ID:
- case SPV_OPERAND_TYPE_OPTIONAL_ID:
- case SPV_OPERAND_TYPE_SCOPE_ID:
- case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
- const uint32_t id = inst_.words[operand_.offset];
- if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
- const spv_result_t result = EncodeTypeId();
- if (result != SPV_SUCCESS) return result;
- } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
- const spv_result_t result = EncodeResultId();
- if (result != SPV_SUCCESS) return result;
- } else {
- const spv_result_t result = EncodeRefId(id);
- if (result != SPV_SUCCESS) return result;
- }
- PromoteIfNeeded(id);
- break;
- }
- case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
- const spv_result_t result =
- EncodeNonIdWord(inst_.words[operand_.offset]);
- if (result != SPV_SUCCESS) return result;
- break;
- }
- case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
- const spv_result_t result = EncodeLiteralNumber(operand_);
- if (result != SPV_SUCCESS) return result;
- break;
- }
- case SPV_OPERAND_TYPE_LITERAL_STRING: {
- const char* src =
- reinterpret_cast<const char*>(&inst_.words[operand_.offset]);
- auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
- if (codec) {
- uint64_t bits = 0;
- size_t num_bits = 0;
- const std::string str = src;
- if (codec->Encode(str, &bits, &num_bits)) {
- writer_.WriteBits(bits, num_bits);
- break;
- } else {
- bool result =
- codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits);
- (void)result;
- assert(result);
- writer_.WriteBits(bits, num_bits);
- }
- }
- const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
- if (length == operand_.num_words * 4)
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to find terminal character of literal string";
- for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]);
- break;
- }
- default: {
- for (int i = 0; i < operand_.num_words; ++i) {
- const uint32_t word = inst_.words[operand_.offset + i];
- const spv_result_t result = EncodeNonIdWord(word);
- if (result != SPV_SUCCESS) return result;
- }
- break;
- }
- }
- }
- AddByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte);
- if (logger_) {
- logger_->NewLine();
- logger_->NewLine();
- if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
- }
- ProcessCurInstruction();
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
- const bool header_read_success =
- reader_.ReadUnencoded(&header_.magic_number) &&
- reader_.ReadUnencoded(&header_.markv_version) &&
- reader_.ReadUnencoded(&header_.markv_model) &&
- reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
- reader_.ReadUnencoded(&header_.spirv_version) &&
- reader_.ReadUnencoded(&header_.spirv_generator);
- if (!header_read_success)
- return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
- if (header_.markv_length_in_bits == 0)
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Header markv_length_in_bits field is zero";
- if (header_.magic_number != kMarkvMagicNumber)
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary has incorrect magic number";
- // TODO([email protected]): Print version strings.
- if (header_.markv_version != GetMarkvVersion())
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary and the codec have different versions";
- const uint32_t model_type = header_.markv_model >> 16;
- const uint32_t model_version = header_.markv_model & 0xFFFF;
- if (model_type != model_->model_type())
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary and the codec use different MARK-V models";
- if (model_version != model_->model_version())
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary and the codec use different versions if the same "
- << "MARK-V model";
- spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
- spirv_.resize(5, 0);
- spirv_[0] = kSpirvMagicNumber;
- spirv_[1] = header_.spirv_version;
- spirv_[2] = header_.spirv_generator;
- if (logger_) {
- reader_.SetCallback(
- [this](const std::string& str) { logger_->AppendBitSequence(str); });
- }
- while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
- inst_ = {};
- const spv_result_t decode_result = DecodeInstruction();
- if (decode_result != SPV_SUCCESS) return decode_result;
- const spv_result_t validation_result = UpdateValidationState(inst_);
- if (validation_result != SPV_SUCCESS) return validation_result;
- }
- if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
- !reader_.OnlyZeroesLeft()) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary has wrong stated bit length "
- << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
- }
- // Decoding of the module is finished, validation state should have correct
- // id bound.
- spirv_[3] = GetIdBound();
- *spirv_binary = std::move(spirv_);
- return SPV_SUCCESS;
- }
- // TODO([email protected]): The implementation borrows heavily from
- // Parser::parseOperand.
- // Consider coupling them together in some way once MARK-V codec is more mature.
- // For now it's better to keep the code independent for experimentation
- // purposes.
- spv_result_t MarkvDecoder::DecodeOperand(
- size_t operand_offset, const spv_operand_type_t type,
- spv_operand_pattern_t* expected_operands) {
- const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
- memset(&operand_, 0, sizeof(operand_));
- assert((operand_offset >> 16) == 0);
- operand_.offset = static_cast<uint16_t>(operand_offset);
- operand_.type = type;
- // Set default values, may be updated later.
- operand_.number_kind = SPV_NUMBER_NONE;
- operand_.number_bit_width = 0;
- const size_t first_word_index = inst_words_.size();
- switch (type) {
- case SPV_OPERAND_TYPE_RESULT_ID: {
- const spv_result_t result = DecodeResultId();
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(inst_.result_id);
- SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
- PromoteIfNeeded(inst_.result_id);
- break;
- }
- case SPV_OPERAND_TYPE_TYPE_ID: {
- const spv_result_t result = DecodeTypeId();
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(inst_.type_id);
- SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
- PromoteIfNeeded(inst_.type_id);
- break;
- }
- case SPV_OPERAND_TYPE_ID:
- case SPV_OPERAND_TYPE_OPTIONAL_ID:
- case SPV_OPERAND_TYPE_SCOPE_ID:
- case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
- uint32_t id = 0;
- const spv_result_t result = DecodeRefId(&id);
- if (result != SPV_SUCCESS) return result;
- if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
- if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
- operand_.type = SPV_OPERAND_TYPE_ID;
- if (opcode == SpvOpExtInst && operand_.offset == 3) {
- // The current word is the extended instruction set id.
- // Set the extended instruction set type for the current
- // instruction.
- auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
- if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
- return Diag(SPV_ERROR_INVALID_ID)
- << "OpExtInst set id " << id
- << " does not reference an OpExtInstImport result Id";
- }
- inst_.ext_inst_type = ext_inst_type_iter->second;
- }
- }
- inst_words_.push_back(id);
- SetIdBound(std::max(GetIdBound(), id + 1));
- PromoteIfNeeded(id);
- break;
- }
- case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- assert(SpvOpExtInst == opcode);
- assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
- spv_ext_inst_desc ext_inst;
- if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid extended instruction number: " << word;
- spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
- break;
- }
- case SPV_OPERAND_TYPE_LITERAL_INTEGER:
- case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
- // These are regular single-word literal integer operands.
- // Post-parsing validation should check the range of the parsed value.
- operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
- // It turns out they are always unsigned integers!
- operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
- operand_.number_bit_width = 32;
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- break;
- }
- case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
- case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
- operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
- if (opcode == SpvOpSwitch) {
- // The literal operands have the same type as the value
- // referenced by the selector Id.
- const uint32_t selector_id = inst_words_.at(1);
- const auto type_id_iter = id_to_type_id_.find(selector_id);
- if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid OpSwitch: selector id " << selector_id
- << " has no type";
- }
- uint32_t type_id = type_id_iter->second;
- if (selector_id == type_id) {
- // Recall that by convention, a result ID that is a type definition
- // maps to itself.
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid OpSwitch: selector id " << selector_id
- << " is a type, not a value";
- }
- if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
- return error;
- if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
- operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid OpSwitch: selector id " << selector_id
- << " is not a scalar integer";
- }
- } else {
- assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
- // The literal number type is determined by the type Id for the
- // constant.
- assert(inst_.type_id);
- if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
- return error;
- }
- if (auto error = DecodeLiteralNumber(operand_)) return error;
- break;
- }
- case SPV_OPERAND_TYPE_LITERAL_STRING:
- case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
- operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
- std::vector<char> str;
- auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
- if (codec) {
- std::string decoded_string;
- const bool huffman_result =
- codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
- assert(huffman_result);
- if (!huffman_result)
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read literal string";
- if (decoded_string != "kMarkvNoneOfTheAbove") {
- std::copy(decoded_string.begin(), decoded_string.end(),
- std::back_inserter(str));
- str.push_back('\0');
- }
- }
- // The loop is expected to terminate once we encounter '\0' or exhaust
- // the bit stream.
- if (str.empty()) {
- while (true) {
- char ch = 0;
- if (!reader_.ReadUnencoded(&ch))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read literal string";
- str.push_back(ch);
- if (ch == '\0') break;
- }
- }
- while (str.size() % 4 != 0) str.push_back('\0');
- inst_words_.resize(inst_words_.size() + str.size() / 4);
- std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
- if (SpvOpExtInstImport == opcode) {
- // Record the extended instruction type for the ID for this import.
- // There is only one string literal argument to OpExtInstImport,
- // so it's sufficient to guard this just on the opcode.
- const spv_ext_inst_type_t ext_inst_type =
- spvExtInstImportTypeGet(str.data());
- if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid extended instruction import '" << str.data()
- << "'";
- }
- // We must have parsed a valid result ID. It's a condition
- // of the grammar, and we only accept non-zero result Ids.
- assert(inst_.result_id);
- const bool inserted =
- import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
- .second;
- (void)inserted;
- assert(inserted);
- }
- break;
- }
- case SPV_OPERAND_TYPE_CAPABILITY:
- case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
- case SPV_OPERAND_TYPE_EXECUTION_MODEL:
- case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
- case SPV_OPERAND_TYPE_MEMORY_MODEL:
- case SPV_OPERAND_TYPE_EXECUTION_MODE:
- case SPV_OPERAND_TYPE_STORAGE_CLASS:
- case SPV_OPERAND_TYPE_DIMENSIONALITY:
- case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
- case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
- case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
- case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
- case SPV_OPERAND_TYPE_LINKAGE_TYPE:
- case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
- case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
- case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
- case SPV_OPERAND_TYPE_DECORATION:
- case SPV_OPERAND_TYPE_BUILT_IN:
- case SPV_OPERAND_TYPE_GROUP_OPERATION:
- case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
- case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
- // A single word that is a plain enum value.
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- // Map an optional operand type to its corresponding concrete type.
- if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
- operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
- spv_operand_desc entry;
- if (grammar_.lookupOperand(type, word, &entry)) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid " << spvOperandTypeStr(operand_.type)
- << " operand: " << word;
- }
- // Prepare to accept operands to this operand, if needed.
- spvPushOperandTypes(entry->operandTypes, expected_operands);
- break;
- }
- case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
- case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
- case SPV_OPERAND_TYPE_LOOP_CONTROL:
- case SPV_OPERAND_TYPE_IMAGE:
- case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
- case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
- case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
- // This operand is a mask.
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- // Map an optional operand type to its corresponding concrete type.
- if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
- operand_.type = SPV_OPERAND_TYPE_IMAGE;
- else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
- operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
- // Check validity of set mask bits. Also prepare for operands for those
- // masks if they have any. To get operand order correct, scan from
- // MSB to LSB since we can only prepend operands to a pattern.
- // The only case in the grammar where you have more than one mask bit
- // having an operand is for image operands. See SPIR-V 3.14 Image
- // Operands.
- uint32_t remaining_word = word;
- for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
- if (remaining_word & mask) {
- spv_operand_desc entry;
- if (grammar_.lookupOperand(type, mask, &entry)) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid " << spvOperandTypeStr(operand_.type)
- << " operand: " << word << " has invalid mask component "
- << mask;
- }
- remaining_word ^= mask;
- spvPushOperandTypes(entry->operandTypes, expected_operands);
- }
- }
- if (word == 0) {
- // An all-zeroes mask *might* also be valid.
- spv_operand_desc entry;
- if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
- // Prepare for its operands, if any.
- spvPushOperandTypes(entry->operandTypes, expected_operands);
- }
- }
- break;
- }
- default:
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Internal error: Unhandled operand type: " << type;
- }
- operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
- assert(spvOperandIsConcrete(operand_.type));
- parsed_operands_.push_back(operand_);
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeInstruction() {
- parsed_operands_.clear();
- inst_words_.clear();
- // Opcode/num_words placeholder, the word will be filled in later.
- inst_words_.push_back(0);
- bool num_operands_still_unknown = true;
- {
- uint32_t opcode = 0;
- uint32_t num_operands = 0;
- const spv_result_t opcode_decoding_result =
- DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
- if (opcode_decoding_result < 0) return opcode_decoding_result;
- if (opcode_decoding_result == SPV_SUCCESS) {
- inst_.num_operands = static_cast<uint16_t>(num_operands);
- num_operands_still_unknown = false;
- } else {
- if (!reader_.ReadVariableWidthU32(&opcode,
- model_->opcode_chunk_length())) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read opcode of instruction";
- }
- }
- inst_.opcode = static_cast<uint16_t>(opcode);
- }
- const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
- spv_opcode_desc opcode_desc;
- if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
- return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
- }
- spv_operand_pattern_t expected_operands;
- expected_operands.reserve(opcode_desc->numTypes);
- for (auto i = 0; i < opcode_desc->numTypes; i++) {
- expected_operands.push_back(
- opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
- }
- if (num_operands_still_unknown) {
- if (!OpcodeHasFixedNumberOfOperands(opcode)) {
- if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
- model_->num_operands_chunk_length()))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read num_operands of instruction";
- } else {
- inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
- }
- }
- for (operand_index_ = 0;
- operand_index_ < static_cast<size_t>(inst_.num_operands);
- ++operand_index_) {
- assert(!expected_operands.empty());
- const spv_operand_type_t type =
- spvTakeFirstMatchableOperand(&expected_operands);
- const size_t operand_offset = inst_words_.size();
- const spv_result_t decode_result =
- DecodeOperand(operand_offset, type, &expected_operands);
- if (decode_result != SPV_SUCCESS) return decode_result;
- }
- assert(inst_.num_operands == parsed_operands_.size());
- // Only valid while inst_words_ and parsed_operands_ remain unchanged (until
- // next DecodeInstruction call).
- inst_.words = inst_words_.data();
- inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
- inst_.num_words = static_cast<uint16_t>(inst_words_.size());
- inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
- std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
- assert(inst_.num_words ==
- std::accumulate(
- parsed_operands_.begin(), parsed_operands_.end(), 1,
- [](int num_words, const spv_parsed_operand_t& operand) {
- return num_words += operand.num_words;
- }) &&
- "num_words in instruction doesn't correspond to the sum of num_words"
- "in the operands");
- RecordNumberType();
- ProcessCurInstruction();
- if (!ReadToByteBreak(kByteBreakAfterInstIfLessThanUntilNextByte))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
- if (logger_) {
- logger_->NewLine();
- std::stringstream ss;
- ss << spvOpcodeString(opcode) << " ";
- for (size_t index = 1; index < inst_words_.size(); ++index)
- ss << inst_words_[index] << " ";
- logger_->AppendText(ss.str());
- logger_->NewLine();
- logger_->NewLine();
- if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
- spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
- assert(type_id != 0);
- auto type_info_iter = type_id_to_number_type_info_.find(type_id);
- if (type_info_iter == type_id_to_number_type_info_.end()) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Type Id " << type_id << " is not a type";
- }
- const NumberType& info = type_info_iter->second;
- if (info.type == SPV_NUMBER_NONE) {
- // This is a valid type, but for something other than a scalar number.
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Type Id " << type_id << " is not a scalar numeric type";
- }
- parsed_operand->number_kind = info.type;
- parsed_operand->number_bit_width = info.bit_width;
- // Round up the word count.
- parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
- return SPV_SUCCESS;
- }
- void MarkvDecoder::RecordNumberType() {
- const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
- if (spvOpcodeGeneratesType(opcode)) {
- NumberType info = {SPV_NUMBER_NONE, 0};
- if (SpvOpTypeInt == opcode) {
- info.bit_width = inst_.words[inst_.operands[1].offset];
- info.type = inst_.words[inst_.operands[2].offset]
- ? SPV_NUMBER_SIGNED_INT
- : SPV_NUMBER_UNSIGNED_INT;
- } else if (SpvOpTypeFloat == opcode) {
- info.bit_width = inst_.words[inst_.operands[1].offset];
- info.type = SPV_NUMBER_FLOATING;
- }
- // The *result* Id of a type generating instruction is the type Id.
- type_id_to_number_type_info_[inst_.result_id] = info;
- }
- }
- spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
- uint32_t magic, uint32_t version, uint32_t generator,
- uint32_t id_bound, uint32_t schema) {
- MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
- return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
- schema);
- }
- spv_result_t EncodeInstruction(void* user_data,
- const spv_parsed_instruction_t* inst) {
- MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
- return encoder->EncodeInstruction(*inst);
- }
- } // namespace
- spv_result_t SpirvToMarkv(
- spv_const_context context, const std::vector<uint32_t>& spirv,
- const MarkvCodecOptions& options, const MarkvModel& markv_model,
- MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
- MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
- spv_context_t hijack_context = *context;
- libspirv::SetContextMessageConsumer(&hijack_context, message_consumer);
- spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
- spv_endianness_t endian;
- spv_position_t position = {};
- if (spvBinaryEndianness(&spirv_binary, &endian)) {
- return DiagnosticStream(position, hijack_context.consumer,
- SPV_ERROR_INVALID_BINARY)
- << "Invalid SPIR-V magic number.";
- }
- spv_header_t header;
- if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
- return DiagnosticStream(position, hijack_context.consumer,
- SPV_ERROR_INVALID_BINARY)
- << "Invalid SPIR-V header.";
- }
- MarkvEncoder encoder(&hijack_context, options, &markv_model);
- if (log_consumer || debug_consumer) {
- encoder.CreateLogger(log_consumer, debug_consumer);
- spv_text text = nullptr;
- if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
- SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
- nullptr) != SPV_SUCCESS) {
- return DiagnosticStream(position, hijack_context.consumer,
- SPV_ERROR_INVALID_BINARY)
- << "Failed to disassemble SPIR-V binary.";
- }
- assert(text);
- encoder.SetDisassembly(std::string(text->str, text->length));
- spvTextDestroy(text);
- }
- if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
- EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
- return DiagnosticStream(position, hijack_context.consumer,
- SPV_ERROR_INVALID_BINARY)
- << "Unable to encode to MARK-V.";
- }
- *markv = encoder.GetMarkvBinary();
- return SPV_SUCCESS;
- }
- spv_result_t MarkvToSpirv(
- spv_const_context context, const std::vector<uint8_t>& markv,
- const MarkvCodecOptions& options, const MarkvModel& markv_model,
- MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
- MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
- spv_position_t position = {};
- spv_context_t hijack_context = *context;
- libspirv::SetContextMessageConsumer(&hijack_context, message_consumer);
- MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
- if (log_consumer || debug_consumer)
- decoder.CreateLogger(log_consumer, debug_consumer);
- if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
- return DiagnosticStream(position, hijack_context.consumer,
- SPV_ERROR_INVALID_BINARY)
- << "Unable to decode MARK-V.";
- }
- assert(!spirv->empty());
- return SPV_SUCCESS;
- }
- } // namespace spvtools
|