validate_atomics.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. // Copyright (c) 2017 Google Inc.
  2. // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
  3. // reserved.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. // Validates correctness of atomic SPIR-V instructions.
  17. #include "source/val/validate.h"
  18. #include "source/diagnostic.h"
  19. #include "source/opcode.h"
  20. #include "source/spirv_target_env.h"
  21. #include "source/util/bitutils.h"
  22. #include "source/val/instruction.h"
  23. #include "source/val/validate_memory_semantics.h"
  24. #include "source/val/validate_scopes.h"
  25. #include "source/val/validation_state.h"
  26. namespace {
  27. bool IsStorageClassAllowedByUniversalRules(spv::StorageClass storage_class) {
  28. switch (storage_class) {
  29. case spv::StorageClass::Uniform:
  30. case spv::StorageClass::StorageBuffer:
  31. case spv::StorageClass::Workgroup:
  32. case spv::StorageClass::CrossWorkgroup:
  33. case spv::StorageClass::Generic:
  34. case spv::StorageClass::AtomicCounter:
  35. case spv::StorageClass::Image:
  36. case spv::StorageClass::Function:
  37. case spv::StorageClass::PhysicalStorageBuffer:
  38. case spv::StorageClass::TaskPayloadWorkgroupEXT:
  39. return true;
  40. break;
  41. default:
  42. return false;
  43. }
  44. }
  45. bool HasReturnType(spv::Op opcode) {
  46. switch (opcode) {
  47. case spv::Op::OpAtomicStore:
  48. case spv::Op::OpAtomicFlagClear:
  49. return false;
  50. break;
  51. default:
  52. return true;
  53. }
  54. }
  55. bool HasOnlyFloatReturnType(spv::Op opcode) {
  56. switch (opcode) {
  57. case spv::Op::OpAtomicFAddEXT:
  58. case spv::Op::OpAtomicFMinEXT:
  59. case spv::Op::OpAtomicFMaxEXT:
  60. return true;
  61. break;
  62. default:
  63. return false;
  64. }
  65. }
  66. bool HasOnlyIntReturnType(spv::Op opcode) {
  67. switch (opcode) {
  68. case spv::Op::OpAtomicCompareExchange:
  69. case spv::Op::OpAtomicCompareExchangeWeak:
  70. case spv::Op::OpAtomicIIncrement:
  71. case spv::Op::OpAtomicIDecrement:
  72. case spv::Op::OpAtomicIAdd:
  73. case spv::Op::OpAtomicISub:
  74. case spv::Op::OpAtomicSMin:
  75. case spv::Op::OpAtomicUMin:
  76. case spv::Op::OpAtomicSMax:
  77. case spv::Op::OpAtomicUMax:
  78. case spv::Op::OpAtomicAnd:
  79. case spv::Op::OpAtomicOr:
  80. case spv::Op::OpAtomicXor:
  81. return true;
  82. break;
  83. default:
  84. return false;
  85. }
  86. }
  87. bool HasIntOrFloatReturnType(spv::Op opcode) {
  88. switch (opcode) {
  89. case spv::Op::OpAtomicLoad:
  90. case spv::Op::OpAtomicExchange:
  91. return true;
  92. break;
  93. default:
  94. return false;
  95. }
  96. }
  97. bool HasOnlyBoolReturnType(spv::Op opcode) {
  98. switch (opcode) {
  99. case spv::Op::OpAtomicFlagTestAndSet:
  100. return true;
  101. break;
  102. default:
  103. return false;
  104. }
  105. }
  106. } // namespace
  107. namespace spvtools {
  108. namespace val {
  109. // Validates correctness of atomic instructions.
  110. spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
  111. const spv::Op opcode = inst->opcode();
  112. switch (opcode) {
  113. case spv::Op::OpAtomicLoad:
  114. case spv::Op::OpAtomicStore:
  115. case spv::Op::OpAtomicExchange:
  116. case spv::Op::OpAtomicFAddEXT:
  117. case spv::Op::OpAtomicCompareExchange:
  118. case spv::Op::OpAtomicCompareExchangeWeak:
  119. case spv::Op::OpAtomicIIncrement:
  120. case spv::Op::OpAtomicIDecrement:
  121. case spv::Op::OpAtomicIAdd:
  122. case spv::Op::OpAtomicISub:
  123. case spv::Op::OpAtomicSMin:
  124. case spv::Op::OpAtomicUMin:
  125. case spv::Op::OpAtomicFMinEXT:
  126. case spv::Op::OpAtomicSMax:
  127. case spv::Op::OpAtomicUMax:
  128. case spv::Op::OpAtomicFMaxEXT:
  129. case spv::Op::OpAtomicAnd:
  130. case spv::Op::OpAtomicOr:
  131. case spv::Op::OpAtomicXor:
  132. case spv::Op::OpAtomicFlagTestAndSet:
  133. case spv::Op::OpAtomicFlagClear: {
  134. const uint32_t result_type = inst->type_id();
  135. // All current atomics only are scalar result
  136. // Validate return type first so can just check if pointer type is same
  137. // (if applicable)
  138. if (HasReturnType(opcode)) {
  139. if (HasOnlyFloatReturnType(opcode) &&
  140. !_.IsFloatScalarType(result_type)) {
  141. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  142. << spvOpcodeString(opcode)
  143. << ": expected Result Type to be float scalar type";
  144. } else if (HasOnlyIntReturnType(opcode) &&
  145. !_.IsIntScalarType(result_type)) {
  146. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  147. << spvOpcodeString(opcode)
  148. << ": expected Result Type to be integer scalar type";
  149. } else if (HasIntOrFloatReturnType(opcode) &&
  150. !_.IsFloatScalarType(result_type) &&
  151. !_.IsIntScalarType(result_type)) {
  152. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  153. << spvOpcodeString(opcode)
  154. << ": expected Result Type to be integer or float scalar type";
  155. } else if (HasOnlyBoolReturnType(opcode) &&
  156. !_.IsBoolScalarType(result_type)) {
  157. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  158. << spvOpcodeString(opcode)
  159. << ": expected Result Type to be bool scalar type";
  160. }
  161. }
  162. uint32_t operand_index = HasReturnType(opcode) ? 2 : 0;
  163. const uint32_t pointer_type = _.GetOperandTypeId(inst, operand_index++);
  164. uint32_t data_type = 0;
  165. spv::StorageClass storage_class;
  166. if (!_.GetPointerTypeInfo(pointer_type, &data_type, &storage_class)) {
  167. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  168. << spvOpcodeString(opcode)
  169. << ": expected Pointer to be of type OpTypePointer";
  170. }
  171. // Can't use result_type because OpAtomicStore doesn't have a result
  172. if (_.IsIntScalarType(data_type) && _.GetBitWidth(data_type) == 64 &&
  173. !_.HasCapability(spv::Capability::Int64Atomics)) {
  174. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  175. << spvOpcodeString(opcode)
  176. << ": 64-bit atomics require the Int64Atomics capability";
  177. }
  178. // Validate storage class against universal rules
  179. if (!IsStorageClassAllowedByUniversalRules(storage_class)) {
  180. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  181. << spvOpcodeString(opcode)
  182. << ": storage class forbidden by universal validation rules.";
  183. }
  184. // Then Shader rules
  185. if (_.HasCapability(spv::Capability::Shader)) {
  186. // Vulkan environment rule
  187. if (spvIsVulkanEnv(_.context()->target_env)) {
  188. if ((storage_class != spv::StorageClass::Uniform) &&
  189. (storage_class != spv::StorageClass::StorageBuffer) &&
  190. (storage_class != spv::StorageClass::Workgroup) &&
  191. (storage_class != spv::StorageClass::Image) &&
  192. (storage_class != spv::StorageClass::PhysicalStorageBuffer) &&
  193. (storage_class != spv::StorageClass::TaskPayloadWorkgroupEXT)) {
  194. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  195. << _.VkErrorID(4686) << spvOpcodeString(opcode)
  196. << ": Vulkan spec only allows storage classes for atomic to "
  197. "be: Uniform, Workgroup, Image, StorageBuffer, "
  198. "PhysicalStorageBuffer or TaskPayloadWorkgroupEXT.";
  199. }
  200. } else if (storage_class == spv::StorageClass::Function) {
  201. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  202. << spvOpcodeString(opcode)
  203. << ": Function storage class forbidden when the Shader "
  204. "capability is declared.";
  205. }
  206. if (opcode == spv::Op::OpAtomicFAddEXT) {
  207. // result type being float checked already
  208. if ((_.GetBitWidth(result_type) == 16) &&
  209. (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT))) {
  210. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  211. << spvOpcodeString(opcode)
  212. << ": float add atomics require the AtomicFloat32AddEXT "
  213. "capability";
  214. }
  215. if ((_.GetBitWidth(result_type) == 32) &&
  216. (!_.HasCapability(spv::Capability::AtomicFloat32AddEXT))) {
  217. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  218. << spvOpcodeString(opcode)
  219. << ": float add atomics require the AtomicFloat32AddEXT "
  220. "capability";
  221. }
  222. if ((_.GetBitWidth(result_type) == 64) &&
  223. (!_.HasCapability(spv::Capability::AtomicFloat64AddEXT))) {
  224. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  225. << spvOpcodeString(opcode)
  226. << ": float add atomics require the AtomicFloat64AddEXT "
  227. "capability";
  228. }
  229. } else if (opcode == spv::Op::OpAtomicFMinEXT ||
  230. opcode == spv::Op::OpAtomicFMaxEXT) {
  231. if ((_.GetBitWidth(result_type) == 16) &&
  232. (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT))) {
  233. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  234. << spvOpcodeString(opcode)
  235. << ": float min/max atomics require the "
  236. "AtomicFloat16MinMaxEXT capability";
  237. }
  238. if ((_.GetBitWidth(result_type) == 32) &&
  239. (!_.HasCapability(spv::Capability::AtomicFloat32MinMaxEXT))) {
  240. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  241. << spvOpcodeString(opcode)
  242. << ": float min/max atomics require the "
  243. "AtomicFloat32MinMaxEXT capability";
  244. }
  245. if ((_.GetBitWidth(result_type) == 64) &&
  246. (!_.HasCapability(spv::Capability::AtomicFloat64MinMaxEXT))) {
  247. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  248. << spvOpcodeString(opcode)
  249. << ": float min/max atomics require the "
  250. "AtomicFloat64MinMaxEXT capability";
  251. }
  252. }
  253. }
  254. // And finally OpenCL environment rules
  255. if (spvIsOpenCLEnv(_.context()->target_env)) {
  256. if ((storage_class != spv::StorageClass::Function) &&
  257. (storage_class != spv::StorageClass::Workgroup) &&
  258. (storage_class != spv::StorageClass::CrossWorkgroup) &&
  259. (storage_class != spv::StorageClass::Generic)) {
  260. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  261. << spvOpcodeString(opcode)
  262. << ": storage class must be Function, Workgroup, "
  263. "CrossWorkGroup or Generic in the OpenCL environment.";
  264. }
  265. if (_.context()->target_env == SPV_ENV_OPENCL_1_2) {
  266. if (storage_class == spv::StorageClass::Generic) {
  267. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  268. << "Storage class cannot be Generic in OpenCL 1.2 "
  269. "environment";
  270. }
  271. }
  272. }
  273. // If result and pointer type are different, need to do special check here
  274. if (opcode == spv::Op::OpAtomicFlagTestAndSet ||
  275. opcode == spv::Op::OpAtomicFlagClear) {
  276. if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) {
  277. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  278. << spvOpcodeString(opcode)
  279. << ": expected Pointer to point to a value of 32-bit integer "
  280. "type";
  281. }
  282. } else if (opcode == spv::Op::OpAtomicStore) {
  283. if (!_.IsFloatScalarType(data_type) && !_.IsIntScalarType(data_type)) {
  284. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  285. << spvOpcodeString(opcode)
  286. << ": expected Pointer to be a pointer to integer or float "
  287. << "scalar type";
  288. }
  289. } else if (data_type != result_type) {
  290. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  291. << spvOpcodeString(opcode)
  292. << ": expected Pointer to point to a value of type Result "
  293. "Type";
  294. }
  295. auto memory_scope = inst->GetOperandAs<const uint32_t>(operand_index++);
  296. if (auto error = ValidateMemoryScope(_, inst, memory_scope)) {
  297. return error;
  298. }
  299. const auto equal_semantics_index = operand_index++;
  300. if (auto error = ValidateMemorySemantics(_, inst, equal_semantics_index,
  301. memory_scope))
  302. return error;
  303. if (opcode == spv::Op::OpAtomicCompareExchange ||
  304. opcode == spv::Op::OpAtomicCompareExchangeWeak) {
  305. const auto unequal_semantics_index = operand_index++;
  306. if (auto error = ValidateMemorySemantics(
  307. _, inst, unequal_semantics_index, memory_scope))
  308. return error;
  309. // Volatile bits must match for equal and unequal semantics. Previous
  310. // checks guarantee they are 32-bit constants, but we need to recheck
  311. // whether they are evaluatable constants.
  312. bool is_int32 = false;
  313. bool is_equal_const = false;
  314. bool is_unequal_const = false;
  315. uint32_t equal_value = 0;
  316. uint32_t unequal_value = 0;
  317. std::tie(is_int32, is_equal_const, equal_value) = _.EvalInt32IfConst(
  318. inst->GetOperandAs<uint32_t>(equal_semantics_index));
  319. std::tie(is_int32, is_unequal_const, unequal_value) =
  320. _.EvalInt32IfConst(
  321. inst->GetOperandAs<uint32_t>(unequal_semantics_index));
  322. if (is_equal_const && is_unequal_const &&
  323. ((equal_value & uint32_t(spv::MemorySemanticsMask::Volatile)) ^
  324. (unequal_value & uint32_t(spv::MemorySemanticsMask::Volatile)))) {
  325. return _.diag(SPV_ERROR_INVALID_ID, inst)
  326. << "Volatile mask setting must match for Equal and Unequal "
  327. "memory semantics";
  328. }
  329. }
  330. if (opcode == spv::Op::OpAtomicStore) {
  331. const uint32_t value_type = _.GetOperandTypeId(inst, 3);
  332. if (value_type != data_type) {
  333. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  334. << spvOpcodeString(opcode)
  335. << ": expected Value type and the type pointed to by "
  336. "Pointer to be the same";
  337. }
  338. } else if (opcode != spv::Op::OpAtomicLoad &&
  339. opcode != spv::Op::OpAtomicIIncrement &&
  340. opcode != spv::Op::OpAtomicIDecrement &&
  341. opcode != spv::Op::OpAtomicFlagTestAndSet &&
  342. opcode != spv::Op::OpAtomicFlagClear) {
  343. const uint32_t value_type = _.GetOperandTypeId(inst, operand_index++);
  344. if (value_type != result_type) {
  345. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  346. << spvOpcodeString(opcode)
  347. << ": expected Value to be of type Result Type";
  348. }
  349. }
  350. if (opcode == spv::Op::OpAtomicCompareExchange ||
  351. opcode == spv::Op::OpAtomicCompareExchangeWeak) {
  352. const uint32_t comparator_type =
  353. _.GetOperandTypeId(inst, operand_index++);
  354. if (comparator_type != result_type) {
  355. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  356. << spvOpcodeString(opcode)
  357. << ": expected Comparator to be of type Result Type";
  358. }
  359. }
  360. break;
  361. }
  362. default:
  363. break;
  364. }
  365. return SPV_SUCCESS;
  366. }
  367. } // namespace val
  368. } // namespace spvtools