From ab14529843d7bbb3a0d0c30b163ed444afee04ed Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Mon, 29 May 2023 16:37:43 -0400 Subject: WIRS Query tests + fixes --- include/framework/MutableBuffer.h | 4 ++ include/shard/WIRS.h | 115 +++++++++++++++++++++++++++++++------- 2 files changed, 99 insertions(+), 20 deletions(-) (limited to 'include') diff --git a/include/framework/MutableBuffer.h b/include/framework/MutableBuffer.h index bc80922..3e0de40 100644 --- a/include/framework/MutableBuffer.h +++ b/include/framework/MutableBuffer.h @@ -151,6 +151,10 @@ public: return m_data; } + double get_max_weight() { + return m_max_weight; + } + private: int32_t try_advance_tail() { size_t new_tail = m_reccnt.fetch_add(1); diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h index 9e4d911..7e3f468 100644 --- a/include/shard/WIRS.h +++ b/include/shard/WIRS.h @@ -35,9 +35,10 @@ struct wirs_query_parms { decltype(R::key) lower_bound; decltype(R::key) upper_bound; size_t sample_size; + gsl_rng *rng; }; -template +template class WIRSQuery; template @@ -59,6 +60,19 @@ struct WIRSState { } }; +template +struct WIRSBufferState { + size_t cutoff; + Alias* alias; + std::vector> records; + decltype(R::weight) max_weight; + + ~WIRSBufferState() { + delete alias; + } + +}; + template class WIRS { private: @@ -69,7 +83,9 @@ private: public: - friend class WIRSQuery; + // FIXME: there has to be a better way to do this + friend class WIRSQuery; + friend class WIRSQuery; WIRS(MutableBuffer* buffer) : m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) { @@ -251,13 +267,13 @@ private: bool covered_by(struct wirs_node* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); - return lower_key < m_data[low_index].key && m_data[high_index].key < upper_key; + return lower_key < m_data[low_index].rec.key && m_data[high_index].rec.key < upper_key; } bool intersects(struct wirs_node* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); - return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key; + return lower_key < m_data[high_index].rec.key && m_data[low_index].rec.key < upper_key; } void build_wirs_structure() { @@ -337,7 +353,7 @@ private: }; -template +template class WIRSQuery { public: static void *get_query_state(WIRS *wirs, void *parms) { @@ -373,10 +389,39 @@ public: } static void* get_buffer_query_state(MutableBuffer *buffer, void *parms) { - return nullptr; + WIRSBufferState *state = new WIRSBufferState(); + auto parameters = (wirs_query_parms*) parms; + if constexpr (Rejection) { + state->cutoff = buffer->get_record_count() - 1; + state->max_weight = buffer->get_max_weight(); + return state; + } + + std::vector weights; + + state->cutoff = buffer->get_record_count() - 1; + double tot_weight = 0.0; + + for (size_t i = 0; i <= state->cutoff; i++) { + auto rec = buffer->get_data() + i; + + if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) { + weights.push_back(rec->rec.weight); + state->records.push_back(*rec); + tot_weight += rec->rec.weight; + } + } + + for (size_t i = 0; i < weights.size(); i++) { + weights[i] = weights[i] / tot_weight; + } + + state->alias = new Alias(weights); + + return state; } - static std::vector> *query(WIRS *wirs, void *q_state, void *parms) { + static std::vector> query(WIRS *wirs, void *q_state, void *parms) { auto sample_sz = ((wirs_query_parms *) parms)->sample_size; auto lower_key = ((wirs_query_parms *) parms)->lower_bound; auto upper_key = ((wirs_query_parms *) parms)->upper_bound; @@ -384,10 +429,10 @@ public: auto state = (WIRSState *) q_state; - std::vector> *result_set = new std::vector>(); + std::vector> result_set; if (sample_sz == 0) { - return 0; + return result_set; } // k -> sampling: three levels. 1. select a node -> select a fat point -> select a record. size_t cnt = 0; @@ -403,36 +448,66 @@ public: auto record = wirs->m_data + rec_offset; // bounds rejection - if (lower_key > record->key || upper_key < record->key) { + if (lower_key > record->rec.key || upper_key < record->rec.key) { continue; } - result_set->emplace_back(*record); + result_set.emplace_back(*record); cnt++; } while (attempts < sample_sz); return result_set; } - static std::vector> *buffer_query(MutableBuffer *buffer, void *state, void *parms) { - return new std::vector>(); + static std::vector> buffer_query(MutableBuffer *buffer, void *state, void *parms) { + auto st = (WIRSBufferState *) state; + auto p = (wirs_query_parms *) parms; + + std::vector> result; + result.reserve(p->sample_size); + + if constexpr (Rejection) { + for (size_t i=0; isample_size; i++) { + auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); + auto rec = buffer->get_data() + idx; + + auto test = gsl_rng_uniform(p->rng) * st->max_weight; + + if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) { + result.emplace_back(*rec); + } + } + return result; + } + + for (size_t i=0; isample_size; i++) { + auto idx = st->alias->get(p->rng); + result.emplace_back(st->records[idx]); + } + + return result; } - static std::vector merge(std::vector*> *results) { + static std::vector merge(std::vector> &results) { std::vector output; - for (size_t i=0; isize(); i++) { - for (size_t j=0; j<(*results)[i]->size(); j++) { - output->emplace_back(*((*results)[i])[j]); + for (size_t i=0; i *parameters = parm; - delete parameters; + static void delete_query_state(void *state) { + auto s = (WIRSState *) state; + delete s; + } + + static void delete_buffer_query_state(void *state) { + auto s = (WIRSBufferState *) state; + delete s; } -- cgit v1.2.3