summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDouglas Rumbaugh <dbr4@psu.edu>2023-07-24 11:50:33 -0400
committerDouglas Rumbaugh <dbr4@psu.edu>2023-07-24 11:50:33 -0400
commitac018f5f96c32c96158a239fbfeb9dc439c95548 (patch)
tree22cfdf8aa0fd7f1680c37f38ec359a4dc69bada2
parent6b434ec5f2182cb9624a011bd8d65587cd5a0759 (diff)
downloaddynamic-extension-ac018f5f96c32c96158a239fbfeb9dc439c95548.tar.gz
Cosine Similarity Type
-rw-r--r--benchmarks/include/bench_utility.h2
-rw-r--r--include/framework/RecordInterface.h52
-rw-r--r--tests/testing.h2
3 files changed, 50 insertions, 6 deletions
diff --git a/benchmarks/include/bench_utility.h b/benchmarks/include/bench_utility.h
index a1a2773..b728cbd 100644
--- a/benchmarks/include/bench_utility.h
+++ b/benchmarks/include/bench_utility.h
@@ -38,7 +38,7 @@ typedef uint64_t weight_type;
typedef de::WeightedRecord<key_type, value_type, weight_type> WRec;
typedef de::Record<key_type, value_type> Rec;
-typedef de::Point<double, 300> Word2VecRec;
+typedef de::CosinePoint<double, 300> Word2VecRec;
typedef de::DynamicExtension<WRec, de::WSS<WRec>, de::WSSQuery<WRec>> ExtendedWSS;
typedef de::DynamicExtension<Rec, de::TrieSpline<Rec>, de::TrieSplineRangeQuery<Rec>> ExtendedTSRQ;
diff --git a/include/framework/RecordInterface.h b/include/framework/RecordInterface.h
index 8d40590..85a0794 100644
--- a/include/framework/RecordInterface.h
+++ b/include/framework/RecordInterface.h
@@ -112,10 +112,10 @@ struct WeightedRecord {
};
template <typename V, size_t D=2>
-struct Point{
+struct CosinePoint{
V data[D];
- inline bool operator==(const Point& other) const {
+ inline bool operator==(const CosinePoint& other) const {
for (size_t i=0; i<D; i++) {
if (data[i] != other.data[i]) {
return false;
@@ -126,7 +126,7 @@ struct Point{
}
// lexicographic order
- inline bool operator<(const Point& other) const {
+ inline bool operator<(const CosinePoint& other) const {
for (size_t i=0; i<D; i++) {
if (data[i] < other.data[i]) {
return true;
@@ -138,7 +138,51 @@ struct Point{
return false;
}
- inline double calc_distance(const Point& other) const {
+ inline double calc_distance(const CosinePoint& other) const {
+
+ double prod = 0;
+ double asquared = 0;
+ double bsquared = 0;
+
+ for (size_t i=0; i<D; i++) {
+ prod += data[i] * other.data[i];
+ asquared += data[i]*data[i];
+ bsquared += other.data[i]*other.data[i];
+ }
+
+ return prod / std::sqrt(asquared * bsquared);
+ }
+};
+
+
+template <typename V, size_t D=2>
+struct EuclidPoint{
+ V data[D];
+
+ inline bool operator==(const EuclidPoint& other) const {
+ for (size_t i=0; i<D; i++) {
+ if (data[i] != other.data[i]) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ // lexicographic order
+ inline bool operator<(const EuclidPoint& other) const {
+ for (size_t i=0; i<D; i++) {
+ if (data[i] < other.data[i]) {
+ return true;
+ } else if (data[i] > other.data[i]) {
+ return false;
+ }
+ }
+
+ return false;
+ }
+
+ inline double calc_distance(const EuclidPoint& other) const {
double dist = 0;
for (size_t i=0; i<D; i++) {
dist += pow(data[i] - other.data[i], 2);
diff --git a/tests/testing.h b/tests/testing.h
index 1d5db59..4d49474 100644
--- a/tests/testing.h
+++ b/tests/testing.h
@@ -23,7 +23,7 @@
typedef de::WeightedRecord<uint64_t, uint32_t, uint64_t> WRec;
typedef de::Record<uint64_t, uint32_t> Rec;
-typedef de::Point<int64_t> PRec;
+typedef de::EuclidPoint<int64_t> PRec;
template <de::RecordInterface R>
std::vector<R> strip_wrapping(std::vector<de::Wrapped<R>> vec) {