validate_function.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <algorithm>
  15. #include "source/opcode.h"
  16. #include "source/table2.h"
  17. #include "source/val/instruction.h"
  18. #include "source/val/validate.h"
  19. #include "source/val/validation_state.h"
  20. namespace spvtools {
  21. namespace val {
  22. namespace {
  23. // Returns true if |a| and |b| are instructions defining pointers that point to
  24. // types logically match and the decorations that apply to |b| are a subset
  25. // of the decorations that apply to |a|.
  26. bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
  27. ValidationState_t& _) {
  28. if (a->opcode() != spv::Op::OpTypePointer ||
  29. b->opcode() != spv::Op::OpTypePointer) {
  30. return false;
  31. }
  32. const auto& dec_a = _.id_decorations(a->id());
  33. const auto& dec_b = _.id_decorations(b->id());
  34. for (const auto& dec : dec_b) {
  35. if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
  36. return false;
  37. }
  38. }
  39. uint32_t a_type = a->GetOperandAs<uint32_t>(2);
  40. uint32_t b_type = b->GetOperandAs<uint32_t>(2);
  41. if (a_type == b_type) {
  42. return true;
  43. }
  44. Instruction* a_type_inst = _.FindDef(a_type);
  45. Instruction* b_type_inst = _.FindDef(b_type);
  46. return _.LogicallyMatch(a_type_inst, b_type_inst, true);
  47. }
  48. spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
  49. const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
  50. const auto function_type = _.FindDef(function_type_id);
  51. if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
  52. return _.diag(SPV_ERROR_INVALID_ID, inst)
  53. << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
  54. << " is not a function type.";
  55. }
  56. const auto return_id = function_type->GetOperandAs<uint32_t>(1);
  57. if (return_id != inst->type_id()) {
  58. return _.diag(SPV_ERROR_INVALID_ID, inst)
  59. << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
  60. << " does not match the Function Type's return type <id> "
  61. << _.getIdName(return_id) << ".";
  62. }
  63. const std::vector<spv::Op> acceptable = {
  64. spv::Op::OpGroupDecorate,
  65. spv::Op::OpDecorate,
  66. spv::Op::OpEnqueueKernel,
  67. spv::Op::OpEntryPoint,
  68. spv::Op::OpExecutionMode,
  69. spv::Op::OpExecutionModeId,
  70. spv::Op::OpFunctionCall,
  71. spv::Op::OpGetKernelNDrangeSubGroupCount,
  72. spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
  73. spv::Op::OpGetKernelWorkGroupSize,
  74. spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
  75. spv::Op::OpGetKernelLocalSizeForSubgroupCount,
  76. spv::Op::OpGetKernelMaxNumSubgroups,
  77. spv::Op::OpName,
  78. spv::Op::OpCooperativeMatrixPerElementOpNV,
  79. spv::Op::OpCooperativeMatrixReduceNV,
  80. spv::Op::OpCooperativeMatrixLoadTensorNV,
  81. spv::Op::OpConditionalEntryPointINTEL,
  82. };
  83. for (auto& pair : inst->uses()) {
  84. const auto* use = pair.first;
  85. if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
  86. acceptable.end() &&
  87. !use->IsNonSemantic() && !use->IsDebugInfo()) {
  88. return _.diag(SPV_ERROR_INVALID_ID, use)
  89. << "Invalid use of function result id " << _.getIdName(inst->id())
  90. << ".";
  91. }
  92. }
  93. return SPV_SUCCESS;
  94. }
  95. spv_result_t ValidateFunctionParameter(ValidationState_t& _,
  96. const Instruction* inst) {
  97. // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
  98. size_t param_index = 0;
  99. size_t inst_num = inst->LineNum() - 1;
  100. auto func_inst = &_.ordered_instructions()[inst_num];
  101. while (--inst_num) {
  102. func_inst = &_.ordered_instructions()[inst_num];
  103. if (func_inst->opcode() == spv::Op::OpFunction) {
  104. break;
  105. } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
  106. ++param_index;
  107. }
  108. }
  109. if (func_inst->opcode() != spv::Op::OpFunction) {
  110. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  111. << "Function parameter must be preceded by a function.";
  112. }
  113. const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
  114. const auto function_type = _.FindDef(function_type_id);
  115. if (!function_type) {
  116. return _.diag(SPV_ERROR_INVALID_ID, func_inst)
  117. << "Missing function type definition.";
  118. }
  119. if (param_index >= function_type->words().size() - 3) {
  120. return _.diag(SPV_ERROR_INVALID_ID, inst)
  121. << "Too many OpFunctionParameters for " << func_inst->id()
  122. << ": expected " << function_type->words().size() - 3
  123. << " based on the function's type";
  124. }
  125. const auto param_type =
  126. _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
  127. if (!param_type || inst->type_id() != param_type->id()) {
  128. return _.diag(SPV_ERROR_INVALID_ID, inst)
  129. << "OpFunctionParameter Result Type <id> "
  130. << _.getIdName(inst->type_id())
  131. << " does not match the OpTypeFunction parameter "
  132. "type of the same index.";
  133. }
  134. return SPV_SUCCESS;
  135. }
  136. spv_result_t ValidateFunctionCall(ValidationState_t& _,
  137. const Instruction* inst) {
  138. const auto function_id = inst->GetOperandAs<uint32_t>(2);
  139. const auto function = _.FindDef(function_id);
  140. if (!function || spv::Op::OpFunction != function->opcode()) {
  141. return _.diag(SPV_ERROR_INVALID_ID, inst)
  142. << "OpFunctionCall Function <id> " << _.getIdName(function_id)
  143. << " is not a function.";
  144. }
  145. auto return_type = _.FindDef(function->type_id());
  146. if (!return_type || return_type->id() != inst->type_id()) {
  147. return _.diag(SPV_ERROR_INVALID_ID, inst)
  148. << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
  149. << "s type does not match Function <id> "
  150. << _.getIdName(return_type->id()) << "s return type.";
  151. }
  152. if (!_.options()->relax_logical_pointer &&
  153. (_.addressing_model() == spv::AddressingModel::Logical ||
  154. _.addressing_model() == spv::AddressingModel::PhysicalStorageBuffer64)) {
  155. if (return_type->opcode() == spv::Op::OpTypePointer ||
  156. return_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) {
  157. const auto sc = return_type->GetOperandAs<spv::StorageClass>(1);
  158. if (sc != spv::StorageClass::PhysicalStorageBuffer) {
  159. if (!_.HasCapability(spv::Capability::VariablePointersStorageBuffer) &&
  160. sc == spv::StorageClass::StorageBuffer) {
  161. return _.diag(SPV_ERROR_INVALID_ID, inst)
  162. << "In Logical addressing, functions may only return a "
  163. "storage buffer pointer if the "
  164. "VariablePointersStorageBuffer capability is declared";
  165. } else if (!_.HasCapability(spv::Capability::VariablePointers) &&
  166. sc == spv::StorageClass::Workgroup) {
  167. return _.diag(SPV_ERROR_INVALID_ID, inst)
  168. << "In Logical addressing, functions may only return a "
  169. "workgroup pointer if the VariablePointers capability is "
  170. "declared";
  171. } else if (sc != spv::StorageClass::StorageBuffer &&
  172. sc != spv::StorageClass::Workgroup) {
  173. return _.diag(SPV_ERROR_INVALID_ID, inst)
  174. << "In Logical addressing, functions may not return a pointer "
  175. "in this storage class";
  176. }
  177. }
  178. }
  179. }
  180. const auto function_type_id = function->GetOperandAs<uint32_t>(3);
  181. const auto function_type = _.FindDef(function_type_id);
  182. if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
  183. return _.diag(SPV_ERROR_INVALID_ID, inst)
  184. << "Missing function type definition.";
  185. }
  186. const auto function_call_arg_count = inst->words().size() - 4;
  187. const auto function_param_count = function_type->words().size() - 3;
  188. if (function_param_count != function_call_arg_count) {
  189. return _.diag(SPV_ERROR_INVALID_ID, inst)
  190. << "OpFunctionCall Function <id>'s parameter count does not match "
  191. "the argument count.";
  192. }
  193. for (size_t argument_index = 3, param_index = 2;
  194. argument_index < inst->operands().size();
  195. argument_index++, param_index++) {
  196. const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
  197. const auto argument = _.FindDef(argument_id);
  198. if (!argument) {
  199. return _.diag(SPV_ERROR_INVALID_ID, inst)
  200. << "Missing argument " << argument_index - 3 << " definition.";
  201. }
  202. const auto argument_type = _.FindDef(argument->type_id());
  203. if (!argument_type) {
  204. return _.diag(SPV_ERROR_INVALID_ID, inst)
  205. << "Missing argument " << argument_index - 3
  206. << " type definition.";
  207. }
  208. const auto parameter_type_id =
  209. function_type->GetOperandAs<uint32_t>(param_index);
  210. const auto parameter_type = _.FindDef(parameter_type_id);
  211. if (!parameter_type || argument_type->id() != parameter_type->id()) {
  212. if (!parameter_type || !_.options()->before_hlsl_legalization ||
  213. !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
  214. return _.diag(SPV_ERROR_INVALID_ID, inst)
  215. << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
  216. << "s type does not match Function <id> "
  217. << _.getIdName(parameter_type_id) << "s parameter type.";
  218. }
  219. }
  220. if (_.addressing_model() == spv::AddressingModel::Logical ||
  221. _.addressing_model() == spv::AddressingModel::PhysicalStorageBuffer64) {
  222. if ((parameter_type->opcode() == spv::Op::OpTypePointer ||
  223. parameter_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) &&
  224. !_.options()->relax_logical_pointer) {
  225. spv::StorageClass sc =
  226. parameter_type->GetOperandAs<spv::StorageClass>(1u);
  227. if (sc != spv::StorageClass::PhysicalStorageBuffer) {
  228. // Validate which storage classes can be pointer operands.
  229. switch (sc) {
  230. case spv::StorageClass::UniformConstant:
  231. case spv::StorageClass::Function:
  232. case spv::StorageClass::Private:
  233. case spv::StorageClass::Workgroup:
  234. case spv::StorageClass::AtomicCounter:
  235. // SPV_EXT_tile_image
  236. case spv::StorageClass::TileImageEXT:
  237. // SPV_KHR_ray_tracing
  238. case spv::StorageClass::ShaderRecordBufferKHR:
  239. // These are always allowed.
  240. break;
  241. case spv::StorageClass::StorageBuffer:
  242. if (!_.features().variable_pointers) {
  243. return _.diag(SPV_ERROR_INVALID_ID, inst)
  244. << "StorageBuffer pointer operand "
  245. << _.getIdName(argument_id)
  246. << " requires a variable pointers capability";
  247. }
  248. break;
  249. default:
  250. return _.diag(SPV_ERROR_INVALID_ID, inst)
  251. << "Invalid storage class for pointer operand "
  252. << _.getIdName(argument_id);
  253. }
  254. // Validate memory object declaration requirements.
  255. if (argument->opcode() != spv::Op::OpVariable &&
  256. argument->opcode() != spv::Op::OpUntypedVariableKHR &&
  257. argument->opcode() != spv::Op::OpFunctionParameter) {
  258. const bool ssbo_vptr =
  259. _.HasCapability(
  260. spv::Capability::VariablePointersStorageBuffer) &&
  261. sc == spv::StorageClass::StorageBuffer;
  262. const bool wg_vptr =
  263. _.HasCapability(spv::Capability::VariablePointers) &&
  264. sc == spv::StorageClass::Workgroup;
  265. const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
  266. if (!_.options()->before_hlsl_legalization && !ssbo_vptr &&
  267. !wg_vptr && !uc_ptr) {
  268. return _.diag(SPV_ERROR_INVALID_ID, inst)
  269. << "Pointer operand " << _.getIdName(argument_id)
  270. << " must be a memory object declaration";
  271. }
  272. }
  273. }
  274. }
  275. }
  276. }
  277. return SPV_SUCCESS;
  278. }
  279. spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _,
  280. const Instruction* inst) {
  281. const auto function_id = inst->GetOperandAs<uint32_t>(3);
  282. const auto function = _.FindDef(function_id);
  283. if (!function || spv::Op::OpFunction != function->opcode()) {
  284. return _.diag(SPV_ERROR_INVALID_ID, inst)
  285. << "OpCooperativeMatrixPerElementOpNV Function <id> "
  286. << _.getIdName(function_id) << " is not a function.";
  287. }
  288. const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
  289. const auto matrix = _.FindDef(matrix_id);
  290. const auto matrix_type_id = matrix->type_id();
  291. if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
  292. return _.diag(SPV_ERROR_INVALID_ID, inst)
  293. << "OpCooperativeMatrixPerElementOpNV Matrix <id> "
  294. << _.getIdName(matrix_id) << " is not a cooperative matrix.";
  295. }
  296. const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
  297. if (matrix_type_id != result_type_id) {
  298. return _.diag(SPV_ERROR_INVALID_ID, inst)
  299. << "OpCooperativeMatrixPerElementOpNV Result Type <id> "
  300. << _.getIdName(result_type_id) << " must match matrix type <id> "
  301. << _.getIdName(matrix_type_id) << ".";
  302. }
  303. const auto matrix_comp_type_id =
  304. _.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1);
  305. const auto function_type_id = function->GetOperandAs<uint32_t>(3);
  306. const auto function_type = _.FindDef(function_type_id);
  307. auto return_type_id = function_type->GetOperandAs<uint32_t>(1);
  308. if (return_type_id != matrix_comp_type_id) {
  309. return _.diag(SPV_ERROR_INVALID_ID, inst)
  310. << "OpCooperativeMatrixPerElementOpNV function return type <id> "
  311. << _.getIdName(return_type_id)
  312. << " must match matrix component type <id> "
  313. << _.getIdName(matrix_comp_type_id) << ".";
  314. }
  315. if (function_type->operands().size() < 5) {
  316. return _.diag(SPV_ERROR_INVALID_ID, inst)
  317. << "OpCooperativeMatrixPerElementOpNV function type <id> "
  318. << _.getIdName(function_type_id)
  319. << " must have a least three parameters.";
  320. }
  321. const auto param0_id = function_type->GetOperandAs<uint32_t>(2);
  322. const auto param1_id = function_type->GetOperandAs<uint32_t>(3);
  323. const auto param2_id = function_type->GetOperandAs<uint32_t>(4);
  324. if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) {
  325. return _.diag(SPV_ERROR_INVALID_ID, inst)
  326. << "OpCooperativeMatrixPerElementOpNV function type first parameter "
  327. "type <id> "
  328. << _.getIdName(param0_id) << " must be a 32-bit integer.";
  329. }
  330. if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) {
  331. return _.diag(SPV_ERROR_INVALID_ID, inst)
  332. << "OpCooperativeMatrixPerElementOpNV function type second "
  333. "parameter type <id> "
  334. << _.getIdName(param1_id) << " must be a 32-bit integer.";
  335. }
  336. if (param2_id != matrix_comp_type_id) {
  337. return _.diag(SPV_ERROR_INVALID_ID, inst)
  338. << "OpCooperativeMatrixPerElementOpNV function type third parameter "
  339. "type <id> "
  340. << _.getIdName(param2_id) << " must match matrix component type.";
  341. }
  342. return SPV_SUCCESS;
  343. }
  344. } // namespace
  345. spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
  346. switch (inst->opcode()) {
  347. case spv::Op::OpFunction:
  348. if (auto error = ValidateFunction(_, inst)) return error;
  349. break;
  350. case spv::Op::OpFunctionParameter:
  351. if (auto error = ValidateFunctionParameter(_, inst)) return error;
  352. break;
  353. case spv::Op::OpFunctionCall:
  354. if (auto error = ValidateFunctionCall(_, inst)) return error;
  355. break;
  356. case spv::Op::OpCooperativeMatrixPerElementOpNV:
  357. if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst))
  358. return error;
  359. break;
  360. default:
  361. break;
  362. }
  363. return SPV_SUCCESS;
  364. }
  365. } // namespace val
  366. } // namespace spvtools