diff options
Diffstat (limited to 'include/query/knn.h')
| -rw-r--r-- | include/query/knn.h | 42 |
1 files changed, 22 insertions, 20 deletions
diff --git a/include/query/knn.h b/include/query/knn.h index 87ea10a..91a032c 100644 --- a/include/query/knn.h +++ b/include/query/knn.h @@ -39,8 +39,8 @@ public: Parameters global_parms; }; - typedef Wrapped<R> LocalResultType; - typedef R ResultType; + typedef std::vector<const Wrapped<R> *> LocalResultType; + typedef std::vector<R> ResultType; constexpr static bool EARLY_ABORT = false; constexpr static bool SKIP_DELETE_FILTER = true; @@ -66,8 +66,8 @@ public: return; } - static std::vector<LocalResultType> local_query(S *shard, LocalQuery *query) { - std::vector<LocalResultType> results; + static LocalResultType local_query(S *shard, LocalQuery *query) { + LocalResultType results; Wrapped<R> wrec; wrec.rec = query->global_parms.point; @@ -79,17 +79,16 @@ public: shard->search(query->global_parms.point, query->global_parms.k, pq); while (pq.size() > 0) { - results.emplace_back(*pq.peek().data); + results.emplace_back(pq.peek().data); pq.pop(); } return results; } - static std::vector<LocalResultType> - local_query_buffer(LocalQueryBuffer *query) { + static LocalResultType local_query_buffer(LocalQueryBuffer *query) { - std::vector<LocalResultType> results; + LocalResultType results; Wrapped<R> wrec; wrec.rec = query->global_parms.point; @@ -118,41 +117,44 @@ public: } while (pq.size() > 0) { - results.emplace_back(*(pq.peek().data)); + results.emplace_back(pq.peek().data); pq.pop(); } - return std::move(results); + return results; } - static void - combine(std::vector<std::vector<LocalResultType>> const &local_results, - Parameters *parms, std::vector<ResultType> &output) { + static void combine(std::vector<LocalResultType> const &local_results, + Parameters *parms, ResultType &output) { + + Wrapped<R> wrec; + wrec.rec = parms->point; + wrec.header = 0; + PriorityQueue<Wrapped<R>, DistCmpMax<Wrapped<R>>> pq(parms->k, &wrec); - PriorityQueue<R, DistCmpMax<R>> pq(parms->k, &(parms->point)); for (size_t i = 0; i < local_results.size(); i++) { for (size_t j = 0; j < local_results[i].size(); j++) { if (pq.size() < parms->k) { - pq.push(&local_results[i][j].rec); + pq.push(local_results[i][j]); } else { - double head_dist = pq.peek().data->calc_distance(parms->point); - double cur_dist = local_results[i][j].rec.calc_distance(parms->point); + double head_dist = pq.peek().data->rec.calc_distance(wrec.rec); + double cur_dist = local_results[i][j]->rec.calc_distance(wrec.rec); if (cur_dist < head_dist) { pq.pop(); - pq.push(&local_results[i][j].rec); + pq.push(local_results[i][j]); } } } } while (pq.size() > 0) { - output.emplace_back(*pq.peek().data); + output.emplace_back(pq.peek().data->rec); pq.pop(); } } - static bool repeat(Parameters *parms, std::vector<ResultType> &output, + static bool repeat(Parameters *parms, ResultType &output, std::vector<LocalQuery *> const &local_queries, LocalQueryBuffer *buffer_query) { return false; |