summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/framework/MutableBuffer.h4
-rw-r--r--include/shard/WIRS.h115
-rw-r--r--tests/wirs_tests.cpp207
3 files changed, 226 insertions, 100 deletions
diff --git a/include/framework/MutableBuffer.h b/include/framework/MutableBuffer.h
index bc80922..3e0de40 100644
--- a/include/framework/MutableBuffer.h
+++ b/include/framework/MutableBuffer.h
@@ -151,6 +151,10 @@ public:
return m_data;
}
+ double get_max_weight() {
+ return m_max_weight;
+ }
+
private:
int32_t try_advance_tail() {
size_t new_tail = m_reccnt.fetch_add(1);
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index 9e4d911..7e3f468 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -35,9 +35,10 @@ struct wirs_query_parms {
decltype(R::key) lower_bound;
decltype(R::key) upper_bound;
size_t sample_size;
+ gsl_rng *rng;
};
-template <WeightedRecordInterface R>
+template <WeightedRecordInterface R, bool Rejection>
class WIRSQuery;
template <WeightedRecordInterface R>
@@ -60,6 +61,19 @@ struct WIRSState {
};
template <WeightedRecordInterface R>
+struct WIRSBufferState {
+ size_t cutoff;
+ Alias* alias;
+ std::vector<Wrapped<R>> records;
+ decltype(R::weight) max_weight;
+
+ ~WIRSBufferState() {
+ delete alias;
+ }
+
+};
+
+template <WeightedRecordInterface R>
class WIRS {
private:
@@ -69,7 +83,9 @@ private:
public:
- friend class WIRSQuery<R>;
+ // FIXME: there has to be a better way to do this
+ friend class WIRSQuery<R, true>;
+ friend class WIRSQuery<R, false>;
WIRS(MutableBuffer<R>* buffer)
: m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) {
@@ -251,13 +267,13 @@ private:
bool covered_by(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) {
auto low_index = node->low * m_group_size;
auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1);
- return lower_key < m_data[low_index].key && m_data[high_index].key < upper_key;
+ return lower_key < m_data[low_index].rec.key && m_data[high_index].rec.key < upper_key;
}
bool intersects(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) {
auto low_index = node->low * m_group_size;
auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1);
- return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key;
+ return lower_key < m_data[high_index].rec.key && m_data[low_index].rec.key < upper_key;
}
void build_wirs_structure() {
@@ -337,7 +353,7 @@ private:
};
-template <WeightedRecordInterface R>
+template <WeightedRecordInterface R, bool Rejection=true>
class WIRSQuery {
public:
static void *get_query_state(WIRS<R> *wirs, void *parms) {
@@ -373,10 +389,39 @@ public:
}
static void* get_buffer_query_state(MutableBuffer<R> *buffer, void *parms) {
- return nullptr;
+ WIRSBufferState<R> *state = new WIRSBufferState<R>();
+ auto parameters = (wirs_query_parms<R>*) parms;
+ if constexpr (Rejection) {
+ state->cutoff = buffer->get_record_count() - 1;
+ state->max_weight = buffer->get_max_weight();
+ return state;
+ }
+
+ std::vector<double> weights;
+
+ state->cutoff = buffer->get_record_count() - 1;
+ double tot_weight = 0.0;
+
+ for (size_t i = 0; i <= state->cutoff; i++) {
+ auto rec = buffer->get_data() + i;
+
+ if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) {
+ weights.push_back(rec->rec.weight);
+ state->records.push_back(*rec);
+ tot_weight += rec->rec.weight;
+ }
+ }
+
+ for (size_t i = 0; i < weights.size(); i++) {
+ weights[i] = weights[i] / tot_weight;
+ }
+
+ state->alias = new Alias(weights);
+
+ return state;
}
- static std::vector<Wrapped<R>> *query(WIRS<R> *wirs, void *q_state, void *parms) {
+ static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) {
auto sample_sz = ((wirs_query_parms<R> *) parms)->sample_size;
auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound;
auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound;
@@ -384,10 +429,10 @@ public:
auto state = (WIRSState<R> *) q_state;
- std::vector<Wrapped<R>> *result_set = new std::vector<Wrapped<R>>();
+ std::vector<Wrapped<R>> result_set;
if (sample_sz == 0) {
- return 0;
+ return result_set;
}
// k -> sampling: three levels. 1. select a node -> select a fat point -> select a record.
size_t cnt = 0;
@@ -403,36 +448,66 @@ public:
auto record = wirs->m_data + rec_offset;
// bounds rejection
- if (lower_key > record->key || upper_key < record->key) {
+ if (lower_key > record->rec.key || upper_key < record->rec.key) {
continue;
}
- result_set->emplace_back(*record);
+ result_set.emplace_back(*record);
cnt++;
} while (attempts < sample_sz);
return result_set;
}
- static std::vector<Wrapped<R>> *buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) {
- return new std::vector<Wrapped<R>>();
+ static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) {
+ auto st = (WIRSBufferState<R> *) state;
+ auto p = (wirs_query_parms<R> *) parms;
+
+ std::vector<Wrapped<R>> result;
+ result.reserve(p->sample_size);
+
+ if constexpr (Rejection) {
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = gsl_rng_uniform_int(p->rng, st->cutoff);
+ auto rec = buffer->get_data() + idx;
+
+ auto test = gsl_rng_uniform(p->rng) * st->max_weight;
+
+ if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) {
+ result.emplace_back(*rec);
+ }
+ }
+ return result;
+ }
+
+ for (size_t i=0; i<p->sample_size; i++) {
+ auto idx = st->alias->get(p->rng);
+ result.emplace_back(st->records[idx]);
+ }
+
+ return result;
}
- static std::vector<R> merge(std::vector<std::vector<R>*> *results) {
+ static std::vector<R> merge(std::vector<std::vector<R>> &results) {
std::vector<R> output;
- for (size_t i=0; i<results->size(); i++) {
- for (size_t j=0; j<(*results)[i]->size(); j++) {
- output->emplace_back(*((*results)[i])[j]);
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results[i].size(); j++) {
+ output.emplace_back(results[i][j]);
}
}
return output;
}
- static void delete_query_state(void *parm) {
- wirs_query_parms<R> *parameters = parm;
- delete parameters;
+ static void delete_query_state(void *state) {
+ auto s = (WIRSState<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (WIRSBufferState<R> *) state;
+ delete s;
}
diff --git a/tests/wirs_tests.cpp b/tests/wirs_tests.cpp
index ba4b754..0a9b1d0 100644
--- a/tests/wirs_tests.cpp
+++ b/tests/wirs_tests.cpp
@@ -44,6 +44,7 @@ START_TEST(t_mbuffer_init)
delete shard;
}
+
START_TEST(t_wirs_init)
{
size_t n = 512;
@@ -94,31 +95,6 @@ START_TEST(t_wirs_init)
delete shard4;
}
-/*
-START_TEST(t_get_lower_bound_index)
-{
- size_t n = 10000;
- auto mbuffer = create_double_seq_mbuffer<WRec>(n);
-
- ck_assert_ptr_nonnull(mbuffer);
- Shard* shard = new Shard(mbuffer);
-
- ck_assert_int_eq(shard->get_record_count(), n);
- ck_assert_int_eq(shard->get_tombstone_count(), 0);
-
- auto tbl_records = mbuffer->sorted_output();
- for (size_t i=0; i<n; i++) {
- const WRec *tbl_rec = mbuffer->get_record_at(i);
- auto pos = shard->get_lower_bound(tbl_rec->key);
- ck_assert_int_eq(shard->get_record_at(pos)->key, tbl_rec->key);
- ck_assert_int_le(pos, i);
- }
-
- delete mbuffer;
- delete shard;
-}
-
-*/
START_TEST(t_full_cancelation)
{
@@ -150,8 +126,7 @@ START_TEST(t_full_cancelation)
END_TEST
-/*
-START_TEST(t_weighted_sampling)
+START_TEST(t_wirs_query)
{
size_t n=1000;
auto buffer = create_weighted_mbuffer<WRec>(n);
@@ -163,72 +138,156 @@ START_TEST(t_weighted_sampling)
size_t k = 1000;
- std::vector<WRec> results;
- results.reserve(k);
size_t cnt[3] = {0};
+ wirs_query_parms<WRec> parms = {lower_key, upper_key, k};
+ parms.rng = gsl_rng_alloc(gsl_rng_mt19937);
+
for (size_t i=0; i<1000; i++) {
- WIRS<WRec>::wirs_query_parms parms = {lower_key, upper_key};
- auto state = shard->get_query_state(&parms);
-
- shard->get_samples(state, results, lower_key, upper_key, k, g_rng);
+ auto state = WIRSQuery<WRec>::get_query_state(shard, &parms);
+ auto result = WIRSQuery<WRec>::query(shard, state, &parms);
- for (size_t j=0; j<k; j++) {
- cnt[results[j].key - 1]++;
+ for (size_t j=0; j<result.size(); j++) {
+ cnt[result[j].rec.key - 1]++;
}
- WIRS<WRec>::delete_query_state(state);
+ WIRSQuery<WRec>::delete_query_state(state);
}
ck_assert(roughly_equal(cnt[0] / 1000, (double) k/4.0, k, .05));
ck_assert(roughly_equal(cnt[1] / 1000, (double) k/4.0, k, .05));
ck_assert(roughly_equal(cnt[2] / 1000, (double) k/2.0, k, .05));
+ gsl_rng_free(parms.rng);
delete shard;
delete buffer;
}
END_TEST
-*/
-/*
-START_TEST(t_tombstone_check)
+template <RecordInterface R>
+std::vector<R> strip_wrapping(std::vector<Wrapped<R>> vec) {
+ std::vector<R> out(vec.size());
+ for (size_t i=0; i<vec.size(); i++) {
+ out[i] = vec[i].rec;
+ }
+
+ return out;
+}
+
+
+START_TEST(t_wirs_query_merge)
{
- size_t cnt = 1024;
- size_t ts_cnt = 256;
- auto buffer = new MutableBuffer<WRec>(cnt + ts_cnt, true, ts_cnt);
-
- std::vector<std::pair<uint64_t, uint32_t>> tombstones;
-
- uint64_t key = 1000;
- uint32_t val = 101;
- for (size_t i = 0; i < cnt; i++) {
- buffer->append({key, val, 1});
- key++;
- val++;
+ size_t n=1000;
+ auto buffer = create_weighted_mbuffer<WRec>(n);
+
+ Shard* shard = new Shard(buffer);
+
+ uint64_t lower_key = 0;
+ uint64_t upper_key = 5;
+
+ size_t k = 1000;
+
+ size_t cnt[3] = {0};
+ wirs_query_parms<WRec> parms = {lower_key, upper_key, k};
+ parms.rng = gsl_rng_alloc(gsl_rng_mt19937);
+
+ std::vector<std::vector<WRec>> results(2);
+
+ for (size_t i=0; i<1000; i++) {
+ auto state1 = WIRSQuery<WRec>::get_query_state(shard, &parms);
+ results[0] = strip_wrapping(WIRSQuery<WRec>::query(shard, state1, &parms));
+
+ auto state2 = WIRSQuery<WRec>::get_query_state(shard, &parms);
+ results[1] = strip_wrapping(WIRSQuery<WRec>::query(shard, state2, &parms));
+
+ WIRSQuery<WRec>::delete_query_state(state1);
+ WIRSQuery<WRec>::delete_query_state(state2);
}
- // ensure that the key range doesn't overlap, so nothing
- // gets cancelled.
- for (size_t i=0; i<ts_cnt; i++) {
- tombstones.push_back({i, i});
+ auto merged = WIRSQuery<WRec>::merge(results);
+
+ ck_assert_int_eq(merged.size(), 2*k);
+ for (size_t i=0; i<merged.size(); i++) {
+ ck_assert_int_ge(merged[i].key, lower_key);
+ ck_assert_int_le(merged[i].key, upper_key);
}
- for (size_t i=0; i<ts_cnt; i++) {
- buffer->append({tombstones[i].first, tombstones[i].second, 1, 1});
+ gsl_rng_free(parms.rng);
+ delete shard;
+ delete buffer;
+}
+END_TEST
+
+
+START_TEST(t_wirs_buffer_query_scan)
+{
+ size_t n=1000;
+ auto buffer = create_weighted_mbuffer<WRec>(n);
+
+ uint64_t lower_key = 0;
+ uint64_t upper_key = 5;
+
+ size_t k = 1000;
+
+ size_t cnt[3] = {0};
+ wirs_query_parms<WRec> parms = {lower_key, upper_key, k};
+ parms.rng = gsl_rng_alloc(gsl_rng_mt19937);
+
+ for (size_t i=0; i<1000; i++) {
+ auto state = WIRSQuery<WRec, false>::get_buffer_query_state(buffer, &parms);
+ auto result = WIRSQuery<WRec, false>::buffer_query(buffer, state, &parms);
+
+ for (size_t j=0; j<result.size(); j++) {
+ cnt[result[j].rec.key - 1]++;
+ }
+
+ WIRSQuery<WRec, false>::delete_buffer_query_state(state);
}
- auto shard = new Shard(buffer);
+ ck_assert(roughly_equal(cnt[0] / 1000, (double) k/4.0, k, .05));
+ ck_assert(roughly_equal(cnt[1] / 1000, (double) k/4.0, k, .05));
+ ck_assert(roughly_equal(cnt[2] / 1000, (double) k/2.0, k, .05));
+
+ gsl_rng_free(parms.rng);
+ delete buffer;
+}
+END_TEST
+
+
+START_TEST(t_wirs_buffer_query_rejection)
+{
+ size_t n=1000;
+ auto buffer = create_weighted_mbuffer<WRec>(n);
+
+ uint64_t lower_key = 0;
+ uint64_t upper_key = 5;
+
+ size_t k = 1000;
+
+ size_t cnt[3] = {0};
+ wirs_query_parms<WRec> parms = {lower_key, upper_key, k};
+ parms.rng = gsl_rng_alloc(gsl_rng_mt19937);
- for (size_t i=0; i<tombstones.size(); i++) {
- ck_assert(shard->check_tombstone({tombstones[i].first, tombstones[i].second}));
- ck_assert_int_eq(shard->get_rejection_count(), i+1);
+ for (size_t i=0; i<1000; i++) {
+ auto state = WIRSQuery<WRec>::get_buffer_query_state(buffer, &parms);
+ auto result = WIRSQuery<WRec>::buffer_query(buffer, state, &parms);
+
+ for (size_t j=0; j<result.size(); j++) {
+ cnt[result[j].rec.key - 1]++;
+ }
+
+ WIRSQuery<WRec>::delete_buffer_query_state(state);
}
- delete shard;
+ ck_assert(roughly_equal(cnt[0] / 1000, (double) k/4.0, k, .05));
+ ck_assert(roughly_equal(cnt[1] / 1000, (double) k/4.0, k, .05));
+ ck_assert(roughly_equal(cnt[2] / 1000, (double) k/2.0, k, .05));
+
+ gsl_rng_free(parms.rng);
delete buffer;
}
END_TEST
-*/
+
Suite *unit_testing()
{
@@ -241,29 +300,17 @@ Suite *unit_testing()
suite_add_tcase(unit, create);
- TCase *bounds = tcase_create("de:WIRS::get_{lower,upper}_bound Testing");
- //tcase_add_test(bounds, t_get_lower_bound_index);
- tcase_set_timeout(bounds, 100);
- suite_add_tcase(unit, bounds);
-
-
TCase *tombstone = tcase_create("de:WIRS::tombstone cancellation Testing");
tcase_add_test(tombstone, t_full_cancelation);
suite_add_tcase(unit, tombstone);
- /*
- TCase *sampling = tcase_create("de:WIRS::sampling Testing");
- tcase_add_test(sampling, t_weighted_sampling);
+ TCase *sampling = tcase_create("de:WIRS::WIRSQuery Testing");
+ tcase_add_test(sampling, t_wirs_query);
+ tcase_add_test(sampling, t_wirs_query_merge);
+ tcase_add_test(sampling, t_wirs_buffer_query_rejection);
+ tcase_add_test(sampling, t_wirs_buffer_query_scan);
suite_add_tcase(unit, sampling);
- */
-
-
- /*
- TCase *check_ts = tcase_create("de::WIRS::check_tombstone Testing");
- tcase_add_test(check_ts, t_tombstone_check);
- suite_add_tcase(unit, check_ts);
- */
return unit;
}