summaryrefslogtreecommitdiffstats
path: root/include/shard/WIRS.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/shard/WIRS.h')
-rw-r--r--include/shard/WIRS.h22
1 files changed, 12 insertions, 10 deletions
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index 619c2fe..ab72129 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -392,6 +392,7 @@ public:
}
res->total_weight = total_weight;
res->top_level_alias = new Alias(weights);
+ res->sample_size = 0;
return res;
}
@@ -403,6 +404,7 @@ public:
state->cutoff = buffer->get_record_count() - 1;
state->max_weight = buffer->get_max_weight();
state->total_weight = buffer->get_total_weight();
+ state->sample_size = 0;
return state;
}
@@ -427,6 +429,7 @@ public:
state->total_weight = total_weight;
state->alias = new Alias(weights);
+ state->sample_size = 0;
return state;
}
@@ -435,13 +438,13 @@ public:
auto p = (wirs_query_parms<R> *) query_parms;
auto bs = (WIRSBufferState<R> *) buff_state;
- std::vector<size_t> shard_sample_sizes = {0};
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
size_t buffer_sz = 0;
std::vector<decltype(R::weight)> weights;
weights.push_back(bs->total_weight);
- decltype(R::weight) total_weight;
+ decltype(R::weight) total_weight = 0;
for (auto &s : shard_states) {
auto state = (WIRSState<R> *) s;
total_weight += state->total_weight;
@@ -465,22 +468,21 @@ public:
bs->sample_size = buffer_sz;
- size_t i=1;
- for (auto &s : shard_states) {
- auto state = (WIRSState<R> *) s;
- state->sample_size = shard_sample_sizes[i++];
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (WIRSState<R> *) shard_states[i];
+ state->sample_size = shard_sample_sizes[i+1];
}
}
static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) {
- auto sample_size = ((wirs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound;
auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
auto rng = ((wirs_query_parms<R> *) parms)->rng;
auto state = (WIRSState<R> *) q_state;
+ auto sample_size = state->sample_size;
std::vector<Wrapped<R>> result_set;
@@ -517,10 +519,10 @@ public:
auto p = (wirs_query_parms<R> *) parms;
std::vector<Wrapped<R>> result;
- result.reserve(p->sample_size);
+ result.reserve(st->sample_size);
if constexpr (Rejection) {
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
auto rec = buffer->get_data() + idx;
@@ -533,7 +535,7 @@ public:
return result;
}
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = st->alias->get(p->rng);
result.emplace_back(st->records[idx]);
}