From 711769574e647839677739192698e400529efe75 Mon Sep 17 00:00:00 2001 From: Douglas Rumbaugh Date: Thu, 8 Feb 2024 16:38:44 -0500 Subject: Updated VPTree to new shard/query interfaces --- include/query/knn.h | 159 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 include/query/knn.h (limited to 'include/query/knn.h') diff --git a/include/query/knn.h b/include/query/knn.h new file mode 100644 index 0000000..19dcf5c --- /dev/null +++ b/include/query/knn.h @@ -0,0 +1,159 @@ +/* + * include/query/knn.h + * + * Copyright (C) 2023 Douglas B. Rumbaugh + * + * Distributed under the Modified BSD License. + * + * A query class for k-NN queries, designed for use with the VPTree + * shard. + * + * FIXME: no support for tombstone deletes just yet. This would require a + * query resumption mechanism, most likely. + */ +#pragma once + +#include "framework/QueryRequirements.h" +#include "psu-ds/PriorityQueue.h" + +namespace de { namespace knn { + +using psudb::PriorityQueue; + +template +struct Parms { + R point; + size_t k; +}; + +template +struct State { + size_t k; +}; + +template +struct BufferState { + BufferView *buffer; + + BufferState(BufferView *buffer) + : buffer(buffer) {} +}; + +template S> +class Query { +public: + constexpr static bool EARLY_ABORT=false; + constexpr static bool SKIP_DELETE_FILTER=true; + + static void *get_query_state(S *shard, void *parms) { + return nullptr; + } + + static void* get_buffer_query_state(BufferView *buffer, void *parms) { + return new BufferState(buffer); + } + + static void process_query_states(void *query_parms, std::vector &shard_states, void* buffer_state) { + return; + } + + static std::vector> query(S *shard, void *q_state, void *parms) { + std::vector> results; + Parms *p = (Parms *) parms; + Wrapped wrec; + wrec.rec = p->point; + wrec.header = 0; + + PriorityQueue, DistCmpMax>> pq(p->k, &wrec); + + shard->search(p->point, p->k, pq); + + while (pq.size() > 0) { + results.emplace_back(*pq.peek().data); + pq.pop(); + } + + return results; + } + + static std::vector> buffer_query(void *state, void *parms) { + Parms *p = (Parms *) parms; + BufferState *s = (BufferState *) state; + Wrapped wrec; + wrec.rec = p->point; + wrec.header = 0; + + size_t k = p->k; + + PriorityQueue, DistCmpMax>> pq(k, &wrec); + for (size_t i=0; ibuffer->get_record_count(); i++) { + // Skip over deleted records (under tagging) + if (s->buffer->get(i)->is_deleted()) { + continue; + } + + if (pq.size() < k) { + pq.push(s->buffer->get(i)); + } else { + double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); + double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(s->buffer->get(i)); + } + } + } + + std::vector> results; + while (pq.size() > 0) { + results.emplace_back(*(pq.peek().data)); + pq.pop(); + } + + return results; + } + + static std::vector merge(std::vector>> &results, void *parms) { + Parms *p = (Parms *) parms; + R rec = p->point; + size_t k = p->k; + + PriorityQueue> pq(k, &rec); + for (size_t i=0; icalc_distance(rec); + double cur_dist = results[i][j].rec.calc_distance(rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(&results[i][j].rec); + } + } + } + } + + std::vector output; + while (pq.size() > 0) { + output.emplace_back(*pq.peek().data); + pq.pop(); + } + + return output; + } + + static void delete_query_state(void *state) { + auto s = (State *) state; + delete s; + } + + static void delete_buffer_query_state(void *state) { + auto s = (BufferState *) state; + delete s; + } +}; + +}} -- cgit v1.2.3