validate_atomics.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. // Copyright (c) 2017 Google Inc.
  2. // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
  3. // reserved.
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. // Validates correctness of atomic SPIR-V instructions.
  17. #include "source/val/validate.h"
  18. #include "source/diagnostic.h"
  19. #include "source/opcode.h"
  20. #include "source/spirv_target_env.h"
  21. #include "source/util/bitutils.h"
  22. #include "source/val/instruction.h"
  23. #include "source/val/validate_memory_semantics.h"
  24. #include "source/val/validate_scopes.h"
  25. #include "source/val/validation_state.h"
  26. namespace {
  27. bool IsStorageClassAllowedByUniversalRules(uint32_t storage_class) {
  28. switch (storage_class) {
  29. case SpvStorageClassUniform:
  30. case SpvStorageClassStorageBuffer:
  31. case SpvStorageClassWorkgroup:
  32. case SpvStorageClassCrossWorkgroup:
  33. case SpvStorageClassGeneric:
  34. case SpvStorageClassAtomicCounter:
  35. case SpvStorageClassImage:
  36. case SpvStorageClassFunction:
  37. case SpvStorageClassPhysicalStorageBufferEXT:
  38. return true;
  39. break;
  40. default:
  41. return false;
  42. }
  43. }
  44. } // namespace
  45. namespace spvtools {
  46. namespace val {
  47. // Validates correctness of atomic instructions.
  48. spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
  49. const SpvOp opcode = inst->opcode();
  50. const uint32_t result_type = inst->type_id();
  51. bool is_atomic_float_opcode = false;
  52. if (opcode == SpvOpAtomicLoad || opcode == SpvOpAtomicStore ||
  53. opcode == SpvOpAtomicFAddEXT || opcode == SpvOpAtomicExchange) {
  54. is_atomic_float_opcode = true;
  55. }
  56. switch (opcode) {
  57. case SpvOpAtomicLoad:
  58. case SpvOpAtomicStore:
  59. case SpvOpAtomicExchange:
  60. case SpvOpAtomicFAddEXT:
  61. case SpvOpAtomicCompareExchange:
  62. case SpvOpAtomicCompareExchangeWeak:
  63. case SpvOpAtomicIIncrement:
  64. case SpvOpAtomicIDecrement:
  65. case SpvOpAtomicIAdd:
  66. case SpvOpAtomicISub:
  67. case SpvOpAtomicSMin:
  68. case SpvOpAtomicUMin:
  69. case SpvOpAtomicSMax:
  70. case SpvOpAtomicUMax:
  71. case SpvOpAtomicAnd:
  72. case SpvOpAtomicOr:
  73. case SpvOpAtomicXor:
  74. case SpvOpAtomicFlagTestAndSet:
  75. case SpvOpAtomicFlagClear: {
  76. if (_.HasCapability(SpvCapabilityKernel) &&
  77. (opcode == SpvOpAtomicLoad || opcode == SpvOpAtomicExchange ||
  78. opcode == SpvOpAtomicCompareExchange)) {
  79. if (!_.IsFloatScalarType(result_type) &&
  80. !_.IsIntScalarType(result_type)) {
  81. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  82. << spvOpcodeString(opcode)
  83. << ": expected Result Type to be int or float scalar type";
  84. }
  85. } else if (opcode == SpvOpAtomicFlagTestAndSet) {
  86. if (!_.IsBoolScalarType(result_type)) {
  87. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  88. << spvOpcodeString(opcode)
  89. << ": expected Result Type to be bool scalar type";
  90. }
  91. } else if (opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore) {
  92. assert(result_type == 0);
  93. } else {
  94. if (_.IsFloatScalarType(result_type)) {
  95. if (is_atomic_float_opcode) {
  96. if (opcode == SpvOpAtomicFAddEXT) {
  97. if ((_.GetBitWidth(result_type) == 32) &&
  98. (!_.HasCapability(SpvCapabilityAtomicFloat32AddEXT))) {
  99. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  100. << spvOpcodeString(opcode)
  101. << ": float add atomics require the AtomicFloat32AddEXT "
  102. "capability";
  103. }
  104. if ((_.GetBitWidth(result_type) == 64) &&
  105. (!_.HasCapability(SpvCapabilityAtomicFloat64AddEXT))) {
  106. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  107. << spvOpcodeString(opcode)
  108. << ": float add atomics require the AtomicFloat64AddEXT "
  109. "capability";
  110. }
  111. }
  112. } else {
  113. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  114. << spvOpcodeString(opcode)
  115. << ": expected Result Type to be int scalar type";
  116. }
  117. } else if (_.IsIntScalarType(result_type) &&
  118. opcode == SpvOpAtomicFAddEXT) {
  119. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  120. << spvOpcodeString(opcode)
  121. << ": expected Result Type to be float scalar type";
  122. } else if (!_.IsFloatScalarType(result_type) &&
  123. !_.IsIntScalarType(result_type)) {
  124. switch (opcode) {
  125. case SpvOpAtomicFAddEXT:
  126. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  127. << spvOpcodeString(opcode)
  128. << ": expected Result Type to be float scalar type";
  129. case SpvOpAtomicIIncrement:
  130. case SpvOpAtomicIDecrement:
  131. case SpvOpAtomicIAdd:
  132. case SpvOpAtomicISub:
  133. case SpvOpAtomicSMin:
  134. case SpvOpAtomicSMax:
  135. case SpvOpAtomicUMin:
  136. case SpvOpAtomicUMax:
  137. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  138. << spvOpcodeString(opcode)
  139. << ": expected Result Type to be integer scalar type";
  140. default:
  141. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  142. << spvOpcodeString(opcode)
  143. << ": expected Result Type to be int or float scalar type";
  144. }
  145. }
  146. if (spvIsVulkanEnv(_.context()->target_env) &&
  147. (_.GetBitWidth(result_type) != 32 &&
  148. (_.GetBitWidth(result_type) != 64 ||
  149. !_.HasCapability(SpvCapabilityInt64ImageEXT)))) {
  150. switch (opcode) {
  151. case SpvOpAtomicSMin:
  152. case SpvOpAtomicUMin:
  153. case SpvOpAtomicSMax:
  154. case SpvOpAtomicUMax:
  155. case SpvOpAtomicAnd:
  156. case SpvOpAtomicOr:
  157. case SpvOpAtomicXor:
  158. case SpvOpAtomicIAdd:
  159. case SpvOpAtomicISub:
  160. case SpvOpAtomicFAddEXT:
  161. case SpvOpAtomicLoad:
  162. case SpvOpAtomicStore:
  163. case SpvOpAtomicExchange:
  164. case SpvOpAtomicIIncrement:
  165. case SpvOpAtomicIDecrement:
  166. case SpvOpAtomicCompareExchangeWeak:
  167. case SpvOpAtomicCompareExchange: {
  168. if (_.GetBitWidth(result_type) == 64 &&
  169. _.IsIntScalarType(result_type) &&
  170. !_.HasCapability(SpvCapabilityInt64Atomics))
  171. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  172. << spvOpcodeString(opcode)
  173. << ": 64-bit atomics require the Int64Atomics "
  174. "capability";
  175. } break;
  176. default:
  177. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  178. << spvOpcodeString(opcode)
  179. << ": according to the Vulkan spec atomic Result Type "
  180. "needs "
  181. "to be a 32-bit int scalar type";
  182. }
  183. }
  184. }
  185. uint32_t operand_index =
  186. opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore ? 0 : 2;
  187. const uint32_t pointer_type = _.GetOperandTypeId(inst, operand_index++);
  188. uint32_t data_type = 0;
  189. uint32_t storage_class = 0;
  190. if (!_.GetPointerTypeInfo(pointer_type, &data_type, &storage_class)) {
  191. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  192. << spvOpcodeString(opcode)
  193. << ": expected Pointer to be of type OpTypePointer";
  194. }
  195. // Validate storage class against universal rules
  196. if (!IsStorageClassAllowedByUniversalRules(storage_class)) {
  197. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  198. << spvOpcodeString(opcode)
  199. << ": storage class forbidden by universal validation rules.";
  200. }
  201. // Then Shader rules
  202. if (_.HasCapability(SpvCapabilityShader)) {
  203. if (spvIsVulkanEnv(_.context()->target_env)) {
  204. if ((storage_class != SpvStorageClassUniform) &&
  205. (storage_class != SpvStorageClassStorageBuffer) &&
  206. (storage_class != SpvStorageClassWorkgroup) &&
  207. (storage_class != SpvStorageClassImage) &&
  208. (storage_class != SpvStorageClassPhysicalStorageBuffer)) {
  209. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  210. << _.VkErrorID(4686) << spvOpcodeString(opcode)
  211. << ": Vulkan spec only allows storage classes for atomic to "
  212. "be: Uniform, Workgroup, Image, StorageBuffer, or "
  213. "PhysicalStorageBuffer.";
  214. }
  215. } else if (storage_class == SpvStorageClassFunction) {
  216. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  217. << spvOpcodeString(opcode)
  218. << ": Function storage class forbidden when the Shader "
  219. "capability is declared.";
  220. }
  221. }
  222. // And finally OpenCL environment rules
  223. if (spvIsOpenCLEnv(_.context()->target_env)) {
  224. if ((storage_class != SpvStorageClassFunction) &&
  225. (storage_class != SpvStorageClassWorkgroup) &&
  226. (storage_class != SpvStorageClassCrossWorkgroup) &&
  227. (storage_class != SpvStorageClassGeneric)) {
  228. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  229. << spvOpcodeString(opcode)
  230. << ": storage class must be Function, Workgroup, "
  231. "CrossWorkGroup or Generic in the OpenCL environment.";
  232. }
  233. if (_.context()->target_env == SPV_ENV_OPENCL_1_2) {
  234. if (storage_class == SpvStorageClassGeneric) {
  235. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  236. << "Storage class cannot be Generic in OpenCL 1.2 "
  237. "environment";
  238. }
  239. }
  240. }
  241. if (opcode == SpvOpAtomicFlagTestAndSet ||
  242. opcode == SpvOpAtomicFlagClear) {
  243. if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) {
  244. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  245. << spvOpcodeString(opcode)
  246. << ": expected Pointer to point to a value of 32-bit int type";
  247. }
  248. } else if (opcode == SpvOpAtomicStore) {
  249. if (!_.IsFloatScalarType(data_type) && !_.IsIntScalarType(data_type)) {
  250. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  251. << spvOpcodeString(opcode)
  252. << ": expected Pointer to be a pointer to int or float "
  253. << "scalar type";
  254. }
  255. } else {
  256. if (data_type != result_type) {
  257. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  258. << spvOpcodeString(opcode)
  259. << ": expected Pointer to point to a value of type Result "
  260. "Type";
  261. }
  262. }
  263. auto memory_scope = inst->GetOperandAs<const uint32_t>(operand_index++);
  264. if (auto error = ValidateMemoryScope(_, inst, memory_scope)) {
  265. return error;
  266. }
  267. const auto equal_semantics_index = operand_index++;
  268. if (auto error = ValidateMemorySemantics(_, inst, equal_semantics_index))
  269. return error;
  270. if (opcode == SpvOpAtomicCompareExchange ||
  271. opcode == SpvOpAtomicCompareExchangeWeak) {
  272. const auto unequal_semantics_index = operand_index++;
  273. if (auto error =
  274. ValidateMemorySemantics(_, inst, unequal_semantics_index))
  275. return error;
  276. // Volatile bits must match for equal and unequal semantics. Previous
  277. // checks guarantee they are 32-bit constants, but we need to recheck
  278. // whether they are evaluatable constants.
  279. bool is_int32 = false;
  280. bool is_equal_const = false;
  281. bool is_unequal_const = false;
  282. uint32_t equal_value = 0;
  283. uint32_t unequal_value = 0;
  284. std::tie(is_int32, is_equal_const, equal_value) = _.EvalInt32IfConst(
  285. inst->GetOperandAs<uint32_t>(equal_semantics_index));
  286. std::tie(is_int32, is_unequal_const, unequal_value) =
  287. _.EvalInt32IfConst(
  288. inst->GetOperandAs<uint32_t>(unequal_semantics_index));
  289. if (is_equal_const && is_unequal_const &&
  290. ((equal_value & SpvMemorySemanticsVolatileMask) ^
  291. (unequal_value & SpvMemorySemanticsVolatileMask))) {
  292. return _.diag(SPV_ERROR_INVALID_ID, inst)
  293. << "Volatile mask setting must match for Equal and Unequal "
  294. "memory semantics";
  295. }
  296. }
  297. if (opcode == SpvOpAtomicStore) {
  298. const uint32_t value_type = _.GetOperandTypeId(inst, 3);
  299. if (value_type != data_type) {
  300. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  301. << spvOpcodeString(opcode)
  302. << ": expected Value type and the type pointed to by "
  303. "Pointer to be the same";
  304. }
  305. } else if (opcode != SpvOpAtomicLoad && opcode != SpvOpAtomicIIncrement &&
  306. opcode != SpvOpAtomicIDecrement &&
  307. opcode != SpvOpAtomicFlagTestAndSet &&
  308. opcode != SpvOpAtomicFlagClear) {
  309. const uint32_t value_type = _.GetOperandTypeId(inst, operand_index++);
  310. if (value_type != result_type) {
  311. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  312. << spvOpcodeString(opcode)
  313. << ": expected Value to be of type Result Type";
  314. }
  315. }
  316. if (opcode == SpvOpAtomicCompareExchange ||
  317. opcode == SpvOpAtomicCompareExchangeWeak) {
  318. const uint32_t comparator_type =
  319. _.GetOperandTypeId(inst, operand_index++);
  320. if (comparator_type != result_type) {
  321. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  322. << spvOpcodeString(opcode)
  323. << ": expected Comparator to be of type Result Type";
  324. }
  325. }
  326. break;
  327. }
  328. default:
  329. break;
  330. }
  331. return SPV_SUCCESS;
  332. }
  333. } // namespace val
  334. } // namespace spvtools