diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 16:49:21 -0400 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 16:49:21 -0400 |
| commit | d02fe67962c8002ddc6e0d6569128ae2645ea7fc (patch) | |
| tree | b0b27a29c58c65d51984318433f58698f297700e | |
| parent | ac018f5f96c32c96158a239fbfeb9dc439c95548 (diff) | |
| download | dynamic-extension-d02fe67962c8002ddc6e0d6569128ae2645ea7fc.tar.gz | |
VPTree: fixed knn query
| -rw-r--r-- | include/ds/PriorityQueue.h | 6 | ||||
| -rw-r--r-- | include/framework/RecordInterface.h | 10 | ||||
| -rw-r--r-- | include/shard/VPTree.h | 136 | ||||
| -rw-r--r-- | tests/vptree_tests.cpp | 74 |
4 files changed, 164 insertions, 62 deletions
diff --git a/include/ds/PriorityQueue.h b/include/ds/PriorityQueue.h index a8e9ba5..4612eef 100644 --- a/include/ds/PriorityQueue.h +++ b/include/ds/PriorityQueue.h @@ -23,6 +23,7 @@ struct queue_record { template <typename R> class standard_minheap { public: + standard_minheap(R *baseline) {} inline bool operator()(const R* a, const R* b) { return *a < *b; } @@ -31,6 +32,7 @@ public: template <typename R> class standard_maxheap { public: + standard_maxheap(R *baseline) {} inline bool operator()(const R* a, const R* b) { return *a > *b; } @@ -39,7 +41,8 @@ public: template <typename R, typename CMP=standard_minheap<R>> class PriorityQueue { public: - PriorityQueue(size_t size) : data(size), tail(0) {} + PriorityQueue(size_t size, R* cmp_baseline=nullptr) : data(size), tail(0), cmp(cmp_baseline) {} + ~PriorityQueue() = default; size_t size() const { @@ -97,6 +100,7 @@ private: std::vector<queue_record<R>> data; CMP cmp; size_t tail; + R *baseline; /* * Swap the elements at position a and position diff --git a/include/framework/RecordInterface.h b/include/framework/RecordInterface.h index 85a0794..cea9fbe 100644 --- a/include/framework/RecordInterface.h +++ b/include/framework/RecordInterface.h @@ -49,6 +49,7 @@ concept WrappedInterface = RecordInterface<R> && requires(R r, R s, bool b) { {r.set_tombstone(b)}; {r.is_tombstone()} -> std::convertible_to<bool>; {r < s} -> std::convertible_to<bool>; + {r == s} ->std::convertible_to<bool>; }; template<RecordInterface R> @@ -79,6 +80,11 @@ struct Wrapped { inline bool operator<(const Wrapped& other) const { return rec < other.rec || (rec == other.rec && header < other.header); } + + inline bool operator==(const Wrapped& other) const { + return rec == other.rec; + } + }; template <typename K, typename V> @@ -185,10 +191,10 @@ struct EuclidPoint{ inline double calc_distance(const EuclidPoint& other) const { double dist = 0; for (size_t i=0; i<D; i++) { - dist += pow(data[i] - other.data[i], 2); + dist += (data[i] - other.data[i]) * (data[i] - other.data[i]); } - return sqrt(dist); + return std::sqrt(dist); } }; diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h index 05ffd50..927108c 100644 --- a/include/shard/VPTree.h +++ b/include/shard/VPTree.h @@ -46,26 +46,56 @@ struct KNNState { template <NDRecordInterface R> struct KNNBufferState { - size_t cutoff; - size_t sample_size; - Alias* alias; - decltype(R::weight) max_weight; - decltype(R::weight) total_weight; - - ~KNNBufferState() { - delete alias; + +}; + + +template <typename R> +class KNNDistCmpMax { +public: + KNNDistCmpMax(R *baseline) : P(baseline) {} + + inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> { + return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec); + } + + inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){ + return a->calc_distance(*P) > b->calc_distance(*P); + } + +private: + R *P; +}; + +template <typename R> +class KNNDistCmpMin { +public: + KNNDistCmpMin(R *baseline) : P(baseline) {} + + inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> { + return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec); + } + + inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){ + return a->calc_distance(*P) < b->calc_distance(*P); } +private: + R *P; }; + + template <NDRecordInterface R> class VPTree { private: struct vpnode { - size_t idx = 0; - double radius = 0; - vpnode *inside = nullptr; - vpnode *outside = nullptr; + size_t idx; + double radius; + vpnode *inside; + vpnode *outside; + + vpnode(size_t idx) : idx(idx), radius(0), inside(nullptr), outside(nullptr) {} ~vpnode() { delete inside; @@ -180,15 +210,17 @@ public: private: vpnode *build_vptree() { - assert(m_reccnt > 0); + if (m_reccnt == 0) { + return nullptr; + } size_t lower = 0; size_t upper = m_reccnt - 1; auto rng = gsl_rng_alloc(gsl_rng_mt19937); - auto n = build_subtree(lower, upper, rng); + auto root = build_subtree(lower, upper, rng); gsl_rng_free(rng); - return n; + return root; } void build_map() { @@ -204,32 +236,42 @@ private: } vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) { - if (start >= stop) { + // base-case: sometimes happens (probably because of the +1 and -1 + // in the first recursive call) + if (start > stop) { return nullptr; } - // select a random element to partition based on, and swap - // it to the front of the sub-array - auto i = start + gsl_rng_uniform_int(rng, stop - start); + // base-case: create a leaf node + if (start == stop) { + vpnode *node = new vpnode(start); + node->idx = start; + + m_node_cnt++; + + return node; + } + + // select a random element to be the root of the + // subtree + auto i = start + gsl_rng_uniform_int(rng, stop - start + 1); swap(start, i); // partition elements based on their distance from the start, // with those elements with distance falling below the median // distance going into the left sub-array and those above // the median in the right. This is easily done using QuickSelect. - auto mid = ((start+1) + stop) / 2; + auto mid = (start + 1 + stop) / 2; quickselect(start + 1, stop, mid, m_data[start], rng); // Create a new node based on this partitioning - vpnode *node = new vpnode(); + vpnode *node = new vpnode(start); - // store the radius of the circle used for partitioning the - // node. - node->idx = start; + // store the radius of the circle used for partitioning the node. node->radius = m_data[start].rec.calc_distance(m_data[mid].rec); // recursively construct the left and right subtrees - node->inside = build_subtree(start + 1, mid - 1, rng); + node->inside = build_subtree(start + 1, mid-1, rng); node->outside = build_subtree(mid, stop, rng); m_node_cnt++; @@ -277,35 +319,38 @@ private: } - void search(vpnode *node, const R &point, size_t k, PriorityQueue<R> &pq, double *farthest) { + void search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> &pq, double *farthest) { if (node == nullptr) return; double d = point.calc_distance(m_data[node->idx].rec); if (d < *farthest) { - if (pq.size() == k) pq.pop(); + if (pq.size() == k) { + auto t = pq.peek().data->rec; + pq.pop(); + } pq.push(&m_data[node->idx]); if (pq.size() == k) { - *farthest = point.calc_distance(*pq.peek().data.rec); + *farthest = point.calc_distance(pq.peek().data->rec); } } if (!node->inside && !node->outside) return; if (d < node->radius) { - if (d - *farthest <= node->radius) { + if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } - if (d + *farthest >= node->radius) { + if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } } else { - if (d + *farthest >= node->radius) { + if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } - if (d - *farthest <= node->radius) { + if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } } @@ -320,17 +365,6 @@ private: vpnode *m_root; }; -template <NDRecordInterface R, R P> -class KNNDistCmp { -public: - inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires WrappedInterface<R> { - return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec); - } - - inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires (!WrappedInterface<R>){ - return a->data.calc_distance(P) > b->data.calc_distance(P); - } -}; template <NDRecordInterface R> class KNNQuery { @@ -354,7 +388,7 @@ public: wrec.rec = p->point; wrec.header = 0; - PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq; + PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(p->k, &wrec); double farthest = std::numeric_limits<double>::max(); @@ -376,18 +410,18 @@ public: size_t k = p->k; - PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq; + PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(k, &wrec); for (size_t i=0; i<buffer->get_record_count(); i++) { // Skip over deleted records (under tagging) - if (buffer->get_data()[i]->is_deleted()) { + if ((buffer->get_data())[i].is_deleted()) { continue; } if (pq.size() < k) { pq.push(buffer->get_data() + i); } else { - double head_dist = pq.peek().data.rec->calc_distance(wrec.rec); - double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec); + double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); + double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec); if (cur_dist < head_dist) { pq.pop(); @@ -398,9 +432,11 @@ public: std::vector<Wrapped<R>> results; while (pq.size() > 0) { - results.emplace(*pq.peek().data); + results.emplace_back(*(pq.peek().data)); pq.pop(); } + + return results; } static std::vector<R> merge(std::vector<std::vector<R>> &results, void *parms) { @@ -408,7 +444,7 @@ public: R rec = p->point; size_t k = p->k; - PriorityQueue<R, KNNDistCmp<R, rec>> pq; + PriorityQueue<R, KNNDistCmpMax<R>> pq(k, &rec); for (size_t i=0; i<results.size(); i++) { for (size_t j=0; j<results.size(); j++) { if (pq.size() < k) { diff --git a/tests/vptree_tests.cpp b/tests/vptree_tests.cpp index b86e1e9..894e64a 100644 --- a/tests/vptree_tests.cpp +++ b/tests/vptree_tests.cpp @@ -11,6 +11,7 @@ #include "shard/VPTree.h" #include "testing.h" +#include "vptree.hpp" #include <check.h> @@ -108,6 +109,65 @@ START_TEST(t_point_lookup_miss) } +START_TEST(t_buffer_query) +{ + size_t n = 10000; + auto buffer = create_2d_sequential_mbuffer(n); + + PRec target; + target.data[0] = 120; + target.data[1] = 120; + + KNNQueryParms<PRec> p; + p.k = 10; + p.point = target; + + auto state = KNNQuery<PRec>::get_buffer_query_state(buffer, &p); + auto result = KNNQuery<PRec>::buffer_query(buffer, state, &p); + KNNQuery<PRec>::delete_buffer_query_state(state); + + std::sort(result.begin(), result.end()); + size_t start = 120 - 5; + for (size_t i=0; i<result.size(); i++) { + ck_assert_int_eq(result[i].rec.data[0], start++); + } + + delete buffer; +} + +START_TEST(t_knn_query) +{ + size_t n = 100; + auto buffer = create_2d_sequential_mbuffer(n); + + PRec target; + target.data[0] = 50; + target.data[1] = 50; + + KNNQueryParms<PRec> p; + p.k = 10; + p.point = target; + + auto state = KNNQuery<PRec>::get_buffer_query_state(buffer, &p); + auto result = KNNQuery<PRec>::buffer_query(buffer, state, &p); + + KNNQuery<PRec>::delete_buffer_query_state(state); + + auto vptree = VPTree<PRec>(buffer); + auto state_2 = KNNQuery<PRec>::get_query_state(&vptree, &p); + auto result_2 = KNNQuery<PRec>::query(&vptree, state_2, &p); + KNNQuery<PRec>::delete_query_state(state_2); + + std::sort(result_2.begin(), result_2.end()); + size_t start = 46; + for (size_t i=0; i<result_2.size(); i++) { + ck_assert_int_eq(result_2[i].rec.data[0], start++); + } + + delete buffer; +} + + Suite *unit_testing() { Suite *unit = suite_create("VPTree Shard Unit Testing"); @@ -121,18 +181,14 @@ Suite *unit_testing() TCase *lookup = tcase_create("de:VPTree:point_lookup Testing"); tcase_add_test(lookup, t_point_lookup); - //tcase_add_test(lookup, t_point_lookup_miss); + tcase_add_test(lookup, t_point_lookup_miss); suite_add_tcase(unit, lookup); - /* - TCase *sampling = tcase_create("de:VPTree::VPTreeQuery Testing"); - tcase_add_test(sampling, t_wss_query); - tcase_add_test(sampling, t_wss_query_merge); - tcase_add_test(sampling, t_wss_buffer_query_rejection); - tcase_add_test(sampling, t_wss_buffer_query_scan); - suite_add_tcase(unit, sampling); - */ + TCase *query = tcase_create("de:VPTree::VPTreeQuery Testing"); + tcase_add_test(query, t_buffer_query); + tcase_add_test(query, t_knn_query); + suite_add_tcase(unit, query); return unit; } |