diff options
Diffstat (limited to 'include/shard/WSS.h')
| -rw-r--r-- | include/shard/WSS.h | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/include/shard/WSS.h b/include/shard/WSS.h index 1069897..9300932 100644 --- a/include/shard/WSS.h +++ b/include/shard/WSS.h @@ -283,6 +283,8 @@ class WSSQuery { public: static void *get_query_state(WSS<R> *wss, void *parms) { auto res = new WSSState<R>(); + res->total_weight = wss->m_total_weight; + res->sample_size = 0; return res; } @@ -293,6 +295,7 @@ public: if constexpr (Rejection) { state->cutoff = buffer->get_record_count() - 1; state->max_weight = buffer->get_max_weight(); + state->total_weight = buffer->get_total_weight(); return state; } @@ -312,6 +315,7 @@ public: } state->alias = new Alias(weights); + state->total_weight = total_weight; return state; } @@ -320,13 +324,13 @@ public: auto p = (wss_query_parms<R> *) query_parms; auto bs = (WSSBufferState<R> *) buff_state; - std::vector<size_t> shard_sample_sizes = {0}; + std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0); size_t buffer_sz = 0; std::vector<decltype(R::weight)> weights; weights.push_back(bs->total_weight); - decltype(R::weight) total_weight; + decltype(R::weight) total_weight = 0; for (auto &s : shard_states) { auto state = (WSSState<R> *) s; total_weight += state->total_weight; @@ -350,18 +354,17 @@ public: bs->sample_size = buffer_sz; - size_t i=1; - for (auto &s : shard_states) { - auto state = (WSSState<R> *) s; - state->sample_size = shard_sample_sizes[i++]; + for (size_t i=0; i<shard_states.size(); i++) { + auto state = (WSSState<R> *) shard_states[i]; + state->sample_size = shard_sample_sizes[i+1]; } } static std::vector<Wrapped<R>> query(WSS<R> *wss, void *q_state, void *parms) { - auto sample_size = ((WSSState<R> *) q_state)->sample_size; auto rng = ((wss_query_parms<R> *) parms)->rng; auto state = (WSSState<R> *) q_state; + auto sample_size = state->sample_size; std::vector<Wrapped<R>> result_set; |