| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472 |
- // Copyright (c) 2023 Google Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- #include <stddef.h>
- #include <algorithm>
- #include <cassert>
- #include <cstdint>
- #include <functional>
- #include <initializer_list>
- #include <iterator>
- #include <limits>
- #include <type_traits>
- #include <vector>
- #ifndef SOURCE_ENUM_SET_H_
- #define SOURCE_ENUM_SET_H_
- #include "source/latest_version_spirv_header.h"
- namespace spvtools {
- // This container is optimized to store and retrieve unsigned enum values.
- // The base model for this implementation is an open-addressing hashtable with
- // linear probing. For small enums (max index < 64), all operations are O(1).
- //
- // - Enums are stored in buckets (64 contiguous values max per bucket)
- // - Buckets ranges don't overlap, but don't have to be contiguous.
- // - Enums are packed into 64-bits buckets, using 1 bit per enum value.
- //
- // Example:
- // - MyEnum { A = 0, B = 1, C = 64, D = 65 }
- // - 2 buckets are required:
- // - bucket 0, storing values in the range [ 0; 64[
- // - bucket 1, storing values in the range [64; 128[
- //
- // - Buckets are stored in a sorted vector (sorted by bucket range).
- // - Retrieval is done by computing the theoretical bucket index using the enum
- // value, and
- // doing a linear scan from this position.
- // - Insertion is done by retrieving the bucket and either:
- // - inserting a new bucket in the sorted vector when no buckets has a
- // compatible range.
- // - setting the corresponding bit in the bucket.
- // This means insertion in the middle/beginning can cause a memmove when no
- // bucket is available. In our case, this happens at most 23 times for the
- // largest enum we have (Opcodes).
- template <typename T>
- class EnumSet {
- private:
- using BucketType = uint64_t;
- using ElementType = std::underlying_type_t<T>;
- static_assert(std::is_enum_v<T>, "EnumSets only works with enums.");
- static_assert(std::is_signed_v<ElementType> == false,
- "EnumSet doesn't supports signed enums.");
- // Each bucket can hold up to `kBucketSize` distinct, contiguous enum values.
- // The first value a bucket can hold must be aligned on `kBucketSize`.
- struct Bucket {
- // bit mask to store `kBucketSize` enums.
- BucketType data;
- // 1st enum this bucket can represent.
- T start;
- friend bool operator==(const Bucket& lhs, const Bucket& rhs) {
- return lhs.start == rhs.start && lhs.data == rhs.data;
- }
- };
- // How many distinct values can a bucket hold? 1 bit per value.
- static constexpr size_t kBucketSize = sizeof(BucketType) * 8ULL;
- public:
- class Iterator {
- public:
- typedef Iterator self_type;
- typedef T value_type;
- typedef T& reference;
- typedef T* pointer;
- typedef std::forward_iterator_tag iterator_category;
- typedef size_t difference_type;
- Iterator(const Iterator& other)
- : set_(other.set_),
- bucketIndex_(other.bucketIndex_),
- bucketOffset_(other.bucketOffset_) {}
- Iterator& operator++() {
- do {
- if (bucketIndex_ >= set_->buckets_.size()) {
- bucketIndex_ = set_->buckets_.size();
- bucketOffset_ = 0;
- break;
- }
- if (bucketOffset_ + 1 == kBucketSize) {
- bucketOffset_ = 0;
- ++bucketIndex_;
- } else {
- ++bucketOffset_;
- }
- } while (bucketIndex_ < set_->buckets_.size() &&
- !set_->HasEnumAt(bucketIndex_, bucketOffset_));
- return *this;
- }
- Iterator operator++(int) {
- Iterator old = *this;
- operator++();
- return old;
- }
- T operator*() const {
- assert(set_->HasEnumAt(bucketIndex_, bucketOffset_) &&
- "operator*() called on an invalid iterator.");
- return GetValueFromBucket(set_->buckets_[bucketIndex_], bucketOffset_);
- }
- bool operator!=(const Iterator& other) const {
- return set_ != other.set_ || bucketOffset_ != other.bucketOffset_ ||
- bucketIndex_ != other.bucketIndex_;
- }
- bool operator==(const Iterator& other) const {
- return !(operator!=(other));
- }
- Iterator& operator=(const Iterator& other) {
- set_ = other.set_;
- bucketIndex_ = other.bucketIndex_;
- bucketOffset_ = other.bucketOffset_;
- return *this;
- }
- private:
- Iterator(const EnumSet* set, size_t bucketIndex, ElementType bucketOffset)
- : set_(set), bucketIndex_(bucketIndex), bucketOffset_(bucketOffset) {}
- private:
- const EnumSet* set_ = nullptr;
- // Index of the bucket in the vector.
- size_t bucketIndex_ = 0;
- // Offset in bits in the current bucket.
- ElementType bucketOffset_ = 0;
- friend class EnumSet;
- };
- // Required to allow the use of std::inserter.
- using value_type = T;
- using const_iterator = Iterator;
- using iterator = Iterator;
- public:
- iterator cbegin() const noexcept {
- auto it = iterator(this, /* bucketIndex= */ 0, /* bucketOffset= */ 0);
- if (buckets_.size() == 0) {
- return it;
- }
- // The iterator has the logic to find the next valid bit. If the value 0
- // is not stored, use it to find the next valid bit.
- if (!HasEnumAt(it.bucketIndex_, it.bucketOffset_)) {
- ++it;
- }
- return it;
- }
- iterator begin() const noexcept { return cbegin(); }
- iterator cend() const noexcept {
- return iterator(this, buckets_.size(), /* bucketOffset= */ 0);
- }
- iterator end() const noexcept { return cend(); }
- // Creates an empty set.
- EnumSet() : buckets_(0), size_(0) {}
- // Creates a set and store `value` in it.
- EnumSet(T value) : EnumSet() { insert(value); }
- // Creates a set and stores each `values` in it.
- EnumSet(std::initializer_list<T> values) : EnumSet() {
- for (auto item : values) {
- insert(item);
- }
- }
- // Creates a set, and insert `count` enum values pointed by `array` in it.
- EnumSet(ElementType count, const T* array) : EnumSet() {
- for (ElementType i = 0; i < count; i++) {
- insert(array[i]);
- }
- }
- // Creates a set initialized with the content of the range [begin; end[.
- template <class InputIt>
- EnumSet(InputIt begin, InputIt end) : EnumSet() {
- for (; begin != end; ++begin) {
- insert(*begin);
- }
- }
- // Copies the EnumSet `other` into a new EnumSet.
- EnumSet(const EnumSet& other)
- : buckets_(other.buckets_), size_(other.size_) {}
- // Moves the EnumSet `other` into a new EnumSet.
- EnumSet(EnumSet&& other)
- : buckets_(std::move(other.buckets_)), size_(other.size_) {}
- // Deep-copies the EnumSet `other` into this EnumSet.
- EnumSet& operator=(const EnumSet& other) {
- buckets_ = other.buckets_;
- size_ = other.size_;
- return *this;
- }
- // Matches std::unordered_set::insert behavior.
- std::pair<iterator, bool> insert(const T& value) {
- const size_t index = FindBucketForValue(value);
- const ElementType offset = ComputeBucketOffset(value);
- if (index >= buckets_.size() ||
- buckets_[index].start != ComputeBucketStart(value)) {
- size_ += 1;
- InsertBucketFor(index, value);
- return std::make_pair(Iterator(this, index, offset), true);
- }
- auto& bucket = buckets_[index];
- const auto mask = ComputeMaskForValue(value);
- if (bucket.data & mask) {
- return std::make_pair(Iterator(this, index, offset), false);
- }
- size_ += 1;
- bucket.data |= ComputeMaskForValue(value);
- return std::make_pair(Iterator(this, index, offset), true);
- }
- // Inserts `value` in the set if possible.
- // Similar to `std::unordered_set::insert`, except the hint is ignored.
- // Returns an iterator to the inserted element, or the element preventing
- // insertion.
- iterator insert(const_iterator, const T& value) {
- return insert(value).first;
- }
- // Inserts `value` in the set if possible.
- // Similar to `std::unordered_set::insert`, except the hint is ignored.
- // Returns an iterator to the inserted element, or the element preventing
- // insertion.
- iterator insert(const_iterator, T&& value) { return insert(value).first; }
- // Inserts all the values in the range [`first`; `last[.
- // Similar to `std::unordered_set::insert`.
- template <class InputIt>
- void insert(InputIt first, InputIt last) {
- for (auto it = first; it != last; ++it) {
- insert(*it);
- }
- }
- // Removes the value `value` into the set.
- // Similar to `std::unordered_set::erase`.
- // Returns the number of erased elements.
- size_t erase(const T& value) {
- const size_t index = FindBucketForValue(value);
- if (index >= buckets_.size() ||
- buckets_[index].start != ComputeBucketStart(value)) {
- return 0;
- }
- auto& bucket = buckets_[index];
- const auto mask = ComputeMaskForValue(value);
- if (!(bucket.data & mask)) {
- return 0;
- }
- size_ -= 1;
- bucket.data &= ~mask;
- if (bucket.data == 0) {
- buckets_.erase(buckets_.cbegin() + index);
- }
- return 1;
- }
- // Returns true if `value` is present in the set.
- bool contains(T value) const {
- const size_t index = FindBucketForValue(value);
- if (index >= buckets_.size() ||
- buckets_[index].start != ComputeBucketStart(value)) {
- return false;
- }
- auto& bucket = buckets_[index];
- return bucket.data & ComputeMaskForValue(value);
- }
- // Returns the 1 if `value` is present in the set, `0` otherwise.
- inline size_t count(T value) const { return contains(value) ? 1 : 0; }
- // Returns true if the set is holds no values.
- inline bool empty() const { return size_ == 0; }
- // Returns the number of enums stored in this set.
- size_t size() const { return size_; }
- // Returns true if this set contains at least one value contained in `in_set`.
- // Note: If `in_set` is empty, this function returns true.
- bool HasAnyOf(const EnumSet<T>& in_set) const {
- if (in_set.empty()) {
- return true;
- }
- auto lhs = buckets_.cbegin();
- auto rhs = in_set.buckets_.cbegin();
- while (lhs != buckets_.cend() && rhs != in_set.buckets_.cend()) {
- if (lhs->start == rhs->start) {
- if (lhs->data & rhs->data) {
- // At least 1 bit is shared. Early return.
- return true;
- }
- lhs++;
- rhs++;
- continue;
- }
- // LHS bucket is smaller than the current RHS bucket. Catching up on RHS.
- if (lhs->start < rhs->start) {
- lhs++;
- continue;
- }
- // Otherwise, RHS needs to catch up on LHS.
- rhs++;
- }
- return false;
- }
- private:
- // Returns the index of the last bucket in which `value` could be stored.
- static constexpr inline size_t ComputeLargestPossibleBucketIndexFor(T value) {
- return static_cast<size_t>(value) / kBucketSize;
- }
- // Returns the smallest enum value that could be contained in the same bucket
- // as `value`.
- static constexpr inline T ComputeBucketStart(T value) {
- return static_cast<T>(kBucketSize *
- ComputeLargestPossibleBucketIndexFor(value));
- }
- // Returns the index of the bit that corresponds to `value` in the bucket.
- static constexpr inline ElementType ComputeBucketOffset(T value) {
- return static_cast<ElementType>(value) % kBucketSize;
- }
- // Returns the bitmask used to represent the enum `value` in its bucket.
- static constexpr inline BucketType ComputeMaskForValue(T value) {
- return 1ULL << ComputeBucketOffset(value);
- }
- // Returns the `enum` stored in `bucket` at `offset`.
- // `offset` is the bit-offset in the bucket storage.
- static constexpr inline T GetValueFromBucket(const Bucket& bucket,
- BucketType offset) {
- return static_cast<T>(static_cast<ElementType>(bucket.start) + offset);
- }
- // For a given enum `value`, finds the bucket index that could contain this
- // value. If no such bucket is found, the index at which the new bucket should
- // be inserted is returned.
- size_t FindBucketForValue(T value) const {
- // Set is empty, insert at 0.
- if (buckets_.size() == 0) {
- return 0;
- }
- const T wanted_start = ComputeBucketStart(value);
- assert(buckets_.size() > 0 &&
- "Size must not be 0 here. Has the code above changed?");
- size_t index = std::min(buckets_.size() - 1,
- ComputeLargestPossibleBucketIndexFor(value));
- // This loops behaves like std::upper_bound with a reverse iterator.
- // Buckets are sorted. 3 main cases:
- // - The bucket matches
- // => returns the bucket index.
- // - The found bucket is larger
- // => scans left until it finds the correct bucket, or insertion point.
- // - The found bucket is smaller
- // => We are at the end, so we return past-end index for insertion.
- for (; buckets_[index].start >= wanted_start; index--) {
- if (index == 0) {
- return 0;
- }
- }
- return index + 1;
- }
- // Creates a new bucket to store `value` and inserts it at `index`.
- // If the `index` is past the end, the bucket is inserted at the end of the
- // vector.
- void InsertBucketFor(size_t index, T value) {
- const T bucket_start = ComputeBucketStart(value);
- Bucket bucket = {1ULL << ComputeBucketOffset(value), bucket_start};
- auto it = buckets_.emplace(buckets_.begin() + index, std::move(bucket));
- #if defined(NDEBUG)
- (void)it; // Silencing unused variable warning.
- #else
- assert(std::next(it) == buckets_.end() ||
- std::next(it)->start > bucket_start);
- assert(it == buckets_.begin() || std::prev(it)->start < bucket_start);
- #endif
- }
- // Returns true if the bucket at `bucketIndex/ stores the enum at
- // `bucketOffset`, false otherwise.
- bool HasEnumAt(size_t bucketIndex, BucketType bucketOffset) const {
- assert(bucketIndex < buckets_.size());
- assert(bucketOffset < kBucketSize);
- return buckets_[bucketIndex].data & (1ULL << bucketOffset);
- }
- // Returns true if `lhs` and `rhs` hold the exact same values.
- friend bool operator==(const EnumSet& lhs, const EnumSet& rhs) {
- if (lhs.size_ != rhs.size_) {
- return false;
- }
- if (lhs.buckets_.size() != rhs.buckets_.size()) {
- return false;
- }
- return lhs.buckets_ == rhs.buckets_;
- }
- // Returns true if `lhs` and `rhs` hold at least 1 different value.
- friend bool operator!=(const EnumSet& lhs, const EnumSet& rhs) {
- return !(lhs == rhs);
- }
- // Storage for the buckets.
- std::vector<Bucket> buckets_;
- // How many enums is this set storing.
- size_t size_ = 0;
- };
- // A set of spv::Capability.
- using CapabilitySet = EnumSet<spv::Capability>;
- } // namespace spvtools
- #endif // SOURCE_ENUM_SET_H_
|