diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-11-07 12:29:03 -0500 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-11-07 12:29:03 -0500 |
| commit | a2fe4b1616a1b2318f70e842382818ee44aea9e6 (patch) | |
| tree | 40a3dcac716ded595d917d845b255f54b941260a /include/query | |
| parent | e02742b07540dd5a9bcbb44dae14856bf10955ed (diff) | |
| download | dynamic-extension-a2fe4b1616a1b2318f70e842382818ee44aea9e6.tar.gz | |
Alias shard fixes
Diffstat (limited to 'include/query')
| -rw-r--r-- | include/query/wss.h | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/include/query/wss.h b/include/query/wss.h index b8a5d54..794485c 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -90,15 +90,19 @@ public: static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) { auto p = (Parms<R> *) query_parms; - auto bs = (BufferState<R> *) buffer_states[0]; - std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0); + std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0); size_t buffer_sz = 0; std::vector<decltype(R::weight)> weights; - weights.push_back(bs->total_weight); decltype(R::weight) total_weight = 0; + for (auto &s : buffer_states) { + auto bs = (BufferState<R> *) s; + total_weight += bs->total_weight; + weights.push_back(bs->total_weight); + } + for (auto &s : shard_states) { auto state = (State<R> *) s; total_weight += state->total_weight; @@ -113,19 +117,15 @@ public: auto shard_alias = psudb::Alias(normalized_weights); for (size_t i=0; i<p->sample_size; i++) { auto idx = shard_alias.get(p->rng); - if (idx == 0) { - buffer_sz++; + + if (idx < buffer_states.size()) { + auto state = (BufferState<R> *) buffer_states[idx]; + state->sample_size++; } else { - shard_sample_sizes[idx - 1]++; + auto state = (State<R> *) shard_states[idx - buffer_states.size()]; + state->sample_size++; } } - - - bs->sample_size = buffer_sz; - for (size_t i=0; i<shard_states.size(); i++) { - auto state = (State<R> *) shard_states[i]; - state->sample_size = shard_sample_sizes[i+1]; - } } static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { @@ -142,7 +142,7 @@ public: size_t attempts = 0; do { attempts++; - size_t idx = shard->m_alias->get(rng); + size_t idx = shard->get_weighted_sample(rng); result_set.emplace_back(*shard->get_record_at(idx)); } while (attempts < sample_size); |