From d02fe67962c8002ddc6e0d6569128ae2645ea7fc Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Mon, 24 Jul 2023 16:49:21 -0400 Subject: VPTree: fixed knn query --- include/ds/PriorityQueue.h | 6 +- include/framework/RecordInterface.h | 10 ++- include/shard/VPTree.h | 136 +++++++++++++++++++++++------------- tests/vptree_tests.cpp | 74 +++++++++++++++++--- 4 files changed, 164 insertions(+), 62 deletions(-) diff --git a/include/ds/PriorityQueue.h b/include/ds/PriorityQueue.h index a8e9ba5..4612eef 100644 --- a/include/ds/PriorityQueue.h +++ b/include/ds/PriorityQueue.h @@ -23,6 +23,7 @@ struct queue_record { template class standard_minheap { public: + standard_minheap(R *baseline) {} inline bool operator()(const R* a, const R* b) { return *a < *b; } @@ -31,6 +32,7 @@ public: template class standard_maxheap { public: + standard_maxheap(R *baseline) {} inline bool operator()(const R* a, const R* b) { return *a > *b; } @@ -39,7 +41,8 @@ public: template > class PriorityQueue { public: - PriorityQueue(size_t size) : data(size), tail(0) {} + PriorityQueue(size_t size, R* cmp_baseline=nullptr) : data(size), tail(0), cmp(cmp_baseline) {} + ~PriorityQueue() = default; size_t size() const { @@ -97,6 +100,7 @@ private: std::vector> data; CMP cmp; size_t tail; + R *baseline; /* * Swap the elements at position a and position diff --git a/include/framework/RecordInterface.h b/include/framework/RecordInterface.h index 85a0794..cea9fbe 100644 --- a/include/framework/RecordInterface.h +++ b/include/framework/RecordInterface.h @@ -49,6 +49,7 @@ concept WrappedInterface = RecordInterface && requires(R r, R s, bool b) { {r.set_tombstone(b)}; {r.is_tombstone()} -> std::convertible_to; {r < s} -> std::convertible_to; + {r == s} ->std::convertible_to; }; template @@ -79,6 +80,11 @@ struct Wrapped { inline bool operator<(const Wrapped& other) const { return rec < other.rec || (rec == other.rec && header < other.header); } + + inline bool operator==(const Wrapped& other) const { + return rec == other.rec; + } + }; template @@ -185,10 +191,10 @@ struct EuclidPoint{ inline double calc_distance(const EuclidPoint& other) const { double dist = 0; for (size_t i=0; i struct KNNBufferState { - size_t cutoff; - size_t sample_size; - Alias* alias; - decltype(R::weight) max_weight; - decltype(R::weight) total_weight; - - ~KNNBufferState() { - delete alias; + +}; + + +template +class KNNDistCmpMax { +public: + KNNDistCmpMax(R *baseline) : P(baseline) {} + + inline bool operator()(const R *a, const R *b) requires WrappedInterface { + return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec); + } + + inline bool operator()(const R *a, const R *b) requires (!WrappedInterface){ + return a->calc_distance(*P) > b->calc_distance(*P); + } + +private: + R *P; +}; + +template +class KNNDistCmpMin { +public: + KNNDistCmpMin(R *baseline) : P(baseline) {} + + inline bool operator()(const R *a, const R *b) requires WrappedInterface { + return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec); + } + + inline bool operator()(const R *a, const R *b) requires (!WrappedInterface){ + return a->calc_distance(*P) < b->calc_distance(*P); } +private: + R *P; }; + + template class VPTree { private: struct vpnode { - size_t idx = 0; - double radius = 0; - vpnode *inside = nullptr; - vpnode *outside = nullptr; + size_t idx; + double radius; + vpnode *inside; + vpnode *outside; + + vpnode(size_t idx) : idx(idx), radius(0), inside(nullptr), outside(nullptr) {} ~vpnode() { delete inside; @@ -180,15 +210,17 @@ public: private: vpnode *build_vptree() { - assert(m_reccnt > 0); + if (m_reccnt == 0) { + return nullptr; + } size_t lower = 0; size_t upper = m_reccnt - 1; auto rng = gsl_rng_alloc(gsl_rng_mt19937); - auto n = build_subtree(lower, upper, rng); + auto root = build_subtree(lower, upper, rng); gsl_rng_free(rng); - return n; + return root; } void build_map() { @@ -204,32 +236,42 @@ private: } vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) { - if (start >= stop) { + // base-case: sometimes happens (probably because of the +1 and -1 + // in the first recursive call) + if (start > stop) { return nullptr; } - // select a random element to partition based on, and swap - // it to the front of the sub-array - auto i = start + gsl_rng_uniform_int(rng, stop - start); + // base-case: create a leaf node + if (start == stop) { + vpnode *node = new vpnode(start); + node->idx = start; + + m_node_cnt++; + + return node; + } + + // select a random element to be the root of the + // subtree + auto i = start + gsl_rng_uniform_int(rng, stop - start + 1); swap(start, i); // partition elements based on their distance from the start, // with those elements with distance falling below the median // distance going into the left sub-array and those above // the median in the right. This is easily done using QuickSelect. - auto mid = ((start+1) + stop) / 2; + auto mid = (start + 1 + stop) / 2; quickselect(start + 1, stop, mid, m_data[start], rng); // Create a new node based on this partitioning - vpnode *node = new vpnode(); + vpnode *node = new vpnode(start); - // store the radius of the circle used for partitioning the - // node. - node->idx = start; + // store the radius of the circle used for partitioning the node. node->radius = m_data[start].rec.calc_distance(m_data[mid].rec); // recursively construct the left and right subtrees - node->inside = build_subtree(start + 1, mid - 1, rng); + node->inside = build_subtree(start + 1, mid-1, rng); node->outside = build_subtree(mid, stop, rng); m_node_cnt++; @@ -277,35 +319,38 @@ private: } - void search(vpnode *node, const R &point, size_t k, PriorityQueue &pq, double *farthest) { + void search(vpnode *node, const R &point, size_t k, PriorityQueue, KNNDistCmpMax>> &pq, double *farthest) { if (node == nullptr) return; double d = point.calc_distance(m_data[node->idx].rec); if (d < *farthest) { - if (pq.size() == k) pq.pop(); + if (pq.size() == k) { + auto t = pq.peek().data->rec; + pq.pop(); + } pq.push(&m_data[node->idx]); if (pq.size() == k) { - *farthest = point.calc_distance(*pq.peek().data.rec); + *farthest = point.calc_distance(pq.peek().data->rec); } } if (!node->inside && !node->outside) return; if (d < node->radius) { - if (d - *farthest <= node->radius) { + if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } - if (d + *farthest >= node->radius) { + if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } } else { - if (d + *farthest >= node->radius) { + if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } - if (d - *farthest <= node->radius) { + if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } } @@ -320,17 +365,6 @@ private: vpnode *m_root; }; -template -class KNNDistCmp { -public: - inline bool operator()(queue_record *a, queue_record *b) requires WrappedInterface { - return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec); - } - - inline bool operator()(queue_record *a, queue_record *b) requires (!WrappedInterface){ - return a->data.calc_distance(P) > b->data.calc_distance(P); - } -}; template class KNNQuery { @@ -354,7 +388,7 @@ public: wrec.rec = p->point; wrec.header = 0; - PriorityQueue, KNNDistCmp, wrec>> pq; + PriorityQueue, KNNDistCmpMax>> pq(p->k, &wrec); double farthest = std::numeric_limits::max(); @@ -376,18 +410,18 @@ public: size_t k = p->k; - PriorityQueue, KNNDistCmp, wrec>> pq; + PriorityQueue, KNNDistCmpMax>> pq(k, &wrec); for (size_t i=0; iget_record_count(); i++) { // Skip over deleted records (under tagging) - if (buffer->get_data()[i]->is_deleted()) { + if ((buffer->get_data())[i].is_deleted()) { continue; } if (pq.size() < k) { pq.push(buffer->get_data() + i); } else { - double head_dist = pq.peek().data.rec->calc_distance(wrec.rec); - double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec); + double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); + double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec); if (cur_dist < head_dist) { pq.pop(); @@ -398,9 +432,11 @@ public: std::vector> results; while (pq.size() > 0) { - results.emplace(*pq.peek().data); + results.emplace_back(*(pq.peek().data)); pq.pop(); } + + return results; } static std::vector merge(std::vector> &results, void *parms) { @@ -408,7 +444,7 @@ public: R rec = p->point; size_t k = p->k; - PriorityQueue> pq; + PriorityQueue> pq(k, &rec); for (size_t i=0; i @@ -108,6 +109,65 @@ START_TEST(t_point_lookup_miss) } +START_TEST(t_buffer_query) +{ + size_t n = 10000; + auto buffer = create_2d_sequential_mbuffer(n); + + PRec target; + target.data[0] = 120; + target.data[1] = 120; + + KNNQueryParms p; + p.k = 10; + p.point = target; + + auto state = KNNQuery::get_buffer_query_state(buffer, &p); + auto result = KNNQuery::buffer_query(buffer, state, &p); + KNNQuery::delete_buffer_query_state(state); + + std::sort(result.begin(), result.end()); + size_t start = 120 - 5; + for (size_t i=0; i p; + p.k = 10; + p.point = target; + + auto state = KNNQuery::get_buffer_query_state(buffer, &p); + auto result = KNNQuery::buffer_query(buffer, state, &p); + + KNNQuery::delete_buffer_query_state(state); + + auto vptree = VPTree(buffer); + auto state_2 = KNNQuery::get_query_state(&vptree, &p); + auto result_2 = KNNQuery::query(&vptree, state_2, &p); + KNNQuery::delete_query_state(state_2); + + std::sort(result_2.begin(), result_2.end()); + size_t start = 46; + for (size_t i=0; i