HashTable.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. // Jolt Physics Library (https://github.com/jrouwe/JoltPhysics)
  2. // SPDX-FileCopyrightText: 2024 Jorrit Rouwe
  3. // SPDX-License-Identifier: MIT
  4. #pragma once
  5. #include <Jolt/Math/BVec16.h>
  6. JPH_NAMESPACE_BEGIN
  7. /// Helper class for implementing an UnorderedSet or UnorderedMap
  8. /// Based on CppCon 2017: Matt Kulukundis "Designing a Fast, Efficient, Cache-friendly Hash Table, Step by Step"
  9. /// See: https://www.youtube.com/watch?v=ncHmEUmJZf4
  10. template <class Key, class KeyValue, class HashTableDetail, class Hash, class KeyEqual>
  11. class HashTable
  12. {
  13. public:
  14. /// Properties
  15. using value_type = KeyValue;
  16. using size_type = uint32;
  17. using difference_type = ptrdiff_t;
  18. private:
  19. /// Base class for iterators
  20. template <class Table, class Iterator>
  21. class IteratorBase
  22. {
  23. public:
  24. /// Properties
  25. using difference_type = typename Table::difference_type;
  26. using value_type = typename Table::value_type;
  27. using iterator_category = std::forward_iterator_tag;
  28. /// Copy constructor
  29. IteratorBase(const IteratorBase &inRHS) = default;
  30. /// Assignment operator
  31. IteratorBase & operator = (const IteratorBase &inRHS) = default;
  32. /// Iterator at start of table
  33. explicit IteratorBase(Table *inTable) :
  34. mTable(inTable),
  35. mIndex(0)
  36. {
  37. while (mIndex < mTable->mMaxSize && (mTable->mControl[mIndex] & cBucketUsed) == 0)
  38. ++mIndex;
  39. }
  40. /// Iterator at specific index
  41. IteratorBase(Table *inTable, size_type inIndex) :
  42. mTable(inTable),
  43. mIndex(inIndex)
  44. {
  45. }
  46. /// Prefix increment
  47. Iterator & operator ++ ()
  48. {
  49. JPH_ASSERT(IsValid());
  50. do
  51. {
  52. ++mIndex;
  53. }
  54. while (mIndex < mTable->mMaxSize && (mTable->mControl[mIndex] & cBucketUsed) == 0);
  55. return static_cast<Iterator &>(*this);
  56. }
  57. /// Postfix increment
  58. Iterator operator ++ (int)
  59. {
  60. Iterator result(mTable, mIndex);
  61. ++(*this);
  62. return result;
  63. }
  64. /// Access to key value pair
  65. const KeyValue & operator * () const
  66. {
  67. JPH_ASSERT(IsValid());
  68. return mTable->mData[mIndex];
  69. }
  70. /// Access to key value pair
  71. const KeyValue * operator -> () const
  72. {
  73. JPH_ASSERT(IsValid());
  74. return mTable->mData + mIndex;
  75. }
  76. /// Equality operator
  77. bool operator == (const Iterator &inRHS) const
  78. {
  79. return mIndex == inRHS.mIndex && mTable == inRHS.mTable;
  80. }
  81. /// Inequality operator
  82. bool operator != (const Iterator &inRHS) const
  83. {
  84. return !(*this == inRHS);
  85. }
  86. /// Check that the iterator is valid
  87. bool IsValid() const
  88. {
  89. return mIndex < mTable->mMaxSize
  90. && (mTable->mControl[mIndex] & cBucketUsed) != 0;
  91. }
  92. Table * mTable;
  93. size_type mIndex;
  94. };
  95. /// Allocate space for the hash table
  96. void AllocateTable(size_type inMaxSize)
  97. {
  98. JPH_ASSERT(mData == nullptr);
  99. mMaxSize = inMaxSize;
  100. mMaxLoad = uint32((cMaxLoadFactorNumerator * inMaxSize) / cMaxLoadFactorDenominator);
  101. size_type required_size = mMaxSize * (sizeof(KeyValue) + 1) + 15; // Add 15 bytes to mirror the first 15 bytes of the control values
  102. if constexpr (cNeedsAlignedAllocate)
  103. mData = reinterpret_cast<KeyValue *>(AlignedAllocate(required_size, alignof(KeyValue)));
  104. else
  105. mData = reinterpret_cast<KeyValue *>(Allocate(required_size));
  106. mControl = reinterpret_cast<uint8 *>(mData + mMaxSize);
  107. }
  108. /// Copy the contents of another hash table
  109. void CopyTable(const HashTable &inRHS)
  110. {
  111. if (inRHS.empty())
  112. return;
  113. AllocateTable(inRHS.mMaxSize);
  114. // Copy control bytes
  115. memcpy(mControl, inRHS.mControl, mMaxSize + 15);
  116. // Copy elements
  117. uint index = 0;
  118. for (const uint8 *control = mControl, *control_end = mControl + mMaxSize; control != control_end; ++control, ++index)
  119. if (*control & cBucketUsed)
  120. ::new (mData + index) KeyValue(inRHS.mData[index]);
  121. mSize = inRHS.mSize;
  122. }
  123. /// Grow the table to the next power of 2
  124. void GrowTable()
  125. {
  126. // Calculate new size
  127. size_type new_max_size = max<size_type>(mMaxSize << 1, 16);
  128. if (new_max_size < mMaxSize)
  129. {
  130. JPH_ASSERT(false, "Overflow in hash table size, can't grow!");
  131. return;
  132. }
  133. // Move the old table to a temporary structure
  134. size_type old_max_size = mMaxSize;
  135. KeyValue *old_data = mData;
  136. const uint8 *old_control = mControl;
  137. mData = nullptr;
  138. mControl = nullptr;
  139. mSize = 0;
  140. mMaxSize = 0;
  141. mMaxLoad = 0;
  142. // Allocate new table
  143. AllocateTable(new_max_size);
  144. // Reset all control bytes
  145. memset(mControl, cBucketEmpty, mMaxSize + 15);
  146. if (old_data != nullptr)
  147. {
  148. // Copy all elements from the old table
  149. for (size_type i = 0; i < old_max_size; ++i)
  150. if (old_control[i] & cBucketUsed)
  151. {
  152. size_type index;
  153. KeyValue *element = old_data + i;
  154. JPH_IF_ENABLE_ASSERTS(bool inserted =) InsertKey</* AllowDeleted= */ false>(HashTableDetail::sGetKey(*element), index);
  155. JPH_ASSERT(inserted);
  156. ::new (mData + index) KeyValue(std::move(*element));
  157. element->~KeyValue();
  158. }
  159. // Free memory
  160. if constexpr (cNeedsAlignedAllocate)
  161. AlignedFree(old_data);
  162. else
  163. Free(old_data);
  164. }
  165. }
  166. protected:
  167. /// Get an element by index
  168. KeyValue & GetElement(size_type inIndex) const
  169. {
  170. return mData[inIndex];
  171. }
  172. /// Insert a key into the map, returns true if the element was inserted, false if it already existed.
  173. /// outIndex is the index at which the element should be constructed / where it is located.
  174. template <bool AllowDeleted = true>
  175. bool InsertKey(const Key &inKey, size_type &outIndex)
  176. {
  177. // Ensure we have enough space
  178. if (mSize + 1 >= mMaxLoad)
  179. GrowTable();
  180. // Calculate hash
  181. uint64 hash_value = Hash { } (inKey);
  182. // Split hash into control byte and index
  183. uint8 control = cBucketUsed | uint8(hash_value);
  184. size_type bucket_mask = mMaxSize - 1;
  185. size_type index = size_type(hash_value >> 7) & bucket_mask;
  186. BVec16 control16 = BVec16::sReplicate(control);
  187. BVec16 bucket_empty = BVec16::sZero();
  188. BVec16 bucket_deleted = BVec16::sReplicate(cBucketDeleted);
  189. // Keeps track of the index of the first deleted bucket we found
  190. constexpr size_type cNoDeleted = ~size_type(0);
  191. size_type first_deleted_index = cNoDeleted;
  192. // Linear probing
  193. KeyEqual equal;
  194. for (;;)
  195. {
  196. // Read 16 control values (note that we added 15 bytes at the end of the control values that mirror the first 15 bytes)
  197. BVec16 control_bytes = BVec16::sLoadByte16(mControl + index);
  198. // Check for the control value we're looking for
  199. uint32 control_equal = uint32(BVec16::sEquals(control_bytes, control16).GetTrues());
  200. // Check for empty buckets
  201. uint32 control_empty = uint32(BVec16::sEquals(control_bytes, bucket_empty).GetTrues());
  202. // Check if we're still scanning for deleted buckets
  203. if constexpr (AllowDeleted)
  204. if (first_deleted_index == cNoDeleted)
  205. {
  206. // Check if any buckets have been deleted, if so store the first one
  207. uint32 control_deleted = uint32(BVec16::sEquals(control_bytes, bucket_deleted).GetTrues());
  208. if (control_deleted != 0)
  209. first_deleted_index = index + CountTrailingZeros(control_deleted);
  210. }
  211. // Index within the 16 buckets
  212. size_type local_index = index;
  213. // Loop while there's still buckets to process
  214. while ((control_equal | control_empty) != 0)
  215. {
  216. // Get the index of the first bucket that is either equal or empty
  217. uint first_equal = CountTrailingZeros(control_equal);
  218. uint first_empty = CountTrailingZeros(control_empty);
  219. // Check if we first found a bucket with equal control value before an empty bucket
  220. if (first_equal < first_empty)
  221. {
  222. // Skip to the bucket
  223. local_index += first_equal;
  224. // Make sure that our index is not beyond the end of the table
  225. local_index &= bucket_mask;
  226. // We found a bucket with same control value
  227. if (equal(HashTableDetail::sGetKey(mData[local_index]), inKey))
  228. {
  229. // Element already exists
  230. outIndex = local_index;
  231. return false;
  232. }
  233. // Skip past this bucket
  234. local_index++;
  235. uint shift = first_equal + 1;
  236. control_equal >>= shift;
  237. control_empty >>= shift;
  238. }
  239. else
  240. {
  241. // An empty bucket was found, we can insert a new item
  242. JPH_ASSERT(control_empty != 0);
  243. // Get the location of the first empty or deleted bucket
  244. local_index += first_empty;
  245. if constexpr (AllowDeleted)
  246. if (first_deleted_index < local_index)
  247. local_index = first_deleted_index;
  248. // Make sure that our index is not beyond the end of the table
  249. local_index &= bucket_mask;
  250. // Update control byte
  251. mControl[local_index] = control;
  252. if (local_index < 15)
  253. mControl[mMaxSize + local_index] = control; // Mirror the first 15 bytes at the end of the control values
  254. ++mSize;
  255. // Return index to newly allocated bucket
  256. outIndex = local_index;
  257. return true;
  258. }
  259. }
  260. // Move to next batch of 16 buckets
  261. index = (index + 16) & bucket_mask;
  262. }
  263. }
  264. public:
  265. /// Non-const iterator
  266. class iterator : public IteratorBase<HashTable, iterator>
  267. {
  268. using Base = IteratorBase<HashTable, iterator>;
  269. public:
  270. /// Properties
  271. using reference = typename Base::value_type &;
  272. using pointer = typename Base::value_type *;
  273. /// Constructors
  274. explicit iterator(HashTable *inTable) : Base(inTable) { }
  275. iterator(HashTable *inTable, size_type inIndex) : Base(inTable, inIndex) { }
  276. iterator(const iterator &inIterator) : Base(inIterator) { }
  277. /// Assignment
  278. iterator & operator = (const iterator &inRHS) { Base::operator = (inRHS); return *this; }
  279. using Base::operator *;
  280. /// Non-const access to key value pair
  281. KeyValue & operator * ()
  282. {
  283. JPH_ASSERT(this->IsValid());
  284. return this->mTable->mData[this->mIndex];
  285. }
  286. using Base::operator ->;
  287. /// Non-const access to key value pair
  288. KeyValue * operator -> ()
  289. {
  290. JPH_ASSERT(this->IsValid());
  291. return this->mTable->mData + this->mIndex;
  292. }
  293. };
  294. /// Const iterator
  295. class const_iterator : public IteratorBase<const HashTable, const_iterator>
  296. {
  297. using Base = IteratorBase<const HashTable, const_iterator>;
  298. public:
  299. /// Properties
  300. using reference = const typename Base::value_type &;
  301. using pointer = const typename Base::value_type *;
  302. /// Constructors
  303. explicit const_iterator(const HashTable *inTable) : Base(inTable) { }
  304. const_iterator(const HashTable *inTable, size_type inIndex) : Base(inTable, inIndex) { }
  305. const_iterator(const const_iterator &inRHS) : Base(inRHS) { }
  306. const_iterator(const iterator &inIterator) : Base(inIterator.mTable, inIterator.mIndex) { }
  307. /// Assignment
  308. const_iterator & operator = (const iterator &inRHS) { this->mTable = inRHS.mTable; this->mIndex = inRHS.mIndex; return *this; }
  309. const_iterator & operator = (const const_iterator &inRHS) { Base::operator = (inRHS); return *this; }
  310. };
  311. /// Default constructor
  312. HashTable() = default;
  313. /// Copy constructor
  314. HashTable(const HashTable &inRHS)
  315. {
  316. CopyTable(inRHS);
  317. }
  318. /// Move constructor
  319. HashTable(HashTable &&ioRHS) noexcept :
  320. mData(ioRHS.mData),
  321. mControl(ioRHS.mControl),
  322. mSize(ioRHS.mSize),
  323. mMaxSize(ioRHS.mMaxSize),
  324. mMaxLoad(ioRHS.mMaxLoad)
  325. {
  326. ioRHS.mData = nullptr;
  327. ioRHS.mControl = nullptr;
  328. ioRHS.mSize = 0;
  329. ioRHS.mMaxSize = 0;
  330. ioRHS.mMaxLoad = 0;
  331. }
  332. /// Assignment operator
  333. HashTable & operator = (const HashTable &inRHS)
  334. {
  335. if (this != &inRHS)
  336. {
  337. clear();
  338. CopyTable(inRHS);
  339. }
  340. return *this;
  341. }
  342. /// Destructor
  343. ~HashTable()
  344. {
  345. clear();
  346. }
  347. /// Reserve memory for a certain number of elements
  348. void reserve(size_type inMaxSize)
  349. {
  350. // Calculate max size based on load factor
  351. size_type max_size = GetNextPowerOf2(max<uint32>((cMaxLoadFactorDenominator * inMaxSize) / cMaxLoadFactorNumerator, 16));
  352. if (max_size <= mMaxSize)
  353. return;
  354. // Allocate buffers
  355. AllocateTable(max_size);
  356. // Reset all control bytes
  357. memset(mControl, cBucketEmpty, mMaxSize + 15);
  358. }
  359. /// Destroy the entire hash table
  360. void clear()
  361. {
  362. // Delete all elements
  363. if constexpr (!std::is_trivially_destructible<KeyValue>())
  364. if (!empty())
  365. for (size_type i = 0; i < mMaxSize; ++i)
  366. if (mControl[i] & cBucketUsed)
  367. mData[i].~KeyValue();
  368. if (mData != nullptr)
  369. {
  370. // Free memory
  371. if constexpr (cNeedsAlignedAllocate)
  372. AlignedFree(mData);
  373. else
  374. Free(mData);
  375. // Reset members
  376. mData = nullptr;
  377. mControl = nullptr;
  378. mSize = 0;
  379. mMaxSize = 0;
  380. mMaxLoad = 0;
  381. }
  382. }
  383. /// Iterator to first element
  384. iterator begin()
  385. {
  386. return iterator(this);
  387. }
  388. /// Iterator to one beyond last element
  389. iterator end()
  390. {
  391. return iterator(this, mMaxSize);
  392. }
  393. /// Iterator to first element
  394. const_iterator begin() const
  395. {
  396. return const_iterator(this);
  397. }
  398. /// Iterator to one beyond last element
  399. const_iterator end() const
  400. {
  401. return const_iterator(this, mMaxSize);
  402. }
  403. /// Iterator to first element
  404. const_iterator cbegin() const
  405. {
  406. return const_iterator(this);
  407. }
  408. /// Iterator to one beyond last element
  409. const_iterator cend() const
  410. {
  411. return const_iterator(this, mMaxSize);
  412. }
  413. /// Check if there are no elements in the table
  414. bool empty() const
  415. {
  416. return mSize == 0;
  417. }
  418. /// Number of elements in the table
  419. size_type size() const
  420. {
  421. return mSize;
  422. }
  423. /// Insert a new element, returns iterator and if the element was inserted
  424. std::pair<iterator, bool> insert(const value_type &inValue)
  425. {
  426. size_type index;
  427. bool inserted = InsertKey(HashTableDetail::sGetKey(inValue), index);
  428. if (inserted)
  429. ::new (mData + index) KeyValue(inValue);
  430. return std::make_pair(iterator(this, index), inserted);
  431. }
  432. /// Find an element, returns iterator to element or end() if not found
  433. const_iterator find(const Key &inKey) const
  434. {
  435. // Check if we have any data
  436. if (empty())
  437. return cend();
  438. // Calculate hash
  439. uint64 hash_value = Hash { } (inKey);
  440. // Split hash into control byte and index
  441. uint8 control = cBucketUsed | uint8(hash_value);
  442. size_type bucket_mask = mMaxSize - 1;
  443. size_type index = size_type(hash_value >> 7) & bucket_mask;
  444. BVec16 control16 = BVec16::sReplicate(control);
  445. BVec16 bucket_empty = BVec16::sZero();
  446. // Linear probing
  447. KeyEqual equal;
  448. for (;;)
  449. {
  450. // Read 16 control values (note that we added 15 bytes at the end of the control values that mirror the first 15 bytes)
  451. BVec16 control_bytes = BVec16::sLoadByte16(mControl + index);
  452. // Check for the control value we're looking for
  453. uint32 control_equal = uint32(BVec16::sEquals(control_bytes, control16).GetTrues());
  454. // Check for empty buckets
  455. uint32 control_empty = uint32(BVec16::sEquals(control_bytes, bucket_empty).GetTrues());
  456. // Index within the 16 buckets
  457. size_type local_index = index;
  458. // Loop while there's still buckets to process
  459. while ((control_equal | control_empty) != 0)
  460. {
  461. // Get the index of the first bucket that is either equal or empty
  462. uint first_equal = CountTrailingZeros(control_equal);
  463. uint first_empty = CountTrailingZeros(control_empty);
  464. // Check if we first found a bucket with equal control value before an empty bucket
  465. if (first_equal < first_empty)
  466. {
  467. // Skip to the bucket
  468. local_index += first_equal;
  469. // Make sure that our index is not beyond the end of the table
  470. local_index &= bucket_mask;
  471. // We found a bucket with same control value
  472. if (equal(HashTableDetail::sGetKey(mData[local_index]), inKey))
  473. {
  474. // Element found
  475. return const_iterator(this, local_index);
  476. }
  477. // Skip past this bucket
  478. local_index++;
  479. uint shift = first_equal + 1;
  480. control_equal >>= shift;
  481. control_empty >>= shift;
  482. }
  483. else
  484. {
  485. // An empty bucket was found, we didn't find the element
  486. JPH_ASSERT(control_empty != 0);
  487. return cend();
  488. }
  489. }
  490. // Move to next batch of 16 buckets
  491. index = (index + 16) & bucket_mask;
  492. }
  493. }
  494. /// @brief Erase an element by iterator
  495. void erase(const const_iterator &inIterator)
  496. {
  497. JPH_ASSERT(inIterator.IsValid());
  498. // Mark the bucket as deleted
  499. mControl[inIterator.mIndex] = cBucketDeleted;
  500. if (inIterator.mIndex < 15)
  501. mControl[inIterator.mIndex + mMaxSize] = cBucketDeleted;
  502. // Destruct the element
  503. mData[inIterator.mIndex].~KeyValue();
  504. // Decrease size
  505. --mSize;
  506. }
  507. /// @brief Erase an element by key
  508. size_type erase(const Key &inKey)
  509. {
  510. const_iterator it = find(inKey);
  511. if (it == cend())
  512. return 0;
  513. erase(it);
  514. return 1;
  515. }
  516. /// Swap the contents of two hash tables
  517. void swap(HashTable &ioRHS) noexcept
  518. {
  519. std::swap(mData, ioRHS.mData);
  520. std::swap(mControl, ioRHS.mControl);
  521. std::swap(mSize, ioRHS.mSize);
  522. std::swap(mMaxSize, ioRHS.mMaxSize);
  523. std::swap(mMaxLoad, ioRHS.mMaxLoad);
  524. }
  525. private:
  526. /// If this allocator needs to fall back to aligned allocations because the type requires it
  527. static constexpr bool cNeedsAlignedAllocate = alignof(KeyValue) > (JPH_CPU_ADDRESS_BITS == 32? 8 : 16);
  528. /// Max load factor is cMaxLoadFactorNumerator / cMaxLoadFactorDenominator
  529. static constexpr uint64 cMaxLoadFactorNumerator = 7;
  530. static constexpr uint64 cMaxLoadFactorDenominator = 8;
  531. /// Values that the control bytes can have
  532. static constexpr uint8 cBucketEmpty = 0;
  533. static constexpr uint8 cBucketDeleted = 0x7f;
  534. static constexpr uint8 cBucketUsed = 0x80; // Lowest 7 bits are lowest 7 bits of the hash value
  535. /// The buckets, an array of size mMaxSize
  536. KeyValue * mData = nullptr;
  537. /// Control bytes, an array of size mMaxSize + 15
  538. uint8 * mControl = nullptr;
  539. /// Number of elements in the table
  540. size_type mSize = 0;
  541. /// Max number of elements that can be stored in the table
  542. size_type mMaxSize = 0;
  543. /// Max number of elements in the table before it should grow
  544. size_type mMaxLoad = 0;
  545. };
  546. JPH_NAMESPACE_END