summaryrefslogtreecommitdiffstats
path: root/include
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
commit1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 (patch)
treefbc59c0c52e2db66b252a7b47243c293ea008797 /include
parent1800af2e9503302274e7ba35eed45aa5839af23d (diff)
downloaddynamic-extension-1a791e7241fb9898f58cd4642cf8cf8ec2a1c885.tar.gz
Added a pre-query hook for processing states
This is used for setting up the query alias structure stuff for sampling queries.
Diffstat (limited to 'include')
-rw-r--r--include/framework/DynamicExtension.h4
-rw-r--r--include/shard/MemISAM.h46
-rw-r--r--include/shard/PGM.h4
-rw-r--r--include/shard/TrieSpline.h4
-rw-r--r--include/shard/WIRS.h72
-rw-r--r--include/shard/WSS.h71
6 files changed, 172 insertions, 29 deletions
diff --git a/include/framework/DynamicExtension.h b/include/framework/DynamicExtension.h
index 4f3a3bc..a345da6 100644
--- a/include/framework/DynamicExtension.h
+++ b/include/framework/DynamicExtension.h
@@ -56,7 +56,7 @@ static constexpr bool LSM_REJ_SAMPLE = false;
// True for leveling, false for tiering
static constexpr bool LSM_LEVELING = false;
-static constexpr bool DELETE_TAGGING = false;
+static constexpr bool DELETE_TAGGING = true;
// TODO: Replace the constexpr bools above
// with template parameters based on these
@@ -142,6 +142,8 @@ public:
level->get_query_states(shards, states, parms);
}
+ Q::process_query_states(parms, states, buffer_state);
+
std::vector<std::vector<R>> query_results(shards.size() + 1);
// Execute the query for the buffer
diff --git a/include/shard/MemISAM.h b/include/shard/MemISAM.h
index 01a539a..ae1c682 100644
--- a/include/shard/MemISAM.h
+++ b/include/shard/MemISAM.h
@@ -39,12 +39,14 @@ template <RecordInterface R>
struct IRSState {
size_t lower_bound;
size_t upper_bound;
+ size_t sample_size;
};
template <RecordInterface R>
struct IRSBufferState {
size_t cutoff;
std::vector<Wrapped<R>> records;
+ size_t sample_size;
};
@@ -384,6 +386,50 @@ public:
return res;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ auto p = (irs_query_parms<R> *) query_parms;
+ auto bs = (IRSBufferState<R> *) buff_state;
+
+ std::vector<size_t> shard_sample_sizes = {0};
+ size_t buffer_sz = 0;
+
+ std::vector<size_t> weights;
+ if (Rejection) {
+ weights.push_back(bs->cutoff);
+ } else {
+ weights.push_back(bs->records.size());
+ }
+
+ decltype(R::weight) total_weight;
+ for (auto &s : shard_states) {
+ auto state = (IRSState<R> *) s;
+ total_weight += state->upper_bound - state->lower_bound;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
+
+
+ bs->sample_size = buffer_sz;
+ size_t i=1;
+ for (auto &s : shard_states) {
+ auto state = (IRSState<R> *) s;
+ state->sample_size = shard_sample_sizes[i++];
+ }
+ }
static std::vector<Wrapped<R>> query(MemISAM<R> *isam, void *q_state, void *parms) {
auto sample_sz = ((irs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((irs_query_parms<R> *) parms)->lower_bound;
diff --git a/include/shard/PGM.h b/include/shard/PGM.h
index f9e1dad..8b0bd69 100644
--- a/include/shard/PGM.h
+++ b/include/shard/PGM.h
@@ -286,6 +286,10 @@ public:
return res;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ return;
+ }
+
static std::vector<Wrapped<R>> query(PGM<R> *ts, void *q_state, void *parms) {
std::vector<Wrapped<R>> records;
auto p = (pgm_range_query_parms<R> *) parms;
diff --git a/include/shard/TrieSpline.h b/include/shard/TrieSpline.h
index fb0ed70..2341751 100644
--- a/include/shard/TrieSpline.h
+++ b/include/shard/TrieSpline.h
@@ -305,6 +305,10 @@ public:
return res;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ return;
+ }
+
static std::vector<Wrapped<R>> query(TrieSpline<R> *ts, void *q_state, void *parms) {
std::vector<Wrapped<R>> records;
auto p = (ts_range_query_parms<R> *) parms;
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index f3696a4..619c2fe 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -51,12 +51,13 @@ struct wirs_node {
template <WeightedRecordInterface R>
struct WIRSState {
- decltype(R::weight) tot_weight;
+ decltype(R::weight) total_weight;
std::vector<wirs_node<R>*> nodes;
Alias* top_level_alias;
+ size_t sample_size;
WIRSState() {
- tot_weight = 0;
+ total_weight = 0;
top_level_alias = nullptr;
}
@@ -71,6 +72,8 @@ struct WIRSBufferState {
Alias* alias;
std::vector<Wrapped<R>> records;
decltype(R::weight) max_weight;
+ size_t sample_size;
+ decltype(R::weight) total_weight;
~WIRSBufferState() {
delete alias;
@@ -367,7 +370,7 @@ public:
decltype(R::key) upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
// Simulate a stack to unfold recursion.
- double tot_weight = 0.0;
+ double total_weight = 0.0;
struct wirs_node<R>* st[64] = {0};
st[0] = wirs->m_root;
size_t top = 1;
@@ -376,7 +379,7 @@ public:
if (wirs->covered_by(now, lower_key, upper_key) ||
(now->left == nullptr && now->right == nullptr && wirs->intersects(now, lower_key, upper_key))) {
res->nodes.emplace_back(now);
- tot_weight += now->weight;
+ total_weight += now->weight;
} else {
if (now->left && wirs->intersects(now->left, lower_key, upper_key)) st[top++] = now->left;
if (now->right && wirs->intersects(now->right, lower_key, upper_key)) st[top++] = now->right;
@@ -385,9 +388,9 @@ public:
std::vector<double> weights;
for (const auto& node: res->nodes) {
- weights.emplace_back(node->weight / tot_weight);
+ weights.emplace_back(node->weight / total_weight);
}
- res->tot_weight = tot_weight;
+ res->total_weight = total_weight;
res->top_level_alias = new Alias(weights);
return res;
@@ -399,13 +402,14 @@ public:
if constexpr (Rejection) {
state->cutoff = buffer->get_record_count() - 1;
state->max_weight = buffer->get_max_weight();
+ state->total_weight = buffer->get_total_weight();
return state;
}
std::vector<double> weights;
state->cutoff = buffer->get_record_count() - 1;
- double tot_weight = 0.0;
+ double total_weight = 0.0;
for (size_t i = 0; i <= state->cutoff; i++) {
auto rec = buffer->get_data() + i;
@@ -413,21 +417,65 @@ public:
if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) {
weights.push_back(rec->rec.weight);
state->records.push_back(*rec);
- tot_weight += rec->rec.weight;
+ total_weight += rec->rec.weight;
}
}
for (size_t i = 0; i < weights.size(); i++) {
- weights[i] = weights[i] / tot_weight;
+ weights[i] = weights[i] / total_weight;
}
+ state->total_weight = total_weight;
state->alias = new Alias(weights);
return state;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ auto p = (wirs_query_parms<R> *) query_parms;
+ auto bs = (WIRSBufferState<R> *) buff_state;
+
+ std::vector<size_t> shard_sample_sizes = {0};
+ size_t buffer_sz = 0;
+
+ std::vector<decltype(R::weight)> weights;
+ weights.push_back(bs->total_weight);
+
+ decltype(R::weight) total_weight;
+ for (auto &s : shard_states) {
+ auto state = (WIRSState<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
+
+
+ bs->sample_size = buffer_sz;
+ size_t i=1;
+ for (auto &s : shard_states) {
+ auto state = (WIRSState<R> *) s;
+ state->sample_size = shard_sample_sizes[i++];
+ }
+ }
+
+
+
static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) {
- auto sample_sz = ((wirs_query_parms<R> *) parms)->sample_size;
+ auto sample_size = ((wirs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound;
auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
auto rng = ((wirs_query_parms<R> *) parms)->rng;
@@ -436,7 +484,7 @@ public:
std::vector<Wrapped<R>> result_set;
- if (sample_sz == 0) {
+ if (sample_size == 0) {
return result_set;
}
// k -> sampling: three levels. 1. select a node -> select a fat point -> select a record.
@@ -459,7 +507,7 @@ public:
result_set.emplace_back(*record);
cnt++;
- } while (attempts < sample_sz);
+ } while (attempts < sample_size);
return result_set;
}
diff --git a/include/shard/WSS.h b/include/shard/WSS.h
index bb7ee2a..1069897 100644
--- a/include/shard/WSS.h
+++ b/include/shard/WSS.h
@@ -41,18 +41,21 @@ class WSSQuery;
template <WeightedRecordInterface R>
struct WSSState {
- decltype(R::weight) tot_weight;
+ decltype(R::weight) total_weight;
+ size_t sample_size;
WSSState() {
- tot_weight = 0;
+ total_weight = 0;
}
};
template <WeightedRecordInterface R>
struct WSSBufferState {
size_t cutoff;
+ size_t sample_size;
Alias* alias;
decltype(R::weight) max_weight;
+ decltype(R::weight) total_weight;
~WSSBufferState() {
delete alias;
@@ -296,16 +299,16 @@ public:
std::vector<double> weights;
state->cutoff = buffer->get_record_count() - 1;
- double tot_weight = 0.0;
+ double total_weight = 0.0;
for (size_t i = 0; i <= state->cutoff; i++) {
auto rec = buffer->get_data() + i;
weights.push_back(rec->rec.weight);
- tot_weight += rec->rec.weight;
+ total_weight += rec->rec.weight;
}
for (size_t i = 0; i < weights.size(); i++) {
- weights[i] = weights[i] / tot_weight;
+ weights[i] = weights[i] / total_weight;
}
state->alias = new Alias(weights);
@@ -313,15 +316,56 @@ public:
return state;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ auto p = (wss_query_parms<R> *) query_parms;
+ auto bs = (WSSBufferState<R> *) buff_state;
+
+ std::vector<size_t> shard_sample_sizes = {0};
+ size_t buffer_sz = 0;
+
+ std::vector<decltype(R::weight)> weights;
+ weights.push_back(bs->total_weight);
+
+ decltype(R::weight) total_weight;
+ for (auto &s : shard_states) {
+ auto state = (WSSState<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
+
+
+ bs->sample_size = buffer_sz;
+ size_t i=1;
+ for (auto &s : shard_states) {
+ auto state = (WSSState<R> *) s;
+ state->sample_size = shard_sample_sizes[i++];
+ }
+ }
+
static std::vector<Wrapped<R>> query(WSS<R> *wss, void *q_state, void *parms) {
- auto sample_sz = ((wss_query_parms<R> *) parms)->sample_size;
+ auto sample_size = ((WSSState<R> *) q_state)->sample_size;
auto rng = ((wss_query_parms<R> *) parms)->rng;
auto state = (WSSState<R> *) q_state;
std::vector<Wrapped<R>> result_set;
- if (sample_sz == 0) {
+ if (sample_size == 0) {
return result_set;
}
size_t attempts = 0;
@@ -329,7 +373,7 @@ public:
attempts++;
size_t idx = wss->m_alias->get(rng);
result_set.emplace_back(*wss->get_record_at(idx));
- } while (attempts < sample_sz);
+ } while (attempts < sample_size);
return result_set;
}
@@ -339,10 +383,10 @@ public:
auto p = (wss_query_parms<R> *) parms;
std::vector<Wrapped<R>> result;
- result.reserve(p->sample_size);
+ result.reserve(st->sample_size);
if constexpr (Rejection) {
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
auto rec = buffer->get_data() + idx;
@@ -355,7 +399,7 @@ public:
return result;
}
- for (size_t i=0; i<p->sample_size; i++) {
+ for (size_t i=0; i<st->sample_size; i++) {
auto idx = st->alias->get(p->rng);
result.emplace_back(*(buffer->get_data() + idx));
}
@@ -384,11 +428,6 @@ public:
auto s = (WSSBufferState<R> *) state;
delete s;
}
-
-
- //{q.get_buffer_query_state(p, p)};
- //{q.buffer_query(p, p)};
-
};
}