validate_interfaces.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <algorithm>
  15. #include <vector>
  16. #include "source/diagnostic.h"
  17. #include "source/spirv_constant.h"
  18. #include "source/spirv_target_env.h"
  19. #include "source/val/function.h"
  20. #include "source/val/instruction.h"
  21. #include "source/val/validate.h"
  22. #include "source/val/validation_state.h"
  23. namespace spvtools {
  24. namespace val {
  25. namespace {
  26. // Limit the number of checked locations to 4096. Multiplied by 4 to represent
  27. // all the components. This limit is set to be well beyond practical use cases.
  28. const uint32_t kMaxLocations = 4096 * 4;
  29. // Returns true if \c inst is an input or output variable.
  30. bool is_interface_variable(const Instruction* inst, bool is_spv_1_4) {
  31. if (is_spv_1_4) {
  32. // Starting in SPIR-V 1.4, all global variables are interface variables.
  33. return inst->opcode() == SpvOpVariable &&
  34. inst->word(3u) != SpvStorageClassFunction;
  35. } else {
  36. return inst->opcode() == SpvOpVariable &&
  37. (inst->word(3u) == SpvStorageClassInput ||
  38. inst->word(3u) == SpvStorageClassOutput);
  39. }
  40. }
  41. // Checks that \c var is listed as an interface in all the entry points that use
  42. // it.
  43. spv_result_t check_interface_variable(ValidationState_t& _,
  44. const Instruction* var) {
  45. std::vector<const Function*> functions;
  46. std::vector<const Instruction*> uses;
  47. for (auto use : var->uses()) {
  48. uses.push_back(use.first);
  49. }
  50. for (uint32_t i = 0; i < uses.size(); ++i) {
  51. const auto user = uses[i];
  52. if (const Function* func = user->function()) {
  53. functions.push_back(func);
  54. } else {
  55. // In the rare case that the variable is used by another instruction in
  56. // the global scope, continue searching for an instruction used in a
  57. // function.
  58. for (auto use : user->uses()) {
  59. uses.push_back(use.first);
  60. }
  61. }
  62. }
  63. std::sort(functions.begin(), functions.end(),
  64. [](const Function* lhs, const Function* rhs) {
  65. return lhs->id() < rhs->id();
  66. });
  67. functions.erase(std::unique(functions.begin(), functions.end()),
  68. functions.end());
  69. std::vector<uint32_t> entry_points;
  70. for (const auto func : functions) {
  71. for (auto id : _.FunctionEntryPoints(func->id())) {
  72. entry_points.push_back(id);
  73. }
  74. }
  75. std::sort(entry_points.begin(), entry_points.end());
  76. entry_points.erase(std::unique(entry_points.begin(), entry_points.end()),
  77. entry_points.end());
  78. for (auto id : entry_points) {
  79. for (const auto& desc : _.entry_point_descriptions(id)) {
  80. bool found = false;
  81. for (auto interface : desc.interfaces) {
  82. if (var->id() == interface) {
  83. found = true;
  84. break;
  85. }
  86. }
  87. if (!found) {
  88. return _.diag(SPV_ERROR_INVALID_ID, var)
  89. << "Interface variable id <" << var->id()
  90. << "> is used by entry point '" << desc.name << "' id <" << id
  91. << ">, but is not listed as an interface";
  92. }
  93. }
  94. }
  95. return SPV_SUCCESS;
  96. }
  97. // This function assumes a base location has been determined already. As such
  98. // any further location decorations are invalid.
  99. // TODO: if this code turns out to be slow, there is an opportunity to cache
  100. // the result for a given type id.
  101. spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
  102. uint32_t* num_locations) {
  103. *num_locations = 0;
  104. switch (type->opcode()) {
  105. case SpvOpTypeInt:
  106. case SpvOpTypeFloat:
  107. // Scalars always consume a single location.
  108. *num_locations = 1;
  109. break;
  110. case SpvOpTypeVector:
  111. // 3- and 4-component 64-bit vectors consume two locations.
  112. if ((_.ContainsSizedIntOrFloatType(type->id(), SpvOpTypeInt, 64) ||
  113. _.ContainsSizedIntOrFloatType(type->id(), SpvOpTypeFloat, 64)) &&
  114. (type->GetOperandAs<uint32_t>(2) > 2)) {
  115. *num_locations = 2;
  116. } else {
  117. *num_locations = 1;
  118. }
  119. break;
  120. case SpvOpTypeMatrix:
  121. // Matrices consume locations equal to the underlying vector type for
  122. // each column.
  123. NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
  124. num_locations);
  125. *num_locations *= type->GetOperandAs<uint32_t>(2);
  126. break;
  127. case SpvOpTypeArray: {
  128. // Arrays consume locations equal to the underlying type times the number
  129. // of elements in the vector.
  130. NumConsumedLocations(_, _.FindDef(type->GetOperandAs<uint32_t>(1)),
  131. num_locations);
  132. bool is_int = false;
  133. bool is_const = false;
  134. uint32_t value = 0;
  135. // Attempt to evaluate the number of array elements.
  136. std::tie(is_int, is_const, value) =
  137. _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
  138. if (is_int && is_const) *num_locations *= value;
  139. break;
  140. }
  141. case SpvOpTypeStruct: {
  142. // Members cannot have location decorations at this point.
  143. if (_.HasDecoration(type->id(), SpvDecorationLocation)) {
  144. return _.diag(SPV_ERROR_INVALID_DATA, type)
  145. << "Members cannot be assigned a location";
  146. }
  147. // Structs consume locations equal to the sum of the locations consumed
  148. // by the members.
  149. for (uint32_t i = 1; i < type->operands().size(); ++i) {
  150. uint32_t member_locations = 0;
  151. if (auto error = NumConsumedLocations(
  152. _, _.FindDef(type->GetOperandAs<uint32_t>(i)),
  153. &member_locations)) {
  154. return error;
  155. }
  156. *num_locations += member_locations;
  157. }
  158. break;
  159. }
  160. default:
  161. break;
  162. }
  163. return SPV_SUCCESS;
  164. }
  165. // Returns the number of components consumed by types that support a component
  166. // decoration.
  167. uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
  168. uint32_t num_components = 0;
  169. switch (type->opcode()) {
  170. case SpvOpTypeInt:
  171. case SpvOpTypeFloat:
  172. // 64-bit types consume two components.
  173. if (type->GetOperandAs<uint32_t>(1) == 64) {
  174. num_components = 2;
  175. } else {
  176. num_components = 1;
  177. }
  178. break;
  179. case SpvOpTypeVector:
  180. // Vectors consume components equal to the underlying type's consumption
  181. // times the number of elements in the vector. Note that 3- and 4-element
  182. // vectors cannot have a component decoration (i.e. assumed to be zero).
  183. num_components =
  184. NumConsumedComponents(_, _.FindDef(type->GetOperandAs<uint32_t>(1)));
  185. num_components *= type->GetOperandAs<uint32_t>(2);
  186. break;
  187. default:
  188. // This is an error that is validated elsewhere.
  189. break;
  190. }
  191. return num_components;
  192. }
  193. // Populates |locations| (and/or |output_index1_locations|) with the use
  194. // location and component coordinates for |variable|. Indices are calculated as
  195. // 4 * location + component.
  196. spv_result_t GetLocationsForVariable(
  197. ValidationState_t& _, const Instruction* entry_point,
  198. const Instruction* variable, std::unordered_set<uint32_t>* locations,
  199. std::unordered_set<uint32_t>* output_index1_locations) {
  200. const bool is_fragment = entry_point->GetOperandAs<SpvExecutionModel>(0) ==
  201. SpvExecutionModelFragment;
  202. const bool is_output =
  203. variable->GetOperandAs<SpvStorageClass>(2) == SpvStorageClassOutput;
  204. auto ptr_type_id = variable->GetOperandAs<uint32_t>(0);
  205. auto ptr_type = _.FindDef(ptr_type_id);
  206. auto type_id = ptr_type->GetOperandAs<uint32_t>(2);
  207. auto type = _.FindDef(type_id);
  208. // Check for Location, Component and Index decorations on the variable. The
  209. // validator allows duplicate decorations if the location/component/index are
  210. // equal. Also track Patch and PerTaskNV decorations.
  211. bool has_location = false;
  212. uint32_t location = 0;
  213. bool has_component = false;
  214. uint32_t component = 0;
  215. bool has_index = false;
  216. uint32_t index = 0;
  217. bool has_patch = false;
  218. bool has_per_task_nv = false;
  219. bool has_per_vertex_nv = false;
  220. for (auto& dec : _.id_decorations(variable->id())) {
  221. if (dec.dec_type() == SpvDecorationLocation) {
  222. if (has_location && dec.params()[0] != location) {
  223. return _.diag(SPV_ERROR_INVALID_DATA, variable)
  224. << "Variable has conflicting location decorations";
  225. }
  226. has_location = true;
  227. location = dec.params()[0];
  228. } else if (dec.dec_type() == SpvDecorationComponent) {
  229. if (has_component && dec.params()[0] != component) {
  230. return _.diag(SPV_ERROR_INVALID_DATA, variable)
  231. << "Variable has conflicting component decorations";
  232. }
  233. has_component = true;
  234. component = dec.params()[0];
  235. } else if (dec.dec_type() == SpvDecorationIndex) {
  236. if (!is_output || !is_fragment) {
  237. return _.diag(SPV_ERROR_INVALID_DATA, variable)
  238. << "Index can only be applied to Fragment output variables";
  239. }
  240. if (has_index && dec.params()[0] != index) {
  241. return _.diag(SPV_ERROR_INVALID_DATA, variable)
  242. << "Variable has conflicting index decorations";
  243. }
  244. has_index = true;
  245. index = dec.params()[0];
  246. } else if (dec.dec_type() == SpvDecorationBuiltIn) {
  247. // Don't check built-ins.
  248. return SPV_SUCCESS;
  249. } else if (dec.dec_type() == SpvDecorationPatch) {
  250. has_patch = true;
  251. } else if (dec.dec_type() == SpvDecorationPerTaskNV) {
  252. has_per_task_nv = true;
  253. } else if (dec.dec_type() == SpvDecorationPerVertexNV) {
  254. has_per_vertex_nv = true;
  255. }
  256. }
  257. // Vulkan 14.1.3: Tessellation control and mesh per-vertex outputs and
  258. // tessellation control, evaluation and geometry per-vertex inputs have a
  259. // layer of arraying that is not included in interface matching.
  260. bool is_arrayed = false;
  261. switch (entry_point->GetOperandAs<SpvExecutionModel>(0)) {
  262. case SpvExecutionModelTessellationControl:
  263. if (!has_patch) {
  264. is_arrayed = true;
  265. }
  266. break;
  267. case SpvExecutionModelTessellationEvaluation:
  268. if (!is_output && !has_patch) {
  269. is_arrayed = true;
  270. }
  271. break;
  272. case SpvExecutionModelGeometry:
  273. if (!is_output) {
  274. is_arrayed = true;
  275. }
  276. break;
  277. case SpvExecutionModelFragment:
  278. if (!is_output && has_per_vertex_nv) {
  279. is_arrayed = true;
  280. }
  281. break;
  282. case SpvExecutionModelMeshNV:
  283. if (is_output && !has_per_task_nv) {
  284. is_arrayed = true;
  285. }
  286. break;
  287. default:
  288. break;
  289. }
  290. // Unpack arrayness.
  291. if (is_arrayed && (type->opcode() == SpvOpTypeArray ||
  292. type->opcode() == SpvOpTypeRuntimeArray)) {
  293. type_id = type->GetOperandAs<uint32_t>(1);
  294. type = _.FindDef(type_id);
  295. }
  296. if (type->opcode() == SpvOpTypeStruct) {
  297. // Don't check built-ins.
  298. if (_.HasDecoration(type_id, SpvDecorationBuiltIn)) return SPV_SUCCESS;
  299. }
  300. // Only block-decorated structs don't need a location on the variable.
  301. const bool is_block = _.HasDecoration(type_id, SpvDecorationBlock);
  302. if (!has_location && !is_block) {
  303. return _.diag(SPV_ERROR_INVALID_DATA, variable)
  304. << "Variable must be decorated with a location";
  305. }
  306. const std::string storage_class = is_output ? "output" : "input";
  307. if (has_location) {
  308. auto sub_type = type;
  309. bool is_int = false;
  310. bool is_const = false;
  311. uint32_t array_size = 1;
  312. // If the variable is still arrayed, mark the locations/components per
  313. // index.
  314. if (type->opcode() == SpvOpTypeArray) {
  315. // Determine the array size if possible and get the element type.
  316. std::tie(is_int, is_const, array_size) =
  317. _.EvalInt32IfConst(type->GetOperandAs<uint32_t>(2));
  318. if (!is_int || !is_const) array_size = 1;
  319. auto sub_type_id = type->GetOperandAs<uint32_t>(1);
  320. sub_type = _.FindDef(sub_type_id);
  321. }
  322. for (uint32_t array_idx = 0; array_idx < array_size; ++array_idx) {
  323. uint32_t num_locations = 0;
  324. if (auto error = NumConsumedLocations(_, sub_type, &num_locations))
  325. return error;
  326. uint32_t num_components = NumConsumedComponents(_, sub_type);
  327. uint32_t array_location = location + (num_locations * array_idx);
  328. uint32_t start = array_location * 4;
  329. if (kMaxLocations <= start) {
  330. // Too many locations, give up.
  331. break;
  332. }
  333. uint32_t end = (array_location + num_locations) * 4;
  334. if (num_components != 0) {
  335. start += component;
  336. end = array_location * 4 + component + num_components;
  337. }
  338. auto locs = locations;
  339. if (has_index && index == 1) locs = output_index1_locations;
  340. for (uint32_t i = start; i < end; ++i) {
  341. if (!locs->insert(i).second) {
  342. return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
  343. << "Entry-point has conflicting " << storage_class
  344. << " location assignment at location " << i / 4
  345. << ", component " << i % 4;
  346. }
  347. }
  348. }
  349. } else {
  350. // For Block-decorated structs with no location assigned to the variable,
  351. // each member of the block must be assigned a location. Also record any
  352. // member component assignments. The validator allows duplicate decorations
  353. // if they agree on the location/component.
  354. std::unordered_map<uint32_t, uint32_t> member_locations;
  355. std::unordered_map<uint32_t, uint32_t> member_components;
  356. for (auto& dec : _.id_decorations(type_id)) {
  357. if (dec.dec_type() == SpvDecorationLocation) {
  358. auto where = member_locations.find(dec.struct_member_index());
  359. if (where == member_locations.end()) {
  360. member_locations[dec.struct_member_index()] = dec.params()[0];
  361. } else if (where->second != dec.params()[0]) {
  362. return _.diag(SPV_ERROR_INVALID_DATA, type)
  363. << "Member index " << dec.struct_member_index()
  364. << " has conflicting location assignments";
  365. }
  366. } else if (dec.dec_type() == SpvDecorationComponent) {
  367. auto where = member_components.find(dec.struct_member_index());
  368. if (where == member_components.end()) {
  369. member_components[dec.struct_member_index()] = dec.params()[0];
  370. } else if (where->second != dec.params()[0]) {
  371. return _.diag(SPV_ERROR_INVALID_DATA, type)
  372. << "Member index " << dec.struct_member_index()
  373. << " has conflicting component assignments";
  374. }
  375. }
  376. }
  377. for (uint32_t i = 1; i < type->operands().size(); ++i) {
  378. auto where = member_locations.find(i - 1);
  379. if (where == member_locations.end()) {
  380. return _.diag(SPV_ERROR_INVALID_DATA, type)
  381. << "Member index " << i - 1
  382. << " is missing a location assignment";
  383. }
  384. location = where->second;
  385. auto member = _.FindDef(type->GetOperandAs<uint32_t>(i));
  386. uint32_t num_locations = 0;
  387. if (auto error = NumConsumedLocations(_, member, &num_locations))
  388. return error;
  389. // If the component is not specified, it is assumed to be zero.
  390. uint32_t num_components = NumConsumedComponents(_, member);
  391. component = 0;
  392. if (member_components.count(i - 1)) {
  393. component = member_components[i - 1];
  394. }
  395. uint32_t start = location * 4;
  396. if (kMaxLocations <= start) {
  397. // Too many locations, give up.
  398. continue;
  399. }
  400. uint32_t end = (location + num_locations) * 4;
  401. if (num_components != 0) {
  402. start += component;
  403. end = location * 4 + component + num_components;
  404. }
  405. for (uint32_t l = start; l < end; ++l) {
  406. if (!locations->insert(l).second) {
  407. return _.diag(SPV_ERROR_INVALID_DATA, entry_point)
  408. << "Entry-point has conflicting " << storage_class
  409. << " location assignment at location " << l / 4
  410. << ", component " << l % 4;
  411. }
  412. }
  413. }
  414. }
  415. return SPV_SUCCESS;
  416. }
  417. spv_result_t ValidateLocations(ValidationState_t& _,
  418. const Instruction* entry_point) {
  419. // According to Vulkan 14.1 only the following execution models have
  420. // locations assigned.
  421. switch (entry_point->GetOperandAs<SpvExecutionModel>(0)) {
  422. case SpvExecutionModelVertex:
  423. case SpvExecutionModelTessellationControl:
  424. case SpvExecutionModelTessellationEvaluation:
  425. case SpvExecutionModelGeometry:
  426. case SpvExecutionModelFragment:
  427. break;
  428. default:
  429. return SPV_SUCCESS;
  430. }
  431. // Locations are stored as a combined location and component values.
  432. std::unordered_set<uint32_t> input_locations;
  433. std::unordered_set<uint32_t> output_locations_index0;
  434. std::unordered_set<uint32_t> output_locations_index1;
  435. std::unordered_set<uint32_t> seen;
  436. for (uint32_t i = 3; i < entry_point->operands().size(); ++i) {
  437. auto interface_id = entry_point->GetOperandAs<uint32_t>(i);
  438. auto interface_var = _.FindDef(interface_id);
  439. auto storage_class = interface_var->GetOperandAs<SpvStorageClass>(2);
  440. if (storage_class != SpvStorageClassInput &&
  441. storage_class != SpvStorageClassOutput) {
  442. continue;
  443. }
  444. if (!seen.insert(interface_id).second) {
  445. // Pre-1.4 an interface variable could be listed multiple times in an
  446. // entry point. Validation for 1.4 or later is done elsewhere.
  447. continue;
  448. }
  449. auto locations = (storage_class == SpvStorageClassInput)
  450. ? &input_locations
  451. : &output_locations_index0;
  452. if (auto error = GetLocationsForVariable(
  453. _, entry_point, interface_var, locations, &output_locations_index1))
  454. return error;
  455. }
  456. return SPV_SUCCESS;
  457. }
  458. } // namespace
  459. spv_result_t ValidateInterfaces(ValidationState_t& _) {
  460. bool is_spv_1_4 = _.version() >= SPV_SPIRV_VERSION_WORD(1, 4);
  461. for (auto& inst : _.ordered_instructions()) {
  462. if (is_interface_variable(&inst, is_spv_1_4)) {
  463. if (auto error = check_interface_variable(_, &inst)) {
  464. return error;
  465. }
  466. }
  467. }
  468. if (spvIsVulkanEnv(_.context()->target_env)) {
  469. for (auto& inst : _.ordered_instructions()) {
  470. if (inst.opcode() == SpvOpEntryPoint) {
  471. if (auto error = ValidateLocations(_, &inst)) {
  472. return error;
  473. }
  474. }
  475. if (inst.opcode() == SpvOpTypeVoid) break;
  476. }
  477. }
  478. return SPV_SUCCESS;
  479. }
  480. } // namespace val
  481. } // namespace spvtools