validate_function.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/val/validate.h"
  15. #include <algorithm>
  16. #include "source/opcode.h"
  17. #include "source/val/instruction.h"
  18. #include "source/val/validation_state.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. // Returns true if |a| and |b| are instruction defining pointers that point to
  23. // the same type.
  24. bool ArePointersToSameType(val::Instruction* a, val::Instruction* b) {
  25. if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) {
  26. return false;
  27. }
  28. uint32_t a_type = a->GetOperandAs<uint32_t>(2);
  29. return a_type && (a_type == b->GetOperandAs<uint32_t>(2));
  30. }
  31. spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
  32. const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
  33. const auto function_type = _.FindDef(function_type_id);
  34. if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
  35. return _.diag(SPV_ERROR_INVALID_ID, inst)
  36. << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
  37. << "' is not a function type.";
  38. }
  39. const auto return_id = function_type->GetOperandAs<uint32_t>(1);
  40. if (return_id != inst->type_id()) {
  41. return _.diag(SPV_ERROR_INVALID_ID, inst)
  42. << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
  43. << "' does not match the Function Type's return type <id> '"
  44. << _.getIdName(return_id) << "'.";
  45. }
  46. const std::vector<SpvOp> acceptable = {
  47. SpvOpDecorate,
  48. SpvOpEnqueueKernel,
  49. SpvOpEntryPoint,
  50. SpvOpExecutionMode,
  51. SpvOpExecutionModeId,
  52. SpvOpFunctionCall,
  53. SpvOpGetKernelNDrangeSubGroupCount,
  54. SpvOpGetKernelNDrangeMaxSubGroupSize,
  55. SpvOpGetKernelWorkGroupSize,
  56. SpvOpGetKernelPreferredWorkGroupSizeMultiple,
  57. SpvOpGetKernelLocalSizeForSubgroupCount,
  58. SpvOpGetKernelMaxNumSubgroups,
  59. SpvOpName};
  60. for (auto& pair : inst->uses()) {
  61. const auto* use = pair.first;
  62. if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
  63. acceptable.end()) {
  64. return _.diag(SPV_ERROR_INVALID_ID, use)
  65. << "Invalid use of function result id " << _.getIdName(inst->id())
  66. << ".";
  67. }
  68. }
  69. return SPV_SUCCESS;
  70. }
  71. spv_result_t ValidateFunctionParameter(ValidationState_t& _,
  72. const Instruction* inst) {
  73. // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
  74. size_t param_index = 0;
  75. size_t inst_num = inst->LineNum() - 1;
  76. if (inst_num == 0) {
  77. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  78. << "Function parameter cannot be the first instruction.";
  79. }
  80. auto func_inst = &_.ordered_instructions()[inst_num];
  81. while (--inst_num) {
  82. func_inst = &_.ordered_instructions()[inst_num];
  83. if (func_inst->opcode() == SpvOpFunction) {
  84. break;
  85. } else if (func_inst->opcode() == SpvOpFunctionParameter) {
  86. ++param_index;
  87. }
  88. }
  89. if (func_inst->opcode() != SpvOpFunction) {
  90. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  91. << "Function parameter must be preceded by a function.";
  92. }
  93. const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
  94. const auto function_type = _.FindDef(function_type_id);
  95. if (!function_type) {
  96. return _.diag(SPV_ERROR_INVALID_ID, func_inst)
  97. << "Missing function type definition.";
  98. }
  99. if (param_index >= function_type->words().size() - 3) {
  100. return _.diag(SPV_ERROR_INVALID_ID, inst)
  101. << "Too many OpFunctionParameters for " << func_inst->id()
  102. << ": expected " << function_type->words().size() - 3
  103. << " based on the function's type";
  104. }
  105. const auto param_type =
  106. _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
  107. if (!param_type || inst->type_id() != param_type->id()) {
  108. return _.diag(SPV_ERROR_INVALID_ID, inst)
  109. << "OpFunctionParameter Result Type <id> '"
  110. << _.getIdName(inst->type_id())
  111. << "' does not match the OpTypeFunction parameter "
  112. "type of the same index.";
  113. }
  114. // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
  115. // RestrictPointerEXT, or AliasedPointerEXT.
  116. auto param_nonarray_type_id = param_type->id();
  117. while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
  118. param_nonarray_type_id =
  119. _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
  120. }
  121. if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
  122. auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
  123. if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
  124. SpvStorageClassPhysicalStorageBufferEXT) {
  125. // check for Aliased or Restrict
  126. const auto& decorations = _.id_decorations(inst->id());
  127. bool foundAliased = std::any_of(
  128. decorations.begin(), decorations.end(), [](const Decoration& d) {
  129. return SpvDecorationAliased == d.dec_type();
  130. });
  131. bool foundRestrict = std::any_of(
  132. decorations.begin(), decorations.end(), [](const Decoration& d) {
  133. return SpvDecorationRestrict == d.dec_type();
  134. });
  135. if (!foundAliased && !foundRestrict) {
  136. return _.diag(SPV_ERROR_INVALID_ID, inst)
  137. << "OpFunctionParameter " << inst->id()
  138. << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
  139. "pointer.";
  140. }
  141. if (foundAliased && foundRestrict) {
  142. return _.diag(SPV_ERROR_INVALID_ID, inst)
  143. << "OpFunctionParameter " << inst->id()
  144. << ": can't specify both Aliased and Restrict for "
  145. "PhysicalStorageBufferEXT pointer.";
  146. }
  147. } else {
  148. const auto pointee_type_id =
  149. param_nonarray_type->GetOperandAs<uint32_t>(2);
  150. const auto pointee_type = _.FindDef(pointee_type_id);
  151. if (SpvOpTypePointer == pointee_type->opcode() &&
  152. pointee_type->GetOperandAs<uint32_t>(1u) ==
  153. SpvStorageClassPhysicalStorageBufferEXT) {
  154. // check for AliasedPointerEXT/RestrictPointerEXT
  155. const auto& decorations = _.id_decorations(inst->id());
  156. bool foundAliased = std::any_of(
  157. decorations.begin(), decorations.end(), [](const Decoration& d) {
  158. return SpvDecorationAliasedPointerEXT == d.dec_type();
  159. });
  160. bool foundRestrict = std::any_of(
  161. decorations.begin(), decorations.end(), [](const Decoration& d) {
  162. return SpvDecorationRestrictPointerEXT == d.dec_type();
  163. });
  164. if (!foundAliased && !foundRestrict) {
  165. return _.diag(SPV_ERROR_INVALID_ID, inst)
  166. << "OpFunctionParameter " << inst->id()
  167. << ": expected AliasedPointerEXT or RestrictPointerEXT for "
  168. "PhysicalStorageBufferEXT pointer.";
  169. }
  170. if (foundAliased && foundRestrict) {
  171. return _.diag(SPV_ERROR_INVALID_ID, inst)
  172. << "OpFunctionParameter " << inst->id()
  173. << ": can't specify both AliasedPointerEXT and "
  174. "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
  175. }
  176. }
  177. }
  178. }
  179. return SPV_SUCCESS;
  180. }
  181. spv_result_t ValidateFunctionCall(ValidationState_t& _,
  182. const Instruction* inst) {
  183. const auto function_id = inst->GetOperandAs<uint32_t>(2);
  184. const auto function = _.FindDef(function_id);
  185. if (!function || SpvOpFunction != function->opcode()) {
  186. return _.diag(SPV_ERROR_INVALID_ID, inst)
  187. << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
  188. << "' is not a function.";
  189. }
  190. auto return_type = _.FindDef(function->type_id());
  191. if (!return_type || return_type->id() != inst->type_id()) {
  192. return _.diag(SPV_ERROR_INVALID_ID, inst)
  193. << "OpFunctionCall Result Type <id> '"
  194. << _.getIdName(inst->type_id())
  195. << "'s type does not match Function <id> '"
  196. << _.getIdName(return_type->id()) << "'s return type.";
  197. }
  198. const auto function_type_id = function->GetOperandAs<uint32_t>(3);
  199. const auto function_type = _.FindDef(function_type_id);
  200. if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
  201. return _.diag(SPV_ERROR_INVALID_ID, inst)
  202. << "Missing function type definition.";
  203. }
  204. const auto function_call_arg_count = inst->words().size() - 4;
  205. const auto function_param_count = function_type->words().size() - 3;
  206. if (function_param_count != function_call_arg_count) {
  207. return _.diag(SPV_ERROR_INVALID_ID, inst)
  208. << "OpFunctionCall Function <id>'s parameter count does not match "
  209. "the argument count.";
  210. }
  211. for (size_t argument_index = 3, param_index = 2;
  212. argument_index < inst->operands().size();
  213. argument_index++, param_index++) {
  214. const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
  215. const auto argument = _.FindDef(argument_id);
  216. if (!argument) {
  217. return _.diag(SPV_ERROR_INVALID_ID, inst)
  218. << "Missing argument " << argument_index - 3 << " definition.";
  219. }
  220. const auto argument_type = _.FindDef(argument->type_id());
  221. if (!argument_type) {
  222. return _.diag(SPV_ERROR_INVALID_ID, inst)
  223. << "Missing argument " << argument_index - 3
  224. << " type definition.";
  225. }
  226. const auto parameter_type_id =
  227. function_type->GetOperandAs<uint32_t>(param_index);
  228. const auto parameter_type = _.FindDef(parameter_type_id);
  229. if (!parameter_type ||
  230. (argument_type->id() != parameter_type->id() &&
  231. !(_.options()->relax_logical_pointer &&
  232. ArePointersToSameType(argument_type, parameter_type)))) {
  233. return _.diag(SPV_ERROR_INVALID_ID, inst)
  234. << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
  235. << "'s type does not match Function <id> '"
  236. << _.getIdName(parameter_type_id) << "'s parameter type.";
  237. }
  238. if (_.addressing_model() == SpvAddressingModelLogical) {
  239. if (parameter_type->opcode() == SpvOpTypePointer &&
  240. !_.options()->relax_logical_pointer) {
  241. SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
  242. // Validate which storage classes can be pointer operands.
  243. switch (sc) {
  244. case SpvStorageClassUniformConstant:
  245. case SpvStorageClassFunction:
  246. case SpvStorageClassPrivate:
  247. case SpvStorageClassWorkgroup:
  248. case SpvStorageClassAtomicCounter:
  249. // These are always allowed.
  250. break;
  251. case SpvStorageClassStorageBuffer:
  252. if (!_.features().variable_pointers_storage_buffer) {
  253. return _.diag(SPV_ERROR_INVALID_ID, inst)
  254. << "StorageBuffer pointer operand "
  255. << _.getIdName(argument_id)
  256. << " requires a variable pointers capability";
  257. }
  258. break;
  259. default:
  260. return _.diag(SPV_ERROR_INVALID_ID, inst)
  261. << "Invalid storage class for pointer operand "
  262. << _.getIdName(argument_id);
  263. }
  264. // Validate memory object declaration requirements.
  265. if (argument->opcode() != SpvOpVariable &&
  266. argument->opcode() != SpvOpFunctionParameter) {
  267. const bool ssbo_vptr =
  268. _.features().variable_pointers_storage_buffer &&
  269. sc == SpvStorageClassStorageBuffer;
  270. const bool wg_vptr =
  271. _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
  272. const bool uc_ptr = sc == SpvStorageClassUniformConstant;
  273. if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
  274. return _.diag(SPV_ERROR_INVALID_ID, inst)
  275. << "Pointer operand " << _.getIdName(argument_id)
  276. << " must be a memory object declaration";
  277. }
  278. }
  279. }
  280. }
  281. }
  282. return SPV_SUCCESS;
  283. }
  284. } // namespace
  285. spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
  286. switch (inst->opcode()) {
  287. case SpvOpFunction:
  288. if (auto error = ValidateFunction(_, inst)) return error;
  289. break;
  290. case SpvOpFunctionParameter:
  291. if (auto error = ValidateFunctionParameter(_, inst)) return error;
  292. break;
  293. case SpvOpFunctionCall:
  294. if (auto error = ValidateFunctionCall(_, inst)) return error;
  295. break;
  296. default:
  297. break;
  298. }
  299. return SPV_SUCCESS;
  300. }
  301. } // namespace val
  302. } // namespace spvtools