From de5fa133758e2f0aad855ac58dff5cfa13d06f74 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Mon, 29 May 2023 10:35:13 -0400 Subject: WIRS query interface --- include/shard/WIRS.h | 160 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 117 insertions(+), 43 deletions(-) diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h index 7dee496..4c7563f 100644 --- a/include/shard/WIRS.h +++ b/include/shard/WIRS.h @@ -36,6 +36,28 @@ struct wirs_query_parms { class InternalLevel; +template +class WIRSQuery; + +template +struct wirs_node { + struct wirs_node *left, *right; + decltype(R::key) low, high; + decltype(R::weight) weight; + Alias* alias; +}; + +template +struct WIRSState { + decltype(R::weight) tot_weight; + std::vector*> nodes; + Alias* top_level_alias; + + ~WIRSState() { + if (top_level_alias) delete top_level_alias; + } +}; + template class WIRS { friend class InternalLevel; @@ -45,27 +67,10 @@ private: typedef decltype(R::value) V; typedef decltype(R::weight) W; - template - struct wirs_node { - struct wirs_node *left, *right; - K low, high; - W weight; - Alias* alias; - }; - - template - struct WIRSState { - W tot_weight; - std::vector*> nodes; - Alias* top_level_alias; - - ~WIRSState() { - if (top_level_alias) delete top_level_alias; - } - }; - public: + friend class WIRSQuery; + WIRS(MutableBuffer* buffer) : m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) { @@ -252,30 +257,6 @@ private: auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key; } - - struct wirs_node* construct_wirs_node(const std::vector& weights, size_t low, size_t high) { - if (low == high) { - return new wirs_node{nullptr, nullptr, low, high, weights[low], new Alias({1.0})}; - } else if (low > high) return nullptr; - - std::vector node_weights; - W sum = 0; - for (size_t i = low; i < high; ++i) { - node_weights.emplace_back(weights[i]); - sum += weights[i]; - } - - for (auto& w: node_weights) - if (sum) w /= sum; - else w = 1.0 / node_weights.size(); - - - size_t mid = (low + high) / 2; - return new wirs_node{construct_wirs_node(weights, low, mid), - construct_wirs_node(weights, mid + 1, high), - low, high, sum, new Alias(node_weights)}; - } - void build_wirs_structure() { m_group_size = std::ceil(std::log(m_reccnt)); @@ -330,4 +311,97 @@ private: BloomFilter m_bf; }; + +template +class WIRSQuery { +public: + static void *get_query_state(wirs_query_parms *parameters, WIRS *wirs) { + auto res = new WIRSState(); + decltype(R::key) lower_key = ((wirs_query_parms *) parameters)->lower_bound; + decltype(R::key) upper_key = ((wirs_query_parms *) parameters)->upper_bound; + + // Simulate a stack to unfold recursion. + double tot_weight = 0.0; + struct wirs_node* st[64] = {0}; + st[0] = wirs->m_root; + size_t top = 1; + while(top > 0) { + auto now = st[--top]; + if (covered_by(now, lower_key, upper_key) || + (now->left == nullptr && now->right == nullptr && intersects(now, lower_key, upper_key))) { + res->nodes.emplace_back(now); + tot_weight += now->weight; + } else { + if (now->left && intersects(now->left, lower_key, upper_key)) st[top++] = now->left; + if (now->right && intersects(now->right, lower_key, upper_key)) st[top++] = now->right; + } + } + + std::vector weights; + for (const auto& node: res->nodes) { + weights.emplace_back(node->weight / tot_weight); + } + res->tot_weight = tot_weight; + res->top_level_alias = new Alias(weights); + + return res; + } + + static std::vector *query(wirs_query_parms *parameters, WIRSState *state, WIRS *wirs) { + auto sample_sz = parameters->sample_size; + auto lower_key = parameters->lower_bound; + auto upper_key = parameters->upper_bound; + auto rng = parameters->rng; + + std::vector *result_set = new std::vector(); + + if (sample_sz == 0) { + return 0; + } + // k -> sampling: three levels. 1. select a node -> select a fat point -> select a record. + size_t cnt = 0; + size_t attempts = 0; + do { + ++attempts; + // first level.... + auto node = state->nodes[state->top_level_alias->get(rng)]; + // second level... + auto fat_point = node->low + node->alias->get(rng); + // third level... + size_t rec_offset = fat_point * wirs->m_group_size + wirs->m_alias[fat_point]->get(rng); + auto record = wirs->m_data + rec_offset; + + // bounds rejection + if (lower_key > record->key || upper_key < record->key) { + continue; + } + + result_set->emplace_back(*record); + cnt++; + } while (attempts < sample_sz); + + return result_set; + } + + static std::vector *merge(std::vector> *results) { + std::vector *output = new std::vector(); + + for (size_t i=0; isize(); i++) { + for (size_t j=0; j<(*results)[i]->size(); j++) { + output->emplace_back(*((*results)[i])[j]); + } + } + return output; + } + + static void delete_query_state(wirs_query_parms *parameters) { + delete parameters; + } + + + //{q.get_buffer_query_state(p, p)}; + //{q.buffer_query(p, p)}; + +}; + } -- cgit v1.2.3