diff options
| author | Douglas B. Rumbaugh <dbr4@psu.edu> | 2024-12-06 13:13:51 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-06 18:13:51 +0000 |
| commit | 9fe305c7d28e993e55c55427f377ae7e3251ea4f (patch) | |
| tree | 384b687f64b84eb81bde2becac8a5f24916b07b4 /include/query | |
| parent | 47916da2ba5ed5bee2dda3cbcc58d39e1e931bfc (diff) | |
| download | dynamic-extension-9fe305c7d28e993e55c55427f377ae7e3251ea4f.tar.gz | |
Interface update (#5)
* Query Interface Adjustments/Refactoring
Began the process of adjusting the query interface (and also the shard
interface, to a lesser degree) to better accommodate the user. In
particular the following changes have been made,
1. The number of necessary template arguments for the query type
has been drastically reduced, while also removing the void pointers
and manual delete functions from the interface.
This was accomplished by requiring many of the sub-types associated
with a query (parameters, etc.) to be nested inside the main query
class, and by forcing the SHARD type to expose its associated
record type.
2. User-defined query return types are now supported.
Queries no longer are required to return strictly sets of records.
Instead, the query now has LocalResultType and ResultType
template parameters (which can be defaulted using a typedef in
the Query type itself), allowing much more flexibility.
Note that, at least for the short term, the LocalResultType must
still expose the same is_deleted/is_tombstone interface as a
Wrapped<R> used to, as this is currently needed for delete
filtering. A better approach to this is, hopefully, forthcoming.
3. Updated the ISAMTree.h shard and rangequery.h query to use the
new interfaces, and adjusted the associated unit tests as well.
4. Dropped the unnecessary "get_data()" function from the ShardInterface
concept.
5. Dropped the need to specify a record type in the ShardInterface
concept. This is now handled using a required Shard::RECORD
member of the Shard class itself, which should expose the name
of the record type.
* Updates to framework to support new Query/Shard interfaces
Pretty extensive adjustments to the framework, particularly to the
templates themselves, along with some type-renaming work, to support
the new query and shard interfaces.
Adjusted the external query interface to take an rvalue reference, rather
than a pointer, to the query parameters.
* Removed framework-level delete filtering
This was causing some issues with the new query interface, and should
probably be reworked anyway, so I'm temporarily (TM) removing the
feature.
* Updated benchmarks + remaining code for new interface
Diffstat (limited to 'include/query')
| -rw-r--r-- | include/query/irs.h | 360 | ||||
| -rw-r--r-- | include/query/knn.h | 224 | ||||
| -rw-r--r-- | include/query/pointlookup.h | 170 | ||||
| -rw-r--r-- | include/query/rangecount.h | 259 | ||||
| -rw-r--r-- | include/query/rangequery.h | 283 | ||||
| -rw-r--r-- | include/query/wirs.h | 251 | ||||
| -rw-r--r-- | include/query/wss.h | 282 |
7 files changed, 773 insertions, 1056 deletions
diff --git a/include/query/irs.h b/include/query/irs.h index 879d070..6dec850 100644 --- a/include/query/irs.h +++ b/include/query/irs.h @@ -1,12 +1,12 @@ /* * include/query/irs.h * - * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu> + * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu> * * Distributed under the Modified BSD License. * - * A query class for independent range sampling. This query requires - * that the shard support get_lower_bound(key), get_upper_bound(key), + * A query class for independent range sampling. This query requires + * that the shard support get_lower_bound(key), get_upper_bound(key), * and get_record_at(index). */ #pragma once @@ -14,237 +14,227 @@ #include "framework/QueryRequirements.h" #include "psu-ds/Alias.h" -namespace de { namespace irs { +namespace de { +namespace irs { -template <RecordInterface R> -struct Parms { +template <ShardInterface S, bool REJECTION = true> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { decltype(R::key) lower_bound; decltype(R::key) upper_bound; size_t sample_size; gsl_rng *rng; -}; + }; - -template <RecordInterface R> -struct State { - size_t lower_bound; - size_t upper_bound; - size_t sample_size; + struct LocalQuery { + size_t lower_idx; + size_t upper_idx; size_t total_weight; -}; + size_t sample_size; + Parameters global_parms; + }; + + struct LocalQueryBuffer { + BufferView<R> *buffer; -template <RecordInterface R> -struct BufferState { size_t cutoff; std::vector<Wrapped<R>> records; + std::unique_ptr<psudb::Alias> alias; size_t sample_size; - BufferView<R> *buffer; - psudb::Alias *alias; + Parameters global_parms; + }; - BufferState(BufferView<R> *buffer) : buffer(buffer) {} - ~BufferState() { - delete alias; - } -}; + typedef Wrapped<R> LocalResultType; + typedef R ResultType; -template <RecordInterface R, ShardInterface<R> S, bool Rejection=true> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=false; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = false; - static void *get_query_state(S *shard, void *parms) { - auto res = new State<R>(); - decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound; - decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound; + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); - res->lower_bound = shard->get_lower_bound(lower_key); - res->upper_bound = shard->get_upper_bound(upper_key); + query->global_parms = *parms; - if (res->lower_bound == shard->get_record_count()) { - res->total_weight = 0; - } else { - res->total_weight = res->upper_bound - res->lower_bound; - } + query->lower_idx = shard->get_lower_bound(query->global_parms.lower_bound); + query->upper_idx = shard->get_upper_bound(query->global_parms.upper_bound); - res->sample_size = 0; - return res; + if (query->lower_idx == shard->get_record_count()) { + query->total_weight = 0; + } else { + query->total_weight = query->upper_idx - query->lower_idx; } - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - auto res = new BufferState<R>(buffer); - - res->cutoff = res->buffer->get_record_count(); - res->sample_size = 0; - res->alias = nullptr; + query->sample_size = 0; + return query; + } - if constexpr (Rejection) { - return res; - } - - auto lower_key = ((Parms<R> *) parms)->lower_bound; - auto upper_key = ((Parms<R> *) parms)->upper_bound; + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); + query->buffer = buffer; - for (size_t i=0; i<res->cutoff; i++) { - if ((res->buffer->get(i)->rec.key >= lower_key) && (buffer->get(i)->rec.key <= upper_key)) { - res->records.emplace_back(*(res->buffer->get(i))); - } - } + query->cutoff = query->buffer->get_record_count(); + query->sample_size = 0; + query->alias = nullptr; + query->global_parms = *parms; - return res; + if constexpr (REJECTION) { + return query; } - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void *buffer_state) { - auto p = (Parms<R> *) query_parms; - auto bs = (buffer_state) ? (BufferState<R> *) buffer_state : nullptr; - - std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0); - size_t buffer_sz = 0; + for (size_t i = 0; i < query->cutoff; i++) { + if ((query->buffer->get(i)->rec.key >= query->global_parms.lower_bound) && + (buffer->get(i)->rec.key <= query->global_parms.upper_bound)) { + query->records.emplace_back(*(query->buffer->get(i))); + } + } - /* for simplicity of static structure testing */ - if (!bs) { - assert(shard_states.size() == 1); - auto state = (State<R> *) shard_states[0]; - state->sample_size = p->sample_size; - return; - } + return query; + } - /* we only need to build the shard alias on the first call */ - if (bs->alias == nullptr) { - std::vector<size_t> weights; - if constexpr (Rejection) { - weights.push_back((bs) ? bs->cutoff : 0); - } else { - weights.push_back((bs) ? bs->records.size() : 0); - } - - size_t total_weight = weights[0]; - for (auto &s : shard_states) { - auto state = (State<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - // if no valid records fall within the query range, just - // set all of the sample sizes to 0 and bail out. - if (total_weight == 0) { - for (size_t i=0; i<shard_states.size(); i++) { - auto state = (State<R> *) shard_states[i]; - state->sample_size = 0; - } - - return; - } - - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); - } - - bs->alias = new psudb::Alias(normalized_weights); - } + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { - for (size_t i=0; i<p->sample_size; i++) { - auto idx = bs->alias->get(p->rng); - if (idx == 0) { - buffer_sz++; - } else { - shard_sample_sizes[idx - 1]++; - } - } + std::vector<size_t> shard_sample_sizes(local_queries.size() + 1, 0); + size_t buffer_sz = 0; - if (bs) { - bs->sample_size = buffer_sz; - } - for (size_t i=0; i<shard_states.size(); i++) { - auto state = (State<R> *) shard_states[i]; - state->sample_size = shard_sample_sizes[i+1]; - } + /* for simplicity of static structure testing */ + if (!buffer_query) { + assert(local_queries.size() == 1); + local_queries[0]->sample_size = + local_queries[0]->global_parms.sample_size; + return; } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - auto lower_key = ((Parms<R> *) parms)->lower_bound; - auto upper_key = ((Parms<R> *) parms)->upper_bound; - auto rng = ((Parms<R> *) parms)->rng; - - auto state = (State<R> *) q_state; - auto sample_sz = state->sample_size; - - std::vector<Wrapped<R>> result_set; - - if (sample_sz == 0 || state->lower_bound == shard->get_record_count()) { - return result_set; + /* we only need to build the shard alias on the first call */ + if (buffer_query->alias == nullptr) { + std::vector<size_t> weights; + if constexpr (REJECTION) { + weights.push_back(buffer_query->cutoff); + } else { + weights.push_back(buffer_query->records.size()); + } + + size_t total_weight = weights[0]; + for (auto &q : local_queries) { + total_weight += q->total_weight; + weights.push_back(q->total_weight); + } + + /* + * if no valid records fall within the query range, + * set all of the sample sizes to 0 and bail out. + */ + if (total_weight == 0) { + for (auto q : local_queries) { + q->sample_size = 0; } - size_t attempts = 0; - size_t range_length = state->upper_bound - state->lower_bound; - do { - attempts++; - size_t idx = (range_length > 0) ? gsl_rng_uniform_int(rng, range_length) : 0; - result_set.emplace_back(*shard->get_record_at(state->lower_bound + idx)); - } while (attempts < sample_sz); + return; + } - return result_set; - } + std::vector<double> normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double)w / (double)total_weight); + } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto st = (BufferState<R> *) state; - auto p = (Parms<R> *) parms; + buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights); + } - std::vector<Wrapped<R>> result; - result.reserve(st->sample_size); + for (size_t i = 0; i < parms->sample_size; i++) { + auto idx = buffer_query->alias->get(parms->rng); + if (idx == 0) { + buffer_sz++; + } else { + shard_sample_sizes[idx - 1]++; + } + } - if constexpr (Rejection) { - for (size_t i=0; i<st->sample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); - auto rec = st->buffer->get(idx); + if (buffer_query) { + buffer_query->sample_size = buffer_sz; + } - if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) { - result.emplace_back(*rec); - } - } + for (size_t i = 0; i < local_queries.size(); i++) { + local_queries[i]->sample_size = shard_sample_sizes[i]; + } + } - return result; - } + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + auto sample_sz = query->sample_size; - for (size_t i=0; i<st->sample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->records.size()); - result.emplace_back(st->records[idx]); - } + std::vector<LocalResultType> result_set; - return result; + if (sample_sz == 0 || query->lower_idx == shard->get_record_count()) { + return result_set; } - static void merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - output.emplace_back(results[i][j].rec); - } + size_t attempts = 0; + size_t range_length = query->upper_idx - query->lower_idx; + do { + attempts++; + size_t idx = + (range_length > 0) + ? gsl_rng_uniform_int(query->global_parms.rng, range_length) + : 0; + result_set.emplace_back(*shard->get_record_at(query->lower_idx + idx)); + } while (attempts < sample_sz); + + return result_set; + } + + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { + std::vector<LocalResultType> result; + result.reserve(query->sample_size); + + if constexpr (REJECTION) { + for (size_t i = 0; i < query->sample_size; i++) { + auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff); + auto rec = query->buffer->get(idx); + + if (rec->rec.key >= query->global_parms.lower_bound && + rec->rec.key <= query->global_parms.upper_bound) { + result.emplace_back(*rec); } - } + } - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; + return result; } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; + for (size_t i = 0; i < query->sample_size; i++) { + auto idx = + gsl_rng_uniform_int(query->global_parms.rng, query->records.size()); + result.emplace_back(query->records[idx]); } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - auto p = (Parms<R> *) parms; - - if (results.size() < p->sample_size) { - auto q = *p; - q.sample_size -= results.size(); - process_query_states(&q, states, buffer_state); - return true; - } + return result; + } - return false; + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + for (size_t i = 0; i < local_results.size(); i++) { + for (size_t j = 0; j < local_results[i].size(); j++) { + output.emplace_back(local_results[i][j].rec); + } } + } + + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + if (output.size() < parms->sample_size) { + parms->sample_size -= output.size(); + distribute_query(parms, local_queries, buffer_query); + return true; + } + + return false; + } }; -}} +} // namespace irs +} // namespace de diff --git a/include/query/knn.h b/include/query/knn.h index a227293..87ea10a 100644 --- a/include/query/knn.h +++ b/include/query/knn.h @@ -6,7 +6,7 @@ * Distributed under the Modified BSD License. * * A query class for k-NN queries, designed for use with the VPTree - * shard. + * shard. * * FIXME: no support for tombstone deletes just yet. This would require a * query resumption mechanism, most likely. @@ -16,147 +16,147 @@ #include "framework/QueryRequirements.h" #include "psu-ds/PriorityQueue.h" -namespace de { namespace knn { +namespace de { +namespace knn { using psudb::PriorityQueue; -template <NDRecordInterface R> -struct Parms { +template <ShardInterface S> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { R point; size_t k; -}; + }; -template <NDRecordInterface R> -struct State { - size_t k; -}; + struct LocalQuery { + Parameters global_parms; + }; -template <NDRecordInterface R> -struct BufferState { + struct LocalQueryBuffer { BufferView<R> *buffer; + Parameters global_parms; + }; - BufferState(BufferView<R> *buffer) - : buffer(buffer) {} -}; + typedef Wrapped<R> LocalResultType; + typedef R ResultType; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = true; -template <NDRecordInterface R, ShardInterface<R> S> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=true; + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); + query->global_parms = *parms; - static void *get_query_state(S *shard, void *parms) { - return nullptr; - } + return query; + } - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - return new BufferState<R>(buffer); - } + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); + query->global_parms = *parms; + query->buffer = buffer; - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) { - return; - } + return query; + } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - std::vector<Wrapped<R>> results; - Parms<R> *p = (Parms<R> *) parms; - Wrapped<R> wrec; - wrec.rec = p->point; - wrec.header = 0; + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return; + } - PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(p->k, &wrec); + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> results; - shard->search(p->point, p->k, pq); + Wrapped<R> wrec; + wrec.rec = query->global_parms.point; + wrec.header = 0; - while (pq.size() > 0) { - results.emplace_back(*pq.peek().data); - pq.pop(); - } + PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k, + &wrec); - return results; + shard->search(query->global_parms.point, query->global_parms.k, pq); + + while (pq.size() > 0) { + results.emplace_back(*pq.peek().data); + pq.pop(); } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - Parms<R> *p = (Parms<R> *) parms; - BufferState<R> *s = (BufferState<R> *) state; - Wrapped<R> wrec; - wrec.rec = p->point; - wrec.header = 0; - - size_t k = p->k; - - PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(k, &wrec); - for (size_t i=0; i<s->buffer->get_record_count(); i++) { - // Skip over deleted records (under tagging) - if (s->buffer->get(i)->is_deleted()) { - continue; - } - - if (pq.size() < k) { - pq.push(s->buffer->get(i)); - } else { - double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); - double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec); - - if (cur_dist < head_dist) { - pq.pop(); - pq.push(s->buffer->get(i)); - } - } - } + return results; + } - std::vector<Wrapped<R>> results; - while (pq.size() > 0) { - results.emplace_back(*(pq.peek().data)); - pq.pop(); - } + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { - return std::move(results); - } + std::vector<LocalResultType> results; - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - Parms<R> *p = (Parms<R> *) parms; - R rec = p->point; - size_t k = p->k; - - PriorityQueue<R, DistCmpMax<R>> pq(k, &rec); - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - if (pq.size() < k) { - pq.push(&results[i][j].rec); - } else { - double head_dist = pq.peek().data->calc_distance(rec); - double cur_dist = results[i][j].rec.calc_distance(rec); - - if (cur_dist < head_dist) { - pq.pop(); - pq.push(&results[i][j].rec); - } - } - } - } + Wrapped<R> wrec; + wrec.rec = query->global_parms.point; + wrec.header = 0; - while (pq.size() > 0) { - output.emplace_back(*pq.peek().data); - pq.pop(); - } + PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k, + &wrec); + + for (size_t i = 0; i < query->buffer->get_record_count(); i++) { + // Skip over deleted records (under tagging) + if (query->buffer->get(i)->is_deleted()) { + continue; + } - return std::move(output); + if (pq.size() < query->global_parms.k) { + pq.push(query->buffer->get(i)); + } else { + double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); + double cur_dist = (query->buffer->get(i))->rec.calc_distance(wrec.rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(query->buffer->get(i)); + } + } } - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; + while (pq.size() > 0) { + results.emplace_back(*(pq.peek().data)); + pq.pop(); } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; + return std::move(results); + } + + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + + PriorityQueue<R, DistCmpMax<R>> pq(parms->k, &(parms->point)); + for (size_t i = 0; i < local_results.size(); i++) { + for (size_t j = 0; j < local_results[i].size(); j++) { + if (pq.size() < parms->k) { + pq.push(&local_results[i][j].rec); + } else { + double head_dist = pq.peek().data->calc_distance(parms->point); + double cur_dist = local_results[i][j].rec.calc_distance(parms->point); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(&local_results[i][j].rec); + } + } + } } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - return false; + while (pq.size() > 0) { + output.emplace_back(*pq.peek().data); + pq.pop(); } -}; + } -}} + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return false; + } +}; +} // namespace knn +} // namespace de diff --git a/include/query/pointlookup.h b/include/query/pointlookup.h index 94c2bce..f3788de 100644 --- a/include/query/pointlookup.h +++ b/include/query/pointlookup.h @@ -18,106 +18,102 @@ #include "framework/QueryRequirements.h" -namespace de { namespace pl { +namespace de { +namespace pl { -template <RecordInterface R> -struct Parms { - decltype(R::key) search_key; -}; +template <ShardInterface S> class Query { + typedef typename S::RECORD R; -template <RecordInterface R> -struct State { -}; - -template <RecordInterface R> -struct BufferState { - BufferView<R> *buffer; - - BufferState(BufferView<R> *buffer) - : buffer(buffer) {} -}; - -template <KVPInterface R, ShardInterface<R> S> -class Query { public: - constexpr static bool EARLY_ABORT=true; - constexpr static bool SKIP_DELETE_FILTER=true; - - static void *get_query_state(S *shard, void *parms) { - return nullptr; - } - - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - auto res = new BufferState<R>(buffer); + struct Parameters { + decltype(R::key) search_key; + }; - return res; - } + struct LocalQuery { + Parameters global_parms; + }; - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) { - return; + struct LocalQueryBuffer { + BufferView<R> *buffer; + Parameters global_parms; + }; + + typedef Wrapped<R> LocalResultType; + typedef R ResultType; + + constexpr static bool EARLY_ABORT = true; + constexpr static bool SKIP_DELETE_FILTER = true; + + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); + query->global_parms = *parms; + return query; + } + + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); + query->buffer = buffer; + query->global_parms = *parms; + + return query; + } + + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return; + } + + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> result; + + auto r = shard->point_lookup({query->global_parms.search_key, 0}); + + if (r) { + result.push_back(*r); } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - auto p = (Parms<R> *) parms; - auto s = (State<R> *) q_state; - - std::vector<Wrapped<R>> result; - - auto r = shard->point_lookup({p->search_key, 0}); + return result; + } + + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { + std::vector<LocalResultType> result; - if (r) { - result.push_back(*r); - } + for (size_t i = 0; i < query->buffer->get_record_count(); i++) { + auto rec = query->buffer->get(i); + if (rec->rec.key == query->global_parms.search_key) { + result.push_back(*rec); return result; + } } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto p = (Parms<R> *) parms; - auto s = (BufferState<R> *) state; - - std::vector<Wrapped<R>> records; - for (size_t i=0; i<s->buffer->get_record_count(); i++) { - auto rec = s->buffer->get(i); - - if (rec->rec.key == p->search_key) { - records.push_back(*rec); - return records; - } + return result; + } + + + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + for (auto r : local_results) { + if (r.size() > 0) { + if (r[0].is_deleted() || r[0].is_tombstone()) { + return; } - return records; - } - - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - for (auto r : results) { - if (r.size() > 0) { - if (r[0].is_deleted() || r[0].is_tombstone()) { - return output; - } - - output.push_back(r[0].rec); - return output; - } - } - - return output; - } - - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; - } - - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; - } - - - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - return false; + output.push_back(r[0].rec); + return; + } } + } + + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return false; + } }; - -}} +} // namespace pl +} // namespace de diff --git a/include/query/rangecount.h b/include/query/rangecount.h index 5b95cdd..68d304d 100644 --- a/include/query/rangecount.h +++ b/include/query/rangecount.h @@ -5,169 +5,168 @@ * * Distributed under the Modified BSD License. * - * A query class for single dimensional range count queries. This query - * requires that the shard support get_lower_bound(key) and + * A query class for single dimensional range count queries. This query + * requires that the shard support get_lower_bound(key) and * get_record_at(index). */ #pragma once #include "framework/QueryRequirements.h" -namespace de { namespace rc { +namespace de { +namespace rc { -template <RecordInterface R> -struct Parms { +template <ShardInterface S, bool FORCE_SCAN = true> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { decltype(R::key) lower_bound; decltype(R::key) upper_bound; -}; + }; -template <RecordInterface R> -struct State { + struct LocalQuery { size_t start_idx; size_t stop_idx; -}; + Parameters global_parms; + }; -template <RecordInterface R> -struct BufferState { + struct LocalQueryBuffer { BufferView<R> *buffer; - - BufferState(BufferView<R> *buffer) - : buffer(buffer) {} -}; - -template <KVPInterface R, ShardInterface<R> S, bool FORCE_SCAN=false> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=true; - - static void *get_query_state(S *shard, void *parms) { - return nullptr; - } - - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - auto res = new BufferState<R>(buffer); - - return res; + Parameters global_parms; + }; + + struct LocalResultType { + size_t record_count; + size_t tombstone_count; + + bool is_deleted() {return false;} + bool is_tombstone() {return false;} + }; + + typedef size_t ResultType; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = true; + + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); + + query->start_idx = shard->get_lower_bound(parms->lower_bound); + query->stop_idx = shard->get_record_count(); + query->global_parms.lower_bound = parms->lower_bound; + query->global_parms.upper_bound = parms->upper_bound; + + return query; + } + + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); + query->buffer = buffer; + query->global_parms.lower_bound = parms->lower_bound; + query->global_parms.upper_bound = parms->upper_bound; + + return query; + } + + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return; + } + + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> result; + + /* + * if the returned index is one past the end of the + * records for the PGM, then there are not records + * in the index falling into the specified range. + */ + if (query->start_idx == shard->get_record_count()) { + return result; } - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) { - return; + auto ptr = shard->get_record_at(query->start_idx); + size_t reccnt = 0; + size_t tscnt = 0; + + /* + * roll the pointer forward to the first record that is + * greater than or equal to the lower bound. + */ + while (ptr < shard->get_data() + query->stop_idx && + ptr->rec.key < query->global_parms.lower_bound) { + ptr++; } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - std::vector<Wrapped<R>> records; - auto p = (Parms<R> *) parms; - auto s = (State<R> *) q_state; - - size_t reccnt = 0; - size_t tscnt = 0; - - Wrapped<R> res; - res.rec.key= 0; // records - res.rec.value = 0; // tombstones - records.emplace_back(res); - - - auto start_idx = shard->get_lower_bound(p->lower_bound); - auto stop_idx = shard->get_lower_bound(p->upper_bound); + while (ptr < shard->get_data() + query->stop_idx && + ptr->rec.key <= query->global_parms.upper_bound) { - /* - * if the returned index is one past the end of the - * records for the PGM, then there are not records - * in the index falling into the specified range. - */ - if (start_idx == shard->get_record_count()) { - return records; - } - - - /* - * roll the pointer forward to the first record that is - * greater than or equal to the lower bound. - */ - auto recs = shard->get_data(); - while(start_idx < stop_idx && recs[start_idx].rec.key < p->lower_bound) { - start_idx++; - } - - while (stop_idx < shard->get_record_count() && recs[stop_idx].rec.key <= p->upper_bound) { - stop_idx++; - } - size_t idx = start_idx; - size_t ts_cnt = 0; + if (!ptr->is_deleted()) { + reccnt++; - while (idx < stop_idx) { - ts_cnt += recs[idx].is_tombstone() * 2 + recs[idx].is_deleted(); - idx++; + if (ptr->is_tombstone()) { + tscnt++; } + } - records[0].rec.key = idx - start_idx; - records[0].rec.value = ts_cnt; - - return records; + ptr++; } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto p = (Parms<R> *) parms; - auto s = (BufferState<R> *) state; - - std::vector<Wrapped<R>> records; - - Wrapped<R> res; - res.rec.key= 0; // records - res.rec.value = 0; // tombstones - records.emplace_back(res); - - size_t stop_idx; - if constexpr (FORCE_SCAN) { - stop_idx = s->buffer->get_capacity() / 2; - } else { - stop_idx = s->buffer->get_record_count(); - } - - for (size_t i=0; i<s->buffer->get_record_count(); i++) { - auto rec = s->buffer->get(i); - - if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound - && !rec->is_deleted()) { - if (rec->is_tombstone()) { - records[0].rec.value++; - } else { - records[0].rec.key++; - } - } + result.push_back({reccnt, tscnt}); + return result; + } + + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { + + std::vector<LocalResultType> result; + size_t reccnt = 0; + size_t tscnt = 0; + for (size_t i = 0; i < query->buffer->get_record_count(); i++) { + auto rec = query->buffer->get(i); + if (rec->rec.key >= query->global_parms.lower_bound && + rec->rec.key <= query->global_parms.upper_bound) { + if (!rec->is_deleted()) { + reccnt++; + if (rec->is_tombstone()) { + tscnt++; + } } - - return records; + } } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - R res; - res.key = 0; - res.value = 0; - output.emplace_back(res); + result.push_back({reccnt, tscnt}); - for (size_t i=0; i<results.size(); i++) { - output[0].key += results[i][0].rec.key; // records - output[0].value += results[i][0].rec.value; // tombstones - } + return result; + } - output[0].key -= output[0].value; - return output; - } + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + size_t reccnt = 0; + size_t tscnt = 0; - static void delete_query_state(void *state) { + for (auto &local_result : local_results) { + reccnt += local_result[0].record_count; + tscnt += local_result[0].tombstone_count; } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; + /* if more tombstones than results, clamp the output at 0 */ + if (tscnt > reccnt) { + tscnt = reccnt; } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - return false; - } + output.push_back({reccnt - tscnt}); + } + + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return false; + } }; -}} +} // namespace rc +} // namespace de diff --git a/include/query/rangequery.h b/include/query/rangequery.h index e0690e6..e7be39c 100644 --- a/include/query/rangequery.h +++ b/include/query/rangequery.h @@ -1,177 +1,186 @@ /* * include/query/rangequery.h * - * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu> + * Copyright (C) 2023-2024 Douglas B. Rumbaugh <drumbaugh@psu.edu> * * Distributed under the Modified BSD License. * - * A query class for single dimensional range queries. This query requires + * A query class for single dimensional range queries. This query requires * that the shard support get_lower_bound(key) and get_record_at(index). */ #pragma once #include "framework/QueryRequirements.h" +#include "framework/interface/Record.h" #include "psu-ds/PriorityQueue.h" #include "util/Cursor.h" -namespace de { namespace rq { +namespace de { +namespace rq { -template <RecordInterface R> -struct Parms { +template <ShardInterface S> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { decltype(R::key) lower_bound; decltype(R::key) upper_bound; -}; + }; -template <RecordInterface R> -struct State { + struct LocalQuery { size_t start_idx; size_t stop_idx; -}; + Parameters global_parms; + }; -template <RecordInterface R> -struct BufferState { + struct LocalQueryBuffer { BufferView<R> *buffer; - - BufferState(BufferView<R> *buffer) - : buffer(buffer) {} -}; - -template <RecordInterface R, ShardInterface<R> S> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=true; - - static void *get_query_state(S *shard, void *parms) { - auto res = new State<R>(); - auto p = (Parms<R> *) parms; - - res->start_idx = shard->get_lower_bound(p->lower_bound); - res->stop_idx = shard->get_record_count(); - - return res; + Parameters global_parms; + }; + + typedef Wrapped<R> LocalResultType; + typedef R ResultType; + + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = true; + + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); + + query->start_idx = shard->get_lower_bound(parms->lower_bound); + query->stop_idx = shard->get_record_count(); + query->global_parms = *parms; + + return query; + } + + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); + query->buffer = buffer; + query->global_parms = *parms; + + return query; + } + + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return; + } + + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> result; + + /* + * if the returned index is one past the end of the + * records for the PGM, then there are not records + * in the index falling into the specified range. + */ + if (query->start_idx == shard->get_record_count()) { + return result; } - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - auto res = new BufferState<R>(buffer); + auto ptr = shard->get_record_at(query->start_idx); - return res; + /* + * roll the pointer forward to the first record that is + * greater than or equal to the lower bound. + */ + while (ptr < shard->get_data() + query->stop_idx && + ptr->rec.key < query->global_parms.lower_bound) { + ptr++; } - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) { - return; + while (ptr < shard->get_data() + query->stop_idx && + ptr->rec.key <= query->global_parms.upper_bound) { + result.emplace_back(*ptr); + ptr++; } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - std::vector<Wrapped<R>> records; - auto p = (Parms<R> *) parms; - auto s = (State<R> *) q_state; - - /* - * if the returned index is one past the end of the - * records for the PGM, then there are not records - * in the index falling into the specified range. - */ - if (s->start_idx == shard->get_record_count()) { - return records; - } - - auto ptr = shard->get_record_at(s->start_idx); - - /* - * roll the pointer forward to the first record that is - * greater than or equal to the lower bound. - */ - while(ptr < shard->get_data() + s->stop_idx && ptr->rec.key < p->lower_bound) { - ptr++; - } - - while (ptr < shard->get_data() + s->stop_idx && ptr->rec.key <= p->upper_bound) { - records.emplace_back(*ptr); - ptr++; - } - - return records; - } + return result; + } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto p = (Parms<R> *) parms; - auto s = (BufferState<R> *) state; + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { - std::vector<Wrapped<R>> records; - for (size_t i=0; i<s->buffer->get_record_count(); i++) { - auto rec = s->buffer->get(i); - if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) { - records.emplace_back(*rec); - } - } - - return records; + std::vector<LocalResultType> result; + for (size_t i = 0; i < query->buffer->get_record_count(); i++) { + auto rec = query->buffer->get(i); + if (rec->rec.key >= query->global_parms.lower_bound && + rec->rec.key <= query->global_parms.upper_bound) { + result.emplace_back(*rec); + } } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - std::vector<Cursor<Wrapped<R>>> cursors; - cursors.reserve(results.size()); - - psudb::PriorityQueue<Wrapped<R>> pq(results.size()); - size_t total = 0; - size_t tmp_n = results.size(); - - - for (size_t i = 0; i < tmp_n; ++i) - if (results[i].size() > 0){ - auto base = results[i].data(); - cursors.emplace_back(Cursor<Wrapped<R>>{base, base + results[i].size(), 0, results[i].size()}); - assert(i == cursors.size() - 1); - total += results[i].size(); - pq.push(cursors[i].ptr, tmp_n - i - 1); - } else { - cursors.emplace_back(Cursor<Wrapped<R>>{nullptr, nullptr, 0, 0}); - } - - if (total == 0) { - return std::vector<R>(); - } - - output.reserve(total); - - while (pq.size()) { - auto now = pq.peek(); - auto next = pq.size() > 1 ? pq.peek(1) : psudb::queue_record<Wrapped<R>>{nullptr, 0}; - if (!now.data->is_tombstone() && next.data != nullptr && - now.data->rec == next.data->rec && next.data->is_tombstone()) { - - pq.pop(); pq.pop(); - auto& cursor1 = cursors[tmp_n - now.version - 1]; - auto& cursor2 = cursors[tmp_n - next.version - 1]; - if (advance_cursor<Wrapped<R>>(cursor1)) pq.push(cursor1.ptr, now.version); - if (advance_cursor<Wrapped<R>>(cursor2)) pq.push(cursor2.ptr, next.version); - } else { - auto& cursor = cursors[tmp_n - now.version - 1]; - if (!now.data->is_tombstone()) output.push_back(cursor.ptr->rec); - - pq.pop(); - - if (advance_cursor<Wrapped<R>>(cursor)) pq.push(cursor.ptr, now.version); - } - } - - return output; + return result; + } + + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + std::vector<Cursor<LocalResultType>> cursors; + cursors.reserve(local_results.size()); + + psudb::PriorityQueue<LocalResultType> pq(local_results.size()); + size_t total = 0; + size_t tmp_n = local_results.size(); + + for (size_t i = 0; i < tmp_n; ++i) + if (local_results[i].size() > 0) { + auto base = local_results[i].data(); + cursors.emplace_back(Cursor<LocalResultType>{ + base, base + local_results[i].size(), 0, local_results[i].size()}); + assert(i == cursors.size() - 1); + total += local_results[i].size(); + pq.push(cursors[i].ptr, tmp_n - i - 1); + } else { + cursors.emplace_back(Cursor<LocalResultType>{nullptr, nullptr, 0, 0}); + } + + if (total == 0) { + return; } - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; + output.reserve(total); + + while (pq.size()) { + auto now = pq.peek(); + auto next = pq.size() > 1 + ? pq.peek(1) + : psudb::queue_record<LocalResultType>{nullptr, 0}; + if (!now.data->is_tombstone() && next.data != nullptr && + now.data->rec == next.data->rec && next.data->is_tombstone()) { + + pq.pop(); + pq.pop(); + auto &cursor1 = cursors[tmp_n - now.version - 1]; + auto &cursor2 = cursors[tmp_n - next.version - 1]; + if (advance_cursor<LocalResultType>(cursor1)) + pq.push(cursor1.ptr, now.version); + if (advance_cursor<LocalResultType>(cursor2)) + pq.push(cursor2.ptr, next.version); + } else { + auto &cursor = cursors[tmp_n - now.version - 1]; + if (!now.data->is_tombstone()) + output.push_back(cursor.ptr->rec); + + pq.pop(); + + if (advance_cursor<LocalResultType>(cursor)) + pq.push(cursor.ptr, now.version); + } } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; - } + return; + } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - return false; - } + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return false; + } }; -}} +} // namespace rq +} // namespace de diff --git a/include/query/wirs.h b/include/query/wirs.h deleted file mode 100644 index 62b43f6..0000000 --- a/include/query/wirs.h +++ /dev/null @@ -1,251 +0,0 @@ -/* - * include/query/wirs.h - * - * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu> - * - * Distributed under the Modified BSD License. - * - * A query class for weighted independent range sampling. This - * class is tightly coupled with include/shard/AugBTree.h, and - * so is probably of limited general utility. - */ -#pragma once - -#include "framework/QueryRequirements.h" -#include "psu-ds/Alias.h" - -namespace de { namespace wirs { - -template <WeightedRecordInterface R> -struct Parms { - decltype(R::key) lower_bound; - decltype(R::key) upper_bound; - size_t sample_size; - gsl_rng *rng; -}; - -template <WeightedRecordInterface R> -struct State { - decltype(R::weight) total_weight; - std::vector<void*> nodes; - psudb::Alias* top_level_alias; - size_t sample_size; - - State() { - total_weight = 0; - top_level_alias = nullptr; - } - - ~State() { - if (top_level_alias) delete top_level_alias; - } -}; - -template <RecordInterface R> -struct BufferState { - size_t cutoff; - psudb::Alias* alias; - std::vector<Wrapped<R>> records; - decltype(R::weight) max_weight; - size_t sample_size; - decltype(R::weight) total_weight; - BufferView<R> *buffer; - - ~BufferState() { - delete alias; - } -}; - -template <RecordInterface R, ShardInterface<R> S, bool Rejection=true> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=false; - - static void *get_query_state(S *shard, void *parms) { - auto res = new State<R>(); - decltype(R::key) lower_key = ((Parms<R> *) parms)->lower_bound; - decltype(R::key) upper_key = ((Parms<R> *) parms)->upper_bound; - - std::vector<decltype(R::weight)> weights; - res->total_weight = shard->find_covering_nodes(lower_key, upper_key, res->nodes, weights); - - std::vector<double> normalized_weights; - for (auto weight : weights) { - normalized_weights.emplace_back(weight / res->total_weight); - } - - res->top_level_alias = new psudb::Alias(normalized_weights); - res->sample_size = 0; - - return res; - } - - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - BufferState<R> *state = new BufferState<R>(); - auto parameters = (Parms<R>*) parms; - - if constexpr (Rejection) { - state->cutoff = buffer->get_record_count() - 1; - state->max_weight = buffer->get_max_weight(); - state->total_weight = buffer->get_total_weight(); - state->sample_size = 0; - state->buffer = buffer; - return state; - } - - std::vector<decltype(R::weight)> weights; - - state->buffer = buffer; - decltype(R::weight) total_weight = 0; - - for (size_t i = 0; i <= buffer->get_record_count(); i++) { - auto rec = buffer->get(i); - - if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) { - weights.push_back(rec->rec.weight); - state->records.push_back(*rec); - total_weight += rec->rec.weight; - } - } - - std::vector<double> normalized_weights; - for (size_t i = 0; i < weights.size(); i++) { - normalized_weights.push_back(weights[i] / total_weight); - } - - state->total_weight = total_weight; - state->alias = new psudb::Alias(normalized_weights); - state->sample_size = 0; - - return state; - } - - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) { - auto p = (Parms<R> *) query_parms; - - std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0); - size_t buffer_sz = 0; - - std::vector<decltype(R::weight)> weights; - - decltype(R::weight) total_weight = 0; - for (auto &s : buffer_states) { - auto bs = (BufferState<R> *) s; - total_weight += bs->total_weight; - weights.push_back(bs->total_weight); - } - - for (auto &s : shard_states) { - auto state = (State<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); - } - - auto shard_alias = psudb::Alias(normalized_weights); - for (size_t i=0; i<p->sample_size; i++) { - auto idx = shard_alias.get(p->rng); - - if (idx < buffer_states.size()) { - auto state = (BufferState<R> *) buffer_states[idx]; - state->sample_size++; - } else { - auto state = (State<R> *) shard_states[idx - buffer_states.size()]; - state->sample_size++; - } - } - } - - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - auto lower_key = ((Parms<R> *) parms)->lower_bound; - auto upper_key = ((Parms<R> *) parms)->upper_bound; - auto rng = ((Parms<R> *) parms)->rng; - - auto state = (State<R> *) q_state; - auto sample_size = state->sample_size; - - std::vector<Wrapped<R>> result_set; - - if (sample_size == 0) { - return result_set; - } - size_t cnt = 0; - size_t attempts = 0; - - for (size_t i=0; i<sample_size; i++) { - auto rec = shard->get_weighted_sample(lower_key, upper_key, - state->nodes[state->top_level_alias->get(rng)], - rng); - if (rec) { - result_set.emplace_back(*rec); - } - } - - return result_set; - } - - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto st = (BufferState<R> *) state; - auto p = (Parms<R> *) parms; - auto buffer = st->buffer; - - std::vector<Wrapped<R>> result; - result.reserve(st->sample_size); - - if constexpr (Rejection) { - for (size_t i=0; i<st->sample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); - auto rec = buffer->get(idx); - - auto test = gsl_rng_uniform(p->rng) * st->max_weight; - - if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) { - result.emplace_back(*rec); - } - } - return result; - } - - for (size_t i=0; i<st->sample_size; i++) { - auto idx = st->alias->get(p->rng); - result.emplace_back(st->records[idx]); - } - - return result; - } - - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - output.emplace_back(results[i][j].rec); - } - } - - return output; - } - - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; - } - - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; - } - - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - auto p = (Parms<R> *) parms; - - if (results.size() < p->sample_size) { - return true; - } - return false; - } -}; -}} diff --git a/include/query/wss.h b/include/query/wss.h index fb0b414..54620ca 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -6,7 +6,7 @@ * Distributed under the Modified BSD License. * * A query class for weighted set sampling. This - * class is tightly coupled with include/shard/Alias.h, + * class is tightly coupled with include/shard/Alias.h, * and so is probably of limited general utility. */ #pragma once @@ -14,203 +14,177 @@ #include "framework/QueryRequirements.h" #include "psu-ds/Alias.h" -namespace de { namespace wss { +namespace de { +namespace wss { -template <WeightedRecordInterface R> -struct Parms { +template <ShardInterface S> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { size_t sample_size; gsl_rng *rng; -}; + }; -template <WeightedRecordInterface R> -struct State { - decltype(R::weight) total_weight; + struct LocalQuery { size_t sample_size; + decltype(R::weight) total_weight; - State() { - total_weight = 0; - } -}; + Parameters global_parms; + }; + + struct LocalQueryBuffer { + BufferView<R> *buffer; -template <RecordInterface R> -struct BufferState { - size_t cutoff; size_t sample_size; - psudb::Alias *alias; - decltype(R::weight) max_weight; decltype(R::weight) total_weight; - BufferView<R> *buffer; + decltype(R::weight) max_weight; + size_t cutoff; - ~BufferState() { - delete alias; - } -}; + std::unique_ptr<psudb::Alias> alias; -template <RecordInterface R, ShardInterface<R> S, bool Rejection=true> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=false; + Parameters global_parms; + }; - static void *get_query_state(S *shard, void *parms) { - auto res = new State<R>(); - res->total_weight = shard->get_total_weight(); - res->sample_size = 0; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = false; - return res; - } + typedef Wrapped<R> LocalResultType; + typedef R ResultType; - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - BufferState<R> *state = new BufferState<R>(); - auto parameters = (Parms<R>*) parms; - if constexpr (Rejection) { - state->cutoff = buffer->get_record_count() - 1; - state->max_weight = buffer->get_max_weight(); - state->total_weight = buffer->get_total_weight(); - state->buffer = buffer; - return state; - } + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); - std::vector<double> weights; + query->global_parms = *parms; + query->total_weight = shard->get_total_weight(); + query->sample_size = 0; - double total_weight = 0.0; - state->buffer = buffer; + return query; + } - for (size_t i = 0; i <= buffer->get_record_count(); i++) { - auto rec = buffer->get_data(i); - weights.push_back(rec->rec.weight); - total_weight += rec->rec.weight; - } + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); - for (size_t i = 0; i < weights.size(); i++) { - weights[i] = weights[i] / total_weight; - } + query->cutoff = buffer->get_record_count() - 1; - state->alias = new psudb::Alias(weights); - state->total_weight = total_weight; + query->max_weight = 0; + query->total_weight = 0; - return state; - } + for (size_t i = 0; i < buffer->get_record_count(); i++) { + auto weight = buffer->get(i)->rec.weight; + query->total_weight += weight; - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) { - auto p = (Parms<R> *) query_parms; - - std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0); - size_t buffer_sz = 0; - - std::vector<decltype(R::weight)> weights; - - decltype(R::weight) total_weight = 0; - for (auto &s : buffer_states) { - auto bs = (BufferState<R> *) s; - total_weight += bs->total_weight; - weights.push_back(bs->total_weight); - } - - for (auto &s : shard_states) { - auto state = (State<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); - } - - auto shard_alias = psudb::Alias(normalized_weights); - for (size_t i=0; i<p->sample_size; i++) { - auto idx = shard_alias.get(p->rng); - - if (idx < buffer_states.size()) { - auto state = (BufferState<R> *) buffer_states[idx]; - state->sample_size++; - } else { - auto state = (State<R> *) shard_states[idx - buffer_states.size()]; - state->sample_size++; - } - } + if (weight > query->max_weight) { + query->max_weight = weight; + } } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - auto rng = ((Parms<R> *) parms)->rng; + query->buffer = buffer; + query->global_parms = *parms; - auto state = (State<R> *) q_state; - auto sample_size = state->sample_size; + query->alias = nullptr; - std::vector<Wrapped<R>> result_set; + return query; + } - if (sample_size == 0) { - return result_set; - } - size_t attempts = 0; - do { - attempts++; - size_t idx = shard->get_weighted_sample(rng); - result_set.emplace_back(*shard->get_record_at(idx)); - } while (attempts < sample_size); + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { - return result_set; + if (!buffer_query) { + assert(local_queries.size() == 1); + local_queries[0]->sample_size = + local_queries[0]->global_parms.sample_size; + return; } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto st = (BufferState<R> *) state; - auto p = (Parms<R> *) parms; - auto buffer = st->buffer; + if (!buffer_query->alias) { + std::vector<decltype(R::weight)> weights; - std::vector<Wrapped<R>> result; - result.reserve(st->sample_size); + decltype(R::weight) total_weight = buffer_query->total_weight; + weights.push_back(total_weight); - if constexpr (Rejection) { - for (size_t i=0; i<st->sample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); - auto rec = buffer->get(idx); + for (auto &q : local_queries) { + total_weight += q->total_weight; + weights.push_back(q->total_weight); + q->sample_size = 0; + } - auto test = gsl_rng_uniform(p->rng) * st->max_weight; + std::vector<double> normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double)w / (double)total_weight); + } - if (test <= rec->rec.weight) { - result.emplace_back(*rec); - } - } - return result; - } + buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights); + } - for (size_t i=0; i<st->sample_size; i++) { - auto idx = st->alias->get(p->rng); - result.emplace_back(*(buffer->get_data() + idx)); - } + for (size_t i = 0; i < parms->sample_size; i++) { + auto idx = buffer_query->alias->get(parms->rng); - return result; + if (idx == 0) { + buffer_query->sample_size++; + } else { + local_queries[idx - 1]->sample_size++; + } } + } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - output.emplace_back(results[i][j].rec); - } - } + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> result; - return output; + if (query->sample_size == 0) { + return result; } - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; + for (size_t i = 0; i < query->sample_size; i++) { + size_t idx = shard->get_weighted_sample(query->global_parms.rng); + if (!shard->get_record_at(idx)->is_deleted()) { + result.emplace_back(*shard->get_record_at(idx)); + } } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; + return result; + } + + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { + std::vector<LocalResultType> result; + + for (size_t i = 0; i < query->sample_size; i++) { + auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff); + auto rec = query->buffer->get(idx); + + auto test = gsl_rng_uniform(query->global_parms.rng) * query->max_weight; + if (test <= rec->rec.weight && !rec->is_deleted()) { + result.emplace_back(*rec); + } } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - auto p = (Parms<R> *) parms; + return result; + } - if (results.size() < p->sample_size) { - return true; - } - return false; + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + for (size_t i = 0; i < local_results.size(); i++) { + for (size_t j = 0; j < local_results[i].size(); j++) { + output.emplace_back(local_results[i][j].rec); + } + } + } + + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + if (output.size() < parms->sample_size) { + parms->sample_size -= output.size(); + distribute_query(parms, local_queries, buffer_query); + return true; } -}; -}} + return false; + } +}; +} // namespace wss +} // namespace de |