diff --git a/Makefile b/Makefile index 6825e3b..1b40721 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ test-python: instant-distance-py/test/instant_distance.so bench-python: instant-distance-py/test/instant_distance.so PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)' + PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)' clean: cargo clean diff --git a/distance-metrics/benches/all.rs b/distance-metrics/benches/all.rs index 2c7a305..b3405f7 100644 --- a/distance-metrics/benches/all.rs +++ b/distance-metrics/benches/all.rs @@ -1,10 +1,18 @@ use bencher::{benchmark_group, benchmark_main, Bencher}; -use distance_metrics::{EuclidMetric, Metric}; +use distance_metrics::{ + Metric, {CosineMetric, EuclidMetric}, +}; use rand::{rngs::StdRng, Rng, SeedableRng}; benchmark_main!(benches); -benchmark_group!(benches, legacy, non_simd, metric::,); +benchmark_group!( + benches, + legacy, + non_simd, + metric::, + metric:: +); fn legacy(bench: &mut Bencher) { let mut rng = StdRng::seed_from_u64(SEED); @@ -24,8 +32,10 @@ fn non_simd(bench: &mut Bencher) { fn metric(bench: &mut Bencher) { let mut rng = StdRng::seed_from_u64(SEED); - let point_a = [rng.gen(); 300]; - let point_b = [rng.gen(); 300]; + let mut point_a = [rng.gen(); 300]; + let mut point_b = [rng.gen(); 300]; + M::preprocess(&mut point_a); + M::preprocess(&mut point_b); bench.iter(|| M::distance(&point_a, &point_b)) } diff --git a/distance-metrics/src/lib.rs b/distance-metrics/src/lib.rs index 520a088..e197918 100644 --- a/distance-metrics/src/lib.rs +++ b/distance-metrics/src/lib.rs @@ -11,6 +11,9 @@ pub mod simd_neon; pub trait Metric { /// Greater the value - more distant the vectors fn distance(v1: &[f32], v2: &[f32]) -> f32; + + /// Necessary vector transformations performed before adding it to the collection (like normalization) + fn preprocess(vector: &mut [f32]); } #[cfg(target_arch = "x86_64")] @@ -54,6 +57,72 @@ impl Metric for EuclidMetric { euclid_distance(v1, v2) } + + fn preprocess(_vector: &mut [f32]) { + // no-op + } +} + +#[derive(Clone, Copy)] +pub struct CosineMetric {} + +impl Metric for CosineMetric { + fn distance(v1: &[f32], v2: &[f32]) -> f32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && v1.len() >= MIN_DIM_SIZE_AVX + { + return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD { + return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD { + return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) }; + } + } + + 1.0 - dot_similarity(v1, v2) + } + + fn preprocess(vector: &mut [f32]) { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && vector.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { simd_avx::cosine_preprocess_avx(vector) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { simd_sse::cosine_preprocess_sse(vector) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD + { + return unsafe { simd_neon::cosine_preprocess_neon(vector) }; + } + } + + cosine_preprocess(vector); + } } pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 { @@ -66,6 +135,21 @@ pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 { s.abs().sqrt() } +pub fn cosine_preprocess(vector: &mut [f32]) { + let mut length: f32 = vector.iter().map(|x| x * x).sum(); + if length < f32::EPSILON { + return; + } + length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 { + v1.iter().zip(v2).map(|(a, b)| a * b).sum() +} + pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 { #[cfg(target_arch = "x86_64")] { diff --git a/distance-metrics/src/simd_avx.rs b/distance-metrics/src/simd_avx.rs index bd120a8..ea8a566 100644 --- a/distance-metrics/src/simd_avx.rs +++ b/distance-metrics/src/simd_avx.rs @@ -52,3 +52,135 @@ pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 { } result.abs().sqrt() } + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn cosine_preprocess_avx(vector: &mut [f32]) { + let n = vector.len(); + let m = n - (n % 32); + let mut ptr: *const f32 = vector.as_ptr(); + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + let m256_1 = _mm256_loadu_ps(ptr); + sum256_1 = _mm256_fmadd_ps(m256_1, m256_1, sum256_1); + + let m256_2 = _mm256_loadu_ps(ptr.add(8)); + sum256_2 = _mm256_fmadd_ps(m256_2, m256_2, sum256_2); + + let m256_3 = _mm256_loadu_ps(ptr.add(16)); + sum256_3 = _mm256_fmadd_ps(m256_3, m256_3, sum256_3); + + let m256_4 = _mm256_loadu_ps(ptr.add(24)); + sum256_4 = _mm256_fmadd_ps(m256_4, m256_4, sum256_4); + + ptr = ptr.add(32); + i += 32; + } + + let mut length = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + for i in 0..n - m { + length += (*ptr.add(i)).powi(2); + } + if length < f32::EPSILON { + return; + } + length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn dot_similarity_avx(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1); + sum256_2 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(8)), + _mm256_loadu_ps(ptr2.add(8)), + sum256_2, + ); + sum256_3 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(16)), + _mm256_loadu_ps(ptr2.add(16)), + sum256_3, + ); + sum256_4 = _mm256_fmadd_ps( + _mm256_loadu_ps(ptr1.add(24)), + _mm256_loadu_ps(ptr2.add(24)), + sum256_4, + ); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } + + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + + for i in 0..n - m { + result += (*ptr1.add(i)) * (*ptr2.add(i)); + } + result +} + +#[cfg(test)] +mod tests { + #[test] + fn test_spaces_avx() { + use super::*; + use crate::*; + + if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 56., 57., 58., 59., 60., 61., + ]; + + let euclid_simd = unsafe { euclid_distance_avx(&v1, &v2) }; + let euclid = euclid_distance(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_avx(&v1, &v2) }; + let dot = dot_similarity(&v1, &v2); + assert_eq!(dot_simd, dot); + + let mut v1 = v1; + let mut v1_copy = v1.clone(); + unsafe { cosine_preprocess_avx(&mut v1) }; + cosine_preprocess(&mut v1_copy); + assert_eq!(v1, v1_copy); + } else { + println!("avx test skipped"); + } + } +} diff --git a/distance-metrics/src/simd_neon.rs b/distance-metrics/src/simd_neon.rs index 7211d49..7141d85 100644 --- a/distance-metrics/src/simd_neon.rs +++ b/distance-metrics/src/simd_neon.rs @@ -36,3 +36,108 @@ pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 { } result.abs().sqrt() } + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn cosine_preprocess_neon(vector: &mut [f32]) { + let n = vector.len(); + let m = n - (n % 16); + let mut ptr: *const f32 = vector.as_ptr(); + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + let d1 = vld1q_f32(ptr); + sum1 = vfmaq_f32(sum1, d1, d1); + + let d2 = vld1q_f32(ptr.add(4)); + sum2 = vfmaq_f32(sum2, d2, d2); + + let d3 = vld1q_f32(ptr.add(8)); + sum3 = vfmaq_f32(sum3, d3, d3); + + let d4 = vld1q_f32(ptr.add(12)); + sum4 = vfmaq_f32(sum4, d4, d4); + + ptr = ptr.add(16); + i += 16; + } + let mut length = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for v in vector.iter().take(n).skip(m) { + length += v.powi(2); + } + if length < f32::EPSILON { + return; + } + let length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn dot_similarity_neon(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + sum1 = vfmaq_f32(sum1, vld1q_f32(ptr1), vld1q_f32(ptr2)); + sum2 = vfmaq_f32(sum2, vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4))); + sum3 = vfmaq_f32(sum3, vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8))); + sum4 = vfmaq_f32(sum4, vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12))); + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + result += (*ptr1.add(i)) * (*ptr2.add(i)); + } + result +} + +#[cfg(test)] +mod tests { + #[cfg(target_feature = "neon")] + #[test] + fn test_spaces_neon() { + use super::*; + use crate::*; + + if std::arch::is_aarch64_feature_detected!("neon") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 56., 57., 58., 59., 60., 61., + ]; + + let euclid_simd = unsafe { euclid_distance_neon(&v1, &v2) }; + let euclid = euclid_distance(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_neon(&v1, &v2) }; + let dot = dot_similarity(&v1, &v2); + assert_eq!(dot_simd, dot); + + let mut v1 = v1; + let mut v1_copy = v1.clone(); + unsafe { cosine_preprocess_neon(&mut v1) }; + cosine_preprocess(&mut v1_copy); + assert_eq!(v1, v1_copy); + } else { + println!("neon test skipped"); + } + } +} diff --git a/distance-metrics/src/simd_sse.rs b/distance-metrics/src/simd_sse.rs index eebc249..596afba 100644 --- a/distance-metrics/src/simd_sse.rs +++ b/distance-metrics/src/simd_sse.rs @@ -48,3 +48,132 @@ pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 { } result.abs().sqrt() } + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn cosine_preprocess_sse(vector: &mut [f32]) { + let n = vector.len(); + let m = n - (n % 16); + let mut ptr: *const f32 = vector.as_ptr(); + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + + let mut i: usize = 0; + while i < m { + let m128_1 = _mm_loadu_ps(ptr); + sum128_1 = _mm_add_ps(_mm_mul_ps(m128_1, m128_1), sum128_1); + + let m128_2 = _mm_loadu_ps(ptr.add(4)); + sum128_2 = _mm_add_ps(_mm_mul_ps(m128_2, m128_2), sum128_2); + + let m128_3 = _mm_loadu_ps(ptr.add(8)); + sum128_3 = _mm_add_ps(_mm_mul_ps(m128_3, m128_3), sum128_3); + + let m128_4 = _mm_loadu_ps(ptr.add(12)); + sum128_4 = _mm_add_ps(_mm_mul_ps(m128_4, m128_4), sum128_4); + + ptr = ptr.add(16); + i += 16; + } + + let mut length = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + length += (*ptr.add(i)).powi(2); + } + if length < f32::EPSILON { + return; + } + length = length.sqrt(); + for x in vector.iter_mut() { + *x /= length; + } +} + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn dot_similarity_sse(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + + let mut i: usize = 0; + while i < m { + sum128_1 = _mm_add_ps(_mm_mul_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)), sum128_1); + + sum128_2 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))), + sum128_2, + ); + + sum128_3 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))), + sum128_3, + ); + + sum128_4 = _mm_add_ps( + _mm_mul_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))), + sum128_4, + ); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + result += (*ptr1.add(i)) * (*ptr2.add(i)); + } + result +} + +#[cfg(test)] +mod tests { + #[test] + fn test_spaces_sse() { + use super::*; + use crate::*; + + if is_x86_feature_detected!("sse") { + let v1: Vec = vec![ + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 26., 27., 28., 29., 30., 31., + ]; + let v2: Vec = vec![ + 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., + 56., 57., 58., 59., 60., 61., + ]; + + let euclid_simd = unsafe { euclid_distance_sse(&v1, &v2) }; + let euclid = euclid_distance(&v1, &v2); + assert_eq!(euclid_simd, euclid); + + let dot_simd = unsafe { dot_similarity_sse(&v1, &v2) }; + let dot = dot_similarity(&v1, &v2); + assert_eq!(dot_simd, dot); + + let mut v1 = v1; + let mut v1_copy = v1.clone(); + unsafe { cosine_preprocess_sse(&mut v1) }; + cosine_preprocess(&mut v1_copy); + assert_eq!(v1, v1_copy); + } else { + println!("sse test skipped"); + } + } +} diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index 4d98768..6f40449 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -7,8 +7,8 @@ use std::io::{BufReader, BufWriter}; use std::iter::FromIterator; use std::marker::PhantomData; -use distance_metrics::EuclidMetric; use distance_metrics::Metric; +use distance_metrics::{CosineMetric, EuclidMetric}; use instant_distance::Point; use pyo3::conversion::IntoPy; use pyo3::exceptions::PyValueError; @@ -65,7 +65,7 @@ struct HnswMap { #[derive(Deserialize, Serialize)] enum HnswMapWithMetric { Euclid(instant_distance::HnswMap, MapValue>), - Cosine(instant_distance::HnswMap, MapValue>), + Cosine(instant_distance::HnswMap, MapValue>), } #[pymethods] @@ -136,7 +136,7 @@ struct Hnsw { #[derive(Deserialize, Serialize)] enum HnswWithMetric { Euclid(instant_distance::Hnsw>), - Cosine(instant_distance::Hnsw>), + Cosine(instant_distance::Hnsw>), } #[pymethods] @@ -426,7 +426,8 @@ impl FloatArray { } impl From> for FloatArray { - fn from(array: Vec) -> Self { + fn from(mut array: Vec) -> Self { + M::preprocess(&mut array); Self { array, phantom: PhantomData,