summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2024-02-08 16:38:44 -0500
committerDouglas Rumbaugh <dbr4@psu.edu>2024-02-08 16:39:08 -0500
commit711769574e647839677739192698e400529efe75 (patch)
tree6262e9aa99123cfdc6c9930020662a4dc9c136eb
parent923e27797f6fd3a2b04f1a7a8d990a49374f4c61 (diff)
downloaddynamic-extension-711769574e647839677739192698e400529efe75.tar.gz
Updated VPTree to new shard/query interfaces
-rw-r--r--CMakeLists.txt7
-rw-r--r--include/framework/interface/Record.h19
-rw-r--r--include/framework/structure/BufferView.h3
-rw-r--r--include/query/knn.h159
-rw-r--r--include/shard/VPTree.h282
-rw-r--r--tests/include/shard_standard.h16
-rw-r--r--tests/include/testing.h59
-rw-r--r--tests/vptree_tests.cpp94
8 files changed, 319 insertions, 320 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index be0fb15..81fdb63 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -60,9 +60,10 @@ if (tests)
target_include_directories(rangecount_tests PRIVATE include external/psudb-common/cpp/include)
- #add_executable(vptree_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/vptree_tests.cpp)
- #target_link_libraries(vptree_tests PUBLIC gsl check subunit pthread)
- #target_include_directories(vptree_tests PRIVATE include external/vptree external/psudb-common/cpp/include)
+ add_executable(vptree_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/vptree_tests.cpp)
+ target_link_libraries(vptree_tests PUBLIC gsl check subunit pthread atomic)
+ target_link_options(vptree_tests PUBLIC -mcx16)
+ target_include_directories(vptree_tests PRIVATE include external/vptree external/psudb-common/cpp/include)
add_executable(de_tier_tag ${CMAKE_CURRENT_SOURCE_DIR}/tests/de_tier_tag.cpp)
target_link_libraries(de_tier_tag PUBLIC gsl check subunit pthread atomic)
diff --git a/include/framework/interface/Record.h b/include/framework/interface/Record.h
index 457078d..29df4b6 100644
--- a/include/framework/interface/Record.h
+++ b/include/framework/interface/Record.h
@@ -212,4 +212,23 @@ struct RecordHash {
}
};
+template <typename R>
+class DistCmpMax {
+public:
+ DistCmpMax(R *baseline) : P(baseline) {}
+
+ inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> {
+ return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec);
+ }
+
+ inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){
+ return a->calc_distance(*P) > b->calc_distance(*P);
+ }
+
+private:
+ R *P;
+};
+
+
+
}
diff --git a/include/framework/structure/BufferView.h b/include/framework/structure/BufferView.h
index edf6707..4e3de25 100644
--- a/include/framework/structure/BufferView.h
+++ b/include/framework/structure/BufferView.h
@@ -123,7 +123,6 @@ public:
Wrapped<R> *get(size_t i) {
assert(i < get_record_count());
- m_total += (m_data + to_idx(i))->rec.key;
return m_data + to_idx(i);
}
@@ -159,8 +158,6 @@ private:
psudb::BloomFilter<R> *m_tombstone_filter;
bool m_active;
- size_t m_total;
-
size_t to_idx(size_t i) {
size_t idx = (m_start + i >= m_cap) ? i = (m_cap - m_start)
: m_start + i;
diff --git a/include/query/knn.h b/include/query/knn.h
new file mode 100644
index 0000000..19dcf5c
--- /dev/null
+++ b/include/query/knn.h
@@ -0,0 +1,159 @@
+/*
+ * include/query/knn.h
+ *
+ * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu>
+ *
+ * Distributed under the Modified BSD License.
+ *
+ * A query class for k-NN queries, designed for use with the VPTree
+ * shard.
+ *
+ * FIXME: no support for tombstone deletes just yet. This would require a
+ * query resumption mechanism, most likely.
+ */
+#pragma once
+
+#include "framework/QueryRequirements.h"
+#include "psu-ds/PriorityQueue.h"
+
+namespace de { namespace knn {
+
+using psudb::PriorityQueue;
+
+template <NDRecordInterface R>
+struct Parms {
+ R point;
+ size_t k;
+};
+
+template <NDRecordInterface R>
+struct State {
+ size_t k;
+};
+
+template <NDRecordInterface R>
+struct BufferState {
+ BufferView<R> *buffer;
+
+ BufferState(BufferView<R> *buffer)
+ : buffer(buffer) {}
+};
+
+template <NDRecordInterface R, ShardInterface<R> S>
+class Query {
+public:
+ constexpr static bool EARLY_ABORT=false;
+ constexpr static bool SKIP_DELETE_FILTER=true;
+
+ static void *get_query_state(S *shard, void *parms) {
+ return nullptr;
+ }
+
+ static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
+ return new BufferState<R>(buffer);
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
+ return;
+ }
+
+ static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
+ std::vector<Wrapped<R>> results;
+ Parms<R> *p = (Parms<R> *) parms;
+ Wrapped<R> wrec;
+ wrec.rec = p->point;
+ wrec.header = 0;
+
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(p->k, &wrec);
+
+ shard->search(p->point, p->k, pq);
+
+ while (pq.size() > 0) {
+ results.emplace_back(*pq.peek().data);
+ pq.pop();
+ }
+
+ return results;
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
+ Parms<R> *p = (Parms<R> *) parms;
+ BufferState<R> *s = (BufferState<R> *) state;
+ Wrapped<R> wrec;
+ wrec.rec = p->point;
+ wrec.header = 0;
+
+ size_t k = p->k;
+
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(k, &wrec);
+ for (size_t i=0; i<s->buffer->get_record_count(); i++) {
+ // Skip over deleted records (under tagging)
+ if (s->buffer->get(i)->is_deleted()) {
+ continue;
+ }
+
+ if (pq.size() < k) {
+ pq.push(s->buffer->get(i));
+ } else {
+ double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
+ double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(s->buffer->get(i));
+ }
+ }
+ }
+
+ std::vector<Wrapped<R>> results;
+ while (pq.size() > 0) {
+ results.emplace_back(*(pq.peek().data));
+ pq.pop();
+ }
+
+ return results;
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
+ Parms<R> *p = (Parms<R> *) parms;
+ R rec = p->point;
+ size_t k = p->k;
+
+ PriorityQueue<R, DistCmpMax<R>> pq(k, &rec);
+ for (size_t i=0; i<results.size(); i++) {
+ for (size_t j=0; j<results[i].size(); j++) {
+ if (pq.size() < k) {
+ pq.push(&results[i][j].rec);
+ } else {
+ double head_dist = pq.peek().data->calc_distance(rec);
+ double cur_dist = results[i][j].rec.calc_distance(rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(&results[i][j].rec);
+ }
+ }
+ }
+ }
+
+ std::vector<R> output;
+ while (pq.size() > 0) {
+ output.emplace_back(*pq.peek().data);
+ pq.pop();
+ }
+
+ return output;
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (State<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (BufferState<R> *) state;
+ delete s;
+ }
+};
+
+}}
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h
index 2f5ebbb..ba13a87 100644
--- a/include/shard/VPTree.h
+++ b/include/shard/VPTree.h
@@ -5,98 +5,27 @@
*
* Distributed under the Modified BSD License.
*
- * A shard shim around the VPTree spatial index.
+ * A shard shim around a VPTree for high-dimensional metric similarity
+ * search.
*
- * FIXME: separate the KNN query class out into a standalone
- * file in include/query .
+ * FIXME: Does not yet support the tombstone delete policy.
*
*/
#pragma once
#include <vector>
-#include <cassert>
-#include <queue>
-#include <memory>
-#include <concepts>
-#include <map>
-#include <unordered_map>
-#include <functional>
+#include <unordered_map>
#include "framework/ShardRequirements.h"
-
#include "psu-ds/PriorityQueue.h"
-#include "util/Cursor.h"
-#include "psu-ds/BloomFilter.h"
-#include "util/bf_config.h"
using psudb::CACHELINE_SIZE;
-using psudb::BloomFilter;
using psudb::PriorityQueue;
using psudb::queue_record;
-using psudb::Alias;
+using psudb::byte;
namespace de {
-template <NDRecordInterface R>
-struct KNNQueryParms {
- R point;
- size_t k;
-};
-
-template <NDRecordInterface R>
-class KNNQuery;
-
-template <NDRecordInterface R>
-struct KNNState {
- size_t k;
-
- KNNState() {
- k = 0;
- }
-};
-
-template <NDRecordInterface R>
-struct KNNBufferState {
-
-};
-
-
-template <typename R>
-class KNNDistCmpMax {
-public:
- KNNDistCmpMax(R *baseline) : P(baseline) {}
-
- inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> {
- return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec);
- }
-
- inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){
- return a->calc_distance(*P) > b->calc_distance(*P);
- }
-
-private:
- R *P;
-};
-
-template <typename R>
-class KNNDistCmpMin {
-public:
- KNNDistCmpMin(R *baseline) : P(baseline) {}
-
- inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> {
- return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec);
- }
-
- inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){
- return a->calc_distance(*P) < b->calc_distance(*P);
- }
-
-private:
- R *P;
-};
-
-
-
template <NDRecordInterface R, size_t LEAFSZ=100, bool HMAP=false>
class VPTree {
private:
@@ -117,16 +46,19 @@ private:
}
};
-public:
- friend class KNNQuery<R>;
- VPTree(MutableBuffer<R>* buffer)
+
+public:
+ VPTree(BufferView<R> buffer)
: m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) {
- m_alloc_size = (buffer->get_record_count() * sizeof(Wrapped<R>)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Wrapped<R>)) % CACHELINE_SIZE);
- assert(m_alloc_size % CACHELINE_SIZE == 0);
- m_data = (Wrapped<R>*)std::aligned_alloc(CACHELINE_SIZE, m_alloc_size);
- m_ptrs = new Wrapped<R>*[buffer->get_record_count()];
+
+ m_alloc_size = psudb::sf_aligned_alloc(CACHELINE_SIZE,
+ buffer.get_record_count() *
+ sizeof(Wrapped<R>),
+ (byte**) &m_data);
+
+ m_ptrs = new Wrapped<R>*[buffer.get_record_count()];
size_t offset = 0;
m_reccnt = 0;
@@ -135,8 +67,8 @@ public:
// this one will likely require the multi-pass
// approach, as otherwise we'll need to sort the
// records repeatedly on each reconstruction.
- for (size_t i=0; i<buffer->get_record_count(); i++) {
- auto rec = buffer->get_data() + i;
+ for (size_t i=0; i<buffer.get_record_count(); i++) {
+ auto rec = buffer.get(i);
if (rec->is_deleted()) {
continue;
@@ -154,25 +86,24 @@ public:
}
}
- VPTree(VPTree** shards, size_t len)
+ VPTree(std::vector<VPTree*> shards)
: m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) {
size_t attemp_reccnt = 0;
-
- for (size_t i=0; i<len; i++) {
+ for (size_t i=0; i<shards.size(); i++) {
attemp_reccnt += shards[i]->get_record_count();
}
-
- m_alloc_size = (attemp_reccnt * sizeof(Wrapped<R>)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Wrapped<R>)) % CACHELINE_SIZE);
- assert(m_alloc_size % CACHELINE_SIZE == 0);
- m_data = (Wrapped<R>*)std::aligned_alloc(CACHELINE_SIZE, m_alloc_size);
+
+ m_alloc_size = psudb::sf_aligned_alloc(CACHELINE_SIZE,
+ attemp_reccnt * sizeof(Wrapped<R>),
+ (byte **) &m_data);
m_ptrs = new Wrapped<R>*[attemp_reccnt];
// FIXME: will eventually need to figure out tombstones
// this one will likely require the multi-pass
// approach, as otherwise we'll need to sort the
// records repeatedly on each reconstruction.
- for (size_t i=0; i<len; i++) {
+ for (size_t i=0; i<shards.size(); i++) {
for (size_t j=0; j<shards[i]->get_record_count(); j++) {
if (shards[i]->get_record_at(j)->is_deleted()) {
continue;
@@ -191,9 +122,9 @@ public:
}
~VPTree() {
- if (m_data) free(m_data);
- if (m_root) delete m_root;
- if (m_ptrs) delete[] m_ptrs;
+ free(m_data);
+ delete m_root;
+ delete[] m_ptrs;
}
Wrapped<R> *point_lookup(const R &rec, bool filter=false) {
@@ -248,11 +179,27 @@ public:
}
size_t get_aux_memory_usage() {
+ // FIXME: need to return the size of the unordered_map
return 0;
}
+ void search(const R &point, size_t k, PriorityQueue<Wrapped<R>,
+ DistCmpMax<Wrapped<R>>> &pq) {
+ double farthest = std::numeric_limits<double>::max();
+
+ internal_search(m_root, point, k, pq, &farthest);
+ }
private:
+ Wrapped<R>* m_data;
+ Wrapped<R>** m_ptrs;
+ std::unordered_map<R, size_t, RecordHash<R>> m_lookup_map;
+ size_t m_reccnt;
+ size_t m_tombstone_cnt;
+ size_t m_node_cnt;
+ size_t m_alloc_size;
+
+ vpnode *m_root;
vpnode *build_vptree() {
if (m_reccnt == 0) {
@@ -332,7 +279,6 @@ private:
return node;
}
-
void quickselect(size_t start, size_t stop, size_t k, Wrapped<R> *p, gsl_rng *rng) {
if (start == stop) return;
@@ -345,7 +291,6 @@ private:
}
}
-
size_t partition(size_t start, size_t stop, Wrapped<R> *p, gsl_rng *rng) {
auto pivot = start + gsl_rng_uniform_int(rng, stop - start);
double pivot_dist = p->rec.calc_distance(m_ptrs[pivot]->rec);
@@ -364,15 +309,15 @@ private:
return j;
}
-
void swap(size_t idx1, size_t idx2) {
auto tmp = m_ptrs[idx1];
m_ptrs[idx1] = m_ptrs[idx2];
m_ptrs[idx2] = tmp;
}
+ void internal_search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>,
+ DistCmpMax<Wrapped<R>>> &pq, double *farthest) {
- void search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> &pq, double *farthest) {
if (node == nullptr) return;
if (node->leaf) {
@@ -408,151 +353,24 @@ private:
if (d < node->radius) {
if (d - (*farthest) <= node->radius) {
- search(node->inside, point, k, pq, farthest);
+ internal_search(node->inside, point, k, pq, farthest);
}
if (d + (*farthest) >= node->radius) {
- search(node->outside, point, k, pq, farthest);
+ internal_search(node->outside, point, k, pq, farthest);
}
} else {
if (d + (*farthest) >= node->radius) {
- search(node->outside, point, k, pq, farthest);
+ internal_search(node->outside, point, k, pq, farthest);
}
if (d - (*farthest) <= node->radius) {
- search(node->inside, point, k, pq, farthest);
+ internal_search(node->inside, point, k, pq, farthest);
}
}
}
- Wrapped<R>* m_data;
- Wrapped<R>** m_ptrs;
- std::unordered_map<R, size_t, RecordHash<R>> m_lookup_map;
- size_t m_reccnt;
- size_t m_tombstone_cnt;
- size_t m_node_cnt;
- size_t m_alloc_size;
-
- vpnode *m_root;
-};
-
-
-template <NDRecordInterface R>
-class KNNQuery {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=true;
-
- static void *get_query_state(VPTree<R> *wss, void *parms) {
- return nullptr;
- }
-
- static void* get_buffer_query_state(MutableBuffer<R> *buffer, void *parms) {
- return nullptr;
- }
-
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void *buff_state) {
- return;
- }
-
- static std::vector<Wrapped<R>> query(VPTree<R> *wss, void *q_state, void *parms) {
- std::vector<Wrapped<R>> results;
- KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms;
- Wrapped<R> wrec;
- wrec.rec = p->point;
- wrec.header = 0;
-
- PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(p->k, &wrec);
-
- double farthest = std::numeric_limits<double>::max();
-
- wss->search(wss->m_root, p->point, p->k, pq, &farthest);
-
- while (pq.size() > 0) {
- results.emplace_back(*pq.peek().data);
- pq.pop();
- }
-
- return results;
- }
-
- static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) {
- KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms;
- Wrapped<R> wrec;
- wrec.rec = p->point;
- wrec.header = 0;
-
- size_t k = p->k;
-
- PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(k, &wrec);
- for (size_t i=0; i<buffer->get_record_count(); i++) {
- // Skip over deleted records (under tagging)
- if ((buffer->get_data())[i].is_deleted()) {
- continue;
- }
-
- if (pq.size() < k) {
- pq.push(buffer->get_data() + i);
- } else {
- double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
- double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec);
-
- if (cur_dist < head_dist) {
- pq.pop();
- pq.push(buffer->get_data() + i);
- }
- }
- }
-
- std::vector<Wrapped<R>> results;
- while (pq.size() > 0) {
- results.emplace_back(*(pq.peek().data));
- pq.pop();
- }
-
- return results;
- }
-
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) {
- KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms;
- R rec = p->point;
- size_t k = p->k;
-
- PriorityQueue<R, KNNDistCmpMax<R>> pq(k, &rec);
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- if (pq.size() < k) {
- pq.push(&results[i][j].rec);
- } else {
- double head_dist = pq.peek().data->calc_distance(rec);
- double cur_dist = results[i][j].rec.calc_distance(rec);
- if (cur_dist < head_dist) {
- pq.pop();
- pq.push(&results[i][j].rec);
- }
- }
- }
- }
-
- std::vector<R> output;
- while (pq.size() > 0) {
- output.emplace_back(*pq.peek().data);
- pq.pop();
- }
-
- return output;
- }
-
- static void delete_query_state(void *state) {
- auto s = (KNNState<R> *) state;
- delete s;
- }
-
- static void delete_buffer_query_state(void *state) {
- auto s = (KNNBufferState<R> *) state;
- delete s;
- }
-};
+ };
}
diff --git a/tests/include/shard_standard.h b/tests/include/shard_standard.h
index f50c1cb..7d17dcb 100644
--- a/tests/include/shard_standard.h
+++ b/tests/include/shard_standard.h
@@ -22,18 +22,22 @@
* should be included in the source file that includes this one, above the
* include statement.
*/
-//#include "shard/ISAMTree.h"
-//#include "testing.h"
-//#include <check.h>
-//using namespace de;
-//typedef ISAMTree<R> Shard;
+/*
+#include "shard/ISAMTree.h"
+#include "shard/ISAMTree.h"
+#include "testing.h"
+#include <check.h>
+using namespace de;
+typedef Rec R;
+typedef ISAMTree<R> Shard;
+*/
START_TEST(t_mbuffer_init)
{
auto buffer = new MutableBuffer<R>(512, 1024);
for (uint64_t i = 512; i > 0; i--) {
uint32_t v = i;
- buffer->append({i,v, 1});
+ buffer->append({i, v, 1});
}
for (uint64_t i = 1; i <= 256; ++i) {
diff --git a/tests/include/testing.h b/tests/include/testing.h
index 4e660dd..f935b53 100644
--- a/tests/include/testing.h
+++ b/tests/include/testing.h
@@ -23,7 +23,7 @@
typedef de::WeightedRecord<uint64_t, uint32_t, uint64_t> WRec;
typedef de::Record<uint64_t, uint32_t> Rec;
-typedef de::EuclidPoint<int64_t> PRec;
+typedef de::EuclidPoint<uint64_t> PRec;
template <de::RecordInterface R>
std::vector<R> strip_wrapping(std::vector<de::Wrapped<R>> vec) {
@@ -76,55 +76,48 @@ static bool roughly_equal(int n1, int n2, size_t mag, double epsilon) {
return ((double) std::abs(n1 - n2) / (double) mag) < epsilon;
}
-static de::MutableBuffer<PRec> *create_2d_mbuffer(size_t cnt) {
- auto buffer = new de::MutableBuffer<PRec>(cnt/2, cnt);
-
- for (int64_t i=0; i<cnt; i++) {
- buffer->append({rand(), rand()});
- }
-
- return buffer;
-}
-
-static de::MutableBuffer<PRec> *create_2d_sequential_mbuffer(size_t cnt) {
- auto buffer = new de::MutableBuffer<PRec>(cnt/2, cnt);
- for (int64_t i=0; i<cnt; i++) {
- buffer->append({i, i});
- }
-
- return buffer;
-}
-
-template <de::KVPInterface R>
+template <de::RecordInterface R>
static de::MutableBuffer<R> *create_test_mbuffer(size_t cnt)
{
auto buffer = new de::MutableBuffer<R>(cnt/2, cnt);
R rec;
- for (size_t i = 0; i < cnt; i++) {
- rec.key = rand();
- rec.value = rand();
+ if constexpr (de::KVPInterface<R>) {
+ for (size_t i = 0; i < cnt; i++) {
+ rec.key = rand();
+ rec.value = rand();
- if constexpr (de::WeightedRecordInterface<R>) {
- rec.weight = 1;
- }
+ if constexpr (de::WeightedRecordInterface<R>) {
+ rec.weight = 1;
+ }
- buffer->append(rec);
- }
+ buffer->append(rec);
+ }
+ } else if constexpr (de::NDRecordInterface<R>) {
+ for (size_t i=0; i<cnt; i++) {
+ uint64_t a = rand();
+ uint64_t b = rand();
+ buffer->append({a, b});
+ }
+ }
return buffer;
}
-template <de::KVPInterface R>
-static de::MutableBuffer<R> *create_sequential_mbuffer(decltype(R::key) start, decltype(R::key) stop)
+template <de::RecordInterface R>
+static de::MutableBuffer<R> *create_sequential_mbuffer(size_t start, size_t stop)
{
size_t cnt = stop - start;
auto buffer = new de::MutableBuffer<R>(cnt/2, cnt);
for (size_t i=start; i<stop; i++) {
R rec;
- rec.key = i;
- rec.value = i;
+ if constexpr (de::KVPInterface<R>) {
+ rec.key = i;
+ rec.value = i;
+ } else if constexpr (de::NDRecordInterface<R>) {
+ rec = {i, i};
+ }
if constexpr (de::WeightedRecordInterface<R>) {
rec.weight = 1;
diff --git a/tests/vptree_tests.cpp b/tests/vptree_tests.cpp
index fb568dd..ff99ba6 100644
--- a/tests/vptree_tests.cpp
+++ b/tests/vptree_tests.cpp
@@ -9,27 +9,28 @@
*
*/
+
+#include "include/testing.h"
#include "shard/VPTree.h"
-#include "testing.h"
-#include "vptree.hpp"
+#include "query/knn.h"
#include <check.h>
using namespace de;
-
-typedef VPTree<PRec> Shard;
+typedef PRec R;
+typedef VPTree<R> Shard;
START_TEST(t_mbuffer_init)
{
size_t n= 24;
- auto buffer = new MutableBuffer<PRec>(n, n);
+ auto buffer = new MutableBuffer<PRec>(n/2, n);
for (int64_t i=0; i<n; i++) {
- buffer->append({i, i});
+ buffer->append({(uint64_t) i, (uint64_t) i});
}
- Shard* shard = new Shard(buffer);
+ Shard* shard = new Shard(buffer->get_buffer_view());
ck_assert_uint_eq(shard->get_record_count(), n);
delete buffer;
@@ -40,16 +41,16 @@ START_TEST(t_mbuffer_init)
START_TEST(t_wss_init)
{
size_t n = 512;
- auto mbuffer1 = create_2d_mbuffer(n);
- auto mbuffer2 = create_2d_mbuffer(n);
- auto mbuffer3 = create_2d_mbuffer(n);
+ auto mbuffer1 = create_test_mbuffer<R>(n);
+ auto mbuffer2 = create_test_mbuffer<R>(n);
+ auto mbuffer3 = create_test_mbuffer<R>(n);
- auto shard1 = new Shard(mbuffer1);
- auto shard2 = new Shard(mbuffer2);
- auto shard3 = new Shard(mbuffer3);
+ auto shard1 = new Shard(mbuffer1->get_buffer_view());
+ auto shard2 = new Shard(mbuffer2->get_buffer_view());
+ auto shard3 = new Shard(mbuffer3->get_buffer_view());
- Shard* shards[3] = {shard1, shard2, shard3};
- auto shard4 = new Shard(shards, 3);
+ std::vector<Shard *> shards = {shard1, shard2, shard3};
+ auto shard4 = new Shard(shards);
ck_assert_int_eq(shard4->get_record_count(), n * 3);
ck_assert_int_eq(shard4->get_tombstone_count(), 0);
@@ -69,19 +70,23 @@ START_TEST(t_point_lookup)
{
size_t n = 16;
- auto buffer = create_2d_sequential_mbuffer(n);
- auto wss = Shard(buffer);
+ auto buffer = create_sequential_mbuffer<R>(0, n);
+ auto wss = Shard(buffer->get_buffer_view());
- for (size_t i=0; i<n; i++) {
- PRec r;
- auto rec = (buffer->get_data() + i);
- r.data[0] = rec->rec.data[0];
- r.data[1] = rec->rec.data[1];
+ {
+ auto bv = buffer->get_buffer_view();
- auto result = wss.point_lookup(r);
- ck_assert_ptr_nonnull(result);
- ck_assert_int_eq(result->rec.data[0], r.data[0]);
- ck_assert_int_eq(result->rec.data[1], r.data[1]);
+ for (size_t i=0; i<n; i++) {
+ PRec r;
+ auto rec = (bv.get(i));
+ r.data[0] = rec->rec.data[0];
+ r.data[1] = rec->rec.data[1];
+
+ auto result = wss.point_lookup(r);
+ ck_assert_ptr_nonnull(result);
+ ck_assert_int_eq(result->rec.data[0], r.data[0]);
+ ck_assert_int_eq(result->rec.data[1], r.data[1]);
+ }
}
delete buffer;
@@ -93,8 +98,8 @@ START_TEST(t_point_lookup_miss)
{
size_t n = 10000;
- auto buffer = create_2d_sequential_mbuffer(n);
- auto wss = Shard(buffer);
+ auto buffer = create_sequential_mbuffer<R>(0, n);
+ auto wss = Shard(buffer->get_buffer_view());
for (size_t i=n + 100; i<2*n; i++) {
PRec r;
@@ -112,24 +117,27 @@ START_TEST(t_point_lookup_miss)
START_TEST(t_buffer_query)
{
size_t n = 10000;
- auto buffer = create_2d_sequential_mbuffer(n);
+ auto buffer = create_sequential_mbuffer<R>(0, n);
PRec target;
target.data[0] = 120;
target.data[1] = 120;
- KNNQueryParms<PRec> p;
+ knn::Parms<PRec> p;
p.k = 10;
p.point = target;
- auto state = KNNQuery<PRec>::get_buffer_query_state(buffer, &p);
- auto result = KNNQuery<PRec>::buffer_query(buffer, state, &p);
- KNNQuery<PRec>::delete_buffer_query_state(state);
+ {
+ auto bv = buffer->get_buffer_view();
+ auto state = knn::Query<PRec, Shard>::get_buffer_query_state(&bv, &p);
+ auto result = knn::Query<PRec, Shard>::buffer_query(state, &p);
+ knn::Query<PRec, Shard>::delete_buffer_query_state(state);
- std::sort(result.begin(), result.end());
- size_t start = 120 - 5;
- for (size_t i=0; i<result.size(); i++) {
- ck_assert_int_eq(result[i].rec.data[0], start++);
+ std::sort(result.begin(), result.end());
+ size_t start = 120 - 5;
+ for (size_t i=0; i<result.size(); i++) {
+ ck_assert_int_eq(result[i].rec.data[0], start++);
+ }
}
delete buffer;
@@ -138,19 +146,19 @@ START_TEST(t_buffer_query)
START_TEST(t_knn_query)
{
size_t n = 1000;
- auto buffer = create_2d_sequential_mbuffer(n);
+ auto buffer = create_sequential_mbuffer<R>(0, n);
- auto vptree = VPTree<PRec>(buffer);
+ auto vptree = VPTree<PRec>(buffer->get_buffer_view());
- KNNQueryParms<PRec> p;
+ knn::Parms<PRec> p;
for (size_t i=0; i<100; i++) {
p.k = rand() % 150;
p.point.data[0] = rand() % (n-p.k);
p.point.data[1] = p.point.data[0];
- auto state = KNNQuery<PRec>::get_query_state(&vptree, &p);
- auto results = KNNQuery<PRec>::query(&vptree, state, &p);
- KNNQuery<PRec>::delete_query_state(state);
+ auto state = knn::Query<PRec, Shard>::get_query_state(&vptree, &p);
+ auto results = knn::Query<PRec, Shard>::query(&vptree, state, &p);
+ knn::Query<PRec, Shard>::delete_query_state(state);
ck_assert_int_eq(results.size(), p.k);