validate_function.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <algorithm>
  15. #include "source/opcode.h"
  16. #include "source/val/instruction.h"
  17. #include "source/val/validate.h"
  18. #include "source/val/validation_state.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. // Returns true if |a| and |b| are instructions defining pointers that point to
  23. // types logically match and the decorations that apply to |b| are a subset
  24. // of the decorations that apply to |a|.
  25. bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
  26. ValidationState_t& _) {
  27. if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) {
  28. return false;
  29. }
  30. const auto& dec_a = _.id_decorations(a->id());
  31. const auto& dec_b = _.id_decorations(b->id());
  32. for (const auto& dec : dec_b) {
  33. if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
  34. return false;
  35. }
  36. }
  37. uint32_t a_type = a->GetOperandAs<uint32_t>(2);
  38. uint32_t b_type = b->GetOperandAs<uint32_t>(2);
  39. if (a_type == b_type) {
  40. return true;
  41. }
  42. Instruction* a_type_inst = _.FindDef(a_type);
  43. Instruction* b_type_inst = _.FindDef(b_type);
  44. return _.LogicallyMatch(a_type_inst, b_type_inst, true);
  45. }
  46. spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
  47. const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
  48. const auto function_type = _.FindDef(function_type_id);
  49. if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
  50. return _.diag(SPV_ERROR_INVALID_ID, inst)
  51. << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
  52. << "' is not a function type.";
  53. }
  54. const auto return_id = function_type->GetOperandAs<uint32_t>(1);
  55. if (return_id != inst->type_id()) {
  56. return _.diag(SPV_ERROR_INVALID_ID, inst)
  57. << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
  58. << "' does not match the Function Type's return type <id> '"
  59. << _.getIdName(return_id) << "'.";
  60. }
  61. const std::vector<SpvOp> acceptable = {
  62. SpvOpDecorate,
  63. SpvOpEnqueueKernel,
  64. SpvOpEntryPoint,
  65. SpvOpExecutionMode,
  66. SpvOpExecutionModeId,
  67. SpvOpFunctionCall,
  68. SpvOpGetKernelNDrangeSubGroupCount,
  69. SpvOpGetKernelNDrangeMaxSubGroupSize,
  70. SpvOpGetKernelWorkGroupSize,
  71. SpvOpGetKernelPreferredWorkGroupSizeMultiple,
  72. SpvOpGetKernelLocalSizeForSubgroupCount,
  73. SpvOpGetKernelMaxNumSubgroups,
  74. SpvOpName};
  75. for (auto& pair : inst->uses()) {
  76. const auto* use = pair.first;
  77. if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
  78. acceptable.end()) {
  79. return _.diag(SPV_ERROR_INVALID_ID, use)
  80. << "Invalid use of function result id " << _.getIdName(inst->id())
  81. << ".";
  82. }
  83. }
  84. return SPV_SUCCESS;
  85. }
  86. spv_result_t ValidateFunctionParameter(ValidationState_t& _,
  87. const Instruction* inst) {
  88. // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
  89. size_t param_index = 0;
  90. size_t inst_num = inst->LineNum() - 1;
  91. if (inst_num == 0) {
  92. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  93. << "Function parameter cannot be the first instruction.";
  94. }
  95. auto func_inst = &_.ordered_instructions()[inst_num];
  96. while (--inst_num) {
  97. func_inst = &_.ordered_instructions()[inst_num];
  98. if (func_inst->opcode() == SpvOpFunction) {
  99. break;
  100. } else if (func_inst->opcode() == SpvOpFunctionParameter) {
  101. ++param_index;
  102. }
  103. }
  104. if (func_inst->opcode() != SpvOpFunction) {
  105. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  106. << "Function parameter must be preceded by a function.";
  107. }
  108. const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
  109. const auto function_type = _.FindDef(function_type_id);
  110. if (!function_type) {
  111. return _.diag(SPV_ERROR_INVALID_ID, func_inst)
  112. << "Missing function type definition.";
  113. }
  114. if (param_index >= function_type->words().size() - 3) {
  115. return _.diag(SPV_ERROR_INVALID_ID, inst)
  116. << "Too many OpFunctionParameters for " << func_inst->id()
  117. << ": expected " << function_type->words().size() - 3
  118. << " based on the function's type";
  119. }
  120. const auto param_type =
  121. _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
  122. if (!param_type || inst->type_id() != param_type->id()) {
  123. return _.diag(SPV_ERROR_INVALID_ID, inst)
  124. << "OpFunctionParameter Result Type <id> '"
  125. << _.getIdName(inst->type_id())
  126. << "' does not match the OpTypeFunction parameter "
  127. "type of the same index.";
  128. }
  129. // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
  130. // RestrictPointerEXT, or AliasedPointerEXT.
  131. auto param_nonarray_type_id = param_type->id();
  132. while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
  133. param_nonarray_type_id =
  134. _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
  135. }
  136. if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
  137. auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
  138. if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
  139. SpvStorageClassPhysicalStorageBufferEXT) {
  140. // check for Aliased or Restrict
  141. const auto& decorations = _.id_decorations(inst->id());
  142. bool foundAliased = std::any_of(
  143. decorations.begin(), decorations.end(), [](const Decoration& d) {
  144. return SpvDecorationAliased == d.dec_type();
  145. });
  146. bool foundRestrict = std::any_of(
  147. decorations.begin(), decorations.end(), [](const Decoration& d) {
  148. return SpvDecorationRestrict == d.dec_type();
  149. });
  150. if (!foundAliased && !foundRestrict) {
  151. return _.diag(SPV_ERROR_INVALID_ID, inst)
  152. << "OpFunctionParameter " << inst->id()
  153. << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
  154. "pointer.";
  155. }
  156. if (foundAliased && foundRestrict) {
  157. return _.diag(SPV_ERROR_INVALID_ID, inst)
  158. << "OpFunctionParameter " << inst->id()
  159. << ": can't specify both Aliased and Restrict for "
  160. "PhysicalStorageBufferEXT pointer.";
  161. }
  162. } else {
  163. const auto pointee_type_id =
  164. param_nonarray_type->GetOperandAs<uint32_t>(2);
  165. const auto pointee_type = _.FindDef(pointee_type_id);
  166. if (SpvOpTypePointer == pointee_type->opcode() &&
  167. pointee_type->GetOperandAs<uint32_t>(1u) ==
  168. SpvStorageClassPhysicalStorageBufferEXT) {
  169. // check for AliasedPointerEXT/RestrictPointerEXT
  170. const auto& decorations = _.id_decorations(inst->id());
  171. bool foundAliased = std::any_of(
  172. decorations.begin(), decorations.end(), [](const Decoration& d) {
  173. return SpvDecorationAliasedPointerEXT == d.dec_type();
  174. });
  175. bool foundRestrict = std::any_of(
  176. decorations.begin(), decorations.end(), [](const Decoration& d) {
  177. return SpvDecorationRestrictPointerEXT == d.dec_type();
  178. });
  179. if (!foundAliased && !foundRestrict) {
  180. return _.diag(SPV_ERROR_INVALID_ID, inst)
  181. << "OpFunctionParameter " << inst->id()
  182. << ": expected AliasedPointerEXT or RestrictPointerEXT for "
  183. "PhysicalStorageBufferEXT pointer.";
  184. }
  185. if (foundAliased && foundRestrict) {
  186. return _.diag(SPV_ERROR_INVALID_ID, inst)
  187. << "OpFunctionParameter " << inst->id()
  188. << ": can't specify both AliasedPointerEXT and "
  189. "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
  190. }
  191. }
  192. }
  193. }
  194. return SPV_SUCCESS;
  195. }
  196. spv_result_t ValidateFunctionCall(ValidationState_t& _,
  197. const Instruction* inst) {
  198. const auto function_id = inst->GetOperandAs<uint32_t>(2);
  199. const auto function = _.FindDef(function_id);
  200. if (!function || SpvOpFunction != function->opcode()) {
  201. return _.diag(SPV_ERROR_INVALID_ID, inst)
  202. << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
  203. << "' is not a function.";
  204. }
  205. auto return_type = _.FindDef(function->type_id());
  206. if (!return_type || return_type->id() != inst->type_id()) {
  207. return _.diag(SPV_ERROR_INVALID_ID, inst)
  208. << "OpFunctionCall Result Type <id> '"
  209. << _.getIdName(inst->type_id())
  210. << "'s type does not match Function <id> '"
  211. << _.getIdName(return_type->id()) << "'s return type.";
  212. }
  213. const auto function_type_id = function->GetOperandAs<uint32_t>(3);
  214. const auto function_type = _.FindDef(function_type_id);
  215. if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
  216. return _.diag(SPV_ERROR_INVALID_ID, inst)
  217. << "Missing function type definition.";
  218. }
  219. const auto function_call_arg_count = inst->words().size() - 4;
  220. const auto function_param_count = function_type->words().size() - 3;
  221. if (function_param_count != function_call_arg_count) {
  222. return _.diag(SPV_ERROR_INVALID_ID, inst)
  223. << "OpFunctionCall Function <id>'s parameter count does not match "
  224. "the argument count.";
  225. }
  226. for (size_t argument_index = 3, param_index = 2;
  227. argument_index < inst->operands().size();
  228. argument_index++, param_index++) {
  229. const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
  230. const auto argument = _.FindDef(argument_id);
  231. if (!argument) {
  232. return _.diag(SPV_ERROR_INVALID_ID, inst)
  233. << "Missing argument " << argument_index - 3 << " definition.";
  234. }
  235. const auto argument_type = _.FindDef(argument->type_id());
  236. if (!argument_type) {
  237. return _.diag(SPV_ERROR_INVALID_ID, inst)
  238. << "Missing argument " << argument_index - 3
  239. << " type definition.";
  240. }
  241. const auto parameter_type_id =
  242. function_type->GetOperandAs<uint32_t>(param_index);
  243. const auto parameter_type = _.FindDef(parameter_type_id);
  244. if (!parameter_type || argument_type->id() != parameter_type->id()) {
  245. if (!_.options()->before_hlsl_legalization ||
  246. !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
  247. return _.diag(SPV_ERROR_INVALID_ID, inst)
  248. << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
  249. << "'s type does not match Function <id> '"
  250. << _.getIdName(parameter_type_id) << "'s parameter type.";
  251. }
  252. }
  253. if (_.addressing_model() == SpvAddressingModelLogical) {
  254. if (parameter_type->opcode() == SpvOpTypePointer &&
  255. !_.options()->relax_logical_pointer) {
  256. SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
  257. // Validate which storage classes can be pointer operands.
  258. switch (sc) {
  259. case SpvStorageClassUniformConstant:
  260. case SpvStorageClassFunction:
  261. case SpvStorageClassPrivate:
  262. case SpvStorageClassWorkgroup:
  263. case SpvStorageClassAtomicCounter:
  264. // These are always allowed.
  265. break;
  266. case SpvStorageClassStorageBuffer:
  267. if (!_.features().variable_pointers_storage_buffer) {
  268. return _.diag(SPV_ERROR_INVALID_ID, inst)
  269. << "StorageBuffer pointer operand "
  270. << _.getIdName(argument_id)
  271. << " requires a variable pointers capability";
  272. }
  273. break;
  274. default:
  275. return _.diag(SPV_ERROR_INVALID_ID, inst)
  276. << "Invalid storage class for pointer operand "
  277. << _.getIdName(argument_id);
  278. }
  279. // Validate memory object declaration requirements.
  280. if (argument->opcode() != SpvOpVariable &&
  281. argument->opcode() != SpvOpFunctionParameter) {
  282. const bool ssbo_vptr =
  283. _.features().variable_pointers_storage_buffer &&
  284. sc == SpvStorageClassStorageBuffer;
  285. const bool wg_vptr =
  286. _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
  287. const bool uc_ptr = sc == SpvStorageClassUniformConstant;
  288. if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
  289. return _.diag(SPV_ERROR_INVALID_ID, inst)
  290. << "Pointer operand " << _.getIdName(argument_id)
  291. << " must be a memory object declaration";
  292. }
  293. }
  294. }
  295. }
  296. }
  297. return SPV_SUCCESS;
  298. }
  299. } // namespace
  300. spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
  301. switch (inst->opcode()) {
  302. case SpvOpFunction:
  303. if (auto error = ValidateFunction(_, inst)) return error;
  304. break;
  305. case SpvOpFunctionParameter:
  306. if (auto error = ValidateFunctionParameter(_, inst)) return error;
  307. break;
  308. case SpvOpFunctionCall:
  309. if (auto error = ValidateFunctionCall(_, inst)) return error;
  310. break;
  311. default:
  312. break;
  313. }
  314. return SPV_SUCCESS;
  315. }
  316. } // namespace val
  317. } // namespace spvtools