/* * include/shard/VPTree.h * * Copyright (C) 2023 Douglas Rumbaugh * * All outsides reserved. Published under the Modified BSD License. * */ #pragma once #include #include #include #include #include #include #include "ds/PriorityQueue.h" #include "util/Cursor.h" #include "ds/BloomFilter.h" #include "util/bf_config.h" #include "framework/MutableBuffer.h" #include "framework/RecordInterface.h" #include "framework/ShardInterface.h" #include "framework/QueryInterface.h" namespace de { template struct KNNQueryParms { R point; size_t k; }; template class KNNQuery; template struct KNNState { size_t k; KNNState() { k = 0; } }; template struct KNNBufferState { }; template class KNNDistCmpMax { public: KNNDistCmpMax(R *baseline) : P(baseline) {} inline bool operator()(const R *a, const R *b) requires WrappedInterface { return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec); } inline bool operator()(const R *a, const R *b) requires (!WrappedInterface){ return a->calc_distance(*P) > b->calc_distance(*P); } private: R *P; }; template class KNNDistCmpMin { public: KNNDistCmpMin(R *baseline) : P(baseline) {} inline bool operator()(const R *a, const R *b) requires WrappedInterface { return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec); } inline bool operator()(const R *a, const R *b) requires (!WrappedInterface){ return a->calc_distance(*P) < b->calc_distance(*P); } private: R *P; }; template class VPTree { private: struct vpnode { size_t start; size_t stop; bool leaf; double radius; vpnode *inside; vpnode *outside; vpnode() : start(0), stop(0), leaf(false), radius(0.0), inside(nullptr), outside(nullptr) {} ~vpnode() { delete inside; delete outside; } }; public: friend class KNNQuery; VPTree(MutableBuffer* buffer) : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { size_t alloc_size = (buffer->get_record_count() * sizeof(Wrapped)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Wrapped)) % CACHELINE_SIZE); assert(alloc_size % CACHELINE_SIZE == 0); m_data = (Wrapped*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); m_ptrs = new Wrapped*[buffer->get_record_count()]; size_t offset = 0; m_reccnt = 0; // FIXME: will eventually need to figure out tombstones // this one will likely require the multi-pass // approach, as otherwise we'll need to sort the // records repeatedly on each reconstruction. for (size_t i=0; iget_record_count(); i++) { auto rec = buffer->get_data() + i; if (rec->is_deleted()) { continue; } rec->header &= 3; m_data[m_reccnt] = *rec; m_ptrs[m_reccnt] = &m_data[m_reccnt]; m_reccnt++; } if (m_reccnt > 0) { m_root = build_vptree(); build_map(); } } VPTree(VPTree** shards, size_t len) : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { size_t attemp_reccnt = 0; for (size_t i=0; iget_record_count(); } size_t alloc_size = (attemp_reccnt * sizeof(Wrapped)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Wrapped)) % CACHELINE_SIZE); assert(alloc_size % CACHELINE_SIZE == 0); m_data = (Wrapped*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); m_ptrs = new Wrapped*[attemp_reccnt]; // FIXME: will eventually need to figure out tombstones // this one will likely require the multi-pass // approach, as otherwise we'll need to sort the // records repeatedly on each reconstruction. for (size_t i=0; iget_record_count(); j++) { if (shards[i]->get_record_at(j)->is_deleted()) { continue; } m_data[m_reccnt] = *shards[i]->get_record_at(j); m_ptrs[m_reccnt] = &m_data[m_reccnt]; m_reccnt++; } } if (m_reccnt > 0) { m_root = build_vptree(); build_map(); } } ~VPTree() { if (m_data) free(m_data); if (m_root) delete m_root; if (m_ptrs) delete[] m_ptrs; } Wrapped *point_lookup(const R &rec, bool filter=false) { auto idx = m_lookup_map.find(rec); if (idx == m_lookup_map.end()) { return nullptr; } return m_data + idx->second; } Wrapped* get_data() const { return m_data; } size_t get_record_count() const { return m_reccnt; } size_t get_tombstone_count() const { return m_tombstone_cnt; } const Wrapped* get_record_at(size_t idx) const { if (idx >= m_reccnt) return nullptr; return m_data + idx; } size_t get_memory_usage() { return m_node_cnt * sizeof(vpnode); } private: vpnode *build_vptree() { if (m_reccnt == 0) { return nullptr; } size_t lower = 0; size_t upper = m_reccnt - 1; auto rng = gsl_rng_alloc(gsl_rng_mt19937); auto root = build_subtree(lower, upper, rng); gsl_rng_free(rng); return root; } void build_map() { for (size_t i=0; i stop) { return nullptr; } // base-case: create a leaf node if (stop - start <= LEAFSZ) { vpnode *node = new vpnode(); node->start = start; node->stop = stop; node->leaf = true; m_node_cnt++; return node; } // select a random element to be the root of the // subtree auto i = start + gsl_rng_uniform_int(rng, stop - start + 1); swap(start, i); // partition elements based on their distance from the start, // with those elements with distance falling below the median // distance going into the left sub-array and those above // the median in the right. This is easily done using QuickSelect. auto mid = (start + 1 + stop) / 2; quickselect(start + 1, stop, mid, m_ptrs[start], rng); // Create a new node based on this partitioning vpnode *node = new vpnode(); node->start = start; // store the radius of the circle used for partitioning the node. node->radius = m_ptrs[start]->rec.calc_distance(m_ptrs[mid]->rec); // recursively construct the left and right subtrees node->inside = build_subtree(start + 1, mid-1, rng); node->outside = build_subtree(mid, stop, rng); m_node_cnt++; return node; } void quickselect(size_t start, size_t stop, size_t k, Wrapped *p, gsl_rng *rng) { if (start == stop) return; auto pivot = partition(start, stop, p, rng); if (k < pivot) { quickselect(start, pivot - 1, k, p, rng); } else if (k > pivot) { quickselect(pivot + 1, stop, k, p, rng); } } size_t partition(size_t start, size_t stop, Wrapped *p, gsl_rng *rng) { auto pivot = start + gsl_rng_uniform_int(rng, stop - start); double pivot_dist = p->rec.calc_distance(m_ptrs[pivot]->rec); swap(pivot, stop); size_t j = start; for (size_t i=start; irec.calc_distance(m_ptrs[i]->rec) < pivot_dist) { swap(j, i); j++; } } swap(j, stop); return j; } void swap(size_t idx1, size_t idx2) { auto tmp = m_ptrs[idx1]; m_ptrs[idx1] = m_ptrs[idx2]; m_ptrs[idx2] = tmp; } void search(vpnode *node, const R &point, size_t k, PriorityQueue, KNNDistCmpMax>> &pq, double *farthest) { if (node == nullptr) return; if (node->leaf) { for (size_t i=node->start; i<=node->stop; i++) { double d = point.calc_distance(m_ptrs[i]->rec); if (d < *farthest) { if (pq.size() == k) { pq.pop(); } pq.push(m_ptrs[i]); if (pq.size() == k) { *farthest = point.calc_distance(pq.peek().data->rec); } } } return; } double d = point.calc_distance(m_ptrs[node->start]->rec); if (d < *farthest) { if (pq.size() == k) { auto t = pq.peek().data->rec; pq.pop(); } pq.push(m_ptrs[node->start]); if (pq.size() == k) { *farthest = point.calc_distance(pq.peek().data->rec); } } if (d < node->radius) { if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } } else { if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } } } Wrapped* m_data; Wrapped** m_ptrs; std::unordered_map> m_lookup_map; size_t m_reccnt; size_t m_tombstone_cnt; size_t m_node_cnt; vpnode *m_root; }; template class KNNQuery { public: static void *get_query_state(VPTree *wss, void *parms) { return nullptr; } static void* get_buffer_query_state(MutableBuffer *buffer, void *parms) { return nullptr; } static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { return; } static std::vector> query(VPTree *wss, void *q_state, void *parms) { std::vector> results; KNNQueryParms *p = (KNNQueryParms *) parms; Wrapped wrec; wrec.rec = p->point; wrec.header = 0; PriorityQueue, KNNDistCmpMax>> pq(p->k, &wrec); double farthest = std::numeric_limits::max(); wss->search(wss->m_root, p->point, p->k, pq, &farthest); while (pq.size() > 0) { results.emplace_back(*pq.peek().data); pq.pop(); } return results; } static std::vector> buffer_query(MutableBuffer *buffer, void *state, void *parms) { KNNQueryParms *p = (KNNQueryParms *) parms; Wrapped wrec; wrec.rec = p->point; wrec.header = 0; size_t k = p->k; PriorityQueue, KNNDistCmpMax>> pq(k, &wrec); for (size_t i=0; iget_record_count(); i++) { // Skip over deleted records (under tagging) if ((buffer->get_data())[i].is_deleted()) { continue; } if (pq.size() < k) { pq.push(buffer->get_data() + i); } else { double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec); if (cur_dist < head_dist) { pq.pop(); pq.push(buffer->get_data() + i); } } } std::vector> results; while (pq.size() > 0) { results.emplace_back(*(pq.peek().data)); pq.pop(); } return results; } static std::vector merge(std::vector> &results, void *parms) { KNNQueryParms *p = (KNNQueryParms *) parms; R rec = p->point; size_t k = p->k; PriorityQueue> pq(k, &rec); for (size_t i=0; icalc_distance(rec); double cur_dist = results[i][j].calc_distance(rec); if (cur_dist < head_dist) { pq.pop(); pq.push(&results[i][j]); } } } } std::vector output; while (pq.size() > 0) { output.emplace(*pq.peek().data); pq.pop(); } return output; } static void delete_query_state(void *state) { auto s = (KNNState *) state; delete s; } static void delete_buffer_query_state(void *state) { auto s = (KNNBufferState *) state; delete s; } }; }