summaryrefslogtreecommitdiffstats
path: root/include/query
diff options
context:
space:
mode:
authorDouglas B. Rumbaugh <dbr4@psu.edu>2024-02-09 14:06:59 -0500
committerGitHub <noreply@github.com>2024-02-09 14:06:59 -0500
commitbc0f3cca3a5b495fcae1d3ad8d09e6d714da5d30 (patch)
tree66333c55feb0ea8875a50e6dc07c8535d241bf1c /include/query
parent076e104b8672924c3d80cd1da2fdb5ebee1766ac (diff)
parent46885246313358a3b606eca139b20280e96db10e (diff)
downloaddynamic-extension-bc0f3cca3a5b495fcae1d3ad8d09e6d714da5d30.tar.gz
Merge pull request #1 from dbrumbaugh/new-buffer
Initial Concurrency Implementation
Diffstat (limited to 'include/query')
-rw-r--r--include/query/irs.h223
-rw-r--r--include/query/knn.h159
-rw-r--r--include/query/rangecount.h165
-rw-r--r--include/query/rangequery.h174
-rw-r--r--include/query/wirs.h244
-rw-r--r--include/query/wss.h209
6 files changed, 1174 insertions, 0 deletions
diff --git a/include/query/irs.h b/include/query/irs.h
new file mode 100644
index 0000000..e2d9325
--- /dev/null
+++ b/include/query/irs.h
@@ -0,0 +1,223 @@
+/*
+ * include/query/irs.h
+ *
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ *
+ * Distributed under the Modified BSD License.
+ *
+ * A query class for independent range sampling. This query requires
+ * that the shard support get_lower_bound(key), get_upper_bound(key),
+ * and get_record_at(index).
+ */
+#pragma once
+
+#include "framework/QueryRequirements.h"
+#include "psu-ds/Alias.h"
+
+namespace de { namespace irs {
+
+template <RecordInterface R>
+struct Parms {
+ decltype(R::key) lower_bound;
+ decltype(R::key) upper_bound;
+ size_t sample_size;
+ gsl_rng *rng;
+};
+
+
+template <RecordInterface R>
+struct State {
+ size_t lower_bound;
+ size_t upper_bound;
+ size_t sample_size;
+ size_t total_weight;
+};
+
+template <RecordInterface R>
+struct BufferState {
+ size_t cutoff;
+ std::vector<Wrapped<R>> records;
+ size_t sample_size;
+ BufferView<R> *buffer;
+
+ BufferState(BufferView<R> *buffer) : buffer(buffer) {}
+};
+
+template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
+class Query {
+public:
+ constexpr static bool EARLY_ABORT=false;
+ constexpr static bool SKIP_DELETE_FILTER=false;
+
+ static void *get_query_state(S *shard, void *parms) {
+ auto res = new State<R>();
+ decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound;
+ decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound;
+
+ res->lower_bound = shard->get_lower_bound(lower_key);
+ res->upper_bound = shard->get_upper_bound(upper_key);
+
+ if (res->lower_bound == shard->get_record_count()) {
+ res->total_weight = 0;
+ } else {
+ res->total_weight = res->upper_bound - res->lower_bound;
+ }
+
+ res->sample_size = 0;
+ return res;
+ }
+
+ static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
+ auto res = new BufferState<R>(buffer);
+
+ res->cutoff = res->buffer->get_record_count();
+ res->sample_size = 0;
+
+ if constexpr (Rejection) {
+ return res;
+ }
+
+ auto lower_key = ((Parms<R> *) parms)->lower_bound;
+ auto upper_key = ((Parms<R> *) parms)->upper_bound;
+
+ for (size_t i=0; i<res->cutoff; i++) {
+ if ((res->buffer->get(i)->rec.key >= lower_key) && (buffer->get(i)->rec.key <= upper_key)) {
+ res->records.emplace_back(*(res->buffer->get(i)));
+ }
+ }
+
+ return res;
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void *buffer_state) {
+ auto p = (Parms<R> *) query_parms;
+ auto bs = (buffer_state) ? (BufferState<R> *) buffer_state : nullptr;
+
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
+ size_t buffer_sz = 0;
+
+ std::vector<size_t> weights;
+ if constexpr (Rejection) {
+ weights.push_back((bs) ? bs->cutoff : 0);
+ } else {
+ weights.push_back((bs) ? bs->records.size() : 0);
+ }
+
+ size_t total_weight = 0;
+ for (auto &s : shard_states) {
+ auto state = (State<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
+ }
+
+ // if no valid records fall within the query range, just
+ // set all of the sample sizes to 0 and bail out.
+ if (total_weight == 0) {
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (State<R> *) shard_states[i];
+ state->sample_size = 0;
+ }
+
+ return;
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = psudb::Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
+
+ if (bs) {
+ bs->sample_size = buffer_sz;
+ }
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (State<R> *) shard_states[i];
+ state->sample_size = shard_sample_sizes[i+1];
+ }
+ }
+
+ static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
+ auto lower_key = ((Parms<R> *) parms)->lower_bound;
+ auto upper_key = ((Parms<R> *) parms)->upper_bound;
+ auto rng = ((Parms<R> *) parms)->rng;
+
+ auto state = (State<R> *) q_state;
+ auto sample_sz = state->sample_size;
+
+ std::vector<Wrapped<R>> result_set;
+
+ if (sample_sz == 0 || state->lower_bound == shard->get_record_count()) {
+ return result_set;
+ }
+
+ size_t attempts = 0;
+ size_t range_length = state->upper_bound - state->lower_bound;
+ do {
+ attempts++;
+ size_t idx = (range_length > 0) ? gsl_rng_uniform_int(rng, range_length) : 0;
+ result_set.emplace_back(*shard->get_record_at(state->lower_bound + idx));
+ } while (attempts < sample_sz);
+
+ return result_set;
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
+ auto st = (BufferState<R> *) state;
+ auto p = (Parms<R> *) parms;
+
+ std::vector<Wrapped<R>> result;
+ result.reserve(st->sample_size);
+
+ if constexpr (Rejection) {
+ for (size_t i=0; i<st->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
+ auto rec = st->buffer->get(idx);
+
+ if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
+ result.emplace_back(*rec);
+ }
+ }
+
+ return result;
+ }
+
+ for (size_t i=0; i<st->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(p->rng, st->records.size());
+ result.emplace_back(st->records[idx]);
+ }
+
+ return result;
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
+ std::vector<R> output;
+
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results[i].size(); j++) {
+ output.emplace_back(results[i][j].rec);
+ }
+ }
+
+ return output;
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (State<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (BufferState<R> *) state;
+ delete s;
+ }
+};
+}}
diff --git a/include/query/knn.h b/include/query/knn.h
new file mode 100644
index 0000000..19dcf5c
--- /dev/null
+++ b/include/query/knn.h
@@ -0,0 +1,159 @@
+/*
+ * include/query/knn.h
+ *
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ *
+ * Distributed under the Modified BSD License.
+ *
+ * A query class for k-NN queries, designed for use with the VPTree
+ * shard.
+ *
+ * FIXME: no support for tombstone deletes just yet. This would require a
+ * query resumption mechanism, most likely.
+ */
+#pragma once
+
+#include "framework/QueryRequirements.h"
+#include "psu-ds/PriorityQueue.h"
+
+namespace de { namespace knn {
+
+using psudb::PriorityQueue;
+
+template <NDRecordInterface R>
+struct Parms {
+ R point;
+ size_t k;
+};
+
+template <NDRecordInterface R>
+struct State {
+ size_t k;
+};
+
+template <NDRecordInterface R>
+struct BufferState {
+ BufferView<R> *buffer;
+
+ BufferState(BufferView<R> *buffer)
+ : buffer(buffer) {}
+};
+
+template <NDRecordInterface R, ShardInterface<R> S>
+class Query {
+public:
+ constexpr static bool EARLY_ABORT=false;
+ constexpr static bool SKIP_DELETE_FILTER=true;
+
+ static void *get_query_state(S *shard, void *parms) {
+ return nullptr;
+ }
+
+ static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
+ return new BufferState<R>(buffer);
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
+ return;
+ }
+
+ static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
+ std::vector<Wrapped<R>> results;
+ Parms<R> *p = (Parms<R> *) parms;
+ Wrapped<R> wrec;
+ wrec.rec = p->point;
+ wrec.header = 0;
+
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(p->k, &wrec);
+
+ shard->search(p->point, p->k, pq);
+
+ while (pq.size() > 0) {
+ results.emplace_back(*pq.peek().data);
+ pq.pop();
+ }
+
+ return results;
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
+ Parms<R> *p = (Parms<R> *) parms;
+ BufferState<R> *s = (BufferState<R> *) state;
+ Wrapped<R> wrec;
+ wrec.rec = p->point;
+ wrec.header = 0;
+
+ size_t k = p->k;
+
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(k, &wrec);
+ for (size_t i=0; i<s->buffer->get_record_count(); i++) {
+ // Skip over deleted records (under tagging)
+ if (s->buffer->get(i)->is_deleted()) {
+ continue;
+ }
+
+ if (pq.size() < k) {
+ pq.push(s->buffer->get(i));
+ } else {
+ double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
+ double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(s->buffer->get(i));
+ }
+ }
+ }
+
+ std::vector<Wrapped<R>> results;
+ while (pq.size() > 0) {
+ results.emplace_back(*(pq.peek().data));
+ pq.pop();
+ }
+
+ return results;
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
+ Parms<R> *p = (Parms<R> *) parms;
+ R rec = p->point;
+ size_t k = p->k;
+
+ PriorityQueue<R, DistCmpMax<R>> pq(k, &rec);
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results[i].size(); j++) {
+ if (pq.size() < k) {
+ pq.push(&results[i][j].rec);
+ } else {
+ double head_dist = pq.peek().data->calc_distance(rec);
+ double cur_dist = results[i][j].rec.calc_distance(rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(&results[i][j].rec);
+ }
+ }
+ }
+ }
+
+ std::vector<R> output;
+ while (pq.size() > 0) {
+ output.emplace_back(*pq.peek().data);
+ pq.pop();
+ }
+
+ return output;
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (State<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (BufferState<R> *) state;
+ delete s;
+ }
+};
+
+}}
diff --git a/include/query/rangecount.h b/include/query/rangecount.h
new file mode 100644
index 0000000..6c57809
--- /dev/null
+++ b/include/query/rangecount.h
@@ -0,0 +1,165 @@
+/*
+ * include/query/rangecount.h
+ *
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ *
+ * Distributed under the Modified BSD License.
+ *
+ * A query class for single dimensional range count queries. This query
+ * requires that the shard support get_lower_bound(key) and
+ * get_record_at(index).
+ */
+#pragma once
+
+#include "framework/QueryRequirements.h"
+
+namespace de { namespace rc {
+
+template <RecordInterface R>
+struct Parms {
+ decltype(R::key) lower_bound;
+ decltype(R::key) upper_bound;
+};
+
+template <RecordInterface R>
+struct State {
+ size_t start_idx;
+ size_t stop_idx;
+};
+
+template <RecordInterface R>
+struct BufferState {
+ BufferView<R> *buffer;
+
+ BufferState(BufferView<R> *buffer)
+ : buffer(buffer) {}
+};
+
+template <KVPInterface R, ShardInterface<R> S>
+class Query {
+public:
+ constexpr static bool EARLY_ABORT=false;
+ constexpr static bool SKIP_DELETE_FILTER=true;
+
+ static void *get_query_state(S *shard, void *parms) {
+ auto res = new State<R>();
+ auto p = (Parms<R> *) parms;
+
+ res->start_idx = shard->get_lower_bound(p->lower_bound);
+ res->stop_idx = shard->get_record_count();
+
+ return res;
+ }
+
+ static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
+ auto res = new BufferState<R>(buffer);
+
+ return res;
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
+ return;
+ }
+
+ static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
+ std::vector<Wrapped<R>> records;
+ auto p = (Parms<R> *) parms;
+ auto s = (State<R> *) q_state;
+
+ size_t reccnt = 0;
+ size_t tscnt = 0;
+
+ Wrapped<R> res;
+ res.rec.key= 0; // records
+ res.rec.value = 0; // tombstones
+ records.emplace_back(res);
+
+ /*
+ * if the returned index is one past the end of the
+ * records for the PGM, then there are not records
+ * in the index falling into the specified range.
+ */
+ if (s->start_idx == shard->get_record_count()) {
+ return records;
+ }
+
+ auto ptr = shard->get_record_at(s->start_idx);
+
+ /*
+ * roll the pointer forward to the first record that is
+ * greater than or equal to the lower bound.
+ */
+ while(ptr < shard->get_data() + s->stop_idx && ptr->rec.key < p->lower_bound) {
+ ptr++;
+ }
+
+ while (ptr < shard->get_data() + s->stop_idx && ptr->rec.key <= p->upper_bound) {
+ if (!ptr->is_deleted()) {
+ if (ptr->is_tombstone()) {
+ records[0].rec.value++;
+ } else {
+ records[0].rec.key++;
+ }
+ }
+
+ ptr++;
+ }
+
+ return records;
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
+ auto p = (Parms<R> *) parms;
+ auto s = (BufferState<R> *) state;
+
+ std::vector<Wrapped<R>> records;
+
+ Wrapped<R> res;
+ res.rec.key= 0; // records
+ res.rec.value = 0; // tombstones
+ records.emplace_back(res);
+
+ for (size_t i=0; i<s->buffer->get_record_count(); i++) {
+ auto rec = s->buffer->get(i);
+ if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound
+ && !rec->is_deleted()) {
+ if (rec->is_tombstone()) {
+ records[0].rec.value++;
+ } else {
+ records[0].rec.key++;
+ }
+ }
+ }
+
+ return records;
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
+
+ R res;
+ res.key = 0;
+ res.value = 0;
+ std::vector<R> output;
+ output.emplace_back(res);
+
+ for (size_t i=0; i<results.size(); i++) {
+ output[0].key += results[i][0].rec.key; // records
+ output[0].value += results[i][0].rec.value; // tombstones
+ }
+
+ output[0].key -= output[0].value;
+ return output;
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (State<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (BufferState<R> *) state;
+ delete s;
+ }
+};
+
+}}
diff --git a/include/query/rangequery.h b/include/query/rangequery.h
new file mode 100644
index 0000000..24b38ec
--- /dev/null
+++ b/include/query/rangequery.h
@@ -0,0 +1,174 @@
+/*
+ * include/query/rangequery.h
+ *
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ *
+ * Distributed under the Modified BSD License.
+ *
+ * A query class for single dimensional range queries. This query requires
+ * that the shard support get_lower_bound(key) and get_record_at(index).
+ */
+#pragma once
+
+#include "framework/QueryRequirements.h"
+#include "psu-ds/PriorityQueue.h"
+#include "util/Cursor.h"
+
+namespace de { namespace rq {
+
+template <RecordInterface R>
+struct Parms {
+ decltype(R::key) lower_bound;
+ decltype(R::key) upper_bound;
+};
+
+template <RecordInterface R>
+struct State {
+ size_t start_idx;
+ size_t stop_idx;
+};
+
+template <RecordInterface R>
+struct BufferState {
+ BufferView<R> *buffer;
+
+ BufferState(BufferView<R> *buffer)
+ : buffer(buffer) {}
+};
+
+template <RecordInterface R, ShardInterface<R> S>
+class Query {
+public:
+ constexpr static bool EARLY_ABORT=false;
+ constexpr static bool SKIP_DELETE_FILTER=true;
+
+ static void *get_query_state(S *shard, void *parms) {
+ auto res = new State<R>();
+ auto p = (Parms<R> *) parms;
+
+ res->start_idx = shard->get_lower_bound(p->lower_bound);
+ res->stop_idx = shard->get_record_count();
+
+ return res;
+ }
+
+ static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
+ auto res = new BufferState<R>(buffer);
+
+ return res;
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
+ return;
+ }
+
+ static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
+ std::vector<Wrapped<R>> records;
+ auto p = (Parms<R> *) parms;
+ auto s = (State<R> *) q_state;
+
+ /*
+ * if the returned index is one past the end of the
+ * records for the PGM, then there are not records
+ * in the index falling into the specified range.
+ */
+ if (s->start_idx == shard->get_record_count()) {
+ return records;
+ }
+
+ auto ptr = shard->get_record_at(s->start_idx);
+
+ /*
+ * roll the pointer forward to the first record that is
+ * greater than or equal to the lower bound.
+ */
+ while(ptr < shard->get_data() + s->stop_idx && ptr->rec.key < p->lower_bound) {
+ ptr++;
+ }
+
+ while (ptr < shard->get_data() + s->stop_idx && ptr->rec.key <= p->upper_bound) {
+ records.emplace_back(*ptr);
+ ptr++;
+ }
+
+ return records;
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
+ auto p = (Parms<R> *) parms;
+ auto s = (BufferState<R> *) state;
+
+ std::vector<Wrapped<R>> records;
+ for (size_t i=0; i<s->buffer->get_record_count(); i++) {
+ auto rec = s->buffer->get(i);
+ if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
+ records.emplace_back(*rec);
+ }
+ }
+
+ return records;
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
+ std::vector<Cursor<Wrapped<R>>> cursors;
+ cursors.reserve(results.size());
+
+ psudb::PriorityQueue<Wrapped<R>> pq(results.size());
+ size_t total = 0;
+ size_t tmp_n = results.size();
+
+
+ for (size_t i = 0; i < tmp_n; ++i)
+ if (results[i].size() > 0){
+ auto base = results[i].data();
+ cursors.emplace_back(Cursor{base, base + results[i].size(), 0, results[i].size()});
+ assert(i == cursors.size() - 1);
+ total += results[i].size();
+ pq.push(cursors[i].ptr, tmp_n - i - 1);
+ } else {
+ cursors.emplace_back(Cursor<Wrapped<R>>{nullptr, nullptr, 0, 0});
+ }
+
+ if (total == 0) {
+ return std::vector<R>();
+ }
+
+ std::vector<R> output;
+ output.reserve(total);
+
+ while (pq.size()) {
+ auto now = pq.peek();
+ auto next = pq.size() > 1 ? pq.peek(1) : psudb::queue_record<Wrapped<R>>{nullptr, 0};
+ if (!now.data->is_tombstone() && next.data != nullptr &&
+ now.data->rec == next.data->rec && next.data->is_tombstone()) {
+
+ pq.pop(); pq.pop();
+ auto& cursor1 = cursors[tmp_n - now.version - 1];
+ auto& cursor2 = cursors[tmp_n - next.version - 1];
+ if (advance_cursor<Wrapped<R>>(cursor1)) pq.push(cursor1.ptr, now.version);
+ if (advance_cursor<Wrapped<R>>(cursor2)) pq.push(cursor2.ptr, next.version);
+ } else {
+ auto& cursor = cursors[tmp_n - now.version - 1];
+ if (!now.data->is_tombstone()) output.push_back(cursor.ptr->rec);
+
+ pq.pop();
+
+ if (advance_cursor<Wrapped<R>>(cursor)) pq.push(cursor.ptr, now.version);
+ }
+ }
+
+ return output;
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (State<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (BufferState<R> *) state;
+ delete s;
+ }
+};
+
+}}
diff --git a/include/query/wirs.h b/include/query/wirs.h
new file mode 100644
index 0000000..ae82194
--- /dev/null
+++ b/include/query/wirs.h
@@ -0,0 +1,244 @@
+/*
+ * include/query/wirs.h
+ *
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ *
+ * Distributed under the Modified BSD License.
+ *
+ * A query class for weighted independent range sampling. This
+ * class is tightly coupled with include/shard/AugBTree.h, and
+ * so is probably of limited general utility.
+ */
+#pragma once
+
+#include "framework/QueryRequirements.h"
+#include "psu-ds/Alias.h"
+
+namespace de { namespace wirs {
+
+template <WeightedRecordInterface R>
+struct Parms {
+ decltype(R::key) lower_bound;
+ decltype(R::key) upper_bound;
+ size_t sample_size;
+ gsl_rng *rng;
+};
+
+template <WeightedRecordInterface R>
+struct State {
+ decltype(R::weight) total_weight;
+ std::vector<void*> nodes;
+ psudb::Alias* top_level_alias;
+ size_t sample_size;
+
+ State() {
+ total_weight = 0;
+ top_level_alias = nullptr;
+ }
+
+ ~State() {
+ if (top_level_alias) delete top_level_alias;
+ }
+};
+
+template <RecordInterface R>
+struct BufferState {
+ size_t cutoff;
+ psudb::Alias* alias;
+ std::vector<Wrapped<R>> records;
+ decltype(R::weight) max_weight;
+ size_t sample_size;
+ decltype(R::weight) total_weight;
+ BufferView<R> *buffer;
+
+ ~BufferState() {
+ delete alias;
+ }
+};
+
+template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
+class Query {
+public:
+ constexpr static bool EARLY_ABORT=false;
+ constexpr static bool SKIP_DELETE_FILTER=false;
+
+ static void *get_query_state(S *shard, void *parms) {
+ auto res = new State<R>();
+ decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound;
+ decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound;
+
+ std::vector<decltype(R::weight)> weights;
+ res->total_weight = shard->find_covering_nodes(lower_key, upper_key, res->nodes, weights);
+
+ std::vector<double> normalized_weights;
+ for (auto weight : weights) {
+ normalized_weights.emplace_back(weight / res->total_weight);
+ }
+
+ res->top_level_alias = new psudb::Alias(normalized_weights);
+ res->sample_size = 0;
+
+ return res;
+ }
+
+ static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
+ BufferState<R> *state = new BufferState<R>();
+ auto parameters = (Parms<R>*) parms;
+
+ if constexpr (Rejection) {
+ state->cutoff = buffer->get_record_count() - 1;
+ state->max_weight = buffer->get_max_weight();
+ state->total_weight = buffer->get_total_weight();
+ state->sample_size = 0;
+ state->buffer = buffer;
+ return state;
+ }
+
+ std::vector<decltype(R::weight)> weights;
+
+ state->buffer = buffer;
+ decltype(R::weight) total_weight = 0;
+
+ for (size_t i = 0; i <= buffer->get_record_count(); i++) {
+ auto rec = buffer->get(i);
+
+ if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) {
+ weights.push_back(rec->rec.weight);
+ state->records.push_back(*rec);
+ total_weight += rec->rec.weight;
+ }
+ }
+
+ std::vector<double> normalized_weights;
+ for (size_t i = 0; i < weights.size(); i++) {
+ normalized_weights.push_back(weights[i] / total_weight);
+ }
+
+ state->total_weight = total_weight;
+ state->alias = new psudb::Alias(normalized_weights);
+ state->sample_size = 0;
+
+ return state;
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
+ auto p = (Parms<R> *) query_parms;
+
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
+ size_t buffer_sz = 0;
+
+ std::vector<decltype(R::weight)> weights;
+
+ decltype(R::weight) total_weight = 0;
+ for (auto &s : buffer_states) {
+ auto bs = (BufferState<R> *) s;
+ total_weight += bs->total_weight;
+ weights.push_back(bs->total_weight);
+ }
+
+ for (auto &s : shard_states) {
+ auto state = (State<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = psudb::Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+
+ if (idx < buffer_states.size()) {
+ auto state = (BufferState<R> *) buffer_states[idx];
+ state->sample_size++;
+ } else {
+ auto state = (State<R> *) shard_states[idx - buffer_states.size()];
+ state->sample_size++;
+ }
+ }
+ }
+
+ static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
+ auto lower_key = ((Parms<R> *) parms)->lower_bound;
+ auto upper_key = ((Parms<R> *) parms)->upper_bound;
+ auto rng = ((Parms<R> *) parms)->rng;
+
+ auto state = (State<R> *) q_state;
+ auto sample_size = state->sample_size;
+
+ std::vector<Wrapped<R>> result_set;
+
+ if (sample_size == 0) {
+ return result_set;
+ }
+ size_t cnt = 0;
+ size_t attempts = 0;
+
+ for (size_t i=0; i<sample_size; i++) {
+ auto rec = shard->get_weighted_sample(lower_key, upper_key,
+ state->nodes[state->top_level_alias->get(rng)],
+ rng);
+ if (rec) {
+ result_set.emplace_back(*rec);
+ }
+ }
+
+ return result_set;
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
+ auto st = (BufferState<R> *) state;
+ auto p = (Parms<R> *) parms;
+ auto buffer = st->buffer;
+
+ std::vector<Wrapped<R>> result;
+ result.reserve(st->sample_size);
+
+ if constexpr (Rejection) {
+ for (size_t i=0; i<st->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
+ auto rec = buffer->get(idx);
+
+ auto test = gsl_rng_uniform(p->rng) * st->max_weight;
+
+ if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
+ result.emplace_back(*rec);
+ }
+ }
+ return result;
+ }
+
+ for (size_t i=0; i<st->sample_size; i++) {
+ auto idx = st->alias->get(p->rng);
+ result.emplace_back(st->records[idx]);
+ }
+
+ return result;
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
+ std::vector<R> output;
+
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results[i].size(); j++) {
+ output.emplace_back(results[i][j].rec);
+ }
+ }
+
+ return output;
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (State<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (BufferState<R> *) state;
+ delete s;
+ }
+};
+}}
diff --git a/include/query/wss.h b/include/query/wss.h
new file mode 100644
index 0000000..8797035
--- /dev/null
+++ b/include/query/wss.h
@@ -0,0 +1,209 @@
+/*
+ * include/query/wss.h
+ *
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ *
+ * Distributed under the Modified BSD License.
+ *
+ * A query class for weighted set sampling. This
+ * class is tightly coupled with include/shard/Alias.h,
+ * and so is probably of limited general utility.
+ */
+#pragma once
+
+#include "framework/QueryRequirements.h"
+#include "psu-ds/Alias.h"
+
+namespace de { namespace wss {
+
+template <WeightedRecordInterface R>
+struct Parms {
+ size_t sample_size;
+ gsl_rng *rng;
+};
+
+template <WeightedRecordInterface R>
+struct State {
+ decltype(R::weight) total_weight;
+ size_t sample_size;
+
+ State() {
+ total_weight = 0;
+ }
+};
+
+template <RecordInterface R>
+struct BufferState {
+ size_t cutoff;
+ size_t sample_size;
+ psudb::Alias *alias;
+ decltype(R::weight) max_weight;
+ decltype(R::weight) total_weight;
+ BufferView<R> *buffer;
+
+ ~BufferState() {
+ delete alias;
+ }
+};
+
+template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
+class Query {
+public:
+ constexpr static bool EARLY_ABORT=false;
+ constexpr static bool SKIP_DELETE_FILTER=false;
+
+ static void *get_query_state(S *shard, void *parms) {
+ auto res = new State<R>();
+ res->total_weight = shard->get_total_weight();
+ res->sample_size = 0;
+
+ return res;
+ }
+
+ static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
+ BufferState<R> *state = new BufferState<R>();
+ auto parameters = (Parms<R>*) parms;
+ if constexpr (Rejection) {
+ state->cutoff = buffer->get_record_count() - 1;
+ state->max_weight = buffer->get_max_weight();
+ state->total_weight = buffer->get_total_weight();
+ state->buffer = buffer;
+ return state;
+ }
+
+ std::vector<double> weights;
+
+ double total_weight = 0.0;
+ state->buffer = buffer;
+
+ for (size_t i = 0; i <= buffer->get_record_count(); i++) {
+ auto rec = buffer->get_data(i);
+ weights.push_back(rec->rec.weight);
+ total_weight += rec->rec.weight;
+ }
+
+ for (size_t i = 0; i < weights.size(); i++) {
+ weights[i] = weights[i] / total_weight;
+ }
+
+ state->alias = new psudb::Alias(weights);
+ state->total_weight = total_weight;
+
+ return state;
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
+ auto p = (Parms<R> *) query_parms;
+
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
+ size_t buffer_sz = 0;
+
+ std::vector<decltype(R::weight)> weights;
+
+ decltype(R::weight) total_weight = 0;
+ for (auto &s : buffer_states) {
+ auto bs = (BufferState<R> *) s;
+ total_weight += bs->total_weight;
+ weights.push_back(bs->total_weight);
+ }
+
+ for (auto &s : shard_states) {
+ auto state = (State<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = psudb::Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+
+ if (idx < buffer_states.size()) {
+ auto state = (BufferState<R> *) buffer_states[idx];
+ state->sample_size++;
+ } else {
+ auto state = (State<R> *) shard_states[idx - buffer_states.size()];
+ state->sample_size++;
+ }
+ }
+ }
+
+ static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
+ auto rng = ((Parms<R> *) parms)->rng;
+
+ auto state = (State<R> *) q_state;
+ auto sample_size = state->sample_size;
+
+ std::vector<Wrapped<R>> result_set;
+
+ if (sample_size == 0) {
+ return result_set;
+ }
+ size_t attempts = 0;
+ do {
+ attempts++;
+ size_t idx = shard->get_weighted_sample(rng);
+ result_set.emplace_back(*shard->get_record_at(idx));
+ } while (attempts < sample_size);
+
+ return result_set;
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
+ auto st = (BufferState<R> *) state;
+ auto p = (Parms<R> *) parms;
+ auto buffer = st->buffer;
+
+ std::vector<Wrapped<R>> result;
+ result.reserve(st->sample_size);
+
+ if constexpr (Rejection) {
+ for (size_t i=0; i<st->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
+ auto rec = buffer->get(idx);
+
+ auto test = gsl_rng_uniform(p->rng) * st->max_weight;
+
+ if (test <= rec->rec.weight) {
+ result.emplace_back(*rec);
+ }
+ }
+ return result;
+ }
+
+ for (size_t i=0; i<st->sample_size; i++) {
+ auto idx = st->alias->get(p->rng);
+ result.emplace_back(*(buffer->get_data() + idx));
+ }
+
+ return result;
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
+ std::vector<R> output;
+
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results[i].size(); j++) {
+ output.emplace_back(results[i][j].rec);
+ }
+ }
+
+ return output;
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (State<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (BufferState<R> *) state;
+ delete s;
+ }
+};
+
+}}