diff options
Diffstat (limited to 'include/shard')
| -rw-r--r-- | include/shard/AugBTree.h (renamed from include/shard/WIRS.h) | 351 |
1 files changed, 64 insertions, 287 deletions
diff --git a/include/shard/WIRS.h b/include/shard/AugBTree.h index bf29325..e32ec64 100644 --- a/include/shard/WIRS.h +++ b/include/shard/AugBTree.h @@ -1,5 +1,5 @@ /* - * include/shard/WIRS.h + * include/shard/AugBTree.h * * Copyright (C) 2023 Dong Xie <dongx@psu.edu> * Douglas B. Rumbaugh <drumbaugh@psu.edu> @@ -35,73 +35,23 @@ namespace de { thread_local size_t wirs_cancelations = 0; template <WeightedRecordInterface R> -struct wirs_query_parms { - decltype(R::key) lower_bound; - decltype(R::key) upper_bound; - size_t sample_size; - gsl_rng *rng; -}; - -template <WeightedRecordInterface R, bool Rejection> -class WIRSQuery; - -template <WeightedRecordInterface R> -struct wirs_node { - struct wirs_node<R> *left, *right; +struct AugBTreeNode { + struct AugBTreeNode<R> *left, *right; decltype(R::key) low, high; decltype(R::weight) weight; Alias* alias; }; template <WeightedRecordInterface R> -struct WIRSState { - decltype(R::weight) total_weight; - std::vector<wirs_node<R>*> nodes; - Alias* top_level_alias; - size_t sample_size; - - WIRSState() { - total_weight = 0; - top_level_alias = nullptr; - } - - ~WIRSState() { - if (top_level_alias) delete top_level_alias; - } -}; - -template <WeightedRecordInterface R> -struct WIRSBufferState { - size_t cutoff; - Alias* alias; - std::vector<Wrapped<R>> records; - decltype(R::weight) max_weight; - size_t sample_size; - decltype(R::weight) total_weight; - - ~WIRSBufferState() { - delete alias; - } - -}; - -template <WeightedRecordInterface R> -class WIRS { +class AugBTree { private: - typedef decltype(R::key) K; typedef decltype(R::value) V; typedef decltype(R::weight) W; public: - - // FIXME: there has to be a better way to do this - friend class WIRSQuery<R, true>; - friend class WIRSQuery<R, false>; - - WIRS(MutableBuffer<R>* buffer) + AugBTree(MutableBuffer<R>* buffer) : m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) { - m_alloc_size = (buffer->get_record_count() * sizeof(Wrapped<R>)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Wrapped<R>)) % CACHELINE_SIZE); assert(m_alloc_size % CACHELINE_SIZE == 0); m_data = (Wrapped<R>*)std::aligned_alloc(CACHELINE_SIZE, m_alloc_size); @@ -148,7 +98,7 @@ public: } } - WIRS(WIRS** shards, size_t len) + AugBTree(AugBTree** shards, size_t len) : m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) { std::vector<Cursor<Wrapped<R>>> cursors; cursors.reserve(len); @@ -208,7 +158,7 @@ public: } } - ~WIRS() { + ~AugBTree() { if (m_data) free(m_data); for (size_t i=0; i<m_alias.size(); i++) { if (m_alias[i]) delete m_alias[i]; @@ -257,15 +207,13 @@ public: size_t get_memory_usage() { - return m_alloc_size + m_node_cnt * sizeof(wirs_node<Wrapped<R>>); + return m_alloc_size + m_node_cnt * sizeof(AugBTreeNode<Wrapped<R>>); } size_t get_aux_memory_usage() { return 0; } -private: - size_t get_lower_bound(const K& key) const { size_t min = 0; size_t max = m_reccnt - 1; @@ -284,13 +232,60 @@ private: return min; } - bool covered_by(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) { + W find_covering_nodes(K lower_key, K upper_key, std::vector<void *> &nodes, std::vector<W> &weights) { + W total_weight = 0; + + /* Simulate a stack to unfold recursion. */ + struct AugBTreeNode<R>* st[64] = {0}; + st[0] = m_root; + size_t top = 1; + while(top > 0) { + auto now = st[--top]; + if (covered_by(now, lower_key, upper_key) || + (now->left == nullptr && now->right == nullptr && intersects(now, lower_key, upper_key))) { + nodes.emplace_back(now); + weights.emplace_back(now->weight); + total_weight += now->weight; + } else { + if (now->left && intersects(now->left, lower_key, upper_key)) st[top++] = now->left; + if (now->right && intersects(now->right, lower_key, upper_key)) st[top++] = now->right; + } + } + + + return total_weight; + } + + Wrapped<R> *get_weighted_sample(K lower_key, K upper_key, void *internal_node, gsl_rng *rng) { + /* k -> sampling: three levels. 1. select a node -> select a fat point -> select a record. */ + + /* first level */ + auto node = (AugBTreeNode<R>*) internal_node; + + /* second level */ + auto fat_point = node->low + node->alias->get(rng); + + /* third level */ + size_t rec_offset = fat_point * m_group_size + m_alias[fat_point]->get(rng); + auto record = m_data + rec_offset; + + /* bounds rejection */ + if (lower_key > record->rec.key || upper_key < record->rec.key) { + return nullptr; + } + + return record; + } + +private: + + bool covered_by(struct AugBTreeNode<R>* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); return lower_key < m_data[low_index].rec.key && m_data[high_index].rec.key < upper_key; } - bool intersects(struct wirs_node<R>* node, const K& lower_key, const K& upper_key) { + bool intersects(struct AugBTreeNode<R>* node, const K& lower_key, const K& upper_key) { auto low_index = node->low * m_group_size; auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1); return lower_key < m_data[high_index].rec.key && m_data[low_index].rec.key < upper_key; @@ -327,12 +322,12 @@ private: assert(weights.size() == n_groups); - m_root = construct_wirs_node(weights, 0, n_groups-1); + m_root = construct_AugBTreeNode(weights, 0, n_groups-1); } - struct wirs_node<R>* construct_wirs_node(const std::vector<W>& weights, size_t low, size_t high) { + struct AugBTreeNode<R>* construct_AugBTreeNode(const std::vector<W>& weights, size_t low, size_t high) { if (low == high) { - return new wirs_node<R>{nullptr, nullptr, low, high, weights[low], new Alias({1.0})}; + return new AugBTreeNode<R>{nullptr, nullptr, low, high, weights[low], new Alias({1.0})}; } else if (low > high) return nullptr; std::vector<double> node_weights; @@ -348,12 +343,12 @@ private: m_node_cnt += 1; size_t mid = (low + high) / 2; - return new wirs_node<R>{construct_wirs_node(weights, low, mid), - construct_wirs_node(weights, mid + 1, high), + return new AugBTreeNode<R>{construct_AugBTreeNode(weights, low, mid), + construct_AugBTreeNode(weights, mid + 1, high), low, high, sum, new Alias(node_weights)}; } - void free_tree(struct wirs_node<R>* node) { + void free_tree(struct AugBTreeNode<R>* node) { if (node) { delete node->alias; free_tree(node->left); @@ -364,7 +359,7 @@ private: Wrapped<R>* m_data; std::vector<Alias *> m_alias; - wirs_node<R>* m_root; + AugBTreeNode<R>* m_root; W m_total_weight; size_t m_reccnt; size_t m_tombstone_cnt; @@ -373,222 +368,4 @@ private: size_t m_node_cnt; BloomFilter<R> *m_bf; }; - - -template <WeightedRecordInterface R, bool Rejection=true> -class WIRSQuery { -public: - - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=false; - - static void *get_query_state(WIRS<R> *wirs, void *parms) { - auto res = new WIRSState<R>(); - decltype(R::key) lower_key = ((wirs_query_parms<R> *) parms)->lower_bound; - decltype(R::key) upper_key = ((wirs_query_parms<R> *) parms)->upper_bound; - - // Simulate a stack to unfold recursion. - double total_weight = 0.0; - struct wirs_node<R>* st[64] = {0}; - st[0] = wirs->m_root; - size_t top = 1; - while(top > 0) { - auto now = st[--top]; - if (wirs->covered_by(now, lower_key, upper_key) || - (now->left == nullptr && now->right == nullptr && wirs->intersects(now, lower_key, upper_key))) { - res->nodes.emplace_back(now); - total_weight += now->weight; - } else { - if (now->left && wirs->intersects(now->left, lower_key, upper_key)) st[top++] = now->left; - if (now->right && wirs->intersects(now->right, lower_key, upper_key)) st[top++] = now->right; - } - } - - std::vector<double> weights; - for (const auto& node: res->nodes) { - weights.emplace_back(node->weight / total_weight); - } - res->total_weight = total_weight; - res->top_level_alias = new Alias(weights); - res->sample_size = 0; - - return res; - } - - static void* get_buffer_query_state(MutableBuffer<R> *buffer, void *parms) { - WIRSBufferState<R> *state = new WIRSBufferState<R>(); - auto parameters = (wirs_query_parms<R>*) parms; - if constexpr (Rejection) { - state->cutoff = buffer->get_record_count() - 1; - state->max_weight = buffer->get_max_weight(); - state->total_weight = buffer->get_total_weight(); - state->sample_size = 0; - return state; - } - - std::vector<double> weights; - - state->cutoff = buffer->get_record_count() - 1; - double total_weight = 0.0; - - for (size_t i = 0; i <= state->cutoff; i++) { - auto rec = buffer->get_data() + i; - - if (rec->rec.key >= parameters->lower_bound && rec->rec.key <= parameters->upper_bound && !rec->is_tombstone() && !rec->is_deleted()) { - weights.push_back(rec->rec.weight); - state->records.push_back(*rec); - total_weight += rec->rec.weight; - } - } - - for (size_t i = 0; i < weights.size(); i++) { - weights[i] = weights[i] / total_weight; - } - - state->total_weight = total_weight; - state->alias = new Alias(weights); - state->sample_size = 0; - - return state; - } - - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, std::vector<void*> &buff_states) { - // FIXME: need to redo for the buffer vector interface - auto p = (wirs_query_parms<R> *) query_parms; - - std::vector<size_t> shard_sample_sizes(shard_states.size()+1, 0); - size_t buffer_sz = 0; - - decltype(R::weight) total_weight = 0; - std::vector<decltype(R::weight)> weights; - for (auto &s : buff_states) { - auto state = (WIRSBufferState<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - for (auto &s : shard_states) { - auto state = (WIRSState<R> *) s; - total_weight += state->total_weight; - weights.push_back(state->total_weight); - } - - std::vector<double> normalized_weights; - for (auto w : weights) { - normalized_weights.push_back((double) w / (double) total_weight); - } - - auto shard_alias = Alias(normalized_weights); - for (size_t i=0; i<p->sample_size; i++) { - auto idx = shard_alias.get(p->rng); - if (idx == 0) { - buffer_sz++; - } else { - shard_sample_sizes[idx - 1]++; - } - } - - for (size_t i=0; i<shard_states.size(); i++) { - auto state = (WIRSState<R> *) shard_states[i]; - state->sample_size = shard_sample_sizes[i+1]; - } - } - - - - static std::vector<Wrapped<R>> query(WIRS<R> *wirs, void *q_state, void *parms) { - auto lower_key = ((wirs_query_parms<R> *) parms)->lower_bound; - auto upper_key = ((wirs_query_parms<R> *) parms)->upper_bound; - auto rng = ((wirs_query_parms<R> *) parms)->rng; - - auto state = (WIRSState<R> *) q_state; - auto sample_size = state->sample_size; - - std::vector<Wrapped<R>> result_set; - - if (sample_size == 0) { - return result_set; - } - // k -> sampling: three levels. 1. select a node -> select a fat point -> select a record. - size_t cnt = 0; - size_t attempts = 0; - do { - ++attempts; - // first level.... - auto node = state->nodes[state->top_level_alias->get(rng)]; - // second level... - auto fat_point = node->low + node->alias->get(rng); - // third level... - size_t rec_offset = fat_point * wirs->m_group_size + wirs->m_alias[fat_point]->get(rng); - auto record = wirs->m_data + rec_offset; - - // bounds rejection - if (lower_key > record->rec.key || upper_key < record->rec.key) { - continue; - } - - result_set.emplace_back(*record); - cnt++; - } while (attempts < sample_size); - - return result_set; - } - - static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) { - auto st = (WIRSBufferState<R> *) state; - auto p = (wirs_query_parms<R> *) parms; - - std::vector<Wrapped<R>> result; - result.reserve(st->sample_size); - - if constexpr (Rejection) { - for (size_t i=0; i<st->sample_size; i++) { - auto idx = gsl_rng_uniform_int(p->rng, st->cutoff); - auto rec = buffer->get_data() + idx; - - auto test = gsl_rng_uniform(p->rng) * st->max_weight; - - if (test <= rec->rec.weight && rec->rec.key >= p->lower_bound && rec->rec.key <= p->upper_bound) { - result.emplace_back(*rec); - } - } - return result; - } - - for (size_t i=0; i<st->sample_size; i++) { - auto idx = st->alias->get(p->rng); - result.emplace_back(st->records[idx]); - } - - return result; - } - - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { - std::vector<R> output; - - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - output.emplace_back(results[i][j].rec); - } - } - - return output; - } - - static void delete_query_state(void *state) { - auto s = (WIRSState<R> *) state; - delete s; - } - - static void delete_buffer_query_state(void *state) { - auto s = (WIRSBufferState<R> *) state; - delete s; - } - - - //{q.get_buffer_query_state(p, p)}; - //{q.buffer_query(p, p)}; - -}; - } |