validate_non_uniform.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates correctness of barrier SPIR-V instructions.
  15. #include "source/opcode.h"
  16. #include "source/spirv_constant.h"
  17. #include "source/spirv_target_env.h"
  18. #include "source/val/instruction.h"
  19. #include "source/val/validate.h"
  20. #include "source/val/validate_scopes.h"
  21. #include "source/val/validation_state.h"
  22. namespace spvtools {
  23. namespace val {
  24. namespace {
  25. spv_result_t ValidateGroupNonUniformElect(ValidationState_t& _,
  26. const Instruction* inst) {
  27. if (!_.IsBoolScalarType(inst->type_id())) {
  28. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  29. << "Result must be a boolean scalar type";
  30. }
  31. return SPV_SUCCESS;
  32. }
  33. spv_result_t ValidateGroupNonUniformAnyAll(ValidationState_t& _,
  34. const Instruction* inst) {
  35. if (!_.IsBoolScalarType(inst->type_id())) {
  36. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  37. << "Result must be a boolean scalar type";
  38. }
  39. if (!_.IsBoolScalarType(_.GetOperandTypeId(inst, 3))) {
  40. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  41. << "Predicate must be a boolean scalar type";
  42. }
  43. return SPV_SUCCESS;
  44. }
  45. spv_result_t ValidateGroupNonUniformAllEqual(ValidationState_t& _,
  46. const Instruction* inst) {
  47. if (!_.IsBoolScalarType(inst->type_id())) {
  48. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  49. << "Result must be a boolean scalar type";
  50. }
  51. const auto value_type = _.GetOperandTypeId(inst, 3);
  52. if (!_.IsFloatScalarOrVectorType(value_type) &&
  53. !_.IsIntScalarOrVectorType(value_type) &&
  54. !_.IsBoolScalarOrVectorType(value_type)) {
  55. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  56. << "Value must be a scalar or vector of integer, floating-point, or "
  57. "boolean type";
  58. }
  59. return SPV_SUCCESS;
  60. }
  61. spv_result_t ValidateGroupNonUniformBroadcastShuffle(ValidationState_t& _,
  62. const Instruction* inst) {
  63. const auto type_id = inst->type_id();
  64. if (!_.IsFloatScalarOrVectorType(type_id) &&
  65. !_.IsIntScalarOrVectorType(type_id) &&
  66. !_.IsBoolScalarOrVectorType(type_id)) {
  67. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  68. << "Result must be a scalar or vector of integer, floating-point, "
  69. "or boolean type";
  70. }
  71. const auto value_type_id = _.GetOperandTypeId(inst, 3);
  72. if (value_type_id != type_id) {
  73. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  74. << "The type of Value must match the Result type";
  75. }
  76. const auto GetOperandName = [](const spv::Op opcode) {
  77. std::string operand;
  78. switch (opcode) {
  79. case spv::Op::OpGroupNonUniformBroadcast:
  80. case spv::Op::OpGroupNonUniformShuffle:
  81. operand = "Id";
  82. break;
  83. case spv::Op::OpGroupNonUniformShuffleXor:
  84. operand = "Mask";
  85. break;
  86. case spv::Op::OpGroupNonUniformQuadBroadcast:
  87. operand = "Index";
  88. break;
  89. case spv::Op::OpGroupNonUniformQuadSwap:
  90. operand = "Direction";
  91. break;
  92. case spv::Op::OpGroupNonUniformShuffleUp:
  93. case spv::Op::OpGroupNonUniformShuffleDown:
  94. default:
  95. operand = "Delta";
  96. break;
  97. }
  98. return operand;
  99. };
  100. const auto id_type_id = _.GetOperandTypeId(inst, 4);
  101. if (!_.IsUnsignedIntScalarType(id_type_id)) {
  102. std::string operand = GetOperandName(inst->opcode());
  103. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  104. << operand << " must be an unsigned integer scalar";
  105. }
  106. const bool should_be_constant =
  107. inst->opcode() == spv::Op::OpGroupNonUniformQuadSwap ||
  108. ((inst->opcode() == spv::Op::OpGroupNonUniformBroadcast ||
  109. inst->opcode() == spv::Op::OpGroupNonUniformQuadBroadcast) &&
  110. _.version() < SPV_SPIRV_VERSION_WORD(1, 5));
  111. if (should_be_constant) {
  112. const auto id_id = inst->GetOperandAs<uint32_t>(4);
  113. const auto id_op = _.GetIdOpcode(id_id);
  114. if (!spvOpcodeIsConstant(id_op)) {
  115. std::string operand = GetOperandName(inst->opcode());
  116. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  117. << "Before SPIR-V 1.5, " << operand
  118. << " must be a constant instruction";
  119. }
  120. }
  121. return SPV_SUCCESS;
  122. }
  123. spv_result_t ValidateGroupNonUniformBroadcastFirst(ValidationState_t& _,
  124. const Instruction* inst) {
  125. const auto type_id = inst->type_id();
  126. if (!_.IsFloatScalarOrVectorType(type_id) &&
  127. !_.IsIntScalarOrVectorType(type_id) &&
  128. !_.IsBoolScalarOrVectorType(type_id)) {
  129. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  130. << "Result must be a scalar or vector of integer, floating-point, "
  131. "or boolean type";
  132. }
  133. const auto value_type_id = _.GetOperandTypeId(inst, 3);
  134. if (value_type_id != type_id) {
  135. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  136. << "The type of Value must match the Result type";
  137. }
  138. return SPV_SUCCESS;
  139. }
  140. spv_result_t ValidateGroupNonUniformBallot(ValidationState_t& _,
  141. const Instruction* inst) {
  142. if (!_.IsUnsignedIntVectorType(inst->type_id())) {
  143. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  144. << "Result must be a 4-component unsigned integer vector";
  145. }
  146. if (_.GetDimension(inst->type_id()) != 4) {
  147. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  148. << "Result must be a 4-component unsigned integer vector";
  149. }
  150. const auto pred_type_id = _.GetOperandTypeId(inst, 3);
  151. if (!_.IsBoolScalarType(pred_type_id)) {
  152. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  153. << "Predicate must be a boolean scalar";
  154. }
  155. return SPV_SUCCESS;
  156. }
  157. spv_result_t ValidateGroupNonUniformInverseBallot(ValidationState_t& _,
  158. const Instruction* inst) {
  159. if (!_.IsBoolScalarType(inst->type_id())) {
  160. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  161. << "Result must be a boolean scalar";
  162. }
  163. const auto value_type_id = _.GetOperandTypeId(inst, 3);
  164. if (!_.IsUnsignedIntVectorType(value_type_id)) {
  165. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  166. << "Value must be a 4-component unsigned integer vector";
  167. }
  168. if (_.GetDimension(value_type_id) != 4) {
  169. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  170. << "Value must be a 4-component unsigned integer vector";
  171. }
  172. return SPV_SUCCESS;
  173. }
  174. spv_result_t ValidateGroupNonUniformBallotBitExtract(ValidationState_t& _,
  175. const Instruction* inst) {
  176. if (!_.IsBoolScalarType(inst->type_id())) {
  177. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  178. << "Result must be a boolean scalar";
  179. }
  180. const auto value_type_id = _.GetOperandTypeId(inst, 3);
  181. if (!_.IsUnsignedIntVectorType(value_type_id)) {
  182. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  183. << "Value must be a 4-component unsigned integer vector";
  184. }
  185. if (_.GetDimension(value_type_id) != 4) {
  186. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  187. << "Value must be a 4-component unsigned integer vector";
  188. }
  189. const auto id_type_id = _.GetOperandTypeId(inst, 4);
  190. if (!_.IsUnsignedIntScalarType(id_type_id)) {
  191. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  192. << "Id must be an unsigned integer scalar";
  193. }
  194. return SPV_SUCCESS;
  195. }
  196. spv_result_t ValidateGroupNonUniformBallotBitCount(ValidationState_t& _,
  197. const Instruction* inst) {
  198. // Scope is already checked by ValidateExecutionScope() above.
  199. const uint32_t result_type = inst->type_id();
  200. if (!_.IsUnsignedIntScalarType(result_type)) {
  201. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  202. << "Expected Result Type to be an unsigned integer type scalar.";
  203. }
  204. const auto value = inst->GetOperandAs<uint32_t>(4);
  205. const auto value_type = _.FindDef(value)->type_id();
  206. if (!_.IsUnsignedIntVectorType(value_type) ||
  207. _.GetDimension(value_type) != 4) {
  208. return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Value to be a "
  209. "vector of four components "
  210. "of integer type scalar";
  211. }
  212. const auto group = inst->GetOperandAs<spv::GroupOperation>(3);
  213. if (spvIsVulkanEnv(_.context()->target_env)) {
  214. if ((group != spv::GroupOperation::Reduce) &&
  215. (group != spv::GroupOperation::InclusiveScan) &&
  216. (group != spv::GroupOperation::ExclusiveScan)) {
  217. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  218. << _.VkErrorID(4685)
  219. << "In Vulkan: The OpGroupNonUniformBallotBitCount group "
  220. "operation must be only: Reduce, InclusiveScan, or "
  221. "ExclusiveScan.";
  222. }
  223. }
  224. return SPV_SUCCESS;
  225. }
  226. spv_result_t ValidateGroupNonUniformBallotFind(ValidationState_t& _,
  227. const Instruction* inst) {
  228. if (!_.IsUnsignedIntScalarType(inst->type_id())) {
  229. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  230. << "Result must be an unsigned integer scalar";
  231. }
  232. const auto value_type_id = _.GetOperandTypeId(inst, 3);
  233. if (!_.IsUnsignedIntVectorType(value_type_id)) {
  234. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  235. << "Value must be a 4-component unsigned integer vector";
  236. }
  237. if (_.GetDimension(value_type_id) != 4) {
  238. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  239. << "Value must be a 4-component unsigned integer vector";
  240. }
  241. return SPV_SUCCESS;
  242. }
  243. spv_result_t ValidateGroupNonUniformArithmetic(ValidationState_t& _,
  244. const Instruction* inst) {
  245. const bool is_unsigned = inst->opcode() == spv::Op::OpGroupNonUniformUMin ||
  246. inst->opcode() == spv::Op::OpGroupNonUniformUMax;
  247. const bool is_float = inst->opcode() == spv::Op::OpGroupNonUniformFAdd ||
  248. inst->opcode() == spv::Op::OpGroupNonUniformFMul ||
  249. inst->opcode() == spv::Op::OpGroupNonUniformFMin ||
  250. inst->opcode() == spv::Op::OpGroupNonUniformFMax;
  251. const bool is_bool = inst->opcode() == spv::Op::OpGroupNonUniformLogicalAnd ||
  252. inst->opcode() == spv::Op::OpGroupNonUniformLogicalOr ||
  253. inst->opcode() == spv::Op::OpGroupNonUniformLogicalXor;
  254. if (is_float) {
  255. if (!_.IsFloatScalarOrVectorType(inst->type_id())) {
  256. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  257. << "Result must be a floating-point scalar or vector";
  258. }
  259. } else if (is_bool) {
  260. if (!_.IsBoolScalarOrVectorType(inst->type_id())) {
  261. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  262. << "Result must be a boolean scalar or vector";
  263. }
  264. } else if (is_unsigned) {
  265. if (!_.IsUnsignedIntScalarOrVectorType(inst->type_id())) {
  266. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  267. << "Result must be an unsigned integer scalar or vector";
  268. }
  269. } else if (!_.IsIntScalarOrVectorType(inst->type_id())) {
  270. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  271. << "Result must be an integer scalar or vector";
  272. }
  273. const auto value_type_id = _.GetOperandTypeId(inst, 4);
  274. if (value_type_id != inst->type_id()) {
  275. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  276. << "The type of Value must match the Result type";
  277. }
  278. const auto group_op = inst->GetOperandAs<spv::GroupOperation>(3);
  279. bool is_clustered_reduce = group_op == spv::GroupOperation::ClusteredReduce;
  280. bool is_partitioned_nv =
  281. group_op == spv::GroupOperation::PartitionedReduceNV ||
  282. group_op == spv::GroupOperation::PartitionedInclusiveScanNV ||
  283. group_op == spv::GroupOperation::PartitionedExclusiveScanNV;
  284. if (inst->operands().size() <= 5) {
  285. if (is_clustered_reduce) {
  286. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  287. << "ClusterSize must be present when Operation is ClusteredReduce";
  288. } else if (is_partitioned_nv) {
  289. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  290. << "Ballot must be present when Operation is PartitionedReduceNV, "
  291. "PartitionedInclusiveScanNV, or PartitionedExclusiveScanNV";
  292. }
  293. } else {
  294. const auto operand_id = inst->GetOperandAs<uint32_t>(5);
  295. const auto* operand = _.FindDef(operand_id);
  296. if (is_partitioned_nv) {
  297. if (!operand || !_.IsIntScalarOrVectorType(operand->type_id())) {
  298. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  299. << "Ballot must be a 4-component integer vector";
  300. }
  301. if (_.GetDimension(operand->type_id()) != 4) {
  302. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  303. << "Ballot must be a 4-component integer vector";
  304. }
  305. } else {
  306. if (!operand || !_.IsUnsignedIntScalarType(operand->type_id())) {
  307. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  308. << "ClusterSize must be an unsigned integer scalar";
  309. }
  310. if (!spvOpcodeIsConstant(operand->opcode())) {
  311. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  312. << "ClusterSize must be a constant instruction";
  313. }
  314. }
  315. }
  316. return SPV_SUCCESS;
  317. }
  318. spv_result_t ValidateGroupNonUniformRotateKHR(ValidationState_t& _,
  319. const Instruction* inst) {
  320. // Scope is already checked by ValidateExecutionScope() above.
  321. const uint32_t result_type = inst->type_id();
  322. if (!_.IsIntScalarOrVectorType(result_type) &&
  323. !_.IsFloatScalarOrVectorType(result_type) &&
  324. !_.IsBoolScalarOrVectorType(result_type)) {
  325. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  326. << "Expected Result Type to be a scalar or vector of "
  327. "floating-point, integer or boolean type.";
  328. }
  329. const uint32_t value_type = _.GetTypeId(inst->GetOperandAs<uint32_t>(3));
  330. if (value_type != result_type) {
  331. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  332. << "Result Type must be the same as the type of Value.";
  333. }
  334. const uint32_t delta_type = _.GetTypeId(inst->GetOperandAs<uint32_t>(4));
  335. if (!_.IsUnsignedIntScalarType(delta_type)) {
  336. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  337. << "Delta must be a scalar of integer type, whose Signedness "
  338. "operand is 0.";
  339. }
  340. if (inst->words().size() > 6) {
  341. const uint32_t cluster_size_op_id = inst->GetOperandAs<uint32_t>(5);
  342. const Instruction* cluster_size_inst = _.FindDef(cluster_size_op_id);
  343. if (!cluster_size_inst ||
  344. !_.IsUnsignedIntScalarType(cluster_size_inst->type_id())) {
  345. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  346. << "ClusterSize must be a scalar of integer type, whose "
  347. "Signedness operand is 0.";
  348. }
  349. if (!spvOpcodeIsConstant(cluster_size_inst->opcode())) {
  350. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  351. << "ClusterSize must come from a constant instruction.";
  352. }
  353. uint64_t cluster_size;
  354. const bool valid_const =
  355. _.EvalConstantValUint64(cluster_size_op_id, &cluster_size);
  356. if (valid_const &&
  357. ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0))) {
  358. return _.diag(SPV_WARNING, inst)
  359. << "Behavior is undefined unless ClusterSize is at least 1 and a "
  360. "power of 2.";
  361. }
  362. // TODO(kpet) Warn about undefined behavior when ClusterSize is greater than
  363. // the declared SubGroupSize
  364. }
  365. return SPV_SUCCESS;
  366. }
  367. } // namespace
  368. // Validates correctness of non-uniform group instructions.
  369. spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst) {
  370. const spv::Op opcode = inst->opcode();
  371. if (spvOpcodeIsNonUniformGroupOperation(opcode)) {
  372. // OpGroupNonUniformQuadAllKHR and OpGroupNonUniformQuadAnyKHR don't have
  373. // scope paramter
  374. if ((opcode != spv::Op::OpGroupNonUniformQuadAllKHR) &&
  375. (opcode != spv::Op::OpGroupNonUniformQuadAnyKHR)) {
  376. const uint32_t execution_scope = inst->GetOperandAs<uint32_t>(2);
  377. if (auto error = ValidateExecutionScope(_, inst, execution_scope)) {
  378. return error;
  379. }
  380. }
  381. }
  382. switch (opcode) {
  383. case spv::Op::OpGroupNonUniformElect:
  384. return ValidateGroupNonUniformElect(_, inst);
  385. case spv::Op::OpGroupNonUniformAny:
  386. case spv::Op::OpGroupNonUniformAll:
  387. return ValidateGroupNonUniformAnyAll(_, inst);
  388. case spv::Op::OpGroupNonUniformAllEqual:
  389. return ValidateGroupNonUniformAllEqual(_, inst);
  390. case spv::Op::OpGroupNonUniformBroadcast:
  391. case spv::Op::OpGroupNonUniformShuffle:
  392. case spv::Op::OpGroupNonUniformShuffleXor:
  393. case spv::Op::OpGroupNonUniformShuffleUp:
  394. case spv::Op::OpGroupNonUniformShuffleDown:
  395. case spv::Op::OpGroupNonUniformQuadBroadcast:
  396. case spv::Op::OpGroupNonUniformQuadSwap:
  397. return ValidateGroupNonUniformBroadcastShuffle(_, inst);
  398. case spv::Op::OpGroupNonUniformBroadcastFirst:
  399. return ValidateGroupNonUniformBroadcastFirst(_, inst);
  400. case spv::Op::OpGroupNonUniformBallot:
  401. return ValidateGroupNonUniformBallot(_, inst);
  402. case spv::Op::OpGroupNonUniformInverseBallot:
  403. return ValidateGroupNonUniformInverseBallot(_, inst);
  404. case spv::Op::OpGroupNonUniformBallotBitExtract:
  405. return ValidateGroupNonUniformBallotBitExtract(_, inst);
  406. case spv::Op::OpGroupNonUniformBallotBitCount:
  407. return ValidateGroupNonUniformBallotBitCount(_, inst);
  408. case spv::Op::OpGroupNonUniformBallotFindLSB:
  409. case spv::Op::OpGroupNonUniformBallotFindMSB:
  410. return ValidateGroupNonUniformBallotFind(_, inst);
  411. case spv::Op::OpGroupNonUniformIAdd:
  412. case spv::Op::OpGroupNonUniformFAdd:
  413. case spv::Op::OpGroupNonUniformIMul:
  414. case spv::Op::OpGroupNonUniformFMul:
  415. case spv::Op::OpGroupNonUniformSMin:
  416. case spv::Op::OpGroupNonUniformUMin:
  417. case spv::Op::OpGroupNonUniformFMin:
  418. case spv::Op::OpGroupNonUniformSMax:
  419. case spv::Op::OpGroupNonUniformUMax:
  420. case spv::Op::OpGroupNonUniformFMax:
  421. case spv::Op::OpGroupNonUniformBitwiseAnd:
  422. case spv::Op::OpGroupNonUniformBitwiseOr:
  423. case spv::Op::OpGroupNonUniformBitwiseXor:
  424. case spv::Op::OpGroupNonUniformLogicalAnd:
  425. case spv::Op::OpGroupNonUniformLogicalOr:
  426. case spv::Op::OpGroupNonUniformLogicalXor:
  427. return ValidateGroupNonUniformArithmetic(_, inst);
  428. case spv::Op::OpGroupNonUniformRotateKHR:
  429. return ValidateGroupNonUniformRotateKHR(_, inst);
  430. default:
  431. break;
  432. }
  433. return SPV_SUCCESS;
  434. }
  435. } // namespace val
  436. } // namespace spvtools