diff options
| -rw-r--r-- | include/framework/MutableBuffer.h | 4 | ||||
| -rw-r--r-- | include/shard/WIRS.h | 115 | ||||
| -rw-r--r-- | tests/wirs_tests.cpp | 207 |
3 files changed, 226 insertions, 100 deletions
diff --git a/include/framework/MutableBuffer.h b/include/framework/MutableBuffer.h index bc80922..3e0de40 100644 --- a/include/framework/MutableBuffer.h +++ b/include/framework/MutableBuffer.h @@ -151,6 +151,10 @@ public: return m_data; } + double get_max_weight() { + return m_max_weight; + } + private: int32_t try_advance_tail() { size_t new_tail = m_reccnt.fetch_add(1); diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h index 9e4d911..7e3f468 100644 --- a/include/shard/WIRS.h +++ b/include/shard/WIRS.h @@ -35,9 +35,10 @@ struct wirs_query_parms { decltype(R::key) lower_bound; decltype(R::key) upper_bound; size_t sample_size; + gsl_rng *rng; }; -template <WeightedRecordInterface R> +template <WeightedRecordInterface R, bool Rejection> class WIRSQuery; template <WeightedRecordInterface R> @@ -60,6 +61,19 @@ struct WIRSState { }; template <WeightedRecordInterface R> +struct WIRSBufferState { + size_t cutoff; + Alias* alias; + std::vector<Wrapped<R>> records; + decltype(R::weight) max_weight; + + ~WIRSBufferState() { + delete alias; + } + +}; + +template <WeightedRecordInterface R> class WIRS { private: @@ -69,7 +83,9 @@ private: public: - friend class WIRSQuery<R>; + // FIXME: there has to be a better way to do this + friend class WIRSQuery<R, true>; + friend class WIRSQuery<R, false>; WIRS(MutableBuffer<R>* buffer) : m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) { @@ -251,13 +267,13 @@ private: bool covered_by(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); - return lower_key < m_data[low_index].key && m_data[high_index].key < upper_key; + return lower_key < m_data[low_index].rec.key && m_data[high_index].rec.key < upper_key; } bool intersects(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); - return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key; + return lower_key < m_data[high_index].rec.key && m_data[low_index].rec.key < upper_key; } void build_wirs_structure() { @@ -337,7 +353,7 @@ private: }; -template <WeightedRecordInterface R> +template <WeightedRecordInterface R, bool Rejection=true> class WIRSQuery { public: static void *get_query_state(WIRS<R> *wirs, void *parms) { @@ -373,10 +389,39 @@ public: } static void* get_buffer_query_state(MutableBuffer<R> *buffer, void *parms) { - return nullptr; + WIRSBufferState<R> *state = new WIRSBufferState<R>(); + auto parameters = (wirs_query_parms<R>*) parms; + if constexpr (Rejection) { + state->cutoff = buffer->get_record_count() - 1; + state->max_weight = buffer->get_max_weight(); + return state; + } + + std::vector<double> weights; + + state->cutoff = buffer->get_record_count() - 1; + double tot_weight = 0.0; + + for (size_t i = 0; i <= state->cutoff; i++) { + auto rec = buffer->get_data() + i; + + if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) { + weights.push_back(rec->rec.weight); + state->records.push_back(*rec); + tot_weight += rec->rec.weight; + } + } + + for (size_t i = 0; i < weights.size(); i++) { + weights[i] = weights[i] / tot_weight; + } + + state->alias = new Alias(weights); + + return state; } - static std::vector<Wrapped<R>> *query(WIRS<R> *wirs, void *q_state, void *parms) { + static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) { auto sample_sz = ((wirs_query_parms<R> *) parms)->sample_size; auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound; auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound; @@ -384,10 +429,10 @@ public: auto state = (WIRSState<R> *) q_state; - std::vector<Wrapped<R>> *result_set = new std::vector<Wrapped<R>>(); + std::vector<Wrapped<R>> result_set; if (sample_sz == 0) { - return 0; + return result_set; } // k -> sampling: three levels. 1. select a node -> select a fat point -> select a record. size_t cnt = 0; @@ -403,36 +448,66 @@ public: auto record = wirs->m_data + rec_offset; // bounds rejection - if (lower_key > record->key || upper_key < record->key) { + if (lower_key > record->rec.key || upper_key < record->rec.key) { continue; } - result_set->emplace_back(*record); + result_set.emplace_back(*record); cnt++; } while (attempts < sample_sz); return result_set; } - static std::vector<Wrapped<R>> *buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) { - return new std::vector<Wrapped<R>>(); + static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) { + auto st = (WIRSBufferState<R> *) state; + auto p = (wirs_query_parms<R> *) parms; + + std::vector<Wrapped<R>> result; + result.reserve(p->sample_size); + + if constexpr (Rejection) { + for (size_t i=0; i<p->sample_size; i++) { + auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); + auto rec = buffer->get_data() + idx; + + auto test = gsl_rng_uniform(p->rng) * st->max_weight; + + if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) { + result.emplace_back(*rec); + } + } + return result; + } + + for (size_t i=0; i<p->sample_size; i++) { + auto idx = st->alias->get(p->rng); + result.emplace_back(st->records[idx]); + } + + return result; } - static std::vector<R> merge(std::vector<std::vector<R>*> *results) { + static std::vector<R> merge(std::vector<std::vector<R>> &results) { std::vector<R> output; - for (size_t i=0; i<results->size(); i++) { - for (size_t j=0; j<(*results)[i]->size(); j++) { - output->emplace_back(*((*results)[i])[j]); + for (size_t i=0; i<results.size(); i++) { + for (size_t j=0; j<results[i].size(); j++) { + output.emplace_back(results[i][j]); } } return output; } - static void delete_query_state(void *parm) { - wirs_query_parms<R> *parameters = parm; - delete parameters; + static void delete_query_state(void *state) { + auto s = (WIRSState<R> *) state; + delete s; + } + + static void delete_buffer_query_state(void *state) { + auto s = (WIRSBufferState<R> *) state; + delete s; } diff --git a/tests/wirs_tests.cpp b/tests/wirs_tests.cpp index ba4b754..0a9b1d0 100644 --- a/tests/wirs_tests.cpp +++ b/tests/wirs_tests.cpp @@ -44,6 +44,7 @@ START_TEST(t_mbuffer_init) delete shard; } + START_TEST(t_wirs_init) { size_t n = 512; @@ -94,31 +95,6 @@ START_TEST(t_wirs_init) delete shard4; } -/* -START_TEST(t_get_lower_bound_index) -{ - size_t n = 10000; - auto mbuffer = create_double_seq_mbuffer<WRec>(n); - - ck_assert_ptr_nonnull(mbuffer); - Shard* shard = new Shard(mbuffer); - - ck_assert_int_eq(shard->get_record_count(), n); - ck_assert_int_eq(shard->get_tombstone_count(), 0); - - auto tbl_records = mbuffer->sorted_output(); - for (size_t i=0; i<n; i++) { - const WRec *tbl_rec = mbuffer->get_record_at(i); - auto pos = shard->get_lower_bound(tbl_rec->key); - ck_assert_int_eq(shard->get_record_at(pos)->key, tbl_rec->key); - ck_assert_int_le(pos, i); - } - - delete mbuffer; - delete shard; -} - -*/ START_TEST(t_full_cancelation) { @@ -150,8 +126,7 @@ START_TEST(t_full_cancelation) END_TEST -/* -START_TEST(t_weighted_sampling) +START_TEST(t_wirs_query) { size_t n=1000; auto buffer = create_weighted_mbuffer<WRec>(n); @@ -163,72 +138,156 @@ START_TEST(t_weighted_sampling) size_t k = 1000; - std::vector<WRec> results; - results.reserve(k); size_t cnt[3] = {0}; + wirs_query_parms<WRec> parms = {lower_key, upper_key, k}; + parms.rng = gsl_rng_alloc(gsl_rng_mt19937); + for (size_t i=0; i<1000; i++) { - WIRS<WRec>::wirs_query_parms parms = {lower_key, upper_key}; - auto state = shard->get_query_state(&parms); - - shard->get_samples(state, results, lower_key, upper_key, k, g_rng); + auto state = WIRSQuery<WRec>::get_query_state(shard, &parms); + auto result = WIRSQuery<WRec>::query(shard, state, &parms); - for (size_t j=0; j<k; j++) { - cnt[results[j].key - 1]++; + for (size_t j=0; j<result.size(); j++) { + cnt[result[j].rec.key - 1]++; } - WIRS<WRec>::delete_query_state(state); + WIRSQuery<WRec>::delete_query_state(state); } ck_assert(roughly_equal(cnt[0] / 1000, (double) k/4.0, k, .05)); ck_assert(roughly_equal(cnt[1] / 1000, (double) k/4.0, k, .05)); ck_assert(roughly_equal(cnt[2] / 1000, (double) k/2.0, k, .05)); + gsl_rng_free(parms.rng); delete shard; delete buffer; } END_TEST -*/ -/* -START_TEST(t_tombstone_check) +template <RecordInterface R> +std::vector<R> strip_wrapping(std::vector<Wrapped<R>> vec) { + std::vector<R> out(vec.size()); + for (size_t i=0; i<vec.size(); i++) { + out[i] = vec[i].rec; + } + + return out; +} + + +START_TEST(t_wirs_query_merge) { - size_t cnt = 1024; - size_t ts_cnt = 256; - auto buffer = new MutableBuffer<WRec>(cnt + ts_cnt, true, ts_cnt); - - std::vector<std::pair<uint64_t, uint32_t>> tombstones; - - uint64_t key = 1000; - uint32_t val = 101; - for (size_t i = 0; i < cnt; i++) { - buffer->append({key, val, 1}); - key++; - val++; + size_t n=1000; + auto buffer = create_weighted_mbuffer<WRec>(n); + + Shard* shard = new Shard(buffer); + + uint64_t lower_key = 0; + uint64_t upper_key = 5; + + size_t k = 1000; + + size_t cnt[3] = {0}; + wirs_query_parms<WRec> parms = {lower_key, upper_key, k}; + parms.rng = gsl_rng_alloc(gsl_rng_mt19937); + + std::vector<std::vector<WRec>> results(2); + + for (size_t i=0; i<1000; i++) { + auto state1 = WIRSQuery<WRec>::get_query_state(shard, &parms); + results[0] = strip_wrapping(WIRSQuery<WRec>::query(shard, state1, &parms)); + + auto state2 = WIRSQuery<WRec>::get_query_state(shard, &parms); + results[1] = strip_wrapping(WIRSQuery<WRec>::query(shard, state2, &parms)); + + WIRSQuery<WRec>::delete_query_state(state1); + WIRSQuery<WRec>::delete_query_state(state2); } - // ensure that the key range doesn't overlap, so nothing - // gets cancelled. - for (size_t i=0; i<ts_cnt; i++) { - tombstones.push_back({i, i}); + auto merged = WIRSQuery<WRec>::merge(results); + + ck_assert_int_eq(merged.size(), 2*k); + for (size_t i=0; i<merged.size(); i++) { + ck_assert_int_ge(merged[i].key, lower_key); + ck_assert_int_le(merged[i].key, upper_key); } - for (size_t i=0; i<ts_cnt; i++) { - buffer->append({tombstones[i].first, tombstones[i].second, 1, 1}); + gsl_rng_free(parms.rng); + delete shard; + delete buffer; +} +END_TEST + + +START_TEST(t_wirs_buffer_query_scan) +{ + size_t n=1000; + auto buffer = create_weighted_mbuffer<WRec>(n); + + uint64_t lower_key = 0; + uint64_t upper_key = 5; + + size_t k = 1000; + + size_t cnt[3] = {0}; + wirs_query_parms<WRec> parms = {lower_key, upper_key, k}; + parms.rng = gsl_rng_alloc(gsl_rng_mt19937); + + for (size_t i=0; i<1000; i++) { + auto state = WIRSQuery<WRec, false>::get_buffer_query_state(buffer, &parms); + auto result = WIRSQuery<WRec, false>::buffer_query(buffer, state, &parms); + + for (size_t j=0; j<result.size(); j++) { + cnt[result[j].rec.key - 1]++; + } + + WIRSQuery<WRec, false>::delete_buffer_query_state(state); } - auto shard = new Shard(buffer); + ck_assert(roughly_equal(cnt[0] / 1000, (double) k/4.0, k, .05)); + ck_assert(roughly_equal(cnt[1] / 1000, (double) k/4.0, k, .05)); + ck_assert(roughly_equal(cnt[2] / 1000, (double) k/2.0, k, .05)); + + gsl_rng_free(parms.rng); + delete buffer; +} +END_TEST + + +START_TEST(t_wirs_buffer_query_rejection) +{ + size_t n=1000; + auto buffer = create_weighted_mbuffer<WRec>(n); + + uint64_t lower_key = 0; + uint64_t upper_key = 5; + + size_t k = 1000; + + size_t cnt[3] = {0}; + wirs_query_parms<WRec> parms = {lower_key, upper_key, k}; + parms.rng = gsl_rng_alloc(gsl_rng_mt19937); - for (size_t i=0; i<tombstones.size(); i++) { - ck_assert(shard->check_tombstone({tombstones[i].first, tombstones[i].second})); - ck_assert_int_eq(shard->get_rejection_count(), i+1); + for (size_t i=0; i<1000; i++) { + auto state = WIRSQuery<WRec>::get_buffer_query_state(buffer, &parms); + auto result = WIRSQuery<WRec>::buffer_query(buffer, state, &parms); + + for (size_t j=0; j<result.size(); j++) { + cnt[result[j].rec.key - 1]++; + } + + WIRSQuery<WRec>::delete_buffer_query_state(state); } - delete shard; + ck_assert(roughly_equal(cnt[0] / 1000, (double) k/4.0, k, .05)); + ck_assert(roughly_equal(cnt[1] / 1000, (double) k/4.0, k, .05)); + ck_assert(roughly_equal(cnt[2] / 1000, (double) k/2.0, k, .05)); + + gsl_rng_free(parms.rng); delete buffer; } END_TEST -*/ + Suite *unit_testing() { @@ -241,29 +300,17 @@ Suite *unit_testing() suite_add_tcase(unit, create); - TCase *bounds = tcase_create("de:WIRS::get_{lower,upper}_bound Testing"); - //tcase_add_test(bounds, t_get_lower_bound_index); - tcase_set_timeout(bounds, 100); - suite_add_tcase(unit, bounds); - - TCase *tombstone = tcase_create("de:WIRS::tombstone cancellation Testing"); tcase_add_test(tombstone, t_full_cancelation); suite_add_tcase(unit, tombstone); - /* - TCase *sampling = tcase_create("de:WIRS::sampling Testing"); - tcase_add_test(sampling, t_weighted_sampling); + TCase *sampling = tcase_create("de:WIRS::WIRSQuery Testing"); + tcase_add_test(sampling, t_wirs_query); + tcase_add_test(sampling, t_wirs_query_merge); + tcase_add_test(sampling, t_wirs_buffer_query_rejection); + tcase_add_test(sampling, t_wirs_buffer_query_scan); suite_add_tcase(unit, sampling); - */ - - - /* - TCase *check_ts = tcase_create("de::WIRS::check_tombstone Testing"); - tcase_add_test(check_ts, t_tombstone_check); - suite_add_tcase(unit, check_ts); - */ return unit; } |