summaryrefslogtreecommitdiffstats
path: root/include/shard
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-07-24 16:49:21 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-07-24 16:49:21 -0400
commitd02fe67962c8002ddc6e0d6569128ae2645ea7fc (patch)
treeb0b27a29c58c65d51984318433f58698f297700e /include/shard
parentac018f5f96c32c96158a239fbfeb9dc439c95548 (diff)
downloaddynamic-extension-d02fe67962c8002ddc6e0d6569128ae2645ea7fc.tar.gz
VPTree: fixed knn query
Diffstat (limited to 'include/shard')
-rw-r--r--include/shard/VPTree.h136
1 files changed, 86 insertions, 50 deletions
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h
index 05ffd50..927108c 100644
--- a/include/shard/VPTree.h
+++ b/include/shard/VPTree.h
@@ -46,26 +46,56 @@ struct KNNState {
template <NDRecordInterface R>
struct KNNBufferState {
- size_t cutoff;
- size_t sample_size;
- Alias* alias;
- decltype(R::weight) max_weight;
- decltype(R::weight) total_weight;
-
- ~KNNBufferState() {
- delete alias;
+
+};
+
+
+template <typename R>
+class KNNDistCmpMax {
+public:
+ KNNDistCmpMax(R *baseline) : P(baseline) {}
+
+ inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> {
+ return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec);
+ }
+
+ inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){
+ return a->calc_distance(*P) > b->calc_distance(*P);
+ }
+
+private:
+ R *P;
+};
+
+template <typename R>
+class KNNDistCmpMin {
+public:
+ KNNDistCmpMin(R *baseline) : P(baseline) {}
+
+ inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> {
+ return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec);
+ }
+
+ inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){
+ return a->calc_distance(*P) < b->calc_distance(*P);
}
+private:
+ R *P;
};
+
+
template <NDRecordInterface R>
class VPTree {
private:
struct vpnode {
- size_t idx = 0;
- double radius = 0;
- vpnode *inside = nullptr;
- vpnode *outside = nullptr;
+ size_t idx;
+ double radius;
+ vpnode *inside;
+ vpnode *outside;
+
+ vpnode(size_t idx) : idx(idx), radius(0), inside(nullptr), outside(nullptr) {}
~vpnode() {
delete inside;
@@ -180,15 +210,17 @@ public:
private:
vpnode *build_vptree() {
- assert(m_reccnt > 0);
+ if (m_reccnt == 0) {
+ return nullptr;
+ }
size_t lower = 0;
size_t upper = m_reccnt - 1;
auto rng = gsl_rng_alloc(gsl_rng_mt19937);
- auto n = build_subtree(lower, upper, rng);
+ auto root = build_subtree(lower, upper, rng);
gsl_rng_free(rng);
- return n;
+ return root;
}
void build_map() {
@@ -204,32 +236,42 @@ private:
}
vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) {
- if (start >= stop) {
+ // base-case: sometimes happens (probably because of the +1 and -1
+ // in the first recursive call)
+ if (start > stop) {
return nullptr;
}
- // select a random element to partition based on, and swap
- // it to the front of the sub-array
- auto i = start + gsl_rng_uniform_int(rng, stop - start);
+ // base-case: create a leaf node
+ if (start == stop) {
+ vpnode *node = new vpnode(start);
+ node->idx = start;
+
+ m_node_cnt++;
+
+ return node;
+ }
+
+ // select a random element to be the root of the
+ // subtree
+ auto i = start + gsl_rng_uniform_int(rng, stop - start + 1);
swap(start, i);
// partition elements based on their distance from the start,
// with those elements with distance falling below the median
// distance going into the left sub-array and those above
// the median in the right. This is easily done using QuickSelect.
- auto mid = ((start+1) + stop) / 2;
+ auto mid = (start + 1 + stop) / 2;
quickselect(start + 1, stop, mid, m_data[start], rng);
// Create a new node based on this partitioning
- vpnode *node = new vpnode();
+ vpnode *node = new vpnode(start);
- // store the radius of the circle used for partitioning the
- // node.
- node->idx = start;
+ // store the radius of the circle used for partitioning the node.
node->radius = m_data[start].rec.calc_distance(m_data[mid].rec);
// recursively construct the left and right subtrees
- node->inside = build_subtree(start + 1, mid - 1, rng);
+ node->inside = build_subtree(start + 1, mid-1, rng);
node->outside = build_subtree(mid, stop, rng);
m_node_cnt++;
@@ -277,35 +319,38 @@ private:
}
- void search(vpnode *node, const R &point, size_t k, PriorityQueue<R> &pq, double *farthest) {
+ void search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> &pq, double *farthest) {
if (node == nullptr) return;
double d = point.calc_distance(m_data[node->idx].rec);
if (d < *farthest) {
- if (pq.size() == k) pq.pop();
+ if (pq.size() == k) {
+ auto t = pq.peek().data->rec;
+ pq.pop();
+ }
pq.push(&m_data[node->idx]);
if (pq.size() == k) {
- *farthest = point.calc_distance(*pq.peek().data.rec);
+ *farthest = point.calc_distance(pq.peek().data->rec);
}
}
if (!node->inside && !node->outside) return;
if (d < node->radius) {
- if (d - *farthest <= node->radius) {
+ if (d - (*farthest) <= node->radius) {
search(node->inside, point, k, pq, farthest);
}
- if (d + *farthest >= node->radius) {
+ if (d + (*farthest) >= node->radius) {
search(node->outside, point, k, pq, farthest);
}
} else {
- if (d + *farthest >= node->radius) {
+ if (d + (*farthest) >= node->radius) {
search(node->outside, point, k, pq, farthest);
}
- if (d - *farthest <= node->radius) {
+ if (d - (*farthest) <= node->radius) {
search(node->inside, point, k, pq, farthest);
}
}
@@ -320,17 +365,6 @@ private:
vpnode *m_root;
};
-template <NDRecordInterface R, R P>
-class KNNDistCmp {
-public:
- inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires WrappedInterface<R> {
- return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec);
- }
-
- inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires (!WrappedInterface<R>){
- return a->data.calc_distance(P) > b->data.calc_distance(P);
- }
-};
template <NDRecordInterface R>
class KNNQuery {
@@ -354,7 +388,7 @@ public:
wrec.rec = p->point;
wrec.header = 0;
- PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq;
+ PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(p->k, &wrec);
double farthest = std::numeric_limits<double>::max();
@@ -376,18 +410,18 @@ public:
size_t k = p->k;
- PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq;
+ PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(k, &wrec);
for (size_t i=0; i<buffer->get_record_count(); i++) {
// Skip over deleted records (under tagging)
- if (buffer->get_data()[i]->is_deleted()) {
+ if ((buffer->get_data())[i].is_deleted()) {
continue;
}
if (pq.size() < k) {
pq.push(buffer->get_data() + i);
} else {
- double head_dist = pq.peek().data.rec->calc_distance(wrec.rec);
- double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec);
+ double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
+ double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec);
if (cur_dist < head_dist) {
pq.pop();
@@ -398,9 +432,11 @@ public:
std::vector<Wrapped<R>> results;
while (pq.size() > 0) {
- results.emplace(*pq.peek().data);
+ results.emplace_back(*(pq.peek().data));
pq.pop();
}
+
+ return results;
}
static std::vector<R> merge(std::vector<std::vector<R>> &results, void *parms) {
@@ -408,7 +444,7 @@ public:
R rec = p->point;
size_t k = p->k;
- PriorityQueue<R, KNNDistCmp<R, rec>> pq;
+ PriorityQueue<R, KNNDistCmpMax<R>> pq(k, &rec);
for (size_t i=0; i<results.size(); i++) {
for (size_t j=0; j<results.size(); j++) {
if (pq.size() < k) {