From 9fe305c7d28e993e55c55427f377ae7e3251ea4f Mon Sep 17 00:00:00 2001 From: "Douglas B. Rumbaugh" Date: Fri, 6 Dec 2024 13:13:51 -0500 Subject: Interface update (#5) * Query Interface Adjustments/Refactoring Began the process of adjusting the query interface (and also the shard interface, to a lesser degree) to better accommodate the user. In particular the following changes have been made, 1. The number of necessary template arguments for the query type has been drastically reduced, while also removing the void pointers and manual delete functions from the interface. This was accomplished by requiring many of the sub-types associated with a query (parameters, etc.) to be nested inside the main query class, and by forcing the SHARD type to expose its associated record type. 2. User-defined query return types are now supported. Queries no longer are required to return strictly sets of records. Instead, the query now has LocalResultType and ResultType template parameters (which can be defaulted using a typedef in the Query type itself), allowing much more flexibility. Note that, at least for the short term, the LocalResultType must still expose the same is_deleted/is_tombstone interface as a Wrapped used to, as this is currently needed for delete filtering. A better approach to this is, hopefully, forthcoming. 3. Updated the ISAMTree.h shard and rangequery.h query to use the new interfaces, and adjusted the associated unit tests as well. 4. Dropped the unnecessary "get_data()" function from the ShardInterface concept. 5. Dropped the need to specify a record type in the ShardInterface concept. This is now handled using a required Shard::RECORD member of the Shard class itself, which should expose the name of the record type. * Updates to framework to support new Query/Shard interfaces Pretty extensive adjustments to the framework, particularly to the templates themselves, along with some type-renaming work, to support the new query and shard interfaces. Adjusted the external query interface to take an rvalue reference, rather than a pointer, to the query parameters. * Removed framework-level delete filtering This was causing some issues with the new query interface, and should probably be reworked anyway, so I'm temporarily (TM) removing the feature. * Updated benchmarks + remaining code for new interface --- include/query/irs.h | 360 +++++++++++++++++++++++++--------------------------- 1 file changed, 175 insertions(+), 185 deletions(-) (limited to 'include/query/irs.h') diff --git a/include/query/irs.h b/include/query/irs.h index 879d070..6dec850 100644 --- a/include/query/irs.h +++ b/include/query/irs.h @@ -1,12 +1,12 @@ /* * include/query/irs.h * - * Copyright (C) 2023 Douglas B. Rumbaugh + * Copyright (C) 2023 Douglas B. Rumbaugh * * Distributed under the Modified BSD License. * - * A query class for independent range sampling. This query requires - * that the shard support get_lower_bound(key), get_upper_bound(key), + * A query class for independent range sampling. This query requires + * that the shard support get_lower_bound(key), get_upper_bound(key), * and get_record_at(index). */ #pragma once @@ -14,237 +14,227 @@ #include "framework/QueryRequirements.h" #include "psu-ds/Alias.h" -namespace de { namespace irs { +namespace de { +namespace irs { -template -struct Parms { +template class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { decltype(R::key) lower_bound; decltype(R::key) upper_bound; size_t sample_size; gsl_rng *rng; -}; + }; - -template -struct State { - size_t lower_bound; - size_t upper_bound; - size_t sample_size; + struct LocalQuery { + size_t lower_idx; + size_t upper_idx; size_t total_weight; -}; + size_t sample_size; + Parameters global_parms; + }; + + struct LocalQueryBuffer { + BufferView *buffer; -template -struct BufferState { size_t cutoff; std::vector> records; + std::unique_ptr alias; size_t sample_size; - BufferView *buffer; - psudb::Alias *alias; + Parameters global_parms; + }; - BufferState(BufferView *buffer) : buffer(buffer) {} - ~BufferState() { - delete alias; - } -}; + typedef Wrapped LocalResultType; + typedef R ResultType; -template S, bool Rejection=true> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=false; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = false; - static void *get_query_state(S *shard, void *parms) { - auto res = new State(); - decltype(R::key) lower_key = ((Parms *) parms)->lower_bound; - decltype(R::key) upper_key = ((Parms *) parms)->upper_bound; + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); - res->lower_bound = shard->get_lower_bound(lower_key); - res->upper_bound = shard->get_upper_bound(upper_key); + query->global_parms = *parms; - if (res->lower_bound == shard->get_record_count()) { - res->total_weight = 0; - } else { - res->total_weight = res->upper_bound - res->lower_bound; - } + query->lower_idx = shard->get_lower_bound(query->global_parms.lower_bound); + query->upper_idx = shard->get_upper_bound(query->global_parms.upper_bound); - res->sample_size = 0; - return res; + if (query->lower_idx == shard->get_record_count()) { + query->total_weight = 0; + } else { + query->total_weight = query->upper_idx - query->lower_idx; } - static void* get_buffer_query_state(BufferView *buffer, void *parms) { - auto res = new BufferState(buffer); - - res->cutoff = res->buffer->get_record_count(); - res->sample_size = 0; - res->alias = nullptr; + query->sample_size = 0; + return query; + } - if constexpr (Rejection) { - return res; - } - - auto lower_key = ((Parms *) parms)->lower_bound; - auto upper_key = ((Parms *) parms)->upper_bound; + static LocalQueryBuffer *local_preproc_buffer(BufferView *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); + query->buffer = buffer; - for (size_t i=0; icutoff; i++) { - if ((res->buffer->get(i)->rec.key >= lower_key) && (buffer->get(i)->rec.key <= upper_key)) { - res->records.emplace_back(*(res->buffer->get(i))); - } - } + query->cutoff = query->buffer->get_record_count(); + query->sample_size = 0; + query->alias = nullptr; + query->global_parms = *parms; - return res; + if constexpr (REJECTION) { + return query; } - static void process_query_states(void *query_parms, std::vector &shard_states, void *buffer_state) { - auto p = (Parms *) query_parms; - auto bs = (buffer_state) ? (BufferState *) buffer_state : nullptr; - - std::vector shard_sample_sizes(shard_states.size()+1, 0); - size_t buffer_sz = 0; + for (size_t i = 0; i < query->cutoff; i++) { + if ((query->buffer->get(i)->rec.key >= query->global_parms.lower_bound) && + (buffer->get(i)->rec.key <= query->global_parms.upper_bound)) { + query->records.emplace_back(*(query->buffer->get(i))); + } + } - /* for simplicity of static structure testing */ - if (!bs) { - assert(shard_states.size() == 1); - auto state = (State *) shard_states[0]; - state->sample_size = p->sample_size; - return; - } + return query; + } - /* we only need to build the shard alias on the first call */ - if (bs->alias == nullptr) { - std::vector weights; - if constexpr (Rejection) { - weights.push_back((bs) ? bs->cutoff : 0); - } else { - weights.push_back((bs) ? bs->records.size() : 0); - } - - size_t total_weight = weights[0]; - for (auto &s : shard_states) { - auto state = (State *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - // if no valid records fall within the query range, just - // set all of the sample sizes to 0 and bail out. - if (total_weight == 0) { - for (size_t i=0; i *) shard_states[i]; - state->sample_size = 0; - } - - return; - } - - std::vector normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); - } - - bs->alias = new psudb::Alias(normalized_weights); - } + static void distribute_query(Parameters *parms, + std::vector const &local_queries, + LocalQueryBuffer *buffer_query) { - for (size_t i=0; isample_size; i++) { - auto idx = bs->alias->get(p->rng); - if (idx == 0) { - buffer_sz++; - } else { - shard_sample_sizes[idx - 1]++; - } - } + std::vector shard_sample_sizes(local_queries.size() + 1, 0); + size_t buffer_sz = 0; - if (bs) { - bs->sample_size = buffer_sz; - } - for (size_t i=0; i *) shard_states[i]; - state->sample_size = shard_sample_sizes[i+1]; - } + /* for simplicity of static structure testing */ + if (!buffer_query) { + assert(local_queries.size() == 1); + local_queries[0]->sample_size = + local_queries[0]->global_parms.sample_size; + return; } - static std::vector> query(S *shard, void *q_state, void *parms) { - auto lower_key = ((Parms *) parms)->lower_bound; - auto upper_key = ((Parms *) parms)->upper_bound; - auto rng = ((Parms *) parms)->rng; - - auto state = (State *) q_state; - auto sample_sz = state->sample_size; - - std::vector> result_set; - - if (sample_sz == 0 || state->lower_bound == shard->get_record_count()) { - return result_set; + /* we only need to build the shard alias on the first call */ + if (buffer_query->alias == nullptr) { + std::vector weights; + if constexpr (REJECTION) { + weights.push_back(buffer_query->cutoff); + } else { + weights.push_back(buffer_query->records.size()); + } + + size_t total_weight = weights[0]; + for (auto &q : local_queries) { + total_weight += q->total_weight; + weights.push_back(q->total_weight); + } + + /* + * if no valid records fall within the query range, + * set all of the sample sizes to 0 and bail out. + */ + if (total_weight == 0) { + for (auto q : local_queries) { + q->sample_size = 0; } - size_t attempts = 0; - size_t range_length = state->upper_bound - state->lower_bound; - do { - attempts++; - size_t idx = (range_length > 0) ? gsl_rng_uniform_int(rng, range_length) : 0; - result_set.emplace_back(*shard->get_record_at(state->lower_bound + idx)); - } while (attempts < sample_sz); + return; + } - return result_set; - } + std::vector normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double)w / (double)total_weight); + } - static std::vector> buffer_query(void *state, void *parms) { - auto st = (BufferState *) state; - auto p = (Parms *) parms; + buffer_query->alias = std::make_unique(normalized_weights); + } - std::vector> result; - result.reserve(st->sample_size); + for (size_t i = 0; i < parms->sample_size; i++) { + auto idx = buffer_query->alias->get(parms->rng); + if (idx == 0) { + buffer_sz++; + } else { + shard_sample_sizes[idx - 1]++; + } + } - if constexpr (Rejection) { - for (size_t i=0; isample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); - auto rec = st->buffer->get(idx); + if (buffer_query) { + buffer_query->sample_size = buffer_sz; + } - if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) { - result.emplace_back(*rec); - } - } + for (size_t i = 0; i < local_queries.size(); i++) { + local_queries[i]->sample_size = shard_sample_sizes[i]; + } + } - return result; - } + static std::vector local_query(S *shard, LocalQuery *query) { + auto sample_sz = query->sample_size; - for (size_t i=0; isample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->records.size()); - result.emplace_back(st->records[idx]); - } + std::vector result_set; - return result; + if (sample_sz == 0 || query->lower_idx == shard->get_record_count()) { + return result_set; } - static void merge(std::vector>> &results, void *parms, std::vector &output) { - for (size_t i=0; iupper_idx - query->lower_idx; + do { + attempts++; + size_t idx = + (range_length > 0) + ? gsl_rng_uniform_int(query->global_parms.rng, range_length) + : 0; + result_set.emplace_back(*shard->get_record_at(query->lower_idx + idx)); + } while (attempts < sample_sz); + + return result_set; + } + + static std::vector + local_query_buffer(LocalQueryBuffer *query) { + std::vector result; + result.reserve(query->sample_size); + + if constexpr (REJECTION) { + for (size_t i = 0; i < query->sample_size; i++) { + auto idx = gsl_rng_uniform_int(query->global_parms.rng, query->cutoff); + auto rec = query->buffer->get(idx); + + if (rec->rec.key >= query->global_parms.lower_bound && + rec->rec.key <= query->global_parms.upper_bound) { + result.emplace_back(*rec); } - } + } - static void delete_query_state(void *state) { - auto s = (State *) state; - delete s; + return result; } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState *) state; - delete s; + for (size_t i = 0; i < query->sample_size; i++) { + auto idx = + gsl_rng_uniform_int(query->global_parms.rng, query->records.size()); + result.emplace_back(query->records[idx]); } - static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { - auto p = (Parms *) parms; - - if (results.size() < p->sample_size) { - auto q = *p; - q.sample_size -= results.size(); - process_query_states(&q, states, buffer_state); - return true; - } + return result; + } - return false; + static void + combine(std::vector> const &local_results, + Parameters *parms, std::vector &output) { + for (size_t i = 0; i < local_results.size(); i++) { + for (size_t j = 0; j < local_results[i].size(); j++) { + output.emplace_back(local_results[i][j].rec); + } } + } + + static bool repeat(Parameters *parms, std::vector &output, + std::vector const &local_queries, + LocalQueryBuffer *buffer_query) { + if (output.size() < parms->sample_size) { + parms->sample_size -= output.size(); + distribute_query(parms, local_queries, buffer_query); + return true; + } + + return false; + } }; -}} +} // namespace irs +} // namespace de -- cgit v1.2.3