summaryrefslogtreecommitdiffstats
path: root/include/shard/MemISAM.h
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-06-07 11:39:25 -0400
commit1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 (patch)
treefbc59c0c52e2db66b252a7b47243c293ea008797 /include/shard/MemISAM.h
parent1800af2e9503302274e7ba35eed45aa5839af23d (diff)
downloaddynamic-extension-1a791e7241fb9898f58cd4642cf8cf8ec2a1c885.tar.gz
Added a pre-query hook for processing states
This is used for setting up the query alias structure stuff for sampling queries.
Diffstat (limited to 'include/shard/MemISAM.h')
-rw-r--r--include/shard/MemISAM.h46
1 files changed, 46 insertions, 0 deletions
diff --git a/include/shard/MemISAM.h b/include/shard/MemISAM.h
index 01a539a..ae1c682 100644
--- a/include/shard/MemISAM.h
+++ b/include/shard/MemISAM.h
@@ -39,12 +39,14 @@ template <RecordInterface R>
struct IRSState {
size_t lower_bound;
size_t upper_bound;
+ size_t sample_size;
};
template <RecordInterface R>
struct IRSBufferState {
size_t cutoff;
std::vector<Wrapped<R>> records;
+ size_t sample_size;
};
@@ -384,6 +386,50 @@ public:
return res;
}
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ auto p = (irs_query_parms<R> *) query_parms;
+ auto bs = (IRSBufferState<R> *) buff_state;
+
+ std::vector<size_t> shard_sample_sizes = {0};
+ size_t buffer_sz = 0;
+
+ std::vector<size_t> weights;
+ if (Rejection) {
+ weights.push_back(bs->cutoff);
+ } else {
+ weights.push_back(bs->records.size());
+ }
+
+ decltype(R::weight) total_weight;
+ for (auto &s : shard_states) {
+ auto state = (IRSState<R> *) s;
+ total_weight += state->upper_bound - state->lower_bound;
+ weights.push_back(state->total_weight);
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ auto shard_alias = Alias(normalized_weights);
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = shard_alias.get(p->rng);
+ if (idx == 0) {
+ buffer_sz++;
+ } else {
+ shard_sample_sizes[idx - 1]++;
+ }
+ }
+
+
+ bs->sample_size = buffer_sz;
+ size_t i=1;
+ for (auto &s : shard_states) {
+ auto state = (IRSState<R> *) s;
+ state->sample_size = shard_sample_sizes[i++];
+ }
+ }
static std::vector<Wrapped<R>> query(MemISAM<R> *isam, void *q_state, void *parms) {
auto sample_sz = ((irs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((irs_query_parms<R> *) parms)->lower_bound;