summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/ds/PriorityQueue.h6
-rw-r--r--include/framework/RecordInterface.h10
-rw-r--r--include/shard/VPTree.h136
-rw-r--r--tests/vptree_tests.cpp74
4 files changed, 164 insertions, 62 deletions
diff --git a/include/ds/PriorityQueue.h b/include/ds/PriorityQueue.h
index a8e9ba5..4612eef 100644
--- a/include/ds/PriorityQueue.h
+++ b/include/ds/PriorityQueue.h
@@ -23,6 +23,7 @@ struct queue_record {
template <typename R>
class standard_minheap {
public:
+ standard_minheap(R *baseline) {}
inline bool operator()(const R* a, const R* b) {
return *a < *b;
}
@@ -31,6 +32,7 @@ public:
template <typename R>
class standard_maxheap {
public:
+ standard_maxheap(R *baseline) {}
inline bool operator()(const R* a, const R* b) {
return *a > *b;
}
@@ -39,7 +41,8 @@ public:
template <typename R, typename CMP=standard_minheap<R>>
class PriorityQueue {
public:
- PriorityQueue(size_t size) : data(size), tail(0) {}
+ PriorityQueue(size_t size, R* cmp_baseline=nullptr) : data(size), tail(0), cmp(cmp_baseline) {}
+
~PriorityQueue() = default;
size_t size() const {
@@ -97,6 +100,7 @@ private:
std::vector<queue_record<R>> data;
CMP cmp;
size_t tail;
+ R *baseline;
/*
* Swap the elements at position a and position
diff --git a/include/framework/RecordInterface.h b/include/framework/RecordInterface.h
index 85a0794..cea9fbe 100644
--- a/include/framework/RecordInterface.h
+++ b/include/framework/RecordInterface.h
@@ -49,6 +49,7 @@ concept WrappedInterface = RecordInterface<R> && requires(R r, R s, bool b) {
{r.set_tombstone(b)};
{r.is_tombstone()} -> std::convertible_to<bool>;
{r < s} -> std::convertible_to<bool>;
+ {r == s} ->std::convertible_to<bool>;
};
template<RecordInterface R>
@@ -79,6 +80,11 @@ struct Wrapped {
inline bool operator<(const Wrapped& other) const {
return rec < other.rec || (rec == other.rec && header < other.header);
}
+
+ inline bool operator==(const Wrapped& other) const {
+ return rec == other.rec;
+ }
+
};
template <typename K, typename V>
@@ -185,10 +191,10 @@ struct EuclidPoint{
inline double calc_distance(const EuclidPoint& other) const {
double dist = 0;
for (size_t i=0; i<D; i++) {
- dist += pow(data[i] - other.data[i], 2);
+ dist += (data[i] - other.data[i]) * (data[i] - other.data[i]);
}
- return sqrt(dist);
+ return std::sqrt(dist);
}
};
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h
index 05ffd50..927108c 100644
--- a/include/shard/VPTree.h
+++ b/include/shard/VPTree.h
@@ -46,26 +46,56 @@ struct KNNState {
template <NDRecordInterface R>
struct KNNBufferState {
- size_t cutoff;
- size_t sample_size;
- Alias* alias;
- decltype(R::weight) max_weight;
- decltype(R::weight) total_weight;
-
- ~KNNBufferState() {
- delete alias;
+
+};
+
+
+template <typename R>
+class KNNDistCmpMax {
+public:
+ KNNDistCmpMax(R *baseline) : P(baseline) {}
+
+ inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> {
+ return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec);
+ }
+
+ inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){
+ return a->calc_distance(*P) > b->calc_distance(*P);
+ }
+
+private:
+ R *P;
+};
+
+template <typename R>
+class KNNDistCmpMin {
+public:
+ KNNDistCmpMin(R *baseline) : P(baseline) {}
+
+ inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> {
+ return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec);
+ }
+
+ inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){
+ return a->calc_distance(*P) < b->calc_distance(*P);
}
+private:
+ R *P;
};
+
+
template <NDRecordInterface R>
class VPTree {
private:
struct vpnode {
- size_t idx = 0;
- double radius = 0;
- vpnode *inside = nullptr;
- vpnode *outside = nullptr;
+ size_t idx;
+ double radius;
+ vpnode *inside;
+ vpnode *outside;
+
+ vpnode(size_t idx) : idx(idx), radius(0), inside(nullptr), outside(nullptr) {}
~vpnode() {
delete inside;
@@ -180,15 +210,17 @@ public:
private:
vpnode *build_vptree() {
- assert(m_reccnt > 0);
+ if (m_reccnt == 0) {
+ return nullptr;
+ }
size_t lower = 0;
size_t upper = m_reccnt - 1;
auto rng = gsl_rng_alloc(gsl_rng_mt19937);
- auto n = build_subtree(lower, upper, rng);
+ auto root = build_subtree(lower, upper, rng);
gsl_rng_free(rng);
- return n;
+ return root;
}
void build_map() {
@@ -204,32 +236,42 @@ private:
}
vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) {
- if (start >= stop) {
+ // base-case: sometimes happens (probably because of the +1 and -1
+ // in the first recursive call)
+ if (start > stop) {
return nullptr;
}
- // select a random element to partition based on, and swap
- // it to the front of the sub-array
- auto i = start + gsl_rng_uniform_int(rng, stop - start);
+ // base-case: create a leaf node
+ if (start == stop) {
+ vpnode *node = new vpnode(start);
+ node->idx = start;
+
+ m_node_cnt++;
+
+ return node;
+ }
+
+ // select a random element to be the root of the
+ // subtree
+ auto i = start + gsl_rng_uniform_int(rng, stop - start + 1);
swap(start, i);
// partition elements based on their distance from the start,
// with those elements with distance falling below the median
// distance going into the left sub-array and those above
// the median in the right. This is easily done using QuickSelect.
- auto mid = ((start+1) + stop) / 2;
+ auto mid = (start + 1 + stop) / 2;
quickselect(start + 1, stop, mid, m_data[start], rng);
// Create a new node based on this partitioning
- vpnode *node = new vpnode();
+ vpnode *node = new vpnode(start);
- // store the radius of the circle used for partitioning the
- // node.
- node->idx = start;
+ // store the radius of the circle used for partitioning the node.
node->radius = m_data[start].rec.calc_distance(m_data[mid].rec);
// recursively construct the left and right subtrees
- node->inside = build_subtree(start + 1, mid - 1, rng);
+ node->inside = build_subtree(start + 1, mid-1, rng);
node->outside = build_subtree(mid, stop, rng);
m_node_cnt++;
@@ -277,35 +319,38 @@ private:
}
- void search(vpnode *node, const R &point, size_t k, PriorityQueue<R> &pq, double *farthest) {
+ void search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> &pq, double *farthest) {
if (node == nullptr) return;
double d = point.calc_distance(m_data[node->idx].rec);
if (d < *farthest) {
- if (pq.size() == k) pq.pop();
+ if (pq.size() == k) {
+ auto t = pq.peek().data->rec;
+ pq.pop();
+ }
pq.push(&m_data[node->idx]);
if (pq.size() == k) {
- *farthest = point.calc_distance(*pq.peek().data.rec);
+ *farthest = point.calc_distance(pq.peek().data->rec);
}
}
if (!node->inside && !node->outside) return;
if (d < node->radius) {
- if (d - *farthest <= node->radius) {
+ if (d - (*farthest) <= node->radius) {
search(node->inside, point, k, pq, farthest);
}
- if (d + *farthest >= node->radius) {
+ if (d + (*farthest) >= node->radius) {
search(node->outside, point, k, pq, farthest);
}
} else {
- if (d + *farthest >= node->radius) {
+ if (d + (*farthest) >= node->radius) {
search(node->outside, point, k, pq, farthest);
}
- if (d - *farthest <= node->radius) {
+ if (d - (*farthest) <= node->radius) {
search(node->inside, point, k, pq, farthest);
}
}
@@ -320,17 +365,6 @@ private:
vpnode *m_root;
};
-template <NDRecordInterface R, R P>
-class KNNDistCmp {
-public:
- inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires WrappedInterface<R> {
- return a->rec.calc_distance(P.rec) > b->rec.calc_distance(P.rec);
- }
-
- inline bool operator()(queue_record<R> *a, queue_record<R> *b) requires (!WrappedInterface<R>){
- return a->data.calc_distance(P) > b->data.calc_distance(P);
- }
-};
template <NDRecordInterface R>
class KNNQuery {
@@ -354,7 +388,7 @@ public:
wrec.rec = p->point;
wrec.header = 0;
- PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq;
+ PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(p->k, &wrec);
double farthest = std::numeric_limits<double>::max();
@@ -376,18 +410,18 @@ public:
size_t k = p->k;
- PriorityQueue<Wrapped<R>, KNNDistCmp<Wrapped<R>, wrec>> pq;
+ PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(k, &wrec);
for (size_t i=0; i<buffer->get_record_count(); i++) {
// Skip over deleted records (under tagging)
- if (buffer->get_data()[i]->is_deleted()) {
+ if ((buffer->get_data())[i].is_deleted()) {
continue;
}
if (pq.size() < k) {
pq.push(buffer->get_data() + i);
} else {
- double head_dist = pq.peek().data.rec->calc_distance(wrec.rec);
- double cur_dist = (buffer->get_data() + i)->calc_distance(wrec.rec);
+ double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
+ double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec);
if (cur_dist < head_dist) {
pq.pop();
@@ -398,9 +432,11 @@ public:
std::vector<Wrapped<R>> results;
while (pq.size() > 0) {
- results.emplace(*pq.peek().data);
+ results.emplace_back(*(pq.peek().data));
pq.pop();
}
+
+ return results;
}
static std::vector<R> merge(std::vector<std::vector<R>> &results, void *parms) {
@@ -408,7 +444,7 @@ public:
R rec = p->point;
size_t k = p->k;
- PriorityQueue<R, KNNDistCmp<R, rec>> pq;
+ PriorityQueue<R, KNNDistCmpMax<R>> pq(k, &rec);
for (size_t i=0; i<results.size(); i++) {
for (size_t j=0; j<results.size(); j++) {
if (pq.size() < k) {
diff --git a/tests/vptree_tests.cpp b/tests/vptree_tests.cpp
index b86e1e9..894e64a 100644
--- a/tests/vptree_tests.cpp
+++ b/tests/vptree_tests.cpp
@@ -11,6 +11,7 @@
#include "shard/VPTree.h"
#include "testing.h"
+#include "vptree.hpp"
#include <check.h>
@@ -108,6 +109,65 @@ START_TEST(t_point_lookup_miss)
}
+START_TEST(t_buffer_query)
+{
+ size_t n = 10000;
+ auto buffer = create_2d_sequential_mbuffer(n);
+
+ PRec target;
+ target.data[0] = 120;
+ target.data[1] = 120;
+
+ KNNQueryParms<PRec> p;
+ p.k = 10;
+ p.point = target;
+
+ auto state = KNNQuery<PRec>::get_buffer_query_state(buffer, &p);
+ auto result = KNNQuery<PRec>::buffer_query(buffer, state, &p);
+ KNNQuery<PRec>::delete_buffer_query_state(state);
+
+ std::sort(result.begin(), result.end());
+ size_t start = 120 - 5;
+ for (size_t i=0; i<result.size(); i++) {
+ ck_assert_int_eq(result[i].rec.data[0], start++);
+ }
+
+ delete buffer;
+}
+
+START_TEST(t_knn_query)
+{
+ size_t n = 100;
+ auto buffer = create_2d_sequential_mbuffer(n);
+
+ PRec target;
+ target.data[0] = 50;
+ target.data[1] = 50;
+
+ KNNQueryParms<PRec> p;
+ p.k = 10;
+ p.point = target;
+
+ auto state = KNNQuery<PRec>::get_buffer_query_state(buffer, &p);
+ auto result = KNNQuery<PRec>::buffer_query(buffer, state, &p);
+
+ KNNQuery<PRec>::delete_buffer_query_state(state);
+
+ auto vptree = VPTree<PRec>(buffer);
+ auto state_2 = KNNQuery<PRec>::get_query_state(&vptree, &p);
+ auto result_2 = KNNQuery<PRec>::query(&vptree, state_2, &p);
+ KNNQuery<PRec>::delete_query_state(state_2);
+
+ std::sort(result_2.begin(), result_2.end());
+ size_t start = 46;
+ for (size_t i=0; i<result_2.size(); i++) {
+ ck_assert_int_eq(result_2[i].rec.data[0], start++);
+ }
+
+ delete buffer;
+}
+
+
Suite *unit_testing()
{
Suite *unit = suite_create("VPTree Shard Unit Testing");
@@ -121,18 +181,14 @@ Suite *unit_testing()
TCase *lookup = tcase_create("de:VPTree:point_lookup Testing");
tcase_add_test(lookup, t_point_lookup);
- //tcase_add_test(lookup, t_point_lookup_miss);
+ tcase_add_test(lookup, t_point_lookup_miss);
suite_add_tcase(unit, lookup);
- /*
- TCase *sampling = tcase_create("de:VPTree::VPTreeQuery Testing");
- tcase_add_test(sampling, t_wss_query);
- tcase_add_test(sampling, t_wss_query_merge);
- tcase_add_test(sampling, t_wss_buffer_query_rejection);
- tcase_add_test(sampling, t_wss_buffer_query_scan);
- suite_add_tcase(unit, sampling);
- */
+ TCase *query = tcase_create("de:VPTree::VPTreeQuery Testing");
+ tcase_add_test(query, t_buffer_query);
+ tcase_add_test(query, t_knn_query);
+ suite_add_tcase(unit, query);
return unit;
}