From 1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Wed, 7 Jun 2023 11:39:25 -0400 Subject: Added a pre-query hook for processing states This is used for setting up the query alias structure stuff for sampling queries. --- include/framework/DynamicExtension.h | 4 +- include/shard/MemISAM.h | 46 +++++++++++++++++++++++ include/shard/PGM.h | 4 ++ include/shard/TrieSpline.h | 4 ++ include/shard/WIRS.h | 72 ++++++++++++++++++++++++++++++------ include/shard/WSS.h | 71 +++++++++++++++++++++++++++-------- 6 files changed, 172 insertions(+), 29 deletions(-) diff --git a/include/framework/DynamicExtension.h b/include/framework/DynamicExtension.h index 4f3a3bc..a345da6 100644 --- a/include/framework/DynamicExtension.h +++ b/include/framework/DynamicExtension.h @@ -56,7 +56,7 @@ static constexpr bool LSM_REJ_SAMPLE = false; // True for leveling, false for tiering static constexpr bool LSM_LEVELING = false; -static constexpr bool DELETE_TAGGING = false; +static constexpr bool DELETE_TAGGING = true; // TODO: Replace the constexpr bools above // with template parameters based on these @@ -142,6 +142,8 @@ public: level->get_query_states(shards, states, parms); } + Q::process_query_states(parms, states, buffer_state); + std::vector> query_results(shards.size() + 1); // Execute the query for the buffer diff --git a/include/shard/MemISAM.h b/include/shard/MemISAM.h index 01a539a..ae1c682 100644 --- a/include/shard/MemISAM.h +++ b/include/shard/MemISAM.h @@ -39,12 +39,14 @@ template struct IRSState { size_t lower_bound; size_t upper_bound; + size_t sample_size; }; template struct IRSBufferState { size_t cutoff; std::vector> records; + size_t sample_size; }; @@ -384,6 +386,50 @@ public: return res; } + static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { + auto p = (irs_query_parms *) query_parms; + auto bs = (IRSBufferState *) buff_state; + + std::vector shard_sample_sizes = {0}; + size_t buffer_sz = 0; + + std::vector weights; + if (Rejection) { + weights.push_back(bs->cutoff); + } else { + weights.push_back(bs->records.size()); + } + + decltype(R::weight) total_weight; + for (auto &s : shard_states) { + auto state = (IRSState *) s; + total_weight += state->upper_bound - state->lower_bound; + weights.push_back(state->total_weight); + } + + std::vector normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double) w / (double) total_weight); + } + + auto shard_alias = Alias(normalized_weights); + for (size_t i=0; isample_size; i++) { + auto idx = shard_alias.get(p->rng); + if (idx == 0) { + buffer_sz++; + } else { + shard_sample_sizes[idx - 1]++; + } + } + + + bs->sample_size = buffer_sz; + size_t i=1; + for (auto &s : shard_states) { + auto state = (IRSState *) s; + state->sample_size = shard_sample_sizes[i++]; + } + } static std::vector> query(MemISAM *isam, void *q_state, void *parms) { auto sample_sz = ((irs_query_parms *) parms)->sample_size; auto lower_key = ((irs_query_parms *) parms)->lower_bound; diff --git a/include/shard/PGM.h b/include/shard/PGM.h index f9e1dad..8b0bd69 100644 --- a/include/shard/PGM.h +++ b/include/shard/PGM.h @@ -286,6 +286,10 @@ public: return res; } + static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { + return; + } + static std::vector> query(PGM *ts, void *q_state, void *parms) { std::vector> records; auto p = (pgm_range_query_parms *) parms; diff --git a/include/shard/TrieSpline.h b/include/shard/TrieSpline.h index fb0ed70..2341751 100644 --- a/include/shard/TrieSpline.h +++ b/include/shard/TrieSpline.h @@ -305,6 +305,10 @@ public: return res; } + static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { + return; + } + static std::vector> query(TrieSpline *ts, void *q_state, void *parms) { std::vector> records; auto p = (ts_range_query_parms *) parms; diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h index f3696a4..619c2fe 100644 --- a/include/shard/WIRS.h +++ b/include/shard/WIRS.h @@ -51,12 +51,13 @@ struct wirs_node { template struct WIRSState { - decltype(R::weight) tot_weight; + decltype(R::weight) total_weight; std::vector*> nodes; Alias* top_level_alias; + size_t sample_size; WIRSState() { - tot_weight = 0; + total_weight = 0; top_level_alias = nullptr; } @@ -71,6 +72,8 @@ struct WIRSBufferState { Alias* alias; std::vector> records; decltype(R::weight) max_weight; + size_t sample_size; + decltype(R::weight) total_weight; ~WIRSBufferState() { delete alias; @@ -367,7 +370,7 @@ public: decltype(R::key) upper_key = ((wirs_query_parms *) parms)->upper_bound; // Simulate a stack to unfold recursion. - double tot_weight = 0.0; + double total_weight = 0.0; struct wirs_node* st[64] = {0}; st[0] = wirs->m_root; size_t top = 1; @@ -376,7 +379,7 @@ public: if (wirs->covered_by(now, lower_key, upper_key) || (now->left == nullptr && now->right == nullptr && wirs->intersects(now, lower_key, upper_key))) { res->nodes.emplace_back(now); - tot_weight += now->weight; + total_weight += now->weight; } else { if (now->left && wirs->intersects(now->left, lower_key, upper_key)) st[top++] = now->left; if (now->right && wirs->intersects(now->right, lower_key, upper_key)) st[top++] = now->right; @@ -385,9 +388,9 @@ public: std::vector weights; for (const auto& node: res->nodes) { - weights.emplace_back(node->weight / tot_weight); + weights.emplace_back(node->weight / total_weight); } - res->tot_weight = tot_weight; + res->total_weight = total_weight; res->top_level_alias = new Alias(weights); return res; @@ -399,13 +402,14 @@ public: if constexpr (Rejection) { state->cutoff = buffer->get_record_count() - 1; state->max_weight = buffer->get_max_weight(); + state->total_weight = buffer->get_total_weight(); return state; } std::vector weights; state->cutoff = buffer->get_record_count() - 1; - double tot_weight = 0.0; + double total_weight = 0.0; for (size_t i = 0; i <= state->cutoff; i++) { auto rec = buffer->get_data() + i; @@ -413,21 +417,65 @@ public: if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) { weights.push_back(rec->rec.weight); state->records.push_back(*rec); - tot_weight += rec->rec.weight; + total_weight += rec->rec.weight; } } for (size_t i = 0; i < weights.size(); i++) { - weights[i] = weights[i] / tot_weight; + weights[i] = weights[i] / total_weight; } + state->total_weight = total_weight; state->alias = new Alias(weights); return state; } + static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { + auto p = (wirs_query_parms *) query_parms; + auto bs = (WIRSBufferState *) buff_state; + + std::vector shard_sample_sizes = {0}; + size_t buffer_sz = 0; + + std::vector weights; + weights.push_back(bs->total_weight); + + decltype(R::weight) total_weight; + for (auto &s : shard_states) { + auto state = (WIRSState *) s; + total_weight += state->total_weight; + weights.push_back(state->total_weight); + } + + std::vector normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double) w / (double) total_weight); + } + + auto shard_alias = Alias(normalized_weights); + for (size_t i=0; isample_size; i++) { + auto idx = shard_alias.get(p->rng); + if (idx == 0) { + buffer_sz++; + } else { + shard_sample_sizes[idx - 1]++; + } + } + + + bs->sample_size = buffer_sz; + size_t i=1; + for (auto &s : shard_states) { + auto state = (WIRSState *) s; + state->sample_size = shard_sample_sizes[i++]; + } + } + + + static std::vector> query(WIRS *wirs, void *q_state, void *parms) { - auto sample_sz = ((wirs_query_parms *) parms)->sample_size; + auto sample_size = ((wirs_query_parms *) parms)->sample_size; auto lower_key = ((wirs_query_parms *) parms)->lower_bound; auto upper_key = ((wirs_query_parms *) parms)->upper_bound; auto rng = ((wirs_query_parms *) parms)->rng; @@ -436,7 +484,7 @@ public: std::vector> result_set; - if (sample_sz == 0) { + if (sample_size == 0) { return result_set; } // k -> sampling: three levels. 1. select a node -> select a fat point -> select a record. @@ -459,7 +507,7 @@ public: result_set.emplace_back(*record); cnt++; - } while (attempts < sample_sz); + } while (attempts < sample_size); return result_set; } diff --git a/include/shard/WSS.h b/include/shard/WSS.h index bb7ee2a..1069897 100644 --- a/include/shard/WSS.h +++ b/include/shard/WSS.h @@ -41,18 +41,21 @@ class WSSQuery; template struct WSSState { - decltype(R::weight) tot_weight; + decltype(R::weight) total_weight; + size_t sample_size; WSSState() { - tot_weight = 0; + total_weight = 0; } }; template struct WSSBufferState { size_t cutoff; + size_t sample_size; Alias* alias; decltype(R::weight) max_weight; + decltype(R::weight) total_weight; ~WSSBufferState() { delete alias; @@ -296,16 +299,16 @@ public: std::vector weights; state->cutoff = buffer->get_record_count() - 1; - double tot_weight = 0.0; + double total_weight = 0.0; for (size_t i = 0; i <= state->cutoff; i++) { auto rec = buffer->get_data() + i; weights.push_back(rec->rec.weight); - tot_weight += rec->rec.weight; + total_weight += rec->rec.weight; } for (size_t i = 0; i < weights.size(); i++) { - weights[i] = weights[i] / tot_weight; + weights[i] = weights[i] / total_weight; } state->alias = new Alias(weights); @@ -313,15 +316,56 @@ public: return state; } + static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { + auto p = (wss_query_parms *) query_parms; + auto bs = (WSSBufferState *) buff_state; + + std::vector shard_sample_sizes = {0}; + size_t buffer_sz = 0; + + std::vector weights; + weights.push_back(bs->total_weight); + + decltype(R::weight) total_weight; + for (auto &s : shard_states) { + auto state = (WSSState *) s; + total_weight += state->total_weight; + weights.push_back(state->total_weight); + } + + std::vector normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double) w / (double) total_weight); + } + + auto shard_alias = Alias(normalized_weights); + for (size_t i=0; isample_size; i++) { + auto idx = shard_alias.get(p->rng); + if (idx == 0) { + buffer_sz++; + } else { + shard_sample_sizes[idx - 1]++; + } + } + + + bs->sample_size = buffer_sz; + size_t i=1; + for (auto &s : shard_states) { + auto state = (WSSState *) s; + state->sample_size = shard_sample_sizes[i++]; + } + } + static std::vector> query(WSS *wss, void *q_state, void *parms) { - auto sample_sz = ((wss_query_parms *) parms)->sample_size; + auto sample_size = ((WSSState *) q_state)->sample_size; auto rng = ((wss_query_parms *) parms)->rng; auto state = (WSSState *) q_state; std::vector> result_set; - if (sample_sz == 0) { + if (sample_size == 0) { return result_set; } size_t attempts = 0; @@ -329,7 +373,7 @@ public: attempts++; size_t idx = wss->m_alias->get(rng); result_set.emplace_back(*wss->get_record_at(idx)); - } while (attempts < sample_sz); + } while (attempts < sample_size); return result_set; } @@ -339,10 +383,10 @@ public: auto p = (wss_query_parms *) parms; std::vector> result; - result.reserve(p->sample_size); + result.reserve(st->sample_size); if constexpr (Rejection) { - for (size_t i=0; isample_size; i++) { + for (size_t i=0; isample_size; i++) { auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); auto rec = buffer->get_data() + idx; @@ -355,7 +399,7 @@ public: return result; } - for (size_t i=0; isample_size; i++) { + for (size_t i=0; isample_size; i++) { auto idx = st->alias->get(p->rng); result.emplace_back(*(buffer->get_data() + idx)); } @@ -384,11 +428,6 @@ public: auto s = (WSSBufferState *) state; delete s; } - - - //{q.get_buffer_query_state(p, p)}; - //{q.buffer_query(p, p)}; - }; } -- cgit v1.2.3