diff options
Diffstat (limited to 'include/shard/VPTree.h')
| -rw-r--r-- | include/shard/VPTree.h | 329 |
1 files changed, 82 insertions, 247 deletions
diff --git a/include/shard/VPTree.h b/include/shard/VPTree.h index 8feec84..b342fe6 100644 --- a/include/shard/VPTree.h +++ b/include/shard/VPTree.h @@ -1,97 +1,31 @@ /* * include/shard/VPTree.h * - * Copyright (C) 2023 Douglas Rumbaugh <drumbaugh@psu.edu> + * Copyright (C) 2023 Douglas B. Rumbaugh <drumbaugh@psu.edu> * - * All outsides reserved. Published under the Modified BSD License. + * Distributed under the Modified BSD License. * + * A shard shim around a VPTree for high-dimensional metric similarity + * search. + * + * FIXME: Does not yet support the tombstone delete policy. + * TODO: The code in this file is very poorly commented. */ #pragma once #include <vector> -#include <cassert> -#include <queue> -#include <memory> -#include <concepts> -#include <map> +#include <unordered_map> +#include "framework/ShardRequirements.h" #include "psu-ds/PriorityQueue.h" -#include "util/Cursor.h" -#include "psu-ds/BloomFilter.h" -#include "util/bf_config.h" -#include "framework/MutableBuffer.h" -#include "framework/RecordInterface.h" -#include "framework/ShardInterface.h" -#include "framework/QueryInterface.h" using psudb::CACHELINE_SIZE; -using psudb::BloomFilter; using psudb::PriorityQueue; using psudb::queue_record; -using psudb::Alias; +using psudb::byte; namespace de { -template <NDRecordInterface R> -struct KNNQueryParms { - R point; - size_t k; -}; - -template <NDRecordInterface R> -class KNNQuery; - -template <NDRecordInterface R> -struct KNNState { - size_t k; - - KNNState() { - k = 0; - } -}; - -template <NDRecordInterface R> -struct KNNBufferState { - -}; - - -template <typename R> -class KNNDistCmpMax { -public: - KNNDistCmpMax(R *baseline) : P(baseline) {} - - inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> { - return a->rec.calc_distance(P->rec) > b->rec.calc_distance(P->rec); - } - - inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){ - return a->calc_distance(*P) > b->calc_distance(*P); - } - -private: - R *P; -}; - -template <typename R> -class KNNDistCmpMin { -public: - KNNDistCmpMin(R *baseline) : P(baseline) {} - - inline bool operator()(const R *a, const R *b) requires WrappedInterface<R> { - return a->rec.calc_distance(P->rec) < b->rec.calc_distance(P->rec); - } - - inline bool operator()(const R *a, const R *b) requires (!WrappedInterface<R>){ - return a->calc_distance(*P) < b->calc_distance(*P); - } - -private: - R *P; -}; - - - template <NDRecordInterface R, size_t LEAFSZ=100, bool HMAP=false> class VPTree { private: @@ -112,16 +46,19 @@ private: } }; -public: - friend class KNNQuery<R>; - VPTree(MutableBuffer<R>* buffer) + +public: + VPTree(BufferView<R> buffer) : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { - m_alloc_size = (buffer->get_record_count() * sizeof(Wrapped<R>)) + (CACHELINE_SIZE - (buffer->get_record_count() * sizeof(Wrapped<R>)) % CACHELINE_SIZE); - assert(m_alloc_size % CACHELINE_SIZE == 0); - m_data = (Wrapped<R>*)std::aligned_alloc(CACHELINE_SIZE, m_alloc_size); - m_ptrs = new Wrapped<R>*[buffer->get_record_count()]; + + m_alloc_size = psudb::sf_aligned_alloc(CACHELINE_SIZE, + buffer.get_record_count() * + sizeof(Wrapped<R>), + (byte**) &m_data); + + m_ptrs = new Wrapped<R>*[buffer.get_record_count()]; size_t offset = 0; m_reccnt = 0; @@ -130,8 +67,8 @@ public: // this one will likely require the multi-pass // approach, as otherwise we'll need to sort the // records repeatedly on each reconstruction. - for (size_t i=0; i<buffer->get_record_count(); i++) { - auto rec = buffer->get_data() + i; + for (size_t i=0; i<buffer.get_record_count(); i++) { + auto rec = buffer.get(i); if (rec->is_deleted()) { continue; @@ -149,25 +86,24 @@ public: } } - VPTree(VPTree** shards, size_t len) + VPTree(std::vector<VPTree*> shards) : m_reccnt(0), m_tombstone_cnt(0), m_root(nullptr), m_node_cnt(0) { size_t attemp_reccnt = 0; - - for (size_t i=0; i<len; i++) { + for (size_t i=0; i<shards.size(); i++) { attemp_reccnt += shards[i]->get_record_count(); } - - m_alloc_size = (attemp_reccnt * sizeof(Wrapped<R>)) + (CACHELINE_SIZE - (attemp_reccnt * sizeof(Wrapped<R>)) % CACHELINE_SIZE); - assert(m_alloc_size % CACHELINE_SIZE == 0); - m_data = (Wrapped<R>*)std::aligned_alloc(CACHELINE_SIZE, m_alloc_size); + + m_alloc_size = psudb::sf_aligned_alloc(CACHELINE_SIZE, + attemp_reccnt * sizeof(Wrapped<R>), + (byte **) &m_data); m_ptrs = new Wrapped<R>*[attemp_reccnt]; // FIXME: will eventually need to figure out tombstones // this one will likely require the multi-pass // approach, as otherwise we'll need to sort the // records repeatedly on each reconstruction. - for (size_t i=0; i<len; i++) { + for (size_t i=0; i<shards.size(); i++) { for (size_t j=0; j<shards[i]->get_record_count(); j++) { if (shards[i]->get_record_at(j)->is_deleted()) { continue; @@ -186,9 +122,9 @@ public: } ~VPTree() { - if (m_data) free(m_data); - if (m_root) delete m_root; - if (m_ptrs) delete[] m_ptrs; + free(m_data); + delete m_root; + delete[] m_ptrs; } Wrapped<R> *point_lookup(const R &rec, bool filter=false) { @@ -242,7 +178,28 @@ public: return m_node_cnt * sizeof(vpnode) + m_reccnt * sizeof(R*) + m_alloc_size; } + size_t get_aux_memory_usage() { + // FIXME: need to return the size of the unordered_map + return 0; + } + + void search(const R &point, size_t k, PriorityQueue<Wrapped<R>, + DistCmpMax<Wrapped<R>>> &pq) { + double farthest = std::numeric_limits<double>::max(); + + internal_search(m_root, point, k, pq, &farthest); + } + private: + Wrapped<R>* m_data; + Wrapped<R>** m_ptrs; + std::unordered_map<R, size_t, RecordHash<R>> m_lookup_map; + size_t m_reccnt; + size_t m_tombstone_cnt; + size_t m_node_cnt; + size_t m_alloc_size; + + vpnode *m_root; vpnode *build_vptree() { if (m_reccnt == 0) { @@ -277,13 +234,15 @@ private: } vpnode *build_subtree(size_t start, size_t stop, gsl_rng *rng) { - // base-case: sometimes happens (probably because of the +1 and -1 - // in the first recursive call) + /* + * base-case: sometimes happens (probably because of the +1 and -1 + * in the first recursive call) + */ if (start > stop) { return nullptr; } - // base-case: create a leaf node + /* base-case: create a leaf node */ if (stop - start <= LEAFSZ) { vpnode *node = new vpnode(); node->start = start; @@ -294,26 +253,30 @@ private: return node; } - // select a random element to be the root of the - // subtree + /* + * select a random element to be the root of the + * subtree + */ auto i = start + gsl_rng_uniform_int(rng, stop - start + 1); swap(start, i); - // partition elements based on their distance from the start, - // with those elements with distance falling below the median - // distance going into the left sub-array and those above - // the median in the right. This is easily done using QuickSelect. + /* + * partition elements based on their distance from the start, + * with those elements with distance falling below the median + * distance going into the left sub-array and those above + * the median in the right. This is easily done using QuickSelect. + */ auto mid = (start + 1 + stop) / 2; quickselect(start + 1, stop, mid, m_ptrs[start], rng); - // Create a new node based on this partitioning + /* Create a new node based on this partitioning */ vpnode *node = new vpnode(); node->start = start; - // store the radius of the circle used for partitioning the node. + /* store the radius of the circle used for partitioning the node. */ node->radius = m_ptrs[start]->rec.calc_distance(m_ptrs[mid]->rec); - // recursively construct the left and right subtrees + /* recursively construct the left and right subtrees */ node->inside = build_subtree(start + 1, mid-1, rng); node->outside = build_subtree(mid, stop, rng); @@ -322,7 +285,8 @@ private: return node; } - + // TODO: The quickselect code can probably be generalized and moved out + // to psudb-common instead. void quickselect(size_t start, size_t stop, size_t k, Wrapped<R> *p, gsl_rng *rng) { if (start == stop) return; @@ -335,7 +299,8 @@ private: } } - + // TODO: The quickselect code can probably be generalized and moved out + // to psudb-common instead. size_t partition(size_t start, size_t stop, Wrapped<R> *p, gsl_rng *rng) { auto pivot = start + gsl_rng_uniform_int(rng, stop - start); double pivot_dist = p->rec.calc_distance(m_ptrs[pivot]->rec); @@ -354,15 +319,15 @@ private: return j; } - void swap(size_t idx1, size_t idx2) { auto tmp = m_ptrs[idx1]; m_ptrs[idx1] = m_ptrs[idx2]; m_ptrs[idx2] = tmp; } + void internal_search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>, + DistCmpMax<Wrapped<R>>> &pq, double *farthest) { - void search(vpnode *node, const R &point, size_t k, PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> &pq, double *farthest) { if (node == nullptr) return; if (node->leaf) { @@ -398,151 +363,21 @@ private: if (d < node->radius) { if (d - (*farthest) <= node->radius) { - search(node->inside, point, k, pq, farthest); + internal_search(node->inside, point, k, pq, farthest); } if (d + (*farthest) >= node->radius) { - search(node->outside, point, k, pq, farthest); + internal_search(node->outside, point, k, pq, farthest); } } else { if (d + (*farthest) >= node->radius) { - search(node->outside, point, k, pq, farthest); + internal_search(node->outside, point, k, pq, farthest); } if (d - (*farthest) <= node->radius) { - search(node->inside, point, k, pq, farthest); + internal_search(node->inside, point, k, pq, farthest); } } } - - Wrapped<R>* m_data; - Wrapped<R>** m_ptrs; - std::unordered_map<R, size_t, RecordHash<R>> m_lookup_map; - size_t m_reccnt; - size_t m_tombstone_cnt; - size_t m_node_cnt; - size_t m_alloc_size; - - vpnode *m_root; -}; - - -template <NDRecordInterface R> -class KNNQuery { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=true; - - static void *get_query_state(VPTree<R> *wss, void *parms) { - return nullptr; - } - - static void* get_buffer_query_state(MutableBuffer<R> *buffer, void *parms) { - return nullptr; - } - - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void *buff_state) { - return; - } - - static std::vector<Wrapped<R>> query(VPTree<R> *wss, void *q_state, void *parms) { - std::vector<Wrapped<R>> results; - KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms; - Wrapped<R> wrec; - wrec.rec = p->point; - wrec.header = 0; - - PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(p->k, &wrec); - - double farthest = std::numeric_limits<double>::max(); - - wss->search(wss->m_root, p->point, p->k, pq, &farthest); - - while (pq.size() > 0) { - results.emplace_back(*pq.peek().data); - pq.pop(); - } - - return results; - } - - static std::vector<Wrapped<R>> buffer_query(MutableBuffer<R> *buffer, void *state, void *parms) { - KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms; - Wrapped<R> wrec; - wrec.rec = p->point; - wrec.header = 0; - - size_t k = p->k; - - PriorityQueue<Wrapped<R>, KNNDistCmpMax<Wrapped<R>>> pq(k, &wrec); - for (size_t i=0; i<buffer->get_record_count(); i++) { - // Skip over deleted records (under tagging) - if ((buffer->get_data())[i].is_deleted()) { - continue; - } - - if (pq.size() < k) { - pq.push(buffer->get_data() + i); - } else { - double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); - double cur_dist = (buffer->get_data() + i)->rec.calc_distance(wrec.rec); - - if (cur_dist < head_dist) { - pq.pop(); - pq.push(buffer->get_data() + i); - } - } - } - - std::vector<Wrapped<R>> results; - while (pq.size() > 0) { - results.emplace_back(*(pq.peek().data)); - pq.pop(); - } - - return results; - } - - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms) { - KNNQueryParms<R> *p = (KNNQueryParms<R> *) parms; - R rec = p->point; - size_t k = p->k; - - PriorityQueue<R, KNNDistCmpMax<R>> pq(k, &rec); - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - if (pq.size() < k) { - pq.push(&results[i][j].rec); - } else { - double head_dist = pq.peek().data->calc_distance(rec); - double cur_dist = results[i][j].rec.calc_distance(rec); - - if (cur_dist < head_dist) { - pq.pop(); - pq.push(&results[i][j].rec); - } - } - } - } - - std::vector<R> output; - while (pq.size() > 0) { - output.emplace_back(*pq.peek().data); - pq.pop(); - } - - return output; - } - - static void delete_query_state(void *state) { - auto s = (KNNState<R> *) state; - delete s; - } - - static void delete_buffer_query_state(void *state) { - auto s = (KNNBufferState<R> *) state; - delete s; - } -}; - + }; } |