Browse Source

hashtable: reimplement as open-addressed robin hood hashtable

This is mostly ported from Taisei Project
Andrei Alexeyev 11 months ago
parent
commit
ba7b346e52
1 changed files with 313 additions and 115 deletions
  1. 313 115
      src/SDL_hashtable.c

+ 313 - 115
src/SDL_hashtable.c

@@ -18,26 +18,42 @@
      misrepresented as being the original software.
      misrepresented as being the original software.
   3. This notice may not be removed or altered from any source distribution.
   3. This notice may not be removed or altered from any source distribution.
 */
 */
+
 #include "SDL_internal.h"
 #include "SDL_internal.h"
 #include "SDL_hashtable.h"
 #include "SDL_hashtable.h"
 
 
+// XXX: We can't use SDL_assert here because it's going to call into hashtable code
+#include <assert.h>
+#define HT_ASSERT(x) assert(x)
+
 typedef struct SDL_HashItem
 typedef struct SDL_HashItem
 {
 {
+    // TODO: Splitting off values into a separate array might be more cache-friendly
     const void *key;
     const void *key;
     const void *value;
     const void *value;
-    struct SDL_HashItem *next;
+    Uint32 hash;
+    Uint32 probe_len : 31;
+    Uint32 live : 1;
 } SDL_HashItem;
 } SDL_HashItem;
 
 
