validate_function.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/val/validate.h"
  15. #include <algorithm>
  16. #include "source/opcode.h"
  17. #include "source/val/instruction.h"
  18. #include "source/val/validation_state.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
  23. const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
  24. const auto function_type = _.FindDef(function_type_id);
  25. if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
  26. return _.diag(SPV_ERROR_INVALID_ID, inst)
  27. << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
  28. << "' is not a function type.";
  29. }
  30. const auto return_id = function_type->GetOperandAs<uint32_t>(1);
  31. if (return_id != inst->type_id()) {
  32. return _.diag(SPV_ERROR_INVALID_ID, inst)
  33. << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
  34. << "' does not match the Function Type's return type <id> '"
  35. << _.getIdName(return_id) << "'.";
  36. }
  37. for (auto& pair : inst->uses()) {
  38. const auto* use = pair.first;
  39. const std::vector<SpvOp> acceptable = {
  40. SpvOpFunctionCall,
  41. SpvOpEntryPoint,
  42. SpvOpEnqueueKernel,
  43. SpvOpGetKernelNDrangeSubGroupCount,
  44. SpvOpGetKernelNDrangeMaxSubGroupSize,
  45. SpvOpGetKernelWorkGroupSize,
  46. SpvOpGetKernelPreferredWorkGroupSizeMultiple,
  47. SpvOpGetKernelLocalSizeForSubgroupCount,
  48. SpvOpGetKernelMaxNumSubgroups};
  49. if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
  50. acceptable.end()) {
  51. return _.diag(SPV_ERROR_INVALID_ID, use)
  52. << "Invalid use of function result id " << _.getIdName(inst->id())
  53. << ".";
  54. }
  55. }
  56. return SPV_SUCCESS;
  57. }
  58. spv_result_t ValidateFunctionParameter(ValidationState_t& _,
  59. const Instruction* inst) {
  60. // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
  61. size_t param_index = 0;
  62. size_t inst_num = inst->LineNum() - 1;
  63. if (inst_num == 0) {
  64. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  65. << "Function parameter cannot be the first instruction.";
  66. }
  67. auto func_inst = &_.ordered_instructions()[inst_num];
  68. while (--inst_num) {
  69. func_inst = &_.ordered_instructions()[inst_num];
  70. if (func_inst->opcode() == SpvOpFunction) {
  71. break;
  72. } else if (func_inst->opcode() == SpvOpFunctionParameter) {
  73. ++param_index;
  74. }
  75. }
  76. if (func_inst->opcode() != SpvOpFunction) {
  77. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  78. << "Function parameter must be preceded by a function.";
  79. }
  80. const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
  81. const auto function_type = _.FindDef(function_type_id);
  82. if (!function_type) {
  83. return _.diag(SPV_ERROR_INVALID_ID, func_inst)
  84. << "Missing function type definition.";
  85. }
  86. if (param_index >= function_type->words().size() - 3) {
  87. return _.diag(SPV_ERROR_INVALID_ID, inst)
  88. << "Too many OpFunctionParameters for " << func_inst->id()
  89. << ": expected " << function_type->words().size() - 3
  90. << " based on the function's type";
  91. }
  92. const auto param_type =
  93. _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
  94. if (!param_type || inst->type_id() != param_type->id()) {
  95. return _.diag(SPV_ERROR_INVALID_ID, inst)
  96. << "OpFunctionParameter Result Type <id> '"
  97. << _.getIdName(inst->type_id())
  98. << "' does not match the OpTypeFunction parameter "
  99. "type of the same index.";
  100. }
  101. // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
  102. // RestrictPointerEXT, or AliasedPointerEXT.
  103. auto param_nonarray_type_id = param_type->id();
  104. while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
  105. param_nonarray_type_id =
  106. _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
  107. }
  108. if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
  109. auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
  110. if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
  111. SpvStorageClassPhysicalStorageBufferEXT) {
  112. // check for Aliased or Restrict
  113. const auto& decorations = _.id_decorations(inst->id());
  114. bool foundAliased = std::any_of(
  115. decorations.begin(), decorations.end(), [](const Decoration& d) {
  116. return SpvDecorationAliased == d.dec_type();
  117. });
  118. bool foundRestrict = std::any_of(
  119. decorations.begin(), decorations.end(), [](const Decoration& d) {
  120. return SpvDecorationRestrict == d.dec_type();
  121. });
  122. if (!foundAliased && !foundRestrict) {
  123. return _.diag(SPV_ERROR_INVALID_ID, inst)
  124. << "OpFunctionParameter " << inst->id()
  125. << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
  126. "pointer.";
  127. }
  128. if (foundAliased && foundRestrict) {
  129. return _.diag(SPV_ERROR_INVALID_ID, inst)
  130. << "OpFunctionParameter " << inst->id()
  131. << ": can't specify both Aliased and Restrict for "
  132. "PhysicalStorageBufferEXT pointer.";
  133. }
  134. } else {
  135. const auto pointee_type_id =
  136. param_nonarray_type->GetOperandAs<uint32_t>(2);
  137. const auto pointee_type = _.FindDef(pointee_type_id);
  138. if (SpvOpTypePointer == pointee_type->opcode() &&
  139. pointee_type->GetOperandAs<uint32_t>(1u) ==
  140. SpvStorageClassPhysicalStorageBufferEXT) {
  141. // check for AliasedPointerEXT/RestrictPointerEXT
  142. const auto& decorations = _.id_decorations(inst->id());
  143. bool foundAliased = std::any_of(
  144. decorations.begin(), decorations.end(), [](const Decoration& d) {
  145. return SpvDecorationAliasedPointerEXT == d.dec_type();
  146. });
  147. bool foundRestrict = std::any_of(
  148. decorations.begin(), decorations.end(), [](const Decoration& d) {
  149. return SpvDecorationRestrictPointerEXT == d.dec_type();
  150. });
  151. if (!foundAliased && !foundRestrict) {
  152. return _.diag(SPV_ERROR_INVALID_ID, inst)
  153. << "OpFunctionParameter " << inst->id()
  154. << ": expected AliasedPointerEXT or RestrictPointerEXT for "
  155. "PhysicalStorageBufferEXT pointer.";
  156. }
  157. if (foundAliased && foundRestrict) {
  158. return _.diag(SPV_ERROR_INVALID_ID, inst)
  159. << "OpFunctionParameter " << inst->id()
  160. << ": can't specify both AliasedPointerEXT and "
  161. "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
  162. }
  163. }
  164. }
  165. }
  166. return SPV_SUCCESS;
  167. }
  168. spv_result_t ValidateFunctionCall(ValidationState_t& _,
  169. const Instruction* inst) {
  170. const auto function_id = inst->GetOperandAs<uint32_t>(2);
  171. const auto function = _.FindDef(function_id);
  172. if (!function || SpvOpFunction != function->opcode()) {
  173. return _.diag(SPV_ERROR_INVALID_ID, inst)
  174. << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
  175. << "' is not a function.";
  176. }
  177. auto return_type = _.FindDef(function->type_id());
  178. if (!return_type || return_type->id() != inst->type_id()) {
  179. return _.diag(SPV_ERROR_INVALID_ID, inst)
  180. << "OpFunctionCall Result Type <id> '"
  181. << _.getIdName(inst->type_id())
  182. << "'s type does not match Function <id> '"
  183. << _.getIdName(return_type->id()) << "'s return type.";
  184. }
  185. const auto function_type_id = function->GetOperandAs<uint32_t>(3);
  186. const auto function_type = _.FindDef(function_type_id);
  187. if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
  188. return _.diag(SPV_ERROR_INVALID_ID, inst)
  189. << "Missing function type definition.";
  190. }
  191. const auto function_call_arg_count = inst->words().size() - 4;
  192. const auto function_param_count = function_type->words().size() - 3;
  193. if (function_param_count != function_call_arg_count) {
  194. return _.diag(SPV_ERROR_INVALID_ID, inst)
  195. << "OpFunctionCall Function <id>'s parameter count does not match "
  196. "the argument count.";
  197. }
  198. for (size_t argument_index = 3, param_index = 2;
  199. argument_index < inst->operands().size();
  200. argument_index++, param_index++) {
  201. const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
  202. const auto argument = _.FindDef(argument_id);
  203. if (!argument) {
  204. return _.diag(SPV_ERROR_INVALID_ID, inst)
  205. << "Missing argument " << argument_index - 3 << " definition.";
  206. }
  207. const auto argument_type = _.FindDef(argument->type_id());
  208. if (!argument_type) {
  209. return _.diag(SPV_ERROR_INVALID_ID, inst)
  210. << "Missing argument " << argument_index - 3
  211. << " type definition.";
  212. }
  213. const auto parameter_type_id =
  214. function_type->GetOperandAs<uint32_t>(param_index);
  215. const auto parameter_type = _.FindDef(parameter_type_id);
  216. if (!parameter_type || argument_type->id() != parameter_type->id()) {
  217. return _.diag(SPV_ERROR_INVALID_ID, inst)
  218. << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
  219. << "'s type does not match Function <id> '"
  220. << _.getIdName(parameter_type_id) << "'s parameter type.";
  221. }
  222. if (_.addressing_model() == SpvAddressingModelLogical) {
  223. if (parameter_type->opcode() == SpvOpTypePointer) {
  224. SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
  225. // Validate which storage classes can be pointer operands.
  226. switch (sc) {
  227. case SpvStorageClassUniformConstant:
  228. case SpvStorageClassFunction:
  229. case SpvStorageClassPrivate:
  230. case SpvStorageClassWorkgroup:
  231. case SpvStorageClassAtomicCounter:
  232. // These are always allowed.
  233. break;
  234. case SpvStorageClassStorageBuffer:
  235. if (!_.features().variable_pointers_storage_buffer) {
  236. return _.diag(SPV_ERROR_INVALID_ID, inst)
  237. << "StorageBuffer pointer operand "
  238. << _.getIdName(argument_id)
  239. << " requires a variable pointers capability";
  240. }
  241. break;
  242. default:
  243. return _.diag(SPV_ERROR_INVALID_ID, inst)
  244. << "Invalid storage class for pointer operand "
  245. << _.getIdName(argument_id);
  246. }
  247. // Validate memory object declaration requirements.
  248. if (argument->opcode() != SpvOpVariable &&
  249. argument->opcode() != SpvOpFunctionParameter) {
  250. const bool ssbo_vptr =
  251. _.features().variable_pointers_storage_buffer &&
  252. sc == SpvStorageClassStorageBuffer;
  253. const bool wg_vptr =
  254. _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
  255. if (!ssbo_vptr && !wg_vptr) {
  256. return _.diag(SPV_ERROR_INVALID_ID, inst)
  257. << "Pointer operand " << _.getIdName(argument_id)
  258. << " must be a memory object declaration";
  259. }
  260. }
  261. }
  262. }
  263. }
  264. return SPV_SUCCESS;
  265. }
  266. } // namespace
  267. spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
  268. switch (inst->opcode()) {
  269. case SpvOpFunction:
  270. if (auto error = ValidateFunction(_, inst)) return error;
  271. break;
  272. case SpvOpFunctionParameter:
  273. if (auto error = ValidateFunctionParameter(_, inst)) return error;
  274. break;
  275. case SpvOpFunctionCall:
  276. if (auto error = ValidateFunctionCall(_, inst)) return error;
  277. break;
  278. default:
  279. break;
  280. }
  281. return SPV_SUCCESS;
  282. }
  283. } // namespace val
  284. } // namespace spvtools