validate_constants.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/opcode.h"
  15. #include "source/val/instruction.h"
  16. #include "source/val/validate.h"
  17. #include "source/val/validation_state.h"
  18. namespace spvtools {
  19. namespace val {
  20. namespace {
  21. spv_result_t ValidateConstantBool(ValidationState_t& _,
  22. const Instruction* inst) {
  23. auto type = _.FindDef(inst->type_id());
  24. if (!type || type->opcode() != SpvOpTypeBool) {
  25. return _.diag(SPV_ERROR_INVALID_ID, inst)
  26. << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '"
  27. << _.getIdName(inst->type_id()) << "' is not a boolean type.";
  28. }
  29. return SPV_SUCCESS;
  30. }
  31. spv_result_t ValidateConstantComposite(ValidationState_t& _,
  32. const Instruction* inst) {
  33. std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
  34. const auto result_type = _.FindDef(inst->type_id());
  35. if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
  36. return _.diag(SPV_ERROR_INVALID_ID, inst)
  37. << opcode_name << " Result Type <id> '"
  38. << _.getIdName(inst->type_id()) << "' is not a composite type.";
  39. }
  40. const auto constituent_count = inst->words().size() - 3;
  41. switch (result_type->opcode()) {
  42. case SpvOpTypeVector: {
  43. const auto component_count = result_type->GetOperandAs<uint32_t>(2);
  44. if (component_count != constituent_count) {
  45. // TODO: Output ID's on diagnostic
  46. return _.diag(SPV_ERROR_INVALID_ID, inst)
  47. << opcode_name
  48. << " Constituent <id> count does not match "
  49. "Result Type <id> '"
  50. << _.getIdName(result_type->id())
  51. << "'s vector component count.";
  52. }
  53. const auto component_type =
  54. _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  55. if (!component_type) {
  56. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  57. << "Component type is not defined.";
  58. }
  59. for (size_t constituent_index = 2;
  60. constituent_index < inst->operands().size(); constituent_index++) {
  61. const auto constituent_id =
  62. inst->GetOperandAs<uint32_t>(constituent_index);
  63. const auto constituent = _.FindDef(constituent_id);
  64. if (!constituent ||
  65. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  66. return _.diag(SPV_ERROR_INVALID_ID, inst)
  67. << opcode_name << " Constituent <id> '"
  68. << _.getIdName(constituent_id)
  69. << "' is not a constant or undef.";
  70. }
  71. const auto constituent_result_type = _.FindDef(constituent->type_id());
  72. if (!constituent_result_type ||
  73. component_type->opcode() != constituent_result_type->opcode()) {
  74. return _.diag(SPV_ERROR_INVALID_ID, inst)
  75. << opcode_name << " Constituent <id> '"
  76. << _.getIdName(constituent_id)
  77. << "'s type does not match Result Type <id> '"
  78. << _.getIdName(result_type->id()) << "'s vector element type.";
  79. }
  80. }
  81. } break;
  82. case SpvOpTypeMatrix: {
  83. const auto column_count = result_type->GetOperandAs<uint32_t>(2);
  84. if (column_count != constituent_count) {
  85. // TODO: Output ID's on diagnostic
  86. return _.diag(SPV_ERROR_INVALID_ID, inst)
  87. << opcode_name
  88. << " Constituent <id> count does not match "
  89. "Result Type <id> '"
  90. << _.getIdName(result_type->id()) << "'s matrix column count.";
  91. }
  92. const auto column_type = _.FindDef(result_type->words()[2]);
  93. if (!column_type) {
  94. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  95. << "Column type is not defined.";
  96. }
  97. const auto component_count = column_type->GetOperandAs<uint32_t>(2);
  98. const auto component_type =
  99. _.FindDef(column_type->GetOperandAs<uint32_t>(1));
  100. if (!component_type) {
  101. return _.diag(SPV_ERROR_INVALID_ID, column_type)
  102. << "Component type is not defined.";
  103. }
  104. for (size_t constituent_index = 2;
  105. constituent_index < inst->operands().size(); constituent_index++) {
  106. const auto constituent_id =
  107. inst->GetOperandAs<uint32_t>(constituent_index);
  108. const auto constituent = _.FindDef(constituent_id);
  109. if (!constituent ||
  110. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  111. // The message says "... or undef" because the spec does not say
  112. // undef is a constant.
  113. return _.diag(SPV_ERROR_INVALID_ID, inst)
  114. << opcode_name << " Constituent <id> '"
  115. << _.getIdName(constituent_id)
  116. << "' is not a constant or undef.";
  117. }
  118. const auto vector = _.FindDef(constituent->type_id());
  119. if (!vector) {
  120. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  121. << "Result type is not defined.";
  122. }
  123. if (column_type->opcode() != vector->opcode()) {
  124. return _.diag(SPV_ERROR_INVALID_ID, inst)
  125. << opcode_name << " Constituent <id> '"
  126. << _.getIdName(constituent_id)
  127. << "' type does not match Result Type <id> '"
  128. << _.getIdName(result_type->id()) << "'s matrix column type.";
  129. }
  130. const auto vector_component_type =
  131. _.FindDef(vector->GetOperandAs<uint32_t>(1));
  132. if (component_type->id() != vector_component_type->id()) {
  133. return _.diag(SPV_ERROR_INVALID_ID, inst)
  134. << opcode_name << " Constituent <id> '"
  135. << _.getIdName(constituent_id)
  136. << "' component type does not match Result Type <id> '"
  137. << _.getIdName(result_type->id())
  138. << "'s matrix column component type.";
  139. }
  140. if (component_count != vector->words()[3]) {
  141. return _.diag(SPV_ERROR_INVALID_ID, inst)
  142. << opcode_name << " Constituent <id> '"
  143. << _.getIdName(constituent_id)
  144. << "' vector component count does not match Result Type <id> '"
  145. << _.getIdName(result_type->id())
  146. << "'s vector component count.";
  147. }
  148. }
  149. } break;
  150. case SpvOpTypeArray: {
  151. auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
  152. if (!element_type) {
  153. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  154. << "Element type is not defined.";
  155. }
  156. const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
  157. if (!length) {
  158. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  159. << "Length is not defined.";
  160. }
  161. bool is_int32;
  162. bool is_const;
  163. uint32_t value;
  164. std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
  165. if (is_int32 && is_const && value != constituent_count) {
  166. return _.diag(SPV_ERROR_INVALID_ID, inst)
  167. << opcode_name
  168. << " Constituent count does not match "
  169. "Result Type <id> '"
  170. << _.getIdName(result_type->id()) << "'s array length.";
  171. }
  172. for (size_t constituent_index = 2;
  173. constituent_index < inst->operands().size(); constituent_index++) {
  174. const auto constituent_id =
  175. inst->GetOperandAs<uint32_t>(constituent_index);
  176. const auto constituent = _.FindDef(constituent_id);
  177. if (!constituent ||
  178. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  179. return _.diag(SPV_ERROR_INVALID_ID, inst)
  180. << opcode_name << " Constituent <id> '"
  181. << _.getIdName(constituent_id)
  182. << "' is not a constant or undef.";
  183. }
  184. const auto constituent_type = _.FindDef(constituent->type_id());
  185. if (!constituent_type) {
  186. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  187. << "Result type is not defined.";
  188. }
  189. if (element_type->id() != constituent_type->id()) {
  190. return _.diag(SPV_ERROR_INVALID_ID, inst)
  191. << opcode_name << " Constituent <id> '"
  192. << _.getIdName(constituent_id)
  193. << "'s type does not match Result Type <id> '"
  194. << _.getIdName(result_type->id()) << "'s array element type.";
  195. }
  196. }
  197. } break;
  198. case SpvOpTypeStruct: {
  199. const auto member_count = result_type->words().size() - 2;
  200. if (member_count != constituent_count) {
  201. return _.diag(SPV_ERROR_INVALID_ID, inst)
  202. << opcode_name << " Constituent <id> '"
  203. << _.getIdName(inst->type_id())
  204. << "' count does not match Result Type <id> '"
  205. << _.getIdName(result_type->id()) << "'s struct member count.";
  206. }
  207. for (uint32_t constituent_index = 2, member_index = 1;
  208. constituent_index < inst->operands().size();
  209. constituent_index++, member_index++) {
  210. const auto constituent_id =
  211. inst->GetOperandAs<uint32_t>(constituent_index);
  212. const auto constituent = _.FindDef(constituent_id);
  213. if (!constituent ||
  214. !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  215. return _.diag(SPV_ERROR_INVALID_ID, inst)
  216. << opcode_name << " Constituent <id> '"
  217. << _.getIdName(constituent_id)
  218. << "' is not a constant or undef.";
  219. }
  220. const auto constituent_type = _.FindDef(constituent->type_id());
  221. if (!constituent_type) {
  222. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  223. << "Result type is not defined.";
  224. }
  225. const auto member_type_id =
  226. result_type->GetOperandAs<uint32_t>(member_index);
  227. const auto member_type = _.FindDef(member_type_id);
  228. if (!member_type || member_type->id() != constituent_type->id()) {
  229. return _.diag(SPV_ERROR_INVALID_ID, inst)
  230. << opcode_name << " Constituent <id> '"
  231. << _.getIdName(constituent_id)
  232. << "' type does not match the Result Type <id> '"
  233. << _.getIdName(result_type->id()) << "'s member type.";
  234. }
  235. }
  236. } break;
  237. case SpvOpTypeCooperativeMatrixNV: {
  238. if (1 != constituent_count) {
  239. return _.diag(SPV_ERROR_INVALID_ID, inst)
  240. << opcode_name << " Constituent <id> '"
  241. << _.getIdName(inst->type_id()) << "' count must be one.";
  242. }
  243. const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
  244. const auto constituent = _.FindDef(constituent_id);
  245. if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
  246. return _.diag(SPV_ERROR_INVALID_ID, inst)
  247. << opcode_name << " Constituent <id> '"
  248. << _.getIdName(constituent_id)
  249. << "' is not a constant or undef.";
  250. }
  251. const auto constituent_type = _.FindDef(constituent->type_id());
  252. if (!constituent_type) {
  253. return _.diag(SPV_ERROR_INVALID_ID, constituent)
  254. << "Result type is not defined.";
  255. }
  256. const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
  257. const auto component_type = _.FindDef(component_type_id);
  258. if (!component_type || component_type->id() != constituent_type->id()) {
  259. return _.diag(SPV_ERROR_INVALID_ID, inst)
  260. << opcode_name << " Constituent <id> '"
  261. << _.getIdName(constituent_id)
  262. << "' type does not match the Result Type <id> '"
  263. << _.getIdName(result_type->id()) << "'s component type.";
  264. }
  265. } break;
  266. default:
  267. break;
  268. }
  269. return SPV_SUCCESS;
  270. }
  271. spv_result_t ValidateConstantSampler(ValidationState_t& _,
  272. const Instruction* inst) {
  273. const auto result_type = _.FindDef(inst->type_id());
  274. if (!result_type || result_type->opcode() != SpvOpTypeSampler) {
  275. return _.diag(SPV_ERROR_INVALID_ID, result_type)
  276. << "OpConstantSampler Result Type <id> '"
  277. << _.getIdName(inst->type_id()) << "' is not a sampler type.";
  278. }
  279. return SPV_SUCCESS;
  280. }
  281. // True if instruction defines a type that can have a null value, as defined by
  282. // the SPIR-V spec. Tracks composite-type components through module to check
  283. // nullability transitively.
  284. bool IsTypeNullable(const std::vector<uint32_t>& instruction,
  285. const ValidationState_t& _) {
  286. uint16_t opcode;
  287. uint16_t word_count;
  288. spvOpcodeSplit(instruction[0], &word_count, &opcode);
  289. switch (static_cast<SpvOp>(opcode)) {
  290. case SpvOpTypeBool:
  291. case SpvOpTypeInt:
  292. case SpvOpTypeFloat:
  293. case SpvOpTypeEvent:
  294. case SpvOpTypeDeviceEvent:
  295. case SpvOpTypeReserveId:
  296. case SpvOpTypeQueue:
  297. return true;
  298. case SpvOpTypeArray:
  299. case SpvOpTypeMatrix:
  300. case SpvOpTypeCooperativeMatrixNV:
  301. case SpvOpTypeVector: {
  302. auto base_type = _.FindDef(instruction[2]);
  303. return base_type && IsTypeNullable(base_type->words(), _);
  304. }
  305. case SpvOpTypeStruct: {
  306. for (size_t elementIndex = 2; elementIndex < instruction.size();
  307. ++elementIndex) {
  308. auto element = _.FindDef(instruction[elementIndex]);
  309. if (!element || !IsTypeNullable(element->words(), _)) return false;
  310. }
  311. return true;
  312. }
  313. case SpvOpTypePointer:
  314. if (instruction[2] == SpvStorageClassPhysicalStorageBuffer) {
  315. return false;
  316. }
  317. return true;
  318. default:
  319. return false;
  320. }
  321. }
  322. spv_result_t ValidateConstantNull(ValidationState_t& _,
  323. const Instruction* inst) {
  324. const auto result_type = _.FindDef(inst->type_id());
  325. if (!result_type || !IsTypeNullable(result_type->words(), _)) {
  326. return _.diag(SPV_ERROR_INVALID_ID, inst)
  327. << "OpConstantNull Result Type <id> '"
  328. << _.getIdName(inst->type_id()) << "' cannot have a null value.";
  329. }
  330. return SPV_SUCCESS;
  331. }
  332. // Validates that OpSpecConstant specializes to either int or float type.
  333. spv_result_t ValidateSpecConstant(ValidationState_t& _,
  334. const Instruction* inst) {
  335. // Operand 0 is the <id> of the type that we're specializing to.
  336. auto type_id = inst->GetOperandAs<const uint32_t>(0);
  337. auto type_instruction = _.FindDef(type_id);
  338. auto type_opcode = type_instruction->opcode();
  339. if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
  340. return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
  341. "must be an integer or "
  342. "floating-point number.";
  343. }
  344. return SPV_SUCCESS;
  345. }
  346. spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
  347. const Instruction* inst) {
  348. const auto op = inst->GetOperandAs<SpvOp>(2);
  349. // The binary parser already ensures that the op is valid for *some*
  350. // environment. Here we check restrictions.
  351. switch (op) {
  352. case SpvOpQuantizeToF16:
  353. if (!_.HasCapability(SpvCapabilityShader)) {
  354. return _.diag(SPV_ERROR_INVALID_ID, inst)
  355. << "Specialization constant operation " << spvOpcodeString(op)
  356. << " requires Shader capability";
  357. }
  358. break;
  359. case SpvOpUConvert:
  360. if (!_.features().uconvert_spec_constant_op &&
  361. !_.HasCapability(SpvCapabilityKernel)) {
  362. return _.diag(SPV_ERROR_INVALID_ID, inst)
  363. << "Prior to SPIR-V 1.4, specialization constant operation "
  364. "UConvert requires Kernel capability or extension "
  365. "SPV_AMD_gpu_shader_int16";
  366. }
  367. break;
  368. case SpvOpConvertFToS:
  369. case SpvOpConvertSToF:
  370. case SpvOpConvertFToU:
  371. case SpvOpConvertUToF:
  372. case SpvOpConvertPtrToU:
  373. case SpvOpConvertUToPtr:
  374. case SpvOpGenericCastToPtr:
  375. case SpvOpPtrCastToGeneric:
  376. case SpvOpBitcast:
  377. case SpvOpFNegate:
  378. case SpvOpFAdd:
  379. case SpvOpFSub:
  380. case SpvOpFMul:
  381. case SpvOpFDiv:
  382. case SpvOpFRem:
  383. case SpvOpFMod:
  384. case SpvOpAccessChain:
  385. case SpvOpInBoundsAccessChain:
  386. case SpvOpPtrAccessChain:
  387. case SpvOpInBoundsPtrAccessChain:
  388. if (!_.HasCapability(SpvCapabilityKernel)) {
  389. return _.diag(SPV_ERROR_INVALID_ID, inst)
  390. << "Specialization constant operation " << spvOpcodeString(op)
  391. << " requires Kernel capability";
  392. }
  393. break;
  394. default:
  395. break;
  396. }
  397. // TODO(dneto): Validate result type and arguments to the various operations.
  398. return SPV_SUCCESS;
  399. }
  400. } // namespace
  401. spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
  402. switch (inst->opcode()) {
  403. case SpvOpConstantTrue:
  404. case SpvOpConstantFalse:
  405. case SpvOpSpecConstantTrue:
  406. case SpvOpSpecConstantFalse:
  407. if (auto error = ValidateConstantBool(_, inst)) return error;
  408. break;
  409. case SpvOpConstantComposite:
  410. case SpvOpSpecConstantComposite:
  411. if (auto error = ValidateConstantComposite(_, inst)) return error;
  412. break;
  413. case SpvOpConstantSampler:
  414. if (auto error = ValidateConstantSampler(_, inst)) return error;
  415. break;
  416. case SpvOpConstantNull:
  417. if (auto error = ValidateConstantNull(_, inst)) return error;
  418. break;
  419. case SpvOpSpecConstant:
  420. if (auto error = ValidateSpecConstant(_, inst)) return error;
  421. break;
  422. case SpvOpSpecConstantOp:
  423. if (auto error = ValidateSpecConstantOp(_, inst)) return error;
  424. break;
  425. default:
  426. break;
  427. }
  428. // Generally disallow creating 8- or 16-bit constants unless the full
  429. // capabilities are present.
  430. if (spvOpcodeIsConstant(inst->opcode()) &&
  431. _.HasCapability(SpvCapabilityShader) &&
  432. !_.IsPointerType(inst->type_id()) &&
  433. _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
  434. return _.diag(SPV_ERROR_INVALID_ID, inst)
  435. << "Cannot form constants of 8- or 16-bit types";
  436. }
  437. return SPV_SUCCESS;
  438. }
  439. } // namespace val
  440. } // namespace spvtools