summaryrefslogtreecommitdiffstats
path: root/include/shard/VPTree.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/shard/VPTree.h')
-rw-r--r--include/shard/VPTree.h51
1 files changed, 31 insertions, 20 deletions
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h
index b342fe6..d5a2393 100644
--- a/include/shard/VPTree.h
+++ b/include/shard/VPTree.h
@@ -58,7 +58,7 @@ public:
sizeof(Wrapped<R>),
(byte**) &m_data);
- m_ptrs = new Wrapped<R>*[buffer.get_record_count()];
+ m_ptrs = new vp_ptr[buffer.get_record_count()];
size_t offset = 0;
m_reccnt = 0;
@@ -76,7 +76,7 @@ public:
rec->header &= 3;
m_data[m_reccnt] = *rec;
- m_ptrs[m_reccnt] = &m_data[m_reccnt];
+ m_ptrs[m_reccnt].ptr = &m_data[m_reccnt];
m_reccnt++;
}
@@ -97,7 +97,7 @@ public:
m_alloc_size = psudb::sf_aligned_alloc(CACHELINE_SIZE,
attemp_reccnt * sizeof(Wrapped<R>),
(byte **) &m_data);
- m_ptrs = new Wrapped<R>*[attemp_reccnt];
+ m_ptrs = new vp_ptr[attemp_reccnt];
// FIXME: will eventually need to figure out tombstones
// this one will likely require the multi-pass
@@ -110,7 +110,7 @@ public:
}
m_data[m_reccnt] = *shards[i]->get_record_at(j);
- m_ptrs[m_reccnt] = &m_data[m_reccnt];
+ m_ptrs[m_reccnt].ptr = &m_data[m_reccnt];
m_reccnt++;
}
}
@@ -139,8 +139,8 @@ public:
} else {
vpnode *node = m_root;
- while (!node->leaf && m_ptrs[node->start]->rec != rec) {
- if (rec.calc_distance((m_ptrs[node->start]->rec)) >= node->radius) {
+ while (!node->leaf && m_ptrs[node->start].ptr->rec != rec) {
+ if (rec.calc_distance((m_ptrs[node->start].ptr->rec)) >= node->radius) {
node = node->outside;
} else {
node = node->inside;
@@ -148,8 +148,8 @@ public:
}
for (size_t i=node->start; i<=node->stop; i++) {
- if (m_ptrs[i]->rec == rec) {
- return m_ptrs[i];
+ if (m_ptrs[i].ptr->rec == rec) {
+ return m_ptrs[i].ptr;
}
}
@@ -175,7 +175,7 @@ public:
}
size_t get_memory_usage() {
- return m_node_cnt * sizeof(vpnode) + m_reccnt * sizeof(R*) + m_alloc_size;
+ return m_node_cnt * sizeof(vpnode) + m_reccnt * sizeof(R*);
}
size_t get_aux_memory_usage() {
@@ -191,8 +191,12 @@ public:
}
private:
+ struct vp_ptr {
+ Wrapped<R> *ptr;
+ double dist;
+ };
Wrapped<R>* m_data;
- Wrapped<R>** m_ptrs;
+ vp_ptr* m_ptrs;
std::unordered_map<R, size_t, RecordHash<R>> m_lookup_map;
size_t m_reccnt;
size_t m_tombstone_cnt;
@@ -260,6 +264,11 @@ private:
auto i = start + gsl_rng_uniform_int(rng, stop - start + 1);
swap(start, i);
+ /* for efficiency, we'll pre-calculate the distances between each point and the root */
+ for (size_t i=start+1; i<=stop; i++) {
+ m_ptrs[i].dist = m_ptrs[start].ptr->rec.calc_distance(m_ptrs[i].ptr->rec);
+ }
+
/*
* partition elements based on their distance from the start,
* with those elements with distance falling below the median
@@ -267,14 +276,15 @@ private:
* the median in the right. This is easily done using QuickSelect.
*/
auto mid = (start + 1 + stop) / 2;
- quickselect(start + 1, stop, mid, m_ptrs[start], rng);
+ quickselect(start + 1, stop, mid, m_ptrs[start].ptr, rng);
/* Create a new node based on this partitioning */
vpnode *node = new vpnode();
node->start = start;
/* store the radius of the circle used for partitioning the node. */
- node->radius = m_ptrs[start]->rec.calc_distance(m_ptrs[mid]->rec);
+ node->radius = m_ptrs[start].ptr->rec.calc_distance(m_ptrs[mid].ptr->rec);
+ m_ptrs[start].dist = node->radius;
/* recursively construct the left and right subtrees */
node->inside = build_subtree(start + 1, mid-1, rng);
@@ -285,8 +295,6 @@ private:
return node;
}
- // TODO: The quickselect code can probably be generalized and moved out
- // to psudb-common instead.
void quickselect(size_t start, size_t stop, size_t k, Wrapped<R> *p, gsl_rng *rng) {
if (start == stop) return;
@@ -303,13 +311,16 @@ private:
// to psudb-common instead.
size_t partition(size_t start, size_t stop, Wrapped<R> *p, gsl_rng *rng) {
auto pivot = start + gsl_rng_uniform_int(rng, stop - start);
- double pivot_dist = p->rec.calc_distance(m_ptrs[pivot]->rec);
+ //double pivot_dist = p->rec.calc_distance(m_ptrs[pivot]->rec);
swap(pivot, stop);
size_t j = start;
for (size_t i=start; i<stop; i++) {
- if (p->rec.calc_distance(m_ptrs[i]->rec) < pivot_dist) {
+ if (m_ptrs[i].dist < m_ptrs[stop].dist) {
+ //assert(distances[i - start] == p->rec.calc_distance(m_ptrs[i]->rec));
+ //if (distances[i - start] < distances[stop - start]) {
+ //if (p->rec .calc_distance(m_ptrs[i]->rec) < pivot_dist) {
swap(j, i);
j++;
}
@@ -332,13 +343,13 @@ private:
if (node->leaf) {
for (size_t i=node->start; i<=node->stop; i++) {
- double d = point.calc_distance(m_ptrs[i]->rec);
+ double d = point.calc_distance(m_ptrs[i].ptr->rec);
if (d < *farthest) {
if (pq.size() == k) {
pq.pop();
}
- pq.push(m_ptrs[i]);
+ pq.push(m_ptrs[i].ptr);
if (pq.size() == k) {
*farthest = point.calc_distance(pq.peek().data->rec);
}
@@ -348,14 +359,14 @@ private:
return;
}
- double d = point.calc_distance(m_ptrs[node->start]->rec);
+ double d = point.calc_distance(m_ptrs[node->start].ptr->rec);
if (d < *farthest) {
if (pq.size() == k) {
auto t = pq.peek().data->rec;
pq.pop();
}
- pq.push(m_ptrs[node->start]);
+ pq.push(m_ptrs[node->start].ptr);
if (pq.size() == k) {
*farthest = point.calc_distance(pq.peek().data->rec);
}