From a6c17386c4e76576f578795947c1763e06f06f46 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Wed, 7 Jun 2023 12:04:13 -0400 Subject: Bugfixes for query state processing function --- include/shard/MemISAM.h | 23 ++++++++++++----------- include/shard/WIRS.h | 22 ++++++++++++---------- include/shard/WSS.h | 17 ++++++++++------- 3 files changed, 34 insertions(+), 28 deletions(-) (limited to 'include/shard') diff --git a/include/shard/MemISAM.h b/include/shard/MemISAM.h index ae1c682..96c404e 100644 --- a/include/shard/MemISAM.h +++ b/include/shard/MemISAM.h @@ -361,6 +361,7 @@ public: res->lower_bound = isam->get_lower_bound(lower_key); res->upper_bound = isam->get_upper_bound(upper_key); + res->sample_size = 0; return res; } @@ -369,6 +370,7 @@ public: auto res = new IRSBufferState(); res->cutoff = buffer->get_record_count(); + res->sample_size = 0; if constexpr (Rejection) { return res; @@ -390,7 +392,7 @@ public: auto p = (irs_query_parms *) query_parms; auto bs = (IRSBufferState *) buff_state; - std::vector shard_sample_sizes = {0}; + std::vector shard_sample_sizes(shard_states.size()+1, 0); size_t buffer_sz = 0; std::vector weights; @@ -400,7 +402,7 @@ public: weights.push_back(bs->records.size()); } - decltype(R::weight) total_weight; + decltype(R::weight) total_weight = 0; for (auto &s : shard_states) { auto state = (IRSState *) s; total_weight += state->upper_bound - state->lower_bound; @@ -422,21 +424,20 @@ public: } } - bs->sample_size = buffer_sz; - size_t i=1; - for (auto &s : shard_states) { - auto state = (IRSState *) s; - state->sample_size = shard_sample_sizes[i++]; + for (size_t i=0; i *) shard_states[i]; + state->sample_size = shard_sample_sizes[i+1]; } } + static std::vector> query(MemISAM *isam, void *q_state, void *parms) { - auto sample_sz = ((irs_query_parms *) parms)->sample_size; auto lower_key = ((irs_query_parms *) parms)->lower_bound; auto upper_key = ((irs_query_parms *) parms)->upper_bound; auto rng = ((irs_query_parms *) parms)->rng; auto state = (IRSState *) q_state; + auto sample_sz = state->sample_size; std::vector> result_set; @@ -460,10 +461,10 @@ public: auto p = (irs_query_parms *) parms; std::vector> result; - result.reserve(p->sample_size); + result.reserve(st->sample_size); if constexpr (Rejection) { - for (size_t i=0; isample_size; i++) { + for (size_t i=0; isample_size; i++) { auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); auto rec = buffer->get_data() + idx; @@ -475,7 +476,7 @@ public: return result; } - for (size_t i=0; isample_size; i++) { + for (size_t i=0; isample_size; i++) { auto idx = gsl_rng_uniform_int(p->rng, st->records.size()); result.emplace_back(st->records[idx]); } diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h index 619c2fe..ab72129 100644 --- a/include/shard/WIRS.h +++ b/include/shard/WIRS.h @@ -392,6 +392,7 @@ public: } res->total_weight = total_weight; res->top_level_alias = new Alias(weights); + res->sample_size = 0; return res; } @@ -403,6 +404,7 @@ public: state->cutoff = buffer->get_record_count() - 1; state->max_weight = buffer->get_max_weight(); state->total_weight = buffer->get_total_weight(); + state->sample_size = 0; return state; } @@ -427,6 +429,7 @@ public: state->total_weight = total_weight; state->alias = new Alias(weights); + state->sample_size = 0; return state; } @@ -435,13 +438,13 @@ public: auto p = (wirs_query_parms *) query_parms; auto bs = (WIRSBufferState *) buff_state; - std::vector shard_sample_sizes = {0}; + std::vector shard_sample_sizes(shard_states.size()+1, 0); size_t buffer_sz = 0; std::vector weights; weights.push_back(bs->total_weight); - decltype(R::weight) total_weight; + decltype(R::weight) total_weight = 0; for (auto &s : shard_states) { auto state = (WIRSState *) s; total_weight += state->total_weight; @@ -465,22 +468,21 @@ public: bs->sample_size = buffer_sz; - size_t i=1; - for (auto &s : shard_states) { - auto state = (WIRSState *) s; - state->sample_size = shard_sample_sizes[i++]; + for (size_t i=0; i *) shard_states[i]; + state->sample_size = shard_sample_sizes[i+1]; } } static std::vector> query(WIRS *wirs, void *q_state, void *parms) { - auto sample_size = ((wirs_query_parms *) parms)->sample_size; auto lower_key = ((wirs_query_parms *) parms)->lower_bound; auto upper_key = ((wirs_query_parms *) parms)->upper_bound; auto rng = ((wirs_query_parms *) parms)->rng; auto state = (WIRSState *) q_state; + auto sample_size = state->sample_size; std::vector> result_set; @@ -517,10 +519,10 @@ public: auto p = (wirs_query_parms *) parms; std::vector> result; - result.reserve(p->sample_size); + result.reserve(st->sample_size); if constexpr (Rejection) { - for (size_t i=0; isample_size; i++) { + for (size_t i=0; isample_size; i++) { auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); auto rec = buffer->get_data() + idx; @@ -533,7 +535,7 @@ public: return result; } - for (size_t i=0; isample_size; i++) { + for (size_t i=0; isample_size; i++) { auto idx = st->alias->get(p->rng); result.emplace_back(st->records[idx]); } diff --git a/include/shard/WSS.h b/include/shard/WSS.h index 1069897..9300932 100644 --- a/include/shard/WSS.h +++ b/include/shard/WSS.h @@ -283,6 +283,8 @@ class WSSQuery { public: static void *get_query_state(WSS *wss, void *parms) { auto res = new WSSState(); + res->total_weight = wss->m_total_weight; + res->sample_size = 0; return res; } @@ -293,6 +295,7 @@ public: if constexpr (Rejection) { state->cutoff = buffer->get_record_count() - 1; state->max_weight = buffer->get_max_weight(); + state->total_weight = buffer->get_total_weight(); return state; } @@ -312,6 +315,7 @@ public: } state->alias = new Alias(weights); + state->total_weight = total_weight; return state; } @@ -320,13 +324,13 @@ public: auto p = (wss_query_parms *) query_parms; auto bs = (WSSBufferState *) buff_state; - std::vector shard_sample_sizes = {0}; + std::vector shard_sample_sizes(shard_states.size()+1, 0); size_t buffer_sz = 0; std::vector weights; weights.push_back(bs->total_weight); - decltype(R::weight) total_weight; + decltype(R::weight) total_weight = 0; for (auto &s : shard_states) { auto state = (WSSState *) s; total_weight += state->total_weight; @@ -350,18 +354,17 @@ public: bs->sample_size = buffer_sz; - size_t i=1; - for (auto &s : shard_states) { - auto state = (WSSState *) s; - state->sample_size = shard_sample_sizes[i++]; + for (size_t i=0; i *) shard_states[i]; + state->sample_size = shard_sample_sizes[i+1]; } } static std::vector> query(WSS *wss, void *q_state, void *parms) { - auto sample_size = ((WSSState *) q_state)->sample_size; auto rng = ((wss_query_parms *) parms)->rng; auto state = (WSSState *) q_state; + auto sample_size = state->sample_size; std::vector> result_set; -- cgit v1.2.3