summaryrefslogtreecommitdiffstats
path: root/include/query
diff options
context:
space:
mode:
Diffstat (limited to 'include/query')
-rw-r--r--include/query/irs.h360
-rw-r--r--include/query/knn.h224
-rw-r--r--include/query/pointlookup.h170
-rw-r--r--include/query/rangecount.h259
-rw-r--r--include/query/rangequery.h283
-rw-r--r--include/query/wirs.h251
-rw-r--r--include/query/wss.h282
7 files changed, 773 insertions, 1056 deletions
diff --git a/include/query/irs.h b/include/query/irs.h
index 879d070..6dec850 100644
--- a/include/query/irs.h
+++ b/include/query/irs.h
@@ -1,12 +1,12 @@
/*
* include/query/irs.h
*
- * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
*
* Distributed under the Modified BSD License.
*
- * A query class for independent range sampling. This query requires
- * that the shard support get_lower_bound(key), get_upper_bound(key),
+ * A query class for independent range sampling. This query requires
+ * that the shard support get_lower_bound(key), get_upper_bound(key),
* and get_record_at(index).
*/
#pragma once
@@ -14,237 +14,227 @@
#include "framework/QueryRequirements.h"
#include "psu-ds/Alias.h"
-namespace de { namespace irs {
+namespace de {
+namespace irs {
-template <RecordInterface R>
-struct Parms {
+template <ShardInterface S, bool REJECTION = true> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
decltype(R::key) lower_bound;
decltype(R::key) upper_bound;
size_t sample_size;
gsl_rng *rng;
-};
+ };
-
-template <RecordInterface R>
-struct State {
- size_t lower_bound;
- size_t upper_bound;
- size_t sample_size;
+ struct LocalQuery {
+ size_t lower_idx;
+ size_t upper_idx;
size_t total_weight;
-};
+ size_t sample_size;
+ Parameters global_parms;
+ };
+
+ struct LocalQueryBuffer {
+ BufferView<R> *buffer;
-template <RecordInterface R>
-struct BufferState {
size_t cutoff;
std::vector<Wrapped<R>> records;
+ std::unique_ptr<psudb::Alias> alias;
size_t sample_size;
- BufferView<R> *buffer;
- psudb::Alias *alias;
+ Parameters global_parms;
+ };
- BufferState(BufferView<R> *buffer) : buffer(buffer) {}
- ~BufferState() {
- delete alias;
- }
-};
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
-template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=false;
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = false;
- static void *get_query_state(S *shard, void *parms) {
- auto res = new State<R>();
- decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound;
- decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound;
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
- res->lower_bound = shard->get_lower_bound(lower_key);
- res->upper_bound = shard->get_upper_bound(upper_key);
+ query->global_parms = *parms;
- if (res->lower_bound == shard->get_record_count()) {
- res->total_weight = 0;
- } else {
- res->total_weight = res->upper_bound - res->lower_bound;
- }
+ query->lower_idx = shard->get_lower_bound(query->global_parms.lower_bound);
+ query->upper_idx = shard->get_upper_bound(query->global_parms.upper_bound);
- res->sample_size = 0;
- return res;
+ if (query->lower_idx == shard->get_record_count()) {
+ query->total_weight = 0;
+ } else {
+ query->total_weight = query->upper_idx - query->lower_idx;
}
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- auto res = new BufferState<R>(buffer);
-
- res->cutoff = res->buffer->get_record_count();
- res->sample_size = 0;
- res->alias = nullptr;
+ query->sample_size = 0;
+ return query;
+ }
- if constexpr (Rejection) {
- return res;
- }
-
- auto lower_key = ((Parms<R> *) parms)->lower_bound;
- auto upper_key = ((Parms<R> *) parms)->upper_bound;
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
+ query->buffer = buffer;
- for (size_t i=0; i<res->cutoff; i++) {
- if ((res->buffer->get(i)->rec.key >= lower_key) && (buffer->get(i)->rec.key <= upper_key)) {
- res->records.emplace_back(*(res->buffer->get(i)));
- }
- }
+ query->cutoff = query->buffer->get_record_count();
+ query->sample_size = 0;
+ query->alias = nullptr;
+ query->global_parms = *parms;
- return res;
+ if constexpr (REJECTION) {
+ return query;
}
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void *buffer_state) {
- auto p = (Parms<R> *) query_parms;
- auto bs = (buffer_state) ? (BufferState<R> *) buffer_state : nullptr;
-
- std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
- size_t buffer_sz = 0;
+ for (size_t i = 0; i < query->cutoff; i++) {
+ if ((query->buffer->get(i)->rec.key >= query->global_parms.lower_bound) &&
+ (buffer->get(i)->rec.key <= query->global_parms.upper_bound)) {
+ query->records.emplace_back(*(query->buffer->get(i)));
+ }
+ }
- /* for simplicity of static structure testing */
- if (!bs) {
- assert(shard_states.size() == 1);
- auto state = (State<R> *) shard_states[0];
- state->sample_size = p->sample_size;
- return;
- }
+ return query;
+ }
- /* we only need to build the shard alias on the first call */
- if (bs->alias == nullptr) {
- std::vector<size_t> weights;
- if constexpr (Rejection) {
- weights.push_back((bs) ? bs->cutoff : 0);
- } else {
- weights.push_back((bs) ? bs->records.size() : 0);
- }
-
- size_t total_weight = weights[0];
- for (auto &s : shard_states) {
- auto state = (State<R> *) s;
- total_weight += state->total_weight;
- weights.push_back(state->total_weight);
- }
-
- // if no valid records fall within the query range, just
- // set all of the sample sizes to 0 and bail out.
- if (total_weight == 0) {
- for (size_t i=0; i<shard_states.size(); i++) {
- auto state = (State<R> *) shard_states[i];
- state->sample_size = 0;
- }
-
- return;
- }
-
- std::vector<double> normalized_weights;
- for (auto w : weights) {
- normalized_weights.push_back((double) w / (double) total_weight);
- }
-
- bs->alias = new psudb::Alias(normalized_weights);
- }
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
- for (size_t i=0; i<p->sample_size; i++) {
- auto idx = bs->alias->get(p->rng);
- if (idx == 0) {
- buffer_sz++;
- } else {
- shard_sample_sizes[idx - 1]++;
- }
- }
+ std::vector<size_t> shard_sample_sizes(local_queries.size() + 1, 0);
+ size_t buffer_sz = 0;
- if (bs) {
- bs->sample_size = buffer_sz;
- }
- for (size_t i=0; i<shard_states.size(); i++) {
- auto state = (State<R> *) shard_states[i];
- state->sample_size = shard_sample_sizes[i+1];
- }
+ /* for simplicity of static structure testing */
+ if (!buffer_query) {
+ assert(local_queries.size() == 1);
+ local_queries[0]->sample_size =
+ local_queries[0]->global_parms.sample_size;
+ return;
}
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- auto lower_key = ((Parms<R> *) parms)->lower_bound;
- auto upper_key = ((Parms<R> *) parms)->upper_bound;
- auto rng = ((Parms<R> *) parms)->rng;
-
- auto state = (State<R> *) q_state;
- auto sample_sz = state->sample_size;
-
- std::vector<Wrapped<R>> result_set;
-
- if (sample_sz == 0 || state->lower_bound == shard->get_record_count()) {
- return result_set;
+ /* we only need to build the shard alias on the first call */
+ if (buffer_query->alias == nullptr) {
+ std::vector<size_t> weights;
+ if constexpr (REJECTION) {
+ weights.push_back(buffer_query->cutoff);
+ } else {
+ weights.push_back(buffer_query->records.size());
+ }
+
+ size_t total_weight = weights[0];
+ for (auto &q : local_queries) {
+ total_weight += q->total_weight;
+ weights.push_back(q->total_weight);
+ }
+
+ /*
+ * if no valid records fall within the query range,
+ * set all of the sample sizes to 0 and bail out.
+ */
+ if (total_weight == 0) {
+ for (auto q : local_queries) {
+ q->sample_size = 0;
}
- size_t attempts = 0;
- size_t range_length = state->upper_bound - state->lower_bound;
- do {
- attempts++;
- size_t idx = (range_length > 0) ? gsl_rng_uniform_int(rng, range_length) : 0;
- result_set.emplace_back(*shard->get_record_at(state->lower_bound + idx));
- } while (attempts < sample_sz);
+ return;
+ }
- return result_set;
- }
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double)w / (double)total_weight);
+ }
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto st = (BufferState<R> *) state;
- auto p = (Parms<R> *) parms;
+ buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights);
+ }
- std::vector<Wrapped<R>> result;
- result.reserve(st->sample_size);
+ for (size_t i = 0; i < parms->sample_size; i++) {
+ auto idx = buffer_query->alias->get(parms->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
- if constexpr (Rejection) {
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
- auto rec = st->buffer->get(idx);
+ if (buffer_query) {
+ buffer_query->sample_size = buffer_sz;
+ }
- if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
- result.emplace_back(*rec);
- }
- }
+ for (size_t i = 0; i < local_queries.size(); i++) {
+ local_queries[i]->sample_size = shard_sample_sizes[i];
+ }
+ }
- return result;
- }
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ auto sample_sz = query->sample_size;
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->records.size());
- result.emplace_back(st->records[idx]);
- }
+ std::vector<LocalResultType> result_set;
- return result;
+ if (sample_sz == 0 || query->lower_idx == shard->get_record_count()) {
+ return result_set;
}
- static void merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- output.emplace_back(results[i][j].rec);
- }
+ size_t attempts = 0;
+ size_t range_length = query->upper_idx - query->lower_idx;
+ do {
+ attempts++;
+ size_t idx =
+ (range_length > 0)
+ ? gsl_rng_uniform_int(query->global_parms.rng, range_length)
+ : 0;
+ result_set.emplace_back(*shard->get_record_at(query->lower_idx + idx));
+ } while (attempts < sample_sz);
+
+ return result_set;
+ }
+
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
+ std::vector<LocalResultType> result;
+ result.reserve(query->sample_size);
+
+ if constexpr (REJECTION) {
+ for (size_t i = 0; i < query->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff);
+ auto rec = query->buffer->get(idx);
+
+ if (rec->rec.key >= query->global_parms.lower_bound &&
+ rec->rec.key <= query->global_parms.upper_bound) {
+ result.emplace_back(*rec);
}
- }
+ }
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
+ return result;
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
+ for (size_t i = 0; i < query->sample_size; i++) {
+ auto idx =
+ gsl_rng_uniform_int(query->global_parms.rng, query->records.size());
+ result.emplace_back(query->records[idx]);
}
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- auto p = (Parms<R> *) parms;
-
- if (results.size() < p->sample_size) {
- auto q = *p;
- q.sample_size -= results.size();
- process_query_states(&q, states, buffer_state);
- return true;
- }
+ return result;
+ }
- return false;
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+ for (size_t i = 0; i < local_results.size(); i++) {
+ for (size_t j = 0; j < local_results[i].size(); j++) {
+ output.emplace_back(local_results[i][j].rec);
+ }
}
+ }
+
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ if (output.size() < parms->sample_size) {
+ parms->sample_size -= output.size();
+ distribute_query(parms, local_queries, buffer_query);
+ return true;
+ }
+
+ return false;
+ }
};
-}}
+} // namespace irs
+} // namespace de
diff --git a/include/query/knn.h b/include/query/knn.h
index a227293..87ea10a 100644
--- a/include/query/knn.h
+++ b/include/query/knn.h
@@ -6,7 +6,7 @@
* Distributed under the Modified BSD License.
*
* A query class for k-NN queries, designed for use with the VPTree
- * shard.
+ * shard.
*
* FIXME: no support for tombstone deletes just yet. This would require a
* query resumption mechanism, most likely.
@@ -16,147 +16,147 @@
#include "framework/QueryRequirements.h"
#include "psu-ds/PriorityQueue.h"
-namespace de { namespace knn {
+namespace de {
+namespace knn {
using psudb::PriorityQueue;
-template <NDRecordInterface R>
-struct Parms {
+template <ShardInterface S> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
R point;
size_t k;
-};
+ };
-template <NDRecordInterface R>
-struct State {
- size_t k;
-};
+ struct LocalQuery {
+ Parameters global_parms;
+ };
-template <NDRecordInterface R>
-struct BufferState {
+ struct LocalQueryBuffer {
BufferView<R> *buffer;
+ Parameters global_parms;
+ };
- BufferState(BufferView<R> *buffer)
- : buffer(buffer) {}
-};
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = true;
-template <NDRecordInterface R, ShardInterface<R> S>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=true;
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
+ query->global_parms = *parms;
- static void *get_query_state(S *shard, void *parms) {
- return nullptr;
- }
+ return query;
+ }
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- return new BufferState<R>(buffer);
- }
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
+ query->global_parms = *parms;
+ query->buffer = buffer;
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
- return;
- }
+ return query;
+ }
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- std::vector<Wrapped<R>> results;
- Parms<R> *p = (Parms<R> *) parms;
- Wrapped<R> wrec;
- wrec.rec = p->point;
- wrec.header = 0;
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return;
+ }
- PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(p->k, &wrec);
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ std::vector<LocalResultType> results;
- shard->search(p->point, p->k, pq);
+ Wrapped<R> wrec;
+ wrec.rec = query->global_parms.point;
+ wrec.header = 0;
- while (pq.size() > 0) {
- results.emplace_back(*pq.peek().data);
- pq.pop();
- }
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k,
+ &wrec);
- return results;
+ shard->search(query->global_parms.point, query->global_parms.k, pq);
+
+ while (pq.size() > 0) {
+ results.emplace_back(*pq.peek().data);
+ pq.pop();
}
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- Parms<R> *p = (Parms<R> *) parms;
- BufferState<R> *s = (BufferState<R> *) state;
- Wrapped<R> wrec;
- wrec.rec = p->point;
- wrec.header = 0;
-
- size_t k = p->k;
-
- PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(k, &wrec);
- for (size_t i=0; i<s->buffer->get_record_count(); i++) {
- // Skip over deleted records (under tagging)
- if (s->buffer->get(i)->is_deleted()) {
- continue;
- }
-
- if (pq.size() < k) {
- pq.push(s->buffer->get(i));
- } else {
- double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
- double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec);
-
- if (cur_dist < head_dist) {
- pq.pop();
- pq.push(s->buffer->get(i));
- }
- }
- }
+ return results;
+ }
- std::vector<Wrapped<R>> results;
- while (pq.size() > 0) {
- results.emplace_back(*(pq.peek().data));
- pq.pop();
- }
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
- return std::move(results);
- }
+ std::vector<LocalResultType> results;
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- Parms<R> *p = (Parms<R> *) parms;
- R rec = p->point;
- size_t k = p->k;
-
- PriorityQueue<R, DistCmpMax<R>> pq(k, &rec);
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- if (pq.size() < k) {
- pq.push(&results[i][j].rec);
- } else {
- double head_dist = pq.peek().data->calc_distance(rec);
- double cur_dist = results[i][j].rec.calc_distance(rec);
-
- if (cur_dist < head_dist) {
- pq.pop();
- pq.push(&results[i][j].rec);
- }
- }
- }
- }
+ Wrapped<R> wrec;
+ wrec.rec = query->global_parms.point;
+ wrec.header = 0;
- while (pq.size() > 0) {
- output.emplace_back(*pq.peek().data);
- pq.pop();
- }
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k,
+ &wrec);
+
+ for (size_t i = 0; i < query->buffer->get_record_count(); i++) {
+ // Skip over deleted records (under tagging)
+ if (query->buffer->get(i)->is_deleted()) {
+ continue;
+ }
- return std::move(output);
+ if (pq.size() < query->global_parms.k) {
+ pq.push(query->buffer->get(i));
+ } else {
+ double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
+ double cur_dist = (query->buffer->get(i))->rec.calc_distance(wrec.rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(query->buffer->get(i));
+ }
+ }
}
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
+ while (pq.size() > 0) {
+ results.emplace_back(*(pq.peek().data));
+ pq.pop();
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
+ return std::move(results);
+ }
+
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+
+ PriorityQueue<R, DistCmpMax<R>> pq(parms->k, &(parms->point));
+ for (size_t i = 0; i < local_results.size(); i++) {
+ for (size_t j = 0; j < local_results[i].size(); j++) {
+ if (pq.size() < parms->k) {
+ pq.push(&local_results[i][j].rec);
+ } else {
+ double head_dist = pq.peek().data->calc_distance(parms->point);
+ double cur_dist = local_results[i][j].rec.calc_distance(parms->point);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(&local_results[i][j].rec);
+ }
+ }
+ }
}
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- return false;
+ while (pq.size() > 0) {
+ output.emplace_back(*pq.peek().data);
+ pq.pop();
}
-};
+ }
-}}
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return false;
+ }
+};
+} // namespace knn
+} // namespace de
diff --git a/include/query/pointlookup.h b/include/query/pointlookup.h
index 94c2bce..f3788de 100644
--- a/include/query/pointlookup.h
+++ b/include/query/pointlookup.h
@@ -18,106 +18,102 @@
#include "framework/QueryRequirements.h"
-namespace de { namespace pl {
+namespace de {
+namespace pl {
-template <RecordInterface R>
-struct Parms {
- decltype(R::key) search_key;
-};
+template <ShardInterface S> class Query {
+ typedef typename S::RECORD R;
-template <RecordInterface R>
-struct State {
-};
-
-template <RecordInterface R>
-struct BufferState {
- BufferView<R> *buffer;
-
- BufferState(BufferView<R> *buffer)
- : buffer(buffer) {}
-};
-
-template <KVPInterface R, ShardInterface<R> S>
-class Query {
public:
- constexpr static bool EARLY_ABORT=true;
- constexpr static bool SKIP_DELETE_FILTER=true;
-
- static void *get_query_state(S *shard, void *parms) {
- return nullptr;
- }
-
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- auto res = new BufferState<R>(buffer);
+ struct Parameters {
+ decltype(R::key) search_key;
+ };
- return res;
- }
+ struct LocalQuery {
+ Parameters global_parms;
+ };
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
- return;
+ struct LocalQueryBuffer {
+ BufferView<R> *buffer;
+ Parameters global_parms;
+ };
+
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
+
+ constexpr static bool EARLY_ABORT = true;
+ constexpr static bool SKIP_DELETE_FILTER = true;
+
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
+ query->global_parms = *parms;
+ return query;
+ }
+
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
+ query->buffer = buffer;
+ query->global_parms = *parms;
+
+ return query;
+ }
+
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return;
+ }
+
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ std::vector<LocalResultType> result;
+
+ auto r = shard->point_lookup({query->global_parms.search_key, 0});
+
+ if (r) {
+ result.push_back(*r);
}
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- auto p = (Parms<R> *) parms;
- auto s = (State<R> *) q_state;
-
- std::vector<Wrapped<R>> result;
-
- auto r = shard->point_lookup({p->search_key, 0});
+ return result;
+ }
+
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
+ std::vector<LocalResultType> result;
- if (r) {
- result.push_back(*r);
- }
+ for (size_t i = 0; i < query->buffer->get_record_count(); i++) {
+ auto rec = query->buffer->get(i);
+ if (rec->rec.key == query->global_parms.search_key) {
+ result.push_back(*rec);
return result;
+ }
}
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto p = (Parms<R> *) parms;
- auto s = (BufferState<R> *) state;
-
- std::vector<Wrapped<R>> records;
- for (size_t i=0; i<s->buffer->get_record_count(); i++) {
- auto rec = s->buffer->get(i);
-
- if (rec->rec.key == p->search_key) {
- records.push_back(*rec);
- return records;
- }
+ return result;
+ }
+
+
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+ for (auto r : local_results) {
+ if (r.size() > 0) {
+ if (r[0].is_deleted() || r[0].is_tombstone()) {
+ return;
}
- return records;
- }
-
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- for (auto r : results) {
- if (r.size() > 0) {
- if (r[0].is_deleted() || r[0].is_tombstone()) {
- return output;
- }
-
- output.push_back(r[0].rec);
- return output;
- }
- }
-
- return output;
- }
-
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
- }
-
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
- }
-
-
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- return false;
+ output.push_back(r[0].rec);
+ return;
+ }
}
+ }
+
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return false;
+ }
};
-
-}}
+} // namespace pl
+} // namespace de
diff --git a/include/query/rangecount.h b/include/query/rangecount.h
index 5b95cdd..68d304d 100644
--- a/include/query/rangecount.h
+++ b/include/query/rangecount.h
@@ -5,169 +5,168 @@
*
* Distributed under the Modified BSD License.
*
- * A query class for single dimensional range count queries. This query
- * requires that the shard support get_lower_bound(key) and
+ * A query class for single dimensional range count queries. This query
+ * requires that the shard support get_lower_bound(key) and
* get_record_at(index).
*/
#pragma once
#include "framework/QueryRequirements.h"
-namespace de { namespace rc {
+namespace de {
+namespace rc {
-template <RecordInterface R>
-struct Parms {
+template <ShardInterface S, bool FORCE_SCAN = true> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
decltype(R::key) lower_bound;
decltype(R::key) upper_bound;
-};
+ };
-template <RecordInterface R>
-struct State {
+ struct LocalQuery {
size_t start_idx;
size_t stop_idx;
-};
+ Parameters global_parms;
+ };
-template <RecordInterface R>
-struct BufferState {
+ struct LocalQueryBuffer {
BufferView<R> *buffer;
-
- BufferState(BufferView<R> *buffer)
- : buffer(buffer) {}
-};
-
-template <KVPInterface R, ShardInterface<R> S, bool FORCE_SCAN=false>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=true;
-
- static void *get_query_state(S *shard, void *parms) {
- return nullptr;
- }
-
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- auto res = new BufferState<R>(buffer);
-
- return res;
+ Parameters global_parms;
+ };
+
+ struct LocalResultType {
+ size_t record_count;
+ size_t tombstone_count;
+
+ bool is_deleted() {return false;}
+ bool is_tombstone() {return false;}
+ };
+
+ typedef size_t ResultType;
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = true;
+
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
+
+ query->start_idx = shard->get_lower_bound(parms->lower_bound);
+ query->stop_idx = shard->get_record_count();
+ query->global_parms.lower_bound = parms->lower_bound;
+ query->global_parms.upper_bound = parms->upper_bound;
+
+ return query;
+ }
+
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
+ query->buffer = buffer;
+ query->global_parms.lower_bound = parms->lower_bound;
+ query->global_parms.upper_bound = parms->upper_bound;
+
+ return query;
+ }
+
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return;
+ }
+
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ std::vector<LocalResultType> result;
+
+ /*
+ * if the returned index is one past the end of the
+ * records for the PGM, then there are not records
+ * in the index falling into the specified range.
+ */
+ if (query->start_idx == shard->get_record_count()) {
+ return result;
}
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
- return;
+ auto ptr = shard->get_record_at(query->start_idx);
+ size_t reccnt = 0;
+ size_t tscnt = 0;
+
+ /*
+ * roll the pointer forward to the first record that is
+ * greater than or equal to the lower bound.
+ */
+ while (ptr < shard->get_data() + query->stop_idx &&
+ ptr->rec.key < query->global_parms.lower_bound) {
+ ptr++;
}
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- std::vector<Wrapped<R>> records;
- auto p = (Parms<R> *) parms;
- auto s = (State<R> *) q_state;
-
- size_t reccnt = 0;
- size_t tscnt = 0;
-
- Wrapped<R> res;
- res.rec.key= 0; // records
- res.rec.value = 0; // tombstones
- records.emplace_back(res);
-
-
- auto start_idx = shard->get_lower_bound(p->lower_bound);
- auto stop_idx = shard->get_lower_bound(p->upper_bound);
+ while (ptr < shard->get_data() + query->stop_idx &&
+ ptr->rec.key <= query->global_parms.upper_bound) {
- /*
- * if the returned index is one past the end of the
- * records for the PGM, then there are not records
- * in the index falling into the specified range.
- */
- if (start_idx == shard->get_record_count()) {
- return records;
- }
-
-
- /*
- * roll the pointer forward to the first record that is
- * greater than or equal to the lower bound.
- */
- auto recs = shard->get_data();
- while(start_idx < stop_idx && recs[start_idx].rec.key < p->lower_bound) {
- start_idx++;
- }
-
- while (stop_idx < shard->get_record_count() && recs[stop_idx].rec.key <= p->upper_bound) {
- stop_idx++;
- }
- size_t idx = start_idx;
- size_t ts_cnt = 0;
+ if (!ptr->is_deleted()) {
+ reccnt++;
- while (idx < stop_idx) {
- ts_cnt += recs[idx].is_tombstone() * 2 + recs[idx].is_deleted();
- idx++;
+ if (ptr->is_tombstone()) {
+ tscnt++;
}
+ }
- records[0].rec.key = idx - start_idx;
- records[0].rec.value = ts_cnt;
-
- return records;
+ ptr++;
}
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto p = (Parms<R> *) parms;
- auto s = (BufferState<R> *) state;
-
- std::vector<Wrapped<R>> records;
-
- Wrapped<R> res;
- res.rec.key= 0; // records
- res.rec.value = 0; // tombstones
- records.emplace_back(res);
-
- size_t stop_idx;
- if constexpr (FORCE_SCAN) {
- stop_idx = s->buffer->get_capacity() / 2;
- } else {
- stop_idx = s->buffer->get_record_count();
- }
-
- for (size_t i=0; i<s->buffer->get_record_count(); i++) {
- auto rec = s->buffer->get(i);
-
- if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound
- && !rec->is_deleted()) {
- if (rec->is_tombstone()) {
- records[0].rec.value++;
- } else {
- records[0].rec.key++;
- }
- }
+ result.push_back({reccnt, tscnt});
+ return result;
+ }
+
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
+
+ std::vector<LocalResultType> result;
+ size_t reccnt = 0;
+ size_t tscnt = 0;
+ for (size_t i = 0; i < query->buffer->get_record_count(); i++) {
+ auto rec = query->buffer->get(i);
+ if (rec->rec.key >= query->global_parms.lower_bound &&
+ rec->rec.key <= query->global_parms.upper_bound) {
+ if (!rec->is_deleted()) {
+ reccnt++;
+ if (rec->is_tombstone()) {
+ tscnt++;
+ }
}
-
- return records;
+ }
}
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- R res;
- res.key = 0;
- res.value = 0;
- output.emplace_back(res);
+ result.push_back({reccnt, tscnt});
- for (size_t i=0; i<results.size(); i++) {
- output[0].key += results[i][0].rec.key; // records
- output[0].value += results[i][0].rec.value; // tombstones
- }
+ return result;
+ }
- output[0].key -= output[0].value;
- return output;
- }
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+ size_t reccnt = 0;
+ size_t tscnt = 0;
- static void delete_query_state(void *state) {
+ for (auto &local_result : local_results) {
+ reccnt += local_result[0].record_count;
+ tscnt += local_result[0].tombstone_count;
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
+ /* if more tombstones than results, clamp the output at 0 */
+ if (tscnt > reccnt) {
+ tscnt = reccnt;
}
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- return false;
- }
+ output.push_back({reccnt - tscnt});
+ }
+
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return false;
+ }
};
-}}
+} // namespace rc
+} // namespace de
diff --git a/include/query/rangequery.h b/include/query/rangequery.h
index e0690e6..e7be39c 100644
--- a/include/query/rangequery.h
+++ b/include/query/rangequery.h
@@ -1,177 +1,186 @@
/*
* include/query/rangequery.h
*
- * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ * Copyright (C) 2023-2024 Douglas B. Rumbaugh <drumbaugh@psu.edu>
*
* Distributed under the Modified BSD License.
*
- * A query class for single dimensional range queries. This query requires
+ * A query class for single dimensional range queries. This query requires
* that the shard support get_lower_bound(key) and get_record_at(index).
*/
#pragma once
#include "framework/QueryRequirements.h"
+#include "framework/interface/Record.h"
#include "psu-ds/PriorityQueue.h"
#include "util/Cursor.h"
-namespace de { namespace rq {
+namespace de {
+namespace rq {
-template <RecordInterface R>
-struct Parms {
+template <ShardInterface S> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
decltype(R::key) lower_bound;
decltype(R::key) upper_bound;
-};
+ };
-template <RecordInterface R>
-struct State {
+ struct LocalQuery {
size_t start_idx;
size_t stop_idx;
-};
+ Parameters global_parms;
+ };
-template <RecordInterface R>
-struct BufferState {
+ struct LocalQueryBuffer {
BufferView<R> *buffer;
-
- BufferState(BufferView<R> *buffer)
- : buffer(buffer) {}
-};
-
-template <RecordInterface R, ShardInterface<R> S>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=true;
-
- static void *get_query_state(S *shard, void *parms) {
- auto res = new State<R>();
- auto p = (Parms<R> *) parms;
-
- res->start_idx = shard->get_lower_bound(p->lower_bound);
- res->stop_idx = shard->get_record_count();
-
- return res;
+ Parameters global_parms;
+ };
+
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
+
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = true;
+
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
+
+ query->start_idx = shard->get_lower_bound(parms->lower_bound);
+ query->stop_idx = shard->get_record_count();
+ query->global_parms = *parms;
+
+ return query;
+ }
+
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
+ query->buffer = buffer;
+ query->global_parms = *parms;
+
+ return query;
+ }
+
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return;
+ }
+
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ std::vector<LocalResultType> result;
+
+ /*
+ * if the returned index is one past the end of the
+ * records for the PGM, then there are not records
+ * in the index falling into the specified range.
+ */
+ if (query->start_idx == shard->get_record_count()) {
+ return result;
}
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- auto res = new BufferState<R>(buffer);
+ auto ptr = shard->get_record_at(query->start_idx);
- return res;
+ /*
+ * roll the pointer forward to the first record that is
+ * greater than or equal to the lower bound.
+ */
+ while (ptr < shard->get_data() + query->stop_idx &&
+ ptr->rec.key < query->global_parms.lower_bound) {
+ ptr++;
}
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
- return;
+ while (ptr < shard->get_data() + query->stop_idx &&
+ ptr->rec.key <= query->global_parms.upper_bound) {
+ result.emplace_back(*ptr);
+ ptr++;
}
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- std::vector<Wrapped<R>> records;
- auto p = (Parms<R> *) parms;
- auto s = (State<R> *) q_state;
-
- /*
- * if the returned index is one past the end of the
- * records for the PGM, then there are not records
- * in the index falling into the specified range.
- */
- if (s->start_idx == shard->get_record_count()) {
- return records;
- }
-
- auto ptr = shard->get_record_at(s->start_idx);
-
- /*
- * roll the pointer forward to the first record that is
- * greater than or equal to the lower bound.
- */
- while(ptr < shard->get_data() + s->stop_idx && ptr->rec.key < p->lower_bound) {
- ptr++;
- }
-
- while (ptr < shard->get_data() + s->stop_idx && ptr->rec.key <= p->upper_bound) {
- records.emplace_back(*ptr);
- ptr++;
- }
-
- return records;
- }
+ return result;
+ }
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto p = (Parms<R> *) parms;
- auto s = (BufferState<R> *) state;
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
- std::vector<Wrapped<R>> records;
- for (size_t i=0; i<s->buffer->get_record_count(); i++) {
- auto rec = s->buffer->get(i);
- if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
- records.emplace_back(*rec);
- }
- }
-
- return records;
+ std::vector<LocalResultType> result;
+ for (size_t i = 0; i < query->buffer->get_record_count(); i++) {
+ auto rec = query->buffer->get(i);
+ if (rec->rec.key >= query->global_parms.lower_bound &&
+ rec->rec.key <= query->global_parms.upper_bound) {
+ result.emplace_back(*rec);
+ }
}
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- std::vector<Cursor<Wrapped<R>>> cursors;
- cursors.reserve(results.size());
-
- psudb::PriorityQueue<Wrapped<R>> pq(results.size());
- size_t total = 0;
- size_t tmp_n = results.size();
-
-
- for (size_t i = 0; i < tmp_n; ++i)
- if (results[i].size() > 0){
- auto base = results[i].data();
- cursors.emplace_back(Cursor<Wrapped<R>>{base, base + results[i].size(), 0, results[i].size()});
- assert(i == cursors.size() - 1);
- total += results[i].size();
- pq.push(cursors[i].ptr, tmp_n - i - 1);
- } else {
- cursors.emplace_back(Cursor<Wrapped<R>>{nullptr, nullptr, 0, 0});
- }
-
- if (total == 0) {
- return std::vector<R>();
- }
-
- output.reserve(total);
-
- while (pq.size()) {
- auto now = pq.peek();
- auto next = pq.size() > 1 ? pq.peek(1) : psudb::queue_record<Wrapped<R>>{nullptr, 0};
- if (!now.data->is_tombstone() && next.data != nullptr &&
- now.data->rec == next.data->rec && next.data->is_tombstone()) {
-
- pq.pop(); pq.pop();
- auto& cursor1 = cursors[tmp_n - now.version - 1];
- auto& cursor2 = cursors[tmp_n - next.version - 1];
- if (advance_cursor<Wrapped<R>>(cursor1)) pq.push(cursor1.ptr, now.version);
- if (advance_cursor<Wrapped<R>>(cursor2)) pq.push(cursor2.ptr, next.version);
- } else {
- auto& cursor = cursors[tmp_n - now.version - 1];
- if (!now.data->is_tombstone()) output.push_back(cursor.ptr->rec);
-
- pq.pop();
-
- if (advance_cursor<Wrapped<R>>(cursor)) pq.push(cursor.ptr, now.version);
- }
- }
-
- return output;
+ return result;
+ }
+
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+ std::vector<Cursor<LocalResultType>> cursors;
+ cursors.reserve(local_results.size());
+
+ psudb::PriorityQueue<LocalResultType> pq(local_results.size());
+ size_t total = 0;
+ size_t tmp_n = local_results.size();
+
+ for (size_t i = 0; i < tmp_n; ++i)
+ if (local_results[i].size() > 0) {
+ auto base = local_results[i].data();
+ cursors.emplace_back(Cursor<LocalResultType>{
+ base, base + local_results[i].size(), 0, local_results[i].size()});
+ assert(i == cursors.size() - 1);
+ total += local_results[i].size();
+ pq.push(cursors[i].ptr, tmp_n - i - 1);
+ } else {
+ cursors.emplace_back(Cursor<LocalResultType>{nullptr, nullptr, 0, 0});
+ }
+
+ if (total == 0) {
+ return;
}
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
+ output.reserve(total);
+
+ while (pq.size()) {
+ auto now = pq.peek();
+ auto next = pq.size() > 1
+ ? pq.peek(1)
+ : psudb::queue_record<LocalResultType>{nullptr, 0};
+ if (!now.data->is_tombstone() && next.data != nullptr &&
+ now.data->rec == next.data->rec && next.data->is_tombstone()) {
+
+ pq.pop();
+ pq.pop();
+ auto &cursor1 = cursors[tmp_n - now.version - 1];
+ auto &cursor2 = cursors[tmp_n - next.version - 1];
+ if (advance_cursor<LocalResultType>(cursor1))
+ pq.push(cursor1.ptr, now.version);
+ if (advance_cursor<LocalResultType>(cursor2))
+ pq.push(cursor2.ptr, next.version);
+ } else {
+ auto &cursor = cursors[tmp_n - now.version - 1];
+ if (!now.data->is_tombstone())
+ output.push_back(cursor.ptr->rec);
+
+ pq.pop();
+
+ if (advance_cursor<LocalResultType>(cursor))
+ pq.push(cursor.ptr, now.version);
+ }
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
- }
+ return;
+ }
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- return false;
- }
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return false;
+ }
};
-}}
+} // namespace rq
+} // namespace de
diff --git a/include/query/wirs.h b/include/query/wirs.h
deleted file mode 100644
index 62b43f6..0000000
--- a/include/query/wirs.h
+++ /dev/null
@@ -1,251 +0,0 @@
-/*
- * include/query/wirs.h
- *
- * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
- *
- * Distributed under the Modified BSD License.
- *
- * A query class for weighted independent range sampling. This
- * class is tightly coupled with include/shard/AugBTree.h, and
- * so is probably of limited general utility.
- */
-#pragma once
-
-#include "framework/QueryRequirements.h"
-#include "psu-ds/Alias.h"
-
-namespace de { namespace wirs {
-
-template <WeightedRecordInterface R>
-struct Parms {
- decltype(R::key) lower_bound;
- decltype(R::key) upper_bound;
- size_t sample_size;
- gsl_rng *rng;
-};
-
-template <WeightedRecordInterface R>
-struct State {
- decltype(R::weight) total_weight;
- std::vector<void*> nodes;
- psudb::Alias* top_level_alias;
- size_t sample_size;
-
- State() {
- total_weight = 0;
- top_level_alias = nullptr;
- }
-
- ~State() {
- if (top_level_alias) delete top_level_alias;
- }
-};
-
-template <RecordInterface R>
-struct BufferState {
- size_t cutoff;
- psudb::Alias* alias;
- std::vector<Wrapped<R>> records;
- decltype(R::weight) max_weight;
- size_t sample_size;
- decltype(R::weight) total_weight;
- BufferView<R> *buffer;
-
- ~BufferState() {
- delete alias;
- }
-};
-
-template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=false;
-
- static void *get_query_state(S *shard, void *parms) {
- auto res = new State<R>();
- decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound;
- decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound;
-
- std::vector<decltype(R::weight)> weights;
- res->total_weight = shard->find_covering_nodes(lower_key, upper_key, res->nodes, weights);
-
- std::vector<double> normalized_weights;
- for (auto weight : weights) {
- normalized_weights.emplace_back(weight / res->total_weight);
- }
-
- res->top_level_alias = new psudb::Alias(normalized_weights);
- res->sample_size = 0;
-
- return res;
- }
-
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- BufferState<R> *state = new BufferState<R>();
- auto parameters = (Parms<R>*) parms;
-
- if constexpr (Rejection) {
- state->cutoff = buffer->get_record_count() - 1;
- state->max_weight = buffer->get_max_weight();
- state->total_weight = buffer->get_total_weight();
- state->sample_size = 0;
- state->buffer = buffer;
- return state;
- }
-
- std::vector<decltype(R::weight)> weights;
-
- state->buffer = buffer;
- decltype(R::weight) total_weight = 0;
-
- for (size_t i = 0; i <= buffer->get_record_count(); i++) {
- auto rec = buffer->get(i);
-
- if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) {
- weights.push_back(rec->rec.weight);
- state->records.push_back(*rec);
- total_weight += rec->rec.weight;
- }
- }
-
- std::vector<double> normalized_weights;
- for (size_t i = 0; i < weights.size(); i++) {
- normalized_weights.push_back(weights[i] / total_weight);
- }
-
- state->total_weight = total_weight;
- state->alias = new psudb::Alias(normalized_weights);
- state->sample_size = 0;
-
- return state;
- }
-
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
- auto p = (Parms<R> *) query_parms;
-
- std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
- size_t buffer_sz = 0;
-
- std::vector<decltype(R::weight)> weights;
-
- decltype(R::weight) total_weight = 0;
- for (auto &s : buffer_states) {
- auto bs = (BufferState<R> *) s;
- total_weight += bs->total_weight;
- weights.push_back(bs->total_weight);
- }
-
- for (auto &s : shard_states) {
- auto state = (State<R> *) s;
- total_weight += state->total_weight;
- weights.push_back(state->total_weight);
- }
-
- std::vector<double> normalized_weights;
- for (auto w : weights) {
- normalized_weights.push_back((double) w / (double) total_weight);
- }
-
- auto shard_alias = psudb::Alias(normalized_weights);
- for (size_t i=0; i<p->sample_size; i++) {
- auto idx = shard_alias.get(p->rng);
-
- if (idx < buffer_states.size()) {
- auto state = (BufferState<R> *) buffer_states[idx];
- state->sample_size++;
- } else {
- auto state = (State<R> *) shard_states[idx - buffer_states.size()];
- state->sample_size++;
- }
- }
- }
-
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- auto lower_key = ((Parms<R> *) parms)->lower_bound;
- auto upper_key = ((Parms<R> *) parms)->upper_bound;
- auto rng = ((Parms<R> *) parms)->rng;
-
- auto state = (State<R> *) q_state;
- auto sample_size = state->sample_size;
-
- std::vector<Wrapped<R>> result_set;
-
- if (sample_size == 0) {
- return result_set;
- }
- size_t cnt = 0;
- size_t attempts = 0;
-
- for (size_t i=0; i<sample_size; i++) {
- auto rec = shard->get_weighted_sample(lower_key, upper_key,
- state->nodes[state->top_level_alias->get(rng)],
- rng);
- if (rec) {
- result_set.emplace_back(*rec);
- }
- }
-
- return result_set;
- }
-
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto st = (BufferState<R> *) state;
- auto p = (Parms<R> *) parms;
- auto buffer = st->buffer;
-
- std::vector<Wrapped<R>> result;
- result.reserve(st->sample_size);
-
- if constexpr (Rejection) {
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
- auto rec = buffer->get(idx);
-
- auto test = gsl_rng_uniform(p->rng) * st->max_weight;
-
- if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
- result.emplace_back(*rec);
- }
- }
- return result;
- }
-
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = st->alias->get(p->rng);
- result.emplace_back(st->records[idx]);
- }
-
- return result;
- }
-
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- output.emplace_back(results[i][j].rec);
- }
- }
-
- return output;
- }
-
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
- }
-
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
- }
-
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- auto p = (Parms<R> *) parms;
-
- if (results.size() < p->sample_size) {
- return true;
- }
- return false;
- }
-};
-}}
diff --git a/include/query/wss.h b/include/query/wss.h
index fb0b414..54620ca 100644
--- a/include/query/wss.h
+++ b/include/query/wss.h
@@ -6,7 +6,7 @@
* Distributed under the Modified BSD License.
*
* A query class for weighted set sampling. This
- * class is tightly coupled with include/shard/Alias.h,
+ * class is tightly coupled with include/shard/Alias.h,
* and so is probably of limited general utility.
*/
#pragma once
@@ -14,203 +14,177 @@
#include "framework/QueryRequirements.h"
#include "psu-ds/Alias.h"
-namespace de { namespace wss {
+namespace de {
+namespace wss {
-template <WeightedRecordInterface R>
-struct Parms {
+template <ShardInterface S> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
size_t sample_size;
gsl_rng *rng;
-};
+ };
-template <WeightedRecordInterface R>
-struct State {
- decltype(R::weight) total_weight;
+ struct LocalQuery {
size_t sample_size;
+ decltype(R::weight) total_weight;
- State() {
- total_weight = 0;
- }
-};
+ Parameters global_parms;
+ };
+
+ struct LocalQueryBuffer {
+ BufferView<R> *buffer;
-template <RecordInterface R>
-struct BufferState {
- size_t cutoff;
size_t sample_size;
- psudb::Alias *alias;
- decltype(R::weight) max_weight;
decltype(R::weight) total_weight;
- BufferView<R> *buffer;
+ decltype(R::weight) max_weight;
+ size_t cutoff;
- ~BufferState() {
- delete alias;
- }
-};
+ std::unique_ptr<psudb::Alias> alias;
-template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=false;
+ Parameters global_parms;
+ };
- static void *get_query_state(S *shard, void *parms) {
- auto res = new State<R>();
- res->total_weight = shard->get_total_weight();
- res->sample_size = 0;
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = false;
- return res;
- }
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- BufferState<R> *state = new BufferState<R>();
- auto parameters = (Parms<R>*) parms;
- if constexpr (Rejection) {
- state->cutoff = buffer->get_record_count() - 1;
- state->max_weight = buffer->get_max_weight();
- state->total_weight = buffer->get_total_weight();
- state->buffer = buffer;
- return state;
- }
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
- std::vector<double> weights;
+ query->global_parms = *parms;
+ query->total_weight = shard->get_total_weight();
+ query->sample_size = 0;
- double total_weight = 0.0;
- state->buffer = buffer;
+ return query;
+ }
- for (size_t i = 0; i <= buffer->get_record_count(); i++) {
- auto rec = buffer->get_data(i);
- weights.push_back(rec->rec.weight);
- total_weight += rec->rec.weight;
- }
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
- for (size_t i = 0; i < weights.size(); i++) {
- weights[i] = weights[i] / total_weight;
- }
+ query->cutoff = buffer->get_record_count() - 1;
- state->alias = new psudb::Alias(weights);
- state->total_weight = total_weight;
+ query->max_weight = 0;
+ query->total_weight = 0;
- return state;
- }
+ for (size_t i = 0; i < buffer->get_record_count(); i++) {
+ auto weight = buffer->get(i)->rec.weight;
+ query->total_weight += weight;
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
- auto p = (Parms<R> *) query_parms;
-
- std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
- size_t buffer_sz = 0;
-
- std::vector<decltype(R::weight)> weights;
-
- decltype(R::weight) total_weight = 0;
- for (auto &s : buffer_states) {
- auto bs = (BufferState<R> *) s;
- total_weight += bs->total_weight;
- weights.push_back(bs->total_weight);
- }
-
- for (auto &s : shard_states) {
- auto state = (State<R> *) s;
- total_weight += state->total_weight;
- weights.push_back(state->total_weight);
- }
-
- std::vector<double> normalized_weights;
- for (auto w : weights) {
- normalized_weights.push_back((double) w / (double) total_weight);
- }
-
- auto shard_alias = psudb::Alias(normalized_weights);
- for (size_t i=0; i<p->sample_size; i++) {
- auto idx = shard_alias.get(p->rng);
-
- if (idx < buffer_states.size()) {
- auto state = (BufferState<R> *) buffer_states[idx];
- state->sample_size++;
- } else {
- auto state = (State<R> *) shard_states[idx - buffer_states.size()];
- state->sample_size++;
- }
- }
+ if (weight > query->max_weight) {
+ query->max_weight = weight;
+ }
}
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- auto rng = ((Parms<R> *) parms)->rng;
+ query->buffer = buffer;
+ query->global_parms = *parms;
- auto state = (State<R> *) q_state;
- auto sample_size = state->sample_size;
+ query->alias = nullptr;
- std::vector<Wrapped<R>> result_set;
+ return query;
+ }
- if (sample_size == 0) {
- return result_set;
- }
- size_t attempts = 0;
- do {
- attempts++;
- size_t idx = shard->get_weighted_sample(rng);
- result_set.emplace_back(*shard->get_record_at(idx));
- } while (attempts < sample_size);
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
- return result_set;
+ if (!buffer_query) {
+ assert(local_queries.size() == 1);
+ local_queries[0]->sample_size =
+ local_queries[0]->global_parms.sample_size;
+ return;
}
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto st = (BufferState<R> *) state;
- auto p = (Parms<R> *) parms;
- auto buffer = st->buffer;
+ if (!buffer_query->alias) {
+ std::vector<decltype(R::weight)> weights;
- std::vector<Wrapped<R>> result;
- result.reserve(st->sample_size);
+ decltype(R::weight) total_weight = buffer_query->total_weight;
+ weights.push_back(total_weight);
- if constexpr (Rejection) {
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
- auto rec = buffer->get(idx);
+ for (auto &q : local_queries) {
+ total_weight += q->total_weight;
+ weights.push_back(q->total_weight);
+ q->sample_size = 0;
+ }
- auto test = gsl_rng_uniform(p->rng) * st->max_weight;
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double)w / (double)total_weight);
+ }
- if (test <= rec->rec.weight) {
- result.emplace_back(*rec);
- }
- }
- return result;
- }
+ buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights);
+ }
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = st->alias->get(p->rng);
- result.emplace_back(*(buffer->get_data() + idx));
- }
+ for (size_t i = 0; i < parms->sample_size; i++) {
+ auto idx = buffer_query->alias->get(parms->rng);
- return result;
+ if (idx == 0) {
+ buffer_query->sample_size++;
+ } else {
+ local_queries[idx - 1]->sample_size++;
+ }
}
+ }
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- output.emplace_back(results[i][j].rec);
- }
- }
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ std::vector<LocalResultType> result;
- return output;
+ if (query->sample_size == 0) {
+ return result;
}
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
+ for (size_t i = 0; i < query->sample_size; i++) {
+ size_t idx = shard->get_weighted_sample(query->global_parms.rng);
+ if (!shard->get_record_at(idx)->is_deleted()) {
+ result.emplace_back(*shard->get_record_at(idx));
+ }
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
+ return result;
+ }
+
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
+ std::vector<LocalResultType> result;
+
+ for (size_t i = 0; i < query->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff);
+ auto rec = query->buffer->get(idx);
+
+ auto test = gsl_rng_uniform(query->global_parms.rng) * query->max_weight;
+ if (test <= rec->rec.weight && !rec->is_deleted()) {
+ result.emplace_back(*rec);
+ }
}
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- auto p = (Parms<R> *) parms;
+ return result;
+ }
- if (results.size() < p->sample_size) {
- return true;
- }
- return false;
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+ for (size_t i = 0; i < local_results.size(); i++) {
+ for (size_t j = 0; j < local_results[i].size(); j++) {
+ output.emplace_back(local_results[i][j].rec);
+ }
+ }
+ }
+
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ if (output.size() < parms->sample_size) {
+ parms->sample_size -= output.size();
+ distribute_query(parms, local_queries, buffer_query);
+ return true;
}
-};
-}}
+ return false;
+ }
+};
+} // namespace wss
+} // namespace de