validate_tensor.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. // Copyright (c) 2023-2025 Arm Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Validates correctness of tensor instructions.
  15. #include "source/opcode.h"
  16. #include "source/val/validate.h"
  17. #include "source/val/validation_state.h"
  18. namespace spvtools {
  19. namespace val {
  20. namespace {
  21. bool IsRankedTensor(ValidationState_t& _, uint32_t id) {
  22. auto inst = _.FindDef(id);
  23. if (!inst || inst->opcode() != spv::Op::OpTypeTensorARM ||
  24. inst->words().size() <= 3) {
  25. return false;
  26. }
  27. return true;
  28. }
  29. uint64_t GetTensorTypeRank(ValidationState_t& _, uint32_t id) {
  30. auto inst = _.FindDef(id);
  31. if (!inst || inst->opcode() != spv::Op::OpTypeTensorARM ||
  32. inst->words().size() <= 3) {
  33. return 0;
  34. }
  35. uint64_t rank = 0;
  36. if (!_.EvalConstantValUint64(inst->word(3), &rank)) {
  37. return 0;
  38. }
  39. return rank;
  40. }
  41. bool IsScalarTypeOrOrArrayOfScalarType(ValidationState_t& _, uint32_t id) {
  42. auto inst = _.FindDef(id);
  43. if (!inst) {
  44. return false;
  45. }
  46. return _.IsScalarType(id) || (inst->opcode() == spv::Op::OpTypeArray &&
  47. _.IsScalarType(inst->word(2)));
  48. }
  49. spv_result_t ValidateTensorRead(ValidationState_t& _, const Instruction* inst) {
  50. // Result Type must be a scalar type or array of scalar type.
  51. if (!IsScalarTypeOrOrArrayOfScalarType(_, inst->type_id())) {
  52. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  53. << "Expected Result Type to be a scalar type or array of "
  54. "scalar type.";
  55. }
  56. // Tensor must be a Ranked Tensor.
  57. auto op_tensor = inst->word(3);
  58. auto inst_tensor = _.FindDef(op_tensor);
  59. if (!inst_tensor || !IsRankedTensor(_, inst_tensor->type_id())) {
  60. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  61. << "Expected Tensor to be an OpTypeTensorARM whose Rank is "
  62. "specified";
  63. }
  64. // The scalar type must be the same as the Element Type of Tensor.
  65. if (_.GetComponentType(inst_tensor->type_id()) !=
  66. _.GetComponentType(inst->type_id())) {
  67. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  68. << "Expected Result Type to be the same as the Element Type of "
  69. "Tensor.";
  70. }
  71. // Coordinates is an array whose Element Type must be an integer type and
  72. // whose Length must be equal to the Rank of Tensor.
  73. auto op_coord = inst->word(4);
  74. auto inst_coord = _.FindDef(op_coord);
  75. auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
  76. if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
  77. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  78. << "Expected Coordinates to be an array whose Element Type is an "
  79. "integer type and whose Length is equal to the Rank of Tensor.";
  80. }
  81. // Validate Tensor Operands
  82. if (inst->words().size() > 5) {
  83. auto toperands = static_cast<spv::TensorOperandsMask>(inst->word(5));
  84. if ((toperands & spv::TensorOperandsMask::OutOfBoundsValueARM) !=
  85. spv::TensorOperandsMask::MaskNone) {
  86. if (inst->words().size() < 7) {
  87. return _.diag(SPV_ERROR_INVALID_ID, inst)
  88. << "A value must be provided after the OutOfBoundsValueARM "
  89. "Tensor Operand.";
  90. }
  91. auto op_oobval = inst->word(6);
  92. auto inst_oobval = _.FindDef(op_oobval);
  93. if (_.GetComponentType(inst_tensor->type_id()) !=
  94. _.GetComponentType(inst_oobval->type_id())) {
  95. return _.diag(SPV_ERROR_INVALID_ID, inst)
  96. << "Expected the type of the OutOfBoundsValueARM value to be "
  97. "the same "
  98. "as the Element Type of Tensor.";
  99. }
  100. }
  101. if ((toperands & spv::TensorOperandsMask::MakeElementAvailableARM) !=
  102. spv::TensorOperandsMask::MaskNone) {
  103. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  104. << "MakeElementAvailableARM cannot be used with OpTensorReadARM.";
  105. }
  106. if (((toperands & spv::TensorOperandsMask::MakeElementVisibleARM) !=
  107. spv::TensorOperandsMask::MaskNone) &&
  108. ((toperands & spv::TensorOperandsMask::NonPrivateElementARM) ==
  109. spv::TensorOperandsMask::MaskNone)) {
  110. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  111. << "MakeElementAvailableARM requires NonPrivateElementARM.";
  112. }
  113. }
  114. return SPV_SUCCESS;
  115. }
  116. spv_result_t ValidateTensorWrite(ValidationState_t& _,
  117. const Instruction* inst) {
  118. // Tensor must be a Ranked Tensor.
  119. auto op_tensor = inst->word(1);
  120. auto inst_tensor = _.FindDef(op_tensor);
  121. if (!IsRankedTensor(_, inst_tensor->type_id())) {
  122. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  123. << "Expected Tensor to be an OpTypeTensorARM whose Rank is "
  124. "specified";
  125. }
  126. // Coordinates is an array whose Element Type must be an integer type and
  127. // whose Length must be equal to the Rank of Tensor.
  128. auto op_coord = inst->word(2);
  129. auto inst_coord = _.FindDef(op_coord);
  130. auto tensor_rank = GetTensorTypeRank(_, inst_tensor->type_id());
  131. if (!_.IsIntArrayType(inst_coord->type_id(), tensor_rank)) {
  132. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  133. << "Expected Coordinates to be an array whose Element Type is an "
  134. "integer type and whose Length is equal to the Rank of Tensor.";
  135. }
  136. // Object must be an object of scalar type or array of scalar type.
  137. // The scalar type must be the same as the Element Type of Tensor.
  138. auto op_object = inst->word(3);
  139. auto inst_object = _.FindDef(op_object);
  140. if (!IsScalarTypeOrOrArrayOfScalarType(_, inst_object->type_id()) ||
  141. (_.GetComponentType(inst_object->type_id()) !=
  142. _.GetComponentType(inst_tensor->type_id()))) {
  143. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  144. << "Expected Object to be a scalar type or array of scalar "
  145. "type that is the same as the Element Type of Tensor.";
  146. }
  147. // Validate Tensor Operands
  148. if (inst->words().size() > 5) {
  149. auto toperands = static_cast<spv::TensorOperandsMask>(inst->word(4));
  150. if ((toperands & spv::TensorOperandsMask::OutOfBoundsValueARM) !=
  151. spv::TensorOperandsMask::MaskNone) {
  152. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  153. << "OutOfBoundsValue Tensor Operand not allowed with "
  154. "OpTensorWriteARM.";
  155. }
  156. if ((toperands & spv::TensorOperandsMask::MakeElementVisibleARM) !=
  157. spv::TensorOperandsMask::MaskNone) {
  158. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  159. << "MakeElementVisibleARM not allowed with OpTensorWriteARM.";
  160. }
  161. if (((toperands & spv::TensorOperandsMask::MakeElementAvailableARM) !=
  162. spv::TensorOperandsMask::MaskNone) &&
  163. ((toperands & spv::TensorOperandsMask::NonPrivateElementARM) ==
  164. spv::TensorOperandsMask::MaskNone)) {
  165. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  166. << "MakeElementAvailableARM requires NonPrivateElementARM.";
  167. }
  168. }
  169. return SPV_SUCCESS;
  170. }
  171. spv_result_t ValidateTensorQuerySize(ValidationState_t& _,
  172. const Instruction* inst) {
  173. // Check result type
  174. if (!_.IsIntScalarType(inst->type_id())) {
  175. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  176. << "Expected Result Type to be an integer type scalar";
  177. }
  178. // Check Tensor operand
  179. auto op_tensor = inst->word(3);
  180. auto inst_tensor = _.FindDef(op_tensor);
  181. if (!inst_tensor || !IsRankedTensor(_, inst_tensor->type_id())) {
  182. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  183. << "Expected Tensor to be an OpTypeTensorARM whose Rank is "
  184. "specified";
  185. }
  186. // Check Dimension operand
  187. auto op_dim = inst->word(4);
  188. auto inst_dim = _.FindDef(op_dim);
  189. if (!spvOpcodeIsConstant(inst_dim->opcode()) ||
  190. !_.IsIntScalarType(inst_dim->type_id())) {
  191. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  192. << "Dimension must come from a constant instruction of scalar "
  193. "integer type.";
  194. }
  195. auto inst_tensor_type = _.FindDef(inst_tensor->type_id());
  196. auto op_tensor_rank = inst_tensor_type->word(3);
  197. uint64_t tensor_rank = 0;
  198. uint64_t dim;
  199. if (_.EvalConstantValUint64(op_tensor_rank, &tensor_rank) &&
  200. _.EvalConstantValUint64(op_dim, &dim) && (dim >= tensor_rank)) {
  201. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  202. << "Dimension (" << dim << ") must be less than the Rank of Tensor ("
  203. << tensor_rank << ").";
  204. }
  205. return SPV_SUCCESS;
  206. }
  207. } // namespace
  208. // Validates correctness of tensor instructions.
  209. spv_result_t TensorPass(ValidationState_t& _, const Instruction* inst) {
  210. (void)_;
  211. const spv::Op opcode = inst->opcode();
  212. switch (opcode) {
  213. case spv::Op::OpTensorReadARM:
  214. return ValidateTensorRead(_, inst);
  215. case spv::Op::OpTensorWriteARM:
  216. return ValidateTensorWrite(_, inst);
  217. case spv::Op::OpTensorQuerySizeARM:
  218. return ValidateTensorQuerySize(_, inst);
  219. default:
  220. break;
  221. }
  222. return SPV_SUCCESS;
  223. }
  224. } // namespace val
  225. } // namespace spvtools