summaryrefslogtreecommitdiffstats
path: root/include/query/irs.h
diff options
context:
space:
mode:
authorDouglas B. Rumbaugh <dbr4@psu.edu>2024-05-14 16:31:05 -0400
committerGitHub <noreply@github.com>2024-05-14 16:31:05 -0400
commit47916da2ba5ed5bee2dda3cbcc58d39e1e931bfc (patch)
treeee5613ce182b2c9caa228d3abeb65dc27fef2db3 /include/query/irs.h
parent4a834497d5f82c817d634925250158d85ca825c2 (diff)
parent8643fe194dec05b4e3f3ea31e162ac0b2b00e162 (diff)
downloaddynamic-extension-47916da2ba5ed5bee2dda3cbcc58d39e1e931bfc.tar.gz
Merge pull request #4 from dbrumbaugh/master
Updates for VLDB revision
Diffstat (limited to 'include/query/irs.h')
-rw-r--r--include/query/irs.h85
1 files changed, 56 insertions, 29 deletions
diff --git a/include/query/irs.h b/include/query/irs.h
index e2d9325..879d070 100644
--- a/include/query/irs.h
+++ b/include/query/irs.h
@@ -40,7 +40,12 @@ struct BufferState {
size_t sample_size;
BufferView<R> *buffer;
+ psudb::Alias *alias;
+
BufferState(BufferView<R> *buffer) : buffer(buffer) {}
+ ~BufferState() {
+ delete alias;
+ }
};
template <RecordInterface R, ShardInterface<R> S, bool Rejection=true>
@@ -72,6 +77,7 @@ public:
res->cutoff = res->buffer->get_record_count();
res->sample_size = 0;
+ res->alias = nullptr;
if constexpr (Rejection) {
return res;
@@ -96,39 +102,51 @@ public:
std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
size_t buffer_sz = 0;
- std::vector<size_t> weights;
- if constexpr (Rejection) {
- weights.push_back((bs) ? bs->cutoff : 0);
- } else {
- weights.push_back((bs) ? bs->records.size() : 0);
+ /* for simplicity of static structure testing */
+ if (!bs) {
+ assert(shard_states.size() == 1);
+ auto state = (State<R> *) shard_states[0];
+ state->sample_size = p->sample_size;
+ return;
}
- size_t total_weight = 0;
- for (auto &s : shard_states) {
- auto state = (State<R> *) s;
- total_weight += state->total_weight;
- weights.push_back(state->total_weight);
- }
+ /* we only need to build the shard alias on the first call */
+ if (bs->alias == nullptr) {
+ std::vector<size_t> weights;
+ if constexpr (Rejection) {
+ weights.push_back((bs) ? bs->cutoff : 0);
+ } else {
+ weights.push_back((bs) ? bs->records.size() : 0);
+ }
- // if no valid records fall within the query range, just
- // set all of the sample sizes to 0 and bail out.
- if (total_weight == 0) {
- for (size_t i=0; i<shard_states.size(); i++) {
- auto state = (State<R> *) shard_states[i];
- state->sample_size = 0;
+ size_t total_weight = weights[0];
+ for (auto &s : shard_states) {
+ auto state = (State<R> *) s;
+ total_weight += state->total_weight;
+ weights.push_back(state->total_weight);
}
- return;
- }
+ // if no valid records fall within the query range, just
+ // set all of the sample sizes to 0 and bail out.
+ if (total_weight == 0) {
+ for (size_t i=0; i<shard_states.size(); i++) {
+ auto state = (State<R> *) shard_states[i];
+ state->sample_size = 0;
+ }
- std::vector<double> normalized_weights;
- for (auto w : weights) {
- normalized_weights.push_back((double) w / (double) total_weight);
+ return;
+ }
+
+ std::vector<double> normalized_weights;
+ for (auto w : weights) {
+ normalized_weights.push_back((double) w / (double) total_weight);
+ }
+
+ bs->alias = new psudb::Alias(normalized_weights);
}
- auto shard_alias = psudb::Alias(normalized_weights);
for (size_t i=0; i<p->sample_size; i++) {
- auto idx = shard_alias.get(p->rng);
+ auto idx = bs->alias->get(p->rng);
if (idx == 0) {
buffer_sz++;
} else {
@@ -198,16 +216,12 @@ public:
return result;
}
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
- std::vector<R> output;
-
+ static void merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
for (size_t i=0; i<results.size(); i++) {
for (size_t j=0; j<results[i].size(); j++) {
output.emplace_back(results[i][j].rec);
}
}
-
- return output;
}
static void delete_query_state(void *state) {
@@ -219,5 +233,18 @@ public:
auto s = (BufferState<R> *) state;
delete s;
}
+
+ static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
+ auto p = (Parms<R> *) parms;
+
+ if (results.size() < p->sample_size) {
+ auto q = *p;
+ q.sample_size -= results.size();
+ process_query_states(&q, states, buffer_state);
+ return true;
+ }
+
+ return false;
+ }
};
}}