summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-07-24 11:40:14 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-07-24 11:40:14 -0400
commit6b434ec5f2182cb9624a011bd8d65587cd5a0759 (patch)
tree04960f23b369f641fe8386ced6a23906e3c94b34
parent5f6dd8bbc12f981c69d01d9e2c2057bfc97d429c (diff)
downloaddynamic-extension-6b434ec5f2182cb9624a011bd8d65587cd5a0759.tar.gz
VPTree: KNN query initial implementation
-rw-r--r--include/shard/VPTree.h130
1 files changed, 126 insertions, 4 deletions
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h
index 7376d4c..05ffd50 100644
--- a/include/shard/VPTree.h
+++ b/include/shard/VPTree.h
@@ -26,8 +26,6 @@
namespace de {
-thread_local size_t wss_cancelations = 0;
-
template <NDRecordInterface R>
struct KNNQueryParms {
R point;
@@ -278,6 +276,41 @@ private:
m_data[idx2] = tmp;
}
+
+ void search(vpnode *node, const R &point, size_t k, PriorityQueue<R> &pq, double *farthest) {
+ if (node == nullptr) return;
+
+ double d = point.calc_distance(m_data[node->idx].rec);
+
+ if (d < *farthest) {
+ if (pq.size() == k) pq.pop();
+ pq.push(&m_data[node->idx]);
+ if (pq.size() == k) {
+ *farthest = point.calc_distance(*pq.peek().data.rec);
+ }
+ }
+
+ if (!node->inside && !node->outside) return;
+
+ if (d < node->radius) {
+ if (d - *farthest <= node->radius) {
+ search(node->inside, point, k, pq, farthest);
+ }
+
+ if (d + *farthest >= node->radius) {
+ search(node->outside, point, k, pq, farthest);
+ }
+ } else {
+ if (d + *farthest >= node->radius) {
+ search(node->outside, point, k, pq, farthest);
+ }
+
+ if (d - *farthest <= node->radius) {
+ search(node->inside, point, k, pq, farthest);
+ }
+ }
+ }
+
Wrapped<R>* m_data;
std::unordered_map<R, size_t, RecordHash<R>> m_lookup_map;
size_t m_reccnt;
@@ -285,9 +318,19 @@ private:
size_t m_node_cnt;
vpnode *m_root;
-
};
+template <NDRecordInterface R, R P>
+class KNNDistCmp {
+public:
+ inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires WrappedInterface<R> {
+ return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec);
+ }
+
+ inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires (!WrappedInterface<R>){
+ return a->data.calc_distance(P) > b->data.calc_distance(P);
+ }
+};
template <NDRecordInterface R>
class KNNQuery {
@@ -301,15 +344,94 @@ public:
}
static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ return;
}
static std::vector<Wrapped<R>> query(VPTree<R> *wss, void *q_state, void *parms) {
+ std::vector<Wrapped<R>> results;
+ KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms;
+ Wrapped<R> wrec;
+ wrec.rec = p->point;
+ wrec.header = 0;
+
+ PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq;
+
+ double farthest = std::numeric_limits<double>::max();
+
+ wss->search(wss->m_root, p->point, p->k, pq, &farthest);
+
+ while (pq.size() > 0) {
+ results.emplace_back(*pq.peek().data);
+ pq.pop();
+ }
+
+ return results;
}
static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) {
+ KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms;
+ Wrapped<R> wrec;
+ wrec.rec = p->point;
+ wrec.header = 0;
+
+ size_t k = p->k;
+
+ PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq;
+ for (size_t i=0; i<buffer->get_record_count(); i++) {
+ // Skip over deleted records (under tagging)
+ if (buffer->get_data()[i]->is_deleted()) {
+ continue;
+ }
+
+ if (pq.size() < k) {
+ pq.push(buffer->get_data() + i);
+ } else {
+ double head_dist = pq.peek().data.rec->calc_distance(wrec.rec);
+ double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(buffer->get_data() + i);
+ }
+ }
+ }
+
+ std::vector<Wrapped<R>> results;
+ while (pq.size() > 0) {
+ results.emplace(*pq.peek().data);
+ pq.pop();
+ }
}
- static std::vector<R> merge(std::vector<std::vector<R>> &results) {
+ static std::vector<R> merge(std::vector<std::vector<R>> &results, void *parms) {
+ KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms;
+ R rec = p->point;
+ size_t k = p->k;
+
+ PriorityQueue<R, KNNDistCmp<R, rec>> pq;
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results.size(); j++) {
+ if (pq.size() < k) {
+ pq.push(&results[i][j]);
+ } else {
+ double head_dist = pq.peek().data->calc_distance(rec);
+ double cur_dist = results[i][j].calc_distance(rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(&results[i][j]);
+ }
+ }
+ }
+ }
+
+ std::vector<R> output;
+ while (pq.size() > 0) {
+ output.emplace(*pq.peek().data);
+ pq.pop();
+ }
+
+ return output;
}
static void delete_query_state(void *state) {