summaryrefslogtreecommitdiffstats
path: root/include/query/knn.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/query/knn.h')
-rw-r--r--include/query/knn.h224
1 files changed, 112 insertions, 112 deletions
diff --git a/include/query/knn.h b/include/query/knn.h
index a227293..87ea10a 100644
--- a/include/query/knn.h
+++ b/include/query/knn.h
@@ -6,7 +6,7 @@
* Distributed under the Modified BSD License.
*
* A query class for k-NN queries, designed for use with the VPTree
- * shard.
+ * shard.
*
* FIXME: no support for tombstone deletes just yet. This would require a
* query resumption mechanism, most likely.
@@ -16,147 +16,147 @@
#include "framework/QueryRequirements.h"
#include "psu-ds/PriorityQueue.h"
-namespace de { namespace knn {
+namespace de {
+namespace knn {
using psudb::PriorityQueue;
-template <NDRecordInterface R>
-struct Parms {
+template <ShardInterface S> class Query {
+ typedef typename S::RECORD R;
+
+public:
+ struct Parameters {
R point;
size_t k;
-};
+ };
-template <NDRecordInterface R>
-struct State {
- size_t k;
-};
+ struct LocalQuery {
+ Parameters global_parms;
+ };
-template <NDRecordInterface R>
-struct BufferState {
+ struct LocalQueryBuffer {
BufferView<R> *buffer;
+ Parameters global_parms;
+ };
- BufferState(BufferView<R> *buffer)
- : buffer(buffer) {}
-};
+ typedef Wrapped<R> LocalResultType;
+ typedef R ResultType;
+ constexpr static bool EARLY_ABORT = false;
+ constexpr static bool SKIP_DELETE_FILTER = true;
-template <NDRecordInterface R, ShardInterface<R> S>
-class Query {
-public:
- constexpr static bool EARLY_ABORT=false;
- constexpr static bool SKIP_DELETE_FILTER=true;
+ static LocalQuery *local_preproc(S *shard, Parameters *parms) {
+ auto query = new LocalQuery();
+ query->global_parms = *parms;
- static void *get_query_state(S *shard, void *parms) {
- return nullptr;
- }
+ return query;
+ }
- static void* get_buffer_query_state(BufferView<R> *buffer, void *parms) {
- return new BufferState<R>(buffer);
- }
+ static LocalQueryBuffer *local_preproc_buffer(BufferView<R> *buffer,
+ Parameters *parms) {
+ auto query = new LocalQueryBuffer();
+ query->global_parms = *parms;
+ query->buffer = buffer;
- static void process_query_states(void *query_parms, std::vector<void*> &shard_states, void* buffer_state) {
- return;
- }
+ return query;
+ }
- static std::vector<Wrapped<R>> query(S *shard, void *q_state, void *parms) {
- std::vector<Wrapped<R>> results;
- Parms<R> *p = (Parms<R> *) parms;
- Wrapped<R> wrec;
- wrec.rec = p->point;
- wrec.header = 0;
+ static void distribute_query(Parameters *parms,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return;
+ }
- PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(p->k, &wrec);
+ static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
+ std::vector<LocalResultType> results;
- shard->search(p->point, p->k, pq);
+ Wrapped<R> wrec;
+ wrec.rec = query->global_parms.point;
+ wrec.header = 0;
- while (pq.size() > 0) {
- results.emplace_back(*pq.peek().data);
- pq.pop();
- }
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k,
+ &wrec);
- return results;
+ shard->search(query->global_parms.point, query->global_parms.k, pq);
+
+ while (pq.size() > 0) {
+ results.emplace_back(*pq.peek().data);
+ pq.pop();
}
- static std::vector<Wrapped<R>> buffer_query(void *state, void *parms) {
- Parms<R> *p = (Parms<R> *) parms;
- BufferState<R> *s = (BufferState<R> *) state;
- Wrapped<R> wrec;
- wrec.rec = p->point;
- wrec.header = 0;
-
- size_t k = p->k;
-
- PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(k, &wrec);
- for (size_t i=0; i<s->buffer->get_record_count(); i++) {
- // Skip over deleted records (under tagging)
- if (s->buffer->get(i)->is_deleted()) {
- continue;
- }
-
- if (pq.size() < k) {
- pq.push(s->buffer->get(i));
- } else {
- double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
- double cur_dist = (s->buffer->get(i))->rec.calc_distance(wrec.rec);
-
- if (cur_dist < head_dist) {
- pq.pop();
- pq.push(s->buffer->get(i));
- }
- }
- }
+ return results;
+ }
- std::vector<Wrapped<R>> results;
- while (pq.size() > 0) {
- results.emplace_back(*(pq.peek().data));
- pq.pop();
- }
+ static std::vector<LocalResultType>
+ local_query_buffer(LocalQueryBuffer *query) {
- return std::move(results);
- }
+ std::vector<LocalResultType> results;
- static std::vector<R> merge(std::vector<std::vector<Wrapped<R>>> &results, void *parms, std::vector<R> &output) {
- Parms<R> *p = (Parms<R> *) parms;
- R rec = p->point;
- size_t k = p->k;
-
- PriorityQueue<R, DistCmpMax<R>> pq(k, &rec);
- for (size_t i=0; i<results.size(); i++) {
- for (size_t j=0; j<results[i].size(); j++) {
- if (pq.size() < k) {
- pq.push(&results[i][j].rec);
- } else {
- double head_dist = pq.peek().data->calc_distance(rec);
- double cur_dist = results[i][j].rec.calc_distance(rec);
-
- if (cur_dist < head_dist) {
- pq.pop();
- pq.push(&results[i][j].rec);
- }
- }
- }
- }
+ Wrapped<R> wrec;
+ wrec.rec = query->global_parms.point;
+ wrec.header = 0;
- while (pq.size() > 0) {
- output.emplace_back(*pq.peek().data);
- pq.pop();
- }
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(query->global_parms.k,
+ &wrec);
+
+ for (size_t i = 0; i < query->buffer->get_record_count(); i++) {
+ // Skip over deleted records (under tagging)
+ if (query->buffer->get(i)->is_deleted()) {
+ continue;
+ }
- return std::move(output);
+ if (pq.size() < query->global_parms.k) {
+ pq.push(query->buffer->get(i));
+ } else {
+ double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
+ double cur_dist = (query->buffer->get(i))->rec.calc_distance(wrec.rec);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(query->buffer->get(i));
+ }
+ }
}
- static void delete_query_state(void *state) {
- auto s = (State<R> *) state;
- delete s;
+ while (pq.size() > 0) {
+ results.emplace_back(*(pq.peek().data));
+ pq.pop();
}
- static void delete_buffer_query_state(void *state) {
- auto s = (BufferState<R> *) state;
- delete s;
+ return std::move(results);
+ }
+
+ static void
+ combine(std::vector<std::vector<LocalResultType>> const &local_results,
+ Parameters *parms, std::vector<ResultType> &output) {
+
+ PriorityQueue<R, DistCmpMax<R>> pq(parms->k, &(parms->point));
+ for (size_t i = 0; i < local_results.size(); i++) {
+ for (size_t j = 0; j < local_results[i].size(); j++) {
+ if (pq.size() < parms->k) {
+ pq.push(&local_results[i][j].rec);
+ } else {
+ double head_dist = pq.peek().data->calc_distance(parms->point);
+ double cur_dist = local_results[i][j].rec.calc_distance(parms->point);
+
+ if (cur_dist < head_dist) {
+ pq.pop();
+ pq.push(&local_results[i][j].rec);
+ }
+ }
+ }
}
- static bool repeat(void *parms, std::vector<R> &results, std::vector<void*> states, void* buffer_state) {
- return false;
+ while (pq.size() > 0) {
+ output.emplace_back(*pq.peek().data);
+ pq.pop();
}
-};
+ }
-}}
+ static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ std::vector<LocalQuery *> const &local_queries,
+ LocalQueryBuffer *buffer_query) {
+ return false;
+ }
+};
+} // namespace knn
+} // namespace de