validate_constants.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opcode.h"
  15. #include "source/val/instruction.h"
  16. #include "source/val/validate.h"
  17. #include "source/val/validation_state.h"
  18. namespace spvtools {
  19. namespace val {
  20. namespace {
  21. spv_result_t ValidateConstantBool(ValidationState_t& _,
  22. const Instruction* inst) {
  23. auto type = _.FindDef(inst->type_id());
  24. if (!type || type->opcode() != SpvOpTypeBool) {
  25. return _.diag(SPV_ERROR_INVALID_ID, inst)
  26. << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '"
  27. << _.getIdName(inst->type_id()) << "' is not a boolean type.";
  28. }
  29. return SPV_SUCCESS;
  30. }
  31. spv_result_t ValidateConstantComposite(ValidationState_t& _,
  32. const Instruction* inst) {
  33. std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
  34. const auto result_type = _.FindDef(inst->type_id());
  35. if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
  36. return _.diag(SPV_ERROR_INVALID_ID, inst)
  37. << opcode_name << " Result Type <id> '"
  38. << _.getIdName(inst->type_id()) << "' is not a composite type.";
  39. }
  40. const auto constituent_count = inst->words().size() - 3;
  41. switch (result_type->opcode()) {
  42. case SpvOpTypeVector: {
  43. const auto component_count = result_type->GetOperandAs<uint32_t>(2);
  44. if (component_count != constituent_count) {
  45. // TODO: Output ID's on diagnostic
  46. return _.diag(SPV_ERROR_INVALID_ID, inst)
  47. << opcode_name
  48. << " Constituent <id> count does not match "
  49. "Result Type <id> '"
  50. << _.getIdName(result_type->id())
  51. << "'s vector component count.";
  52. }
  53. const auto component_type =
  54. _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  55. if (!component_type) {
  56. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  57. << "Component type is not defined.";
  58. }
  59. for (size_t constituent_index = 2;
  60. constituent_index < inst->operands().size(); constituent_index++) {
  61. const auto constituent_id =
  62. inst->GetOperandAs<uint32_t>(constituent_index);
  63. const auto constituent = _.FindDef(constituent_id);
  64. if (!constituent ||
  65. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  66. return _.diag(SPV_ERROR_INVALID_ID, inst)
  67. << opcode_name << " Constituent <id> '"
  68. << _.getIdName(constituent_id)
  69. << "' is not a constant or undef.";
  70. }
  71. const auto constituent_result_type = _.FindDef(constituent->type_id());
  72. if (!constituent_result_type ||
  73. component_type->opcode() != constituent_result_type->opcode()) {
  74. return _.diag(SPV_ERROR_INVALID_ID, inst)
  75. << opcode_name << " Constituent <id> '"
  76. << _.getIdName(constituent_id)
  77. << "'s type does not match Result Type <id> '"
  78. << _.getIdName(result_type->id()) << "'s vector element type.";
  79. }
  80. }
  81. } break;
  82. case SpvOpTypeMatrix: {
  83. const auto column_count = result_type->GetOperandAs<uint32_t>(2);
  84. if (column_count != constituent_count) {
  85. // TODO: Output ID's on diagnostic
  86. return _.diag(SPV_ERROR_INVALID_ID, inst)
  87. << opcode_name
  88. << " Constituent <id> count does not match "
  89. "Result Type <id> '"
  90. << _.getIdName(result_type->id()) << "'s matrix column count.";
  91. }
  92. const auto column_type = _.FindDef(result_type->words()[2]);
  93. if (!column_type) {
  94. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  95. << "Column type is not defined.";
  96. }
  97. const auto component_count = column_type->GetOperandAs<uint32_t>(2);
  98. const auto component_type =
  99. _.FindDef(column_type->GetOperandAs<uint32_t>(1));
  100. if (!component_type) {
  101. return _.diag(SPV_ERROR_INVALID_ID, column_type)
  102. << "Component type is not defined.";
  103. }
  104. for (size_t constituent_index = 2;
  105. constituent_index < inst->operands().size(); constituent_index++) {
  106. const auto constituent_id =
  107. inst->GetOperandAs<uint32_t>(constituent_index);
  108. const auto constituent = _.FindDef(constituent_id);
  109. if (!constituent ||
  110. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  111. // The message says "... or undef" because the spec does not say
  112. // undef is a constant.
  113. return _.diag(SPV_ERROR_INVALID_ID, inst)
  114. << opcode_name << " Constituent <id> '"
  115. << _.getIdName(constituent_id)
  116. << "' is not a constant or undef.";
  117. }
  118. const auto vector = _.FindDef(constituent->type_id());
  119. if (!vector) {
  120. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  121. << "Result type is not defined.";
  122. }
  123. if (column_type->opcode() != vector->opcode()) {
  124. return _.diag(SPV_ERROR_INVALID_ID, inst)
  125. << opcode_name << " Constituent <id> '"
  126. << _.getIdName(constituent_id)
  127. << "' type does not match Result Type <id> '"
  128. << _.getIdName(result_type->id()) << "'s matrix column type.";
  129. }
  130. const auto vector_component_type =
  131. _.FindDef(vector->GetOperandAs<uint32_t>(1));
  132. if (component_type->id() != vector_component_type->id()) {
  133. return _.diag(SPV_ERROR_INVALID_ID, inst)
  134. << opcode_name << " Constituent <id> '"
  135. << _.getIdName(constituent_id)
  136. << "' component type does not match Result Type <id> '"
  137. << _.getIdName(result_type->id())
  138. << "'s matrix column component type.";
  139. }
  140. if (component_count != vector->words()[3]) {
  141. return _.diag(SPV_ERROR_INVALID_ID, inst)
  142. << opcode_name << " Constituent <id> '"
  143. << _.getIdName(constituent_id)
  144. << "' vector component count does not match Result Type <id> '"
  145. << _.getIdName(result_type->id())
  146. << "'s vector component count.";
  147. }
  148. }
  149. } break;
  150. case SpvOpTypeArray: {
  151. auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  152. if (!element_type) {
  153. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  154. << "Element type is not defined.";
  155. }
  156. const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
  157. if (!length) {
  158. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  159. << "Length is not defined.";
  160. }
  161. bool is_int32;
  162. bool is_const;
  163. uint32_t value;
  164. std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
  165. if (is_int32 && is_const && value != constituent_count) {
  166. return _.diag(SPV_ERROR_INVALID_ID, inst)
  167. << opcode_name
  168. << " Constituent count does not match "
  169. "Result Type <id> '"
  170. << _.getIdName(result_type->id()) << "'s array length.";
  171. }
  172. for (size_t constituent_index = 2;
  173. constituent_index < inst->operands().size(); constituent_index++) {
  174. const auto constituent_id =
  175. inst->GetOperandAs<uint32_t>(constituent_index);
  176. const auto constituent = _.FindDef(constituent_id);
  177. if (!constituent ||
  178. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  179. return _.diag(SPV_ERROR_INVALID_ID, inst)
  180. << opcode_name << " Constituent <id> '"
  181. << _.getIdName(constituent_id)
  182. << "' is not a constant or undef.";
  183. }
  184. const auto constituent_type = _.FindDef(constituent->type_id());
  185. if (!constituent_type) {
  186. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  187. << "Result type is not defined.";
  188. }
  189. if (element_type->id() != constituent_type->id()) {
  190. return _.diag(SPV_ERROR_INVALID_ID, inst)
  191. << opcode_name << " Constituent <id> '"
  192. << _.getIdName(constituent_id)
  193. << "'s type does not match Result Type <id> '"
  194. << _.getIdName(result_type->id()) << "'s array element type.";
  195. }
  196. }
  197. } break;
  198. case SpvOpTypeStruct: {
  199. const auto member_count = result_type->words().size() - 2;
  200. if (member_count != constituent_count) {
  201. return _.diag(SPV_ERROR_INVALID_ID, inst)
  202. << opcode_name << " Constituent <id> '"
  203. << _.getIdName(inst->type_id())
  204. << "' count does not match Result Type <id> '"
  205. << _.getIdName(result_type->id()) << "'s struct member count.";
  206. }
  207. for (uint32_t constituent_index = 2, member_index = 1;
  208. constituent_index < inst->operands().size();
  209. constituent_index++, member_index++) {
  210. const auto constituent_id =
  211. inst->GetOperandAs<uint32_t>(constituent_index);
  212. const auto constituent = _.FindDef(constituent_id);
  213. if (!constituent ||
  214. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  215. return _.diag(SPV_ERROR_INVALID_ID, inst)
  216. << opcode_name << " Constituent <id> '"
  217. << _.getIdName(constituent_id)
  218. << "' is not a constant or undef.";
  219. }
  220. const auto constituent_type = _.FindDef(constituent->type_id());
  221. if (!constituent_type) {
  222. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  223. << "Result type is not defined.";
  224. }
  225. const auto member_type_id =
  226. result_type->GetOperandAs<uint32_t>(member_index);
  227. const auto member_type = _.FindDef(member_type_id);
  228. if (!member_type || member_type->id() != constituent_type->id()) {
  229. return _.diag(SPV_ERROR_INVALID_ID, inst)
  230. << opcode_name << " Constituent <id> '"
  231. << _.getIdName(constituent_id)
  232. << "' type does not match the Result Type <id> '"
  233. << _.getIdName(result_type->id()) << "'s member type.";
  234. }
  235. }
  236. } break;
  237. case SpvOpTypeCooperativeMatrixNV: {
  238. if (1 != constituent_count) {
  239. return _.diag(SPV_ERROR_INVALID_ID, inst)
  240. << opcode_name << " Constituent <id> '"
  241. << _.getIdName(inst->type_id()) << "' count must be one.";
  242. }
  243. const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
  244. const auto constituent = _.FindDef(constituent_id);
  245. if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  246. return _.diag(SPV_ERROR_INVALID_ID, inst)
  247. << opcode_name << " Constituent <id> '"
  248. << _.getIdName(constituent_id)
  249. << "' is not a constant or undef.";
  250. }
  251. const auto constituent_type = _.FindDef(constituent->type_id());
  252. if (!constituent_type) {
  253. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  254. << "Result type is not defined.";
  255. }
  256. const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
  257. const auto component_type = _.FindDef(component_type_id);
  258. if (!component_type || component_type->id() != constituent_type->id()) {
  259. return _.diag(SPV_ERROR_INVALID_ID, inst)
  260. << opcode_name << " Constituent <id> '"
  261. << _.getIdName(constituent_id)
  262. << "' type does not match the Result Type <id> '"
  263. << _.getIdName(result_type->id()) << "'s component type.";
  264. }
  265. } break;
  266. default:
  267. break;
  268. }
  269. return SPV_SUCCESS;
  270. }
  271. spv_result_t ValidateConstantSampler(ValidationState_t& _,
  272. const Instruction* inst) {
  273. const auto result_type = _.FindDef(inst->type_id());
  274. if (!result_type || result_type->opcode() != SpvOpTypeSampler) {
  275. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  276. << "OpConstantSampler Result Type <id> '"
  277. << _.getIdName(inst->type_id()) << "' is not a sampler type.";
  278. }
  279. return SPV_SUCCESS;
  280. }
  281. // True if instruction defines a type that can have a null value, as defined by
  282. // the SPIR-V spec. Tracks composite-type components through module to check
  283. // nullability transitively.
  284. bool IsTypeNullable(const std::vector<uint32_t>& instruction,
  285. const ValidationState_t& _) {
  286. uint16_t opcode;
  287. uint16_t word_count;
  288. spvOpcodeSplit(instruction[0], &word_count, &opcode);
  289. switch (static_cast<SpvOp>(opcode)) {
  290. case SpvOpTypeBool:
  291. case SpvOpTypeInt:
  292. case SpvOpTypeFloat:
  293. case SpvOpTypePointer:
  294. case SpvOpTypeEvent:
  295. case SpvOpTypeDeviceEvent:
  296. case SpvOpTypeReserveId:
  297. case SpvOpTypeQueue:
  298. return true;
  299. case SpvOpTypeArray:
  300. case SpvOpTypeMatrix:
  301. case SpvOpTypeCooperativeMatrixNV:
  302. case SpvOpTypeVector: {
  303. auto base_type = _.FindDef(instruction[2]);
  304. return base_type && IsTypeNullable(base_type->words(), _);
  305. }
  306. case SpvOpTypeStruct: {
  307. for (size_t elementIndex = 2; elementIndex < instruction.size();
  308. ++elementIndex) {
  309. auto element = _.FindDef(instruction[elementIndex]);
  310. if (!element || !IsTypeNullable(element->words(), _)) return false;
  311. }
  312. return true;
  313. }
  314. default:
  315. return false;
  316. }
  317. }
  318. spv_result_t ValidateConstantNull(ValidationState_t& _,
  319. const Instruction* inst) {
  320. const auto result_type = _.FindDef(inst->type_id());
  321. if (!result_type || !IsTypeNullable(result_type->words(), _)) {
  322. return _.diag(SPV_ERROR_INVALID_ID, inst)
  323. << "OpConstantNull Result Type <id> '"
  324. << _.getIdName(inst->type_id()) << "' cannot have a null value.";
  325. }
  326. return SPV_SUCCESS;
  327. }
  328. spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
  329. const Instruction* inst) {
  330. const auto op = inst->GetOperandAs<SpvOp>(2);
  331. // The binary parser already ensures that the op is valid for *some*
  332. // environment. Here we check restrictions.
  333. switch (op) {
  334. case SpvOpQuantizeToF16:
  335. if (!_.HasCapability(SpvCapabilityShader)) {
  336. return _.diag(SPV_ERROR_INVALID_ID, inst)
  337. << "Specialization constant operation " << spvOpcodeString(op)
  338. << " requires Shader capability";
  339. }
  340. break;
  341. case SpvOpUConvert:
  342. if (!_.features().uconvert_spec_constant_op &&
  343. !_.HasCapability(SpvCapabilityKernel)) {
  344. return _.diag(SPV_ERROR_INVALID_ID, inst)
  345. << "Prior to SPIR-V 1.4, specialization constant operation "
  346. "UConvert requires Kernel capability or extension "
  347. "SPV_AMD_gpu_shader_int16";
  348. }
  349. break;
  350. case SpvOpConvertFToS:
  351. case SpvOpConvertSToF:
  352. case SpvOpConvertFToU:
  353. case SpvOpConvertUToF:
  354. case SpvOpConvertPtrToU:
  355. case SpvOpConvertUToPtr:
  356. case SpvOpGenericCastToPtr:
  357. case SpvOpPtrCastToGeneric:
  358. case SpvOpBitcast:
  359. case SpvOpFNegate:
  360. case SpvOpFAdd:
  361. case SpvOpFSub:
  362. case SpvOpFMul:
  363. case SpvOpFDiv:
  364. case SpvOpFRem:
  365. case SpvOpFMod:
  366. case SpvOpAccessChain:
  367. case SpvOpInBoundsAccessChain:
  368. case SpvOpPtrAccessChain:
  369. case SpvOpInBoundsPtrAccessChain:
  370. if (!_.HasCapability(SpvCapabilityKernel)) {
  371. return _.diag(SPV_ERROR_INVALID_ID, inst)
  372. << "Specialization constant operation " << spvOpcodeString(op)
  373. << " requires Kernel capability";
  374. }
  375. break;
  376. default:
  377. break;
  378. }
  379. // TODO(dneto): Validate result type and arguments to the various operations.
  380. return SPV_SUCCESS;
  381. }
  382. } // namespace
  383. spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
  384. switch (inst->opcode()) {
  385. case SpvOpConstantTrue:
  386. case SpvOpConstantFalse:
  387. case SpvOpSpecConstantTrue:
  388. case SpvOpSpecConstantFalse:
  389. if (auto error = ValidateConstantBool(_, inst)) return error;
  390. break;
  391. case SpvOpConstantComposite:
  392. case SpvOpSpecConstantComposite:
  393. if (auto error = ValidateConstantComposite(_, inst)) return error;
  394. break;
  395. case SpvOpConstantSampler:
  396. if (auto error = ValidateConstantSampler(_, inst)) return error;
  397. break;
  398. case SpvOpConstantNull:
  399. if (auto error = ValidateConstantNull(_, inst)) return error;
  400. break;
  401. case SpvOpSpecConstantOp:
  402. if (auto error = ValidateSpecConstantOp(_, inst)) return error;
  403. break;
  404. default:
  405. break;
  406. }
  407. // Generally disallow creating 8- or 16-bit constants unless the full
  408. // capabilities are present.
  409. if (spvOpcodeIsConstant(inst->opcode()) &&
  410. _.HasCapability(SpvCapabilityShader) &&
  411. !_.IsPointerType(inst->type_id()) &&
  412. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  413. return _.diag(SPV_ERROR_INVALID_ID, inst)
  414. << "Cannot form constants of 8- or 16-bit types";
  415. }
  416. return SPV_SUCCESS;
  417. }
  418. } // namespace val
  419. } // namespace spvtools