diff options
Diffstat (limited to 'include/query')
| -rw-r--r-- | include/query/irs.h | 85 | ||||
| -rw-r--r-- | include/query/knn.h | 11 | ||||
| -rw-r--r-- | include/query/pointlookup.h | 123 | ||||
| -rw-r--r-- | include/query/rangecount.h | 60 | ||||
| -rw-r--r-- | include/query/rangequery.h | 9 | ||||
| -rw-r--r-- | include/query/wirs.h | 13 | ||||
| -rw-r--r-- | include/query/wss.h | 13 |
7 files changed, 246 insertions, 68 deletions
diff --git a/include/query/irs.h b/include/query/irs.h index e2d9325..879d070 100644 --- a/include/query/irs.h +++ b/include/query/irs.h @@ -40,7 +40,12 @@ struct BufferState { size_t sample_size; BufferView<R> *buffer; + psudb::Alias *alias; + BufferState(BufferView<R> *buffer) : buffer(buffer) {} + ~BufferState() { + delete alias; + } }; template <RecordInterface R, ShardInterface<R> S, bool Rejection=true> @@ -72,6 +77,7 @@ public: res->cutoff = res->buffer->get_record_count(); res->sample_size = 0; + res->alias = nullptr; if constexpr (Rejection) { return res; @@ -96,39 +102,51 @@ public: std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0); size_t buffer_sz = 0; - std::vector<size_t> weights; - if constexpr (Rejection) { - weights.push_back((bs) ? bs->cutoff : 0); - } else { - weights.push_back((bs) ? bs->records.size() : 0); + /* for simplicity of static structure testing */ + if (!bs) { + assert(shard_states.size() == 1); + auto state = (State<R> *) shard_states[0]; + state->sample_size = p->sample_size; + return; } - size_t total_weight = 0; - for (auto &s : shard_states) { - auto state = (State<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } + /* we only need to build the shard alias on the first call */ + if (bs->alias == nullptr) { + std::vector<size_t> weights; + if constexpr (Rejection) { + weights.push_back((bs) ? bs->cutoff : 0); + } else { + weights.push_back((bs) ? bs->records.size() : 0); + } - // if no valid records fall within the query range, just - // set all of the sample sizes to 0 and bail out. - if (total_weight == 0) { - for (size_t i=0; i<shard_states.size(); i++) { - auto state = (State<R> *) shard_states[i]; - state->sample_size = 0; + size_t total_weight = weights[0]; + for (auto &s : shard_states) { + auto state = (State<R> *) s; + total_weight += state->total_weight; + weights.push_back(state->total_weight); } - return; - } + // if no valid records fall within the query range, just + // set all of the sample sizes to 0 and bail out. + if (total_weight == 0) { + for (size_t i=0; i<shard_states.size(); i++) { + auto state = (State<R> *) shard_states[i]; + state->sample_size = 0; + } - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); + return; + } + + std::vector<double> normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double) w / (double) total_weight); + } + + bs->alias = new psudb::Alias(normalized_weights); } - auto shard_alias = psudb::Alias(normalized_weights); for (size_t i=0; i<p->sample_size; i++) { - auto idx = shard_alias.get(p->rng); + auto idx = bs->alias->get(p->rng); if (idx == 0) { buffer_sz++; } else { @@ -198,16 +216,12 @@ public: return result; } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { - std::vector<R> output; - + static void merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { for (size_t i=0; i<results.size(); i++) { for (size_t j=0; j<results[i].size(); j++) { output.emplace_back(results[i][j].rec); } } - - return output; } static void delete_query_state(void *state) { @@ -219,5 +233,18 @@ public: auto s = (BufferState<R> *) state; delete s; } + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + auto p = (Parms<R> *) parms; + + if (results.size() < p->sample_size) { + auto q = *p; + q.sample_size -= results.size(); + process_query_states(&q, states, buffer_state); + return true; + } + + return false; + } }; }} diff --git a/include/query/knn.h b/include/query/knn.h index 19dcf5c..a227293 100644 --- a/include/query/knn.h +++ b/include/query/knn.h @@ -111,10 +111,10 @@ public: pq.pop(); } - return results; + return std::move(results); } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { + static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { Parms<R> *p = (Parms<R> *) parms; R rec = p->point; size_t k = p->k; @@ -136,13 +136,12 @@ public: } } - std::vector<R> output; while (pq.size() > 0) { output.emplace_back(*pq.peek().data); pq.pop(); } - return output; + return std::move(output); } static void delete_query_state(void *state) { @@ -154,6 +153,10 @@ public: auto s = (BufferState<R> *) state; delete s; } + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + return false; + } }; }} diff --git a/include/query/pointlookup.h b/include/query/pointlookup.h new file mode 100644 index 0000000..94c2bce --- /dev/null +++ b/include/query/pointlookup.h @@ -0,0 +1,123 @@ +/* + * include/query/pointlookup.h + * + * Copyright (C) 2024 Douglas B. Rumbaugh <drumbaugh@psu.edu> + * + * Distributed under the Modified BSD License. + * + * A query class for point lookup operations. + * + * TODO: Currently, this only supports point lookups for unique keys (which + * is the case for the trie that we're building this to use). It would be + * pretty straightforward to extend it to return *all* records that match + * the search_key (including tombstone cancellation--it's invertible) to + * support non-unique indexes, or at least those implementing + * lower_bound(). + */ +#pragma once + +#include "framework/QueryRequirements.h" + +namespace de { namespace pl { + +template <RecordInterface R> +struct Parms { + decltype(R::key) search_key; +}; + +template <RecordInterface R> +struct State { +}; + +template <RecordInterface R> +struct BufferState { + BufferView<R> *buffer; + + BufferState(BufferView<R> *buffer) + : buffer(buffer) {} +}; + +template <KVPInterface R, ShardInterface<R> S> +class Query { +public: + constexpr static bool EARLY_ABORT=true; + constexpr static bool SKIP_DELETE_FILTER=true; + + static void *get_query_state(S *shard, void *parms) { + return nullptr; + } + + static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { + auto res = new BufferState<R>(buffer); + + return res; + } + + static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) { + return; + } + + static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { + auto p = (Parms<R> *) parms; + auto s = (State<R> *) q_state; + + std::vector<Wrapped<R>> result; + + auto r = shard->point_lookup({p->search_key, 0}); + + if (r) { + result.push_back(*r); + } + + return result; + } + + static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { + auto p = (Parms<R> *) parms; + auto s = (BufferState<R> *) state; + + std::vector<Wrapped<R>> records; + for (size_t i=0; i<s->buffer->get_record_count(); i++) { + auto rec = s->buffer->get(i); + + if (rec->rec.key == p->search_key) { + records.push_back(*rec); + return records; + } + } + + return records; + } + + static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { + for (auto r : results) { + if (r.size() > 0) { + if (r[0].is_deleted() || r[0].is_tombstone()) { + return output; + } + + output.push_back(r[0].rec); + return output; + } + } + + return output; + } + + static void delete_query_state(void *state) { + auto s = (State<R> *) state; + delete s; + } + + static void delete_buffer_query_state(void *state) { + auto s = (BufferState<R> *) state; + delete s; + } + + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + return false; + } +}; + +}} diff --git a/include/query/rangecount.h b/include/query/rangecount.h index 6c57809..5b95cdd 100644 --- a/include/query/rangecount.h +++ b/include/query/rangecount.h @@ -35,20 +35,14 @@ struct BufferState { : buffer(buffer) {} }; -template <KVPInterface R, ShardInterface<R> S> +template <KVPInterface R, ShardInterface<R> S, bool FORCE_SCAN=false> class Query { public: constexpr static bool EARLY_ABORT=false; constexpr static bool SKIP_DELETE_FILTER=true; static void *get_query_state(S *shard, void *parms) { - auto res = new State<R>(); - auto p = (Parms<R> *) parms; - - res->start_idx = shard->get_lower_bound(p->lower_bound); - res->stop_idx = shard->get_record_count(); - - return res; + return nullptr; } static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { @@ -74,37 +68,43 @@ public: res.rec.value = 0; // tombstones records.emplace_back(res); + + auto start_idx = shard->get_lower_bound(p->lower_bound); + auto stop_idx = shard->get_lower_bound(p->upper_bound); + /* * if the returned index is one past the end of the * records for the PGM, then there are not records * in the index falling into the specified range. */ - if (s->start_idx == shard->get_record_count()) { + if (start_idx == shard->get_record_count()) { return records; } - auto ptr = shard->get_record_at(s->start_idx); /* * roll the pointer forward to the first record that is * greater than or equal to the lower bound. */ - while(ptr < shard->get_data() + s->stop_idx && ptr->rec.key < p->lower_bound) { - ptr++; + auto recs = shard->get_data(); + while(start_idx < stop_idx && recs[start_idx].rec.key < p->lower_bound) { + start_idx++; } - while (ptr < shard->get_data() + s->stop_idx && ptr->rec.key <= p->upper_bound) { - if (!ptr->is_deleted()) { - if (ptr->is_tombstone()) { - records[0].rec.value++; - } else { - records[0].rec.key++; - } - } + while (stop_idx < shard->get_record_count() && recs[stop_idx].rec.key <= p->upper_bound) { + stop_idx++; + } + size_t idx = start_idx; + size_t ts_cnt = 0; - ptr++; + while (idx < stop_idx) { + ts_cnt += recs[idx].is_tombstone() * 2 + recs[idx].is_deleted(); + idx++; } + records[0].rec.key = idx - start_idx; + records[0].rec.value = ts_cnt; + return records; } @@ -119,8 +119,16 @@ public: res.rec.value = 0; // tombstones records.emplace_back(res); + size_t stop_idx; + if constexpr (FORCE_SCAN) { + stop_idx = s->buffer->get_capacity() / 2; + } else { + stop_idx = s->buffer->get_record_count(); + } + for (size_t i=0; i<s->buffer->get_record_count(); i++) { auto rec = s->buffer->get(i); + if (rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound && !rec->is_deleted()) { if (rec->is_tombstone()) { @@ -134,12 +142,10 @@ public: return records; } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { - + static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { R res; res.key = 0; res.value = 0; - std::vector<R> output; output.emplace_back(res); for (size_t i=0; i<results.size(); i++) { @@ -152,14 +158,16 @@ public: } static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; } static void delete_buffer_query_state(void *state) { auto s = (BufferState<R> *) state; delete s; } + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + return false; + } }; }} diff --git a/include/query/rangequery.h b/include/query/rangequery.h index 24b38ec..e0690e6 100644 --- a/include/query/rangequery.h +++ b/include/query/rangequery.h @@ -109,7 +109,7 @@ public: return records; } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { + static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { std::vector<Cursor<Wrapped<R>>> cursors; cursors.reserve(results.size()); @@ -121,7 +121,7 @@ public: for (size_t i = 0; i < tmp_n; ++i) if (results[i].size() > 0){ auto base = results[i].data(); - cursors.emplace_back(Cursor{base, base + results[i].size(), 0, results[i].size()}); + cursors.emplace_back(Cursor<Wrapped<R>>{base, base + results[i].size(), 0, results[i].size()}); assert(i == cursors.size() - 1); total += results[i].size(); pq.push(cursors[i].ptr, tmp_n - i - 1); @@ -133,7 +133,6 @@ public: return std::vector<R>(); } - std::vector<R> output; output.reserve(total); while (pq.size()) { @@ -169,6 +168,10 @@ public: auto s = (BufferState<R> *) state; delete s; } + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + return false; + } }; }} diff --git a/include/query/wirs.h b/include/query/wirs.h index ae82194..62b43f6 100644 --- a/include/query/wirs.h +++ b/include/query/wirs.h @@ -219,9 +219,7 @@ public: return result; } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { - std::vector<R> output; - + static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { for (size_t i=0; i<results.size(); i++) { for (size_t j=0; j<results[i].size(); j++) { output.emplace_back(results[i][j].rec); @@ -240,5 +238,14 @@ public: auto s = (BufferState<R> *) state; delete s; } + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + auto p = (Parms<R> *) parms; + + if (results.size() < p->sample_size) { + return true; + } + return false; + } }; }} diff --git a/include/query/wss.h b/include/query/wss.h index 8797035..fb0b414 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -183,9 +183,7 @@ public: return result; } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { - std::vector<R> output; - + static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { for (size_t i=0; i<results.size(); i++) { for (size_t j=0; j<results[i].size(); j++) { output.emplace_back(results[i][j].rec); @@ -204,6 +202,15 @@ public: auto s = (BufferState<R> *) state; delete s; } + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + auto p = (Parms<R> *) parms; + + if (results.size() < p->sample_size) { + return true; + } + return false; + } }; }} |