summaryrefslogtreecommitdiffstats
path: root/include/query/wirs.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/query/wirs.h')
-rw-r--r--include/query/wirs.h251
1 files changed, 0 insertions, 251 deletions
diff --git a/include/query/wirs.h b/include/query/wirs.h
deleted file mode 100644
index 62b43f6..0000000
--- a/include/query/wirs.h
+++ /dev/null
@@ -1,251 +0,0 @@
-/*
- * include/query/wirs.h
- *
- * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
- *
- * Distributed under the Modified BSD License.
- *
- * A query class for weighted independent range sampling. This
- * class is tightly coupled with include/shard/AugBTree.h, and
- * so is probably of limited general utility.
- */
-#pragma once
-
-#include "framework/QueryRequirements.h"
-#include "psu-ds/Alias.h"
-
-namespace de { namespace wirs {
-
-template <WeightedRecordInterface R>
-struct Parms {
- decltype(R::key) lower_bound;
- decltype(R::key) upper_bound;
- size_t sample_size;
- gsl_rng *rng;
-};
-
-template <WeightedRecordInterface R>
-struct State {
- decltype(R::weight) total_weight;
- std::vector<void*> nodes;
- psudb::Alias* top_level_alias;
- size_t sample_size;
-
- State() {
- total_weight = 0;
- top_level_alias = nullptr;
- }
-
- ~State() {
- if (top_level_alias) delete top_level_alias;
- }
-};
-
-template <RecordInterface R>
-struct BufferState {
- size_t cutoff;
- psudb::Alias* alias;
- std::vector<Wrapped<R>> records;
- decltype(R::weight) max_weight;
- size_t sample_size;
- decltype(R::weight) total_weight;
- BufferView<R> *buffer;
-
- ~BufferState() {
- delete alias;
- }
-};
-
-template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=false;
-
- static void *get_query_state(S *shard, void *parms) {
- auto res = new State<R>();
- decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound;
- decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound;
-
- std::vector<decltype(R::weight)> weights;
- res->total_weight = shard->find_covering_nodes(lower_key, upper_key, res->nodes, weights);
-
- std::vector<double> normalized_weights;
- for (auto weight : weights) {
- normalized_weights.emplace_back(weight / res->total_weight);
- }
-
- res->top_level_alias = new psudb::Alias(normalized_weights);
- res->sample_size = 0;
-
- return res;
- }
-
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- BufferState<R> *state = new BufferState<R>();
- auto parameters = (Parms<R>*) parms;
-
- if constexpr (Rejection) {
- state->cutoff = buffer->get_record_count() - 1;
- state->max_weight = buffer->get_max_weight();
- state->total_weight = buffer->get_total_weight();
- state->sample_size = 0;
- state->buffer = buffer;
- return state;
- }
-
- std::vector<decltype(R::weight)> weights;
-
- state->buffer = buffer;
- decltype(R::weight) total_weight = 0;
-
- for (size_t i = 0; i <= buffer->get_record_count(); i++) {
- auto rec = buffer->get(i);
-
- if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) {
- weights.push_back(rec->rec.weight);
- state->records.push_back(*rec);
- total_weight += rec->rec.weight;
- }
- }
-
- std::vector<double> normalized_weights;
- for (size_t i = 0; i < weights.size(); i++) {
- normalized_weights.push_back(weights[i] / total_weight);
- }
-
- state->total_weight = total_weight;
- state->alias = new psudb::Alias(normalized_weights);
- state->sample_size = 0;
-
- return state;
- }
-
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
- auto p = (Parms<R> *) query_parms;
-
- std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
- size_t buffer_sz = 0;
-
- std::vector<decltype(R::weight)> weights;
-
- decltype(R::weight) total_weight = 0;
- for (auto &s : buffer_states) {
- auto bs = (BufferState<R> *) s;
- total_weight += bs->total_weight;
- weights.push_back(bs->total_weight);
- }
-
- for (auto &s : shard_states) {
- auto state = (State<R> *) s;
- total_weight += state->total_weight;
- weights.push_back(state->total_weight);
- }
-
- std::vector<double> normalized_weights;
- for (auto w : weights) {
- normalized_weights.push_back((double) w / (double) total_weight);
- }
-
- auto shard_alias = psudb::Alias(normalized_weights);
- for (size_t i=0; i<p->sample_size; i++) {
- auto idx = shard_alias.get(p->rng);
-
- if (idx < buffer_states.size()) {
- auto state = (BufferState<R> *) buffer_states[idx];
- state->sample_size++;
- } else {
- auto state = (State<R> *) shard_states[idx - buffer_states.size()];
- state->sample_size++;
- }
- }
- }
-
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- auto lower_key = ((Parms<R> *) parms)->lower_bound;
- auto upper_key = ((Parms<R> *) parms)->upper_bound;
- auto rng = ((Parms<R> *) parms)->rng;
-
- auto state = (State<R> *) q_state;
- auto sample_size = state->sample_size;
-
- std::vector<Wrapped<R>> result_set;
-
- if (sample_size == 0) {
- return result_set;
- }
- size_t cnt = 0;
- size_t attempts = 0;
-
- for (size_t i=0; i<sample_size; i++) {
- auto rec = shard->get_weighted_sample(lower_key, upper_key,
- state->nodes[state->top_level_alias->get(rng)],
- rng);
- if (rec) {
- result_set.emplace_back(*rec);
- }
- }
-
- return result_set;
- }
-
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto st = (BufferState<R> *) state;
- auto p = (Parms<R> *) parms;
- auto buffer = st->buffer;
-
- std::vector<Wrapped<R>> result;
- result.reserve(st->sample_size);
-
- if constexpr (Rejection) {
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
- auto rec = buffer->get(idx);
-
- auto test = gsl_rng_uniform(p->rng) * st->max_weight;
-
- if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
- result.emplace_back(*rec);
- }
- }
- return result;
- }
-
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = st->alias->get(p->rng);
- result.emplace_back(st->records[idx]);
- }
-
- return result;
- }
-
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- output.emplace_back(results[i][j].rec);
- }
- }
-
- return output;
- }
-
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
- }
-
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
- }
-
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- auto p = (Parms<R> *) parms;
-
- if (results.size() < p->sample_size) {
- return true;
- }
- return false;
- }
-};
-}}