validate_composites.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates correctness of composite SPIR-V instructions.
  15. #include "source/val/validate.h"
  16. #include "source/diagnostic.h"
  17. #include "source/opcode.h"
  18. #include "source/spirv_target_env.h"
  19. #include "source/val/instruction.h"
  20. #include "source/val/validation_state.h"
  21. namespace spvtools {
  22. namespace val {
  23. namespace {
  24. // Returns the type of the value accessed by OpCompositeExtract or
  25. // OpCompositeInsert instruction. The function traverses the hierarchy of
  26. // nested data structures (structs, arrays, vectors, matrices) as directed by
  27. // the sequence of indices in the instruction. May return error if traversal
  28. // fails (encountered non-composite, out of bounds, no indices, nesting too
  29. // deep).
  30. spv_result_t GetExtractInsertValueType(ValidationState_t& _,
  31. const Instruction* inst,
  32. uint32_t* member_type) {
  33. const SpvOp opcode = inst->opcode();
  34. assert(opcode == SpvOpCompositeExtract || opcode == SpvOpCompositeInsert);
  35. uint32_t word_index = opcode == SpvOpCompositeExtract ? 4 : 5;
  36. const uint32_t num_words = static_cast<uint32_t>(inst->words().size());
  37. const uint32_t composite_id_index = word_index - 1;
  38. const uint32_t num_indices = num_words - word_index;
  39. const uint32_t kCompositeExtractInsertMaxNumIndices = 255;
  40. if (num_indices == 0) {
  41. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  42. << "Expected at least one index to Op"
  43. << spvOpcodeString(inst->opcode()) << ", zero found";
  44. } else if (num_indices > kCompositeExtractInsertMaxNumIndices) {
  45. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  46. << "The number of indexes in Op" << spvOpcodeString(opcode)
  47. << " may not exceed " << kCompositeExtractInsertMaxNumIndices
  48. << ". Found " << num_indices << " indexes.";
  49. }
  50. *member_type = _.GetTypeId(inst->word(composite_id_index));
  51. if (*member_type == 0) {
  52. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  53. << "Expected Composite to be an object of composite type";
  54. }
  55. for (; word_index < num_words; ++word_index) {
  56. const uint32_t component_index = inst->word(word_index);
  57. const Instruction* const type_inst = _.FindDef(*member_type);
  58. assert(type_inst);
  59. switch (type_inst->opcode()) {
  60. case SpvOpTypeVector: {
  61. *member_type = type_inst->word(2);
  62. const uint32_t vector_size = type_inst->word(3);
  63. if (component_index >= vector_size) {
  64. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  65. << "Vector access is out of bounds, vector size is "
  66. << vector_size << ", but access index is " << component_index;
  67. }
  68. break;
  69. }
  70. case SpvOpTypeMatrix: {
  71. *member_type = type_inst->word(2);
  72. const uint32_t num_cols = type_inst->word(3);
  73. if (component_index >= num_cols) {
  74. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  75. << "Matrix access is out of bounds, matrix has " << num_cols
  76. << " columns, but access index is " << component_index;
  77. }
  78. break;
  79. }
  80. case SpvOpTypeArray: {
  81. uint64_t array_size = 0;
  82. auto size = _.FindDef(type_inst->word(3));
  83. *member_type = type_inst->word(2);
  84. if (spvOpcodeIsSpecConstant(size->opcode())) {
  85. // Cannot verify against the size of this array.
  86. break;
  87. }
  88. if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
  89. assert(0 && "Array type definition is corrupt");
  90. }
  91. if (component_index >= array_size) {
  92. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  93. << "Array access is out of bounds, array size is "
  94. << array_size << ", but access index is " << component_index;
  95. }
  96. break;
  97. }
  98. case SpvOpTypeRuntimeArray: {
  99. *member_type = type_inst->word(2);
  100. // Array size is unknown.
  101. break;
  102. }
  103. case SpvOpTypeStruct: {
  104. const size_t num_struct_members = type_inst->words().size() - 2;
  105. if (component_index >= num_struct_members) {
  106. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  107. << "Index is out of bounds, can not find index "
  108. << component_index << " in the structure <id> '"
  109. << type_inst->id() << "'. This structure has "
  110. << num_struct_members << " members. Largest valid index is "
  111. << num_struct_members - 1 << ".";
  112. }
  113. *member_type = type_inst->word(component_index + 2);
  114. break;
  115. }
  116. case SpvOpTypeCooperativeMatrixNV: {
  117. *member_type = type_inst->word(2);
  118. break;
  119. }
  120. default:
  121. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  122. << "Reached non-composite type while indexes still remain to "
  123. "be traversed.";
  124. }
  125. }
  126. return SPV_SUCCESS;
  127. }
  128. spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _,
  129. const Instruction* inst) {
  130. const uint32_t result_type = inst->type_id();
  131. const SpvOp result_opcode = _.GetIdOpcode(result_type);
  132. if (!spvOpcodeIsScalarType(result_opcode)) {
  133. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  134. << "Expected Result Type to be a scalar type";
  135. }
  136. const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
  137. const SpvOp vector_opcode = _.GetIdOpcode(vector_type);
  138. if (vector_opcode != SpvOpTypeVector) {
  139. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  140. << "Expected Vector type to be OpTypeVector";
  141. }
  142. if (_.GetComponentType(vector_type) != result_type) {
  143. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  144. << "Expected Vector component type to be equal to Result Type";
  145. }
  146. const auto index = _.FindDef(inst->GetOperandAs<uint32_t>(3));
  147. if (!index || index->type_id() == 0 || !_.IsIntScalarType(index->type_id())) {
  148. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  149. << "Expected Index to be int scalar";
  150. }
  151. if (_.HasCapability(SpvCapabilityShader) &&
  152. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  153. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  154. << "Cannot extract from a vector of 8- or 16-bit types";
  155. }
  156. return SPV_SUCCESS;
  157. }
  158. spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _,
  159. const Instruction* inst) {
  160. const uint32_t result_type = inst->type_id();
  161. const SpvOp result_opcode = _.GetIdOpcode(result_type);
  162. if (result_opcode != SpvOpTypeVector) {
  163. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  164. << "Expected Result Type to be OpTypeVector";
  165. }
  166. const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
  167. if (vector_type != result_type) {
  168. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  169. << "Expected Vector type to be equal to Result Type";
  170. }
  171. const uint32_t component_type = _.GetOperandTypeId(inst, 3);
  172. if (_.GetComponentType(result_type) != component_type) {
  173. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  174. << "Expected Component type to be equal to Result Type "
  175. << "component type";
  176. }
  177. const uint32_t index_type = _.GetOperandTypeId(inst, 4);
  178. if (!_.IsIntScalarType(index_type)) {
  179. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  180. << "Expected Index to be int scalar";
  181. }
  182. if (_.HasCapability(SpvCapabilityShader) &&
  183. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  184. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  185. << "Cannot insert into a vector of 8- or 16-bit types";
  186. }
  187. return SPV_SUCCESS;
  188. }
  189. spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
  190. const Instruction* inst) {
  191. const uint32_t num_operands = static_cast<uint32_t>(inst->operands().size());
  192. const uint32_t result_type = inst->type_id();
  193. const SpvOp result_opcode = _.GetIdOpcode(result_type);
  194. switch (result_opcode) {
  195. case SpvOpTypeVector: {
  196. const uint32_t num_result_components = _.GetDimension(result_type);
  197. const uint32_t result_component_type = _.GetComponentType(result_type);
  198. uint32_t given_component_count = 0;
  199. if (num_operands <= 3) {
  200. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  201. << "Expected number of constituents to be at least 2";
  202. }
  203. for (uint32_t operand_index = 2; operand_index < num_operands;
  204. ++operand_index) {
  205. const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
  206. if (operand_type == result_component_type) {
  207. ++given_component_count;
  208. } else {
  209. if (_.GetIdOpcode(operand_type) != SpvOpTypeVector ||
  210. _.GetComponentType(operand_type) != result_component_type) {
  211. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  212. << "Expected Constituents to be scalars or vectors of"
  213. << " the same type as Result Type components";
  214. }
  215. given_component_count += _.GetDimension(operand_type);
  216. }
  217. }
  218. if (num_result_components != given_component_count) {
  219. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  220. << "Expected total number of given components to be equal "
  221. << "to the size of Result Type vector";
  222. }
  223. break;
  224. }
  225. case SpvOpTypeMatrix: {
  226. uint32_t result_num_rows = 0;
  227. uint32_t result_num_cols = 0;
  228. uint32_t result_col_type = 0;
  229. uint32_t result_component_type = 0;
  230. if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
  231. &result_col_type, &result_component_type)) {
  232. assert(0);
  233. }
  234. if (result_num_cols + 2 != num_operands) {
  235. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  236. << "Expected total number of Constituents to be equal "
  237. << "to the number of columns of Result Type matrix";
  238. }
  239. for (uint32_t operand_index = 2; operand_index < num_operands;
  240. ++operand_index) {
  241. const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
  242. if (operand_type != result_col_type) {
  243. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  244. << "Expected Constituent type to be equal to the column "
  245. << "type Result Type matrix";
  246. }
  247. }
  248. break;
  249. }
  250. case SpvOpTypeArray: {
  251. const Instruction* const array_inst = _.FindDef(result_type);
  252. assert(array_inst);
  253. assert(array_inst->opcode() == SpvOpTypeArray);
  254. auto size = _.FindDef(array_inst->word(3));
  255. if (spvOpcodeIsSpecConstant(size->opcode())) {
  256. // Cannot verify against the size of this array.
  257. break;
  258. }
  259. uint64_t array_size = 0;
  260. if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) {
  261. assert(0 && "Array type definition is corrupt");
  262. }
  263. if (array_size + 2 != num_operands) {
  264. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  265. << "Expected total number of Constituents to be equal "
  266. << "to the number of elements of Result Type array";
  267. }
  268. const uint32_t result_component_type = array_inst->word(2);
  269. for (uint32_t operand_index = 2; operand_index < num_operands;
  270. ++operand_index) {
  271. const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
  272. if (operand_type != result_component_type) {
  273. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  274. << "Expected Constituent type to be equal to the column "
  275. << "type Result Type array";
  276. }
  277. }
  278. break;
  279. }
  280. case SpvOpTypeStruct: {
  281. const Instruction* const struct_inst = _.FindDef(result_type);
  282. assert(struct_inst);
  283. assert(struct_inst->opcode() == SpvOpTypeStruct);
  284. if (struct_inst->operands().size() + 1 != num_operands) {
  285. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  286. << "Expected total number of Constituents to be equal "
  287. << "to the number of members of Result Type struct";
  288. }
  289. for (uint32_t operand_index = 2; operand_index < num_operands;
  290. ++operand_index) {
  291. const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
  292. const uint32_t member_type = struct_inst->word(operand_index);
  293. if (operand_type != member_type) {
  294. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  295. << "Expected Constituent type to be equal to the "
  296. << "corresponding member type of Result Type struct";
  297. }
  298. }
  299. break;
  300. }
  301. case SpvOpTypeCooperativeMatrixNV: {
  302. const auto result_type_inst = _.FindDef(result_type);
  303. assert(result_type_inst);
  304. const auto component_type_id =
  305. result_type_inst->GetOperandAs<uint32_t>(1);
  306. if (3 != num_operands) {
  307. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  308. << "Expected single constituent";
  309. }
  310. const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
  311. if (operand_type_id != component_type_id) {
  312. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  313. << "Expected Constituent type to be equal to the component type";
  314. }
  315. break;
  316. }
  317. default: {
  318. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  319. << "Expected Result Type to be a composite type";
  320. }
  321. }
  322. if (_.HasCapability(SpvCapabilityShader) &&
  323. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  324. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  325. << "Cannot create a composite containing 8- or 16-bit types";
  326. }
  327. return SPV_SUCCESS;
  328. }
  329. spv_result_t ValidateCompositeExtract(ValidationState_t& _,
  330. const Instruction* inst) {
  331. uint32_t member_type = 0;
  332. if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
  333. return error;
  334. }
  335. const uint32_t result_type = inst->type_id();
  336. if (result_type != member_type) {
  337. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  338. << "Result type (Op" << spvOpcodeString(_.GetIdOpcode(result_type))
  339. << ") does not match the type that results from indexing into "
  340. "the composite (Op"
  341. << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
  342. }
  343. if (_.HasCapability(SpvCapabilityShader) &&
  344. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  345. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  346. << "Cannot extract from a composite of 8- or 16-bit types";
  347. }
  348. return SPV_SUCCESS;
  349. }
  350. spv_result_t ValidateCompositeInsert(ValidationState_t& _,
  351. const Instruction* inst) {
  352. const uint32_t object_type = _.GetOperandTypeId(inst, 2);
  353. const uint32_t composite_type = _.GetOperandTypeId(inst, 3);
  354. const uint32_t result_type = inst->type_id();
  355. if (result_type != composite_type) {
  356. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  357. << "The Result Type must be the same as Composite type in Op"
  358. << spvOpcodeString(inst->opcode()) << " yielding Result Id "
  359. << result_type << ".";
  360. }
  361. uint32_t member_type = 0;
  362. if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
  363. return error;
  364. }
  365. if (object_type != member_type) {
  366. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  367. << "The Object type (Op"
  368. << spvOpcodeString(_.GetIdOpcode(object_type))
  369. << ") does not match the type that results from indexing into the "
  370. "Composite (Op"
  371. << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
  372. }
  373. if (_.HasCapability(SpvCapabilityShader) &&
  374. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  375. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  376. << "Cannot insert into a composite of 8- or 16-bit types";
  377. }
  378. return SPV_SUCCESS;
  379. }
  380. spv_result_t ValidateCopyObject(ValidationState_t& _, const Instruction* inst) {
  381. const uint32_t result_type = inst->type_id();
  382. const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
  383. if (operand_type != result_type) {
  384. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  385. << "Expected Result Type and Operand type to be the same";
  386. }
  387. return SPV_SUCCESS;
  388. }
  389. spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
  390. uint32_t result_num_rows = 0;
  391. uint32_t result_num_cols = 0;
  392. uint32_t result_col_type = 0;
  393. uint32_t result_component_type = 0;
  394. const uint32_t result_type = inst->type_id();
  395. if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
  396. &result_col_type, &result_component_type)) {
  397. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  398. << "Expected Result Type to be a matrix type";
  399. }
  400. const uint32_t matrix_type = _.GetOperandTypeId(inst, 2);
  401. uint32_t matrix_num_rows = 0;
  402. uint32_t matrix_num_cols = 0;
  403. uint32_t matrix_col_type = 0;
  404. uint32_t matrix_component_type = 0;
  405. if (!_.GetMatrixTypeInfo(matrix_type, &matrix_num_rows, &matrix_num_cols,
  406. &matrix_col_type, &matrix_component_type)) {
  407. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  408. << "Expected Matrix to be of type OpTypeMatrix";
  409. }
  410. if (result_component_type != matrix_component_type) {
  411. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  412. << "Expected component types of Matrix and Result Type to be "
  413. << "identical";
  414. }
  415. if (result_num_rows != matrix_num_cols ||
  416. result_num_cols != matrix_num_rows) {
  417. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  418. << "Expected number of columns and the column size of Matrix "
  419. << "to be the reverse of those of Result Type";
  420. }
  421. if (_.HasCapability(SpvCapabilityShader) &&
  422. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  423. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  424. << "Cannot transpose matrices of 16-bit floats";
  425. }
  426. return SPV_SUCCESS;
  427. }
  428. spv_result_t ValidateVectorShuffle(ValidationState_t& _,
  429. const Instruction* inst) {
  430. auto resultType = _.FindDef(inst->type_id());
  431. if (!resultType || resultType->opcode() != SpvOpTypeVector) {
  432. return _.diag(SPV_ERROR_INVALID_ID, inst)
  433. << "The Result Type of OpVectorShuffle must be"
  434. << " OpTypeVector. Found Op"
  435. << spvOpcodeString(static_cast<SpvOp>(resultType->opcode())) << ".";
  436. }
  437. // The number of components in Result Type must be the same as the number of
  438. // Component operands.
  439. auto componentCount = inst->operands().size() - 4;
  440. auto resultVectorDimension = resultType->GetOperandAs<uint32_t>(2);
  441. if (componentCount != resultVectorDimension) {
  442. return _.diag(SPV_ERROR_INVALID_ID, inst)
  443. << "OpVectorShuffle component literals count does not match "
  444. "Result Type <id> '"
  445. << _.getIdName(resultType->id()) << "'s vector component count.";
  446. }
  447. // Vector 1 and Vector 2 must both have vector types, with the same Component
  448. // Type as Result Type.
  449. auto vector1Object = _.FindDef(inst->GetOperandAs<uint32_t>(2));
  450. auto vector1Type = _.FindDef(vector1Object->type_id());
  451. auto vector2Object = _.FindDef(inst->GetOperandAs<uint32_t>(3));
  452. auto vector2Type = _.FindDef(vector2Object->type_id());
  453. if (!vector1Type || vector1Type->opcode() != SpvOpTypeVector) {
  454. return _.diag(SPV_ERROR_INVALID_ID, inst)
  455. << "The type of Vector 1 must be OpTypeVector.";
  456. }
  457. if (!vector2Type || vector2Type->opcode() != SpvOpTypeVector) {
  458. return _.diag(SPV_ERROR_INVALID_ID, inst)
  459. << "The type of Vector 2 must be OpTypeVector.";
  460. }
  461. auto resultComponentType = resultType->GetOperandAs<uint32_t>(1);
  462. if (vector1Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
  463. return _.diag(SPV_ERROR_INVALID_ID, inst)
  464. << "The Component Type of Vector 1 must be the same as ResultType.";
  465. }
  466. if (vector2Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
  467. return _.diag(SPV_ERROR_INVALID_ID, inst)
  468. << "The Component Type of Vector 2 must be the same as ResultType.";
  469. }
  470. // All Component literals must either be FFFFFFFF or in [0, N - 1].
  471. // For WebGPU specifically, Component literals cannot be FFFFFFFF.
  472. auto vector1ComponentCount = vector1Type->GetOperandAs<uint32_t>(2);
  473. auto vector2ComponentCount = vector2Type->GetOperandAs<uint32_t>(2);
  474. auto N = vector1ComponentCount + vector2ComponentCount;
  475. auto firstLiteralIndex = 4;
  476. const auto is_webgpu_env = spvIsWebGPUEnv(_.context()->target_env);
  477. for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) {
  478. auto literal = inst->GetOperandAs<uint32_t>(i);
  479. if (literal != 0xFFFFFFFF && literal >= N) {
  480. return _.diag(SPV_ERROR_INVALID_ID, inst)
  481. << "Component index " << literal << " is out of bounds for "
  482. << "combined (Vector1 + Vector2) size of " << N << ".";
  483. }
  484. if (is_webgpu_env && literal == 0xFFFFFFFF) {
  485. return _.diag(SPV_ERROR_INVALID_ID, inst)
  486. << "Component literal at operand " << i - firstLiteralIndex
  487. << " cannot be 0xFFFFFFFF in WebGPU execution environment.";
  488. }
  489. }
  490. if (_.HasCapability(SpvCapabilityShader) &&
  491. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  492. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  493. << "Cannot shuffle a vector of 8- or 16-bit types";
  494. }
  495. return SPV_SUCCESS;
  496. }
  497. spv_result_t ValidateCopyLogical(ValidationState_t& _,
  498. const Instruction* inst) {
  499. const auto result_type = _.FindDef(inst->type_id());
  500. const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
  501. const auto source_type = _.FindDef(source->type_id());
  502. if (!source_type || !result_type || source_type == result_type) {
  503. return _.diag(SPV_ERROR_INVALID_ID, inst)
  504. << "Result Type must not equal the Operand type";
  505. }
  506. if (!_.LogicallyMatch(source_type, result_type, false)) {
  507. return _.diag(SPV_ERROR_INVALID_ID, inst)
  508. << "Result Type does not logically match the Operand type";
  509. }
  510. if (_.HasCapability(SpvCapabilityShader) &&
  511. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  512. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  513. << "Cannot copy composites of 8- or 16-bit types";
  514. }
  515. return SPV_SUCCESS;
  516. }
  517. } // anonymous namespace
  518. // Validates correctness of composite instructions.
  519. spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
  520. switch (inst->opcode()) {
  521. case SpvOpVectorExtractDynamic:
  522. return ValidateVectorExtractDynamic(_, inst);
  523. case SpvOpVectorInsertDynamic:
  524. return ValidateVectorInsertDyanmic(_, inst);
  525. case SpvOpVectorShuffle:
  526. return ValidateVectorShuffle(_, inst);
  527. case SpvOpCompositeConstruct:
  528. return ValidateCompositeConstruct(_, inst);
  529. case SpvOpCompositeExtract:
  530. return ValidateCompositeExtract(_, inst);
  531. case SpvOpCompositeInsert:
  532. return ValidateCompositeInsert(_, inst);
  533. case SpvOpCopyObject:
  534. return ValidateCopyObject(_, inst);
  535. case SpvOpTranspose:
  536. return ValidateTranspose(_, inst);
  537. case SpvOpCopyLogical:
  538. return ValidateCopyLogical(_, inst);
  539. default:
  540. break;
  541. }
  542. return SPV_SUCCESS;
  543. }
  544. } // namespace val
  545. } // namespace spvtools