From 438feac7e56fee425d9c6f1a43298ff9dc5b71d1 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Fri, 19 Apr 2024 17:38:16 -0400 Subject: Properly implemented support for iteratively decomposable problems --- include/query/irs.h | 85 +++++++++++++++++++++++++++++---------------- include/query/knn.h | 7 ++-- include/query/pointlookup.h | 8 +++-- include/query/rangecount.h | 8 +++-- include/query/rangequery.h | 7 ++-- include/query/wirs.h | 13 +++++-- include/query/wss.h | 13 +++++-- 7 files changed, 97 insertions(+), 44 deletions(-) (limited to 'include/query') diff --git a/include/query/irs.h b/include/query/irs.h index 51eb4e2..879d070 100644 --- a/include/query/irs.h +++ b/include/query/irs.h @@ -40,7 +40,12 @@ struct BufferState { size_t sample_size; BufferView *buffer; + psudb::Alias *alias; + BufferState(BufferView *buffer) : buffer(buffer) {} + ~BufferState() { + delete alias; + } }; template S, bool Rejection=true> @@ -72,6 +77,7 @@ public: res->cutoff = res->buffer->get_record_count(); res->sample_size = 0; + res->alias = nullptr; if constexpr (Rejection) { return res; @@ -96,39 +102,51 @@ public: std::vector shard_sample_sizes(shard_states.size()+1, 0); size_t buffer_sz = 0; - std::vector weights; - if constexpr (Rejection) { - weights.push_back((bs) ? bs->cutoff : 0); - } else { - weights.push_back((bs) ? bs->records.size() : 0); + /* for simplicity of static structure testing */ + if (!bs) { + assert(shard_states.size() == 1); + auto state = (State *) shard_states[0]; + state->sample_size = p->sample_size; + return; } - size_t total_weight = weights[0]; - for (auto &s : shard_states) { - auto state = (State *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } + /* we only need to build the shard alias on the first call */ + if (bs->alias == nullptr) { + std::vector weights; + if constexpr (Rejection) { + weights.push_back((bs) ? bs->cutoff : 0); + } else { + weights.push_back((bs) ? bs->records.size() : 0); + } - // if no valid records fall within the query range, just - // set all of the sample sizes to 0 and bail out. - if (total_weight == 0) { - for (size_t i=0; i *) shard_states[i]; - state->sample_size = 0; + size_t total_weight = weights[0]; + for (auto &s : shard_states) { + auto state = (State *) s; + total_weight += state->total_weight; + weights.push_back(state->total_weight); } - return; - } + // if no valid records fall within the query range, just + // set all of the sample sizes to 0 and bail out. + if (total_weight == 0) { + for (size_t i=0; i *) shard_states[i]; + state->sample_size = 0; + } - std::vector normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); + return; + } + + std::vector normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double) w / (double) total_weight); + } + + bs->alias = new psudb::Alias(normalized_weights); } - auto shard_alias = psudb::Alias(normalized_weights); for (size_t i=0; isample_size; i++) { - auto idx = shard_alias.get(p->rng); + auto idx = bs->alias->get(p->rng); if (idx == 0) { buffer_sz++; } else { @@ -198,16 +216,12 @@ public: return result; } - static std::vector merge(std::vector>> &results, void *parms) { - std::vector output; - + static void merge(std::vector>> &results, void *parms, std::vector &output) { for (size_t i=0; i *) state; delete s; } + + static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { + auto p = (Parms *) parms; + + if (results.size() < p->sample_size) { + auto q = *p; + q.sample_size -= results.size(); + process_query_states(&q, states, buffer_state); + return true; + } + + return false; + } }; }} diff --git a/include/query/knn.h b/include/query/knn.h index 19dcf5c..c856a74 100644 --- a/include/query/knn.h +++ b/include/query/knn.h @@ -114,7 +114,7 @@ public: return results; } - static std::vector merge(std::vector>> &results, void *parms) { + static std::vector merge(std::vector>> &results, void *parms, std::vector &output) { Parms *p = (Parms *) parms; R rec = p->point; size_t k = p->k; @@ -136,7 +136,6 @@ public: } } - std::vector output; while (pq.size() > 0) { output.emplace_back(*pq.peek().data); pq.pop(); @@ -154,6 +153,10 @@ public: auto s = (BufferState *) state; delete s; } + + static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { + return false; + } }; }} diff --git a/include/query/pointlookup.h b/include/query/pointlookup.h index 35d38e3..94c2bce 100644 --- a/include/query/pointlookup.h +++ b/include/query/pointlookup.h @@ -89,8 +89,7 @@ public: return records; } - static std::vector merge(std::vector>> &results, void *parms) { - std::vector output; + static std::vector merge(std::vector>> &results, void *parms, std::vector &output) { for (auto r : results) { if (r.size() > 0) { if (r[0].is_deleted() || r[0].is_tombstone()) { @@ -114,6 +113,11 @@ public: auto s = (BufferState *) state; delete s; } + + + static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { + return false; + } }; }} diff --git a/include/query/rangecount.h b/include/query/rangecount.h index 6c57809..c20feaa 100644 --- a/include/query/rangecount.h +++ b/include/query/rangecount.h @@ -134,12 +134,10 @@ public: return records; } - static std::vector merge(std::vector>> &results, void *parms) { - + static std::vector merge(std::vector>> &results, void *parms, std::vector &output) { R res; res.key = 0; res.value = 0; - std::vector output; output.emplace_back(res); for (size_t i=0; i *) state; delete s; } + + static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { + return false; + } }; }} diff --git a/include/query/rangequery.h b/include/query/rangequery.h index e6ab581..e0690e6 100644 --- a/include/query/rangequery.h +++ b/include/query/rangequery.h @@ -109,7 +109,7 @@ public: return records; } - static std::vector merge(std::vector>> &results, void *parms) { + static std::vector merge(std::vector>> &results, void *parms, std::vector &output) { std::vector>> cursors; cursors.reserve(results.size()); @@ -133,7 +133,6 @@ public: return std::vector(); } - std::vector output; output.reserve(total); while (pq.size()) { @@ -169,6 +168,10 @@ public: auto s = (BufferState *) state; delete s; } + + static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { + return false; + } }; }} diff --git a/include/query/wirs.h b/include/query/wirs.h index ae82194..62b43f6 100644 --- a/include/query/wirs.h +++ b/include/query/wirs.h @@ -219,9 +219,7 @@ public: return result; } - static std::vector merge(std::vector>> &results, void *parms) { - std::vector output; - + static std::vector merge(std::vector>> &results, void *parms, std::vector &output) { for (size_t i=0; i *) state; delete s; } + + static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { + auto p = (Parms *) parms; + + if (results.size() < p->sample_size) { + return true; + } + return false; + } }; }} diff --git a/include/query/wss.h b/include/query/wss.h index 8797035..fb0b414 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -183,9 +183,7 @@ public: return result; } - static std::vector merge(std::vector>> &results, void *parms) { - std::vector output; - + static std::vector merge(std::vector>> &results, void *parms, std::vector &output) { for (size_t i=0; i *) state; delete s; } + + static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { + auto p = (Parms *) parms; + + if (results.size() < p->sample_size) { + return true; + } + return false; + } }; }} -- cgit v1.2.3