/* * include/shard/VPTree.h * * Copyright (C) 2023 Douglas Rumbaugh * * All outsides reserved. Published under the Modified BSD License. * */ #pragma once #include #include #include #include #include #include #include "ds/PriorityQueue.h" #include "util/Cursor.h" #include "ds/BloomFilter.h" #include "util/bf_config.h" #include "framework/MutableBuffer.h" #include "framework/RecordInterface.h" #include "framework/ShardInterface.h" #include "framework/QueryInterface.h" namespace de { template struct KNNQueryParms { R point; size_t k; }; template class KNNQuery; template struct KNNState { size_t k; KNNState() { k = 0; } }; template struct KNNBufferState { }; template class KNNDistCmpMax { public: KNNDistCmpMax(R *baseline) : P(baseline) {} inline bool operator()(const R *a, const R *b) requires WrappedInterface { return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec); } inline bool operator()(const R *a, const R *b) requires (!WrappedInterface){ return a->calc_distance(*P) > b->calc_distance(*P); } private: R *P; }; template class KNNDistCmpMin { public: KNNDistCmpMin(R *baseline) : P(baseline) {} inline bool operator()(const R *a, const R *b) requires WrappedInterface { return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec); } inline bool operator()(const R *a, const R *b) requires (!WrappedInterface){ return a->calc_distance(*P) < b->calc_distance(*P); } private: R *P; }; template class VPTree { private: struct vpnode { size_t idx; double radius; vpnode *inside; vpnode *outside; vpnode(size_t idx) : idx(idx), radius(0), inside(nullptr), outside(nullptr) {} ~vpnode() { delete inside; delete outside; } }; public: friend class KNNQuery; VPTree(MutableBuffer* buffer) : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { size_t alloc_size = (buffer->get_record_count() * sizeof(Wrapped)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Wrapped)) % CACHELINE_SIZE); assert(alloc_size % CACHELINE_SIZE == 0); m_data = (Wrapped*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); size_t offset = 0; m_reccnt = 0; // FIXME: will eventually need to figure out tombstones // this one will likely require the multi-pass // approach, as otherwise we'll need to sort the // records repeatedly on each reconstruction. for (size_t i=0; iget_record_count(); i++) { auto rec = buffer->get_data() + i; if (rec->is_deleted()) { continue; } rec->header &= 3; m_data[m_reccnt++] = *rec; } if (m_reccnt > 0) { m_root = build_vptree(); build_map(); } } VPTree(VPTree** shards, size_t len) : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { size_t attemp_reccnt = 0; for (size_t i=0; iget_record_count(); } size_t alloc_size = (attemp_reccnt * sizeof(Wrapped)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Wrapped)) % CACHELINE_SIZE); assert(alloc_size % CACHELINE_SIZE == 0); m_data = (Wrapped*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); // FIXME: will eventually need to figure out tombstones // this one will likely require the multi-pass // approach, as otherwise we'll need to sort the // records repeatedly on each reconstruction. for (size_t i=0; iget_record_count(); j++) { if (shards[i]->get_record_at(j)->is_deleted()) { continue; } m_data[m_reccnt++] = *shards[i]->get_record_at(j); } } if (m_reccnt > 0) { m_root = build_vptree(); build_map(); } } ~VPTree() { if (m_data) free(m_data); if (m_root) delete m_root; } Wrapped *point_lookup(const R &rec, bool filter=false) { auto idx = m_lookup_map.find(rec); if (idx == m_lookup_map.end()) { return nullptr; } return m_data + idx->second; } Wrapped* get_data() const { return m_data; } size_t get_record_count() const { return m_reccnt; } size_t get_tombstone_count() const { return m_tombstone_cnt; } const Wrapped* get_record_at(size_t idx) const { if (idx >= m_reccnt) return nullptr; return m_data + idx; } size_t get_memory_usage() { return m_node_cnt * sizeof(vpnode); } private: vpnode *build_vptree() { if (m_reccnt == 0) { return nullptr; } size_t lower = 0; size_t upper = m_reccnt - 1; auto rng = gsl_rng_alloc(gsl_rng_mt19937); auto root = build_subtree(lower, upper, rng); gsl_rng_free(rng); return root; } void build_map() { for (size_t i=0; i stop) { return nullptr; } // base-case: create a leaf node if (start == stop) { vpnode *node = new vpnode(start); node->idx = start; m_node_cnt++; return node; } // select a random element to be the root of the // subtree auto i = start + gsl_rng_uniform_int(rng, stop - start + 1); swap(start, i); // partition elements based on their distance from the start, // with those elements with distance falling below the median // distance going into the left sub-array and those above // the median in the right. This is easily done using QuickSelect. auto mid = (start + 1 + stop) / 2; quickselect(start + 1, stop, mid, m_data[start], rng); // Create a new node based on this partitioning vpnode *node = new vpnode(start); // store the radius of the circle used for partitioning the node. node->radius = m_data[start].rec.calc_distance(m_data[mid].rec); // recursively construct the left and right subtrees node->inside = build_subtree(start + 1, mid-1, rng); node->outside = build_subtree(mid, stop, rng); m_node_cnt++; return node; } void quickselect(size_t start, size_t stop, size_t k, Wrapped p, gsl_rng *rng) { if (start == stop) return; auto pivot = partition(start, stop, p, rng); if (k < pivot) { quickselect(start, pivot - 1, k, p, rng); } else if (k > pivot) { quickselect(pivot + 1, stop, k, p, rng); } } size_t partition(size_t start, size_t stop, Wrapped p, gsl_rng *rng) { auto pivot = start + gsl_rng_uniform_int(rng, stop - start); double pivot_dist = p.rec.calc_distance(m_data[pivot].rec); swap(pivot, stop); size_t j = start; for (size_t i=start; i tmp = m_data[idx1]; m_data[idx1] = m_data[idx2]; m_data[idx2] = tmp; } void search(vpnode *node, const R &point, size_t k, PriorityQueue, KNNDistCmpMax>> &pq, double *farthest) { if (node == nullptr) return; double d = point.calc_distance(m_data[node->idx].rec); if (d < *farthest) { if (pq.size() == k) { auto t = pq.peek().data->rec; pq.pop(); } pq.push(&m_data[node->idx]); if (pq.size() == k) { *farthest = point.calc_distance(pq.peek().data->rec); } } if (!node->inside && !node->outside) return; if (d < node->radius) { if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } } else { if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } } } Wrapped* m_data; std::unordered_map> m_lookup_map; size_t m_reccnt; size_t m_tombstone_cnt; size_t m_node_cnt; vpnode *m_root; }; template class KNNQuery { public: static void *get_query_state(VPTree *wss, void *parms) { return nullptr; } static void* get_buffer_query_state(MutableBuffer *buffer, void *parms) { return nullptr; } static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { return; } static std::vector> query(VPTree *wss, void *q_state, void *parms) { std::vector> results; KNNQueryParms *p = (KNNQueryParms *) parms; Wrapped wrec; wrec.rec = p->point; wrec.header = 0; PriorityQueue, KNNDistCmpMax>> pq(p->k, &wrec); double farthest = std::numeric_limits::max(); wss->search(wss->m_root, p->point, p->k, pq, &farthest); while (pq.size() > 0) { results.emplace_back(*pq.peek().data); pq.pop(); } return results; } static std::vector> buffer_query(MutableBuffer *buffer, void *state, void *parms) { KNNQueryParms *p = (KNNQueryParms *) parms; Wrapped wrec; wrec.rec = p->point; wrec.header = 0; size_t k = p->k; PriorityQueue, KNNDistCmpMax>> pq(k, &wrec); for (size_t i=0; iget_record_count(); i++) { // Skip over deleted records (under tagging) if ((buffer->get_data())[i].is_deleted()) { continue; } if (pq.size() < k) { pq.push(buffer->get_data() + i); } else { double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec); if (cur_dist < head_dist) { pq.pop(); pq.push(buffer->get_data() + i); } } } std::vector> results; while (pq.size() > 0) { results.emplace_back(*(pq.peek().data)); pq.pop(); } return results; } static std::vector merge(std::vector> &results, void *parms) { KNNQueryParms *p = (KNNQueryParms *) parms; R rec = p->point; size_t k = p->k; PriorityQueue> pq(k, &rec); for (size_t i=0; icalc_distance(rec); double cur_dist = results[i][j].calc_distance(rec); if (cur_dist < head_dist) { pq.pop(); pq.push(&results[i][j]); } } } } std::vector output; while (pq.size() > 0) { output.emplace(*pq.peek().data); pq.pop(); } return output; } static void delete_query_state(void *state) { auto s = (KNNState *) state; delete s; } static void delete_buffer_query_state(void *state) { auto s = (KNNBufferState *) state; delete s; } }; }