summaryrefslogtreecommitdiffstats
path: root/include/ds/BloomFilter.h
blob: d55a7af2de1bf7cf497492225ed7a7e5d6f0edba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/*
 * include/ds/BloomFilter.h
 *
 * Copyright (C) 2023 Dong Xie <dongx@psu.edu>
 *
 * All rights reserved. Published under the Modified BSD License.
 *
 */
#pragma once

#include <cmath>
#include <gsl/gsl_rng.h>

#include "ds/BitArray.h"
#include "util/hash.h"
#include "util/base.h"
#include "util/Record.h"

namespace de {

template <typename K>
class BloomFilter {
public:
    BloomFilter(size_t n_bits, size_t k)
    : m_n_bits(n_bits), m_n_salts(k), m_bitarray(n_bits) {
        gsl_rng *rng = gsl_rng_alloc(gsl_rng_mt19937);
        salt = (uint16_t*) aligned_alloc(CACHELINE_SIZE, CACHELINEALIGN(k * sizeof(uint16_t)));
        for (size_t i = 0;  i < k; ++i) {
            salt[i] = (uint16_t) gsl_rng_uniform_int(rng, 1 << 16);
        }

        gsl_rng_free(rng);
    }

    BloomFilter(double max_fpr, size_t n, size_t k)
    : BloomFilter((size_t)(-(double) (k * n) / std::log(1.0 - std::pow(max_fpr, 1.0 / k))), k) {}

    ~BloomFilter() {
        if (salt) free(salt);
    }

    int insert(const K& key, size_t sz = sizeof(K)) {
        if (m_bitarray.size() == 0) return 0;

        for (size_t i = 0; i < m_n_salts; ++i) {
            m_bitarray.set(hash_bytes_with_salt((const char*)&key, sz, salt[i]) % m_n_bits);
        }

        return 1;
    }

    bool lookup(const K& key, size_t sz = sizeof(K)) {
        if (m_bitarray.size() == 0) return false;
        for (size_t i = 0; i < m_n_salts; ++i) {
            if (!m_bitarray.is_set(hash_bytes_with_salt((const char*)&key, sz, salt[i]) % m_n_bits))
                return false;
        }

        return true;
    }

    void clear() {
        m_bitarray.clear();
    }

    size_t get_memory_usage() {
        return this->m_bitarray.mem_size();
    }
private: 
    size_t m_n_salts;
    size_t m_n_bits;
    uint16_t* salt;

    BitArray m_bitarray;
};

}