summaryrefslogtreecommitdiffstats
path: root/include/shard/WIRS.h
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-05-29 16:37:43 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-05-29 16:37:43 -0400
commitab14529843d7bbb3a0d0c30b163ed444afee04ed (patch)
treecfb3019e4f3cac70f22869512fdce94c5aed8516 /include/shard/WIRS.h
parent85942ad2cb99b8a0984579d7dba9504f351ffac0 (diff)
downloaddynamic-extension-ab14529843d7bbb3a0d0c30b163ed444afee04ed.tar.gz
WIRS Query tests + fixes
Diffstat (limited to 'include/shard/WIRS.h')
-rw-r--r--include/shard/WIRS.h115
1 files changed, 95 insertions, 20 deletions
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index 9e4d911..7e3f468 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -35,9 +35,10 @@ struct wirs_query_parms {
decltype(R::key) lower_bound;
decltype(R::key) upper_bound;
size_t sample_size;
+ gsl_rng *rng;
};
-template <WeightedRecordInterface R>
+template <WeightedRecordInterface R, bool Rejection>
class WIRSQuery;
template <WeightedRecordInterface R>
@@ -60,6 +61,19 @@ struct WIRSState {
};
template <WeightedRecordInterface R>
+struct WIRSBufferState {
+ size_t cutoff;
+ Alias* alias;
+ std::vector<Wrapped<R>> records;
+ decltype(R::weight) max_weight;
+
+ ~WIRSBufferState() {
+ delete alias;
+ }
+
+};
+
+template <WeightedRecordInterface R>
class WIRS {
private:
@@ -69,7 +83,9 @@ private:
public:
- friend class WIRSQuery<R>;
+ // FIXME: there has to be a better way to do this
+ friend class WIRSQuery<R, true>;
+ friend class WIRSQuery<R, false>;
WIRS(MutableBuffer<R>* buffer)
: m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) {
@@ -251,13 +267,13 @@ private:
bool covered_by(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) {
auto low_index = node->low * m_group_size;
auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1);
- return lower_key < m_data[low_index].key && m_data[high_index].key < upper_key;
+ return lower_key < m_data[low_index].rec.key && m_data[high_index].rec.key < upper_key;
}
bool intersects(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) {
auto low_index = node->low * m_group_size;
auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1);
- return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key;
+ return lower_key < m_data[high_index].rec.key && m_data[low_index].rec.key < upper_key;
}
void build_wirs_structure() {
@@ -337,7 +353,7 @@ private:
};
-template <WeightedRecordInterface R>
+template <WeightedRecordInterface R, bool Rejection=true>
class WIRSQuery {
public:
static void *get_query_state(WIRS<R> *wirs, void *parms) {
@@ -373,10 +389,39 @@ public:
}
static void* get_buffer_query_state(MutableBuffer<R> *buffer, void *parms) {
- return nullptr;
+ WIRSBufferState<R> *state = new WIRSBufferState<R>();
+ auto parameters = (wirs_query_parms<R>*) parms;
+ if constexpr (Rejection) {
+ state->cutoff = buffer->get_record_count() - 1;
+ state->max_weight = buffer->get_max_weight();
+ return state;
+ }
+
+ std::vector<double> weights;
+
+ state->cutoff = buffer->get_record_count() - 1;
+ double tot_weight = 0.0;
+
+ for (size_t i = 0; i <= state->cutoff; i++) {
+ auto rec = buffer->get_data() + i;
+
+ if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) {
+ weights.push_back(rec->rec.weight);
+ state->records.push_back(*rec);
+ tot_weight += rec->rec.weight;
+ }
+ }
+
+ for (size_t i = 0; i < weights.size(); i++) {
+ weights[i] = weights[i] / tot_weight;
+ }
+
+ state->alias = new Alias(weights);
+
+ return state;
}
- static std::vector<Wrapped<R>> *query(WIRS<R> *wirs, void *q_state, void *parms) {
+ static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) {
auto sample_sz = ((wirs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound;
auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
@@ -384,10 +429,10 @@ public:
auto state = (WIRSState<R> *) q_state;
- std::vector<Wrapped<R>> *result_set = new std::vector<Wrapped<R>>();
+ std::vector<Wrapped<R>> result_set;
if (sample_sz == 0) {
- return 0;
+ return result_set;
}
// k -> sampling: three levels. 1. select a node -> select a fat point -> select a record.
size_t cnt = 0;
@@ -403,36 +448,66 @@ public:
auto record = wirs->m_data + rec_offset;
// bounds rejection
- if (lower_key > record->key || upper_key < record->key) {
+ if (lower_key > record->rec.key || upper_key < record->rec.key) {
continue;
}
- result_set->emplace_back(*record);
+ result_set.emplace_back(*record);
cnt++;
} while (attempts < sample_sz);
return result_set;
}
- static std::vector<Wrapped<R>> *buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) {
- return new std::vector<Wrapped<R>>();
+ static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) {
+ auto st = (WIRSBufferState<R> *) state;
+ auto p = (wirs_query_parms<R> *) parms;
+
+ std::vector<Wrapped<R>> result;
+ result.reserve(p->sample_size);
+
+ if constexpr (Rejection) {
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
+ auto rec = buffer->get_data() + idx;
+
+ auto test = gsl_rng_uniform(p->rng) * st->max_weight;
+
+ if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
+ result.emplace_back(*rec);
+ }
+ }
+ return result;
+ }
+
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = st->alias->get(p->rng);
+ result.emplace_back(st->records[idx]);
+ }
+
+ return result;
}
- static std::vector<R> merge(std::vector<std::vector<R>*> *results) {
+ static std::vector<R> merge(std::vector<std::vector<R>> &results) {
std::vector<R> output;
- for (size_t i=0; i<results->size(); i++) {
- for (size_t j=0; j<(*results)[i]->size(); j++) {
- output->emplace_back(*((*results)[i])[j]);
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results[i].size(); j++) {
+ output.emplace_back(results[i][j]);
}
}
return output;
}
- static void delete_query_state(void *parm) {
- wirs_query_parms<R> *parameters = parm;
- delete parameters;
+ static void delete_query_state(void *state) {
+ auto s = (WIRSState<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (WIRSBufferState<R> *) state;
+ delete s;
}