+// Must be a power of 2 >= sizeof(SDL_HashItem)
+#define MAX_HASHITEM_SIZEOF 32u
+SDL_COMPILE_TIME_ASSERT(sizeof_SDL_HashItem, sizeof(SDL_HashItem) <= MAX_HASHITEM_SIZEOF);
+
+// Anything larger than this will cause integer overflows
+#define MAX_HASHTABLE_SIZE (0x80000000u / (MAX_HASHITEM_SIZEOF))
+
 struct SDL_HashTable
 struct SDL_HashTable
 {
 {
-    SDL_HashItem **table;
-    Uint32 table_len;
-    int hash_shift;
-    bool stackable;
-    void *data;
+    SDL_HashItem *table;
     SDL_HashTable_HashFn hash;
     SDL_HashTable_HashFn hash;
     SDL_HashTable_KeyMatchFn keymatch;
     SDL_HashTable_KeyMatchFn keymatch;
     SDL_HashTable_NukeFn nuke;
     SDL_HashTable_NukeFn nuke;
+    void *data;
+    Uint32 hash_mask;
+    Uint32 max_probe_len;
+    Uint32 num_occupied_slots;
+    bool stackable;
 };
 };
 
 
 SDL_HashTable *SDL_CreateHashTable(void *data, const Uint32 num_buckets, const SDL_HashTable_HashFn hashfn,
 SDL_HashTable *SDL_CreateHashTable(void *data, const Uint32 num_buckets, const SDL_HashTable_HashFn hashfn,
@@ -47,26 +63,29 @@ SDL_HashTable *SDL_CreateHashTable(void *data, const Uint32 num_buckets, const S
 {
 {
     SDL_HashTable *table;
     SDL_HashTable *table;
 
 
-    // num_buckets must be a power of two so we can derive the bucket index with just a bitshift.
-    // Need at least two buckets, otherwise hash_shift would be 32, which is UB!
-    if ((num_buckets < 2) || !SDL_HasExactlyOneBitSet32(num_buckets)) {
+    // num_buckets must be a power of two so we can derive the bucket index with just a bit-and.
+    if ((num_buckets < 1) || !SDL_HasExactlyOneBitSet32(num_buckets)) {
         SDL_SetError("num_buckets must be a power of two");
         SDL_SetError("num_buckets must be a power of two");
         return NULL;
         return NULL;
     }
     }
 
 
-    table = (SDL_HashTable *) SDL_calloc(1, sizeof (SDL_HashTable));
+    if (num_buckets > MAX_HASHTABLE_SIZE) {
+        SDL_SetError("num_buckets is too large");
+        return NULL;
+    }
+
+    table = (SDL_HashTable *)SDL_calloc(1, sizeof(SDL_HashTable));
     if (!table) {
     if (!table) {
         return NULL;
         return NULL;
     }
     }
 
 
-    table->table = (SDL_HashItem **) SDL_calloc(num_buckets, sizeof (SDL_HashItem *));
+    table->table = (SDL_HashItem *)SDL_calloc(num_buckets, sizeof(SDL_HashItem));
     if (!table->table) {
     if (!table->table) {
         SDL_free(table);
         SDL_free(table);
         return NULL;
         return NULL;
     }
     }
 
 
-    table->table_len = num_buckets;
-    table->hash_shift = 32 - SDL_MostSignificantBitIndex32(num_buckets);
+    table->hash_mask = num_buckets - 1;
     table->stackable = stackable;
     table->stackable = stackable;
     table->data = data;
     table->data = data;
     table->hash = hashfn;
     table->hash = hashfn;
@@ -75,47 +94,232 @@ SDL_HashTable *SDL_CreateHashTable(void *data, const Uint32 num_buckets, const S
     return table;
     return table;
 }
 }
 
 
-static SDL_INLINE Uint32 calc_hash(const SDL_HashTable *table, const void *key)
+static SDL_INLINE Uint32 calc_hash(const SDL_HashTable *restrict table, const void *key)
 {
 {
-    // Mix the bits together, and use the highest bits as the bucket index.
     const Uint32 BitMixer = 0x9E3779B1u;
     const Uint32 BitMixer = 0x9E3779B1u;
-    return (table->hash(key, table->data) * BitMixer) >> table->hash_shift;
+    return table->hash(key, table->data) * BitMixer;
 }
 }
 
 
+static SDL_INLINE Uint32 get_probe_length(Uint32 zero_idx, Uint32 actual_idx, Uint32 num_buckets)
+{
+    // returns the probe sequence length from zero_idx to actual_idx
+
+    if (actual_idx < zero_idx) {
+        return num_buckets - zero_idx + actual_idx;
+    }
 
 
-bool SDL_InsertIntoHashTable(SDL_HashTable *table, const void *key, const void *value)
+    return actual_idx - zero_idx;
+}
+
+static SDL_HashItem *find_item(const SDL_HashTable *restrict ht, const void *key, Uint32 hash, Uint32 *restrict i, Uint32 *restrict probe_len)
 {
 {
-    SDL_HashItem *item;
-    Uint32 hash;
+    Uint32 hash_mask = ht->hash_mask;
+    Uint32 max_probe_len = ht->max_probe_len;
 
 
-    if (!table) {
+    SDL_HashItem *table = ht->table;
+
+    for (;;) {
+        SDL_HashItem *item = table + *i;
+        Uint32 item_hash = item->hash;
+
+        if (!item->live) {
+            return NULL;
+        }
+
+        if (item_hash == hash && ht->keymatch(item->key, key, ht->data)) {
+            return item;
+        }
+
+        Uint32 item_probe_len = item->probe_len;
+        HT_ASSERT(item_probe_len == get_probe_length(item_hash & hash_mask, (Uint32)(item - table), hash_mask + 1));
+
+        if (*probe_len > item_probe_len) {
+            return NULL;
+        }
+
+        if (++*probe_len > max_probe_len) {
+            return NULL;
+        }
+
+        *i = (*i + 1) & hash_mask;
+    }
+}
+
+static SDL_HashItem *find_first_item(const SDL_HashTable *restrict ht, const void *key, Uint32 hash)
+{
+    Uint32 i = hash & ht->hash_mask;
+    Uint32 probe_len = 0;
+    return find_item(ht, key, hash, &i, &probe_len);
+}
+
+static SDL_HashItem *insert_item(SDL_HashItem *restrict item_to_insert, SDL_HashItem *restrict table, Uint32 hash_mask, Uint32 *max_probe_len_ptr)
+{
+    Uint32 idx = item_to_insert->hash & hash_mask;
+    SDL_HashItem temp_item, *target = NULL;
+    Uint32 num_buckets = hash_mask + 1;
+
+    for (;;) {
+        SDL_HashItem *candidate = table + idx;
+
+        if (!candidate->live) {
+            // Found an empty slot. Put it here and we're done.
+
+            *candidate = *item_to_insert;
+
+            if (target == NULL) {
+                target = candidate;
+            }
+
+            Uint32 probe_len = get_probe_length(candidate->hash & hash_mask, idx, num_buckets);
+            candidate->probe_len = probe_len;
+
+            if (*max_probe_len_ptr < probe_len) {
+                *max_probe_len_ptr = probe_len;
+            }
+
+            break;
+        }
+
+        Uint32 candidate_probe_len = candidate->probe_len;
+        HT_ASSERT(candidate_probe_len == get_probe_length(candidate->hash & hash_mask, idx, num_buckets));
+        Uint32 new_probe_len = get_probe_length(item_to_insert->hash & hash_mask, idx, num_buckets);
+
+        if (candidate_probe_len < new_probe_len) {
+            // Robin Hood hashing: the item at idx has a better probe length than our item would at this position.
+            // Evict it and put our item in its place, then continue looking for a new spot for the displaced item.
+            // This algorithm significantly reduces clustering in the table, making lookups take very few probes.
+
+            temp_item = *candidate;
+            *candidate = *item_to_insert;
+
+            if (target == NULL) {
+                target = candidate;
+            }
+
+            *item_to_insert = temp_item;
+
+            HT_ASSERT(new_probe_len == get_probe_length(candidate->hash & hash_mask, idx, num_buckets));
+            candidate->probe_len = new_probe_len;
+
+            if (*max_probe_len_ptr < new_probe_len) {
+                *max_probe_len_ptr = new_probe_len;
+            }
+        }
+
+        idx = (idx + 1) & hash_mask;
+    }
+
+    return target;
+}
+
+static void delete_item(SDL_HashTable *restrict ht, SDL_HashItem *item)
+{
+    Uint32 hash_mask = ht->hash_mask;
+    SDL_HashItem *table = ht->table;
+
+    if (ht->nuke) {
+        ht->nuke(item->key, item->value, ht->data);
+    }
+    ht->num_occupied_slots--;
+
+    Uint32 idx = (Uint32)(item - ht->table);
+
+    for (;;) {
+        idx = (idx + 1) & hash_mask;
+        SDL_HashItem *next_item = table + idx;
+
+        if (next_item->probe_len < 1) {
+            SDL_zerop(item);
+            return;
+        }
+
+        *item = *next_item;
+        item->probe_len -= 1;
+        HT_ASSERT(item->probe_len < ht->max_probe_len);
+        item = next_item;
+    }
+}
+
+static bool resize(SDL_HashTable *restrict ht, Uint32 new_size)
+{
+    SDL_HashItem *old_table = ht->table;
+    Uint32 old_size = ht->hash_mask + 1;
+    Uint32 new_hash_mask = new_size - 1;
+    SDL_HashItem *new_table = SDL_calloc(new_size, sizeof(*new_table));
+
+    if (!new_table) {
         return false;
         return false;
     }
     }
 
 
-    if ( (!table->stackable) && (SDL_FindInHashTable(table, key, NULL)) ) {
+    ht->max_probe_len = 0;
+    ht->hash_mask = new_hash_mask;
+    ht->table = new_table;
+
+    for (Uint32 i = 0; i < old_size; ++i) {
+        SDL_HashItem *item = old_table + i;
+        if (item->live) {
+            insert_item(item, new_table, new_hash_mask, &ht->max_probe_len);
+        }
+    }
+
+    SDL_free(old_table);
+    return true;
+}
+
+static bool maybe_resize(SDL_HashTable *restrict ht)
+{
+    Uint32 capacity = ht->hash_mask + 1;
+
+    if (capacity >= MAX_HASHTABLE_SIZE) {
         return false;
         return false;
     }
     }
 
 
-    // !!! FIXME: grow and rehash table if it gets too saturated.
-    item = (SDL_HashItem *) SDL_malloc(sizeof (SDL_HashItem));
-    if (!item) {
+    Uint32 max_load_factor = 217; // range: 0-255; 217 is roughly 85%
+    Uint32 resize_threshold = (max_load_factor * (Uint64)capacity) >> 8;
+
+    if (ht->num_occupied_slots > resize_threshold) {
+        return resize(ht, capacity * 2);
+    }
+
+    return true;
+}
+
+bool SDL_InsertIntoHashTable(SDL_HashTable *restrict table, const void *key, const void *value)
+{
+    SDL_HashItem *item;
+    Uint32 hash;
+
+    if (!table) {
         return false;
         return false;
     }
     }
 
 
     hash = calc_hash(table, key);
     hash = calc_hash(table, key);
+    item = find_first_item(table, key, hash);
 
 
-    item->key = key;
-    item->value = value;
-    item->next = table->table[hash];
-    table->table[hash] = item;
+    if (item && !table->stackable) {
+        // TODO: Maybe allow overwrites? We could do it more efficiently here than unset followed by insert.
+        return false;
+    }
 
 
-    return true;
+    SDL_HashItem new_item;
+    new_item.key = key;
+    new_item.value = value;
+    new_item.hash = hash;
+    new_item.live = true;
+
+    table->num_occupied_slots++;
+
+    if (!maybe_resize(table)) {
+        table->num_occupied_slots--;
+        return false;
+    }
+
+    return insert_item(&new_item, table->table, table->hash_mask, &table->max_probe_len);
 }
 }
 
 
 bool SDL_FindInHashTable(const SDL_HashTable *table, const void *key, const void **_value)
 bool SDL_FindInHashTable(const SDL_HashTable *table, const void *key, const void **_value)
 {
 {
     Uint32 hash;
     Uint32 hash;
-    void *data;
     SDL_HashItem *i;
     SDL_HashItem *i;
 
 
     if (!table) {
     if (!table) {
@@ -123,104 +327,101 @@ bool SDL_FindInHashTable(const SDL_HashTable *table, const void *key, const void
     }
     }
 
 
     hash = calc_hash(table, key);
     hash = calc_hash(table, key);
-    data = table->data;
-
-    for (i = table->table[hash]; i; i = i->next) {
-        if (table->keymatch(key, i->key, data)) {
-            if (_value) {
-                *_value = i->value;
-            }
-            return true;
-        }
-    }
+    i = find_first_item(table, key, hash);
+    *_value = i ? i->value : NULL;
 
 
-    return false;
+    return i;
 }
 }
 
 
 bool SDL_RemoveFromHashTable(SDL_HashTable *table, const void *key)
 bool SDL_RemoveFromHashTable(SDL_HashTable *table, const void *key)
 {
 {
     Uint32 hash;
     Uint32 hash;
-    SDL_HashItem *item = NULL;
-    SDL_HashItem *prev = NULL;
-    void *data;
+    SDL_HashItem *item;
 
 
     if (!table) {
     if (!table) {
         return false;
         return false;
     }
     }
 
 
-    hash = calc_hash(table, key);
-    data = table->data;
-
-    for (item = table->table[hash]; item; item = item->next) {
-        if (table->keymatch(key, item->key, data)) {
-            if (prev) {
-                prev->next = item->next;
-            } else {
-                table->table[hash] = item->next;
-            }
+    // FIXME: what to do for stacking hashtables?
+    // The original code removes just one item.
+    // This hashtable happens to preserve the insertion order of multi-value keys,
+    // so deleting the first one will always delete the least-recently inserted one.
+    // But maybe it makes more sense to remove all matching items?
 
 
-            if (table->nuke) {
-                table->nuke(item->key, item->value, data);
-            }
-            SDL_free(item);
-            return true;
-        }
+    hash = calc_hash(table, key);
+    item = find_first_item(table, key, hash);
 
 
-        prev = item;
+    if (!item) {
+        return false;
     }
     }
 
 
-    return false;
+    delete_item(table, item);
+    return true;
 }
 }
 
 
 bool SDL_IterateHashTableKey(const SDL_HashTable *table, const void *key, const void **_value, void **iter)
 bool SDL_IterateHashTableKey(const SDL_HashTable *table, const void *key, const void **_value, void **iter)
 {
 {
-    SDL_HashItem *item;
+    SDL_HashItem *item = (SDL_HashItem *)*iter;
 
 
     if (!table) {
     if (!table) {
         return false;
         return false;
     }
     }
 
 
-    item = *iter ? ((SDL_HashItem *)*iter)->next : table->table[calc_hash(table, key)];
+    Uint32 i, probe_len, hash;
 
 
-    while (item) {
-        if (table->keymatch(key, item->key, table->data)) {
-            *_value = item->value;
-            *iter = item;
-            return true;
-        }
-        item = item->next;
+    if (item) {
+        HT_ASSERT(item >= table->table);
+        HT_ASSERT(item < table->table + (table->hash_mask + 1));
+
+        hash = item->hash;
+        probe_len = item->probe_len + 1;
+        i = ((Uint32)(item - table->table) + 1) & table->hash_mask;
+        item = table->table + i;
+    } else {
+        hash = calc_hash(table, key);
+        i = hash & table->hash_mask;
+        probe_len = 0;
     }
     }
 
 
-    // no more matches.
-    *_value = NULL;
-    *iter = NULL;
-    return false;
+    item = find_item(table, key, hash, &i, &probe_len);
+
+    if (!item) {
+        *_value = NULL;
+        return false;
+    }
+
+    *_value = item->value;
+    *iter = item;
+
+    return true;
 }
 }
 
 
 bool SDL_IterateHashTable(const SDL_HashTable *table, const void **_key, const void **_value, void **iter)
 bool SDL_IterateHashTable(const SDL_HashTable *table, const void **_key, const void **_value, void **iter)
 {
 {
-    SDL_HashItem *item = (SDL_HashItem *) *iter;
-    Uint32 idx = 0;
+    SDL_HashItem *item = (SDL_HashItem *)*iter;
 
 
     if (!table) {
     if (!table) {
         return false;
         return false;
     }
     }
 
 
-    if (item) {
-        const SDL_HashItem *orig = item;
-        item = item->next;
-        if (!item) {
-            idx = calc_hash(table, orig->key) + 1;  // !!! FIXME: we probably shouldn't rehash each time.
-        }
+    if (!item) {
+        item = table->table;
+    } else {
+        item++;
     }
     }
 
 
-    while (!item && (idx < table->table_len)) {
-        item = table->table[idx++];  // skip empty buckets...
+    HT_ASSERT(item >= table->table);
+    SDL_HashItem *end = table->table + (table->hash_mask + 1);
+
+    while (item < end && !item->live) {
+        ++item;
     }
     }
 
 
-    if (!item) {  // no more matches?
+    HT_ASSERT(item <= end);
+
+    if (item == end) {
         *_key = NULL;
         *_key = NULL;
-        *iter = NULL;
+        *_value = NULL;
         return false;
         return false;
     }
     }
 
 
@@ -233,44 +434,41 @@ bool SDL_IterateHashTable(const SDL_HashTable *table, const void **_key, const v
 
 
 bool SDL_HashTableEmpty(SDL_HashTable *table)
 bool SDL_HashTableEmpty(SDL_HashTable *table)
 {
 {
-    if (table) {
-        Uint32 i;
+    return !(table && table->num_occupied_slots);
+}
 
 
-        for (i = 0; i < table->table_len; i++) {
-            SDL_HashItem *item = table->table[i];
-            if (item) {
-                return false;
-            }
+static void nuke_all(SDL_HashTable *restrict table)
+{
+    void *data = table->data;
+    SDL_HashItem *end = table->table + (table->hash_mask + 1);
+    SDL_HashItem *i;
+
+    for (i = table->table; i < end; ++i) {
+        if (i->live) {
+            table->nuke(i->key, i->value, data);
         }
         }
     }
     }
-    return true;
 }
 }
 
 
-void SDL_EmptyHashTable(SDL_HashTable *table)
+void SDL_EmptyHashTable(SDL_HashTable *restrict table)
 {
 {
     if (table) {
     if (table) {
-        void *data = table->data;
-        Uint32 i;
-
-        for (i = 0; i < table->table_len; i++) {
-            SDL_HashItem *item = table->table[i];
-            while (item) {
-                SDL_HashItem *next = item->next;
-                if (table->nuke) {
-                    table->nuke(item->key, item->value, data);
-                }
-                SDL_free(item);
-                item = next;
-            }
-            table->table[i] = NULL;
+        if (table->nuke) {
+            nuke_all(table);
         }
         }
+
+        SDL_memset(table->table, 0, sizeof(*table->table) * (table->hash_mask + 1));
+        table->num_occupied_slots = 0;
     }
     }
 }
 }
 
 
 void SDL_DestroyHashTable(SDL_HashTable *table)
 void SDL_DestroyHashTable(SDL_HashTable *table)
 {
 {
     if (table) {
     if (table) {
-        SDL_EmptyHashTable(table);
+        if (table->nuke) {
+            nuke_all(table);
+        }
+
         SDL_free(table->table);
         SDL_free(table->table);
         SDL_free(table);
         SDL_free(table);
     }
     }
@@ -298,13 +496,13 @@ bool SDL_KeyMatchString(const void *a, const void *b, void *data)
     const char *b_string = (const char *)b;
     const char *b_string = (const char *)b;
 
 
     if (a == b) {
     if (a == b) {
-        return true;  // same pointer, must match.
+        return true; // same pointer, must match.
     } else if (!a || !b) {
     } else if (!a || !b) {
-        return false;  // one pointer is NULL (and first test shows they aren't the same pointer), must not match.
+        return false; // one pointer is NULL (and first test shows they aren't the same pointer), must not match.
     } else if (a_string[0] != b_string[0]) {
     } else if (a_string[0] != b_string[0]) {
-        return false;  // we know they don't match
+        return false; // we know they don't match
     }
     }
-    return (SDL_strcmp(a_string, b_string) == 0);  // Check against actual string contents.
+    return (SDL_strcmp(a_string, b_string) == 0); // Check against actual string contents.
 }
 }
 
 
 // We assume we can fit the ID in the key directly
 // We assume we can fit the ID in the key directly