validate_constants.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opcode.h"
  15. #include "source/val/instruction.h"
  16. #include "source/val/validate.h"
  17. #include "source/val/validation_state.h"
  18. namespace spvtools {
  19. namespace val {
  20. namespace {
  21. spv_result_t ValidateConstantBool(ValidationState_t& _,
  22. const Instruction* inst) {
  23. auto type = _.FindDef(inst->type_id());
  24. if (!type || type->opcode() != spv::Op::OpTypeBool) {
  25. return _.diag(SPV_ERROR_INVALID_ID, inst)
  26. << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> "
  27. << _.getIdName(inst->type_id()) << " is not a boolean type.";
  28. }
  29. return SPV_SUCCESS;
  30. }
  31. spv_result_t ValidateConstantComposite(ValidationState_t& _,
  32. const Instruction* inst) {
  33. std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
  34. const auto result_type = _.FindDef(inst->type_id());
  35. if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
  36. return _.diag(SPV_ERROR_INVALID_ID, inst)
  37. << opcode_name << " Result Type <id> "
  38. << _.getIdName(inst->type_id()) << " is not a composite type.";
  39. }
  40. const auto constituent_count = inst->words().size() - 3;
  41. switch (result_type->opcode()) {
  42. case spv::Op::OpTypeVector:
  43. case spv::Op::OpTypeCooperativeVectorNV: {
  44. uint32_t num_result_components = _.GetDimension(result_type->id());
  45. bool comp_is_int32 = true, comp_is_const_int32 = true;
  46. if (result_type->opcode() == spv::Op::OpTypeCooperativeVectorNV) {
  47. uint32_t comp_count_id = result_type->GetOperandAs<uint32_t>(2);
  48. std::tie(comp_is_int32, comp_is_const_int32, num_result_components) =
  49. _.EvalInt32IfConst(comp_count_id);
  50. }
  51. if (comp_is_const_int32 && num_result_components != constituent_count) {
  52. // TODO: Output ID's on diagnostic
  53. return _.diag(SPV_ERROR_INVALID_ID, inst)
  54. << opcode_name
  55. << " Constituent <id> count does not match "
  56. "Result Type <id> "
  57. << _.getIdName(result_type->id()) << "s vector component count.";
  58. }
  59. const auto component_type =
  60. _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  61. if (!component_type) {
  62. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  63. << "Component type is not defined.";
  64. }
  65. for (size_t constituent_index = 2;
  66. constituent_index < inst->operands().size(); constituent_index++) {
  67. const auto constituent_id =
  68. inst->GetOperandAs<uint32_t>(constituent_index);
  69. const auto constituent = _.FindDef(constituent_id);
  70. if (!constituent ||
  71. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  72. return _.diag(SPV_ERROR_INVALID_ID, inst)
  73. << opcode_name << " Constituent <id> "
  74. << _.getIdName(constituent_id)
  75. << " is not a constant or undef.";
  76. }
  77. const auto constituent_result_type = _.FindDef(constituent->type_id());
  78. if (!constituent_result_type ||
  79. component_type->id() != constituent_result_type->id()) {
  80. return _.diag(SPV_ERROR_INVALID_ID, inst)
  81. << opcode_name << " Constituent <id> "
  82. << _.getIdName(constituent_id)
  83. << "s type does not match Result Type <id> "
  84. << _.getIdName(result_type->id()) << "s vector element type.";
  85. }
  86. }
  87. } break;
  88. case spv::Op::OpTypeMatrix: {
  89. const auto column_count = result_type->GetOperandAs<uint32_t>(2);
  90. if (column_count != constituent_count) {
  91. // TODO: Output ID's on diagnostic
  92. return _.diag(SPV_ERROR_INVALID_ID, inst)
  93. << opcode_name
  94. << " Constituent <id> count does not match "
  95. "Result Type <id> "
  96. << _.getIdName(result_type->id()) << "s matrix column count.";
  97. }
  98. const auto column_type = _.FindDef(result_type->words()[2]);
  99. if (!column_type) {
  100. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  101. << "Column type is not defined.";
  102. }
  103. const auto component_count = column_type->GetOperandAs<uint32_t>(2);
  104. const auto component_type =
  105. _.FindDef(column_type->GetOperandAs<uint32_t>(1));
  106. if (!component_type) {
  107. return _.diag(SPV_ERROR_INVALID_ID, column_type)
  108. << "Component type is not defined.";
  109. }
  110. for (size_t constituent_index = 2;
  111. constituent_index < inst->operands().size(); constituent_index++) {
  112. const auto constituent_id =
  113. inst->GetOperandAs<uint32_t>(constituent_index);
  114. const auto constituent = _.FindDef(constituent_id);
  115. if (!constituent ||
  116. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  117. // The message says "... or undef" because the spec does not say
  118. // undef is a constant.
  119. return _.diag(SPV_ERROR_INVALID_ID, inst)
  120. << opcode_name << " Constituent <id> "
  121. << _.getIdName(constituent_id)
  122. << " is not a constant or undef.";
  123. }
  124. const auto vector = _.FindDef(constituent->type_id());
  125. if (!vector) {
  126. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  127. << "Result type is not defined.";
  128. }
  129. if (column_type->opcode() != vector->opcode()) {
  130. return _.diag(SPV_ERROR_INVALID_ID, inst)
  131. << opcode_name << " Constituent <id> "
  132. << _.getIdName(constituent_id)
  133. << " type does not match Result Type <id> "
  134. << _.getIdName(result_type->id()) << "s matrix column type.";
  135. }
  136. const auto vector_component_type =
  137. _.FindDef(vector->GetOperandAs<uint32_t>(1));
  138. if (component_type->id() != vector_component_type->id()) {
  139. return _.diag(SPV_ERROR_INVALID_ID, inst)
  140. << opcode_name << " Constituent <id> "
  141. << _.getIdName(constituent_id)
  142. << " component type does not match Result Type <id> "
  143. << _.getIdName(result_type->id())
  144. << "s matrix column component type.";
  145. }
  146. if (component_count != vector->words()[3]) {
  147. return _.diag(SPV_ERROR_INVALID_ID, inst)
  148. << opcode_name << " Constituent <id> "
  149. << _.getIdName(constituent_id)
  150. << " vector component count does not match Result Type <id> "
  151. << _.getIdName(result_type->id())
  152. << "s vector component count.";
  153. }
  154. }
  155. } break;
  156. case spv::Op::OpTypeArray: {
  157. auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  158. if (!element_type) {
  159. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  160. << "Element type is not defined.";
  161. }
  162. const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
  163. if (!length) {
  164. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  165. << "Length is not defined.";
  166. }
  167. bool is_int32;
  168. bool is_const;
  169. uint32_t value;
  170. std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
  171. if (is_int32 && is_const && value != constituent_count) {
  172. return _.diag(SPV_ERROR_INVALID_ID, inst)
  173. << opcode_name
  174. << " Constituent count does not match "
  175. "Result Type <id> "
  176. << _.getIdName(result_type->id()) << "s array length.";
  177. }
  178. for (size_t constituent_index = 2;
  179. constituent_index < inst->operands().size(); constituent_index++) {
  180. const auto constituent_id =
  181. inst->GetOperandAs<uint32_t>(constituent_index);
  182. const auto constituent = _.FindDef(constituent_id);
  183. if (!constituent ||
  184. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  185. return _.diag(SPV_ERROR_INVALID_ID, inst)
  186. << opcode_name << " Constituent <id> "
  187. << _.getIdName(constituent_id)
  188. << " is not a constant or undef.";
  189. }
  190. const auto constituent_type = _.FindDef(constituent->type_id());
  191. if (!constituent_type) {
  192. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  193. << "Result type is not defined.";
  194. }
  195. if (element_type->id() != constituent_type->id()) {
  196. return _.diag(SPV_ERROR_INVALID_ID, inst)
  197. << opcode_name << " Constituent <id> "
  198. << _.getIdName(constituent_id)
  199. << "s type does not match Result Type <id> "
  200. << _.getIdName(result_type->id()) << "s array element type.";
  201. }
  202. }
  203. } break;
  204. case spv::Op::OpTypeStruct: {
  205. const auto member_count = result_type->words().size() - 2;
  206. if (member_count != constituent_count) {
  207. return _.diag(SPV_ERROR_INVALID_ID, inst)
  208. << opcode_name << " Constituent <id> "
  209. << _.getIdName(inst->type_id())
  210. << " count does not match Result Type <id> "
  211. << _.getIdName(result_type->id()) << "s struct member count.";
  212. }
  213. for (uint32_t constituent_index = 2, member_index = 1;
  214. constituent_index < inst->operands().size();
  215. constituent_index++, member_index++) {
  216. const auto constituent_id =
  217. inst->GetOperandAs<uint32_t>(constituent_index);
  218. const auto constituent = _.FindDef(constituent_id);
  219. if (!constituent ||
  220. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  221. return _.diag(SPV_ERROR_INVALID_ID, inst)
  222. << opcode_name << " Constituent <id> "
  223. << _.getIdName(constituent_id)
  224. << " is not a constant or undef.";
  225. }
  226. const auto constituent_type = _.FindDef(constituent->type_id());
  227. if (!constituent_type) {
  228. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  229. << "Result type is not defined.";
  230. }
  231. const auto member_type_id =
  232. result_type->GetOperandAs<uint32_t>(member_index);
  233. const auto member_type = _.FindDef(member_type_id);
  234. if (!member_type || member_type->id() != constituent_type->id()) {
  235. return _.diag(SPV_ERROR_INVALID_ID, inst)
  236. << opcode_name << " Constituent <id> "
  237. << _.getIdName(constituent_id)
  238. << " type does not match the Result Type <id> "
  239. << _.getIdName(result_type->id()) << "s member type.";
  240. }
  241. }
  242. } break;
  243. case spv::Op::OpTypeCooperativeMatrixKHR:
  244. case spv::Op::OpTypeCooperativeMatrixNV: {
  245. if (1 != constituent_count) {
  246. return _.diag(SPV_ERROR_INVALID_ID, inst)
  247. << opcode_name << " Constituent <id> "
  248. << _.getIdName(inst->type_id()) << " count must be one.";
  249. }
  250. const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
  251. const auto constituent = _.FindDef(constituent_id);
  252. if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  253. return _.diag(SPV_ERROR_INVALID_ID, inst)
  254. << opcode_name << " Constituent <id> "
  255. << _.getIdName(constituent_id) << " is not a constant or undef.";
  256. }
  257. const auto constituent_type = _.FindDef(constituent->type_id());
  258. if (!constituent_type) {
  259. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  260. << "Result type is not defined.";
  261. }
  262. const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
  263. const auto component_type = _.FindDef(component_type_id);
  264. if (!component_type || component_type->id() != constituent_type->id()) {
  265. return _.diag(SPV_ERROR_INVALID_ID, inst)
  266. << opcode_name << " Constituent <id> "
  267. << _.getIdName(constituent_id)
  268. << " type does not match the Result Type <id> "
  269. << _.getIdName(result_type->id()) << "s component type.";
  270. }
  271. } break;
  272. default:
  273. break;
  274. }
  275. return SPV_SUCCESS;
  276. }
  277. spv_result_t ValidateConstantSampler(ValidationState_t& _,
  278. const Instruction* inst) {
  279. const auto result_type = _.FindDef(inst->type_id());
  280. if (!result_type || result_type->opcode() != spv::Op::OpTypeSampler) {
  281. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  282. << "OpConstantSampler Result Type <id> "
  283. << _.getIdName(inst->type_id()) << " is not a sampler type.";
  284. }
  285. return SPV_SUCCESS;
  286. }
  287. // True if instruction defines a type that can have a null value, as defined by
  288. // the SPIR-V spec. Tracks composite-type components through module to check
  289. // nullability transitively.
  290. bool IsTypeNullable(const std::vector<uint32_t>& instruction,
  291. const ValidationState_t& _) {
  292. uint16_t opcode;
  293. uint16_t word_count;
  294. spvOpcodeSplit(instruction[0], &word_count, &opcode);
  295. switch (static_cast<spv::Op>(opcode)) {
  296. case spv::Op::OpTypeBool:
  297. case spv::Op::OpTypeInt:
  298. case spv::Op::OpTypeFloat:
  299. case spv::Op::OpTypeEvent:
  300. case spv::Op::OpTypeDeviceEvent:
  301. case spv::Op::OpTypeReserveId:
  302. case spv::Op::OpTypeQueue:
  303. return true;
  304. case spv::Op::OpTypeArray:
  305. case spv::Op::OpTypeMatrix:
  306. case spv::Op::OpTypeCooperativeMatrixNV:
  307. case spv::Op::OpTypeCooperativeMatrixKHR:
  308. case spv::Op::OpTypeCooperativeVectorNV:
  309. case spv::Op::OpTypeVector: {
  310. auto base_type = _.FindDef(instruction[2]);
  311. return base_type && IsTypeNullable(base_type->words(), _);
  312. }
  313. case spv::Op::OpTypeStruct: {
  314. for (size_t elementIndex = 2; elementIndex < instruction.size();
  315. ++elementIndex) {
  316. auto element = _.FindDef(instruction[elementIndex]);
  317. if (!element || !IsTypeNullable(element->words(), _)) return false;
  318. }
  319. return true;
  320. }
  321. case spv::Op::OpTypeUntypedPointerKHR:
  322. case spv::Op::OpTypePointer:
  323. if (spv::StorageClass(instruction[2]) ==
  324. spv::StorageClass::PhysicalStorageBuffer) {
  325. return false;
  326. }
  327. return true;
  328. default:
  329. return false;
  330. }
  331. }
  332. spv_result_t ValidateConstantNull(ValidationState_t& _,
  333. const Instruction* inst) {
  334. const auto result_type = _.FindDef(inst->type_id());
  335. if (!result_type || !IsTypeNullable(result_type->words(), _)) {
  336. return _.diag(SPV_ERROR_INVALID_ID, inst)
  337. << "OpConstantNull Result Type <id> " << _.getIdName(inst->type_id())
  338. << " cannot have a null value.";
  339. }
  340. return SPV_SUCCESS;
  341. }
  342. // Validates that OpSpecConstant specializes to either int or float type.
  343. spv_result_t ValidateSpecConstant(ValidationState_t& _,
  344. const Instruction* inst) {
  345. // Operand 0 is the <id> of the type that we're specializing to.
  346. auto type_id = inst->GetOperandAs<const uint32_t>(0);
  347. auto type_instruction = _.FindDef(type_id);
  348. auto type_opcode = type_instruction->opcode();
  349. if (type_opcode != spv::Op::OpTypeInt &&
  350. type_opcode != spv::Op::OpTypeFloat) {
  351. return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
  352. "must be an integer or "
  353. "floating-point number.";
  354. }
  355. return SPV_SUCCESS;
  356. }
  357. spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
  358. const Instruction* inst) {
  359. const auto op = inst->GetOperandAs<spv::Op>(2);
  360. // The binary parser already ensures that the op is valid for *some*
  361. // environment. Here we check restrictions.
  362. switch (op) {
  363. case spv::Op::OpQuantizeToF16:
  364. if (!_.HasCapability(spv::Capability::Shader)) {
  365. return _.diag(SPV_ERROR_INVALID_ID, inst)
  366. << "Specialization constant operation " << spvOpcodeString(op)
  367. << " requires Shader capability";
  368. }
  369. break;
  370. case spv::Op::OpUConvert:
  371. if (!_.features().uconvert_spec_constant_op &&
  372. !_.HasCapability(spv::Capability::Kernel)) {
  373. return _.diag(SPV_ERROR_INVALID_ID, inst)
  374. << "Prior to SPIR-V 1.4, specialization constant operation "
  375. "UConvert requires Kernel capability or extension "
  376. "SPV_AMD_gpu_shader_int16";
  377. }
  378. break;
  379. case spv::Op::OpConvertFToS:
  380. case spv::Op::OpConvertSToF:
  381. case spv::Op::OpConvertFToU:
  382. case spv::Op::OpConvertUToF:
  383. case spv::Op::OpConvertPtrToU:
  384. case spv::Op::OpConvertUToPtr:
  385. case spv::Op::OpGenericCastToPtr:
  386. case spv::Op::OpPtrCastToGeneric:
  387. case spv::Op::OpBitcast:
  388. case spv::Op::OpFNegate:
  389. case spv::Op::OpFAdd:
  390. case spv::Op::OpFSub:
  391. case spv::Op::OpFMul:
  392. case spv::Op::OpFDiv:
  393. case spv::Op::OpFRem:
  394. case spv::Op::OpFMod:
  395. case spv::Op::OpAccessChain:
  396. case spv::Op::OpInBoundsAccessChain:
  397. case spv::Op::OpPtrAccessChain:
  398. case spv::Op::OpInBoundsPtrAccessChain:
  399. if (!_.HasCapability(spv::Capability::Kernel)) {
  400. return _.diag(SPV_ERROR_INVALID_ID, inst)
  401. << "Specialization constant operation " << spvOpcodeString(op)
  402. << " requires Kernel capability";
  403. }
  404. break;
  405. default:
  406. break;
  407. }
  408. // TODO(dneto): Validate result type and arguments to the various operations.
  409. return SPV_SUCCESS;
  410. }
  411. } // namespace
  412. spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
  413. switch (inst->opcode()) {
  414. case spv::Op::OpConstantTrue:
  415. case spv::Op::OpConstantFalse:
  416. case spv::Op::OpSpecConstantTrue:
  417. case spv::Op::OpSpecConstantFalse:
  418. if (auto error = ValidateConstantBool(_, inst)) return error;
  419. break;
  420. case spv::Op::OpConstantComposite:
  421. case spv::Op::OpSpecConstantComposite:
  422. if (auto error = ValidateConstantComposite(_, inst)) return error;
  423. break;
  424. case spv::Op::OpConstantSampler:
  425. if (auto error = ValidateConstantSampler(_, inst)) return error;
  426. break;
  427. case spv::Op::OpConstantNull:
  428. if (auto error = ValidateConstantNull(_, inst)) return error;
  429. break;
  430. case spv::Op::OpSpecConstant:
  431. if (auto error = ValidateSpecConstant(_, inst)) return error;
  432. break;
  433. case spv::Op::OpSpecConstantOp:
  434. if (auto error = ValidateSpecConstantOp(_, inst)) return error;
  435. break;
  436. default:
  437. break;
  438. }
  439. // Generally disallow creating 8- or 16-bit constants unless the full
  440. // capabilities are present.
  441. if (spvOpcodeIsConstant(inst->opcode()) &&
  442. _.HasCapability(spv::Capability::Shader) &&
  443. !_.IsPointerType(inst->type_id()) &&
  444. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  445. return _.diag(SPV_ERROR_INVALID_ID, inst)
  446. << "Cannot form constants of 8- or 16-bit types";
  447. }
  448. return SPV_SUCCESS;
  449. }
  450. } // namespace val
  451. } // namespace spvtools