diff options
| author | Douglas B. Rumbaugh <dbr4@psu.edu> | 2024-05-14 16:31:05 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-14 16:31:05 -0400 |
| commit | 47916da2ba5ed5bee2dda3cbcc58d39e1e931bfc (patch) | |
| tree | ee5613ce182b2c9caa228d3abeb65dc27fef2db3 /include/query/irs.h | |
| parent | 4a834497d5f82c817d634925250158d85ca825c2 (diff) | |
| parent | 8643fe194dec05b4e3f3ea31e162ac0b2b00e162 (diff) | |
| download | dynamic-extension-47916da2ba5ed5bee2dda3cbcc58d39e1e931bfc.tar.gz | |
Merge pull request #4 from dbrumbaugh/master
Updates for VLDB revision
Diffstat (limited to 'include/query/irs.h')
| -rw-r--r-- | include/query/irs.h | 85 |
1 files changed, 56 insertions, 29 deletions
diff --git a/include/query/irs.h b/include/query/irs.h index e2d9325..879d070 100644 --- a/include/query/irs.h +++ b/include/query/irs.h @@ -40,7 +40,12 @@ struct BufferState { size_t sample_size; BufferView<R> *buffer; + psudb::Alias *alias; + BufferState(BufferView<R> *buffer) : buffer(buffer) {} + ~BufferState() { + delete alias; + } }; template <RecordInterface R, ShardInterface<R> S, bool Rejection=true> @@ -72,6 +77,7 @@ public: res->cutoff = res->buffer->get_record_count(); res->sample_size = 0; + res->alias = nullptr; if constexpr (Rejection) { return res; @@ -96,39 +102,51 @@ public: std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0); size_t buffer_sz = 0; - std::vector<size_t> weights; - if constexpr (Rejection) { - weights.push_back((bs) ? bs->cutoff : 0); - } else { - weights.push_back((bs) ? bs->records.size() : 0); + /* for simplicity of static structure testing */ + if (!bs) { + assert(shard_states.size() == 1); + auto state = (State<R> *) shard_states[0]; + state->sample_size = p->sample_size; + return; } - size_t total_weight = 0; - for (auto &s : shard_states) { - auto state = (State<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } + /* we only need to build the shard alias on the first call */ + if (bs->alias == nullptr) { + std::vector<size_t> weights; + if constexpr (Rejection) { + weights.push_back((bs) ? bs->cutoff : 0); + } else { + weights.push_back((bs) ? bs->records.size() : 0); + } - // if no valid records fall within the query range, just - // set all of the sample sizes to 0 and bail out. - if (total_weight == 0) { - for (size_t i=0; i<shard_states.size(); i++) { - auto state = (State<R> *) shard_states[i]; - state->sample_size = 0; + size_t total_weight = weights[0]; + for (auto &s : shard_states) { + auto state = (State<R> *) s; + total_weight += state->total_weight; + weights.push_back(state->total_weight); } - return; - } + // if no valid records fall within the query range, just + // set all of the sample sizes to 0 and bail out. + if (total_weight == 0) { + for (size_t i=0; i<shard_states.size(); i++) { + auto state = (State<R> *) shard_states[i]; + state->sample_size = 0; + } - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); + return; + } + + std::vector<double> normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double) w / (double) total_weight); + } + + bs->alias = new psudb::Alias(normalized_weights); } - auto shard_alias = psudb::Alias(normalized_weights); for (size_t i=0; i<p->sample_size; i++) { - auto idx = shard_alias.get(p->rng); + auto idx = bs->alias->get(p->rng); if (idx == 0) { buffer_sz++; } else { @@ -198,16 +216,12 @@ public: return result; } - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { - std::vector<R> output; - + static void merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { for (size_t i=0; i<results.size(); i++) { for (size_t j=0; j<results[i].size(); j++) { output.emplace_back(results[i][j].rec); } } - - return output; } static void delete_query_state(void *state) { @@ -219,5 +233,18 @@ public: auto s = (BufferState<R> *) state; delete s; } + + static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { + auto p = (Parms<R> *) parms; + + if (results.size() < p->sample_size) { + auto q = *p; + q.sample_size -= results.size(); + process_query_states(&q, states, buffer_state); + return true; + } + + return false; + } }; }} |