summaryrefslogtreecommitdiffstats
path: root/include/query/knn.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/query/knn.h')
-rw-r--r--include/query/knn.h42
1 files changed, 22 insertions, 20 deletions
diff --git a/include/query/knn.h b/include/query/knn.h
index 87ea10a..91a032c 100644
--- a/include/query/knn.h
+++ b/include/query/knn.h
@@ -39,8 +39,8 @@ public:
Parameters global_parms;
};
- typedef Wrapped<R> LocalResultType;
- typedef R ResultType;
+ typedef std::vector<const Wrapped<R> *> LocalResultType;
+ typedef std::vector<R> ResultType;
constexpr static bool EARLY_ABORT = false;
constexpr static bool SKIP_DELETE_FILTER = true;
@@ -66,8 +66,8 @@ public:
return;
}
- static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) {
- std::vector<LocalResultType> results;
+ static LocalResultType local_query(S *shard, LocalQuery *query) {
+ LocalResultType results;
Wrapped<R> wrec;
wrec.rec = query->global_parms.point;
@@ -79,17 +79,16 @@ public:
shard->search(query->global_parms.point, query->global_parms.k, pq);
while (pq.size() > 0) {
- results.emplace_back(*pq.peek().data);
+ results.emplace_back(pq.peek().data);
pq.pop();
}
return results;
}
- static std::vector<LocalResultType>
- local_query_buffer(LocalQueryBuffer *query) {
+ static LocalResultType local_query_buffer(LocalQueryBuffer *query) {
- std::vector<LocalResultType> results;
+ LocalResultType results;
Wrapped<R> wrec;
wrec.rec = query->global_parms.point;
@@ -118,41 +117,44 @@ public:
}
while (pq.size() > 0) {
- results.emplace_back(*(pq.peek().data));
+ results.emplace_back(pq.peek().data);
pq.pop();
}
- return std::move(results);
+ return results;
}
- static void
- combine(std::vector<std::vector<LocalResultType>> const &local_results,
- Parameters *parms, std::vector<ResultType> &output) {
+ static void combine(std::vector<LocalResultType> const &local_results,
+ Parameters *parms, ResultType &output) {
+
+ Wrapped<R> wrec;
+ wrec.rec = parms->point;
+ wrec.header = 0;
+ PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(parms->k, &wrec);
- PriorityQueue<R, DistCmpMax<R>> pq(parms->k, &(parms->point));
for (size_t i = 0; i < local_results.size(); i++) {
for (size_t j = 0; j < local_results[i].size(); j++) {
if (pq.size() < parms->k) {
- pq.push(&local_results[i][j].rec);
+ pq.push(local_results[i][j]);
} else {
- double head_dist = pq.peek().data->calc_distance(parms->point);
- double cur_dist = local_results[i][j].rec.calc_distance(parms->point);
+ double head_dist = pq.peek().data->rec.calc_distance(wrec.rec);
+ double cur_dist = local_results[i][j]->rec.calc_distance(wrec.rec);
if (cur_dist < head_dist) {
pq.pop();
- pq.push(&local_results[i][j].rec);
+ pq.push(local_results[i][j]);
}
}
}
}
while (pq.size() > 0) {
- output.emplace_back(*pq.peek().data);
+ output.emplace_back(pq.peek().data->rec);
pq.pop();
}
}
- static bool repeat(Parameters *parms, std::vector<ResultType> &output,
+ static bool repeat(Parameters *parms, ResultType &output,
std::vector<LocalQuery *> const &local_queries,
LocalQueryBuffer *buffer_query) {
return false;