summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-07-13 15:34:49 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-07-13 15:34:49 -0400
commit12915304570507e00848aba700f0ed3a26dbb9b6 (patch)
treeed69ab1bf95df0fa7924f677b4bad82f325fd96e
parent369dc4c8b3331aa318f2a98eb973d0840541297d (diff)
downloaddynamic-extension-12915304570507e00848aba700f0ed3a26dbb9b6.tar.gz
Initial commit of VPTree-related code
Point lookups are currently broken; I suspect that there is something wrong with tree construction, although the quickselect implementation seems to be fine.
-rw-r--r--CMakeLists.txt4
-rw-r--r--include/framework/RecordInterface.h47
-rw-r--r--include/shard/VPTree.h314
-rw-r--r--tests/testing.h31
-rw-r--r--tests/vptree_tests.cpp162
5 files changed, 542 insertions, 16 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index dfdf812..a7b2f8a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -45,6 +45,10 @@ if (tests)
target_link_libraries(mutable_buffer_tests PUBLIC gsl check subunit pthread)
target_include_directories(mutable_buffer_tests PRIVATE include)
+ add_executable(vptree_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/vptree_tests.cpp)
+ target_link_libraries(vptree_tests PUBLIC gsl check subunit pthread)
+ target_include_directories(vptree_tests PRIVATE include)
+
#add_executable(dynamic_extension_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/dynamic_extension_tests.cpp)
#target_link_libraries(dynamic_extension_tests PUBLIC gsl check subunit pthread)
#target_include_directories(dynamic_extension_tests PRIVATE include)
diff --git a/include/framework/RecordInterface.h b/include/framework/RecordInterface.h
index 70c8e01..6936a8b 100644
--- a/include/framework/RecordInterface.h
+++ b/include/framework/RecordInterface.h
@@ -11,6 +11,7 @@
#include <cstring>
#include <concepts>
+#include <cmath>
#include "util/base.h"
@@ -18,13 +19,26 @@ namespace de {
template<typename R>
concept RecordInterface = requires(R r, R s) {
- r.key;
- r.value;
-
{ r < s } ->std::convertible_to<bool>;
{ r == s } ->std::convertible_to<bool>;
};
+template<typename R>
+concept WeightedRecordInterface = requires(R r) {
+ {r.weight} -> std::convertible_to<double>;
+};
+
+template<typename R>
+concept NDRecordInterface = RecordInterface<R> && requires(R r, R s) {
+ {r.calc_distance(s)} -> std::convertible_to<double>;
+};
+
+template <typename R>
+concept KVPInterface = RecordInterface<R> && requires(R r) {
+ r.key;
+ r.value;
+};
+
template<RecordInterface R>
struct Wrapped {
uint32_t header;
@@ -51,17 +65,10 @@ struct Wrapped {
}
inline bool operator<(const Wrapped& other) const {
- return (rec.key < other.rec.key) || (rec.key == other.rec.key && rec.value < other.rec.value)
- || (rec.key == other.rec.key && rec.value == other.rec.value && header < other.header);
+ return rec < other.rec || (rec == other.rec && header < other.header);
}
};
-
-template <typename R>
-concept WeightedRecordInterface = RecordInterface<R> && requires(R r) {
- {r.weight} -> std::convertible_to<double>;
-};
-
template <typename K, typename V>
struct Record {
K key;
@@ -92,4 +99,22 @@ struct WeightedRecord {
}
};
+template <typename V>
+struct Point{
+ V x;
+ V y;
+
+ inline bool operator==(const Point& other) const {
+ return x == other.x && y == other.y;
+ }
+
+ // lexicographic order
+ inline bool operator<(const Point& other) const {
+ return x < other.x || (x == other.x && y < other.y);
+ }
+
+ inline double calc_distance(const Point& other) const {
+ return sqrt(pow(x - other.x, 2) + pow(y - other.y, 2));
+ }
+};
}
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h
new file mode 100644
index 0000000..5364537
--- /dev/null
+++ b/include/shard/VPTree.h
@@ -0,0 +1,314 @@
+/*
+ * include/shard/VPTree.h
+ *
+ * Copyright (C) 2023 Douglas Rumbaugh <drumbaugh@psu.edu>
+ *
+ * All outsides reserved. Published under the Modified BSD License.
+ *
+ */
+#pragma once
+
+#include <vector>
+#include <cassert>
+#include <queue>
+#include <memory>
+#include <concepts>
+
+#include "ds/PriorityQueue.h"
+#include "util/Cursor.h"
+#include "ds/BloomFilter.h"
+#include "util/bf_config.h"
+#include "framework/MutableBuffer.h"
+#include "framework/RecordInterface.h"
+#include "framework/ShardInterface.h"
+#include "framework/QueryInterface.h"
+
+namespace de {
+
+thread_local size_t wss_cancelations = 0;
+
+template <NDRecordInterface R>
+struct KNNQueryParms {
+ R point;
+ size_t k;
+};
+
+template <NDRecordInterface R>
+class KNNQuery;
+
+template <NDRecordInterface R>
+struct KNNState {
+ size_t k;
+
+ KNNState() {
+ k = 0;
+ }
+};
+
+template <NDRecordInterface R>
+struct KNNBufferState {
+ size_t cutoff;
+ size_t sample_size;
+ Alias* alias;
+ decltype(R::weight) max_weight;
+ decltype(R::weight) total_weight;
+
+ ~KNNBufferState() {
+ delete alias;
+ }
+
+};
+
+template <NDRecordInterface R>
+class VPTree {
+private:
+ struct vpnode {
+ size_t idx = 0;
+ double radius = 0;
+ vpnode *inside = nullptr;
+ vpnode *outside = nullptr;
+
+ ~vpnode() {
+ delete inside;
+ delete outside;
+ }
+ };
+
+public:
+ friend class KNNQuery<R>;
+
+ VPTree(MutableBuffer<R>* buffer)
+ : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) {
+
+ size_t alloc_size = (buffer->get_record_count() * sizeof(Wrapped<R>)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Wrapped<R>)) % CACHELINE_SIZE);
+ assert(alloc_size % CACHELINE_SIZE == 0);
+ m_data = (Wrapped<R>*)std::aligned_alloc(CACHELINE_SIZE, alloc_size);
+
+ size_t offset = 0;
+ m_reccnt = 0;
+
+ // FIXME: will eventually need to figure out tombstones
+ // this one will likely require the multi-pass
+ // approach, as otherwise we'll need to sort the
+ // records repeatedly on each reconstruction.
+ for (size_t i=0; i<buffer->get_record_count(); i++) {
+ auto rec = buffer->get_data() + i;
+
+ if (rec->is_deleted()) {
+ continue;
+ }
+
+ rec->header &= 3;
+ m_data[m_reccnt++] = *rec;
+ }
+
+ if (m_reccnt > 0) {
+ m_root = build_vptree();
+ }
+ }
+
+ VPTree(VPTree** shards, size_t len)
+ : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) {
+
+ size_t attemp_reccnt = 0;
+
+ for (size_t i=0; i<len; i++) {
+ attemp_reccnt += shards[i]->get_record_count();
+ }
+
+ size_t alloc_size = (attemp_reccnt * sizeof(Wrapped<R>)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Wrapped<R>)) % CACHELINE_SIZE);
+ assert(alloc_size % CACHELINE_SIZE == 0);
+ m_data = (Wrapped<R>*)std::aligned_alloc(CACHELINE_SIZE, alloc_size);
+
+ // FIXME: will eventually need to figure out tombstones
+ // this one will likely require the multi-pass
+ // approach, as otherwise we'll need to sort the
+ // records repeatedly on each reconstruction.
+ for (size_t i=0; i<len; i++) {
+ for (size_t j=0; j<shards[i]->get_record_count(); j++) {
+ if (shards[i]->get_record_at(j)->is_deleted()) {
+ continue;
+ }
+
+ m_data[m_reccnt++] = *shards[i]->get_record_at(j);
+ }
+ }
+
+ if (m_reccnt > 0) {
+ m_root = build_vptree();
+ }
+ }
+
+ ~VPTree() {
+ if (m_data) free(m_data);
+ if (m_root) delete m_root;
+ }
+
+ Wrapped<R> *point_lookup(const R &rec, bool filter=false) {
+ auto node = m_root;
+
+ while (node && m_data[node->idx].rec != rec) {
+ if (rec.calc_distance(m_data[node->idx].rec) >= node->radius) {
+ node = node->outside;
+ } else {
+ node = node->inside;
+ }
+ }
+
+ return (node) ? m_data + node->idx : nullptr;
+ }
+
+ Wrapped<R>* get_data() const {
+ return m_data;
+ }
+
+ size_t get_record_count() const {
+ return m_reccnt;
+ }
+
+ size_t get_tombstone_count() const {
+ return m_tombstone_cnt;
+ }
+
+ const Wrapped<R>* get_record_at(size_t idx) const {
+ if (idx >= m_reccnt) return nullptr;
+ return m_data + idx;
+ }
+
+
+ size_t get_memory_usage() {
+ return m_node_cnt * sizeof(vpnode);
+ }
+
+private:
+
+ vpnode *build_vptree() {
+ assert(m_reccnt > 0);
+
+ size_t lower = 0;
+ size_t upper = m_reccnt - 1;
+
+ auto rng = gsl_rng_alloc(gsl_rng_mt19937);
+ auto n = build_subtree(lower, upper, rng);
+ gsl_rng_free(rng);
+ return n;
+ }
+
+ vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) {
+ if (start >= stop) {
+ return nullptr;
+ }
+
+ // select a random element to partition based on, and swap
+ // it to the front of the sub-array
+ auto i = start + gsl_rng_uniform_int(rng, stop - start);
+ swap(start, i);
+
+ // partition elements based on their distance from the start,
+ // with those elements with distance falling below the median
+ // distance going into the left sub-array and those above
+ // the median in the right. This is easily done using QuickSelect.
+ auto mid = ((start+1) + stop) / 2;
+ quickselect(start + 1, stop, mid, m_data[start], rng);
+
+ // Create a new node based on this partitioning
+ vpnode *node = new vpnode();
+
+ // store the radius of the circle used for partitioning the
+ // node.
+ node->idx = start;
+ node->radius = m_data[start].rec.calc_distance(m_data[mid].rec);
+
+ // recursively construct the left and right subtrees
+ node->inside = build_subtree(start + 1, mid - 1, rng);
+ node->outside = build_subtree(mid, stop, rng);
+
+ m_node_cnt++;
+
+ return node;
+ }
+
+
+ void quickselect(size_t start, size_t stop, size_t k, Wrapped<R> p, gsl_rng *rng) {
+ if (start == stop) return;
+
+ auto pivot = partition(start, stop, p, rng);
+
+ if (k < pivot) {
+ quickselect(start, pivot - 1, k, p, rng);
+ } else if (k > pivot) {
+ quickselect(pivot + 1, stop, k, p, rng);
+ }
+ }
+
+
+ size_t partition(size_t start, size_t stop, Wrapped<R> p, gsl_rng *rng) {
+ auto pivot = start + gsl_rng_uniform_int(rng, stop - start);
+ double pivot_dist = p.rec.calc_distance(m_data[pivot].rec);
+
+ swap(pivot, stop);
+
+ size_t j = start;
+ for (size_t i=start; i<stop; i++) {
+ if (p.rec.calc_distance(m_data[i].rec) < pivot_dist) {
+ swap(j, i);
+ j++;
+ }
+ }
+
+ swap(j, stop);
+ return j;
+ }
+
+
+ void swap(size_t idx1, size_t idx2) {
+ Wrapped<R> tmp = m_data[idx1];
+ m_data[idx1] = m_data[idx2];
+ m_data[idx2] = tmp;
+ }
+
+ Wrapped<R>* m_data;
+ size_t m_reccnt;
+ size_t m_tombstone_cnt;
+ size_t m_node_cnt;
+
+ vpnode *m_root;
+
+};
+
+
+template <NDRecordInterface R>
+class KNNQuery {
+public:
+ static void *get_query_state(VPTree<R> *wss, void *parms) {
+ return nullptr;
+ }
+
+ static void* get_buffer_query_state(MutableBuffer<R> *buffer, void *parms) {
+ return nullptr;
+ }
+
+ static void process_query_states(void *query_parms, std::vector<void*> shard_states, void *buff_state) {
+ }
+
+ static std::vector<Wrapped<R>> query(VPTree<R> *wss, void *q_state, void *parms) {
+ }
+
+ static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) {
+ }
+
+ static std::vector<R> merge(std::vector<std::vector<R>> &results) {
+ }
+
+ static void delete_query_state(void *state) {
+ auto s = (KNNState<R> *) state;
+ delete s;
+ }
+
+ static void delete_buffer_query_state(void *state) {
+ auto s = (KNNBufferState<R> *) state;
+ delete s;
+ }
+};
+
+}
diff --git a/tests/testing.h b/tests/testing.h
index fe6623e..1d5db59 100644
--- a/tests/testing.h
+++ b/tests/testing.h
@@ -23,6 +23,7 @@
typedef de::WeightedRecord<uint64_t, uint32_t, uint64_t> WRec;
typedef de::Record<uint64_t, uint32_t> Rec;
+typedef de::Point<int64_t> PRec;
template <de::RecordInterface R>
std::vector<R> strip_wrapping(std::vector<de::Wrapped<R>> vec) {
@@ -75,7 +76,26 @@ static bool roughly_equal(int n1, int n2, size_t mag, double epsilon) {
return ((double) std::abs(n1 - n2) / (double) mag) < epsilon;
}
-template <de::RecordInterface R>
+static de::MutableBuffer<PRec> *create_2d_mbuffer(size_t cnt) {
+ auto buffer = new de::MutableBuffer<PRec>(cnt, cnt);
+
+ for (int64_t i=0; i<cnt; i++) {
+ buffer->append({rand(), rand()});
+ }
+
+ return buffer;
+}
+
+static de::MutableBuffer<PRec> *create_2d_sequential_mbuffer(size_t cnt) {
+ auto buffer = new de::MutableBuffer<PRec>(cnt, cnt);
+ for (int64_t i=0; i<cnt; i++) {
+ buffer->append({i, i});
+ }
+
+ return buffer;
+}
+
+template <de::KVPInterface R>
static de::MutableBuffer<R> *create_test_mbuffer(size_t cnt)
{
auto buffer = new de::MutableBuffer<R>(cnt, cnt);
@@ -95,7 +115,7 @@ static de::MutableBuffer<R> *create_test_mbuffer(size_t cnt)
return buffer;
}
-template <de::RecordInterface R>
+template <de::KVPInterface R>
static de::MutableBuffer<R> *create_sequential_mbuffer(decltype(R::key) start, decltype(R::key) stop)
{
size_t cnt = stop - start;
@@ -116,7 +136,7 @@ static de::MutableBuffer<R> *create_sequential_mbuffer(decltype(R::key) start, d
return buffer;
}
-template <de::RecordInterface R>
+template <de::KVPInterface R>
static de::MutableBuffer<R> *create_test_mbuffer_tombstones(size_t cnt, size_t ts_cnt)
{
auto buffer = new de::MutableBuffer<R>(cnt, ts_cnt);
@@ -147,7 +167,8 @@ static de::MutableBuffer<R> *create_test_mbuffer_tombstones(size_t cnt, size_t t
return buffer;
}
-template <de::WeightedRecordInterface R>
+template <typename R>
+requires de::WeightedRecordInterface<R> && de::KVPInterface<R>
static de::MutableBuffer<R> *create_weighted_mbuffer(size_t cnt)
{
auto buffer = new de::MutableBuffer<R>(cnt, cnt);
@@ -170,7 +191,7 @@ static de::MutableBuffer<R> *create_weighted_mbuffer(size_t cnt)
return buffer;
}
-template <de::RecordInterface R>
+template <de::KVPInterface R>
static de::MutableBuffer<R> *create_double_seq_mbuffer(size_t cnt, bool ts=false)
{
auto buffer = new de::MutableBuffer<R>(cnt, cnt);
diff --git a/tests/vptree_tests.cpp b/tests/vptree_tests.cpp
new file mode 100644
index 0000000..1b5c18c
--- /dev/null
+++ b/tests/vptree_tests.cpp
@@ -0,0 +1,162 @@
+/*
+ * tests/vptree_tests.cpp
+ *
+ * Unit tests for VPTree (knn queries)
+ *
+ * Copyright (C) 2023 Douglas Rumbaugh <drumbaugh@psu.edu>
+ *
+ * All rights reserved. Published under the Modified BSD License.
+ *
+ */
+
+#include "shard/VPTree.h"
+#include "testing.h"
+
+#include <check.h>
+
+using namespace de;
+
+
+typedef VPTree<PRec> Shard;
+
+START_TEST(t_mbuffer_init)
+{
+ size_t n= 24;
+ auto buffer = new MutableBuffer<PRec>(n, n);
+
+ for (int64_t i=0; i<n; i++) {
+ buffer->append({i, i});
+ }
+
+ Shard* shard = new Shard(buffer);
+ ck_assert_uint_eq(shard->get_record_count(), n);
+
+ delete buffer;
+ delete shard;
+}
+
+
+START_TEST(t_wss_init)
+{
+ size_t n = 512;
+ auto mbuffer1 = create_2d_mbuffer(n);
+ auto mbuffer2 = create_2d_mbuffer(n);
+ auto mbuffer3 = create_2d_mbuffer(n);
+
+ auto shard1 = new Shard(mbuffer1);
+ auto shard2 = new Shard(mbuffer2);
+ auto shard3 = new Shard(mbuffer3);
+
+ Shard* shards[3] = {shard1, shard2, shard3};
+ auto shard4 = new Shard(shards, 3);
+
+ ck_assert_int_eq(shard4->get_record_count(), n * 3);
+ ck_assert_int_eq(shard4->get_tombstone_count(), 0);
+
+ delete mbuffer1;
+ delete mbuffer2;
+ delete mbuffer3;
+
+ delete shard1;
+ delete shard2;
+ delete shard3;
+ delete shard4;
+}
+
+
+START_TEST(t_point_lookup)
+{
+ size_t n = 30;
+
+ auto buffer = create_2d_sequential_mbuffer(n);
+ auto wss = Shard(buffer);
+
+ for (size_t i=0; i<n; i++) {
+ PRec r;
+ auto rec = (buffer->get_data() + i);
+ r.x = rec->rec.x;
+ r.y = rec->rec.y;
+
+ fprintf(stderr, "%ld\n", i);
+
+ auto result = wss.point_lookup(r);
+ ck_assert_ptr_nonnull(result);
+ ck_assert_int_eq(result->rec.x, r.x);
+ ck_assert_int_eq(result->rec.y, r.y);
+ }
+
+ delete buffer;
+}
+END_TEST
+
+
+START_TEST(t_point_lookup_miss)
+{
+ size_t n = 10000;
+
+ auto buffer = create_2d_sequential_mbuffer(n);
+ auto wss = Shard(buffer);
+
+ for (size_t i=n + 100; i<2*n; i++) {
+ PRec r;
+ r.x = i;
+ r.y = i;
+
+ auto result = wss.point_lookup(r);
+ ck_assert_ptr_null(result);
+ }
+
+ delete buffer;
+}
+
+
+Suite *unit_testing()
+{
+ Suite *unit = suite_create("VPTree Shard Unit Testing");
+
+ TCase *create = tcase_create("de::VPTree constructor Testing");
+ tcase_add_test(create, t_mbuffer_init);
+ tcase_add_test(create, t_wss_init);
+ tcase_set_timeout(create, 100);
+ suite_add_tcase(unit, create);
+
+
+ TCase *lookup = tcase_create("de:VPTree:point_lookup Testing");
+ tcase_add_test(lookup, t_point_lookup);
+ tcase_add_test(lookup, t_point_lookup_miss);
+ suite_add_tcase(unit, lookup);
+
+
+ /*
+ TCase *sampling = tcase_create("de:VPTree::VPTreeQuery Testing");
+ tcase_add_test(sampling, t_wss_query);
+ tcase_add_test(sampling, t_wss_query_merge);
+ tcase_add_test(sampling, t_wss_buffer_query_rejection);
+ tcase_add_test(sampling, t_wss_buffer_query_scan);
+ suite_add_tcase(unit, sampling);
+ */
+
+ return unit;
+}
+
+
+int shard_unit_tests()
+{
+ int failed = 0;
+ Suite *unit = unit_testing();
+ SRunner *unit_shardner = srunner_create(unit);
+
+ srunner_run_all(unit_shardner, CK_NORMAL);
+ failed = srunner_ntests_failed(unit_shardner);
+ srunner_free(unit_shardner);
+
+ return failed;
+}
+
+
+int main()
+{
+ int unit_failed = shard_unit_tests();
+
+ return (unit_failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE;
+}