enum_set.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. // Copyright (c) 2023 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <stddef.h>
  15. #include <algorithm>
  16. #include <cassert>
  17. #include <cstdint>
  18. #include <functional>
  19. #include <initializer_list>
  20. #include <iterator>
  21. #include <limits>
  22. #include <type_traits>
  23. #include <vector>
  24. #ifndef SOURCE_ENUM_SET_H_
  25. #define SOURCE_ENUM_SET_H_
  26. #include "source/latest_version_spirv_header.h"
  27. namespace spvtools {
  28. // This container is optimized to store and retrieve unsigned enum values.
  29. // The base model for this implementation is an open-addressing hashtable with
  30. // linear probing. For small enums (max index < 64), all operations are O(1).
  31. //
  32. // - Enums are stored in buckets (64 contiguous values max per bucket)
  33. // - Buckets ranges don't overlap, but don't have to be contiguous.
  34. // - Enums are packed into 64-bits buckets, using 1 bit per enum value.
  35. //
  36. // Example:
  37. // - MyEnum { A = 0, B = 1, C = 64, D = 65 }
  38. // - 2 buckets are required:
  39. // - bucket 0, storing values in the range [ 0; 64[
  40. // - bucket 1, storing values in the range [64; 128[
  41. //
  42. // - Buckets are stored in a sorted vector (sorted by bucket range).
  43. // - Retrieval is done by computing the theoretical bucket index using the enum
  44. // value, and
  45. // doing a linear scan from this position.
  46. // - Insertion is done by retrieving the bucket and either:
  47. // - inserting a new bucket in the sorted vector when no buckets has a
  48. // compatible range.
  49. // - setting the corresponding bit in the bucket.
  50. // This means insertion in the middle/beginning can cause a memmove when no
  51. // bucket is available. In our case, this happens at most 23 times for the
  52. // largest enum we have (Opcodes).
  53. template <typename T>
  54. class EnumSet {
  55. private:
  56. using BucketType = uint64_t;
  57. using ElementType = std::underlying_type_t<T>;
  58. static_assert(std::is_enum_v<T>, "EnumSets only works with enums.");
  59. static_assert(std::is_signed_v<ElementType> == false,
  60. "EnumSet doesn't supports signed enums.");
  61. // Each bucket can hold up to `kBucketSize` distinct, contiguous enum values.
  62. // The first value a bucket can hold must be aligned on `kBucketSize`.
  63. struct Bucket {
  64. // bit mask to store `kBucketSize` enums.
  65. BucketType data;
  66. // 1st enum this bucket can represent.
  67. T start;
  68. friend bool operator==(const Bucket& lhs, const Bucket& rhs) {
  69. return lhs.start == rhs.start && lhs.data == rhs.data;
  70. }
  71. };
  72. // How many distinct values can a bucket hold? 1 bit per value.
  73. static constexpr size_t kBucketSize = sizeof(BucketType) * 8ULL;
  74. public:
  75. class Iterator {
  76. public:
  77. typedef Iterator self_type;
  78. typedef T value_type;
  79. typedef T& reference;
  80. typedef T* pointer;
  81. typedef std::forward_iterator_tag iterator_category;
  82. typedef size_t difference_type;
  83. Iterator(const Iterator& other)
  84. : set_(other.set_),
  85. bucketIndex_(other.bucketIndex_),
  86. bucketOffset_(other.bucketOffset_) {}
  87. Iterator& operator++() {
  88. do {
  89. if (bucketIndex_ >= set_->buckets_.size()) {
  90. bucketIndex_ = set_->buckets_.size();
  91. bucketOffset_ = 0;
  92. break;
  93. }
  94. if (bucketOffset_ + 1 == kBucketSize) {
  95. bucketOffset_ = 0;
  96. ++bucketIndex_;
  97. } else {
  98. ++bucketOffset_;
  99. }
  100. } while (bucketIndex_ < set_->buckets_.size() &&
  101. !set_->HasEnumAt(bucketIndex_, bucketOffset_));
  102. return *this;
  103. }
  104. Iterator operator++(int) {
  105. Iterator old = *this;
  106. operator++();
  107. return old;
  108. }
  109. T operator*() const {
  110. assert(set_->HasEnumAt(bucketIndex_, bucketOffset_) &&
  111. "operator*() called on an invalid iterator.");
  112. return GetValueFromBucket(set_->buckets_[bucketIndex_], bucketOffset_);
  113. }
  114. bool operator!=(const Iterator& other) const {
  115. return set_ != other.set_ || bucketOffset_ != other.bucketOffset_ ||
  116. bucketIndex_ != other.bucketIndex_;
  117. }
  118. bool operator==(const Iterator& other) const {
  119. return !(operator!=(other));
  120. }
  121. Iterator& operator=(const Iterator& other) {
  122. set_ = other.set_;
  123. bucketIndex_ = other.bucketIndex_;
  124. bucketOffset_ = other.bucketOffset_;
  125. return *this;
  126. }
  127. private:
  128. Iterator(const EnumSet* set, size_t bucketIndex, ElementType bucketOffset)
  129. : set_(set), bucketIndex_(bucketIndex), bucketOffset_(bucketOffset) {}
  130. private:
  131. const EnumSet* set_ = nullptr;
  132. // Index of the bucket in the vector.
  133. size_t bucketIndex_ = 0;
  134. // Offset in bits in the current bucket.
  135. ElementType bucketOffset_ = 0;
  136. friend class EnumSet;
  137. };
  138. // Required to allow the use of std::inserter.
  139. using value_type = T;
  140. using const_iterator = Iterator;
  141. using iterator = Iterator;
  142. public:
  143. iterator cbegin() const noexcept {
  144. auto it = iterator(this, /* bucketIndex= */ 0, /* bucketOffset= */ 0);
  145. if (buckets_.size() == 0) {
  146. return it;
  147. }
  148. // The iterator has the logic to find the next valid bit. If the value 0
  149. // is not stored, use it to find the next valid bit.
  150. if (!HasEnumAt(it.bucketIndex_, it.bucketOffset_)) {
  151. ++it;
  152. }
  153. return it;
  154. }
  155. iterator begin() const noexcept { return cbegin(); }
  156. iterator cend() const noexcept {
  157. return iterator(this, buckets_.size(), /* bucketOffset= */ 0);
  158. }
  159. iterator end() const noexcept { return cend(); }
  160. // Creates an empty set.
  161. EnumSet() : buckets_(0), size_(0) {}
  162. // Creates a set and store `value` in it.
  163. EnumSet(T value) : EnumSet() { insert(value); }
  164. // Creates a set and stores each `values` in it.
  165. EnumSet(std::initializer_list<T> values) : EnumSet() {
  166. for (auto item : values) {
  167. insert(item);
  168. }
  169. }
  170. // Creates a set, and insert `count` enum values pointed by `array` in it.
  171. EnumSet(ElementType count, const T* array) : EnumSet() {
  172. for (ElementType i = 0; i < count; i++) {
  173. insert(array[i]);
  174. }
  175. }
  176. // Creates a set initialized with the content of the range [begin; end[.
  177. template <class InputIt>
  178. EnumSet(InputIt begin, InputIt end) : EnumSet() {
  179. for (; begin != end; ++begin) {
  180. insert(*begin);
  181. }
  182. }
  183. // Copies the EnumSet `other` into a new EnumSet.
  184. EnumSet(const EnumSet& other)
  185. : buckets_(other.buckets_), size_(other.size_) {}
  186. // Moves the EnumSet `other` into a new EnumSet.
  187. EnumSet(EnumSet&& other)
  188. : buckets_(std::move(other.buckets_)), size_(other.size_) {}
  189. // Deep-copies the EnumSet `other` into this EnumSet.
  190. EnumSet& operator=(const EnumSet& other) {
  191. buckets_ = other.buckets_;
  192. size_ = other.size_;
  193. return *this;
  194. }
  195. // Matches std::unordered_set::insert behavior.
  196. std::pair<iterator, bool> insert(const T& value) {
  197. const size_t index = FindBucketForValue(value);
  198. const ElementType offset = ComputeBucketOffset(value);
  199. if (index >= buckets_.size() ||
  200. buckets_[index].start != ComputeBucketStart(value)) {
  201. size_ += 1;
  202. InsertBucketFor(index, value);
  203. return std::make_pair(Iterator(this, index, offset), true);
  204. }
  205. auto& bucket = buckets_[index];
  206. const auto mask = ComputeMaskForValue(value);
  207. if (bucket.data & mask) {
  208. return std::make_pair(Iterator(this, index, offset), false);
  209. }
  210. size_ += 1;
  211. bucket.data |= ComputeMaskForValue(value);
  212. return std::make_pair(Iterator(this, index, offset), true);
  213. }
  214. // Inserts `value` in the set if possible.
  215. // Similar to `std::unordered_set::insert`, except the hint is ignored.
  216. // Returns an iterator to the inserted element, or the element preventing
  217. // insertion.
  218. iterator insert(const_iterator, const T& value) {
  219. return insert(value).first;
  220. }
  221. // Inserts `value` in the set if possible.
  222. // Similar to `std::unordered_set::insert`, except the hint is ignored.
  223. // Returns an iterator to the inserted element, or the element preventing
  224. // insertion.
  225. iterator insert(const_iterator, T&& value) { return insert(value).first; }
  226. // Inserts all the values in the range [`first`; `last[.
  227. // Similar to `std::unordered_set::insert`.
  228. template <class InputIt>
  229. void insert(InputIt first, InputIt last) {
  230. for (auto it = first; it != last; ++it) {
  231. insert(*it);
  232. }
  233. }
  234. // Removes the value `value` into the set.
  235. // Similar to `std::unordered_set::erase`.
  236. // Returns the number of erased elements.
  237. size_t erase(const T& value) {
  238. const size_t index = FindBucketForValue(value);
  239. if (index >= buckets_.size() ||
  240. buckets_[index].start != ComputeBucketStart(value)) {
  241. return 0;
  242. }
  243. auto& bucket = buckets_[index];
  244. const auto mask = ComputeMaskForValue(value);
  245. if (!(bucket.data & mask)) {
  246. return 0;
  247. }
  248. size_ -= 1;
  249. bucket.data &= ~mask;
  250. if (bucket.data == 0) {
  251. buckets_.erase(buckets_.cbegin() + index);
  252. }
  253. return 1;
  254. }
  255. // Returns true if `value` is present in the set.
  256. bool contains(T value) const {
  257. const size_t index = FindBucketForValue(value);
  258. if (index >= buckets_.size() ||
  259. buckets_[index].start != ComputeBucketStart(value)) {
  260. return false;
  261. }
  262. auto& bucket = buckets_[index];
  263. return bucket.data & ComputeMaskForValue(value);
  264. }
  265. // Returns the 1 if `value` is present in the set, `0` otherwise.
  266. inline size_t count(T value) const { return contains(value) ? 1 : 0; }
  267. // Returns true if the set is holds no values.
  268. inline bool empty() const { return size_ == 0; }
  269. // Returns the number of enums stored in this set.
  270. size_t size() const { return size_; }
  271. // Returns true if this set contains at least one value contained in `in_set`.
  272. // Note: If `in_set` is empty, this function returns true.
  273. bool HasAnyOf(const EnumSet<T>& in_set) const {
  274. if (in_set.empty()) {
  275. return true;
  276. }
  277. auto lhs = buckets_.cbegin();
  278. auto rhs = in_set.buckets_.cbegin();
  279. while (lhs != buckets_.cend() && rhs != in_set.buckets_.cend()) {
  280. if (lhs->start == rhs->start) {
  281. if (lhs->data & rhs->data) {
  282. // At least 1 bit is shared. Early return.
  283. return true;
  284. }
  285. lhs++;
  286. rhs++;
  287. continue;
  288. }
  289. // LHS bucket is smaller than the current RHS bucket. Catching up on RHS.
  290. if (lhs->start < rhs->start) {
  291. lhs++;
  292. continue;
  293. }
  294. // Otherwise, RHS needs to catch up on LHS.
  295. rhs++;
  296. }
  297. return false;
  298. }
  299. private:
  300. // Returns the index of the last bucket in which `value` could be stored.
  301. static constexpr inline size_t ComputeLargestPossibleBucketIndexFor(T value) {
  302. return static_cast<size_t>(value) / kBucketSize;
  303. }
  304. // Returns the smallest enum value that could be contained in the same bucket
  305. // as `value`.
  306. static constexpr inline T ComputeBucketStart(T value) {
  307. return static_cast<T>(kBucketSize *
  308. ComputeLargestPossibleBucketIndexFor(value));
  309. }
  310. // Returns the index of the bit that corresponds to `value` in the bucket.
  311. static constexpr inline ElementType ComputeBucketOffset(T value) {
  312. return static_cast<ElementType>(value) % kBucketSize;
  313. }
  314. // Returns the bitmask used to represent the enum `value` in its bucket.
  315. static constexpr inline BucketType ComputeMaskForValue(T value) {
  316. return 1ULL << ComputeBucketOffset(value);
  317. }
  318. // Returns the `enum` stored in `bucket` at `offset`.
  319. // `offset` is the bit-offset in the bucket storage.
  320. static constexpr inline T GetValueFromBucket(const Bucket& bucket,
  321. BucketType offset) {
  322. return static_cast<T>(static_cast<ElementType>(bucket.start) + offset);
  323. }
  324. // For a given enum `value`, finds the bucket index that could contain this
  325. // value. If no such bucket is found, the index at which the new bucket should
  326. // be inserted is returned.
  327. size_t FindBucketForValue(T value) const {
  328. // Set is empty, insert at 0.
  329. if (buckets_.size() == 0) {
  330. return 0;
  331. }
  332. const T wanted_start = ComputeBucketStart(value);
  333. assert(buckets_.size() > 0 &&
  334. "Size must not be 0 here. Has the code above changed?");
  335. size_t index = std::min(buckets_.size() - 1,
  336. ComputeLargestPossibleBucketIndexFor(value));
  337. // This loops behaves like std::upper_bound with a reverse iterator.
  338. // Buckets are sorted. 3 main cases:
  339. // - The bucket matches
  340. // => returns the bucket index.
  341. // - The found bucket is larger
  342. // => scans left until it finds the correct bucket, or insertion point.
  343. // - The found bucket is smaller
  344. // => We are at the end, so we return past-end index for insertion.
  345. for (; buckets_[index].start >= wanted_start; index--) {
  346. if (index == 0) {
  347. return 0;
  348. }
  349. }
  350. return index + 1;
  351. }
  352. // Creates a new bucket to store `value` and inserts it at `index`.
  353. // If the `index` is past the end, the bucket is inserted at the end of the
  354. // vector.
  355. void InsertBucketFor(size_t index, T value) {
  356. const T bucket_start = ComputeBucketStart(value);
  357. Bucket bucket = {1ULL << ComputeBucketOffset(value), bucket_start};
  358. auto it = buckets_.emplace(buckets_.begin() + index, std::move(bucket));
  359. #if defined(NDEBUG)
  360. (void)it; // Silencing unused variable warning.
  361. #else
  362. assert(std::next(it) == buckets_.end() ||
  363. std::next(it)->start > bucket_start);
  364. assert(it == buckets_.begin() || std::prev(it)->start < bucket_start);
  365. #endif
  366. }
  367. // Returns true if the bucket at `bucketIndex/ stores the enum at
  368. // `bucketOffset`, false otherwise.
  369. bool HasEnumAt(size_t bucketIndex, BucketType bucketOffset) const {
  370. assert(bucketIndex < buckets_.size());
  371. assert(bucketOffset < kBucketSize);
  372. return buckets_[bucketIndex].data & (1ULL << bucketOffset);
  373. }
  374. // Returns true if `lhs` and `rhs` hold the exact same values.
  375. friend bool operator==(const EnumSet& lhs, const EnumSet& rhs) {
  376. if (lhs.size_ != rhs.size_) {
  377. return false;
  378. }
  379. if (lhs.buckets_.size() != rhs.buckets_.size()) {
  380. return false;
  381. }
  382. return lhs.buckets_ == rhs.buckets_;
  383. }
  384. // Returns true if `lhs` and `rhs` hold at least 1 different value.
  385. friend bool operator!=(const EnumSet& lhs, const EnumSet& rhs) {
  386. return !(lhs == rhs);
  387. }
  388. // Storage for the buckets.
  389. std::vector<Bucket> buckets_;
  390. // How many enums is this set storing.
  391. size_t size_ = 0;
  392. };
  393. // A set of spv::Capability.
  394. using CapabilitySet = EnumSet<spv::Capability>;
  395. } // namespace spvtools
  396. #endif // SOURCE_ENUM_SET_H_