validation_state.cpp 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320
  1. // Copyright (c) 2015-2016 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/val/validation_state.h"
  15. #include <cassert>
  16. #include <stack>
  17. #include <utility>
  18. #include "source/opcode.h"
  19. #include "source/spirv_constant.h"
  20. #include "source/spirv_target_env.h"
  21. #include "source/val/basic_block.h"
  22. #include "source/val/construct.h"
  23. #include "source/val/function.h"
  24. #include "spirv-tools/libspirv.h"
  25. namespace spvtools {
  26. namespace val {
  27. namespace {
  28. bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
  29. // See Section 2.4
  30. bool out = false;
  31. // clang-format off
  32. switch (layout) {
  33. case kLayoutCapabilities: out = op == SpvOpCapability; break;
  34. case kLayoutExtensions: out = op == SpvOpExtension; break;
  35. case kLayoutExtInstImport: out = op == SpvOpExtInstImport; break;
  36. case kLayoutMemoryModel: out = op == SpvOpMemoryModel; break;
  37. case kLayoutEntryPoint: out = op == SpvOpEntryPoint; break;
  38. case kLayoutExecutionMode:
  39. out = op == SpvOpExecutionMode || op == SpvOpExecutionModeId;
  40. break;
  41. case kLayoutDebug1:
  42. switch (op) {
  43. case SpvOpSourceContinued:
  44. case SpvOpSource:
  45. case SpvOpSourceExtension:
  46. case SpvOpString:
  47. out = true;
  48. break;
  49. default: break;
  50. }
  51. break;
  52. case kLayoutDebug2:
  53. switch (op) {
  54. case SpvOpName:
  55. case SpvOpMemberName:
  56. out = true;
  57. break;
  58. default: break;
  59. }
  60. break;
  61. case kLayoutDebug3:
  62. // Only OpModuleProcessed is allowed here.
  63. out = (op == SpvOpModuleProcessed);
  64. break;
  65. case kLayoutAnnotations:
  66. switch (op) {
  67. case SpvOpDecorate:
  68. case SpvOpMemberDecorate:
  69. case SpvOpGroupDecorate:
  70. case SpvOpGroupMemberDecorate:
  71. case SpvOpDecorationGroup:
  72. case SpvOpDecorateId:
  73. case SpvOpDecorateStringGOOGLE:
  74. case SpvOpMemberDecorateStringGOOGLE:
  75. out = true;
  76. break;
  77. default: break;
  78. }
  79. break;
  80. case kLayoutTypes:
  81. if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
  82. out = true;
  83. break;
  84. }
  85. switch (op) {
  86. case SpvOpTypeForwardPointer:
  87. case SpvOpVariable:
  88. case SpvOpLine:
  89. case SpvOpNoLine:
  90. case SpvOpUndef:
  91. out = true;
  92. break;
  93. default: break;
  94. }
  95. break;
  96. case kLayoutFunctionDeclarations:
  97. case kLayoutFunctionDefinitions:
  98. // NOTE: These instructions should NOT be in these layout sections
  99. if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
  100. out = false;
  101. break;
  102. }
  103. switch (op) {
  104. case SpvOpCapability:
  105. case SpvOpExtension:
  106. case SpvOpExtInstImport:
  107. case SpvOpMemoryModel:
  108. case SpvOpEntryPoint:
  109. case SpvOpExecutionMode:
  110. case SpvOpExecutionModeId:
  111. case SpvOpSourceContinued:
  112. case SpvOpSource:
  113. case SpvOpSourceExtension:
  114. case SpvOpString:
  115. case SpvOpName:
  116. case SpvOpMemberName:
  117. case SpvOpModuleProcessed:
  118. case SpvOpDecorate:
  119. case SpvOpMemberDecorate:
  120. case SpvOpGroupDecorate:
  121. case SpvOpGroupMemberDecorate:
  122. case SpvOpDecorationGroup:
  123. case SpvOpTypeForwardPointer:
  124. out = false;
  125. break;
  126. default:
  127. out = true;
  128. break;
  129. }
  130. }
  131. // clang-format on
  132. return out;
  133. }
  134. // Counts the number of instructions and functions in the file.
  135. spv_result_t CountInstructions(void* user_data,
  136. const spv_parsed_instruction_t* inst) {
  137. ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
  138. if (inst->opcode == SpvOpFunction) _.increment_total_functions();
  139. _.increment_total_instructions();
  140. return SPV_SUCCESS;
  141. }
  142. spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
  143. uint32_t version, uint32_t generator, uint32_t id_bound,
  144. uint32_t) {
  145. ValidationState_t& vstate =
  146. *(reinterpret_cast<ValidationState_t*>(user_data));
  147. vstate.setIdBound(id_bound);
  148. vstate.setGenerator(generator);
  149. vstate.setVersion(version);
  150. return SPV_SUCCESS;
  151. }
  152. // Add features based on SPIR-V core version number.
  153. void UpdateFeaturesBasedOnSpirvVersion(ValidationState_t::Feature* features,
  154. uint32_t version) {
  155. assert(features);
  156. if (version >= SPV_SPIRV_VERSION_WORD(1, 4)) {
  157. features->select_between_composites = true;
  158. features->copy_memory_permits_two_memory_accesses = true;
  159. features->uconvert_spec_constant_op = true;
  160. features->nonwritable_var_in_function_or_private = true;
  161. }
  162. }
  163. } // namespace
  164. ValidationState_t::ValidationState_t(const spv_const_context ctx,
  165. const spv_const_validator_options opt,
  166. const uint32_t* words,
  167. const size_t num_words,
  168. const uint32_t max_warnings)
  169. : context_(ctx),
  170. options_(opt),
  171. words_(words),
  172. num_words_(num_words),
  173. unresolved_forward_ids_{},
  174. operand_names_{},
  175. current_layout_section_(kLayoutCapabilities),
  176. module_functions_(),
  177. module_capabilities_(),
  178. module_extensions_(),
  179. ordered_instructions_(),
  180. all_definitions_(),
  181. global_vars_(),
  182. local_vars_(),
  183. struct_nesting_depth_(),
  184. struct_has_nested_blockorbufferblock_struct_(),
  185. grammar_(ctx),
  186. addressing_model_(SpvAddressingModelMax),
  187. memory_model_(SpvMemoryModelMax),
  188. pointer_size_and_alignment_(0),
  189. in_function_(false),
  190. num_of_warnings_(0),
  191. max_num_of_warnings_(max_warnings) {
  192. assert(opt && "Validator options may not be Null.");
  193. const auto env = context_->target_env;
  194. if (spvIsVulkanEnv(env)) {
  195. // Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core.
  196. if (env != SPV_ENV_VULKAN_1_0) {
  197. features_.env_relaxed_block_layout = true;
  198. }
  199. }
  200. // Only attempt to count if we have words, otherwise let the other validation
  201. // fail and generate an error.
  202. if (num_words > 0) {
  203. // Count the number of instructions in the binary.
  204. // This parse should not produce any error messages. Hijack the context and
  205. // replace the message consumer so that we do not pollute any state in input
  206. // consumer.
  207. spv_context_t hijacked_context = *ctx;
  208. hijacked_context.consumer = [](spv_message_level_t, const char*,
  209. const spv_position_t&, const char*) {};
  210. spvBinaryParse(&hijacked_context, this, words, num_words, setHeader,
  211. CountInstructions,
  212. /* diagnostic = */ nullptr);
  213. preallocateStorage();
  214. }
  215. UpdateFeaturesBasedOnSpirvVersion(&features_, version_);
  216. friendly_mapper_ = spvtools::MakeUnique<spvtools::FriendlyNameMapper>(
  217. context_, words_, num_words_);
  218. name_mapper_ = friendly_mapper_->GetNameMapper();
  219. }
  220. void ValidationState_t::preallocateStorage() {
  221. ordered_instructions_.reserve(total_instructions_);
  222. module_functions_.reserve(total_functions_);
  223. }
  224. spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) {
  225. unresolved_forward_ids_.insert(id);
  226. return SPV_SUCCESS;
  227. }
  228. spv_result_t ValidationState_t::RemoveIfForwardDeclared(uint32_t id) {
  229. unresolved_forward_ids_.erase(id);
  230. return SPV_SUCCESS;
  231. }
  232. spv_result_t ValidationState_t::RegisterForwardPointer(uint32_t id) {
  233. forward_pointer_ids_.insert(id);
  234. return SPV_SUCCESS;
  235. }
  236. bool ValidationState_t::IsForwardPointer(uint32_t id) const {
  237. return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end());
  238. }
  239. void ValidationState_t::AssignNameToId(uint32_t id, std::string name) {
  240. operand_names_[id] = name;
  241. }
  242. std::string ValidationState_t::getIdName(uint32_t id) const {
  243. const std::string id_name = name_mapper_(id);
  244. std::stringstream out;
  245. out << id << "[%" << id_name << "]";
  246. return out.str();
  247. }
  248. size_t ValidationState_t::unresolved_forward_id_count() const {
  249. return unresolved_forward_ids_.size();
  250. }
  251. std::vector<uint32_t> ValidationState_t::UnresolvedForwardIds() const {
  252. std::vector<uint32_t> out(std::begin(unresolved_forward_ids_),
  253. std::end(unresolved_forward_ids_));
  254. return out;
  255. }
  256. bool ValidationState_t::IsDefinedId(uint32_t id) const {
  257. return all_definitions_.find(id) != std::end(all_definitions_);
  258. }
  259. const Instruction* ValidationState_t::FindDef(uint32_t id) const {
  260. auto it = all_definitions_.find(id);
  261. if (it == all_definitions_.end()) return nullptr;
  262. return it->second;
  263. }
  264. Instruction* ValidationState_t::FindDef(uint32_t id) {
  265. auto it = all_definitions_.find(id);
  266. if (it == all_definitions_.end()) return nullptr;
  267. return it->second;
  268. }
  269. ModuleLayoutSection ValidationState_t::current_layout_section() const {
  270. return current_layout_section_;
  271. }
  272. void ValidationState_t::ProgressToNextLayoutSectionOrder() {
  273. // Guard against going past the last element(kLayoutFunctionDefinitions)
  274. if (current_layout_section_ <= kLayoutFunctionDefinitions) {
  275. current_layout_section_ =
  276. static_cast<ModuleLayoutSection>(current_layout_section_ + 1);
  277. }
  278. }
  279. bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) {
  280. return IsInstructionInLayoutSection(current_layout_section_, op);
  281. }
  282. DiagnosticStream ValidationState_t::diag(spv_result_t error_code,
  283. const Instruction* inst) {
  284. if (error_code == SPV_WARNING) {
  285. if (num_of_warnings_ == max_num_of_warnings_) {
  286. DiagnosticStream({0, 0, 0}, context_->consumer, "", error_code)
  287. << "Other warnings have been suppressed.\n";
  288. }
  289. if (num_of_warnings_ >= max_num_of_warnings_) {
  290. return DiagnosticStream({0, 0, 0}, nullptr, "", error_code);
  291. }
  292. ++num_of_warnings_;
  293. }
  294. std::string disassembly;
  295. if (inst) disassembly = Disassemble(*inst);
  296. return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0},
  297. context_->consumer, disassembly, error_code);
  298. }
  299. std::vector<Function>& ValidationState_t::functions() {
  300. return module_functions_;
  301. }
  302. Function& ValidationState_t::current_function() {
  303. assert(in_function_body());
  304. return module_functions_.back();
  305. }
  306. const Function& ValidationState_t::current_function() const {
  307. assert(in_function_body());
  308. return module_functions_.back();
  309. }
  310. const Function* ValidationState_t::function(uint32_t id) const {
  311. const auto it = id_to_function_.find(id);
  312. if (it == id_to_function_.end()) return nullptr;
  313. return it->second;
  314. }
  315. Function* ValidationState_t::function(uint32_t id) {
  316. auto it = id_to_function_.find(id);
  317. if (it == id_to_function_.end()) return nullptr;
  318. return it->second;
  319. }
  320. bool ValidationState_t::in_function_body() const { return in_function_; }
  321. bool ValidationState_t::in_block() const {
  322. return module_functions_.empty() == false &&
  323. module_functions_.back().current_block() != nullptr;
  324. }
  325. void ValidationState_t::RegisterCapability(SpvCapability cap) {
  326. // Avoid redundant work. Otherwise the recursion could induce work
  327. // quadrdatic in the capability dependency depth. (Ok, not much, but
  328. // it's something.)
  329. if (module_capabilities_.Contains(cap)) return;
  330. module_capabilities_.Add(cap);
  331. spv_operand_desc desc;
  332. if (SPV_SUCCESS ==
  333. grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) {
  334. CapabilitySet(desc->numCapabilities, desc->capabilities)
  335. .ForEach([this](SpvCapability c) { RegisterCapability(c); });
  336. }
  337. switch (cap) {
  338. case SpvCapabilityKernel:
  339. features_.group_ops_reduce_and_scans = true;
  340. break;
  341. case SpvCapabilityInt8:
  342. features_.use_int8_type = true;
  343. features_.declare_int8_type = true;
  344. break;
  345. case SpvCapabilityStorageBuffer8BitAccess:
  346. case SpvCapabilityUniformAndStorageBuffer8BitAccess:
  347. case SpvCapabilityStoragePushConstant8:
  348. features_.declare_int8_type = true;
  349. break;
  350. case SpvCapabilityInt16:
  351. features_.declare_int16_type = true;
  352. break;
  353. case SpvCapabilityFloat16:
  354. case SpvCapabilityFloat16Buffer:
  355. features_.declare_float16_type = true;
  356. break;
  357. case SpvCapabilityStorageUniformBufferBlock16:
  358. case SpvCapabilityStorageUniform16:
  359. case SpvCapabilityStoragePushConstant16:
  360. case SpvCapabilityStorageInputOutput16:
  361. features_.declare_int16_type = true;
  362. features_.declare_float16_type = true;
  363. features_.free_fp_rounding_mode = true;
  364. break;
  365. case SpvCapabilityVariablePointers:
  366. features_.variable_pointers = true;
  367. features_.variable_pointers_storage_buffer = true;
  368. break;
  369. case SpvCapabilityVariablePointersStorageBuffer:
  370. features_.variable_pointers_storage_buffer = true;
  371. break;
  372. default:
  373. break;
  374. }
  375. }
  376. void ValidationState_t::RegisterExtension(Extension ext) {
  377. if (module_extensions_.Contains(ext)) return;
  378. module_extensions_.Add(ext);
  379. switch (ext) {
  380. case kSPV_AMD_gpu_shader_half_float:
  381. case kSPV_AMD_gpu_shader_half_float_fetch:
  382. // SPV_AMD_gpu_shader_half_float enables float16 type.
  383. // https://github.com/KhronosGroup/SPIRV-Tools/issues/1375
  384. features_.declare_float16_type = true;
  385. break;
  386. case kSPV_AMD_gpu_shader_int16:
  387. // This is not yet in the extension, but it's recommended for it.
  388. // See https://github.com/KhronosGroup/glslang/issues/848
  389. features_.uconvert_spec_constant_op = true;
  390. break;
  391. case kSPV_AMD_shader_ballot:
  392. // The grammar doesn't encode the fact that SPV_AMD_shader_ballot
  393. // enables the use of group operations Reduce, InclusiveScan,
  394. // and ExclusiveScan. Enable it manually.
  395. // https://github.com/KhronosGroup/SPIRV-Tools/issues/991
  396. features_.group_ops_reduce_and_scans = true;
  397. break;
  398. default:
  399. break;
  400. }
  401. }
  402. bool ValidationState_t::HasAnyOfCapabilities(
  403. const CapabilitySet& capabilities) const {
  404. return module_capabilities_.HasAnyOf(capabilities);
  405. }
  406. bool ValidationState_t::HasAnyOfExtensions(
  407. const ExtensionSet& extensions) const {
  408. return module_extensions_.HasAnyOf(extensions);
  409. }
  410. void ValidationState_t::set_addressing_model(SpvAddressingModel am) {
  411. addressing_model_ = am;
  412. switch (am) {
  413. case SpvAddressingModelPhysical32:
  414. pointer_size_and_alignment_ = 4;
  415. break;
  416. default:
  417. // fall through
  418. case SpvAddressingModelPhysical64:
  419. case SpvAddressingModelPhysicalStorageBuffer64EXT:
  420. pointer_size_and_alignment_ = 8;
  421. break;
  422. }
  423. }
  424. SpvAddressingModel ValidationState_t::addressing_model() const {
  425. return addressing_model_;
  426. }
  427. void ValidationState_t::set_memory_model(SpvMemoryModel mm) {
  428. memory_model_ = mm;
  429. }
  430. SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; }
  431. spv_result_t ValidationState_t::RegisterFunction(
  432. uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
  433. uint32_t function_type_id) {
  434. assert(in_function_body() == false &&
  435. "RegisterFunction can only be called when parsing the binary outside "
  436. "of another function");
  437. in_function_ = true;
  438. module_functions_.emplace_back(id, ret_type_id, function_control,
  439. function_type_id);
  440. id_to_function_.emplace(id, &current_function());
  441. // TODO(umar): validate function type and type_id
  442. return SPV_SUCCESS;
  443. }
  444. spv_result_t ValidationState_t::RegisterFunctionEnd() {
  445. assert(in_function_body() == true &&
  446. "RegisterFunctionEnd can only be called when parsing the binary "
  447. "inside of another function");
  448. assert(in_block() == false &&
  449. "RegisterFunctionParameter can only be called when parsing the binary "
  450. "ouside of a block");
  451. current_function().RegisterFunctionEnd();
  452. in_function_ = false;
  453. return SPV_SUCCESS;
  454. }
  455. Instruction* ValidationState_t::AddOrderedInstruction(
  456. const spv_parsed_instruction_t* inst) {
  457. ordered_instructions_.emplace_back(inst);
  458. ordered_instructions_.back().SetLineNum(ordered_instructions_.size());
  459. return &ordered_instructions_.back();
  460. }
  461. // Improves diagnostic messages by collecting names of IDs
  462. void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) {
  463. switch (inst->opcode()) {
  464. case SpvOpName: {
  465. const auto target = inst->GetOperandAs<uint32_t>(0);
  466. const auto* str = reinterpret_cast<const char*>(inst->words().data() +
  467. inst->operand(1).offset);
  468. AssignNameToId(target, str);
  469. break;
  470. }
  471. case SpvOpMemberName: {
  472. const auto target = inst->GetOperandAs<uint32_t>(0);
  473. const auto* str = reinterpret_cast<const char*>(inst->words().data() +
  474. inst->operand(2).offset);
  475. AssignNameToId(target, str);
  476. break;
  477. }
  478. case SpvOpSourceContinued:
  479. case SpvOpSource:
  480. case SpvOpSourceExtension:
  481. case SpvOpString:
  482. case SpvOpLine:
  483. case SpvOpNoLine:
  484. default:
  485. break;
  486. }
  487. }
  488. void ValidationState_t::RegisterInstruction(Instruction* inst) {
  489. if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst));
  490. // If the instruction is using an OpTypeSampledImage as an operand, it should
  491. // be recorded. The validator will ensure that all usages of an
  492. // OpTypeSampledImage and its definition are in the same basic block.
  493. for (uint16_t i = 0; i < inst->operands().size(); ++i) {
  494. const spv_parsed_operand_t& operand = inst->operand(i);
  495. if (SPV_OPERAND_TYPE_ID == operand.type) {
  496. const uint32_t operand_word = inst->word(operand.offset);
  497. Instruction* operand_inst = FindDef(operand_word);
  498. if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) {
  499. RegisterSampledImageConsumer(operand_word, inst);
  500. }
  501. }
  502. }
  503. }
  504. std::vector<Instruction*> ValidationState_t::getSampledImageConsumers(
  505. uint32_t sampled_image_id) const {
  506. std::vector<Instruction*> result;
  507. auto iter = sampled_image_consumers_.find(sampled_image_id);
  508. if (iter != sampled_image_consumers_.end()) {
  509. result = iter->second;
  510. }
  511. return result;
  512. }
  513. void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id,
  514. Instruction* consumer) {
  515. sampled_image_consumers_[sampled_image_id].push_back(consumer);
  516. }
  517. uint32_t ValidationState_t::getIdBound() const { return id_bound_; }
  518. void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; }
  519. bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) {
  520. std::vector<uint32_t> key;
  521. key.push_back(static_cast<uint32_t>(inst->opcode()));
  522. for (size_t index = 0; index < inst->operands().size(); ++index) {
  523. const spv_parsed_operand_t& operand = inst->operand(index);
  524. if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue;
  525. const int words_begin = operand.offset;
  526. const int words_end = words_begin + operand.num_words;
  527. assert(words_end <= static_cast<int>(inst->words().size()));
  528. key.insert(key.end(), inst->words().begin() + words_begin,
  529. inst->words().begin() + words_end);
  530. }
  531. return unique_type_declarations_.insert(std::move(key)).second;
  532. }
  533. uint32_t ValidationState_t::GetTypeId(uint32_t id) const {
  534. const Instruction* inst = FindDef(id);
  535. return inst ? inst->type_id() : 0;
  536. }
  537. SpvOp ValidationState_t::GetIdOpcode(uint32_t id) const {
  538. const Instruction* inst = FindDef(id);
  539. return inst ? inst->opcode() : SpvOpNop;
  540. }
  541. uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
  542. const Instruction* inst = FindDef(id);
  543. assert(inst);
  544. switch (inst->opcode()) {
  545. case SpvOpTypeFloat:
  546. case SpvOpTypeInt:
  547. case SpvOpTypeBool:
  548. return id;
  549. case SpvOpTypeVector:
  550. return inst->word(2);
  551. case SpvOpTypeMatrix:
  552. return GetComponentType(inst->word(2));
  553. case SpvOpTypeCooperativeMatrixNV:
  554. return inst->word(2);
  555. default:
  556. break;
  557. }
  558. if (inst->type_id()) return GetComponentType(inst->type_id());
  559. assert(0);
  560. return 0;
  561. }
  562. uint32_t ValidationState_t::GetDimension(uint32_t id) const {
  563. const Instruction* inst = FindDef(id);
  564. assert(inst);
  565. switch (inst->opcode()) {
  566. case SpvOpTypeFloat:
  567. case SpvOpTypeInt:
  568. case SpvOpTypeBool:
  569. return 1;
  570. case SpvOpTypeVector:
  571. case SpvOpTypeMatrix:
  572. return inst->word(3);
  573. case SpvOpTypeCooperativeMatrixNV:
  574. // Actual dimension isn't known, return 0
  575. return 0;
  576. default:
  577. break;
  578. }
  579. if (inst->type_id()) return GetDimension(inst->type_id());
  580. assert(0);
  581. return 0;
  582. }
  583. uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
  584. const uint32_t component_type_id = GetComponentType(id);
  585. const Instruction* inst = FindDef(component_type_id);
  586. assert(inst);
  587. if (inst->opcode() == SpvOpTypeFloat || inst->opcode() == SpvOpTypeInt)
  588. return inst->word(2);
  589. if (inst->opcode() == SpvOpTypeBool) return 1;
  590. assert(0);
  591. return 0;
  592. }
  593. bool ValidationState_t::IsVoidType(uint32_t id) const {
  594. const Instruction* inst = FindDef(id);
  595. assert(inst);
  596. return inst->opcode() == SpvOpTypeVoid;
  597. }
  598. bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
  599. const Instruction* inst = FindDef(id);
  600. assert(inst);
  601. return inst->opcode() == SpvOpTypeFloat;
  602. }
  603. bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
  604. const Instruction* inst = FindDef(id);
  605. assert(inst);
  606. if (inst->opcode() == SpvOpTypeVector) {
  607. return IsFloatScalarType(GetComponentType(id));
  608. }
  609. return false;
  610. }
  611. bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
  612. const Instruction* inst = FindDef(id);
  613. assert(inst);
  614. if (inst->opcode() == SpvOpTypeFloat) {
  615. return true;
  616. }
  617. if (inst->opcode() == SpvOpTypeVector) {
  618. return IsFloatScalarType(GetComponentType(id));
  619. }
  620. return false;
  621. }
  622. bool ValidationState_t::IsIntScalarType(uint32_t id) const {
  623. const Instruction* inst = FindDef(id);
  624. assert(inst);
  625. return inst->opcode() == SpvOpTypeInt;
  626. }
  627. bool ValidationState_t::IsIntVectorType(uint32_t id) const {
  628. const Instruction* inst = FindDef(id);
  629. assert(inst);
  630. if (inst->opcode() == SpvOpTypeVector) {
  631. return IsIntScalarType(GetComponentType(id));
  632. }
  633. return false;
  634. }
  635. bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
  636. const Instruction* inst = FindDef(id);
  637. assert(inst);
  638. if (inst->opcode() == SpvOpTypeInt) {
  639. return true;
  640. }
  641. if (inst->opcode() == SpvOpTypeVector) {
  642. return IsIntScalarType(GetComponentType(id));
  643. }
  644. return false;
  645. }
  646. bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
  647. const Instruction* inst = FindDef(id);
  648. assert(inst);
  649. return inst->opcode() == SpvOpTypeInt && inst->word(3) == 0;
  650. }
  651. bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
  652. const Instruction* inst = FindDef(id);
  653. assert(inst);
  654. if (inst->opcode() == SpvOpTypeVector) {
  655. return IsUnsignedIntScalarType(GetComponentType(id));
  656. }
  657. return false;
  658. }
  659. bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const {
  660. const Instruction* inst = FindDef(id);
  661. assert(inst);
  662. return inst->opcode() == SpvOpTypeInt && inst->word(3) == 1;
  663. }
  664. bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const {
  665. const Instruction* inst = FindDef(id);
  666. assert(inst);
  667. if (inst->opcode() == SpvOpTypeVector) {
  668. return IsSignedIntScalarType(GetComponentType(id));
  669. }
  670. return false;
  671. }
  672. bool ValidationState_t::IsBoolScalarType(uint32_t id) const {
  673. const Instruction* inst = FindDef(id);
  674. assert(inst);
  675. return inst->opcode() == SpvOpTypeBool;
  676. }
  677. bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
  678. const Instruction* inst = FindDef(id);
  679. assert(inst);
  680. if (inst->opcode() == SpvOpTypeVector) {
  681. return IsBoolScalarType(GetComponentType(id));
  682. }
  683. return false;
  684. }
  685. bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const {
  686. const Instruction* inst = FindDef(id);
  687. assert(inst);
  688. if (inst->opcode() == SpvOpTypeBool) {
  689. return true;
  690. }
  691. if (inst->opcode() == SpvOpTypeVector) {
  692. return IsBoolScalarType(GetComponentType(id));
  693. }
  694. return false;
  695. }
  696. bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
  697. const Instruction* inst = FindDef(id);
  698. assert(inst);
  699. if (inst->opcode() == SpvOpTypeMatrix) {
  700. return IsFloatScalarType(GetComponentType(id));
  701. }
  702. return false;
  703. }
  704. bool ValidationState_t::GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows,
  705. uint32_t* num_cols,
  706. uint32_t* column_type,
  707. uint32_t* component_type) const {
  708. if (!id) return false;
  709. const Instruction* mat_inst = FindDef(id);
  710. assert(mat_inst);
  711. if (mat_inst->opcode() != SpvOpTypeMatrix) return false;
  712. const uint32_t vec_type = mat_inst->word(2);
  713. const Instruction* vec_inst = FindDef(vec_type);
  714. assert(vec_inst);
  715. if (vec_inst->opcode() != SpvOpTypeVector) {
  716. assert(0);
  717. return false;
  718. }
  719. *num_cols = mat_inst->word(3);
  720. *num_rows = vec_inst->word(3);
  721. *column_type = mat_inst->word(2);
  722. *component_type = vec_inst->word(2);
  723. return true;
  724. }
  725. bool ValidationState_t::GetStructMemberTypes(
  726. uint32_t struct_type_id, std::vector<uint32_t>* member_types) const {
  727. member_types->clear();
  728. if (!struct_type_id) return false;
  729. const Instruction* inst = FindDef(struct_type_id);
  730. assert(inst);
  731. if (inst->opcode() != SpvOpTypeStruct) return false;
  732. *member_types =
  733. std::vector<uint32_t>(inst->words().cbegin() + 2, inst->words().cend());
  734. if (member_types->empty()) return false;
  735. return true;
  736. }
  737. bool ValidationState_t::IsPointerType(uint32_t id) const {
  738. const Instruction* inst = FindDef(id);
  739. assert(inst);
  740. return inst->opcode() == SpvOpTypePointer;
  741. }
  742. bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
  743. uint32_t* storage_class) const {
  744. if (!id) return false;
  745. const Instruction* inst = FindDef(id);
  746. assert(inst);
  747. if (inst->opcode() != SpvOpTypePointer) return false;
  748. *storage_class = inst->word(2);
  749. *data_type = inst->word(3);
  750. return true;
  751. }
  752. bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
  753. const Instruction* inst = FindDef(id);
  754. assert(inst);
  755. return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
  756. }
  757. bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
  758. if (!IsCooperativeMatrixType(id)) return false;
  759. return IsFloatScalarType(FindDef(id)->word(2));
  760. }
  761. bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
  762. if (!IsCooperativeMatrixType(id)) return false;
  763. return IsIntScalarType(FindDef(id)->word(2));
  764. }
  765. bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
  766. if (!IsCooperativeMatrixType(id)) return false;
  767. return IsUnsignedIntScalarType(FindDef(id)->word(2));
  768. }
  769. spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
  770. const Instruction* inst, uint32_t m1, uint32_t m2) {
  771. const auto m1_type = FindDef(m1);
  772. const auto m2_type = FindDef(m2);
  773. if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
  774. m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
  775. return diag(SPV_ERROR_INVALID_DATA, inst)
  776. << "Expected cooperative matrix types";
  777. }
  778. uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
  779. uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
  780. uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
  781. uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
  782. uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
  783. uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
  784. bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
  785. m2_is_const_int32 = false;
  786. uint32_t m1_value = 0, m2_value = 0;
  787. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  788. EvalInt32IfConst(m1_scope_id);
  789. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  790. EvalInt32IfConst(m2_scope_id);
  791. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  792. return diag(SPV_ERROR_INVALID_DATA, inst)
  793. << "Expected scopes of Matrix and Result Type to be "
  794. << "identical";
  795. }
  796. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  797. EvalInt32IfConst(m1_rows_id);
  798. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  799. EvalInt32IfConst(m2_rows_id);
  800. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  801. return diag(SPV_ERROR_INVALID_DATA, inst)
  802. << "Expected rows of Matrix type and Result Type to be "
  803. << "identical";
  804. }
  805. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  806. EvalInt32IfConst(m1_cols_id);
  807. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  808. EvalInt32IfConst(m2_cols_id);
  809. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  810. return diag(SPV_ERROR_INVALID_DATA, inst)
  811. << "Expected columns of Matrix type and Result Type to be "
  812. << "identical";
  813. }
  814. return SPV_SUCCESS;
  815. }
  816. uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
  817. size_t operand_index) const {
  818. return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
  819. }
  820. bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
  821. const Instruction* inst = FindDef(id);
  822. if (!inst) {
  823. assert(0 && "Instruction not found");
  824. return false;
  825. }
  826. if (inst->opcode() != SpvOpConstant && inst->opcode() != SpvOpSpecConstant)
  827. return false;
  828. if (!IsIntScalarType(inst->type_id())) return false;
  829. if (inst->words().size() == 4) {
  830. *val = inst->word(3);
  831. } else {
  832. assert(inst->words().size() == 5);
  833. *val = inst->word(3);
  834. *val |= uint64_t(inst->word(4)) << 32;
  835. }
  836. return true;
  837. }
  838. std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
  839. uint32_t id) const {
  840. const Instruction* const inst = FindDef(id);
  841. assert(inst);
  842. const uint32_t type = inst->type_id();
  843. if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) {
  844. return std::make_tuple(false, false, 0);
  845. }
  846. // Spec constant values cannot be evaluated so don't consider constant for
  847. // the purpose of this method.
  848. if (!spvOpcodeIsConstant(inst->opcode()) ||
  849. spvOpcodeIsSpecConstant(inst->opcode())) {
  850. return std::make_tuple(true, false, 0);
  851. }
  852. if (inst->opcode() == SpvOpConstantNull) {
  853. return std::make_tuple(true, true, 0);
  854. }
  855. assert(inst->words().size() == 4);
  856. return std::make_tuple(true, true, inst->word(3));
  857. }
  858. void ValidationState_t::ComputeFunctionToEntryPointMapping() {
  859. for (const uint32_t entry_point : entry_points()) {
  860. std::stack<uint32_t> call_stack;
  861. std::set<uint32_t> visited;
  862. call_stack.push(entry_point);
  863. while (!call_stack.empty()) {
  864. const uint32_t called_func_id = call_stack.top();
  865. call_stack.pop();
  866. if (!visited.insert(called_func_id).second) continue;
  867. function_to_entry_points_[called_func_id].push_back(entry_point);
  868. const Function* called_func = function(called_func_id);
  869. if (called_func) {
  870. // Other checks should error out on this invalid SPIR-V.
  871. for (const uint32_t new_call : called_func->function_call_targets()) {
  872. call_stack.push(new_call);
  873. }
  874. }
  875. }
  876. }
  877. }
  878. void ValidationState_t::ComputeRecursiveEntryPoints() {
  879. for (const Function func : functions()) {
  880. std::stack<uint32_t> call_stack;
  881. std::set<uint32_t> visited;
  882. for (const uint32_t new_call : func.function_call_targets()) {
  883. call_stack.push(new_call);
  884. }
  885. while (!call_stack.empty()) {
  886. const uint32_t called_func_id = call_stack.top();
  887. call_stack.pop();
  888. if (!visited.insert(called_func_id).second) continue;
  889. if (called_func_id == func.id()) {
  890. for (const uint32_t entry_point :
  891. function_to_entry_points_[called_func_id])
  892. recursive_entry_points_.insert(entry_point);
  893. break;
  894. }
  895. const Function* called_func = function(called_func_id);
  896. if (called_func) {
  897. // Other checks should error out on this invalid SPIR-V.
  898. for (const uint32_t new_call : called_func->function_call_targets()) {
  899. call_stack.push(new_call);
  900. }
  901. }
  902. }
  903. }
  904. }
  905. const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
  906. uint32_t func) const {
  907. auto iter = function_to_entry_points_.find(func);
  908. if (iter == function_to_entry_points_.end()) {
  909. return empty_ids_;
  910. } else {
  911. return iter->second;
  912. }
  913. }
  914. std::set<uint32_t> ValidationState_t::EntryPointReferences(uint32_t id) const {
  915. std::set<uint32_t> referenced_entry_points;
  916. const auto inst = FindDef(id);
  917. if (!inst) return referenced_entry_points;
  918. std::vector<const Instruction*> stack;
  919. stack.push_back(inst);
  920. while (!stack.empty()) {
  921. const auto current_inst = stack.back();
  922. stack.pop_back();
  923. if (const auto func = current_inst->function()) {
  924. // Instruction lives in a function, we can stop searching.
  925. const auto function_entry_points = FunctionEntryPoints(func->id());
  926. referenced_entry_points.insert(function_entry_points.begin(),
  927. function_entry_points.end());
  928. } else {
  929. // Instruction is in the global scope, keep searching its uses.
  930. for (auto pair : current_inst->uses()) {
  931. const auto next_inst = pair.first;
  932. stack.push_back(next_inst);
  933. }
  934. }
  935. }
  936. return referenced_entry_points;
  937. }
  938. std::string ValidationState_t::Disassemble(const Instruction& inst) const {
  939. const spv_parsed_instruction_t& c_inst(inst.c_inst());
  940. return Disassemble(c_inst.words, c_inst.num_words);
  941. }
  942. std::string ValidationState_t::Disassemble(const uint32_t* words,
  943. uint16_t num_words) const {
  944. uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
  945. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
  946. return spvInstructionBinaryToText(context()->target_env, words, num_words,
  947. words_, num_words_, disassembly_options);
  948. }
  949. bool ValidationState_t::LogicallyMatch(const Instruction* lhs,
  950. const Instruction* rhs,
  951. bool check_decorations) {
  952. if (lhs->opcode() != rhs->opcode()) {
  953. return false;
  954. }
  955. if (check_decorations) {
  956. const auto& dec_a = id_decorations(lhs->id());
  957. const auto& dec_b = id_decorations(rhs->id());
  958. for (const auto& dec : dec_b) {
  959. if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
  960. return false;
  961. }
  962. }
  963. }
  964. if (lhs->opcode() == SpvOpTypeArray) {
  965. // Size operands must match.
  966. if (lhs->GetOperandAs<uint32_t>(2u) != rhs->GetOperandAs<uint32_t>(2u)) {
  967. return false;
  968. }
  969. // Elements must match or logically match.
  970. const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(1u);
  971. const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(1u);
  972. if (lhs_ele_id == rhs_ele_id) {
  973. return true;
  974. }
  975. const auto lhs_ele = FindDef(lhs_ele_id);
  976. const auto rhs_ele = FindDef(rhs_ele_id);
  977. if (!lhs_ele || !rhs_ele) {
  978. return false;
  979. }
  980. return LogicallyMatch(lhs_ele, rhs_ele, check_decorations);
  981. } else if (lhs->opcode() == SpvOpTypeStruct) {
  982. // Number of elements must match.
  983. if (lhs->operands().size() != rhs->operands().size()) {
  984. return false;
  985. }
  986. for (size_t i = 1u; i < lhs->operands().size(); ++i) {
  987. const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(i);
  988. const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(i);
  989. // Elements must match or logically match.
  990. if (lhs_ele_id == rhs_ele_id) {
  991. continue;
  992. }
  993. const auto lhs_ele = FindDef(lhs_ele_id);
  994. const auto rhs_ele = FindDef(rhs_ele_id);
  995. if (!lhs_ele || !rhs_ele) {
  996. return false;
  997. }
  998. if (!LogicallyMatch(lhs_ele, rhs_ele, check_decorations)) {
  999. return false;
  1000. }
  1001. }
  1002. // All checks passed.
  1003. return true;
  1004. }
  1005. // No other opcodes are acceptable at this point. Arrays and structs are
  1006. // caught above and if they're elements are not arrays or structs they are
  1007. // required to match exactly.
  1008. return false;
  1009. }
  1010. const Instruction* ValidationState_t::TracePointer(
  1011. const Instruction* inst) const {
  1012. auto base_ptr = inst;
  1013. while (base_ptr->opcode() == SpvOpAccessChain ||
  1014. base_ptr->opcode() == SpvOpInBoundsAccessChain ||
  1015. base_ptr->opcode() == SpvOpPtrAccessChain ||
  1016. base_ptr->opcode() == SpvOpInBoundsPtrAccessChain ||
  1017. base_ptr->opcode() == SpvOpCopyObject) {
  1018. base_ptr = FindDef(base_ptr->GetOperandAs<uint32_t>(2u));
  1019. }
  1020. return base_ptr;
  1021. }
  1022. bool ValidationState_t::ContainsSizedIntOrFloatType(uint32_t id, SpvOp type,
  1023. uint32_t width) const {
  1024. if (type != SpvOpTypeInt && type != SpvOpTypeFloat) return false;
  1025. const auto inst = FindDef(id);
  1026. if (!inst) return false;
  1027. if (inst->opcode() == type) {
  1028. return inst->GetOperandAs<uint32_t>(1u) == width;
  1029. }
  1030. switch (inst->opcode()) {
  1031. case SpvOpTypeArray:
  1032. case SpvOpTypeRuntimeArray:
  1033. case SpvOpTypeVector:
  1034. case SpvOpTypeMatrix:
  1035. case SpvOpTypeImage:
  1036. case SpvOpTypeSampledImage:
  1037. case SpvOpTypeCooperativeMatrixNV:
  1038. return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(1u), type,
  1039. width);
  1040. case SpvOpTypePointer:
  1041. if (IsForwardPointer(id)) return false;
  1042. return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(2u), type,
  1043. width);
  1044. case SpvOpTypeFunction:
  1045. case SpvOpTypeStruct: {
  1046. for (uint32_t i = 1; i < inst->operands().size(); ++i) {
  1047. if (ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(i), type,
  1048. width))
  1049. return true;
  1050. }
  1051. return false;
  1052. }
  1053. default:
  1054. return false;
  1055. }
  1056. }
  1057. bool ValidationState_t::ContainsLimitedUseIntOrFloatType(uint32_t id) const {
  1058. if ((!HasCapability(SpvCapabilityInt16) &&
  1059. ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 16)) ||
  1060. (!HasCapability(SpvCapabilityInt8) &&
  1061. ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 8)) ||
  1062. (!HasCapability(SpvCapabilityFloat16) &&
  1063. ContainsSizedIntOrFloatType(id, SpvOpTypeFloat, 16))) {
  1064. return true;
  1065. }
  1066. return false;
  1067. }
  1068. bool ValidationState_t::IsValidStorageClass(
  1069. SpvStorageClass storage_class) const {
  1070. if (spvIsWebGPUEnv(context()->target_env)) {
  1071. switch (storage_class) {
  1072. case SpvStorageClassUniformConstant:
  1073. case SpvStorageClassUniform:
  1074. case SpvStorageClassStorageBuffer:
  1075. case SpvStorageClassInput:
  1076. case SpvStorageClassOutput:
  1077. case SpvStorageClassImage:
  1078. case SpvStorageClassWorkgroup:
  1079. case SpvStorageClassPrivate:
  1080. case SpvStorageClassFunction:
  1081. return true;
  1082. default:
  1083. return false;
  1084. }
  1085. }
  1086. if (spvIsVulkanEnv(context()->target_env)) {
  1087. switch (storage_class) {
  1088. case SpvStorageClassUniformConstant:
  1089. case SpvStorageClassUniform:
  1090. case SpvStorageClassStorageBuffer:
  1091. case SpvStorageClassInput:
  1092. case SpvStorageClassOutput:
  1093. case SpvStorageClassImage:
  1094. case SpvStorageClassWorkgroup:
  1095. case SpvStorageClassPrivate:
  1096. case SpvStorageClassFunction:
  1097. case SpvStorageClassPushConstant:
  1098. case SpvStorageClassPhysicalStorageBuffer:
  1099. case SpvStorageClassRayPayloadNV:
  1100. case SpvStorageClassIncomingRayPayloadNV:
  1101. case SpvStorageClassHitAttributeNV:
  1102. case SpvStorageClassCallableDataNV:
  1103. case SpvStorageClassIncomingCallableDataNV:
  1104. case SpvStorageClassShaderRecordBufferNV:
  1105. return true;
  1106. default:
  1107. return false;
  1108. }
  1109. }
  1110. return true;
  1111. }
  1112. } // namespace val
  1113. } // namespace spvtools