summaryrefslogtreecommitdiffstats
path: root/include/shard/WSS.h
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
commit1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 (patch)
treefbc59c0c52e2db66b252a7b47243c293ea008797 /include/shard/WSS.h
parent1800af2e9503302274e7ba35eed45aa5839af23d (diff)
downloaddynamic-extension-1a791e7241fb9898f58cd4642cf8cf8ec2a1c885.tar.gz
Added a pre-query hook for processing states
This is used for setting up the query alias structure stuff for sampling queries.
Diffstat (limited to 'include/shard/WSS.h')
-rw-r--r--include/shard/WSS.h71
1 files changed, 55 insertions, 16 deletions
diff --git a/include/shard/WSS.h b/include/shard/WSS.h
index bb7ee2a..1069897 100644
--- a/include/shard/WSS.h
+++ b/include/shard/WSS.h
@@ -41,18 +41,21 @@ class WSSQuery;
template <WeightedRecordInterface R>
struct WSSState {
- decltype(R::weight) tot_weight;
+ decltype(R::weight) total_weight;
+ size_t sample_size;
WSSState() {
- tot_weight = 0;
+ total_weight = 0;
}
};
template <WeightedRecordInterface R>
struct WSSBufferState {
size_t cutoff;
+ size_t sample_size;
Alias* alias;
decltype(R::weight) max_weight;
+ decltype(R::weight) total_weight;
~WSSBufferState() {
delete alias;
@@ -296,16 +299,16 @@ public:
std::vector<double> weights;
state->cutoff = buffer->get_record_count() - 1;
- double tot_weight = 0.0;
+ double total_weight = 0.0;
for (size_t i = 0; i <= state->cutoff; i++) {
auto rec = buffer->get_data() + i;
weights.push_back(rec->rec.weight);
- tot_weight += rec->rec.weight;
+ total_weight += rec->rec.weight;
}
for (size_t i = 0; i < weights.size(); i++) {
- weights[i] = weights[i] / tot_weight;
+ weights[i] = weights[i] / total_weight;
}
state->alias = new Alias(weights);
@@ -313,15 +316,56 @@ public:
return state;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ auto p = (wss_query_parms<R> *) query_parms;
+ auto bs = (WSSBufferState<R> *) buff_state;
+
+ std::vector<size_t> shard_sample_sizes = {0};
+ size_t buffer_sz = 0;
+
+ std::vector<decltype(R::weight)> weights;
+ weights.push_back(bs->total_weight);
+
+ decltype(R::weight) total_weight;
+ for (auto &s : shard_states) {
+ auto state = (WSSState<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
+
+
+ bs->sample_size = buffer_sz;
+ size_t i=1;
+ for (auto &s : shard_states) {
+ auto state = (WSSState<R> *) s;
+ state->sample_size = shard_sample_sizes[i++];
+ }
+ }
+
static std::vector<Wrapped<R>> query(WSS<R> *wss, void *q_state, void *parms) {
- auto sample_sz = ((wss_query_parms<R> *) parms)->sample_size;
+ auto sample_size = ((WSSState<R> *) q_state)->sample_size;
auto rng = ((wss_query_parms<R> *) parms)->rng;
auto state = (WSSState<R> *) q_state;
std::vector<Wrapped<R>> result_set;
- if (sample_sz == 0) {
+ if (sample_size == 0) {
return result_set;
}
size_t attempts = 0;
@@ -329,7 +373,7 @@ public:
attempts++;
size_t idx = wss->m_alias->get(rng);
result_set.emplace_back(*wss->get_record_at(idx));
- } while (attempts < sample_sz);
+ } while (attempts < sample_size);
return result_set;
}
@@ -339,10 +383,10 @@ public:
auto p = (wss_query_parms<R> *) parms;
std::vector<Wrapped<R>> result;
- result.reserve(p->sample_size);
+ result.reserve(st->sample_size);
if constexpr (Rejection) {
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
auto rec = buffer->get_data() + idx;
@@ -355,7 +399,7 @@ public:
return result;
}
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = st->alias->get(p->rng);
result.emplace_back(*(buffer->get_data() + idx));
}
@@ -384,11 +428,6 @@ public:
auto s = (WSSBufferState<R> *) state;
delete s;
}
-
-
- //{q.get_buffer_query_state(p, p)};
- //{q.buffer_query(p, p)};
-
};
}