summaryrefslogtreecommitdiffstats
path: root/include/shard/WSS.h
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 12:04:13 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 12:04:13 -0400
commita6c17386c4e76576f578795947c1763e06f06f46 (patch)
treeb932b19e52b125dcb517cce9bc38b2bd89e0a1e8 /include/shard/WSS.h
parent1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 (diff)
downloaddynamic-extension-a6c17386c4e76576f578795947c1763e06f06f46.tar.gz
Bugfixes for query state processing function
Diffstat (limited to 'include/shard/WSS.h')
-rw-r--r--include/shard/WSS.h17
1 files changed, 10 insertions, 7 deletions
diff --git a/include/shard/WSS.h b/include/shard/WSS.h
index 1069897..9300932 100644
--- a/include/shard/WSS.h
+++ b/include/shard/WSS.h
@@ -283,6 +283,8 @@ class WSSQuery {
public:
static void *get_query_state(WSS<R> *wss, void *parms) {
auto res = new WSSState<R>();
+ res->total_weight = wss->m_total_weight;
+ res->sample_size = 0;
return res;
}
@@ -293,6 +295,7 @@ public:
if constexpr (Rejection) {
state->cutoff = buffer->get_record_count() - 1;
state->max_weight = buffer->get_max_weight();
+ state->total_weight = buffer->get_total_weight();
return state;
}
@@ -312,6 +315,7 @@ public:
}
state->alias = new Alias(weights);
+ state->total_weight = total_weight;
return state;
}
@@ -320,13 +324,13 @@ public:
auto p = (wss_query_parms<R> *) query_parms;
auto bs = (WSSBufferState<R> *) buff_state;
- std::vector<size_t> shard_sample_sizes = {0};
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
size_t buffer_sz = 0;
std::vector<decltype(R::weight)> weights;
weights.push_back(bs->total_weight);
- decltype(R::weight) total_weight;
+ decltype(R::weight) total_weight = 0;
for (auto &s : shard_states) {
auto state = (WSSState<R> *) s;
total_weight += state->total_weight;
@@ -350,18 +354,17 @@ public:
bs->sample_size = buffer_sz;
- size_t i=1;
- for (auto &s : shard_states) {
- auto state = (WSSState<R> *) s;
- state->sample_size = shard_sample_sizes[i++];
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (WSSState<R> *) shard_states[i];
+ state->sample_size = shard_sample_sizes[i+1];
}
}
static std::vector<Wrapped<R>> query(WSS<R> *wss, void *q_state, void *parms) {
- auto sample_size = ((WSSState<R> *) q_state)->sample_size;
auto rng = ((wss_query_parms<R> *) parms)->rng;
auto state = (WSSState<R> *) q_state;
+ auto sample_size = state->sample_size;
std::vector<Wrapped<R>> result_set;