summaryrefslogtreecommitdiffstats
path: root/include/shard
diff options
context:
space:
mode:
Diffstat (limited to 'include/shard')
-rw-r--r--include/shard/WIRS.h160
1 files changed, 117 insertions, 43 deletions
diff --git a/include/shard/WIRS.h b/include/shard/WIRS.h
index 7dee496..4c7563f 100644
--- a/include/shard/WIRS.h
+++ b/include/shard/WIRS.h
@@ -37,6 +37,28 @@ struct wirs_query_parms {
class InternalLevel;
template <WeightedRecordInterface R>
+class WIRSQuery;
+
+template <WeightedRecordInterface R>
+struct wirs_node {
+ struct wirs_node<R> *left, *right;
+ decltype(R::key) low, high;
+ decltype(R::weight) weight;
+ Alias* alias;
+};
+
+template <WeightedRecordInterface R>
+struct WIRSState {
+ decltype(R::weight) tot_weight;
+ std::vector<wirs_node<R>*> nodes;
+ Alias* top_level_alias;
+
+ ~WIRSState() {
+ if (top_level_alias) delete top_level_alias;
+ }
+};
+
+template <WeightedRecordInterface R>
class WIRS {
friend class InternalLevel;
private:
@@ -45,27 +67,10 @@ private:
typedef decltype(R::value) V;
typedef decltype(R::weight) W;
- template <WeightedRecordInterface R_ = R>
- struct wirs_node {
- struct wirs_node<R_> *left, *right;
- K low, high;
- W weight;
- Alias* alias;
- };
-
- template <WeightedRecordInterface R_ = R>
- struct WIRSState {
- W tot_weight;
- std::vector<wirs_node<R_>*> nodes;
- Alias* top_level_alias;
-
- ~WIRSState() {
- if (top_level_alias) delete top_level_alias;
- }
- };
-
public:
+ friend class WIRSQuery<R>;
+
WIRS(MutableBuffer<R>* buffer)
: m_reccnt(0), m_tombstone_cnt(0), m_total_weight(0), m_root(nullptr) {
@@ -252,30 +257,6 @@ private:
auto high_index = std::min((node->high + 1) * m_group_size - 1, m_reccnt - 1);
return lower_key < m_data[high_index].key && m_data[low_index].key < upper_key;
}
-
- struct wirs_node<R>* construct_wirs_node(const std::vector<W>& weights, size_t low, size_t high) {
- if (low == high) {
- return new wirs_node<R>{nullptr, nullptr, low, high, weights[low], new Alias({1.0})};
- } else if (low > high) return nullptr;
-
- std::vector<double> node_weights;
- W sum = 0;
- for (size_t i = low; i < high; ++i) {
- node_weights.emplace_back(weights[i]);
- sum += weights[i];
- }
-
- for (auto& w: node_weights)
- if (sum) w /= sum;
- else w = 1.0 / node_weights.size();
-
-
- size_t mid = (low + high) / 2;
- return new wirs_node<R>{construct_wirs_node(weights, low, mid),
- construct_wirs_node(weights, mid + 1, high),
- low, high, sum, new Alias(node_weights)};
- }
-
void build_wirs_structure() {
m_group_size = std::ceil(std::log(m_reccnt));
@@ -330,4 +311,97 @@ private:
BloomFilter<K> m_bf;
};
+
+template <WeightedRecordInterface R>
+class WIRSQuery {
+public:
+ static void *get_query_state(wirs_query_parms<R> *parameters, WIRS<R> *wirs) {
+ auto res = new WIRSState<R>();
+ decltype(R::key) lower_key = ((wirs_query_parms<R> *) parameters)->lower_bound;
+ decltype(R::key) upper_key = ((wirs_query_parms<R> *) parameters)->upper_bound;
+
+ // Simulate a stack to unfold recursion.
+ double tot_weight = 0.0;
+ struct wirs_node<R>* st[64] = {0};
+ st[0] = wirs->m_root;
+ size_t top = 1;
+ while(top > 0) {
+ auto now = st[--top];
+ if (covered_by(now, lower_key, upper_key) ||
+ (now->left == nullptr && now->right == nullptr && intersects(now, lower_key, upper_key))) {
+ res->nodes.emplace_back(now);
+ tot_weight += now->weight;
+ } else {
+ if (now->left && intersects(now->left, lower_key, upper_key)) st[top++] = now->left;
+ if (now->right && intersects(now->right, lower_key, upper_key)) st[top++] = now->right;
+ }
+ }
+
+ std::vector<double> weights;
+ for (const auto& node: res->nodes) {
+ weights.emplace_back(node->weight / tot_weight);
+ }
+ res->tot_weight = tot_weight;
+ res->top_level_alias = new Alias(weights);
+
+ return res;
+ }
+
+ static std::vector<R> *query(wirs_query_parms<R> *parameters, WIRSState<R> *state, WIRS<R> *wirs) {
+ auto sample_sz = parameters->sample_size;
+ auto lower_key = parameters->lower_bound;
+ auto upper_key = parameters->upper_bound;
+ auto rng = parameters->rng;
+
+ std::vector<R> *result_set = new std::vector<R>();
+
+ if (sample_sz == 0) {
+ return 0;
+ }
+ // k -> sampling: three levels. 1. select a node -> select a fat point -> select a record.
+ size_t cnt = 0;
+ size_t attempts = 0;
+ do {
+ ++attempts;
+ // first level....
+ auto node = state->nodes[state->top_level_alias->get(rng)];
+ // second level...
+ auto fat_point = node->low + node->alias->get(rng);
+ // third level...
+ size_t rec_offset = fat_point * wirs->m_group_size + wirs->m_alias[fat_point]->get(rng);
+ auto record = wirs->m_data + rec_offset;
+
+ // bounds rejection
+ if (lower_key > record->key || upper_key < record->key) {
+ continue;
+ }
+
+ result_set->emplace_back(*record);
+ cnt++;
+ } while (attempts < sample_sz);
+
+ return result_set;
+ }
+
+ static std::vector<R> *merge(std::vector<std::vector<R>> *results) {
+ std::vector<R> *output = new std::vector<R>();
+
+ for (size_t i=0; i<results->size(); i++) {
+ for (size_t j=0; j<(*results)[i]->size(); j++) {
+ output->emplace_back(*((*results)[i])[j]);
+ }
+ }
+ return output;
+ }
+
+ static void delete_query_state(wirs_query_parms<R> *parameters) {
+ delete parameters;
+ }
+
+
+ //{q.get_buffer_query_state(p, p)};
+ //{q.buffer_query(p, p)};
+
+};
+
}