summaryrefslogtreecommitdiffstats
path: root/include/query/wss.h
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-11-07 12:29:03 -0500
committerDouglas Rumbaugh <dbr4@psu.edu>2023-11-07 12:29:03 -0500
commita2fe4b1616a1b2318f70e842382818ee44aea9e6 (patch)
tree40a3dcac716ded595d917d845b255f54b941260a /include/query/wss.h
parente02742b07540dd5a9bcbb44dae14856bf10955ed (diff)
downloaddynamic-extension-a2fe4b1616a1b2318f70e842382818ee44aea9e6.tar.gz
Alias shard fixes
Diffstat (limited to 'include/query/wss.h')
-rw-r--r--include/query/wss.h28
1 files changed, 14 insertions, 14 deletions
diff --git a/include/query/wss.h b/include/query/wss.h
index b8a5d54..794485c 100644
--- a/include/query/wss.h
+++ b/include/query/wss.h
@@ -90,15 +90,19 @@ public:
static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
auto p = (Parms<R> *) query_parms;
- auto bs = (BufferState<R> *) buffer_states[0];
- std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
size_t buffer_sz = 0;
std::vector<decltype(R::weight)> weights;
- weights.push_back(bs->total_weight);
decltype(R::weight) total_weight = 0;
+ for (auto &s : buffer_states) {
+ auto bs = (BufferState<R> *) s;
+ total_weight += bs->total_weight;
+ weights.push_back(bs->total_weight);
+ }
+
for (auto &s : shard_states) {
auto state = (State<R> *) s;
total_weight += state->total_weight;
@@ -113,19 +117,15 @@ public:
auto shard_alias = psudb::Alias(normalized_weights);
for (size_t i=0; i<p->sample_size; i++) {
auto idx = shard_alias.get(p->rng);
- if (idx == 0) {
- buffer_sz++;
+
+ if (idx < buffer_states.size()) {
+ auto state = (BufferState<R> *) buffer_states[idx];
+ state->sample_size++;
} else {
- shard_sample_sizes[idx - 1]++;
+ auto state = (State<R> *) shard_states[idx - buffer_states.size()];
+ state->sample_size++;
}
}
-
-
- bs->sample_size = buffer_sz;
- for (size_t i=0; i<shard_states.size(); i++) {
- auto state = (State<R> *) shard_states[i];
- state->sample_size = shard_sample_sizes[i+1];
- }
}
static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
@@ -142,7 +142,7 @@ public:
size_t attempts = 0;
do {
attempts++;
- size_t idx = shard->m_alias->get(rng);
+ size_t idx = shard->get_weighted_sample(rng);
result_set.emplace_back(*shard->get_record_at(idx));
} while (attempts < sample_size);