diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 16:49:21 -0400 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 16:49:21 -0400 |
| commit | d02fe67962c8002ddc6e0d6569128ae2645ea7fc (patch) | |
| tree | b0b27a29c58c65d51984318433f58698f297700e /include/shard | |
| parent | ac018f5f96c32c96158a239fbfeb9dc439c95548 (diff) | |
| download | dynamic-extension-d02fe67962c8002ddc6e0d6569128ae2645ea7fc.tar.gz | |
VPTree: fixed knn query
Diffstat (limited to 'include/shard')
| -rw-r--r-- | include/shard/VPTree.h | 136 |
1 files changed, 86 insertions, 50 deletions
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h index 05ffd50..927108c 100644 --- a/include/shard/VPTree.h +++ b/include/shard/VPTree.h @@ -46,26 +46,56 @@ struct KNNState { template <NDRecordInterface R> struct KNNBufferState { - size_t cutoff; - size_t sample_size; - Alias* alias; - decltype(R::weight) max_weight; - decltype(R::weight) total_weight; - - ~KNNBufferState() { - delete alias; + +}; + + +template <typename R> +class KNNDistCmpMax { +public: + KNNDistCmpMax(R *baseline) : P(baseline) {} + + inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> { + return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec); + } + + inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){ + return a->calc_distance(*P) > b->calc_distance(*P); + } + +private: + R *P; +}; + +template <typename R> +class KNNDistCmpMin { +public: + KNNDistCmpMin(R *baseline) : P(baseline) {} + + inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> { + return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec); + } + + inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){ + return a->calc_distance(*P) < b->calc_distance(*P); } +private: + R *P; }; + + template <NDRecordInterface R> class VPTree { private: struct vpnode { - size_t idx = 0; - double radius = 0; - vpnode *inside = nullptr; - vpnode *outside = nullptr; + size_t idx; + double radius; + vpnode *inside; + vpnode *outside; + + vpnode(size_t idx) : idx(idx), radius(0), inside(nullptr), outside(nullptr) {} ~vpnode() { delete inside; @@ -180,15 +210,17 @@ public: private: vpnode *build_vptree() { - assert(m_reccnt > 0); + if (m_reccnt == 0) { + return nullptr; + } size_t lower = 0; size_t upper = m_reccnt - 1; auto rng = gsl_rng_alloc(gsl_rng_mt19937); - auto n = build_subtree(lower, upper, rng); + auto root = build_subtree(lower, upper, rng); gsl_rng_free(rng); - return n; + return root; } void build_map() { @@ -204,32 +236,42 @@ private: } vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) { - if (start >= stop) { + // base-case: sometimes happens (probably because of the +1 and -1 + // in the first recursive call) + if (start > stop) { return nullptr; } - // select a random element to partition based on, and swap - // it to the front of the sub-array - auto i = start + gsl_rng_uniform_int(rng, stop - start); + // base-case: create a leaf node + if (start == stop) { + vpnode *node = new vpnode(start); + node->idx = start; + + m_node_cnt++; + + return node; + } + + // select a random element to be the root of the + // subtree + auto i = start + gsl_rng_uniform_int(rng, stop - start + 1); swap(start, i); // partition elements based on their distance from the start, // with those elements with distance falling below the median // distance going into the left sub-array and those above // the median in the right. This is easily done using QuickSelect. - auto mid = ((start+1) + stop) / 2; + auto mid = (start + 1 + stop) / 2; quickselect(start + 1, stop, mid, m_data[start], rng); // Create a new node based on this partitioning - vpnode *node = new vpnode(); + vpnode *node = new vpnode(start); - // store the radius of the circle used for partitioning the - // node. - node->idx = start; + // store the radius of the circle used for partitioning the node. node->radius = m_data[start].rec.calc_distance(m_data[mid].rec); // recursively construct the left and right subtrees - node->inside = build_subtree(start + 1, mid - 1, rng); + node->inside = build_subtree(start + 1, mid-1, rng); node->outside = build_subtree(mid, stop, rng); m_node_cnt++; @@ -277,35 +319,38 @@ private: } - void search(vpnode *node, const R &point, size_t k, PriorityQueue<R> &pq, double *farthest) { + void search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> &pq, double *farthest) { if (node == nullptr) return; double d = point.calc_distance(m_data[node->idx].rec); if (d < *farthest) { - if (pq.size() == k) pq.pop(); + if (pq.size() == k) { + auto t = pq.peek().data->rec; + pq.pop(); + } pq.push(&m_data[node->idx]); if (pq.size() == k) { - *farthest = point.calc_distance(*pq.peek().data.rec); + *farthest = point.calc_distance(pq.peek().data->rec); } } if (!node->inside && !node->outside) return; if (d < node->radius) { - if (d - *farthest <= node->radius) { + if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } - if (d + *farthest >= node->radius) { + if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } } else { - if (d + *farthest >= node->radius) { + if (d + (*farthest) >= node->radius) { search(node->outside, point, k, pq, farthest); } - if (d - *farthest <= node->radius) { + if (d - (*farthest) <= node->radius) { search(node->inside, point, k, pq, farthest); } } @@ -320,17 +365,6 @@ private: vpnode *m_root; }; -template <NDRecordInterface R, R P> -class KNNDistCmp { -public: - inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires WrappedInterface<R> { - return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec); - } - - inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires (!WrappedInterface<R>){ - return a->data.calc_distance(P) > b->data.calc_distance(P); - } -}; template <NDRecordInterface R> class KNNQuery { @@ -354,7 +388,7 @@ public: wrec.rec = p->point; wrec.header = 0; - PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq; + PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(p->k, &wrec); double farthest = std::numeric_limits<double>::max(); @@ -376,18 +410,18 @@ public: size_t k = p->k; - PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq; + PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(k, &wrec); for (size_t i=0; i<buffer->get_record_count(); i++) { // Skip over deleted records (under tagging) - if (buffer->get_data()[i]->is_deleted()) { + if ((buffer->get_data())[i].is_deleted()) { continue; } if (pq.size() < k) { pq.push(buffer->get_data() + i); } else { - double head_dist = pq.peek().data.rec->calc_distance(wrec.rec); - double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec); + double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); + double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec); if (cur_dist < head_dist) { pq.pop(); @@ -398,9 +432,11 @@ public: std::vector<Wrapped<R>> results; while (pq.size() > 0) { - results.emplace(*pq.peek().data); + results.emplace_back(*(pq.peek().data)); pq.pop(); } + + return results; } static std::vector<R> merge(std::vector<std::vector<R>> &results, void *parms) { @@ -408,7 +444,7 @@ public: R rec = p->point; size_t k = p->k; - PriorityQueue<R, KNNDistCmp<R, rec>> pq; + PriorityQueue<R, KNNDistCmpMax<R>> pq(k, &rec); for (size_t i=0; i<results.size(); i++) { for (size_t j=0; j<results.size(); j++) { if (pq.size() < k) { |