/* * include/query/knn.h * * Copyright (C) 2023 Douglas B. Rumbaugh * * Distributed under the Modified BSD License. * * A query class for k-NN queries, designed for use with the VPTree * shard. * * FIXME: no support for tombstone deletes just yet. This would require a * query resumption mechanism, most likely. */ #pragma once #include "framework/QueryRequirements.h" #include "psu-ds/PriorityQueue.h" namespace de { namespace knn { using psudb::PriorityQueue; template struct Parms { R point; size_t k; }; template struct State { size_t k; }; template struct BufferState { BufferView *buffer; BufferState(BufferView *buffer) : buffer(buffer) {} }; template S> class Query { public: constexpr static bool EARLY_ABORT=false; constexpr static bool SKIP_DELETE_FILTER=true; static void *get_query_state(S *shard, void *parms) { return nullptr; } static void* get_buffer_query_state(BufferView *buffer, void *parms) { return new BufferState(buffer); } static void process_query_states(void *query_parms, std::vector &shard_states, void* buffer_state) { return; } static std::vector> query(S *shard, void *q_state, void *parms) { std::vector> results; Parms *p = (Parms *) parms; Wrapped wrec; wrec.rec = p->point; wrec.header = 0; PriorityQueue, DistCmpMax>> pq(p->k, &wrec); shard->search(p->point, p->k, pq); while (pq.size() > 0) { results.emplace_back(*pq.peek().data); pq.pop(); } return results; } static std::vector> buffer_query(void *state, void *parms) { Parms *p = (Parms *) parms; BufferState *s = (BufferState *) state; Wrapped wrec; wrec.rec = p->point; wrec.header = 0; size_t k = p->k; PriorityQueue, DistCmpMax>> pq(k, &wrec); for (size_t i=0; ibuffer->get_record_count(); i++) { // Skip over deleted records (under tagging) if (s->buffer->get(i)->is_deleted()) { continue; } if (pq.size() < k) { pq.push(s->buffer->get(i)); } else { double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec); if (cur_dist < head_dist) { pq.pop(); pq.push(s->buffer->get(i)); } } } std::vector> results; while (pq.size() > 0) { results.emplace_back(*(pq.peek().data)); pq.pop(); } return std::move(results); } static std::vector merge(std::vector>> &results, void *parms, std::vector &output) { Parms *p = (Parms *) parms; R rec = p->point; size_t k = p->k; PriorityQueue> pq(k, &rec); for (size_t i=0; icalc_distance(rec); double cur_dist = results[i][j].rec.calc_distance(rec); if (cur_dist < head_dist) { pq.pop(); pq.push(&results[i][j].rec); } } } } while (pq.size() > 0) { output.emplace_back(*pq.peek().data); pq.pop(); } return std::move(output); } static void delete_query_state(void *state) { auto s = (State *) state; delete s; } static void delete_buffer_query_state(void *state) { auto s = (BufferState *) state; delete s; } static bool repeat(void *parms, std::vector &results, std::vector states, void* buffer_state) { return false; } }; }}