validate_constants.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opcode.h"
  15. #include "source/val/instruction.h"
  16. #include "source/val/validate.h"
  17. #include "source/val/validation_state.h"
  18. namespace spvtools {
  19. namespace val {
  20. namespace {
  21. spv_result_t ValidateConstantBool(ValidationState_t& _,
  22. const Instruction* inst) {
  23. auto type = _.FindDef(inst->type_id());
  24. if (!type || type->opcode() != spv::Op::OpTypeBool) {
  25. return _.diag(SPV_ERROR_INVALID_ID, inst)
  26. << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> "
  27. << _.getIdName(inst->type_id()) << " is not a boolean type.";
  28. }
  29. return SPV_SUCCESS;
  30. }
  31. spv_result_t ValidateConstantComposite(ValidationState_t& _,
  32. const Instruction* inst) {
  33. std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
  34. const auto result_type = _.FindDef(inst->type_id());
  35. if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
  36. return _.diag(SPV_ERROR_INVALID_ID, inst)
  37. << opcode_name << " Result Type <id> "
  38. << _.getIdName(inst->type_id()) << " is not a composite type.";
  39. }
  40. const auto constituent_count = inst->words().size() - 3;
  41. switch (result_type->opcode()) {
  42. case spv::Op::OpTypeVector: {
  43. const auto component_count = result_type->GetOperandAs<uint32_t>(2);
  44. if (component_count != constituent_count) {
  45. // TODO: Output ID's on diagnostic
  46. return _.diag(SPV_ERROR_INVALID_ID, inst)
  47. << opcode_name
  48. << " Constituent <id> count does not match "
  49. "Result Type <id> "
  50. << _.getIdName(result_type->id()) << "s vector component count.";
  51. }
  52. const auto component_type =
  53. _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  54. if (!component_type) {
  55. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  56. << "Component type is not defined.";
  57. }
  58. for (size_t constituent_index = 2;
  59. constituent_index < inst->operands().size(); constituent_index++) {
  60. const auto constituent_id =
  61. inst->GetOperandAs<uint32_t>(constituent_index);
  62. const auto constituent = _.FindDef(constituent_id);
  63. if (!constituent ||
  64. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  65. return _.diag(SPV_ERROR_INVALID_ID, inst)
  66. << opcode_name << " Constituent <id> "
  67. << _.getIdName(constituent_id)
  68. << " is not a constant or undef.";
  69. }
  70. const auto constituent_result_type = _.FindDef(constituent->type_id());
  71. if (!constituent_result_type ||
  72. component_type->opcode() != constituent_result_type->opcode()) {
  73. return _.diag(SPV_ERROR_INVALID_ID, inst)
  74. << opcode_name << " Constituent <id> "
  75. << _.getIdName(constituent_id)
  76. << "s type does not match Result Type <id> "
  77. << _.getIdName(result_type->id()) << "s vector element type.";
  78. }
  79. }
  80. } break;
  81. case spv::Op::OpTypeMatrix: {
  82. const auto column_count = result_type->GetOperandAs<uint32_t>(2);
  83. if (column_count != constituent_count) {
  84. // TODO: Output ID's on diagnostic
  85. return _.diag(SPV_ERROR_INVALID_ID, inst)
  86. << opcode_name
  87. << " Constituent <id> count does not match "
  88. "Result Type <id> "
  89. << _.getIdName(result_type->id()) << "s matrix column count.";
  90. }
  91. const auto column_type = _.FindDef(result_type->words()[2]);
  92. if (!column_type) {
  93. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  94. << "Column type is not defined.";
  95. }
  96. const auto component_count = column_type->GetOperandAs<uint32_t>(2);
  97. const auto component_type =
  98. _.FindDef(column_type->GetOperandAs<uint32_t>(1));
  99. if (!component_type) {
  100. return _.diag(SPV_ERROR_INVALID_ID, column_type)
  101. << "Component type is not defined.";
  102. }
  103. for (size_t constituent_index = 2;
  104. constituent_index < inst->operands().size(); constituent_index++) {
  105. const auto constituent_id =
  106. inst->GetOperandAs<uint32_t>(constituent_index);
  107. const auto constituent = _.FindDef(constituent_id);
  108. if (!constituent ||
  109. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  110. // The message says "... or undef" because the spec does not say
  111. // undef is a constant.
  112. return _.diag(SPV_ERROR_INVALID_ID, inst)
  113. << opcode_name << " Constituent <id> "
  114. << _.getIdName(constituent_id)
  115. << " is not a constant or undef.";
  116. }
  117. const auto vector = _.FindDef(constituent->type_id());
  118. if (!vector) {
  119. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  120. << "Result type is not defined.";
  121. }
  122. if (column_type->opcode() != vector->opcode()) {
  123. return _.diag(SPV_ERROR_INVALID_ID, inst)
  124. << opcode_name << " Constituent <id> "
  125. << _.getIdName(constituent_id)
  126. << " type does not match Result Type <id> "
  127. << _.getIdName(result_type->id()) << "s matrix column type.";
  128. }
  129. const auto vector_component_type =
  130. _.FindDef(vector->GetOperandAs<uint32_t>(1));
  131. if (component_type->id() != vector_component_type->id()) {
  132. return _.diag(SPV_ERROR_INVALID_ID, inst)
  133. << opcode_name << " Constituent <id> "
  134. << _.getIdName(constituent_id)
  135. << " component type does not match Result Type <id> "
  136. << _.getIdName(result_type->id())
  137. << "s matrix column component type.";
  138. }
  139. if (component_count != vector->words()[3]) {
  140. return _.diag(SPV_ERROR_INVALID_ID, inst)
  141. << opcode_name << " Constituent <id> "
  142. << _.getIdName(constituent_id)
  143. << " vector component count does not match Result Type <id> "
  144. << _.getIdName(result_type->id())
  145. << "s vector component count.";
  146. }
  147. }
  148. } break;
  149. case spv::Op::OpTypeArray: {
  150. auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  151. if (!element_type) {
  152. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  153. << "Element type is not defined.";
  154. }
  155. const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
  156. if (!length) {
  157. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  158. << "Length is not defined.";
  159. }
  160. bool is_int32;
  161. bool is_const;
  162. uint32_t value;
  163. std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
  164. if (is_int32 && is_const && value != constituent_count) {
  165. return _.diag(SPV_ERROR_INVALID_ID, inst)
  166. << opcode_name
  167. << " Constituent count does not match "
  168. "Result Type <id> "
  169. << _.getIdName(result_type->id()) << "s array length.";
  170. }
  171. for (size_t constituent_index = 2;
  172. constituent_index < inst->operands().size(); constituent_index++) {
  173. const auto constituent_id =
  174. inst->GetOperandAs<uint32_t>(constituent_index);
  175. const auto constituent = _.FindDef(constituent_id);
  176. if (!constituent ||
  177. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  178. return _.diag(SPV_ERROR_INVALID_ID, inst)
  179. << opcode_name << " Constituent <id> "
  180. << _.getIdName(constituent_id)
  181. << " is not a constant or undef.";
  182. }
  183. const auto constituent_type = _.FindDef(constituent->type_id());
  184. if (!constituent_type) {
  185. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  186. << "Result type is not defined.";
  187. }
  188. if (element_type->id() != constituent_type->id()) {
  189. return _.diag(SPV_ERROR_INVALID_ID, inst)
  190. << opcode_name << " Constituent <id> "
  191. << _.getIdName(constituent_id)
  192. << "s type does not match Result Type <id> "
  193. << _.getIdName(result_type->id()) << "s array element type.";
  194. }
  195. }
  196. } break;
  197. case spv::Op::OpTypeStruct: {
  198. const auto member_count = result_type->words().size() - 2;
  199. if (member_count != constituent_count) {
  200. return _.diag(SPV_ERROR_INVALID_ID, inst)
  201. << opcode_name << " Constituent <id> "
  202. << _.getIdName(inst->type_id())
  203. << " count does not match Result Type <id> "
  204. << _.getIdName(result_type->id()) << "s struct member count.";
  205. }
  206. for (uint32_t constituent_index = 2, member_index = 1;
  207. constituent_index < inst->operands().size();
  208. constituent_index++, member_index++) {
  209. const auto constituent_id =
  210. inst->GetOperandAs<uint32_t>(constituent_index);
  211. const auto constituent = _.FindDef(constituent_id);
  212. if (!constituent ||
  213. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  214. return _.diag(SPV_ERROR_INVALID_ID, inst)
  215. << opcode_name << " Constituent <id> "
  216. << _.getIdName(constituent_id)
  217. << " is not a constant or undef.";
  218. }
  219. const auto constituent_type = _.FindDef(constituent->type_id());
  220. if (!constituent_type) {
  221. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  222. << "Result type is not defined.";
  223. }
  224. const auto member_type_id =
  225. result_type->GetOperandAs<uint32_t>(member_index);
  226. const auto member_type = _.FindDef(member_type_id);
  227. if (!member_type || member_type->id() != constituent_type->id()) {
  228. return _.diag(SPV_ERROR_INVALID_ID, inst)
  229. << opcode_name << " Constituent <id> "
  230. << _.getIdName(constituent_id)
  231. << " type does not match the Result Type <id> "
  232. << _.getIdName(result_type->id()) << "s member type.";
  233. }
  234. }
  235. } break;
  236. case spv::Op::OpTypeCooperativeMatrixNV: {
  237. if (1 != constituent_count) {
  238. return _.diag(SPV_ERROR_INVALID_ID, inst)
  239. << opcode_name << " Constituent <id> "
  240. << _.getIdName(inst->type_id()) << " count must be one.";
  241. }
  242. const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
  243. const auto constituent = _.FindDef(constituent_id);
  244. if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  245. return _.diag(SPV_ERROR_INVALID_ID, inst)
  246. << opcode_name << " Constituent <id> "
  247. << _.getIdName(constituent_id) << " is not a constant or undef.";
  248. }
  249. const auto constituent_type = _.FindDef(constituent->type_id());
  250. if (!constituent_type) {
  251. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  252. << "Result type is not defined.";
  253. }
  254. const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
  255. const auto component_type = _.FindDef(component_type_id);
  256. if (!component_type || component_type->id() != constituent_type->id()) {
  257. return _.diag(SPV_ERROR_INVALID_ID, inst)
  258. << opcode_name << " Constituent <id> "
  259. << _.getIdName(constituent_id)
  260. << " type does not match the Result Type <id> "
  261. << _.getIdName(result_type->id()) << "s component type.";
  262. }
  263. } break;
  264. default:
  265. break;
  266. }
  267. return SPV_SUCCESS;
  268. }
  269. spv_result_t ValidateConstantSampler(ValidationState_t& _,
  270. const Instruction* inst) {
  271. const auto result_type = _.FindDef(inst->type_id());
  272. if (!result_type || result_type->opcode() != spv::Op::OpTypeSampler) {
  273. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  274. << "OpConstantSampler Result Type <id> "
  275. << _.getIdName(inst->type_id()) << " is not a sampler type.";
  276. }
  277. return SPV_SUCCESS;
  278. }
  279. // True if instruction defines a type that can have a null value, as defined by
  280. // the SPIR-V spec. Tracks composite-type components through module to check
  281. // nullability transitively.
  282. bool IsTypeNullable(const std::vector<uint32_t>& instruction,
  283. const ValidationState_t& _) {
  284. uint16_t opcode;
  285. uint16_t word_count;
  286. spvOpcodeSplit(instruction[0], &word_count, &opcode);
  287. switch (static_cast<spv::Op>(opcode)) {
  288. case spv::Op::OpTypeBool:
  289. case spv::Op::OpTypeInt:
  290. case spv::Op::OpTypeFloat:
  291. case spv::Op::OpTypeEvent:
  292. case spv::Op::OpTypeDeviceEvent:
  293. case spv::Op::OpTypeReserveId:
  294. case spv::Op::OpTypeQueue:
  295. return true;
  296. case spv::Op::OpTypeArray:
  297. case spv::Op::OpTypeMatrix:
  298. case spv::Op::OpTypeCooperativeMatrixNV:
  299. case spv::Op::OpTypeVector: {
  300. auto base_type = _.FindDef(instruction[2]);
  301. return base_type && IsTypeNullable(base_type->words(), _);
  302. }
  303. case spv::Op::OpTypeStruct: {
  304. for (size_t elementIndex = 2; elementIndex < instruction.size();
  305. ++elementIndex) {
  306. auto element = _.FindDef(instruction[elementIndex]);
  307. if (!element || !IsTypeNullable(element->words(), _)) return false;
  308. }
  309. return true;
  310. }
  311. case spv::Op::OpTypePointer:
  312. if (spv::StorageClass(instruction[2]) ==
  313. spv::StorageClass::PhysicalStorageBuffer) {
  314. return false;
  315. }
  316. return true;
  317. default:
  318. return false;
  319. }
  320. }
  321. spv_result_t ValidateConstantNull(ValidationState_t& _,
  322. const Instruction* inst) {
  323. const auto result_type = _.FindDef(inst->type_id());
  324. if (!result_type || !IsTypeNullable(result_type->words(), _)) {
  325. return _.diag(SPV_ERROR_INVALID_ID, inst)
  326. << "OpConstantNull Result Type <id> " << _.getIdName(inst->type_id())
  327. << " cannot have a null value.";
  328. }
  329. return SPV_SUCCESS;
  330. }
  331. // Validates that OpSpecConstant specializes to either int or float type.
  332. spv_result_t ValidateSpecConstant(ValidationState_t& _,
  333. const Instruction* inst) {
  334. // Operand 0 is the <id> of the type that we're specializing to.
  335. auto type_id = inst->GetOperandAs<const uint32_t>(0);
  336. auto type_instruction = _.FindDef(type_id);
  337. auto type_opcode = type_instruction->opcode();
  338. if (type_opcode != spv::Op::OpTypeInt &&
  339. type_opcode != spv::Op::OpTypeFloat) {
  340. return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
  341. "must be an integer or "
  342. "floating-point number.";
  343. }
  344. return SPV_SUCCESS;
  345. }
  346. spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
  347. const Instruction* inst) {
  348. const auto op = inst->GetOperandAs<spv::Op>(2);
  349. // The binary parser already ensures that the op is valid for *some*
  350. // environment. Here we check restrictions.
  351. switch (op) {
  352. case spv::Op::OpQuantizeToF16:
  353. if (!_.HasCapability(spv::Capability::Shader)) {
  354. return _.diag(SPV_ERROR_INVALID_ID, inst)
  355. << "Specialization constant operation " << spvOpcodeString(op)
  356. << " requires Shader capability";
  357. }
  358. break;
  359. case spv::Op::OpUConvert:
  360. if (!_.features().uconvert_spec_constant_op &&
  361. !_.HasCapability(spv::Capability::Kernel)) {
  362. return _.diag(SPV_ERROR_INVALID_ID, inst)
  363. << "Prior to SPIR-V 1.4, specialization constant operation "
  364. "UConvert requires Kernel capability or extension "
  365. "SPV_AMD_gpu_shader_int16";
  366. }
  367. break;
  368. case spv::Op::OpConvertFToS:
  369. case spv::Op::OpConvertSToF:
  370. case spv::Op::OpConvertFToU:
  371. case spv::Op::OpConvertUToF:
  372. case spv::Op::OpConvertPtrToU:
  373. case spv::Op::OpConvertUToPtr:
  374. case spv::Op::OpGenericCastToPtr:
  375. case spv::Op::OpPtrCastToGeneric:
  376. case spv::Op::OpBitcast:
  377. case spv::Op::OpFNegate:
  378. case spv::Op::OpFAdd:
  379. case spv::Op::OpFSub:
  380. case spv::Op::OpFMul:
  381. case spv::Op::OpFDiv:
  382. case spv::Op::OpFRem:
  383. case spv::Op::OpFMod:
  384. case spv::Op::OpAccessChain:
  385. case spv::Op::OpInBoundsAccessChain:
  386. case spv::Op::OpPtrAccessChain:
  387. case spv::Op::OpInBoundsPtrAccessChain:
  388. if (!_.HasCapability(spv::Capability::Kernel)) {
  389. return _.diag(SPV_ERROR_INVALID_ID, inst)
  390. << "Specialization constant operation " << spvOpcodeString(op)
  391. << " requires Kernel capability";
  392. }
  393. break;
  394. default:
  395. break;
  396. }
  397. // TODO(dneto): Validate result type and arguments to the various operations.
  398. return SPV_SUCCESS;
  399. }
  400. } // namespace
  401. spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
  402. switch (inst->opcode()) {
  403. case spv::Op::OpConstantTrue:
  404. case spv::Op::OpConstantFalse:
  405. case spv::Op::OpSpecConstantTrue:
  406. case spv::Op::OpSpecConstantFalse:
  407. if (auto error = ValidateConstantBool(_, inst)) return error;
  408. break;
  409. case spv::Op::OpConstantComposite:
  410. case spv::Op::OpSpecConstantComposite:
  411. if (auto error = ValidateConstantComposite(_, inst)) return error;
  412. break;
  413. case spv::Op::OpConstantSampler:
  414. if (auto error = ValidateConstantSampler(_, inst)) return error;
  415. break;
  416. case spv::Op::OpConstantNull:
  417. if (auto error = ValidateConstantNull(_, inst)) return error;
  418. break;
  419. case spv::Op::OpSpecConstant:
  420. if (auto error = ValidateSpecConstant(_, inst)) return error;
  421. break;
  422. case spv::Op::OpSpecConstantOp:
  423. if (auto error = ValidateSpecConstantOp(_, inst)) return error;
  424. break;
  425. default:
  426. break;
  427. }
  428. // Generally disallow creating 8- or 16-bit constants unless the full
  429. // capabilities are present.
  430. if (spvOpcodeIsConstant(inst->opcode()) &&
  431. _.HasCapability(spv::Capability::Shader) &&
  432. !_.IsPointerType(inst->type_id()) &&
  433. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  434. return _.diag(SPV_ERROR_INVALID_ID, inst)
  435. << "Cannot form constants of 8- or 16-bit types";
  436. }
  437. return SPV_SUCCESS;
  438. }
  439. } // namespace val
  440. } // namespace spvtools