diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-11-07 12:29:03 -0500 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-11-07 12:29:03 -0500 |
| commit | a2fe4b1616a1b2318f70e842382818ee44aea9e6 (patch) | |
| tree | 40a3dcac716ded595d917d845b255f54b941260a /include | |
| parent | e02742b07540dd5a9bcbb44dae14856bf10955ed (diff) | |
| download | dynamic-extension-a2fe4b1616a1b2318f70e842382818ee44aea9e6.tar.gz | |
Alias shard fixes
Diffstat (limited to 'include')
| -rw-r--r-- | include/query/wss.h | 28 | ||||
| -rw-r--r-- | include/shard/Alias.h | 13 |
2 files changed, 25 insertions, 16 deletions
diff --git a/include/query/wss.h b/include/query/wss.h index b8a5d54..794485c 100644 --- a/include/query/wss.h +++ b/include/query/wss.h @@ -90,15 +90,19 @@ public: static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) { auto p = (Parms<R> *) query_parms; - auto bs = (BufferState<R> *) buffer_states[0]; - std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0); + std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0); size_t buffer_sz = 0; std::vector<decltype(R::weight)> weights; - weights.push_back(bs->total_weight); decltype(R::weight) total_weight = 0; + for (auto &s : buffer_states) { + auto bs = (BufferState<R> *) s; + total_weight += bs->total_weight; + weights.push_back(bs->total_weight); + } + for (auto &s : shard_states) { auto state = (State<R> *) s; total_weight += state->total_weight; @@ -113,19 +117,15 @@ public: auto shard_alias = psudb::Alias(normalized_weights); for (size_t i=0; i<p->sample_size; i++) { auto idx = shard_alias.get(p->rng); - if (idx == 0) { - buffer_sz++; + + if (idx < buffer_states.size()) { + auto state = (BufferState<R> *) buffer_states[idx]; + state->sample_size++; } else { - shard_sample_sizes[idx - 1]++; + auto state = (State<R> *) shard_states[idx - buffer_states.size()]; + state->sample_size++; } } - - - bs->sample_size = buffer_sz; - for (size_t i=0; i<shard_states.size(); i++) { - auto state = (State<R> *) shard_states[i]; - state->sample_size = shard_sample_sizes[i+1]; - } } static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { @@ -142,7 +142,7 @@ public: size_t attempts = 0; do { attempts++; - size_t idx = shard->m_alias->get(rng); + size_t idx = shard->get_weighted_sample(rng); result_set.emplace_back(*shard->get_record_at(idx)); } while (attempts < sample_size); diff --git a/include/shard/Alias.h b/include/shard/Alias.h index b6b16c5..a4a7d02 100644 --- a/include/shard/Alias.h +++ b/include/shard/Alias.h @@ -19,7 +19,7 @@ #include "psu-ds/PriorityQueue.h" #include "util/Cursor.h" -#include "psu-ds/psudb::Alias.h" +#include "psu-ds/Alias.h" #include "psu-ds/BloomFilter.h" #include "util/bf_config.h" @@ -207,7 +207,13 @@ public: return 0; } -private: + W get_total_weight() { + return m_total_weight; + } + + size_t get_weighted_sample(gsl_rng *rng) const { + return m_alias->get(rng); + } size_t get_lower_bound(const K& key) const { size_t min = 0; @@ -227,6 +233,8 @@ private: return min; } +private: + void build_alias_structure(std::vector<W> &weights) { // normalize the weights vector @@ -249,3 +257,4 @@ private: size_t m_alloc_size; BloomFilter<R> *m_bf; }; +} |