summaryrefslogtreecommitdiffstats
path: root/include/query/irs.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/query/irs.h')
-rw-r--r--include/query/irs.h360
1 files changed, 175 insertions, 185 deletions
diff --git a/include/query/irs.h b/include/query/irs.h
index 879d070..6dec850 100644
--- a/include/query/irs.h
+++ b/include/query/irs.h
@@ -1,12 +1,12 @@
/*
* include/query/irs.h
*
- * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
*
* Distributed under the Modified BSD License.
*
- * A query class for independent range sampling. This query requires
- * that the shard support get_lower_bound(key), get_upper_bound(key),
+ * A query class for independent range sampling. This query requires
+ * that the shard support get_lower_bound(key), get_upper_bound(key),
* and get_record_at(index).
*/
#pragma once
@@ -14,237 +14,227 @@
#include "framework/QueryRequirements.h"
#include "psu-ds/Alias.h"
-namespace de { namespace irs {
+namespace de {
+namespace irs {
-template <RecordInterface R>
-struct Parms {
+template <ShardInterface S, bool REJECTION = true> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
decltype(R::key) lower_bound;
decltype(R::key) upper_bound;
size_t sample_size;
gsl_rng *rng;
-};
+ };
-
-template <RecordInterface R>
-struct State {
- size_t lower_bound;
- size_t upper_bound;
- size_t sample_size;
+ struct LocalQuery {
+ size_t lower_idx;
+ size_t upper_idx;
size_t total_weight;
-};
+ size_t sample_size;
+ Parameters global_parms;
+ };
+
+ struct LocalQueryBuffer {
+ BufferView<R> *buffer;
-template <RecordInterface R>
-struct BufferState {
size_t cutoff;
std::vector<Wrapped<R>> records;
+ std::unique_ptr<psudb::Alias> alias;
size_t sample_size;
- BufferView<R> *buffer;
- psudb::Alias *alias;
+ Parameters global_parms;
+ };
- BufferState(BufferView<R> *buffer) : buffer(buffer) {}
- ~BufferState() {
- delete alias;
- }
-};
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
-template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=false;
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = false;
- static void *get_query_state(S *shard, void *parms) {
- auto res = new State<R>();
- decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound;
- decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound;
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
- res->lower_bound = shard->get_lower_bound(lower_key);
- res->upper_bound = shard->get_upper_bound(upper_key);
+ query->global_parms = *parms;
- if (res->lower_bound == shard->get_record_count()) {
- res->total_weight = 0;
- } else {
- res->total_weight = res->upper_bound - res->lower_bound;
- }
+ query->lower_idx = shard->get_lower_bound(query->global_parms.lower_bound);
+ query->upper_idx = shard->get_upper_bound(query->global_parms.upper_bound);
- res->sample_size = 0;
- return res;
+ if (query->lower_idx == shard->get_record_count()) {
+ query->total_weight = 0;
+ } else {
+ query->total_weight = query->upper_idx - query->lower_idx;
}
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- auto res = new BufferState<R>(buffer);
-
- res->cutoff = res->buffer->get_record_count();
- res->sample_size = 0;
- res->alias = nullptr;
+ query->sample_size = 0;
+ return query;
+ }
- if constexpr (Rejection) {
- return res;
- }
-
- auto lower_key = ((Parms<R> *) parms)->lower_bound;
- auto upper_key = ((Parms<R> *) parms)->upper_bound;
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
+ query->buffer = buffer;
- for (size_t i=0; i<res->cutoff; i++) {
- if ((res->buffer->get(i)->rec.key >= lower_key) && (buffer->get(i)->rec.key <= upper_key)) {
- res->records.emplace_back(*(res->buffer->get(i)));
- }
- }
+ query->cutoff = query->buffer->get_record_count();
+ query->sample_size = 0;
+ query->alias = nullptr;
+ query->global_parms = *parms;
- return res;
+ if constexpr (REJECTION) {
+ return query;
}
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void *buffer_state) {
- auto p = (Parms<R> *) query_parms;
- auto bs = (buffer_state) ? (BufferState<R> *) buffer_state : nullptr;
-
- std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
- size_t buffer_sz = 0;
+ for (size_t i = 0; i < query->cutoff; i++) {
+ if ((query->buffer->get(i)->rec.key >= query->global_parms.lower_bound) &&
+ (buffer->get(i)->rec.key <= query->global_parms.upper_bound)) {
+ query->records.emplace_back(*(query->buffer->get(i)));
+ }
+ }
- /* for simplicity of static structure testing */
- if (!bs) {
- assert(shard_states.size() == 1);
- auto state = (State<R> *) shard_states[0];
- state->sample_size = p->sample_size;
- return;
- }
+ return query;
+ }
- /* we only need to build the shard alias on the first call */
- if (bs->alias == nullptr) {
- std::vector<size_t> weights;
- if constexpr (Rejection) {
- weights.push_back((bs) ? bs->cutoff : 0);
- } else {
- weights.push_back((bs) ? bs->records.size() : 0);
- }
-
- size_t total_weight = weights[0];
- for (auto &s : shard_states) {
- auto state = (State<R> *) s;
- total_weight += state->total_weight;
- weights.push_back(state->total_weight);
- }
-
- // if no valid records fall within the query range, just
- // set all of the sample sizes to 0 and bail out.
- if (total_weight == 0) {
- for (size_t i=0; i<shard_states.size(); i++) {
- auto state = (State<R> *) shard_states[i];
- state->sample_size = 0;
- }
-
- return;
- }
-
- std::vector<double> normalized_weights;
- for (auto w : weights) {
- normalized_weights.push_back((double) w / (double) total_weight);
- }
-
- bs->alias = new psudb::Alias(normalized_weights);
- }
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
- for (size_t i=0; i<p->sample_size; i++) {
- auto idx = bs->alias->get(p->rng);
- if (idx == 0) {
- buffer_sz++;
- } else {
- shard_sample_sizes[idx - 1]++;
- }
- }
+ std::vector<size_t> shard_sample_sizes(local_queries.size() + 1, 0);
+ size_t buffer_sz = 0;
- if (bs) {
- bs->sample_size = buffer_sz;
- }
- for (size_t i=0; i<shard_states.size(); i++) {
- auto state = (State<R> *) shard_states[i];
- state->sample_size = shard_sample_sizes[i+1];
- }
+ /* for simplicity of static structure testing */
+ if (!buffer_query) {
+ assert(local_queries.size() == 1);
+ local_queries[0]->sample_size =
+ local_queries[0]->global_parms.sample_size;
+ return;
}
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- auto lower_key = ((Parms<R> *) parms)->lower_bound;
- auto upper_key = ((Parms<R> *) parms)->upper_bound;
- auto rng = ((Parms<R> *) parms)->rng;
-
- auto state = (State<R> *) q_state;
- auto sample_sz = state->sample_size;
-
- std::vector<Wrapped<R>> result_set;
-
- if (sample_sz == 0 || state->lower_bound == shard->get_record_count()) {
- return result_set;
+ /* we only need to build the shard alias on the first call */
+ if (buffer_query->alias == nullptr) {
+ std::vector<size_t> weights;
+ if constexpr (REJECTION) {
+ weights.push_back(buffer_query->cutoff);
+ } else {
+ weights.push_back(buffer_query->records.size());
+ }
+
+ size_t total_weight = weights[0];
+ for (auto &q : local_queries) {
+ total_weight += q->total_weight;
+ weights.push_back(q->total_weight);
+ }
+
+ /*
+ * if no valid records fall within the query range,
+ * set all of the sample sizes to 0 and bail out.
+ */
+ if (total_weight == 0) {
+ for (auto q : local_queries) {
+ q->sample_size = 0;
}
- size_t attempts = 0;
- size_t range_length = state->upper_bound - state->lower_bound;
- do {
- attempts++;
- size_t idx = (range_length > 0) ? gsl_rng_uniform_int(rng, range_length) : 0;
- result_set.emplace_back(*shard->get_record_at(state->lower_bound + idx));
- } while (attempts < sample_sz);
+ return;
+ }
- return result_set;
- }
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double)w / (double)total_weight);
+ }
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto st = (BufferState<R> *) state;
- auto p = (Parms<R> *) parms;
+ buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights);
+ }
- std::vector<Wrapped<R>> result;
- result.reserve(st->sample_size);
+ for (size_t i = 0; i < parms->sample_size; i++) {
+ auto idx = buffer_query->alias->get(parms->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
- if constexpr (Rejection) {
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
- auto rec = st->buffer->get(idx);
+ if (buffer_query) {
+ buffer_query->sample_size = buffer_sz;
+ }
- if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
- result.emplace_back(*rec);
- }
- }
+ for (size_t i = 0; i < local_queries.size(); i++) {
+ local_queries[i]->sample_size = shard_sample_sizes[i];
+ }
+ }
- return result;
- }
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ auto sample_sz = query->sample_size;
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->records.size());
- result.emplace_back(st->records[idx]);
- }
+ std::vector<LocalResultType> result_set;
- return result;
+ if (sample_sz == 0 || query->lower_idx == shard->get_record_count()) {
+ return result_set;
}
- static void merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- output.emplace_back(results[i][j].rec);
- }
+ size_t attempts = 0;
+ size_t range_length = query->upper_idx - query->lower_idx;
+ do {
+ attempts++;
+ size_t idx =
+ (range_length > 0)
+ ? gsl_rng_uniform_int(query->global_parms.rng, range_length)
+ : 0;
+ result_set.emplace_back(*shard->get_record_at(query->lower_idx + idx));
+ } while (attempts < sample_sz);
+
+ return result_set;
+ }
+
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
+ std::vector<LocalResultType> result;
+ result.reserve(query->sample_size);
+
+ if constexpr (REJECTION) {
+ for (size_t i = 0; i < query->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff);
+ auto rec = query->buffer->get(idx);
+
+ if (rec->rec.key >= query->global_parms.lower_bound &&
+ rec->rec.key <= query->global_parms.upper_bound) {
+ result.emplace_back(*rec);
}
- }
+ }
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
+ return result;
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
+ for (size_t i = 0; i < query->sample_size; i++) {
+ auto idx =
+ gsl_rng_uniform_int(query->global_parms.rng, query->records.size());
+ result.emplace_back(query->records[idx]);
}
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- auto p = (Parms<R> *) parms;
-
- if (results.size() < p->sample_size) {
- auto q = *p;
- q.sample_size -= results.size();
- process_query_states(&q, states, buffer_state);
- return true;
- }
+ return result;
+ }
- return false;
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+ for (size_t i = 0; i < local_results.size(); i++) {
+ for (size_t j = 0; j < local_results[i].size(); j++) {
+ output.emplace_back(local_results[i][j].rec);
+ }
}
+ }
+
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ if (output.size() < parms->sample_size) {
+ parms->sample_size -= output.size();
+ distribute_query(parms, local_queries, buffer_query);
+ return true;
+ }
+
+ return false;
+ }
};
-}}
+} // namespace irs
+} // namespace de