diff options
| author | Douglas B. Rumbaugh <dbr4@psu.edu> | 2024-12-06 13:13:51 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-06 18:13:51 +0000 |
| commit | 9fe305c7d28e993e55c55427f377ae7e3251ea4f (patch) | |
| tree | 384b687f64b84eb81bde2becac8a5f24916b07b4 /include/query/wss.h | |
| parent | 47916da2ba5ed5bee2dda3cbcc58d39e1e931bfc (diff) | |
| download | dynamic-extension-9fe305c7d28e993e55c55427f377ae7e3251ea4f.tar.gz | |
Interface update (#5)
* Query Interface Adjustments/Refactoring
Began the process of adjusting the query interface (and also the shard
interface, to a lesser degree) to better accommodate the user. In
particular the following changes have been made,
1. The number of necessary template arguments for the query type
has been drastically reduced, while also removing the void pointers
and manual delete functions from the interface.
This was accomplished by requiring many of the sub-types associated
with a query (parameters, etc.) to be nested inside the main query
class, and by forcing the SHARD type to expose its associated
record type.
2. User-defined query return types are now supported.
Queries no longer are required to return strictly sets of records.
Instead, the query now has LocalResultType and ResultType
template parameters (which can be defaulted using a typedef in
the Query type itself), allowing much more flexibility.
Note that, at least for the short term, the LocalResultType must
still expose the same is_deleted/is_tombstone interface as a
Wrapped<R> used to, as this is currently needed for delete
filtering. A better approach to this is, hopefully, forthcoming.
3. Updated the ISAMTree.h shard and rangequery.h query to use the
new interfaces, and adjusted the associated unit tests as well.
4. Dropped the unnecessary "get_data()" function from the ShardInterface
concept.
5. Dropped the need to specify a record type in the ShardInterface
concept. This is now handled using a required Shard::RECORD
member of the Shard class itself, which should expose the name
of the record type.
* Updates to framework to support new Query/Shard interfaces
Pretty extensive adjustments to the framework, particularly to the
templates themselves, along with some type-renaming work, to support
the new query and shard interfaces.
Adjusted the external query interface to take an rvalue reference, rather
than a pointer, to the query parameters.
* Removed framework-level delete filtering
This was causing some issues with the new query interface, and should
probably be reworked anyway, so I'm temporarily (TM) removing the
feature.
* Updated benchmarks + remaining code for new interface
Diffstat (limited to 'include/query/wss.h')
| -rw-r--r-- | include/query/wss.h | 282 |
1 files changed, 128 insertions, 154 deletions
diff --git a/include/query/wss.h b/include/query/wss.h index fb0b414..54620ca 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -6,7 +6,7 @@ * Distributed under the Modified BSD License. * * A query class for weighted set sampling. This - * class is tightly coupled with include/shard/Alias.h, + * class is tightly coupled with include/shard/Alias.h, * and so is probably of limited general utility. */ #pragma once @@ -14,203 +14,177 @@ #include "framework/QueryRequirements.h" #include "psu-ds/Alias.h" -namespace de { namespace wss { +namespace de { +namespace wss { -template <WeightedRecordInterface R> -struct Parms { +template <ShardInterface S> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { size_t sample_size; gsl_rng *rng; -}; + }; -template <WeightedRecordInterface R> -struct State { - decltype(R::weight) total_weight; + struct LocalQuery { size_t sample_size; + decltype(R::weight) total_weight; - State() { - total_weight = 0; - } -}; + Parameters global_parms; + }; + + struct LocalQueryBuffer { + BufferView<R> *buffer; -template <RecordInterface R> -struct BufferState { - size_t cutoff; size_t sample_size; - psudb::Alias *alias; - decltype(R::weight) max_weight; decltype(R::weight) total_weight; - BufferView<R> *buffer; + decltype(R::weight) max_weight; + size_t cutoff; - ~BufferState() { - delete alias; - } -}; + std::unique_ptr<psudb::Alias> alias; -template <RecordInterface R, ShardInterface<R> S, bool Rejection=true> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=false; + Parameters global_parms; + }; - static void *get_query_state(S *shard, void *parms) { - auto res = new State<R>(); - res->total_weight = shard->get_total_weight(); - res->sample_size = 0; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = false; - return res; - } + typedef Wrapped<R> LocalResultType; + typedef R ResultType; - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - BufferState<R> *state = new BufferState<R>(); - auto parameters = (Parms<R>*) parms; - if constexpr (Rejection) { - state->cutoff = buffer->get_record_count() - 1; - state->max_weight = buffer->get_max_weight(); - state->total_weight = buffer->get_total_weight(); - state->buffer = buffer; - return state; - } + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); - std::vector<double> weights; + query->global_parms = *parms; + query->total_weight = shard->get_total_weight(); + query->sample_size = 0; - double total_weight = 0.0; - state->buffer = buffer; + return query; + } - for (size_t i = 0; i <= buffer->get_record_count(); i++) { - auto rec = buffer->get_data(i); - weights.push_back(rec->rec.weight); - total_weight += rec->rec.weight; - } + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); - for (size_t i = 0; i < weights.size(); i++) { - weights[i] = weights[i] / total_weight; - } + query->cutoff = buffer->get_record_count() - 1; - state->alias = new psudb::Alias(weights); - state->total_weight = total_weight; + query->max_weight = 0; + query->total_weight = 0; - return state; - } + for (size_t i = 0; i < buffer->get_record_count(); i++) { + auto weight = buffer->get(i)->rec.weight; + query->total_weight += weight; - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) { - auto p = (Parms<R> *) query_parms; - - std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0); - size_t buffer_sz = 0; - - std::vector<decltype(R::weight)> weights; - - decltype(R::weight) total_weight = 0; - for (auto &s : buffer_states) { - auto bs = (BufferState<R> *) s; - total_weight += bs->total_weight; - weights.push_back(bs->total_weight); - } - - for (auto &s : shard_states) { - auto state = (State<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); - } - - auto shard_alias = psudb::Alias(normalized_weights); - for (size_t i=0; i<p->sample_size; i++) { - auto idx = shard_alias.get(p->rng); - - if (idx < buffer_states.size()) { - auto state = (BufferState<R> *) buffer_states[idx]; - state->sample_size++; - } else { - auto state = (State<R> *) shard_states[idx - buffer_states.size()]; - state->sample_size++; - } - } + if (weight > query->max_weight) { + query->max_weight = weight; + } } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - auto rng = ((Parms<R> *) parms)->rng; + query->buffer = buffer; + query->global_parms = *parms; - auto state = (State<R> *) q_state; - auto sample_size = state->sample_size; + query->alias = nullptr; - std::vector<Wrapped<R>> result_set; + return query; + } - if (sample_size == 0) { - return result_set; - } - size_t attempts = 0; - do { - attempts++; - size_t idx = shard->get_weighted_sample(rng); - result_set.emplace_back(*shard->get_record_at(idx)); - } while (attempts < sample_size); + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { - return result_set; + if (!buffer_query) { + assert(local_queries.size() == 1); + local_queries[0]->sample_size = + local_queries[0]->global_parms.sample_size; + return; } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - auto st = (BufferState<R> *) state; - auto p = (Parms<R> *) parms; - auto buffer = st->buffer; + if (!buffer_query->alias) { + std::vector<decltype(R::weight)> weights; - std::vector<Wrapped<R>> result; - result.reserve(st->sample_size); + decltype(R::weight) total_weight = buffer_query->total_weight; + weights.push_back(total_weight); - if constexpr (Rejection) { - for (size_t i=0; i<st->sample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); - auto rec = buffer->get(idx); + for (auto &q : local_queries) { + total_weight += q->total_weight; + weights.push_back(q->total_weight); + q->sample_size = 0; + } - auto test = gsl_rng_uniform(p->rng) * st->max_weight; + std::vector<double> normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double)w / (double)total_weight); + } - if (test <= rec->rec.weight) { - result.emplace_back(*rec); - } - } - return result; - } + buffer_query->alias = std::make_unique<psudb::Alias>(normalized_weights); + } - for (size_t i=0; i<st->sample_size; i++) { - auto idx = st->alias->get(p->rng); - result.emplace_back(*(buffer->get_data() + idx)); - } + for (size_t i = 0; i < parms->sample_size; i++) { + auto idx = buffer_query->alias->get(parms->rng); - return result; + if (idx == 0) { + buffer_query->sample_size++; + } else { + local_queries[idx - 1]->sample_size++; + } } + } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - output.emplace_back(results[i][j].rec); - } - } + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> result; - return output; + if (query->sample_size == 0) { + return result; } - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; + for (size_t i = 0; i < query->sample_size; i++) { + size_t idx = shard->get_weighted_sample(query->global_parms.rng); + if (!shard->get_record_at(idx)->is_deleted()) { + result.emplace_back(*shard->get_record_at(idx)); + } } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; + return result; + } + + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { + std::vector<LocalResultType> result; + + for (size_t i = 0; i < query->sample_size; i++) { + auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff); + auto rec = query->buffer->get(idx); + + auto test = gsl_rng_uniform(query->global_parms.rng) * query->max_weight; + if (test <= rec->rec.weight && !rec->is_deleted()) { + result.emplace_back(*rec); + } } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - auto p = (Parms<R> *) parms; + return result; + } - if (results.size() < p->sample_size) { - return true; - } - return false; + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + for (size_t i = 0; i < local_results.size(); i++) { + for (size_t j = 0; j < local_results[i].size(); j++) { + output.emplace_back(local_results[i][j].rec); + } + } + } + + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + if (output.size() < parms->sample_size) { + parms->sample_size -= output.size(); + distribute_query(parms, local_queries, buffer_query); + return true; } -}; -}} + return false; + } +}; +} // namespace wss +} // namespace de |