validate_function.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <algorithm>
  15. #include "source/opcode.h"
  16. #include "source/val/instruction.h"
  17. #include "source/val/validate.h"
  18. #include "source/val/validation_state.h"
  19. namespace spvtools {
  20. namespace val {
  21. namespace {
  22. // Returns true if |a| and |b| are instructions defining pointers that point to
  23. // types logically match and the decorations that apply to |b| are a subset
  24. // of the decorations that apply to |a|.
  25. bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
  26. ValidationState_t& _) {
  27. if (a->opcode() != spv::Op::OpTypePointer ||
  28. b->opcode() != spv::Op::OpTypePointer) {
  29. return false;
  30. }
  31. const auto& dec_a = _.id_decorations(a->id());
  32. const auto& dec_b = _.id_decorations(b->id());
  33. for (const auto& dec : dec_b) {
  34. if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
  35. return false;
  36. }
  37. }
  38. uint32_t a_type = a->GetOperandAs<uint32_t>(2);
  39. uint32_t b_type = b->GetOperandAs<uint32_t>(2);
  40. if (a_type == b_type) {
  41. return true;
  42. }
  43. Instruction* a_type_inst = _.FindDef(a_type);
  44. Instruction* b_type_inst = _.FindDef(b_type);
  45. return _.LogicallyMatch(a_type_inst, b_type_inst, true);
  46. }
  47. spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
  48. const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
  49. const auto function_type = _.FindDef(function_type_id);
  50. if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
  51. return _.diag(SPV_ERROR_INVALID_ID, inst)
  52. << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
  53. << " is not a function type.";
  54. }
  55. const auto return_id = function_type->GetOperandAs<uint32_t>(1);
  56. if (return_id != inst->type_id()) {
  57. return _.diag(SPV_ERROR_INVALID_ID, inst)
  58. << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
  59. << " does not match the Function Type's return type <id> "
  60. << _.getIdName(return_id) << ".";
  61. }
  62. const std::vector<spv::Op> acceptable = {
  63. spv::Op::OpGroupDecorate,
  64. spv::Op::OpDecorate,
  65. spv::Op::OpEnqueueKernel,
  66. spv::Op::OpEntryPoint,
  67. spv::Op::OpExecutionMode,
  68. spv::Op::OpExecutionModeId,
  69. spv::Op::OpFunctionCall,
  70. spv::Op::OpGetKernelNDrangeSubGroupCount,
  71. spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
  72. spv::Op::OpGetKernelWorkGroupSize,
  73. spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
  74. spv::Op::OpGetKernelLocalSizeForSubgroupCount,
  75. spv::Op::OpGetKernelMaxNumSubgroups,
  76. spv::Op::OpName};
  77. for (auto& pair : inst->uses()) {
  78. const auto* use = pair.first;
  79. if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
  80. acceptable.end() &&
  81. !use->IsNonSemantic() && !use->IsDebugInfo()) {
  82. return _.diag(SPV_ERROR_INVALID_ID, use)
  83. << "Invalid use of function result id " << _.getIdName(inst->id())
  84. << ".";
  85. }
  86. }
  87. return SPV_SUCCESS;
  88. }
  89. spv_result_t ValidateFunctionParameter(ValidationState_t& _,
  90. const Instruction* inst) {
  91. // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
  92. size_t param_index = 0;
  93. size_t inst_num = inst->LineNum() - 1;
  94. if (inst_num == 0) {
  95. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  96. << "Function parameter cannot be the first instruction.";
  97. }
  98. auto func_inst = &_.ordered_instructions()[inst_num];
  99. while (--inst_num) {
  100. func_inst = &_.ordered_instructions()[inst_num];
  101. if (func_inst->opcode() == spv::Op::OpFunction) {
  102. break;
  103. } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
  104. ++param_index;
  105. }
  106. }
  107. if (func_inst->opcode() != spv::Op::OpFunction) {
  108. return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
  109. << "Function parameter must be preceded by a function.";
  110. }
  111. const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
  112. const auto function_type = _.FindDef(function_type_id);
  113. if (!function_type) {
  114. return _.diag(SPV_ERROR_INVALID_ID, func_inst)
  115. << "Missing function type definition.";
  116. }
  117. if (param_index >= function_type->words().size() - 3) {
  118. return _.diag(SPV_ERROR_INVALID_ID, inst)
  119. << "Too many OpFunctionParameters for " << func_inst->id()
  120. << ": expected " << function_type->words().size() - 3
  121. << " based on the function's type";
  122. }
  123. const auto param_type =
  124. _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
  125. if (!param_type || inst->type_id() != param_type->id()) {
  126. return _.diag(SPV_ERROR_INVALID_ID, inst)
  127. << "OpFunctionParameter Result Type <id> "
  128. << _.getIdName(inst->type_id())
  129. << " does not match the OpTypeFunction parameter "
  130. "type of the same index.";
  131. }
  132. // Validate that PhysicalStorageBuffer have one of Restrict, Aliased,
  133. // RestrictPointer, or AliasedPointer.
  134. auto param_nonarray_type_id = param_type->id();
  135. while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) {
  136. param_nonarray_type_id =
  137. _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
  138. }
  139. if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer) {
  140. auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
  141. if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) ==
  142. spv::StorageClass::PhysicalStorageBuffer) {
  143. // check for Aliased or Restrict
  144. const auto& decorations = _.id_decorations(inst->id());
  145. bool foundAliased = std::any_of(
  146. decorations.begin(), decorations.end(), [](const Decoration& d) {
  147. return spv::Decoration::Aliased == d.dec_type();
  148. });
  149. bool foundRestrict = std::any_of(
  150. decorations.begin(), decorations.end(), [](const Decoration& d) {
  151. return spv::Decoration::Restrict == d.dec_type();
  152. });
  153. if (!foundAliased && !foundRestrict) {
  154. return _.diag(SPV_ERROR_INVALID_ID, inst)
  155. << "OpFunctionParameter " << inst->id()
  156. << ": expected Aliased or Restrict for PhysicalStorageBuffer "
  157. "pointer.";
  158. }
  159. if (foundAliased && foundRestrict) {
  160. return _.diag(SPV_ERROR_INVALID_ID, inst)
  161. << "OpFunctionParameter " << inst->id()
  162. << ": can't specify both Aliased and Restrict for "
  163. "PhysicalStorageBuffer pointer.";
  164. }
  165. } else {
  166. const auto pointee_type_id =
  167. param_nonarray_type->GetOperandAs<uint32_t>(2);
  168. const auto pointee_type = _.FindDef(pointee_type_id);
  169. if (spv::Op::OpTypePointer == pointee_type->opcode() &&
  170. pointee_type->GetOperandAs<spv::StorageClass>(1u) ==
  171. spv::StorageClass::PhysicalStorageBuffer) {
  172. // check for AliasedPointer/RestrictPointer
  173. const auto& decorations = _.id_decorations(inst->id());
  174. bool foundAliased = std::any_of(
  175. decorations.begin(), decorations.end(), [](const Decoration& d) {
  176. return spv::Decoration::AliasedPointer == d.dec_type();
  177. });
  178. bool foundRestrict = std::any_of(
  179. decorations.begin(), decorations.end(), [](const Decoration& d) {
  180. return spv::Decoration::RestrictPointer == d.dec_type();
  181. });
  182. if (!foundAliased && !foundRestrict) {
  183. return _.diag(SPV_ERROR_INVALID_ID, inst)
  184. << "OpFunctionParameter " << inst->id()
  185. << ": expected AliasedPointer or RestrictPointer for "
  186. "PhysicalStorageBuffer pointer.";
  187. }
  188. if (foundAliased && foundRestrict) {
  189. return _.diag(SPV_ERROR_INVALID_ID, inst)
  190. << "OpFunctionParameter " << inst->id()
  191. << ": can't specify both AliasedPointer and "
  192. "RestrictPointer for PhysicalStorageBuffer pointer.";
  193. }
  194. }
  195. }
  196. }
  197. return SPV_SUCCESS;
  198. }
  199. spv_result_t ValidateFunctionCall(ValidationState_t& _,
  200. const Instruction* inst) {
  201. const auto function_id = inst->GetOperandAs<uint32_t>(2);
  202. const auto function = _.FindDef(function_id);
  203. if (!function || spv::Op::OpFunction != function->opcode()) {
  204. return _.diag(SPV_ERROR_INVALID_ID, inst)
  205. << "OpFunctionCall Function <id> " << _.getIdName(function_id)
  206. << " is not a function.";
  207. }
  208. auto return_type = _.FindDef(function->type_id());
  209. if (!return_type || return_type->id() != inst->type_id()) {
  210. return _.diag(SPV_ERROR_INVALID_ID, inst)
  211. << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
  212. << "s type does not match Function <id> "
  213. << _.getIdName(return_type->id()) << "s return type.";
  214. }
  215. const auto function_type_id = function->GetOperandAs<uint32_t>(3);
  216. const auto function_type = _.FindDef(function_type_id);
  217. if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
  218. return _.diag(SPV_ERROR_INVALID_ID, inst)
  219. << "Missing function type definition.";
  220. }
  221. const auto function_call_arg_count = inst->words().size() - 4;
  222. const auto function_param_count = function_type->words().size() - 3;
  223. if (function_param_count != function_call_arg_count) {
  224. return _.diag(SPV_ERROR_INVALID_ID, inst)
  225. << "OpFunctionCall Function <id>'s parameter count does not match "
  226. "the argument count.";
  227. }
  228. for (size_t argument_index = 3, param_index = 2;
  229. argument_index < inst->operands().size();
  230. argument_index++, param_index++) {
  231. const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
  232. const auto argument = _.FindDef(argument_id);
  233. if (!argument) {
  234. return _.diag(SPV_ERROR_INVALID_ID, inst)
  235. << "Missing argument " << argument_index - 3 << " definition.";
  236. }
  237. const auto argument_type = _.FindDef(argument->type_id());
  238. if (!argument_type) {
  239. return _.diag(SPV_ERROR_INVALID_ID, inst)
  240. << "Missing argument " << argument_index - 3
  241. << " type definition.";
  242. }
  243. const auto parameter_type_id =
  244. function_type->GetOperandAs<uint32_t>(param_index);
  245. const auto parameter_type = _.FindDef(parameter_type_id);
  246. if (!parameter_type || argument_type->id() != parameter_type->id()) {
  247. if (!_.options()->before_hlsl_legalization ||
  248. !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
  249. return _.diag(SPV_ERROR_INVALID_ID, inst)
  250. << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
  251. << "s type does not match Function <id> "
  252. << _.getIdName(parameter_type_id) << "s parameter type.";
  253. }
  254. }
  255. if (_.addressing_model() == spv::AddressingModel::Logical) {
  256. if (parameter_type->opcode() == spv::Op::OpTypePointer &&
  257. !_.options()->relax_logical_pointer) {
  258. spv::StorageClass sc =
  259. parameter_type->GetOperandAs<spv::StorageClass>(1u);
  260. // Validate which storage classes can be pointer operands.
  261. switch (sc) {
  262. case spv::StorageClass::UniformConstant:
  263. case spv::StorageClass::Function:
  264. case spv::StorageClass::Private:
  265. case spv::StorageClass::Workgroup:
  266. case spv::StorageClass::AtomicCounter:
  267. // These are always allowed.
  268. break;
  269. case spv::StorageClass::StorageBuffer:
  270. if (!_.features().variable_pointers) {
  271. return _.diag(SPV_ERROR_INVALID_ID, inst)
  272. << "StorageBuffer pointer operand "
  273. << _.getIdName(argument_id)
  274. << " requires a variable pointers capability";
  275. }
  276. break;
  277. default:
  278. return _.diag(SPV_ERROR_INVALID_ID, inst)
  279. << "Invalid storage class for pointer operand "
  280. << _.getIdName(argument_id);
  281. }
  282. // Validate memory object declaration requirements.
  283. if (argument->opcode() != spv::Op::OpVariable &&
  284. argument->opcode() != spv::Op::OpFunctionParameter) {
  285. const bool ssbo_vptr = _.features().variable_pointers &&
  286. sc == spv::StorageClass::StorageBuffer;
  287. const bool wg_vptr =
  288. _.HasCapability(spv::Capability::VariablePointers) &&
  289. sc == spv::StorageClass::Workgroup;
  290. const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
  291. if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
  292. return _.diag(SPV_ERROR_INVALID_ID, inst)
  293. << "Pointer operand " << _.getIdName(argument_id)
  294. << " must be a memory object declaration";
  295. }
  296. }
  297. }
  298. }
  299. }
  300. return SPV_SUCCESS;
  301. }
  302. } // namespace
  303. spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
  304. switch (inst->opcode()) {
  305. case spv::Op::OpFunction:
  306. if (auto error = ValidateFunction(_, inst)) return error;
  307. break;
  308. case spv::Op::OpFunctionParameter:
  309. if (auto error = ValidateFunctionParameter(_, inst)) return error;
  310. break;
  311. case spv::Op::OpFunctionCall:
  312. if (auto error = ValidateFunctionCall(_, inst)) return error;
  313. break;
  314. default:
  315. break;
  316. }
  317. return SPV_SUCCESS;
  318. }
  319. } // namespace val
  320. } // namespace spvtools