summaryrefslogtreecommitdiffstats
path: root/include/shard
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 12:04:13 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 12:04:13 -0400
commita6c17386c4e76576f578795947c1763e06f06f46 (patch)
treeb932b19e52b125dcb517cce9bc38b2bd89e0a1e8 /include/shard
parent1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 (diff)
downloaddynamic-extension-a6c17386c4e76576f578795947c1763e06f06f46.tar.gz
Bugfixes for query state processing function
Diffstat (limited to 'include/shard')
-rw-r--r--include/shard/MemISAM.h23
-rw-r--r--include/shard/WIRS.h22
-rw-r--r--include/shard/WSS.h17
3 files changed, 34 insertions, 28 deletions
diff --git a/include/shard/MemISAM.h b/include/shard/MemISAM.h
index ae1c682..96c404e 100644
--- a/include/shard/MemISAM.h
+++ b/include/shard/MemISAM.h
@@ -361,6 +361,7 @@ public:
res->lower_bound = isam->get_lower_bound(lower_key);
res->upper_bound = isam->get_upper_bound(upper_key);
+ res->sample_size = 0;
return res;
}
@@ -369,6 +370,7 @@ public:
auto res = new IRSBufferState<R>();
res->cutoff = buffer->get_record_count();
+ res->sample_size = 0;
if constexpr (Rejection) {
return res;
@@ -390,7 +392,7 @@ public:
auto p = (irs_query_parms<R> *) query_parms;
auto bs = (IRSBufferState<R> *) buff_state;
- std::vector<size_t> shard_sample_sizes = {0};
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
size_t buffer_sz = 0;
std::vector<size_t> weights;
@@ -400,7 +402,7 @@ public:
weights.push_back(bs->records.size());
}
- decltype(R::weight) total_weight;
+ decltype(R::weight) total_weight = 0;
for (auto &s : shard_states) {
auto state = (IRSState<R> *) s;
total_weight += state->upper_bound - state->lower_bound;
@@ -422,21 +424,20 @@ public:
}
}
-
bs->sample_size = buffer_sz;
- size_t i=1;
- for (auto &s : shard_states) {
- auto state = (IRSState<R> *) s;
- state->sample_size = shard_sample_sizes[i++];
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (IRSState<R> *) shard_states[i];
+ state->sample_size = shard_sample_sizes[i+1];
}
}
+
static std::vector<Wrapped<R>> query(MemISAM<R> *isam, void *q_state, void *parms) {
- auto sample_sz = ((irs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((irs_query_parms<R> *) parms)->lower_bound;
auto upper_key = ((irs_query_parms<R> *) parms)->upper_bound;
auto rng = ((irs_query_parms<R> *) parms)->rng;
auto state = (IRSState<R> *) q_state;
+ auto sample_sz = state->sample_size;
std::vector<Wrapped<R>> result_set;
@@ -460,10 +461,10 @@ public:
auto p = (irs_query_parms<R> *) parms;
std::vector<Wrapped<R>> result;
- result.reserve(p->sample_size);
+ result.reserve(st->sample_size);
if constexpr (Rejection) {
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
auto rec = buffer->get_data() + idx;
@@ -475,7 +476,7 @@ public:
return result;
}
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = gsl_rng_uniform_int(p->rng, st->records.size());
result.emplace_back(st->records[idx]);
}
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index 619c2fe..ab72129 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -392,6 +392,7 @@ public:
}
res->total_weight = total_weight;
res->top_level_alias = new Alias(weights);
+ res->sample_size = 0;
return res;
}
@@ -403,6 +404,7 @@ public:
state->cutoff = buffer->get_record_count() - 1;
state->max_weight = buffer->get_max_weight();
state->total_weight = buffer->get_total_weight();
+ state->sample_size = 0;
return state;
}
@@ -427,6 +429,7 @@ public:
state->total_weight = total_weight;
state->alias = new Alias(weights);
+ state->sample_size = 0;
return state;
}
@@ -435,13 +438,13 @@ public:
auto p = (wirs_query_parms<R> *) query_parms;
auto bs = (WIRSBufferState<R> *) buff_state;
- std::vector<size_t> shard_sample_sizes = {0};
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
size_t buffer_sz = 0;
std::vector<decltype(R::weight)> weights;
weights.push_back(bs->total_weight);
- decltype(R::weight) total_weight;
+ decltype(R::weight) total_weight = 0;
for (auto &s : shard_states) {
auto state = (WIRSState<R> *) s;
total_weight += state->total_weight;
@@ -465,22 +468,21 @@ public:
bs->sample_size = buffer_sz;
- size_t i=1;
- for (auto &s : shard_states) {
- auto state = (WIRSState<R> *) s;
- state->sample_size = shard_sample_sizes[i++];
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (WIRSState<R> *) shard_states[i];
+ state->sample_size = shard_sample_sizes[i+1];
}
}
static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) {
- auto sample_size = ((wirs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound;
auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
auto rng = ((wirs_query_parms<R> *) parms)->rng;
auto state = (WIRSState<R> *) q_state;
+ auto sample_size = state->sample_size;
std::vector<Wrapped<R>> result_set;
@@ -517,10 +519,10 @@ public:
auto p = (wirs_query_parms<R> *) parms;
std::vector<Wrapped<R>> result;
- result.reserve(p->sample_size);
+ result.reserve(st->sample_size);
if constexpr (Rejection) {
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
auto rec = buffer->get_data() + idx;
@@ -533,7 +535,7 @@ public:
return result;
}
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = st->alias->get(p->rng);
result.emplace_back(st->records[idx]);
}
diff --git a/include/shard/WSS.h b/include/shard/WSS.h
index 1069897..9300932 100644
--- a/include/shard/WSS.h
+++ b/include/shard/WSS.h
@@ -283,6 +283,8 @@ class WSSQuery {
public:
static void *get_query_state(WSS<R> *wss, void *parms) {
auto res = new WSSState<R>();
+ res->total_weight = wss->m_total_weight;
+ res->sample_size = 0;
return res;
}
@@ -293,6 +295,7 @@ public:
if constexpr (Rejection) {
state->cutoff = buffer->get_record_count() - 1;
state->max_weight = buffer->get_max_weight();
+ state->total_weight = buffer->get_total_weight();
return state;
}
@@ -312,6 +315,7 @@ public:
}
state->alias = new Alias(weights);
+ state->total_weight = total_weight;
return state;
}
@@ -320,13 +324,13 @@ public:
auto p = (wss_query_parms<R> *) query_parms;
auto bs = (WSSBufferState<R> *) buff_state;
- std::vector<size_t> shard_sample_sizes = {0};
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
size_t buffer_sz = 0;
std::vector<decltype(R::weight)> weights;
weights.push_back(bs->total_weight);
- decltype(R::weight) total_weight;
+ decltype(R::weight) total_weight = 0;
for (auto &s : shard_states) {
auto state = (WSSState<R> *) s;
total_weight += state->total_weight;
@@ -350,18 +354,17 @@ public:
bs->sample_size = buffer_sz;
- size_t i=1;
- for (auto &s : shard_states) {
- auto state = (WSSState<R> *) s;
- state->sample_size = shard_sample_sizes[i++];
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (WSSState<R> *) shard_states[i];
+ state->sample_size = shard_sample_sizes[i+1];
}
}
static std::vector<Wrapped<R>> query(WSS<R> *wss, void *q_state, void *parms) {
- auto sample_size = ((WSSState<R> *) q_state)->sample_size;
auto rng = ((wss_query_parms<R> *) parms)->rng;
auto state = (WSSState<R> *) q_state;
+ auto sample_size = state->sample_size;
std::vector<Wrapped<R>> result_set;