summaryrefslogtreecommitdiffstats
path: root/include/shard/WIRS.h
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
commit1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 (patch)
treefbc59c0c52e2db66b252a7b47243c293ea008797 /include/shard/WIRS.h
parent1800af2e9503302274e7ba35eed45aa5839af23d (diff)
downloaddynamic-extension-1a791e7241fb9898f58cd4642cf8cf8ec2a1c885.tar.gz
Added a pre-query hook for processing states
This is used for setting up the query alias structure stuff for sampling queries.
Diffstat (limited to 'include/shard/WIRS.h')
-rw-r--r--include/shard/WIRS.h72
1 files changed, 60 insertions, 12 deletions
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index f3696a4..619c2fe 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -51,12 +51,13 @@ struct wirs_node {
template <WeightedRecordInterface R>
struct WIRSState {
- decltype(R::weight) tot_weight;
+ decltype(R::weight) total_weight;
std::vector<wirs_node<R>*> nodes;
Alias* top_level_alias;
+ size_t sample_size;
WIRSState() {
- tot_weight = 0;
+ total_weight = 0;
top_level_alias = nullptr;
}
@@ -71,6 +72,8 @@ struct WIRSBufferState {
Alias* alias;
std::vector<Wrapped<R>> records;
decltype(R::weight) max_weight;
+ size_t sample_size;
+ decltype(R::weight) total_weight;
~WIRSBufferState() {
delete alias;
@@ -367,7 +370,7 @@ public:
decltype(R::key) upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
// Simulate a stack to unfold recursion.
- double tot_weight = 0.0;
+ double total_weight = 0.0;
struct wirs_node<R>* st[64] = {0};
st[0] = wirs->m_root;
size_t top = 1;
@@ -376,7 +379,7 @@ public:
if (wirs->covered_by(now, lower_key, upper_key) ||
(now->left == nullptr && now->right == nullptr && wirs->intersects(now, lower_key, upper_key))) {
res->nodes.emplace_back(now);
- tot_weight += now->weight;
+ total_weight += now->weight;
} else {
if (now->left && wirs->intersects(now->left, lower_key, upper_key)) st[top++] = now->left;
if (now->right && wirs->intersects(now->right, lower_key, upper_key)) st[top++] = now->right;
@@ -385,9 +388,9 @@ public:
std::vector<double> weights;
for (const auto& node: res->nodes) {
- weights.emplace_back(node->weight / tot_weight);
+ weights.emplace_back(node->weight / total_weight);
}
- res->tot_weight = tot_weight;
+ res->total_weight = total_weight;
res->top_level_alias = new Alias(weights);
return res;
@@ -399,13 +402,14 @@ public:
if constexpr (Rejection) {
state->cutoff = buffer->get_record_count() - 1;
state->max_weight = buffer->get_max_weight();
+ state->total_weight = buffer->get_total_weight();
return state;
}
std::vector<double> weights;
state->cutoff = buffer->get_record_count() - 1;
- double tot_weight = 0.0;
+ double total_weight = 0.0;
for (size_t i = 0; i <= state->cutoff; i++) {
auto rec = buffer->get_data() + i;
@@ -413,21 +417,65 @@ public:
if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) {
weights.push_back(rec->rec.weight);
state->records.push_back(*rec);
- tot_weight += rec->rec.weight;
+ total_weight += rec->rec.weight;
}
}
for (size_t i = 0; i < weights.size(); i++) {
- weights[i] = weights[i] / tot_weight;
+ weights[i] = weights[i] / total_weight;
}
+ state->total_weight = total_weight;
state->alias = new Alias(weights);
return state;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ auto p = (wirs_query_parms<R> *) query_parms;
+ auto bs = (WIRSBufferState<R> *) buff_state;
+
+ std::vector<size_t> shard_sample_sizes = {0};
+ size_t buffer_sz = 0;
+
+ std::vector<decltype(R::weight)> weights;
+ weights.push_back(bs->total_weight);
+
+ decltype(R::weight) total_weight;
+ for (auto &s : shard_states) {
+ auto state = (WIRSState<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
+
+
+ bs->sample_size = buffer_sz;
+ size_t i=1;
+ for (auto &s : shard_states) {
+ auto state = (WIRSState<R> *) s;
+ state->sample_size = shard_sample_sizes[i++];
+ }
+ }
+
+
+
static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) {
- auto sample_sz = ((wirs_query_parms<R> *) parms)->sample_size;
+ auto sample_size = ((wirs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound;
auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
auto rng = ((wirs_query_parms<R> *) parms)->rng;
@@ -436,7 +484,7 @@ public:
std::vector<Wrapped<R>> result_set;
- if (sample_sz == 0) {
+ if (sample_size == 0) {
return result_set;
}
// k -> sampling: three levels. 1. select a node -> select a fat point -> select a record.
@@ -459,7 +507,7 @@ public:
result_set.emplace_back(*record);
cnt++;
- } while (attempts < sample_sz);
+ } while (attempts < sample_size);
return result_set;
}