summaryrefslogtreecommitdiffstats
path: root/include/shard/WIRS.h
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-05-15 16:48:56 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-05-15 16:48:56 -0400
commitff000799c3254f52e0beabbe9c62d10c3fc4178e (patch)
tree49a1a045678315e8e215fd80409973679b793043 /include/shard/WIRS.h
parent418e9b079e559c86f3a5b276f712ad2f5d66533c (diff)
downloaddynamic-extension-ff000799c3254f52e0beabbe9c62d10c3fc4178e.tar.gz
Record format generalization
Currently, tombstone counting is bugged. But the rest of it appears to be working.
Diffstat (limited to 'include/shard/WIRS.h')
-rw-r--r--include/shard/WIRS.h108
1 files changed, 58 insertions, 50 deletions
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index 39337bf..41766b9 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -12,6 +12,7 @@
#include <cassert>
#include <queue>
#include <memory>
+#include <concepts>
#include "ds/PriorityQueue.h"
#include "util/Cursor.h"
@@ -24,19 +25,26 @@ namespace de {
thread_local size_t wirs_cancelations = 0;
-template <typename K, typename V, typename W>
+template <WeightedRecordInterface R>
class WIRS {
private:
+
+ typedef decltype(R::key) K;
+ typedef decltype(R::value) V;
+ typedef decltype(R::weight) W;
+
+ template <WeightedRecordInterface R_ = R>
struct wirs_node {
- struct wirs_node *left, *right;
+ struct wirs_node<R_> *left, *right;
K low, high;
W weight;
Alias* alias;
};
+ template <WeightedRecordInterface R_ = R>
struct WIRSState {
W tot_weight;
- std::vector<wirs_node*> nodes;
+ std::vector<wirs_node<R_>*> nodes;
Alias* top_level_alias;
~WIRSState() {
@@ -45,13 +53,13 @@ private:
};
public:
- WIRS(MutableBuffer<K, V, W>* buffer, BloomFilter* bf, bool tagging)
+ WIRS(MutableBuffer<R>* buffer, BloomFilter* bf, bool tagging)
: m_reccnt(0), m_tombstone_cnt(0), m_deleted_cnt(0), m_total_weight(0), m_rejection_cnt(0),
m_ts_check_cnt(0), m_tagging(tagging), m_root(nullptr) {
- size_t alloc_size = (buffer->get_record_count() * sizeof(Record<K, V, W>)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Record<K, V, W>)) % CACHELINE_SIZE);
+ size_t alloc_size = (buffer->get_record_count() * sizeof(R)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(R)) % CACHELINE_SIZE);
assert(alloc_size % CACHELINE_SIZE == 0);
- m_data = (Record<K, V, W>*)std::aligned_alloc(CACHELINE_SIZE, alloc_size);
+ m_data = (R*)std::aligned_alloc(CACHELINE_SIZE, alloc_size);
size_t offset = 0;
m_reccnt = 0;
@@ -61,13 +69,13 @@ public:
while (base < stop) {
if (!m_tagging) {
if (!(base->is_tombstone()) && (base + 1) < stop) {
- if (base->match(base + 1) && (base + 1)->is_tombstone()) {
+ if (*base == *(base + 1) && (base + 1)->is_tombstone()) {
base += 2;
wirs_cancelations++;
continue;
}
}
- } else if (base->get_delete_status()) {
+ } else if (base->is_deleted()) {
base += 1;
continue;
}
@@ -92,10 +100,10 @@ public:
WIRS(WIRS** shards, size_t len, BloomFilter* bf, bool tagging)
: m_reccnt(0), m_tombstone_cnt(0), m_deleted_cnt(0), m_total_weight(0), m_rejection_cnt(0), m_ts_check_cnt(0),
m_tagging(tagging), m_root(nullptr) {
- std::vector<Cursor<K,V,W>> cursors;
+ std::vector<Cursor<R>> cursors;
cursors.reserve(len);
- PriorityQueue<K, V, W> pq(len);
+ PriorityQueue<R> pq(len);
size_t attemp_reccnt = 0;
@@ -106,28 +114,28 @@ public:
attemp_reccnt += shards[i]->get_record_count();
pq.push(cursors[i].ptr, i);
} else {
- cursors.emplace_back(Cursor<K,V,W>{nullptr, nullptr, 0, 0});
+ cursors.emplace_back(Cursor<R>{nullptr, nullptr, 0, 0});
}
}
- size_t alloc_size = (attemp_reccnt * sizeof(Record<K, V, W>)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Record<K, V, W>)) % CACHELINE_SIZE);
+ size_t alloc_size = (attemp_reccnt * sizeof(R)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(R)) % CACHELINE_SIZE);
assert(alloc_size % CACHELINE_SIZE == 0);
- m_data = (Record<K, V, W>*)std::aligned_alloc(CACHELINE_SIZE, alloc_size);
+ m_data = (R*)std::aligned_alloc(CACHELINE_SIZE, alloc_size);
while (pq.size()) {
auto now = pq.peek();
- auto next = pq.size() > 1 ? pq.peek(1) : queue_record<K, V, W>{nullptr, 0};
+ auto next = pq.size() > 1 ? pq.peek(1) : queue_record<R>{nullptr, 0};
if (!m_tagging && !now.data->is_tombstone() && next.data != nullptr &&
- now.data->match(next.data) && next.data->is_tombstone()) {
+ *now.data == *next.data && next.data->is_tombstone()) {
pq.pop(); pq.pop();
auto& cursor1 = cursors[now.version];
auto& cursor2 = cursors[next.version];
- if (advance_cursor<K,V,W>(cursor1)) pq.push(cursor1.ptr, now.version);
- if (advance_cursor<K,V,W>(cursor2)) pq.push(cursor2.ptr, next.version);
+ if (advance_cursor<R>(cursor1)) pq.push(cursor1.ptr, now.version);
+ if (advance_cursor<R>(cursor2)) pq.push(cursor2.ptr, next.version);
} else {
auto& cursor = cursors[now.version];
- if (!m_tagging || !cursor.ptr->get_delete_status()) {
+ if (!m_tagging || !cursor.ptr->is_deleted()) {
m_data[m_reccnt++] = *cursor.ptr;
m_total_weight += cursor.ptr->weight;
if (bf && cursor.ptr->is_tombstone()) {
@@ -137,7 +145,7 @@ public:
}
pq.pop();
- if (advance_cursor<K,V,W>(cursor)) pq.push(cursor.ptr, now.version);
+ if (advance_cursor<R>(cursor)) pq.push(cursor.ptr, now.version);
}
}
@@ -155,16 +163,16 @@ public:
free_tree(m_root);
}
- bool delete_record(const K& key, const V& val) {
- size_t idx = get_lower_bound(key);
+ bool delete_record(const R& rec) {
+ size_t idx = get_lower_bound(rec.key);
if (idx >= m_reccnt) {
return false;
}
- while (idx < m_reccnt && m_data[idx].lt(key, val)) ++idx;
+ while (idx < m_reccnt && m_data[idx] < rec) ++idx;
- if (m_data[idx].match(key, val, false)) {
- m_data[idx].set_delete_status();
+ if (m_data[idx] == R {rec.key, rec.val} && !m_data[idx].is_tombstone()) {
+ m_data[idx].set_delete();
m_deleted_cnt++;
return true;
}
@@ -172,7 +180,7 @@ public:
return false;
}
- void free_tree(struct wirs_node* node) {
+ void free_tree(struct wirs_node<R>* node) {
if (node) {
delete node->alias;
free_tree(node->left);
@@ -181,7 +189,7 @@ public:
}
}
- Record<K, V, W>* sorted_output() const {
+ R* sorted_output() const {
return m_data;
}
@@ -193,19 +201,19 @@ public:
return m_tombstone_cnt;
}
- const Record<K, V, W>* get_record_at(size_t idx) const {
+ const R* get_record_at(size_t idx) const {
if (idx >= m_reccnt) return nullptr;
return m_data + idx;
}
// low - high -> decompose to a set of nodes.
// Build Alias across the decomposed nodes.
- WIRSState* get_sample_shard_state(const K& lower_key, const K& upper_key) {
- WIRSState* res = new WIRSState();
+ WIRSState<R>* get_sample_shard_state(const K& lower_key, const K& upper_key) {
+ auto res = new WIRSState();
// Simulate a stack to unfold recursion.
double tot_weight = 0.0;
- struct wirs_node* st[64] = {0};
+ struct wirs_node<R>* st[64] = {0};
st[0] = m_root;
size_t top = 1;
while(top > 0) {
@@ -231,15 +239,15 @@ public:
}
static void delete_state(void *state) {
- auto s = (WIRSState *) state;
+ WIRSState<R> *s = (WIRSState<R> *) state;
delete s;
}
// returns the number of records sampled
// NOTE: This operation returns records strictly between the lower and upper bounds, not
// including them.
- size_t get_samples(void* shard_state, std::vector<Record<K, V, W>> &result_set, const K& lower_key, const K& upper_key, size_t sample_sz, gsl_rng *rng) {
- WIRSState *state = (WIRSState *) shard_state;
+ size_t get_samples(void* shard_state, std::vector<R> &result_set, const K& lower_key, const K& upper_key, size_t sample_sz, gsl_rng *rng) {
+ WIRSState<R> *state = (WIRSState<R> *) shard_state;
if (sample_sz == 0) {
return 0;
}
@@ -295,30 +303,30 @@ public:
auto ptr = m_data + get_lower_bound(key);
- while (ptr < m_data + m_reccnt && ptr->lt(key, val)) {
+ while (ptr < m_data + m_reccnt && *ptr < R {key, val}) {
ptr ++;
}
- bool result = (m_tagging) ? ptr->get_delete_status()
- : ptr->match(key, val, true);
+ bool result = (m_tagging) ? ptr->is_deleted()
+ : *ptr == R {key, val} && ptr->is_tombstone();
m_rejection_cnt += result;
return result;
}
- bool check_tombstone(const K& key, const V& val) {
+ bool check_tombstone(const R& rec) {
m_ts_check_cnt++;
- size_t idx = get_lower_bound(key);
+ size_t idx = get_lower_bound(rec.key);
if (idx >= m_reccnt) {
return false;
}
- auto ptr = m_data + get_lower_bound(key);
+ auto ptr = m_data + get_lower_bound(rec.key);
- while (ptr < m_data + m_reccnt && ptr->lt(key, val)) {
+ while (ptr < m_data + m_reccnt && *ptr < rec) {
ptr ++;
}
- bool result = ptr->match(key, val, true);
+ bool result = *ptr == rec && ptr->is_tombstone();
m_rejection_cnt += result;
return result;
@@ -340,21 +348,21 @@ public:
private:
- bool covered_by(struct wirs_node* node, const K& lower_key, const K& upper_key) {
+ bool covered_by(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) {
auto low_index = node->low * m_group_size;
auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1);
return lower_key < m_data[low_index].key && m_data[high_index].key < upper_key;
}
- bool intersects(struct wirs_node* node, const K& lower_key, const K& upper_key) {
+ bool intersects(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) {
auto low_index = node->low * m_group_size;
auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1);
return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key;
}
- struct wirs_node* construct_wirs_node(const std::vector<W>& weights, size_t low, size_t high) {
+ struct wirs_node<R>* construct_wirs_node(const std::vector<W>& weights, size_t low, size_t high) {
if (low == high) {
- return new wirs_node{nullptr, nullptr, low, high, weights[low], new Alias({1.0})};
+ return new wirs_node<R>{nullptr, nullptr, low, high, weights[low], new Alias({1.0})};
} else if (low > high) return nullptr;
std::vector<double> node_weights;
@@ -370,9 +378,9 @@ private:
size_t mid = (low + high) / 2;
- return new wirs_node{construct_wirs_node(weights, low, mid),
- construct_wirs_node(weights, mid + 1, high),
- low, high, sum, new Alias(node_weights)};
+ return new wirs_node<R>{construct_wirs_node(weights, low, mid),
+ construct_wirs_node(weights, mid + 1, high),
+ low, high, sum, new Alias(node_weights)};
}
@@ -410,9 +418,9 @@ private:
m_root = construct_wirs_node(weights, 0, n_groups-1);
}
- Record<K, V, W>* m_data;
+ R* m_data;
std::vector<Alias *> m_alias;
- wirs_node* m_root;
+ wirs_node<R>* m_root;
bool m_tagging;
W m_total_weight;
size_t m_reccnt;