diff options
Diffstat (limited to 'include/query/wss.h')
| -rw-r--r-- | include/query/wss.h | 282 |
1 files changed, 128 insertions, 154 deletions
diff --git a/include/query/wss.h b/include/query/wss.h index fb0b414..54620ca 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -6,7 +6,7 @@ * Distributed under the Modified BSD License. * * A query class for weighted set sampling. This - * class is tightly coupled with include/shard/Alias.h, + * class is tightly coupled with include/shard/Alias.h, * and so is probably of limited general utility. */ #pragma once @@ -14,203 +14,177 @@ #include "framework/QueryRequirements.h" #include "psu-ds/Alias.h" -namespace de { namespace wss { +namespace de { +namespace wss { -template <WeightedRecordInterface R> -struct Parms { +template <ShardInterface S> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { size_t sample_size; gsl_rng *rng; -}; + }; -template <WeightedRecordInterface R> -struct State { - decltype(R::weight) total_weight; + struct LocalQuery { size_t sample_size; + decltype(R::weight) total_weight; - State() { - total_weight = 0; - } -}; + Parameters global_parms; + }; + + struct LocalQueryBuffer { + BufferView<R> *buffer; -template <RecordInterface R> -struct BufferState { - size_t cutoff; size_t sample_size; - psudb::Alias *alias; - decltype(R::weight) max_weight; decltype(R::weight) total_weight; - BufferView<R> *buffer; + decltype(R::weight) max_weight; + size_t cutoff; - ~BufferState() { - delete alias; - } -}; + std::unique_ptr<psudb::Alias> alias; -template <RecordInterface R, ShardInterface<R> S, bool Rejection=true> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=false; + Parameters global_parms; + }; - static void *get_query_state(S *shard, void *parms) { - auto res = new State<R>(); - res->total_weight = shard->get_total_weight(); - res->sample_size = 0; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = false; - return res; - } + typedef Wrapped<R> LocalResultType; + typedef R ResultType; - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - BufferState<R> *state = new BufferState<R>(); - auto parameters = (Parms<R>*) parms; - if constexpr (Rejection) { - state->cutoff = buffer->get_record_count() - 1; - state->max_weight = buffer->get_max_weight(); - state->total_weight = buffer->get_total_weight(); - state->buffer = buffer; - return state; - } + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); - std::vector<double> weights; + query->global_parms = *parms; + query->total_weight = shard->get_total_weight(); + query->sample_size = 0; - double total_weight = 0.0; - state->buffer = buffer; + return query; + } - for (size_t i = 0; i <= buffer->get_record_count(); i++) { - auto rec = buffer->get_data(i); - weights.push_back(rec->rec.weight); - total_weight += rec->rec.weight; - } + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); - for (size_t i = 0; i < weights.size(); i++) { - weights[i] = weights[i] / total_weight; - } + query->cutoff = buffer->get_record_count() - 1; - state->alias = new psudb::Alias(weights); - state->total_weight = total_weight; + query->max_weight = 0; + query->total_weight = 0; - return state; - } + for (size_t i = 0; i < buffer->get_record_count(); i++) { + auto weight = buffer->get(i)->rec.weight; + query->total_weight += weight; - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) { - auto p = (Parms<R> *) query_parms; - - std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0); - size_t buffer_sz = 0; - - std::vector<decltype(R::weight)> weights; - - decltype(R::weight) total_weight = 0; - for (auto &s : buffer_states) { - auto bs = (BufferState<R> *) s; - total_weight += bs->total_weight; - weights.push_back(bs->total_weight); - } - - for (auto &s : shard_states) { - auto state = (State<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); - } - - auto shard_alias = psudb::Alias(normalized_weights); - for (size_t i=0; i<p->sample_size; i++) { - auto idx = shard_alias.get(p->rng); - - if (idx < buffer_states.size()) { - auto state = (BufferState<R> *) buffer_states[idx]; - state->sample_size++; - } else { - auto state = (State<R> *) shard_states[idx - buffer_states.size()]; - state->sample_size++; - } - } + if (weight > query->max_weight) { + query->max_weight = weight; + } } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - auto rng = ((Parms<R> *) parms)->rng; + query->buffer = buffer; + query->global_parms = *parms; - auto state = (State<R> *) q_state; - auto sample_size = state->sample_size; + query->alias = nullptr; - std::vector<Wrapped<R>> result_set; + return query; + } - if (sample_size == 0) { - return result_set; - } - size_t attempts = 0; - do { - attempts++; - size_t idx = shard->get_weighted_sample(rng); - result_set.emplace_back(*shard->get_record_at(idx)); - } while (attempts < sample_size); + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { - return result_set; + if (!buffer_query) { + assert(local_queries.size() == 1); + local_queries[0]->sample_size = + local_queries[0]->global_parms.sample_size; + return; } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto st = (BufferState<R> *) state; - auto p = (Parms<R> *) parms; - auto buffer = st->buffer; + if (!buffer_query->alias) { + std::vector<decltype(R::weight)> weights; - std::vector<Wrapped<R>> result; - result.reserve(st->sample_size); + decltype(R::weight) total_weight = buffer_query->total_weight; + weights.push_back(total_weight); - if constexpr (Rejection) { - for (size_t i=0; i<st->sample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); - auto rec = buffer->get(idx); + for (auto &q : local_queries) { + total_weight += q->total_weight; + weights.push_back(q->total_weight); + q->sample_size = 0; + } - auto test = gsl_rng_uniform(p->rng) * st->max_weight; + std::vector<double> normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double)w / (double)total_weight); + } - if (test <= rec->rec.weight) { - result.emplace_back(*rec); - } - } - return result; - } + buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights); + } - for (size_t i=0; i<st->sample_size; i++) { - auto idx = st->alias->get(p->rng); - result.emplace_back(*(buffer->get_data() + idx)); - } + for (size_t i = 0; i < parms->sample_size; i++) { + auto idx = buffer_query->alias->get(parms->rng); - return result; + if (idx == 0) { + buffer_query->sample_size++; + } else { + local_queries[idx - 1]->sample_size++; + } } + } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - output.emplace_back(results[i][j].rec); - } - } + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> result; - return output; + if (query->sample_size == 0) { + return result; } - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; + for (size_t i = 0; i < query->sample_size; i++) { + size_t idx = shard->get_weighted_sample(query->global_parms.rng); + if (!shard->get_record_at(idx)->is_deleted()) { + result.emplace_back(*shard->get_record_at(idx)); + } } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; + return result; + } + + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { + std::vector<LocalResultType> result; + + for (size_t i = 0; i < query->sample_size; i++) { + auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff); + auto rec = query->buffer->get(idx); + + auto test = gsl_rng_uniform(query->global_parms.rng) * query->max_weight; + if (test <= rec->rec.weight && !rec->is_deleted()) { + result.emplace_back(*rec); + } } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - auto p = (Parms<R> *) parms; + return result; + } - if (results.size() < p->sample_size) { - return true; - } - return false; + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + for (size_t i = 0; i < local_results.size(); i++) { + for (size_t j = 0; j < local_results[i].size(); j++) { + output.emplace_back(local_results[i][j].rec); + } + } + } + + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + if (output.size() < parms->sample_size) { + parms->sample_size -= output.size(); + distribute_query(parms, local_queries, buffer_query); + return true; } -}; -}} + return false; + } +}; +} // namespace wss +} // namespace de |