summaryrefslogtreecommitdiffstats
path: root/include
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-11-07 12:29:03 -0500
committerDouglas Rumbaugh <dbr4@psu.edu>2023-11-07 12:29:03 -0500
commita2fe4b1616a1b2318f70e842382818ee44aea9e6 (patch)
tree40a3dcac716ded595d917d845b255f54b941260a /include
parente02742b07540dd5a9bcbb44dae14856bf10955ed (diff)
downloaddynamic-extension-a2fe4b1616a1b2318f70e842382818ee44aea9e6.tar.gz
Alias shard fixes
Diffstat (limited to 'include')
-rw-r--r--include/query/wss.h28
-rw-r--r--include/shard/Alias.h13
2 files changed, 25 insertions, 16 deletions
diff --git a/include/query/wss.h b/include/query/wss.h
index b8a5d54..794485c 100644
--- a/include/query/wss.h
+++ b/include/query/wss.h
@@ -90,15 +90,19 @@ public:
static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buffer_states) {
auto p = (Parms<R> *) query_parms;
- auto bs = (BufferState<R> *) buffer_states[0];
- std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0);
+ std::vector<size_t> shard_sample_sizes(shard_states.size()+buffer_states.size(), 0);
size_t buffer_sz = 0;
std::vector<decltype(R::weight)> weights;
- weights.push_back(bs->total_weight);
decltype(R::weight) total_weight = 0;
+ for (auto &s : buffer_states) {
+ auto bs = (BufferState<R> *) s;
+ total_weight += bs->total_weight;
+ weights.push_back(bs->total_weight);
+ }
+
for (auto &s : shard_states) {
auto state = (State<R> *) s;
total_weight += state->total_weight;
@@ -113,19 +117,15 @@ public:
auto shard_alias = psudb::Alias(normalized_weights);
for (size_t i=0; i<p->sample_size; i++) {
auto idx = shard_alias.get(p->rng);
- if (idx == 0) {
- buffer_sz++;
+
+ if (idx < buffer_states.size()) {
+ auto state = (BufferState<R> *) buffer_states[idx];
+ state->sample_size++;
} else {
- shard_sample_sizes[idx - 1]++;
+ auto state = (State<R> *) shard_states[idx - buffer_states.size()];
+ state->sample_size++;
}
}
-
-
- bs->sample_size = buffer_sz;
- for (size_t i=0; i<shard_states.size(); i++) {
- auto state = (State<R> *) shard_states[i];
- state->sample_size = shard_sample_sizes[i+1];
- }
}
static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
@@ -142,7 +142,7 @@ public:
size_t attempts = 0;
do {
attempts++;
- size_t idx = shard->m_alias->get(rng);
+ size_t idx = shard->get_weighted_sample(rng);
result_set.emplace_back(*shard->get_record_at(idx));
} while (attempts < sample_size);
diff --git a/include/shard/Alias.h b/include/shard/Alias.h
index b6b16c5..a4a7d02 100644
--- a/include/shard/Alias.h
+++ b/include/shard/Alias.h
@@ -19,7 +19,7 @@
#include "psu-ds/PriorityQueue.h"
#include "util/Cursor.h"
-#include "psu-ds/psudb::Alias.h"
+#include "psu-ds/Alias.h"
#include "psu-ds/BloomFilter.h"
#include "util/bf_config.h"
@@ -207,7 +207,13 @@ public:
return 0;
}
-private:
+ W get_total_weight() {
+ return m_total_weight;
+ }
+
+ size_t get_weighted_sample(gsl_rng *rng) const {
+ return m_alias->get(rng);
+ }
size_t get_lower_bound(const K& key) const {
size_t min = 0;
@@ -227,6 +233,8 @@ private:
return min;
}
+private:
+
void build_alias_structure(std::vector<W> &weights) {
// normalize the weights vector
@@ -249,3 +257,4 @@ private:
size_t m_alloc_size;
BloomFilter<R> *m_bf;
};
+}