Skip to content

Instantly share code, notes, and snippets.

@mortie
Last active December 2, 2024 22:28
Show Gist options
  • Save mortie/3cbb935ff539dadf529579739d70c0d7 to your computer and use it in GitHub Desktop.
Save mortie/3cbb935ff539dadf529579739d70c0d7 to your computer and use it in GitHub Desktop.
Fast flat hash set in C++ (Warning: not tested, probably has bugs)
#pragma once
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <cstdlib>
#include <functional>
template<
typename Key,
typename Hash = std::hash<Key>,
typename KeyEqual = std::equal_to<Key>>
class FastHashSet {
public:
~FastHashSet()
{
clear();
free(buffer_);
free(bits_);
}
void clear()
{
if (cap_ == 0) {
return;
}
for (size_t i = 0; i < cap_; ++i) {
if (isOccupied(i)) {
buffer_[i].~Key();
}
}
memset(bits_, 0, (cap_ / (sizeof(*bits_) * 8) + 1) * sizeof(*bits_));
}
template<typename K>
void insert(K &&key)
{
if (cap_ == 0) {
rehash(8);
} else if (elementCount_ >= cap_ / 2) {
rehash(cap_ * 2);
}
insertNoResize(std::forward<K>(key));
}
void erase(const Key &key)
{
if (cap_ == 0) {
return;
}
// Find the key and erase it
size_t index = Hash{}(key) & (cap_ - 1);
while (true) {
if (!isOccupied(index)) {
return;
}
if (KeyEqual{}(buffer_[index], key)) {
buffer_[index].~Key();
elementCount_ -= 1;
break;
}
index = (index + 1) & (cap_ - 1);
}
// We just made a hole in the table where there was none before!
// Move everything after it back, until we reach another hole
// or a fixed point
while (true) {
// Last index \ I gave you my heart
// But the very next day \ you gave it away
size_t lastIndex = index;
index = (index + 1) & (cap_ - 1);
if (!isOccupied(index)) {
clearOccupied(lastIndex);
break;
}
// Fixed point: an element that's exactly where its hash
// says it ought to be
if ((Hash{}(buffer_[index]) & (cap_ - 1)) == index) {
clearOccupied(lastIndex);
break;
}
new (&buffer_[lastIndex]) Key(std::move(buffer_[index]));
buffer_[index].~Key();
}
}
bool contains(const Key &key)
{
size_t index = Hash{}(key) & (cap_ - 1);
while (true) {
if (!isOccupied(index)) {
return false;
}
if (!KeyEqual{}(key, buffer_[index])) {
index = (index + 1) & (cap_ - 1);
continue;
}
return true;
}
}
private:
template<typename K>
void insertNoResize(K &&key)
{
size_t index = Hash{}(key) & (cap_ - 1);
while (true) {
if (isOccupied(index)) {
if (KeyEqual{}(key, buffer_[index])) {
return;
}
index = (index + 1) & (cap_ - 1);
continue;
}
new (&buffer_[index]) Key(std::forward<K>(key));
setOccupied(index);
elementCount_ += 1;
return;
}
}
void rehash(size_t newCap)
{
size_t oldCap = cap_;
Key *oldBuffer = buffer_;
unsigned long long *oldBits = bits_;
cap_ = newCap;
buffer_ = (Key *)malloc(sizeof(*buffer_) * newCap);
bits_ = (unsigned long long *)calloc(newCap / (sizeof(*bits_) * 8) + 1, sizeof(*bits_));
elementCount_ = 0;
for (size_t index = 0; index < oldCap; ++index) {
bool occupied = getBit(oldBits, index);
if (!occupied) {
continue;
}
insertNoResize(std::move(oldBuffer[index]));
oldBuffer[index].~Key();
}
free(oldBuffer);
free(oldBits);
}
bool isOccupied(size_t index)
{
return getBit(bits_, index);
}
void setOccupied(size_t index)
{
setBit(bits_, index);
}
void clearOccupied(size_t index)
{
clearBit(bits_, index);
}
static bool getBit(unsigned long long *bits, size_t bit)
{
size_t index = bit / (sizeof(*bits) * 8);
auto mask = (unsigned long long)1 << (bit % (sizeof(*bits) * 8));
return bits[index] & mask;
}
static void setBit(unsigned long long *bits, size_t bit)
{
size_t index = bit / (sizeof(*bits) * 8);
auto mask = (unsigned long long)1 << (bit % (sizeof(*bits) * 8));
bits[index] |= mask;
}
static void clearBit(unsigned long long *bits, size_t bit)
{
size_t index = bit / (sizeof(*bits) * 8);
auto mask = (unsigned long long)1 << (bit % (sizeof(*bits) * 8));
bits[index] &= ~mask;
}
size_t cap_ = 0;
Key *buffer_ = nullptr;
unsigned long long *bits_ = nullptr;
size_t elementCount_;
};
/*
* Benchmark results (with -O2 -DNDEBUG):
*
* Apple M1 Pro (macOS):
* insertALot :: unordered_set: 24.4842ms per iteration (41 iters)
* insertALot :: absl::node_hash_set: 26.3697ms per iteration (38 iters)
* insertALot :: absl::flat_hash_set: 9.01965ms per iteration (111 iters)
* insertALot :: FastHashSet: 4.83171ms per iteration (207 iters)
*
* insertALot (re-used) :: unordered_set: 22.4338ms per iteration (45 iters)
* insertALot (re-used) :: absl::node_hash_set: 27.8833ms per iteration (36 iters)
* insertALot (re-used) :: absl::flat_hash_set: 8.97861ms per iteration (112 iters)
* insertALot (re-used) :: FastHashSet: 3.23483ms per iteration (310 iters)
*
* AMD R9 5950x (Fedora Linux):
* insertALot :: unordered_set: 31.151ms per iteration (33 iters)
* insertALot :: absl::node_hash_set: 36.5258ms per iteration (28 iters)
* insertALot :: absl::flat_hash_set: 10.6556ms per iteration (94 iters)
* insertALot :: FastHashSet: 3.88768ms per iteration (258 iters)
*
* insertALot (re-used) :: unordered_set: 15.7228ms per iteration (64 iters)
* insertALot (re-used) :: absl::node_hash_set: 35.8289ms per iteration (28 iters)
* insertALot (re-used) :: absl::flat_hash_set: 10.913ms per iteration (92 iters)
* insertALot (re-used) :: FastHashSet: 1.58122ms per iteration (633 iters)
*/
#include "FastHashSet.h"
#include <cstdint>
#include <unordered_set>
#include <time.h>
#include <stdlib.h>
#include <iostream>
#include <absl/container/node_hash_set.h>
#include <absl/container/flat_hash_set.h>
double now() {
struct timespec tv;
clock_gettime(CLOCK_MONOTONIC, &tv);
return tv.tv_sec + (tv.tv_nsec / 1000000000.0);
}
template<typename Set>
void insertALot() {
for (int i = 0; i < 100; ++i) {
Set set;
for (int j = 0; j < 10000; ++j) {
set.insert(j * 13);
}
}
}
template<typename Set>
void reUseInsertALot() {
Set set;
for (int i = 0; i < 100; ++i) {
set.clear();
for (int j = 0; j < 10000; ++j) {
set.insert(j * 13);
}
}
}
void runTest(const char *name, void (*func)()) {
for (int i = 0; i < 10; ++i) {
func();
}
int count = 0;
double startTime = now();
double delta;
while (true) {
func();
count += 1;
delta = now() - startTime;
if (count > 10 && delta > 1) {
break;
}
}
std::cout
<< name << ": " << ((delta / count) * 1000) << "ms per iteration ("
<< count << " iters)\n";
}
int main() {
runTest("insertALot :: unordered_set", insertALot<std::unordered_set<int>>);
runTest("insertALot :: absl::node_hash_set", insertALot<absl::node_hash_set<int>>);
runTest("insertALot :: absl::flat_hash_set", insertALot<absl::flat_hash_set<int>>);
runTest("insertALot :: FastHashSet", insertALot<FastHashSet<int>>);
std::cout << '\n';
runTest("insertALot (re-used) :: unordered_set", reUseInsertALot<std::unordered_set<int>>);
runTest("insertALot (re-used) :: absl::node_hash_set", reUseInsertALot<absl::node_hash_set<int>>);
runTest("insertALot (re-used) :: absl::flat_hash_set", reUseInsertALot<absl::flat_hash_set<int>>);
runTest("insertALot (re-used) :: FastHashSet", reUseInsertALot<FastHashSet<int>>);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment