From 6b434ec5f2182cb9624a011bd8d65587cd5a0759 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Mon, 24 Jul 2023 11:40:14 -0400 Subject: VPTree: KNN query initial implementation --- include/shard/VPTree.h | 130 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 126 insertions(+), 4 deletions(-) (limited to 'include/shard/VPTree.h') diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h index 7376d4c..05ffd50 100644 --- a/include/shard/VPTree.h +++ b/include/shard/VPTree.h @@ -26,8 +26,6 @@ namespace de { -thread_local size_t wss_cancelations = 0; - template struct KNNQueryParms { R point; @@ -278,6 +276,41 @@ private: m_data[idx2] = tmp; } + + void search(vpnode *node, const R &point, size_t k, PriorityQueue &pq, double *farthest) { + if (node == nullptr) return; + + double d = point.calc_distance(m_data[node->idx].rec); + + if (d < *farthest) { + if (pq.size() == k) pq.pop(); + pq.push(&m_data[node->idx]); + if (pq.size() == k) { + *farthest = point.calc_distance(*pq.peek().data.rec); + } + } + + if (!node->inside && !node->outside) return; + + if (d < node->radius) { + if (d - *farthest <= node->radius) { + search(node->inside, point, k, pq, farthest); + } + + if (d + *farthest >= node->radius) { + search(node->outside, point, k, pq, farthest); + } + } else { + if (d + *farthest >= node->radius) { + search(node->outside, point, k, pq, farthest); + } + + if (d - *farthest <= node->radius) { + search(node->inside, point, k, pq, farthest); + } + } + } + Wrapped* m_data; std::unordered_map> m_lookup_map; size_t m_reccnt; @@ -285,9 +318,19 @@ private: size_t m_node_cnt; vpnode *m_root; - }; +template +class KNNDistCmp { +public: + inline bool operator()(queue_record *a, queue_record *b) requires WrappedInterface { + return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec); + } + + inline bool operator()(queue_record *a, queue_record *b) requires (!WrappedInterface){ + return a->data.calc_distance(P) > b->data.calc_distance(P); + } +}; template class KNNQuery { @@ -301,15 +344,94 @@ public: } static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { + return; } static std::vector> query(VPTree *wss, void *q_state, void *parms) { + std::vector> results; + KNNQueryParms *p = (KNNQueryParms *) parms; + Wrapped wrec; + wrec.rec = p->point; + wrec.header = 0; + + PriorityQueue, KNNDistCmp, wrec>> pq; + + double farthest = std::numeric_limits::max(); + + wss->search(wss->m_root, p->point, p->k, pq, &farthest); + + while (pq.size() > 0) { + results.emplace_back(*pq.peek().data); + pq.pop(); + } + + return results; } static std::vector> buffer_query(MutableBuffer *buffer, void *state, void *parms) { + KNNQueryParms *p = (KNNQueryParms *) parms; + Wrapped wrec; + wrec.rec = p->point; + wrec.header = 0; + + size_t k = p->k; + + PriorityQueue, KNNDistCmp, wrec>> pq; + for (size_t i=0; iget_record_count(); i++) { + // Skip over deleted records (under tagging) + if (buffer->get_data()[i]->is_deleted()) { + continue; + } + + if (pq.size() < k) { + pq.push(buffer->get_data() + i); + } else { + double head_dist = pq.peek().data.rec->calc_distance(wrec.rec); + double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(buffer->get_data() + i); + } + } + } + + std::vector> results; + while (pq.size() > 0) { + results.emplace(*pq.peek().data); + pq.pop(); + } } - static std::vector merge(std::vector> &results) { + static std::vector merge(std::vector> &results, void *parms) { + KNNQueryParms *p = (KNNQueryParms *) parms; + R rec = p->point; + size_t k = p->k; + + PriorityQueue> pq; + for (size_t i=0; icalc_distance(rec); + double cur_dist = results[i][j].calc_distance(rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(&results[i][j]); + } + } + } + } + + std::vector output; + while (pq.size() > 0) { + output.emplace(*pq.peek().data); + pq.pop(); + } + + return output; } static void delete_query_state(void *state) { -- cgit v1.2.3