From 12915304570507e00848aba700f0ed3a26dbb9b6 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Thu, 13 Jul 2023 15:34:49 -0400 Subject: Initial commit of VPTree-related code Point lookups are currently broken; I suspect that there is something wrong with tree construction, although the quickselect implementation seems to be fine. --- CMakeLists.txt | 4 + include/framework/RecordInterface.h | 47 ++++-- include/shard/VPTree.h | 314 ++++++++++++++++++++++++++++++++++++ tests/testing.h | 31 +++- tests/vptree_tests.cpp | 162 +++++++++++++++++++ 5 files changed, 542 insertions(+), 16 deletions(-) create mode 100644 include/shard/VPTree.h create mode 100644 tests/vptree_tests.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index dfdf812..a7b2f8a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,10 @@ if (tests) target_link_libraries(mutable_buffer_tests PUBLIC gsl check subunit pthread) target_include_directories(mutable_buffer_tests PRIVATE include) + add_executable(vptree_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/vptree_tests.cpp) + target_link_libraries(vptree_tests PUBLIC gsl check subunit pthread) + target_include_directories(vptree_tests PRIVATE include) + #add_executable(dynamic_extension_tests ${CMAKE_CURRENT_SOURCE_DIR}/tests/dynamic_extension_tests.cpp) #target_link_libraries(dynamic_extension_tests PUBLIC gsl check subunit pthread) #target_include_directories(dynamic_extension_tests PRIVATE include) diff --git a/include/framework/RecordInterface.h b/include/framework/RecordInterface.h index 70c8e01..6936a8b 100644 --- a/include/framework/RecordInterface.h +++ b/include/framework/RecordInterface.h @@ -11,6 +11,7 @@ #include #include +#include #include "util/base.h" @@ -18,13 +19,26 @@ namespace de { template concept RecordInterface = requires(R r, R s) { - r.key; - r.value; - { r < s } ->std::convertible_to; { r == s } ->std::convertible_to; }; +template +concept WeightedRecordInterface = requires(R r) { + {r.weight} -> std::convertible_to; +}; + +template +concept NDRecordInterface = RecordInterface && requires(R r, R s) { + {r.calc_distance(s)} -> std::convertible_to; +}; + +template +concept KVPInterface = RecordInterface && requires(R r) { + r.key; + r.value; +}; + template struct Wrapped { uint32_t header; @@ -51,17 +65,10 @@ struct Wrapped { } inline bool operator<(const Wrapped& other) const { - return (rec.key < other.rec.key) || (rec.key == other.rec.key && rec.value < other.rec.value) - || (rec.key == other.rec.key && rec.value == other.rec.value && header < other.header); + return rec < other.rec || (rec == other.rec && header < other.header); } }; - -template -concept WeightedRecordInterface = RecordInterface && requires(R r) { - {r.weight} -> std::convertible_to; -}; - template struct Record { K key; @@ -92,4 +99,22 @@ struct WeightedRecord { } }; +template +struct Point{ + V x; + V y; + + inline bool operator==(const Point& other) const { + return x == other.x && y == other.y; + } + + // lexicographic order + inline bool operator<(const Point& other) const { + return x < other.x || (x == other.x && y < other.y); + } + + inline double calc_distance(const Point& other) const { + return sqrt(pow(x - other.x, 2) + pow(y - other.y, 2)); + } +}; } diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h new file mode 100644 index 0000000..5364537 --- /dev/null +++ b/include/shard/VPTree.h @@ -0,0 +1,314 @@ +/* + * include/shard/VPTree.h + * + * Copyright (C) 2023 Douglas Rumbaugh + * + * All outsides reserved. Published under the Modified BSD License. + * + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "ds/PriorityQueue.h" +#include "util/Cursor.h" +#include "ds/BloomFilter.h" +#include "util/bf_config.h" +#include "framework/MutableBuffer.h" +#include "framework/RecordInterface.h" +#include "framework/ShardInterface.h" +#include "framework/QueryInterface.h" + +namespace de { + +thread_local size_t wss_cancelations = 0; + +template +struct KNNQueryParms { + R point; + size_t k; +}; + +template +class KNNQuery; + +template +struct KNNState { + size_t k; + + KNNState() { + k = 0; + } +}; + +template +struct KNNBufferState { + size_t cutoff; + size_t sample_size; + Alias* alias; + decltype(R::weight) max_weight; + decltype(R::weight) total_weight; + + ~KNNBufferState() { + delete alias; + } + +}; + +template +class VPTree { +private: + struct vpnode { + size_t idx = 0; + double radius = 0; + vpnode *inside = nullptr; + vpnode *outside = nullptr; + + ~vpnode() { + delete inside; + delete outside; + } + }; + +public: + friend class KNNQuery; + + VPTree(MutableBuffer* buffer) + : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { + + size_t alloc_size = (buffer->get_record_count() * sizeof(Wrapped)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Wrapped)) % CACHELINE_SIZE); + assert(alloc_size % CACHELINE_SIZE == 0); + m_data = (Wrapped*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); + + size_t offset = 0; + m_reccnt = 0; + + // FIXME: will eventually need to figure out tombstones + // this one will likely require the multi-pass + // approach, as otherwise we'll need to sort the + // records repeatedly on each reconstruction. + for (size_t i=0; iget_record_count(); i++) { + auto rec = buffer->get_data() + i; + + if (rec->is_deleted()) { + continue; + } + + rec->header &= 3; + m_data[m_reccnt++] = *rec; + } + + if (m_reccnt > 0) { + m_root = build_vptree(); + } + } + + VPTree(VPTree** shards, size_t len) + : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { + + size_t attemp_reccnt = 0; + + for (size_t i=0; iget_record_count(); + } + + size_t alloc_size = (attemp_reccnt * sizeof(Wrapped)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Wrapped)) % CACHELINE_SIZE); + assert(alloc_size % CACHELINE_SIZE == 0); + m_data = (Wrapped*)std::aligned_alloc(CACHELINE_SIZE, alloc_size); + + // FIXME: will eventually need to figure out tombstones + // this one will likely require the multi-pass + // approach, as otherwise we'll need to sort the + // records repeatedly on each reconstruction. + for (size_t i=0; iget_record_count(); j++) { + if (shards[i]->get_record_at(j)->is_deleted()) { + continue; + } + + m_data[m_reccnt++] = *shards[i]->get_record_at(j); + } + } + + if (m_reccnt > 0) { + m_root = build_vptree(); + } + } + + ~VPTree() { + if (m_data) free(m_data); + if (m_root) delete m_root; + } + + Wrapped *point_lookup(const R &rec, bool filter=false) { + auto node = m_root; + + while (node && m_data[node->idx].rec != rec) { + if (rec.calc_distance(m_data[node->idx].rec) >= node->radius) { + node = node->outside; + } else { + node = node->inside; + } + } + + return (node) ? m_data + node->idx : nullptr; + } + + Wrapped* get_data() const { + return m_data; + } + + size_t get_record_count() const { + return m_reccnt; + } + + size_t get_tombstone_count() const { + return m_tombstone_cnt; + } + + const Wrapped* get_record_at(size_t idx) const { + if (idx >= m_reccnt) return nullptr; + return m_data + idx; + } + + + size_t get_memory_usage() { + return m_node_cnt * sizeof(vpnode); + } + +private: + + vpnode *build_vptree() { + assert(m_reccnt > 0); + + size_t lower = 0; + size_t upper = m_reccnt - 1; + + auto rng = gsl_rng_alloc(gsl_rng_mt19937); + auto n = build_subtree(lower, upper, rng); + gsl_rng_free(rng); + return n; + } + + vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) { + if (start >= stop) { + return nullptr; + } + + // select a random element to partition based on, and swap + // it to the front of the sub-array + auto i = start + gsl_rng_uniform_int(rng, stop - start); + swap(start, i); + + // partition elements based on their distance from the start, + // with those elements with distance falling below the median + // distance going into the left sub-array and those above + // the median in the right. This is easily done using QuickSelect. + auto mid = ((start+1) + stop) / 2; + quickselect(start + 1, stop, mid, m_data[start], rng); + + // Create a new node based on this partitioning + vpnode *node = new vpnode(); + + // store the radius of the circle used for partitioning the + // node. + node->idx = start; + node->radius = m_data[start].rec.calc_distance(m_data[mid].rec); + + // recursively construct the left and right subtrees + node->inside = build_subtree(start + 1, mid - 1, rng); + node->outside = build_subtree(mid, stop, rng); + + m_node_cnt++; + + return node; + } + + + void quickselect(size_t start, size_t stop, size_t k, Wrapped p, gsl_rng *rng) { + if (start == stop) return; + + auto pivot = partition(start, stop, p, rng); + + if (k < pivot) { + quickselect(start, pivot - 1, k, p, rng); + } else if (k > pivot) { + quickselect(pivot + 1, stop, k, p, rng); + } + } + + + size_t partition(size_t start, size_t stop, Wrapped p, gsl_rng *rng) { + auto pivot = start + gsl_rng_uniform_int(rng, stop - start); + double pivot_dist = p.rec.calc_distance(m_data[pivot].rec); + + swap(pivot, stop); + + size_t j = start; + for (size_t i=start; i tmp = m_data[idx1]; + m_data[idx1] = m_data[idx2]; + m_data[idx2] = tmp; + } + + Wrapped* m_data; + size_t m_reccnt; + size_t m_tombstone_cnt; + size_t m_node_cnt; + + vpnode *m_root; + +}; + + +template +class KNNQuery { +public: + static void *get_query_state(VPTree *wss, void *parms) { + return nullptr; + } + + static void* get_buffer_query_state(MutableBuffer *buffer, void *parms) { + return nullptr; + } + + static void process_query_states(void *query_parms, std::vector shard_states, void *buff_state) { + } + + static std::vector> query(VPTree *wss, void *q_state, void *parms) { + } + + static std::vector> buffer_query(MutableBuffer *buffer, void *state, void *parms) { + } + + static std::vector merge(std::vector> &results) { + } + + static void delete_query_state(void *state) { + auto s = (KNNState *) state; + delete s; + } + + static void delete_buffer_query_state(void *state) { + auto s = (KNNBufferState *) state; + delete s; + } +}; + +} diff --git a/tests/testing.h b/tests/testing.h index fe6623e..1d5db59 100644 --- a/tests/testing.h +++ b/tests/testing.h @@ -23,6 +23,7 @@ typedef de::WeightedRecord WRec; typedef de::Record Rec; +typedef de::Point PRec; template std::vector strip_wrapping(std::vector> vec) { @@ -75,7 +76,26 @@ static bool roughly_equal(int n1, int n2, size_t mag, double epsilon) { return ((double) std::abs(n1 - n2) / (double) mag) < epsilon; } -template +static de::MutableBuffer *create_2d_mbuffer(size_t cnt) { + auto buffer = new de::MutableBuffer(cnt, cnt); + + for (int64_t i=0; iappend({rand(), rand()}); + } + + return buffer; +} + +static de::MutableBuffer *create_2d_sequential_mbuffer(size_t cnt) { + auto buffer = new de::MutableBuffer(cnt, cnt); + for (int64_t i=0; iappend({i, i}); + } + + return buffer; +} + +template static de::MutableBuffer *create_test_mbuffer(size_t cnt) { auto buffer = new de::MutableBuffer(cnt, cnt); @@ -95,7 +115,7 @@ static de::MutableBuffer *create_test_mbuffer(size_t cnt) return buffer; } -template +template static de::MutableBuffer *create_sequential_mbuffer(decltype(R::key) start, decltype(R::key) stop) { size_t cnt = stop - start; @@ -116,7 +136,7 @@ static de::MutableBuffer *create_sequential_mbuffer(decltype(R::key) start, d return buffer; } -template +template static de::MutableBuffer *create_test_mbuffer_tombstones(size_t cnt, size_t ts_cnt) { auto buffer = new de::MutableBuffer(cnt, ts_cnt); @@ -147,7 +167,8 @@ static de::MutableBuffer *create_test_mbuffer_tombstones(size_t cnt, size_t t return buffer; } -template +template +requires de::WeightedRecordInterface && de::KVPInterface static de::MutableBuffer *create_weighted_mbuffer(size_t cnt) { auto buffer = new de::MutableBuffer(cnt, cnt); @@ -170,7 +191,7 @@ static de::MutableBuffer *create_weighted_mbuffer(size_t cnt) return buffer; } -template +template static de::MutableBuffer *create_double_seq_mbuffer(size_t cnt, bool ts=false) { auto buffer = new de::MutableBuffer(cnt, cnt); diff --git a/tests/vptree_tests.cpp b/tests/vptree_tests.cpp new file mode 100644 index 0000000..1b5c18c --- /dev/null +++ b/tests/vptree_tests.cpp @@ -0,0 +1,162 @@ +/* + * tests/vptree_tests.cpp + * + * Unit tests for VPTree (knn queries) + * + * Copyright (C) 2023 Douglas Rumbaugh + * + * All rights reserved. Published under the Modified BSD License. + * + */ + +#include "shard/VPTree.h" +#include "testing.h" + +#include + +using namespace de; + + +typedef VPTree Shard; + +START_TEST(t_mbuffer_init) +{ + size_t n= 24; + auto buffer = new MutableBuffer(n, n); + + for (int64_t i=0; iappend({i, i}); + } + + Shard* shard = new Shard(buffer); + ck_assert_uint_eq(shard->get_record_count(), n); + + delete buffer; + delete shard; +} + + +START_TEST(t_wss_init) +{ + size_t n = 512; + auto mbuffer1 = create_2d_mbuffer(n); + auto mbuffer2 = create_2d_mbuffer(n); + auto mbuffer3 = create_2d_mbuffer(n); + + auto shard1 = new Shard(mbuffer1); + auto shard2 = new Shard(mbuffer2); + auto shard3 = new Shard(mbuffer3); + + Shard* shards[3] = {shard1, shard2, shard3}; + auto shard4 = new Shard(shards, 3); + + ck_assert_int_eq(shard4->get_record_count(), n * 3); + ck_assert_int_eq(shard4->get_tombstone_count(), 0); + + delete mbuffer1; + delete mbuffer2; + delete mbuffer3; + + delete shard1; + delete shard2; + delete shard3; + delete shard4; +} + + +START_TEST(t_point_lookup) +{ + size_t n = 30; + + auto buffer = create_2d_sequential_mbuffer(n); + auto wss = Shard(buffer); + + for (size_t i=0; iget_data() + i); + r.x = rec->rec.x; + r.y = rec->rec.y; + + fprintf(stderr, "%ld\n", i); + + auto result = wss.point_lookup(r); + ck_assert_ptr_nonnull(result); + ck_assert_int_eq(result->rec.x, r.x); + ck_assert_int_eq(result->rec.y, r.y); + } + + delete buffer; +} +END_TEST + + +START_TEST(t_point_lookup_miss) +{ + size_t n = 10000; + + auto buffer = create_2d_sequential_mbuffer(n); + auto wss = Shard(buffer); + + for (size_t i=n + 100; i<2*n; i++) { + PRec r; + r.x = i; + r.y = i; + + auto result = wss.point_lookup(r); + ck_assert_ptr_null(result); + } + + delete buffer; +} + + +Suite *unit_testing() +{ + Suite *unit = suite_create("VPTree Shard Unit Testing"); + + TCase *create = tcase_create("de::VPTree constructor Testing"); + tcase_add_test(create, t_mbuffer_init); + tcase_add_test(create, t_wss_init); + tcase_set_timeout(create, 100); + suite_add_tcase(unit, create); + + + TCase *lookup = tcase_create("de:VPTree:point_lookup Testing"); + tcase_add_test(lookup, t_point_lookup); + tcase_add_test(lookup, t_point_lookup_miss); + suite_add_tcase(unit, lookup); + + + /* + TCase *sampling = tcase_create("de:VPTree::VPTreeQuery Testing"); + tcase_add_test(sampling, t_wss_query); + tcase_add_test(sampling, t_wss_query_merge); + tcase_add_test(sampling, t_wss_buffer_query_rejection); + tcase_add_test(sampling, t_wss_buffer_query_scan); + suite_add_tcase(unit, sampling); + */ + + return unit; +} + + +int shard_unit_tests() +{ + int failed = 0; + Suite *unit = unit_testing(); + SRunner *unit_shardner = srunner_create(unit); + + srunner_run_all(unit_shardner, CK_NORMAL); + failed = srunner_ntests_failed(unit_shardner); + srunner_free(unit_shardner); + + return failed; +} + + +int main() +{ + int unit_failed = shard_unit_tests(); + + return (unit_failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE; +} -- cgit v1.2.3