diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-05-15 16:48:56 -0400 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-05-15 16:48:56 -0400 |
| commit | ff000799c3254f52e0beabbe9c62d10c3fc4178e (patch) | |
| tree | 49a1a045678315e8e215fd80409973679b793043 /include/shard/WIRS.h | |
| parent | 418e9b079e559c86f3a5b276f712ad2f5d66533c (diff) | |
| download | dynamic-extension-ff000799c3254f52e0beabbe9c62d10c3fc4178e.tar.gz | |
Record format generalization
Currently, tombstone counting is bugged. But the rest of it appears to
be working.
Diffstat (limited to 'include/shard/WIRS.h')
| -rw-r--r-- | include/shard/WIRS.h | 108 |
1 files changed, 58 insertions, 50 deletions
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h index 39337bf..41766b9 100644 --- a/include/shard/WIRS.h +++ b/include/shard/WIRS.h @@ -12,6 +12,7 @@ #include <cassert> #include <queue> #include <memory> +#include <concepts> #include "ds/PriorityQueue.h" #include "util/Cursor.h" @@ -24,19 +25,26 @@ namespace de { thread_local size_t wirs_cancelations = 0; -template <typename K, typename V, typename W> +template <WeightedRecordInterface R> class WIRS { private: + + typedef decltype(R::key) K; + typedef decltype(R::value) V; + typedef decltype(R::weight) W; + + template <WeightedRecordInterface R_ = R> struct wirs_node { - struct wirs_node *left, *right; + struct wirs_node<R_> *left, *right; K low, high; W weight; Alias* alias; }; + template <WeightedRecordInterface R_ = R> struct WIRSState { W tot_weight; - std::vector<wirs_node*> nodes; + std::vector<wirs_node<R_>*> nodes; Alias* top_level_alias; ~WIRSState() { @@ -45,13 +53,13 @@ private: }; public: - WIRS(MutableBuffer<K, V, W>* buffer, BloomFilter* bf, bool tagging) + WIRS(MutableBuffer<R>* buffer, BloomFilter* bf, bool tagging) : m_reccnt(0), m_tombstone_cnt(0), m_deleted_cnt(0), m_total_weight(0), m_rejection_cnt(0), m_ts_check_cnt(0), m_tagging(tagging), m_root(nullptr) { - size_t alloc_size = (buffer->get_record_count() * sizeof(Record<K, V, W>)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Record<K, V, W>)) % CACHELINE_SIZE); + size_t alloc_size = (buffer->get_record_count() * sizeof(R)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(R)) % CACHELINE_SIZE); assert(alloc_size % CACHELINE_SIZE == 0); - m_data = (Record<K, V, W>*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); + m_data = (R*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); size_t offset = 0; m_reccnt = 0; @@ -61,13 +69,13 @@ public: while (base < stop) { if (!m_tagging) { if (!(base->is_tombstone()) && (base + 1) < stop) { - if (base->match(base + 1) && (base + 1)->is_tombstone()) { + if (*base == *(base + 1) && (base + 1)->is_tombstone()) { base += 2; wirs_cancelations++; continue; } } - } else if (base->get_delete_status()) { + } else if (base->is_deleted()) { base += 1; continue; } @@ -92,10 +100,10 @@ public: WIRS(WIRS** shards, size_t len, BloomFilter* bf, bool tagging) : m_reccnt(0), m_tombstone_cnt(0), m_deleted_cnt(0), m_total_weight(0), m_rejection_cnt(0), m_ts_check_cnt(0), m_tagging(tagging), m_root(nullptr) { - std::vector<Cursor<K,V,W>> cursors; + std::vector<Cursor<R>> cursors; cursors.reserve(len); - PriorityQueue<K, V, W> pq(len); + PriorityQueue<R> pq(len); size_t attemp_reccnt = 0; @@ -106,28 +114,28 @@ public: attemp_reccnt += shards[i]->get_record_count(); pq.push(cursors[i].ptr, i); } else { - cursors.emplace_back(Cursor<K,V,W>{nullptr, nullptr, 0, 0}); + cursors.emplace_back(Cursor<R>{nullptr, nullptr, 0, 0}); } } - size_t alloc_size = (attemp_reccnt * sizeof(Record<K, V, W>)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Record<K, V, W>)) % CACHELINE_SIZE); + size_t alloc_size = (attemp_reccnt * sizeof(R)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(R)) % CACHELINE_SIZE); assert(alloc_size % CACHELINE_SIZE == 0); - m_data = (Record<K, V, W>*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); + m_data = (R*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); while (pq.size()) { auto now = pq.peek(); - auto next = pq.size() > 1 ? pq.peek(1) : queue_record<K, V, W>{nullptr, 0}; + auto next = pq.size() > 1 ? pq.peek(1) : queue_record<R>{nullptr, 0}; if (!m_tagging && !now.data->is_tombstone() && next.data != nullptr && - now.data->match(next.data) && next.data->is_tombstone()) { + *now.data == *next.data && next.data->is_tombstone()) { pq.pop(); pq.pop(); auto& cursor1 = cursors[now.version]; auto& cursor2 = cursors[next.version]; - if (advance_cursor<K,V,W>(cursor1)) pq.push(cursor1.ptr, now.version); - if (advance_cursor<K,V,W>(cursor2)) pq.push(cursor2.ptr, next.version); + if (advance_cursor<R>(cursor1)) pq.push(cursor1.ptr, now.version); + if (advance_cursor<R>(cursor2)) pq.push(cursor2.ptr, next.version); } else { auto& cursor = cursors[now.version]; - if (!m_tagging || !cursor.ptr->get_delete_status()) { + if (!m_tagging || !cursor.ptr->is_deleted()) { m_data[m_reccnt++] = *cursor.ptr; m_total_weight += cursor.ptr->weight; if (bf && cursor.ptr->is_tombstone()) { @@ -137,7 +145,7 @@ public: } pq.pop(); - if (advance_cursor<K,V,W>(cursor)) pq.push(cursor.ptr, now.version); + if (advance_cursor<R>(cursor)) pq.push(cursor.ptr, now.version); } } @@ -155,16 +163,16 @@ public: free_tree(m_root); } - bool delete_record(const K& key, const V& val) { - size_t idx = get_lower_bound(key); + bool delete_record(const R& rec) { + size_t idx = get_lower_bound(rec.key); if (idx >= m_reccnt) { return false; } - while (idx < m_reccnt && m_data[idx].lt(key, val)) ++idx; + while (idx < m_reccnt && m_data[idx] < rec) ++idx; - if (m_data[idx].match(key, val, false)) { - m_data[idx].set_delete_status(); + if (m_data[idx] == R {rec.key, rec.val} && !m_data[idx].is_tombstone()) { + m_data[idx].set_delete(); m_deleted_cnt++; return true; } @@ -172,7 +180,7 @@ public: return false; } - void free_tree(struct wirs_node* node) { + void free_tree(struct wirs_node<R>* node) { if (node) { delete node->alias; free_tree(node->left); @@ -181,7 +189,7 @@ public: } } - Record<K, V, W>* sorted_output() const { + R* sorted_output() const { return m_data; } @@ -193,19 +201,19 @@ public: return m_tombstone_cnt; } - const Record<K, V, W>* get_record_at(size_t idx) const { + const R* get_record_at(size_t idx) const { if (idx >= m_reccnt) return nullptr; return m_data + idx; } // low - high -> decompose to a set of nodes. // Build Alias across the decomposed nodes. - WIRSState* get_sample_shard_state(const K& lower_key, const K& upper_key) { - WIRSState* res = new WIRSState(); + WIRSState<R>* get_sample_shard_state(const K& lower_key, const K& upper_key) { + auto res = new WIRSState(); // Simulate a stack to unfold recursion. double tot_weight = 0.0; - struct wirs_node* st[64] = {0}; + struct wirs_node<R>* st[64] = {0}; st[0] = m_root; size_t top = 1; while(top > 0) { @@ -231,15 +239,15 @@ public: } static void delete_state(void *state) { - auto s = (WIRSState *) state; + WIRSState<R> *s = (WIRSState<R> *) state; delete s; } // returns the number of records sampled // NOTE: This operation returns records strictly between the lower and upper bounds, not // including them. - size_t get_samples(void* shard_state, std::vector<Record<K, V, W>> &result_set, const K& lower_key, const K& upper_key, size_t sample_sz, gsl_rng *rng) { - WIRSState *state = (WIRSState *) shard_state; + size_t get_samples(void* shard_state, std::vector<R> &result_set, const K& lower_key, const K& upper_key, size_t sample_sz, gsl_rng *rng) { + WIRSState<R> *state = (WIRSState<R> *) shard_state; if (sample_sz == 0) { return 0; } @@ -295,30 +303,30 @@ public: auto ptr = m_data + get_lower_bound(key); - while (ptr < m_data + m_reccnt && ptr->lt(key, val)) { + while (ptr < m_data + m_reccnt && *ptr < R {key, val}) { ptr ++; } - bool result = (m_tagging) ? ptr->get_delete_status() - : ptr->match(key, val, true); + bool result = (m_tagging) ? ptr->is_deleted() + : *ptr == R {key, val} && ptr->is_tombstone(); m_rejection_cnt += result; return result; } - bool check_tombstone(const K& key, const V& val) { + bool check_tombstone(const R& rec) { m_ts_check_cnt++; - size_t idx = get_lower_bound(key); + size_t idx = get_lower_bound(rec.key); if (idx >= m_reccnt) { return false; } - auto ptr = m_data + get_lower_bound(key); + auto ptr = m_data + get_lower_bound(rec.key); - while (ptr < m_data + m_reccnt && ptr->lt(key, val)) { + while (ptr < m_data + m_reccnt && *ptr < rec) { ptr ++; } - bool result = ptr->match(key, val, true); + bool result = *ptr == rec && ptr->is_tombstone(); m_rejection_cnt += result; return result; @@ -340,21 +348,21 @@ public: private: - bool covered_by(struct wirs_node* node, const K& lower_key, const K& upper_key) { + bool covered_by(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); return lower_key < m_data[low_index].key && m_data[high_index].key < upper_key; } - bool intersects(struct wirs_node* node, const K& lower_key, const K& upper_key) { + bool intersects(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key; } - struct wirs_node* construct_wirs_node(const std::vector<W>& weights, size_t low, size_t high) { + struct wirs_node<R>* construct_wirs_node(const std::vector<W>& weights, size_t low, size_t high) { if (low == high) { - return new wirs_node{nullptr, nullptr, low, high, weights[low], new Alias({1.0})}; + return new wirs_node<R>{nullptr, nullptr, low, high, weights[low], new Alias({1.0})}; } else if (low > high) return nullptr; std::vector<double> node_weights; @@ -370,9 +378,9 @@ private: size_t mid = (low + high) / 2; - return new wirs_node{construct_wirs_node(weights, low, mid), - construct_wirs_node(weights, mid + 1, high), - low, high, sum, new Alias(node_weights)}; + return new wirs_node<R>{construct_wirs_node(weights, low, mid), + construct_wirs_node(weights, mid + 1, high), + low, high, sum, new Alias(node_weights)}; } @@ -410,9 +418,9 @@ private: m_root = construct_wirs_node(weights, 0, n_groups-1); } - Record<K, V, W>* m_data; + R* m_data; std::vector<Alias *> m_alias; - wirs_node* m_root; + wirs_node<R>* m_root; bool m_tagging; W m_total_weight; size_t m_reccnt; |