diff options
| author | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 11:50:33 -0400 |
|---|---|---|
| committer | Douglas Rumbaugh <dbr4@psu.edu> | 2023-07-24 11:50:33 -0400 |
| commit | ac018f5f96c32c96158a239fbfeb9dc439c95548 (patch) | |
| tree | 22cfdf8aa0fd7f1680c37f38ec359a4dc69bada2 | |
| parent | 6b434ec5f2182cb9624a011bd8d65587cd5a0759 (diff) | |
| download | dynamic-extension-ac018f5f96c32c96158a239fbfeb9dc439c95548.tar.gz | |
Cosine Similarity Type
| -rw-r--r-- | benchmarks/include/bench_utility.h | 2 | ||||
| -rw-r--r-- | include/framework/RecordInterface.h | 52 | ||||
| -rw-r--r-- | tests/testing.h | 2 |
3 files changed, 50 insertions, 6 deletions
diff --git a/benchmarks/include/bench_utility.h b/benchmarks/include/bench_utility.h index a1a2773..b728cbd 100644 --- a/benchmarks/include/bench_utility.h +++ b/benchmarks/include/bench_utility.h @@ -38,7 +38,7 @@ typedef uint64_t weight_type; typedef de::WeightedRecord<key_type, value_type, weight_type> WRec; typedef de::Record<key_type, value_type> Rec; -typedef de::Point<double, 300> Word2VecRec; +typedef de::CosinePoint<double, 300> Word2VecRec; typedef de::DynamicExtension<WRec, de::WSS<WRec>, de::WSSQuery<WRec>> ExtendedWSS; typedef de::DynamicExtension<Rec, de::TrieSpline<Rec>, de::TrieSplineRangeQuery<Rec>> ExtendedTSRQ; diff --git a/include/framework/RecordInterface.h b/include/framework/RecordInterface.h index 8d40590..85a0794 100644 --- a/include/framework/RecordInterface.h +++ b/include/framework/RecordInterface.h @@ -112,10 +112,10 @@ struct WeightedRecord { }; template <typename V, size_t D=2> -struct Point{ +struct CosinePoint{ V data[D]; - inline bool operator==(const Point& other) const { + inline bool operator==(const CosinePoint& other) const { for (size_t i=0; i<D; i++) { if (data[i] != other.data[i]) { return false; @@ -126,7 +126,7 @@ struct Point{ } // lexicographic order - inline bool operator<(const Point& other) const { + inline bool operator<(const CosinePoint& other) const { for (size_t i=0; i<D; i++) { if (data[i] < other.data[i]) { return true; @@ -138,7 +138,51 @@ struct Point{ return false; } - inline double calc_distance(const Point& other) const { + inline double calc_distance(const CosinePoint& other) const { + + double prod = 0; + double asquared = 0; + double bsquared = 0; + + for (size_t i=0; i<D; i++) { + prod += data[i] * other.data[i]; + asquared += data[i]*data[i]; + bsquared += other.data[i]*other.data[i]; + } + + return prod / std::sqrt(asquared * bsquared); + } +}; + + +template <typename V, size_t D=2> +struct EuclidPoint{ + V data[D]; + + inline bool operator==(const EuclidPoint& other) const { + for (size_t i=0; i<D; i++) { + if (data[i] != other.data[i]) { + return false; + } + } + + return true; + } + + // lexicographic order + inline bool operator<(const EuclidPoint& other) const { + for (size_t i=0; i<D; i++) { + if (data[i] < other.data[i]) { + return true; + } else if (data[i] > other.data[i]) { + return false; + } + } + + return false; + } + + inline double calc_distance(const EuclidPoint& other) const { double dist = 0; for (size_t i=0; i<D; i++) { dist += pow(data[i] - other.data[i], 2); diff --git a/tests/testing.h b/tests/testing.h index 1d5db59..4d49474 100644 --- a/tests/testing.h +++ b/tests/testing.h @@ -23,7 +23,7 @@ typedef de::WeightedRecord<uint64_t, uint32_t, uint64_t> WRec; typedef de::Record<uint64_t, uint32_t> Rec; -typedef de::Point<int64_t> PRec; +typedef de::EuclidPoint<int64_t> PRec; template <de::RecordInterface R> std::vector<R> strip_wrapping(std::vector<de::Wrapped<R>> vec) { |