diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 11:40:14 -0400 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 11:40:14 -0400 |
| commit | 6b434ec5f2182cb9624a011bd8d65587cd5a0759 (patch) | |
| tree | 04960f23b369f641fe8386ced6a23906e3c94b34 | |
| parent | 5f6dd8bbc12f981c69d01d9e2c2057bfc97d429c (diff) | |
| download | dynamic-extension-6b434ec5f2182cb9624a011bd8d65587cd5a0759.tar.gz | |
VPTree: KNN query initial implementation
| -rw-r--r-- | include/shard/VPTree.h | 130 |
1 files changed, 126 insertions, 4 deletions
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h index 7376d4c..05ffd50 100644 --- a/include/shard/VPTree.h +++ b/include/shard/VPTree.h @@ -26,8 +26,6 @@ namespace de { -thread_local size_t wss_cancelations = 0; - template <NDRecordInterface R> struct KNNQueryParms { R point; @@ -278,6 +276,41 @@ private: m_data[idx2] = tmp; } + + void search(vpnode *node, const R &point, size_t k, PriorityQueue<R> &pq, double *farthest) { + if (node == nullptr) return; + + double d = point.calc_distance(m_data[node->idx].rec); + + if (d < *farthest) { + if (pq.size() == k) pq.pop(); + pq.push(&m_data[node->idx]); + if (pq.size() == k) { + *farthest = point.calc_distance(*pq.peek().data.rec); + } + } + + if (!node->inside && !node->outside) return; + + if (d < node->radius) { + if (d - *farthest <= node->radius) { + search(node->inside, point, k, pq, farthest); + } + + if (d + *farthest >= node->radius) { + search(node->outside, point, k, pq, farthest); + } + } else { + if (d + *farthest >= node->radius) { + search(node->outside, point, k, pq, farthest); + } + + if (d - *farthest <= node->radius) { + search(node->inside, point, k, pq, farthest); + } + } + } + Wrapped<R>* m_data; std::unordered_map<R, size_t, RecordHash<R>> m_lookup_map; size_t m_reccnt; @@ -285,9 +318,19 @@ private: size_t m_node_cnt; vpnode *m_root; - }; +template <NDRecordInterface R, R P> +class KNNDistCmp { +public: + inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires WrappedInterface<R> { + return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec); + } + + inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires (!WrappedInterface<R>){ + return a->data.calc_distance(P) > b->data.calc_distance(P); + } +}; template <NDRecordInterface R> class KNNQuery { @@ -301,15 +344,94 @@ public: } static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) { + return; } static std::vector<Wrapped<R>> query(VPTree<R> *wss, void *q_state, void *parms) { + std::vector<Wrapped<R>> results; + KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms; + Wrapped<R> wrec; + wrec.rec = p->point; + wrec.header = 0; + + PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq; + + double farthest = std::numeric_limits<double>::max(); + + wss->search(wss->m_root, p->point, p->k, pq, &farthest); + + while (pq.size() > 0) { + results.emplace_back(*pq.peek().data); + pq.pop(); + } + + return results; } static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) { + KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms; + Wrapped<R> wrec; + wrec.rec = p->point; + wrec.header = 0; + + size_t k = p->k; + + PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq; + for (size_t i=0; i<buffer->get_record_count(); i++) { + // Skip over deleted records (under tagging) + if (buffer->get_data()[i]->is_deleted()) { + continue; + } + + if (pq.size() < k) { + pq.push(buffer->get_data() + i); + } else { + double head_dist = pq.peek().data.rec->calc_distance(wrec.rec); + double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(buffer->get_data() + i); + } + } + } + + std::vector<Wrapped<R>> results; + while (pq.size() > 0) { + results.emplace(*pq.peek().data); + pq.pop(); + } } - static std::vector<R> merge(std::vector<std::vector<R>> &results) { + static std::vector<R> merge(std::vector<std::vector<R>> &results, void *parms) { + KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms; + R rec = p->point; + size_t k = p->k; + + PriorityQueue<R, KNNDistCmp<R, rec>> pq; + for (size_t i=0; i<results.size(); i++) { + for (size_t j=0; j<results.size(); j++) { + if (pq.size() < k) { + pq.push(&results[i][j]); + } else { + double head_dist = pq.peek().data->calc_distance(rec); + double cur_dist = results[i][j].calc_distance(rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(&results[i][j]); + } + } + } + } + + std::vector<R> output; + while (pq.size() > 0) { + output.emplace(*pq.peek().data); + pq.pop(); + } + + return output; } static void delete_query_state(void *state) { |