summaryrefslogtreecommitdiffstats
path: root/include/query/wss.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/query/wss.h')
-rw-r--r--include/query/wss.h282
1 files changed, 128 insertions, 154 deletions
diff --git a/include/query/wss.h b/include/query/wss.h
index fb0b414..54620ca 100644
--- a/include/query/wss.h
+++ b/include/query/wss.h
@@ -6,7 +6,7 @@
* Distributed under the Modified BSD License.
*
* A query class for weighted set sampling. This
- * class is tightly coupled with include/shard/Alias.h,
+ * class is tightly coupled with include/shard/Alias.h,
* and so is probably of limited general utility.
*/
#pragma once
@@ -14,203 +14,177 @@
#include "framework/QueryRequirements.h"
#include "psu-ds/Alias.h"
-namespace de { namespace wss {
+namespace de {
+namespace wss {
-template <WeightedRecordInterface R>
-struct Parms {
+template <ShardInterface S> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
size_t sample_size;
gsl_rng *rng;
-};
+ };
-template <WeightedRecordInterface R>
-struct State {
- decltype(R::weight) total_weight;
+ struct LocalQuery {
size_t sample_size;
+ decltype(R::weight) total_weight;
- State() {
- total_weight = 0;
- }
-};
+ Parameters global_parms;
+ };
+
+ struct LocalQueryBuffer {
+ BufferView<R> *buffer;
-template <RecordInterface R>
-struct BufferState {
- size_t cutoff;
size_t sample_size;
- psudb::Alias *alias;
- decltype(R::weight) max_weight;
decltype(R::weight) total_weight;
- BufferView<R> *buffer;
+ decltype(R::weight) max_weight;
+ size_t cutoff;
- ~BufferState() {
- delete alias;
- }
-};
+ std::unique_ptr<psudb::Alias> alias;
-template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=false;
+ Parameters global_parms;
+ };
- static void *get_query_state(S *shard, void *parms) {
- auto res = new State<R>();
- res->total_weight = shard->get_total_weight();
- res->sample_size = 0;
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = false;
- return res;
- }
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- BufferState<R> *state = new BufferState<R>();
- auto parameters = (Parms<R>*) parms;
- if constexpr (Rejection) {
- state->cutoff = buffer->get_record_count() - 1;
- state->max_weight = buffer->get_max_weight();
- state->total_weight = buffer->get_total_weight();
- state->buffer = buffer;
- return state;
- }
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
- std::vector<double> weights;
+ query->global_parms = *parms;
+ query->total_weight = shard->get_total_weight();
+ query->sample_size = 0;
- double total_weight = 0.0;
- state->buffer = buffer;
+ return query;
+ }
- for (size_t i = 0; i <= buffer->get_record_count(); i++) {
- auto rec = buffer->get_data(i);
- weights.push_back(rec->rec.weight);
- total_weight += rec->rec.weight;
- }
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
- for (size_t i = 0; i < weights.size(); i++) {
- weights[i] = weights[i] / total_weight;
- }
+ query->cutoff = buffer->get_record_count() - 1;
- state->alias = new psudb::Alias(weights);
- state->total_weight = total_weight;
+ query->max_weight = 0;
+ query->total_weight = 0;
- return state;
- }
+ for (size_t i = 0; i < buffer->get_record_count(); i++) {
+ auto weight = buffer->get(i)->rec.weight;
+ query->total_weight += weight;
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
- auto p = (Parms<R> *) query_parms;
-
- std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
- size_t buffer_sz = 0;
-
- std::vector<decltype(R::weight)> weights;
-
- decltype(R::weight) total_weight = 0;
- for (auto &s : buffer_states) {
- auto bs = (BufferState<R> *) s;
- total_weight += bs->total_weight;
- weights.push_back(bs->total_weight);
- }
-
- for (auto &s : shard_states) {
- auto state = (State<R> *) s;
- total_weight += state->total_weight;
- weights.push_back(state->total_weight);
- }
-
- std::vector<double> normalized_weights;
- for (auto w : weights) {
- normalized_weights.push_back((double) w / (double) total_weight);
- }
-
- auto shard_alias = psudb::Alias(normalized_weights);
- for (size_t i=0; i<p->sample_size; i++) {
- auto idx = shard_alias.get(p->rng);
-
- if (idx < buffer_states.size()) {
- auto state = (BufferState<R> *) buffer_states[idx];
- state->sample_size++;
- } else {
- auto state = (State<R> *) shard_states[idx - buffer_states.size()];
- state->sample_size++;
- }
- }
+ if (weight > query->max_weight) {
+ query->max_weight = weight;
+ }
}
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- auto rng = ((Parms<R> *) parms)->rng;
+ query->buffer = buffer;
+ query->global_parms = *parms;
- auto state = (State<R> *) q_state;
- auto sample_size = state->sample_size;
+ query->alias = nullptr;
- std::vector<Wrapped<R>> result_set;
+ return query;
+ }
- if (sample_size == 0) {
- return result_set;
- }
- size_t attempts = 0;
- do {
- attempts++;
- size_t idx = shard->get_weighted_sample(rng);
- result_set.emplace_back(*shard->get_record_at(idx));
- } while (attempts < sample_size);
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
- return result_set;
+ if (!buffer_query) {
+ assert(local_queries.size() == 1);
+ local_queries[0]->sample_size =
+ local_queries[0]->global_parms.sample_size;
+ return;
}
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- auto st = (BufferState<R> *) state;
- auto p = (Parms<R> *) parms;
- auto buffer = st->buffer;
+ if (!buffer_query->alias) {
+ std::vector<decltype(R::weight)> weights;
- std::vector<Wrapped<R>> result;
- result.reserve(st->sample_size);
+ decltype(R::weight) total_weight = buffer_query->total_weight;
+ weights.push_back(total_weight);
- if constexpr (Rejection) {
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
- auto rec = buffer->get(idx);
+ for (auto &q : local_queries) {
+ total_weight += q->total_weight;
+ weights.push_back(q->total_weight);
+ q->sample_size = 0;
+ }
- auto test = gsl_rng_uniform(p->rng) * st->max_weight;
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double)w / (double)total_weight);
+ }
- if (test <= rec->rec.weight) {
- result.emplace_back(*rec);
- }
- }
- return result;
- }
+ buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights);
+ }
- for (size_t i=0; i<st->sample_size; i++) {
- auto idx = st->alias->get(p->rng);
- result.emplace_back(*(buffer->get_data() + idx));
- }
+ for (size_t i = 0; i < parms->sample_size; i++) {
+ auto idx = buffer_query->alias->get(parms->rng);
- return result;
+ if (idx == 0) {
+ buffer_query->sample_size++;
+ } else {
+ local_queries[idx - 1]->sample_size++;
+ }
}
+ }
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- output.emplace_back(results[i][j].rec);
- }
- }
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ std::vector<LocalResultType> result;
- return output;
+ if (query->sample_size == 0) {
+ return result;
}
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
+ for (size_t i = 0; i < query->sample_size; i++) {
+ size_t idx = shard->get_weighted_sample(query->global_parms.rng);
+ if (!shard->get_record_at(idx)->is_deleted()) {
+ result.emplace_back(*shard->get_record_at(idx));
+ }
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
+ return result;
+ }
+
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
+ std::vector<LocalResultType> result;
+
+ for (size_t i = 0; i < query->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff);
+ auto rec = query->buffer->get(idx);
+
+ auto test = gsl_rng_uniform(query->global_parms.rng) * query->max_weight;
+ if (test <= rec->rec.weight && !rec->is_deleted()) {
+ result.emplace_back(*rec);
+ }
}
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- auto p = (Parms<R> *) parms;
+ return result;
+ }
- if (results.size() < p->sample_size) {
- return true;
- }
- return false;
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+ for (size_t i = 0; i < local_results.size(); i++) {
+ for (size_t j = 0; j < local_results[i].size(); j++) {
+ output.emplace_back(local_results[i][j].rec);
+ }
+ }
+ }
+
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ if (output.size() < parms->sample_size) {
+ parms->sample_size -= output.size();
+ distribute_query(parms, local_queries, buffer_query);
+ return true;
}
-};
-}}
+ return false;
+ }
+};
+} // namespace wss
+} // namespace de