diff options
Diffstat (limited to 'include/query/knn.h')
| -rw-r--r-- | include/query/knn.h | 224 |
1 files changed, 112 insertions, 112 deletions
diff --git a/include/query/knn.h b/include/query/knn.h index a227293..87ea10a 100644 --- a/include/query/knn.h +++ b/include/query/knn.h @@ -6,7 +6,7 @@ * Distributed under the Modified BSD License. * * A query class for k-NN queries, designed for use with the VPTree - * shard. + * shard. * * FIXME: no support for tombstone deletes just yet. This would require a * query resumption mechanism, most likely. @@ -16,147 +16,147 @@ #include "framework/QueryRequirements.h" #include "psu-ds/PriorityQueue.h" -namespace de { namespace knn { +namespace de { +namespace knn { using psudb::PriorityQueue; -template <NDRecordInterface R> -struct Parms { +template <ShardInterface S> class Query { + typedef typename S::RECORD R; + +public: + struct Parameters { R point; size_t k; -}; + }; -template <NDRecordInterface R> -struct State { - size_t k; -}; + struct LocalQuery { + Parameters global_parms; + }; -template <NDRecordInterface R> -struct BufferState { + struct LocalQueryBuffer { BufferView<R> *buffer; + Parameters global_parms; + }; - BufferState(BufferView<R> *buffer) - : buffer(buffer) {} -}; + typedef Wrapped<R> LocalResultType; + typedef R ResultType; + constexpr static bool EARLY_ABORT = false; + constexpr static bool SKIP_DELETE_FILTER = true; -template <NDRecordInterface R, ShardInterface<R> S> -class Query { -public: - constexpr static bool EARLY_ABORT=false; - constexpr static bool SKIP_DELETE_FILTER=true; + static LocalQuery *local_preproc(S *shard, Parameters *parms) { + auto query = new LocalQuery(); + query->global_parms = *parms; - static void *get_query_state(S *shard, void *parms) { - return nullptr; - } + return query; + } - static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) { - return new BufferState<R>(buffer); - } + static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer, + Parameters *parms) { + auto query = new LocalQueryBuffer(); + query->global_parms = *parms; + query->buffer = buffer; - static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) { - return; - } + return query; + } - static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) { - std::vector<Wrapped<R>> results; - Parms<R> *p = (Parms<R> *) parms; - Wrapped<R> wrec; - wrec.rec = p->point; - wrec.header = 0; + static void distribute_query(Parameters *parms, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return; + } - PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(p->k, &wrec); + static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { + std::vector<LocalResultType> results; - shard->search(p->point, p->k, pq); + Wrapped<R> wrec; + wrec.rec = query->global_parms.point; + wrec.header = 0; - while (pq.size() > 0) { - results.emplace_back(*pq.peek().data); - pq.pop(); - } + PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k, + &wrec); - return results; + shard->search(query->global_parms.point, query->global_parms.k, pq); + + while (pq.size() > 0) { + results.emplace_back(*pq.peek().data); + pq.pop(); } - static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) { - Parms<R> *p = (Parms<R> *) parms; - BufferState<R> *s = (BufferState<R> *) state; - Wrapped<R> wrec; - wrec.rec = p->point; - wrec.header = 0; - - size_t k = p->k; - - PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(k, &wrec); - for (size_t i=0; i<s->buffer->get_record_count(); i++) { - // Skip over deleted records (under tagging) - if (s->buffer->get(i)->is_deleted()) { - continue; - } - - if (pq.size() < k) { - pq.push(s->buffer->get(i)); - } else { - double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); - double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec); - - if (cur_dist < head_dist) { - pq.pop(); - pq.push(s->buffer->get(i)); - } - } - } + return results; + } - std::vector<Wrapped<R>> results; - while (pq.size() > 0) { - results.emplace_back(*(pq.peek().data)); - pq.pop(); - } + static std::vector<LocalResultType> + local_query_buffer(LocalQueryBuffer *query) { - return std::move(results); - } + std::vector<LocalResultType> results; - static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) { - Parms<R> *p = (Parms<R> *) parms; - R rec = p->point; - size_t k = p->k; - - PriorityQueue<R, DistCmpMax<R>> pq(k, &rec); - for (size_t i=0; i<results.size(); i++) { - for (size_t j=0; j<results[i].size(); j++) { - if (pq.size() < k) { - pq.push(&results[i][j].rec); - } else { - double head_dist = pq.peek().data->calc_distance(rec); - double cur_dist = results[i][j].rec.calc_distance(rec); - - if (cur_dist < head_dist) { - pq.pop(); - pq.push(&results[i][j].rec); - } - } - } - } + Wrapped<R> wrec; + wrec.rec = query->global_parms.point; + wrec.header = 0; - while (pq.size() > 0) { - output.emplace_back(*pq.peek().data); - pq.pop(); - } + PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k, + &wrec); + + for (size_t i = 0; i < query->buffer->get_record_count(); i++) { + // Skip over deleted records (under tagging) + if (query->buffer->get(i)->is_deleted()) { + continue; + } - return std::move(output); + if (pq.size() < query->global_parms.k) { + pq.push(query->buffer->get(i)); + } else { + double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); + double cur_dist = (query->buffer->get(i))->rec.calc_distance(wrec.rec); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(query->buffer->get(i)); + } + } } - static void delete_query_state(void *state) { - auto s = (State<R> *) state; - delete s; + while (pq.size() > 0) { + results.emplace_back(*(pq.peek().data)); + pq.pop(); } - static void delete_buffer_query_state(void *state) { - auto s = (BufferState<R> *) state; - delete s; + return std::move(results); + } + + static void + combine(std::vector<std::vector<LocalResultType>> const &local_results, + Parameters *parms, std::vector<ResultType> &output) { + + PriorityQueue<R, DistCmpMax<R>> pq(parms->k, &(parms->point)); + for (size_t i = 0; i < local_results.size(); i++) { + for (size_t j = 0; j < local_results[i].size(); j++) { + if (pq.size() < parms->k) { + pq.push(&local_results[i][j].rec); + } else { + double head_dist = pq.peek().data->calc_distance(parms->point); + double cur_dist = local_results[i][j].rec.calc_distance(parms->point); + + if (cur_dist < head_dist) { + pq.pop(); + pq.push(&local_results[i][j].rec); + } + } + } } - static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) { - return false; + while (pq.size() > 0) { + output.emplace_back(*pq.peek().data); + pq.pop(); } -}; + } -}} + static bool repeat(Parameters *parms, std::vector<ResultType> &output, + std::vector<LocalQuery *> const &local_queries, + LocalQueryBuffer *buffer_query) { + return false; + } +}; +} // namespace knn +} // namespace de |