validation_state.cpp 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323
  1. // Copyright (c) 2015-2016 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/val/validation_state.h"
  15. #include <cassert>
  16. #include <stack>
  17. #include <utility>
  18. #include "source/opcode.h"
  19. #include "source/spirv_constant.h"
  20. #include "source/spirv_target_env.h"
  21. #include "source/val/basic_block.h"
  22. #include "source/val/construct.h"
  23. #include "source/val/function.h"
  24. #include "spirv-tools/libspirv.h"
  25. namespace spvtools {
  26. namespace val {
  27. namespace {
  28. bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
  29. // See Section 2.4
  30. bool out = false;
  31. // clang-format off
  32. switch (layout) {
  33. case kLayoutCapabilities: out = op == SpvOpCapability; break;
  34. case kLayoutExtensions: out = op == SpvOpExtension; break;
  35. case kLayoutExtInstImport: out = op == SpvOpExtInstImport; break;
  36. case kLayoutMemoryModel: out = op == SpvOpMemoryModel; break;
  37. case kLayoutEntryPoint: out = op == SpvOpEntryPoint; break;
  38. case kLayoutExecutionMode:
  39. out = op == SpvOpExecutionMode || op == SpvOpExecutionModeId;
  40. break;
  41. case kLayoutDebug1:
  42. switch (op) {
  43. case SpvOpSourceContinued:
  44. case SpvOpSource:
  45. case SpvOpSourceExtension:
  46. case SpvOpString:
  47. out = true;
  48. break;
  49. default: break;
  50. }
  51. break;
  52. case kLayoutDebug2:
  53. switch (op) {
  54. case SpvOpName:
  55. case SpvOpMemberName:
  56. out = true;
  57. break;
  58. default: break;
  59. }
  60. break;
  61. case kLayoutDebug3:
  62. // Only OpModuleProcessed is allowed here.
  63. out = (op == SpvOpModuleProcessed);
  64. break;
  65. case kLayoutAnnotations:
  66. switch (op) {
  67. case SpvOpDecorate:
  68. case SpvOpMemberDecorate:
  69. case SpvOpGroupDecorate:
  70. case SpvOpGroupMemberDecorate:
  71. case SpvOpDecorationGroup:
  72. case SpvOpDecorateId:
  73. case SpvOpDecorateStringGOOGLE:
  74. case SpvOpMemberDecorateStringGOOGLE:
  75. out = true;
  76. break;
  77. default: break;
  78. }
  79. break;
  80. case kLayoutTypes:
  81. if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
  82. out = true;
  83. break;
  84. }
  85. switch (op) {
  86. case SpvOpTypeForwardPointer:
  87. case SpvOpVariable:
  88. case SpvOpLine:
  89. case SpvOpNoLine:
  90. case SpvOpUndef:
  91. // SpvOpExtInst is only allowed here for certain extended instruction
  92. // sets. This will be checked separately
  93. case SpvOpExtInst:
  94. out = true;
  95. break;
  96. default: break;
  97. }
  98. break;
  99. case kLayoutFunctionDeclarations:
  100. case kLayoutFunctionDefinitions:
  101. // NOTE: These instructions should NOT be in these layout sections
  102. if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
  103. out = false;
  104. break;
  105. }
  106. switch (op) {
  107. case SpvOpCapability:
  108. case SpvOpExtension:
  109. case SpvOpExtInstImport:
  110. case SpvOpMemoryModel:
  111. case SpvOpEntryPoint:
  112. case SpvOpExecutionMode:
  113. case SpvOpExecutionModeId:
  114. case SpvOpSourceContinued:
  115. case SpvOpSource:
  116. case SpvOpSourceExtension:
  117. case SpvOpString:
  118. case SpvOpName:
  119. case SpvOpMemberName:
  120. case SpvOpModuleProcessed:
  121. case SpvOpDecorate:
  122. case SpvOpMemberDecorate:
  123. case SpvOpGroupDecorate:
  124. case SpvOpGroupMemberDecorate:
  125. case SpvOpDecorationGroup:
  126. case SpvOpTypeForwardPointer:
  127. out = false;
  128. break;
  129. default:
  130. out = true;
  131. break;
  132. }
  133. }
  134. // clang-format on
  135. return out;
  136. }
  137. // Counts the number of instructions and functions in the file.
  138. spv_result_t CountInstructions(void* user_data,
  139. const spv_parsed_instruction_t* inst) {
  140. ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
  141. if (inst->opcode == SpvOpFunction) _.increment_total_functions();
  142. _.increment_total_instructions();
  143. return SPV_SUCCESS;
  144. }
  145. spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
  146. uint32_t version, uint32_t generator, uint32_t id_bound,
  147. uint32_t) {
  148. ValidationState_t& vstate =
  149. *(reinterpret_cast<ValidationState_t*>(user_data));
  150. vstate.setIdBound(id_bound);
  151. vstate.setGenerator(generator);
  152. vstate.setVersion(version);
  153. return SPV_SUCCESS;
  154. }
  155. // Add features based on SPIR-V core version number.
  156. void UpdateFeaturesBasedOnSpirvVersion(ValidationState_t::Feature* features,
  157. uint32_t version) {
  158. assert(features);
  159. if (version >= SPV_SPIRV_VERSION_WORD(1, 4)) {
  160. features->select_between_composites = true;
  161. features->copy_memory_permits_two_memory_accesses = true;
  162. features->uconvert_spec_constant_op = true;
  163. features->nonwritable_var_in_function_or_private = true;
  164. }
  165. }
  166. } // namespace
  167. ValidationState_t::ValidationState_t(const spv_const_context ctx,
  168. const spv_const_validator_options opt,
  169. const uint32_t* words,
  170. const size_t num_words,
  171. const uint32_t max_warnings)
  172. : context_(ctx),
  173. options_(opt),
  174. words_(words),
  175. num_words_(num_words),
  176. unresolved_forward_ids_{},
  177. operand_names_{},
  178. current_layout_section_(kLayoutCapabilities),
  179. module_functions_(),
  180. module_capabilities_(),
  181. module_extensions_(),
  182. ordered_instructions_(),
  183. all_definitions_(),
  184. global_vars_(),
  185. local_vars_(),
  186. struct_nesting_depth_(),
  187. struct_has_nested_blockorbufferblock_struct_(),
  188. grammar_(ctx),
  189. addressing_model_(SpvAddressingModelMax),
  190. memory_model_(SpvMemoryModelMax),
  191. pointer_size_and_alignment_(0),
  192. in_function_(false),
  193. num_of_warnings_(0),
  194. max_num_of_warnings_(max_warnings) {
  195. assert(opt && "Validator options may not be Null.");
  196. const auto env = context_->target_env;
  197. if (spvIsVulkanEnv(env)) {
  198. // Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core.
  199. if (env != SPV_ENV_VULKAN_1_0) {
  200. features_.env_relaxed_block_layout = true;
  201. }
  202. }
  203. // Only attempt to count if we have words, otherwise let the other validation
  204. // fail and generate an error.
  205. if (num_words > 0) {
  206. // Count the number of instructions in the binary.
  207. // This parse should not produce any error messages. Hijack the context and
  208. // replace the message consumer so that we do not pollute any state in input
  209. // consumer.
  210. spv_context_t hijacked_context = *ctx;
  211. hijacked_context.consumer = [](spv_message_level_t, const char*,
  212. const spv_position_t&, const char*) {};
  213. spvBinaryParse(&hijacked_context, this, words, num_words, setHeader,
  214. CountInstructions,
  215. /* diagnostic = */ nullptr);
  216. preallocateStorage();
  217. }
  218. UpdateFeaturesBasedOnSpirvVersion(&features_, version_);
  219. friendly_mapper_ = spvtools::MakeUnique<spvtools::FriendlyNameMapper>(
  220. context_, words_, num_words_);
  221. name_mapper_ = friendly_mapper_->GetNameMapper();
  222. }
  223. void ValidationState_t::preallocateStorage() {
  224. ordered_instructions_.reserve(total_instructions_);
  225. module_functions_.reserve(total_functions_);
  226. }
  227. spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) {
  228. unresolved_forward_ids_.insert(id);
  229. return SPV_SUCCESS;
  230. }
  231. spv_result_t ValidationState_t::RemoveIfForwardDeclared(uint32_t id) {
  232. unresolved_forward_ids_.erase(id);
  233. return SPV_SUCCESS;
  234. }
  235. spv_result_t ValidationState_t::RegisterForwardPointer(uint32_t id) {
  236. forward_pointer_ids_.insert(id);
  237. return SPV_SUCCESS;
  238. }
  239. bool ValidationState_t::IsForwardPointer(uint32_t id) const {
  240. return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end());
  241. }
  242. void ValidationState_t::AssignNameToId(uint32_t id, std::string name) {
  243. operand_names_[id] = name;
  244. }
  245. std::string ValidationState_t::getIdName(uint32_t id) const {
  246. const std::string id_name = name_mapper_(id);
  247. std::stringstream out;
  248. out << id << "[%" << id_name << "]";
  249. return out.str();
  250. }
  251. size_t ValidationState_t::unresolved_forward_id_count() const {
  252. return unresolved_forward_ids_.size();
  253. }
  254. std::vector<uint32_t> ValidationState_t::UnresolvedForwardIds() const {
  255. std::vector<uint32_t> out(std::begin(unresolved_forward_ids_),
  256. std::end(unresolved_forward_ids_));
  257. return out;
  258. }
  259. bool ValidationState_t::IsDefinedId(uint32_t id) const {
  260. return all_definitions_.find(id) != std::end(all_definitions_);
  261. }
  262. const Instruction* ValidationState_t::FindDef(uint32_t id) const {
  263. auto it = all_definitions_.find(id);
  264. if (it == all_definitions_.end()) return nullptr;
  265. return it->second;
  266. }
  267. Instruction* ValidationState_t::FindDef(uint32_t id) {
  268. auto it = all_definitions_.find(id);
  269. if (it == all_definitions_.end()) return nullptr;
  270. return it->second;
  271. }
  272. ModuleLayoutSection ValidationState_t::current_layout_section() const {
  273. return current_layout_section_;
  274. }
  275. void ValidationState_t::ProgressToNextLayoutSectionOrder() {
  276. // Guard against going past the last element(kLayoutFunctionDefinitions)
  277. if (current_layout_section_ <= kLayoutFunctionDefinitions) {
  278. current_layout_section_ =
  279. static_cast<ModuleLayoutSection>(current_layout_section_ + 1);
  280. }
  281. }
  282. bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) {
  283. return IsInstructionInLayoutSection(current_layout_section_, op);
  284. }
  285. DiagnosticStream ValidationState_t::diag(spv_result_t error_code,
  286. const Instruction* inst) {
  287. if (error_code == SPV_WARNING) {
  288. if (num_of_warnings_ == max_num_of_warnings_) {
  289. DiagnosticStream({0, 0, 0}, context_->consumer, "", error_code)
  290. << "Other warnings have been suppressed.\n";
  291. }
  292. if (num_of_warnings_ >= max_num_of_warnings_) {
  293. return DiagnosticStream({0, 0, 0}, nullptr, "", error_code);
  294. }
  295. ++num_of_warnings_;
  296. }
  297. std::string disassembly;
  298. if (inst) disassembly = Disassemble(*inst);
  299. return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0},
  300. context_->consumer, disassembly, error_code);
  301. }
  302. std::vector<Function>& ValidationState_t::functions() {
  303. return module_functions_;
  304. }
  305. Function& ValidationState_t::current_function() {
  306. assert(in_function_body());
  307. return module_functions_.back();
  308. }
  309. const Function& ValidationState_t::current_function() const {
  310. assert(in_function_body());
  311. return module_functions_.back();
  312. }
  313. const Function* ValidationState_t::function(uint32_t id) const {
  314. const auto it = id_to_function_.find(id);
  315. if (it == id_to_function_.end()) return nullptr;
  316. return it->second;
  317. }
  318. Function* ValidationState_t::function(uint32_t id) {
  319. auto it = id_to_function_.find(id);
  320. if (it == id_to_function_.end()) return nullptr;
  321. return it->second;
  322. }
  323. bool ValidationState_t::in_function_body() const { return in_function_; }
  324. bool ValidationState_t::in_block() const {
  325. return module_functions_.empty() == false &&
  326. module_functions_.back().current_block() != nullptr;
  327. }
  328. void ValidationState_t::RegisterCapability(SpvCapability cap) {
  329. // Avoid redundant work. Otherwise the recursion could induce work
  330. // quadrdatic in the capability dependency depth. (Ok, not much, but
  331. // it's something.)
  332. if (module_capabilities_.Contains(cap)) return;
  333. module_capabilities_.Add(cap);
  334. spv_operand_desc desc;
  335. if (SPV_SUCCESS ==
  336. grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) {
  337. CapabilitySet(desc->numCapabilities, desc->capabilities)
  338. .ForEach([this](SpvCapability c) { RegisterCapability(c); });
  339. }
  340. switch (cap) {
  341. case SpvCapabilityKernel:
  342. features_.group_ops_reduce_and_scans = true;
  343. break;
  344. case SpvCapabilityInt8:
  345. features_.use_int8_type = true;
  346. features_.declare_int8_type = true;
  347. break;
  348. case SpvCapabilityStorageBuffer8BitAccess:
  349. case SpvCapabilityUniformAndStorageBuffer8BitAccess:
  350. case SpvCapabilityStoragePushConstant8:
  351. features_.declare_int8_type = true;
  352. break;
  353. case SpvCapabilityInt16:
  354. features_.declare_int16_type = true;
  355. break;
  356. case SpvCapabilityFloat16:
  357. case SpvCapabilityFloat16Buffer:
  358. features_.declare_float16_type = true;
  359. break;
  360. case SpvCapabilityStorageUniformBufferBlock16:
  361. case SpvCapabilityStorageUniform16:
  362. case SpvCapabilityStoragePushConstant16:
  363. case SpvCapabilityStorageInputOutput16:
  364. features_.declare_int16_type = true;
  365. features_.declare_float16_type = true;
  366. features_.free_fp_rounding_mode = true;
  367. break;
  368. case SpvCapabilityVariablePointers:
  369. features_.variable_pointers = true;
  370. features_.variable_pointers_storage_buffer = true;
  371. break;
  372. case SpvCapabilityVariablePointersStorageBuffer:
  373. features_.variable_pointers_storage_buffer = true;
  374. break;
  375. default:
  376. break;
  377. }
  378. }
  379. void ValidationState_t::RegisterExtension(Extension ext) {
  380. if (module_extensions_.Contains(ext)) return;
  381. module_extensions_.Add(ext);
  382. switch (ext) {
  383. case kSPV_AMD_gpu_shader_half_float:
  384. case kSPV_AMD_gpu_shader_half_float_fetch:
  385. // SPV_AMD_gpu_shader_half_float enables float16 type.
  386. // https://github.com/KhronosGroup/SPIRV-Tools/issues/1375
  387. features_.declare_float16_type = true;
  388. break;
  389. case kSPV_AMD_gpu_shader_int16:
  390. // This is not yet in the extension, but it's recommended for it.
  391. // See https://github.com/KhronosGroup/glslang/issues/848
  392. features_.uconvert_spec_constant_op = true;
  393. break;
  394. case kSPV_AMD_shader_ballot:
  395. // The grammar doesn't encode the fact that SPV_AMD_shader_ballot
  396. // enables the use of group operations Reduce, InclusiveScan,
  397. // and ExclusiveScan. Enable it manually.
  398. // https://github.com/KhronosGroup/SPIRV-Tools/issues/991
  399. features_.group_ops_reduce_and_scans = true;
  400. break;
  401. default:
  402. break;
  403. }
  404. }
  405. bool ValidationState_t::HasAnyOfCapabilities(
  406. const CapabilitySet& capabilities) const {
  407. return module_capabilities_.HasAnyOf(capabilities);
  408. }
  409. bool ValidationState_t::HasAnyOfExtensions(
  410. const ExtensionSet& extensions) const {
  411. return module_extensions_.HasAnyOf(extensions);
  412. }
  413. void ValidationState_t::set_addressing_model(SpvAddressingModel am) {
  414. addressing_model_ = am;
  415. switch (am) {
  416. case SpvAddressingModelPhysical32:
  417. pointer_size_and_alignment_ = 4;
  418. break;
  419. default:
  420. // fall through
  421. case SpvAddressingModelPhysical64:
  422. case SpvAddressingModelPhysicalStorageBuffer64EXT:
  423. pointer_size_and_alignment_ = 8;
  424. break;
  425. }
  426. }
  427. SpvAddressingModel ValidationState_t::addressing_model() const {
  428. return addressing_model_;
  429. }
  430. void ValidationState_t::set_memory_model(SpvMemoryModel mm) {
  431. memory_model_ = mm;
  432. }
  433. SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; }
  434. spv_result_t ValidationState_t::RegisterFunction(
  435. uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
  436. uint32_t function_type_id) {
  437. assert(in_function_body() == false &&
  438. "RegisterFunction can only be called when parsing the binary outside "
  439. "of another function");
  440. in_function_ = true;
  441. module_functions_.emplace_back(id, ret_type_id, function_control,
  442. function_type_id);
  443. id_to_function_.emplace(id, &current_function());
  444. // TODO(umar): validate function type and type_id
  445. return SPV_SUCCESS;
  446. }
  447. spv_result_t ValidationState_t::RegisterFunctionEnd() {
  448. assert(in_function_body() == true &&
  449. "RegisterFunctionEnd can only be called when parsing the binary "
  450. "inside of another function");
  451. assert(in_block() == false &&
  452. "RegisterFunctionParameter can only be called when parsing the binary "
  453. "ouside of a block");
  454. current_function().RegisterFunctionEnd();
  455. in_function_ = false;
  456. return SPV_SUCCESS;
  457. }
  458. Instruction* ValidationState_t::AddOrderedInstruction(
  459. const spv_parsed_instruction_t* inst) {
  460. ordered_instructions_.emplace_back(inst);
  461. ordered_instructions_.back().SetLineNum(ordered_instructions_.size());
  462. return &ordered_instructions_.back();
  463. }
  464. // Improves diagnostic messages by collecting names of IDs
  465. void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) {
  466. switch (inst->opcode()) {
  467. case SpvOpName: {
  468. const auto target = inst->GetOperandAs<uint32_t>(0);
  469. const auto* str = reinterpret_cast<const char*>(inst->words().data() +
  470. inst->operand(1).offset);
  471. AssignNameToId(target, str);
  472. break;
  473. }
  474. case SpvOpMemberName: {
  475. const auto target = inst->GetOperandAs<uint32_t>(0);
  476. const auto* str = reinterpret_cast<const char*>(inst->words().data() +
  477. inst->operand(2).offset);
  478. AssignNameToId(target, str);
  479. break;
  480. }
  481. case SpvOpSourceContinued:
  482. case SpvOpSource:
  483. case SpvOpSourceExtension:
  484. case SpvOpString:
  485. case SpvOpLine:
  486. case SpvOpNoLine:
  487. default:
  488. break;
  489. }
  490. }
  491. void ValidationState_t::RegisterInstruction(Instruction* inst) {
  492. if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst));
  493. // If the instruction is using an OpTypeSampledImage as an operand, it should
  494. // be recorded. The validator will ensure that all usages of an
  495. // OpTypeSampledImage and its definition are in the same basic block.
  496. for (uint16_t i = 0; i < inst->operands().size(); ++i) {
  497. const spv_parsed_operand_t& operand = inst->operand(i);
  498. if (SPV_OPERAND_TYPE_ID == operand.type) {
  499. const uint32_t operand_word = inst->word(operand.offset);
  500. Instruction* operand_inst = FindDef(operand_word);
  501. if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) {
  502. RegisterSampledImageConsumer(operand_word, inst);
  503. }
  504. }
  505. }
  506. }
  507. std::vector<Instruction*> ValidationState_t::getSampledImageConsumers(
  508. uint32_t sampled_image_id) const {
  509. std::vector<Instruction*> result;
  510. auto iter = sampled_image_consumers_.find(sampled_image_id);
  511. if (iter != sampled_image_consumers_.end()) {
  512. result = iter->second;
  513. }
  514. return result;
  515. }
  516. void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id,
  517. Instruction* consumer) {
  518. sampled_image_consumers_[sampled_image_id].push_back(consumer);
  519. }
  520. uint32_t ValidationState_t::getIdBound() const { return id_bound_; }
  521. void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; }
  522. bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) {
  523. std::vector<uint32_t> key;
  524. key.push_back(static_cast<uint32_t>(inst->opcode()));
  525. for (size_t index = 0; index < inst->operands().size(); ++index) {
  526. const spv_parsed_operand_t& operand = inst->operand(index);
  527. if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue;
  528. const int words_begin = operand.offset;
  529. const int words_end = words_begin + operand.num_words;
  530. assert(words_end <= static_cast<int>(inst->words().size()));
  531. key.insert(key.end(), inst->words().begin() + words_begin,
  532. inst->words().begin() + words_end);
  533. }
  534. return unique_type_declarations_.insert(std::move(key)).second;
  535. }
  536. uint32_t ValidationState_t::GetTypeId(uint32_t id) const {
  537. const Instruction* inst = FindDef(id);
  538. return inst ? inst->type_id() : 0;
  539. }
  540. SpvOp ValidationState_t::GetIdOpcode(uint32_t id) const {
  541. const Instruction* inst = FindDef(id);
  542. return inst ? inst->opcode() : SpvOpNop;
  543. }
  544. uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
  545. const Instruction* inst = FindDef(id);
  546. assert(inst);
  547. switch (inst->opcode()) {
  548. case SpvOpTypeFloat:
  549. case SpvOpTypeInt:
  550. case SpvOpTypeBool:
  551. return id;
  552. case SpvOpTypeVector:
  553. return inst->word(2);
  554. case SpvOpTypeMatrix:
  555. return GetComponentType(inst->word(2));
  556. case SpvOpTypeCooperativeMatrixNV:
  557. return inst->word(2);
  558. default:
  559. break;
  560. }
  561. if (inst->type_id()) return GetComponentType(inst->type_id());
  562. assert(0);
  563. return 0;
  564. }
  565. uint32_t ValidationState_t::GetDimension(uint32_t id) const {
  566. const Instruction* inst = FindDef(id);
  567. assert(inst);
  568. switch (inst->opcode()) {
  569. case SpvOpTypeFloat:
  570. case SpvOpTypeInt:
  571. case SpvOpTypeBool:
  572. return 1;
  573. case SpvOpTypeVector:
  574. case SpvOpTypeMatrix:
  575. return inst->word(3);
  576. case SpvOpTypeCooperativeMatrixNV:
  577. // Actual dimension isn't known, return 0
  578. return 0;
  579. default:
  580. break;
  581. }
  582. if (inst->type_id()) return GetDimension(inst->type_id());
  583. assert(0);
  584. return 0;
  585. }
  586. uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
  587. const uint32_t component_type_id = GetComponentType(id);
  588. const Instruction* inst = FindDef(component_type_id);
  589. assert(inst);
  590. if (inst->opcode() == SpvOpTypeFloat || inst->opcode() == SpvOpTypeInt)
  591. return inst->word(2);
  592. if (inst->opcode() == SpvOpTypeBool) return 1;
  593. assert(0);
  594. return 0;
  595. }
  596. bool ValidationState_t::IsVoidType(uint32_t id) const {
  597. const Instruction* inst = FindDef(id);
  598. assert(inst);
  599. return inst->opcode() == SpvOpTypeVoid;
  600. }
  601. bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
  602. const Instruction* inst = FindDef(id);
  603. assert(inst);
  604. return inst->opcode() == SpvOpTypeFloat;
  605. }
  606. bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
  607. const Instruction* inst = FindDef(id);
  608. assert(inst);
  609. if (inst->opcode() == SpvOpTypeVector) {
  610. return IsFloatScalarType(GetComponentType(id));
  611. }
  612. return false;
  613. }
  614. bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
  615. const Instruction* inst = FindDef(id);
  616. assert(inst);
  617. if (inst->opcode() == SpvOpTypeFloat) {
  618. return true;
  619. }
  620. if (inst->opcode() == SpvOpTypeVector) {
  621. return IsFloatScalarType(GetComponentType(id));
  622. }
  623. return false;
  624. }
  625. bool ValidationState_t::IsIntScalarType(uint32_t id) const {
  626. const Instruction* inst = FindDef(id);
  627. assert(inst);
  628. return inst->opcode() == SpvOpTypeInt;
  629. }
  630. bool ValidationState_t::IsIntVectorType(uint32_t id) const {
  631. const Instruction* inst = FindDef(id);
  632. assert(inst);
  633. if (inst->opcode() == SpvOpTypeVector) {
  634. return IsIntScalarType(GetComponentType(id));
  635. }
  636. return false;
  637. }
  638. bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
  639. const Instruction* inst = FindDef(id);
  640. assert(inst);
  641. if (inst->opcode() == SpvOpTypeInt) {
  642. return true;
  643. }
  644. if (inst->opcode() == SpvOpTypeVector) {
  645. return IsIntScalarType(GetComponentType(id));
  646. }
  647. return false;
  648. }
  649. bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
  650. const Instruction* inst = FindDef(id);
  651. assert(inst);
  652. return inst->opcode() == SpvOpTypeInt && inst->word(3) == 0;
  653. }
  654. bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
  655. const Instruction* inst = FindDef(id);
  656. assert(inst);
  657. if (inst->opcode() == SpvOpTypeVector) {
  658. return IsUnsignedIntScalarType(GetComponentType(id));
  659. }
  660. return false;
  661. }
  662. bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const {
  663. const Instruction* inst = FindDef(id);
  664. assert(inst);
  665. return inst->opcode() == SpvOpTypeInt && inst->word(3) == 1;
  666. }
  667. bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const {
  668. const Instruction* inst = FindDef(id);
  669. assert(inst);
  670. if (inst->opcode() == SpvOpTypeVector) {
  671. return IsSignedIntScalarType(GetComponentType(id));
  672. }
  673. return false;
  674. }
  675. bool ValidationState_t::IsBoolScalarType(uint32_t id) const {
  676. const Instruction* inst = FindDef(id);
  677. assert(inst);
  678. return inst->opcode() == SpvOpTypeBool;
  679. }
  680. bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
  681. const Instruction* inst = FindDef(id);
  682. assert(inst);
  683. if (inst->opcode() == SpvOpTypeVector) {
  684. return IsBoolScalarType(GetComponentType(id));
  685. }
  686. return false;
  687. }
  688. bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const {
  689. const Instruction* inst = FindDef(id);
  690. assert(inst);
  691. if (inst->opcode() == SpvOpTypeBool) {
  692. return true;
  693. }
  694. if (inst->opcode() == SpvOpTypeVector) {
  695. return IsBoolScalarType(GetComponentType(id));
  696. }
  697. return false;
  698. }
  699. bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
  700. const Instruction* inst = FindDef(id);
  701. assert(inst);
  702. if (inst->opcode() == SpvOpTypeMatrix) {
  703. return IsFloatScalarType(GetComponentType(id));
  704. }
  705. return false;
  706. }
  707. bool ValidationState_t::GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows,
  708. uint32_t* num_cols,
  709. uint32_t* column_type,
  710. uint32_t* component_type) const {
  711. if (!id) return false;
  712. const Instruction* mat_inst = FindDef(id);
  713. assert(mat_inst);
  714. if (mat_inst->opcode() != SpvOpTypeMatrix) return false;
  715. const uint32_t vec_type = mat_inst->word(2);
  716. const Instruction* vec_inst = FindDef(vec_type);
  717. assert(vec_inst);
  718. if (vec_inst->opcode() != SpvOpTypeVector) {
  719. assert(0);
  720. return false;
  721. }
  722. *num_cols = mat_inst->word(3);
  723. *num_rows = vec_inst->word(3);
  724. *column_type = mat_inst->word(2);
  725. *component_type = vec_inst->word(2);
  726. return true;
  727. }
  728. bool ValidationState_t::GetStructMemberTypes(
  729. uint32_t struct_type_id, std::vector<uint32_t>* member_types) const {
  730. member_types->clear();
  731. if (!struct_type_id) return false;
  732. const Instruction* inst = FindDef(struct_type_id);
  733. assert(inst);
  734. if (inst->opcode() != SpvOpTypeStruct) return false;
  735. *member_types =
  736. std::vector<uint32_t>(inst->words().cbegin() + 2, inst->words().cend());
  737. if (member_types->empty()) return false;
  738. return true;
  739. }
  740. bool ValidationState_t::IsPointerType(uint32_t id) const {
  741. const Instruction* inst = FindDef(id);
  742. assert(inst);
  743. return inst->opcode() == SpvOpTypePointer;
  744. }
  745. bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
  746. uint32_t* storage_class) const {
  747. if (!id) return false;
  748. const Instruction* inst = FindDef(id);
  749. assert(inst);
  750. if (inst->opcode() != SpvOpTypePointer) return false;
  751. *storage_class = inst->word(2);
  752. *data_type = inst->word(3);
  753. return true;
  754. }
  755. bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
  756. const Instruction* inst = FindDef(id);
  757. assert(inst);
  758. return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
  759. }
  760. bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
  761. if (!IsCooperativeMatrixType(id)) return false;
  762. return IsFloatScalarType(FindDef(id)->word(2));
  763. }
  764. bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
  765. if (!IsCooperativeMatrixType(id)) return false;
  766. return IsIntScalarType(FindDef(id)->word(2));
  767. }
  768. bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
  769. if (!IsCooperativeMatrixType(id)) return false;
  770. return IsUnsignedIntScalarType(FindDef(id)->word(2));
  771. }
  772. spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
  773. const Instruction* inst, uint32_t m1, uint32_t m2) {
  774. const auto m1_type = FindDef(m1);
  775. const auto m2_type = FindDef(m2);
  776. if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
  777. m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
  778. return diag(SPV_ERROR_INVALID_DATA, inst)
  779. << "Expected cooperative matrix types";
  780. }
  781. uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
  782. uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
  783. uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
  784. uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
  785. uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
  786. uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
  787. bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
  788. m2_is_const_int32 = false;
  789. uint32_t m1_value = 0, m2_value = 0;
  790. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  791. EvalInt32IfConst(m1_scope_id);
  792. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  793. EvalInt32IfConst(m2_scope_id);
  794. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  795. return diag(SPV_ERROR_INVALID_DATA, inst)
  796. << "Expected scopes of Matrix and Result Type to be "
  797. << "identical";
  798. }
  799. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  800. EvalInt32IfConst(m1_rows_id);
  801. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  802. EvalInt32IfConst(m2_rows_id);
  803. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  804. return diag(SPV_ERROR_INVALID_DATA, inst)
  805. << "Expected rows of Matrix type and Result Type to be "
  806. << "identical";
  807. }
  808. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  809. EvalInt32IfConst(m1_cols_id);
  810. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  811. EvalInt32IfConst(m2_cols_id);
  812. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  813. return diag(SPV_ERROR_INVALID_DATA, inst)
  814. << "Expected columns of Matrix type and Result Type to be "
  815. << "identical";
  816. }
  817. return SPV_SUCCESS;
  818. }
  819. uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
  820. size_t operand_index) const {
  821. return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
  822. }
  823. bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
  824. const Instruction* inst = FindDef(id);
  825. if (!inst) {
  826. assert(0 && "Instruction not found");
  827. return false;
  828. }
  829. if (inst->opcode() != SpvOpConstant && inst->opcode() != SpvOpSpecConstant)
  830. return false;
  831. if (!IsIntScalarType(inst->type_id())) return false;
  832. if (inst->words().size() == 4) {
  833. *val = inst->word(3);
  834. } else {
  835. assert(inst->words().size() == 5);
  836. *val = inst->word(3);
  837. *val |= uint64_t(inst->word(4)) << 32;
  838. }
  839. return true;
  840. }
  841. std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
  842. uint32_t id) const {
  843. const Instruction* const inst = FindDef(id);
  844. assert(inst);
  845. const uint32_t type = inst->type_id();
  846. if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) {
  847. return std::make_tuple(false, false, 0);
  848. }
  849. // Spec constant values cannot be evaluated so don't consider constant for
  850. // the purpose of this method.
  851. if (!spvOpcodeIsConstant(inst->opcode()) ||
  852. spvOpcodeIsSpecConstant(inst->opcode())) {
  853. return std::make_tuple(true, false, 0);
  854. }
  855. if (inst->opcode() == SpvOpConstantNull) {
  856. return std::make_tuple(true, true, 0);
  857. }
  858. assert(inst->words().size() == 4);
  859. return std::make_tuple(true, true, inst->word(3));
  860. }
  861. void ValidationState_t::ComputeFunctionToEntryPointMapping() {
  862. for (const uint32_t entry_point : entry_points()) {
  863. std::stack<uint32_t> call_stack;
  864. std::set<uint32_t> visited;
  865. call_stack.push(entry_point);
  866. while (!call_stack.empty()) {
  867. const uint32_t called_func_id = call_stack.top();
  868. call_stack.pop();
  869. if (!visited.insert(called_func_id).second) continue;
  870. function_to_entry_points_[called_func_id].push_back(entry_point);
  871. const Function* called_func = function(called_func_id);
  872. if (called_func) {
  873. // Other checks should error out on this invalid SPIR-V.
  874. for (const uint32_t new_call : called_func->function_call_targets()) {
  875. call_stack.push(new_call);
  876. }
  877. }
  878. }
  879. }
  880. }
  881. void ValidationState_t::ComputeRecursiveEntryPoints() {
  882. for (const Function& func : functions()) {
  883. std::stack<uint32_t> call_stack;
  884. std::set<uint32_t> visited;
  885. for (const uint32_t new_call : func.function_call_targets()) {
  886. call_stack.push(new_call);
  887. }
  888. while (!call_stack.empty()) {
  889. const uint32_t called_func_id = call_stack.top();
  890. call_stack.pop();
  891. if (!visited.insert(called_func_id).second) continue;
  892. if (called_func_id == func.id()) {
  893. for (const uint32_t entry_point :
  894. function_to_entry_points_[called_func_id])
  895. recursive_entry_points_.insert(entry_point);
  896. break;
  897. }
  898. const Function* called_func = function(called_func_id);
  899. if (called_func) {
  900. // Other checks should error out on this invalid SPIR-V.
  901. for (const uint32_t new_call : called_func->function_call_targets()) {
  902. call_stack.push(new_call);
  903. }
  904. }
  905. }
  906. }
  907. }
  908. const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
  909. uint32_t func) const {
  910. auto iter = function_to_entry_points_.find(func);
  911. if (iter == function_to_entry_points_.end()) {
  912. return empty_ids_;
  913. } else {
  914. return iter->second;
  915. }
  916. }
  917. std::set<uint32_t> ValidationState_t::EntryPointReferences(uint32_t id) const {
  918. std::set<uint32_t> referenced_entry_points;
  919. const auto inst = FindDef(id);
  920. if (!inst) return referenced_entry_points;
  921. std::vector<const Instruction*> stack;
  922. stack.push_back(inst);
  923. while (!stack.empty()) {
  924. const auto current_inst = stack.back();
  925. stack.pop_back();
  926. if (const auto func = current_inst->function()) {
  927. // Instruction lives in a function, we can stop searching.
  928. const auto function_entry_points = FunctionEntryPoints(func->id());
  929. referenced_entry_points.insert(function_entry_points.begin(),
  930. function_entry_points.end());
  931. } else {
  932. // Instruction is in the global scope, keep searching its uses.
  933. for (auto pair : current_inst->uses()) {
  934. const auto next_inst = pair.first;
  935. stack.push_back(next_inst);
  936. }
  937. }
  938. }
  939. return referenced_entry_points;
  940. }
  941. std::string ValidationState_t::Disassemble(const Instruction& inst) const {
  942. const spv_parsed_instruction_t& c_inst(inst.c_inst());
  943. return Disassemble(c_inst.words, c_inst.num_words);
  944. }
  945. std::string ValidationState_t::Disassemble(const uint32_t* words,
  946. uint16_t num_words) const {
  947. uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
  948. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
  949. return spvInstructionBinaryToText(context()->target_env, words, num_words,
  950. words_, num_words_, disassembly_options);
  951. }
  952. bool ValidationState_t::LogicallyMatch(const Instruction* lhs,
  953. const Instruction* rhs,
  954. bool check_decorations) {
  955. if (lhs->opcode() != rhs->opcode()) {
  956. return false;
  957. }
  958. if (check_decorations) {
  959. const auto& dec_a = id_decorations(lhs->id());
  960. const auto& dec_b = id_decorations(rhs->id());
  961. for (const auto& dec : dec_b) {
  962. if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
  963. return false;
  964. }
  965. }
  966. }
  967. if (lhs->opcode() == SpvOpTypeArray) {
  968. // Size operands must match.
  969. if (lhs->GetOperandAs<uint32_t>(2u) != rhs->GetOperandAs<uint32_t>(2u)) {
  970. return false;
  971. }
  972. // Elements must match or logically match.
  973. const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(1u);
  974. const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(1u);
  975. if (lhs_ele_id == rhs_ele_id) {
  976. return true;
  977. }
  978. const auto lhs_ele = FindDef(lhs_ele_id);
  979. const auto rhs_ele = FindDef(rhs_ele_id);
  980. if (!lhs_ele || !rhs_ele) {
  981. return false;
  982. }
  983. return LogicallyMatch(lhs_ele, rhs_ele, check_decorations);
  984. } else if (lhs->opcode() == SpvOpTypeStruct) {
  985. // Number of elements must match.
  986. if (lhs->operands().size() != rhs->operands().size()) {
  987. return false;
  988. }
  989. for (size_t i = 1u; i < lhs->operands().size(); ++i) {
  990. const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(i);
  991. const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(i);
  992. // Elements must match or logically match.
  993. if (lhs_ele_id == rhs_ele_id) {
  994. continue;
  995. }
  996. const auto lhs_ele = FindDef(lhs_ele_id);
  997. const auto rhs_ele = FindDef(rhs_ele_id);
  998. if (!lhs_ele || !rhs_ele) {
  999. return false;
  1000. }
  1001. if (!LogicallyMatch(lhs_ele, rhs_ele, check_decorations)) {
  1002. return false;
  1003. }
  1004. }
  1005. // All checks passed.
  1006. return true;
  1007. }
  1008. // No other opcodes are acceptable at this point. Arrays and structs are
  1009. // caught above and if they're elements are not arrays or structs they are
  1010. // required to match exactly.
  1011. return false;
  1012. }
  1013. const Instruction* ValidationState_t::TracePointer(
  1014. const Instruction* inst) const {
  1015. auto base_ptr = inst;
  1016. while (base_ptr->opcode() == SpvOpAccessChain ||
  1017. base_ptr->opcode() == SpvOpInBoundsAccessChain ||
  1018. base_ptr->opcode() == SpvOpPtrAccessChain ||
  1019. base_ptr->opcode() == SpvOpInBoundsPtrAccessChain ||
  1020. base_ptr->opcode() == SpvOpCopyObject) {
  1021. base_ptr = FindDef(base_ptr->GetOperandAs<uint32_t>(2u));
  1022. }
  1023. return base_ptr;
  1024. }
  1025. bool ValidationState_t::ContainsSizedIntOrFloatType(uint32_t id, SpvOp type,
  1026. uint32_t width) const {
  1027. if (type != SpvOpTypeInt && type != SpvOpTypeFloat) return false;
  1028. const auto inst = FindDef(id);
  1029. if (!inst) return false;
  1030. if (inst->opcode() == type) {
  1031. return inst->GetOperandAs<uint32_t>(1u) == width;
  1032. }
  1033. switch (inst->opcode()) {
  1034. case SpvOpTypeArray:
  1035. case SpvOpTypeRuntimeArray:
  1036. case SpvOpTypeVector:
  1037. case SpvOpTypeMatrix:
  1038. case SpvOpTypeImage:
  1039. case SpvOpTypeSampledImage:
  1040. case SpvOpTypeCooperativeMatrixNV:
  1041. return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(1u), type,
  1042. width);
  1043. case SpvOpTypePointer:
  1044. if (IsForwardPointer(id)) return false;
  1045. return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(2u), type,
  1046. width);
  1047. case SpvOpTypeFunction:
  1048. case SpvOpTypeStruct: {
  1049. for (uint32_t i = 1; i < inst->operands().size(); ++i) {
  1050. if (ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(i), type,
  1051. width))
  1052. return true;
  1053. }
  1054. return false;
  1055. }
  1056. default:
  1057. return false;
  1058. }
  1059. }
  1060. bool ValidationState_t::ContainsLimitedUseIntOrFloatType(uint32_t id) const {
  1061. if ((!HasCapability(SpvCapabilityInt16) &&
  1062. ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 16)) ||
  1063. (!HasCapability(SpvCapabilityInt8) &&
  1064. ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 8)) ||
  1065. (!HasCapability(SpvCapabilityFloat16) &&
  1066. ContainsSizedIntOrFloatType(id, SpvOpTypeFloat, 16))) {
  1067. return true;
  1068. }
  1069. return false;
  1070. }
  1071. bool ValidationState_t::IsValidStorageClass(
  1072. SpvStorageClass storage_class) const {
  1073. if (spvIsWebGPUEnv(context()->target_env)) {
  1074. switch (storage_class) {
  1075. case SpvStorageClassUniformConstant:
  1076. case SpvStorageClassUniform:
  1077. case SpvStorageClassStorageBuffer:
  1078. case SpvStorageClassInput:
  1079. case SpvStorageClassOutput:
  1080. case SpvStorageClassImage:
  1081. case SpvStorageClassWorkgroup:
  1082. case SpvStorageClassPrivate:
  1083. case SpvStorageClassFunction:
  1084. return true;
  1085. default:
  1086. return false;
  1087. }
  1088. }
  1089. if (spvIsVulkanEnv(context()->target_env)) {
  1090. switch (storage_class) {
  1091. case SpvStorageClassUniformConstant:
  1092. case SpvStorageClassUniform:
  1093. case SpvStorageClassStorageBuffer:
  1094. case SpvStorageClassInput:
  1095. case SpvStorageClassOutput:
  1096. case SpvStorageClassImage:
  1097. case SpvStorageClassWorkgroup:
  1098. case SpvStorageClassPrivate:
  1099. case SpvStorageClassFunction:
  1100. case SpvStorageClassPushConstant:
  1101. case SpvStorageClassPhysicalStorageBuffer:
  1102. case SpvStorageClassRayPayloadNV:
  1103. case SpvStorageClassIncomingRayPayloadNV:
  1104. case SpvStorageClassHitAttributeNV:
  1105. case SpvStorageClassCallableDataNV:
  1106. case SpvStorageClassIncomingCallableDataNV:
  1107. case SpvStorageClassShaderRecordBufferNV:
  1108. return true;
  1109. default:
  1110. return false;
  1111. }
  1112. }
  1113. return true;
  1114. }
  1115. } // namespace val
  1116. } // namespace spvtools