diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-06-07 11:39:25 -0400 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-06-07 11:39:25 -0400 |
| commit | 1a791e7241fb9898f58cd4642cf8cf8ec2a1c885 (patch) | |
| tree | fbc59c0c52e2db66b252a7b47243c293ea008797 /include/shard/WSS.h | |
| parent | 1800af2e9503302274e7ba35eed45aa5839af23d (diff) | |
| download | dynamic-extension-1a791e7241fb9898f58cd4642cf8cf8ec2a1c885.tar.gz | |
Added a pre-query hook for processing states
This is used for setting up the query alias structure stuff for sampling
queries.
Diffstat (limited to 'include/shard/WSS.h')
| -rw-r--r-- | include/shard/WSS.h | 71 |
1 files changed, 55 insertions, 16 deletions
diff --git a/include/shard/WSS.h b/include/shard/WSS.h index bb7ee2a..1069897 100644 --- a/include/shard/WSS.h +++ b/include/shard/WSS.h @@ -41,18 +41,21 @@ class WSSQuery; template <WeightedRecordInterface R> struct WSSState { - decltype(R::weight) tot_weight; + decltype(R::weight) total_weight; + size_t sample_size; WSSState() { - tot_weight = 0; + total_weight = 0; } }; template <WeightedRecordInterface R> struct WSSBufferState { size_t cutoff; + size_t sample_size; Alias* alias; decltype(R::weight) max_weight; + decltype(R::weight) total_weight; ~WSSBufferState() { delete alias; @@ -296,16 +299,16 @@ public: std::vector<double> weights; state->cutoff = buffer->get_record_count() - 1; - double tot_weight = 0.0; + double total_weight = 0.0; for (size_t i = 0; i <= state->cutoff; i++) { auto rec = buffer->get_data() + i; weights.push_back(rec->rec.weight); - tot_weight += rec->rec.weight; + total_weight += rec->rec.weight; } for (size_t i = 0; i < weights.size(); i++) { - weights[i] = weights[i] / tot_weight; + weights[i] = weights[i] / total_weight; } state->alias = new Alias(weights); @@ -313,15 +316,56 @@ public: return state; } + static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) { + auto p = (wss_query_parms<R> *) query_parms; + auto bs = (WSSBufferState<R> *) buff_state; + + std::vector<size_t> shard_sample_sizes = {0}; + size_t buffer_sz = 0; + + std::vector<decltype(R::weight)> weights; + weights.push_back(bs->total_weight); + + decltype(R::weight) total_weight; + for (auto &s : shard_states) { + auto state = (WSSState<R> *) s; + total_weight += state->total_weight; + weights.push_back(state->total_weight); + } + + std::vector<double> normalized_weights; + for (auto w : weights) { + normalized_weights.push_back((double) w / (double) total_weight); + } + + auto shard_alias = Alias(normalized_weights); + for (size_t i=0; i<p->sample_size; i++) { + auto idx = shard_alias.get(p->rng); + if (idx == 0) { + buffer_sz++; + } else { + shard_sample_sizes[idx - 1]++; + } + } + + + bs->sample_size = buffer_sz; + size_t i=1; + for (auto &s : shard_states) { + auto state = (WSSState<R> *) s; + state->sample_size = shard_sample_sizes[i++]; + } + } + static std::vector<Wrapped<R>> query(WSS<R> *wss, void *q_state, void *parms) { - auto sample_sz = ((wss_query_parms<R> *) parms)->sample_size; + auto sample_size = ((WSSState<R> *) q_state)->sample_size; auto rng = ((wss_query_parms<R> *) parms)->rng; auto state = (WSSState<R> *) q_state; std::vector<Wrapped<R>> result_set; - if (sample_sz == 0) { + if (sample_size == 0) { return result_set; } size_t attempts = 0; @@ -329,7 +373,7 @@ public: attempts++; size_t idx = wss->m_alias->get(rng); result_set.emplace_back(*wss->get_record_at(idx)); - } while (attempts < sample_sz); + } while (attempts < sample_size); return result_set; } @@ -339,10 +383,10 @@ public: auto p = (wss_query_parms<R> *) parms; std::vector<Wrapped<R>> result; - result.reserve(p->sample_size); + result.reserve(st->sample_size); if constexpr (Rejection) { - for (size_t i=0; i<p->sample_size; i++) { + for (size_t i=0; i<st->sample_size; i++) { auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); auto rec = buffer->get_data() + idx; @@ -355,7 +399,7 @@ public: return result; } - for (size_t i=0; i<p->sample_size; i++) { + for (size_t i=0; i<st->sample_size; i++) { auto idx = st->alias->get(p->rng); result.emplace_back(*(buffer->get_data() + idx)); } @@ -384,11 +428,6 @@ public: auto s = (WSSBufferState<R> *) state; delete s; } - - - //{q.get_buffer_query_state(p, p)}; - //{q.buffer_query(p, p)}; - }; } |