From ac4257495e5f02294e6ff5e8aec4362004495466 Mon Sep 17 00:00:00 2001 From: Kuba Jaroszewski Date: Thu, 16 Feb 2023 21:50:13 +0100 Subject: [PATCH] Replace Euclid metric implementation --- Cargo.toml | 2 +- distance-metrics/Cargo.toml | 21 ++++++ distance-metrics/benches/all.rs | 33 +++++++++ distance-metrics/src/lib.rs | 115 ++++++++++++++++++++++++++++++ distance-metrics/src/simd_avx.rs | 54 ++++++++++++++ distance-metrics/src/simd_neon.rs | 38 ++++++++++ distance-metrics/src/simd_sse.rs | 50 +++++++++++++ instant-distance-py/Cargo.toml | 2 +- instant-distance-py/src/lib.rs | 98 +++++++++---------------- 9 files changed, 345 insertions(+), 68 deletions(-) create mode 100644 distance-metrics/Cargo.toml create mode 100644 distance-metrics/benches/all.rs create mode 100644 distance-metrics/src/lib.rs create mode 100644 distance-metrics/src/simd_avx.rs create mode 100644 distance-metrics/src/simd_neon.rs create mode 100644 distance-metrics/src/simd_sse.rs diff --git a/Cargo.toml b/Cargo.toml index 0f7eac6..38984ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["instant-distance", "instant-distance-py"] +members = ["distance-metrics", "instant-distance", "instant-distance-py"] [profile.bench] debug = true diff --git a/distance-metrics/Cargo.toml b/distance-metrics/Cargo.toml new file mode 100644 index 0000000..7d0b580 --- /dev/null +++ b/distance-metrics/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "distance-metrics" +version = "0.6.0" +license = "MIT OR Apache-2.0" +edition = "2021" +rust-version = "1.58" +homepage = "https://github.com/InstantDomain/instant-distance" +repository = "https://github.com/InstantDomain/instant-distance" +documentation = "https://docs.rs/instant-distance" +workspace = ".." +readme = "../README.md" + +[dependencies] + +[dev-dependencies] +bencher = "0.1.5" +rand = { version = "0.8", features = ["small_rng"] } + +[[bench]] +name = "all" +harness = false diff --git a/distance-metrics/benches/all.rs b/distance-metrics/benches/all.rs new file mode 100644 index 0000000..2c7a305 --- /dev/null +++ b/distance-metrics/benches/all.rs @@ -0,0 +1,33 @@ +use bencher::{benchmark_group, benchmark_main, Bencher}; + +use distance_metrics::{EuclidMetric, Metric}; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +benchmark_main!(benches); +benchmark_group!(benches, legacy, non_simd, metric::,); + +fn legacy(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let point_a = distance_metrics::FloatArray([rng.gen(); 300]); + let point_b = distance_metrics::FloatArray([rng.gen(); 300]); + + bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b)) +} + +fn non_simd(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let point_a = [rng.gen(); 300]; + let point_b = [rng.gen(); 300]; + + bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b)) +} + +fn metric(bench: &mut Bencher) { + let mut rng = StdRng::seed_from_u64(SEED); + let point_a = [rng.gen(); 300]; + let point_b = [rng.gen(); 300]; + + bench.iter(|| M::distance(&point_a, &point_b)) +} + +const SEED: u64 = 123456789; diff --git a/distance-metrics/src/lib.rs b/distance-metrics/src/lib.rs new file mode 100644 index 0000000..520a088 --- /dev/null +++ b/distance-metrics/src/lib.rs @@ -0,0 +1,115 @@ +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub mod simd_sse; + +#[cfg(target_arch = "x86_64")] +pub mod simd_avx; + +#[cfg(target_arch = "aarch64")] +pub mod simd_neon; + +/// Defines how to compare vectors +pub trait Metric { + /// Greater the value - more distant the vectors + fn distance(v1: &[f32], v2: &[f32]) -> f32; +} + +#[cfg(target_arch = "x86_64")] +const MIN_DIM_SIZE_AVX: usize = 32; + +#[cfg(any( + target_arch = "x86", + target_arch = "x86_64", + all(target_arch = "aarch64", target_feature = "neon") +))] +const MIN_DIM_SIZE_SIMD: usize = 16; + +#[derive(Clone, Copy)] +pub struct EuclidMetric {} + +impl Metric for EuclidMetric { + fn distance(v1: &[f32], v2: &[f32]) -> f32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && v1.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { simd_avx::euclid_distance_avx(v1, v2) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { simd_sse::euclid_distance_sse(v1, v2) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { simple_neon::euclid_distance_neon(v1, v2) }; + } + } + + euclid_distance(v1, v2) + } +} + +pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 { + let s: f32 = v1 + .iter() + .copied() + .zip(v2.iter().copied()) + .map(|(a, b)| (a - b).powi(2)) + .sum(); + s.abs().sqrt() +} + +pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 { + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::{ + _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, + _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps, + _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, + }; + debug_assert_eq!(lhs.0.len() % 8, 4); + + unsafe { + let mut acc_8x = _mm256_setzero_ps(); + for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { + let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); + let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); + let diff = _mm256_sub_ps(lh_8x, rh_8x); + acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); + } + + let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half + let right = _mm256_castps256_ps128(acc_8x); // lower half + acc_4x = _mm_add_ps(acc_4x, right); // sum halves + + let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr()); + let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); + let diff = _mm_sub_ps(lh_4x, rh_4x); + acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); + + let lower = _mm_movehl_ps(acc_4x, acc_4x); + acc_4x = _mm_add_ps(acc_4x, lower); + let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); + acc_4x = _mm_add_ss(acc_4x, upper); + _mm_cvtss_f32(acc_4x) + } + } + #[cfg(not(target_arch = "x86_64"))] + lhs.0 + .iter() + .zip(rhs.0.iter()) + .map(|(&a, &b)| (a - b).powi(2)) + .sum::() +} + +#[repr(align(32))] +pub struct FloatArray(pub [f32; DIMENSIONS]); + +const DIMENSIONS: usize = 300; diff --git a/distance-metrics/src/simd_avx.rs b/distance-metrics/src/simd_avx.rs new file mode 100644 index 0000000..bd120a8 --- /dev/null +++ b/distance-metrics/src/simd_avx.rs @@ -0,0 +1,54 @@ +use std::arch::x86_64::*; + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +unsafe fn hsum256_ps_avx(x: __m256) -> f32 { + let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + _mm_cvtss_f32(x32) +} + +#[target_feature(enable = "avx")] +#[target_feature(enable = "fma")] +pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 32); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum256_1: __m256 = _mm256_setzero_ps(); + let mut sum256_2: __m256 = _mm256_setzero_ps(); + let mut sum256_3: __m256 = _mm256_setzero_ps(); + let mut sum256_4: __m256 = _mm256_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub256_1: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0))); + sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1); + + let sub256_2: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8))); + sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2); + + let sub256_3: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16))); + sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3); + + let sub256_4: __m256 = + _mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24))); + sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4); + + ptr1 = ptr1.add(32); + ptr2 = ptr2.add(32); + i += 32; + } + + let mut result = hsum256_ps_avx(sum256_1) + + hsum256_ps_avx(sum256_2) + + hsum256_ps_avx(sum256_3) + + hsum256_ps_avx(sum256_4); + for i in 0..n - m { + result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); + } + result.abs().sqrt() +} diff --git a/distance-metrics/src/simd_neon.rs b/distance-metrics/src/simd_neon.rs new file mode 100644 index 0000000..7211d49 --- /dev/null +++ b/distance-metrics/src/simd_neon.rs @@ -0,0 +1,38 @@ +#[cfg(target_feature = "neon")] +use std::arch::aarch64::*; + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2)); + sum1 = vfmaq_f32(sum1, sub1, sub1); + + let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4))); + sum2 = vfmaq_f32(sum2, sub2, sub2); + + let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8))); + sum3 = vfmaq_f32(sum3, sub3, sub3); + + let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12))); + sum4 = vfmaq_f32(sum4, sub4, sub4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); + } + result.abs().sqrt() +} diff --git a/distance-metrics/src/simd_sse.rs b/distance-metrics/src/simd_sse.rs new file mode 100644 index 0000000..eebc249 --- /dev/null +++ b/distance-metrics/src/simd_sse.rs @@ -0,0 +1,50 @@ +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; + +#[target_feature(enable = "sse")] +unsafe fn hsum128_ps_sse(x: __m128) -> f32 { + let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x)); + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + _mm_cvtss_f32(x32) +} + +#[target_feature(enable = "sse")] +pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum128_1: __m128 = _mm_setzero_ps(); + let mut sum128_2: __m128 = _mm_setzero_ps(); + let mut sum128_3: __m128 = _mm_setzero_ps(); + let mut sum128_4: __m128 = _mm_setzero_ps(); + let mut i: usize = 0; + while i < m { + let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)); + sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1); + + let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))); + sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2); + + let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))); + sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3); + + let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))); + sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + + let mut result = hsum128_ps_sse(sum128_1) + + hsum128_ps_sse(sum128_2) + + hsum128_ps_sse(sum128_3) + + hsum128_ps_sse(sum128_4); + for i in 0..n - m { + result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); + } + result.abs().sqrt() +} diff --git a/instant-distance-py/Cargo.toml b/instant-distance-py/Cargo.toml index eafab40..d749518 100644 --- a/instant-distance-py/Cargo.toml +++ b/instant-distance-py/Cargo.toml @@ -16,7 +16,7 @@ crate-type = ["cdylib"] [dependencies] bincode = "1.3.1" +distance-metrics = { version = "0.6", path = "../distance-metrics" } instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] } pyo3 = { version = "0.18.0", features = ["extension-module"] } serde = { version = "1", features = ["derive"] } -serde-big-array = "0.4.1" diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index 223c557..4d98768 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -5,15 +5,17 @@ use std::convert::TryFrom; use std::fs::File; use std::io::{BufReader, BufWriter}; use std::iter::FromIterator; +use std::marker::PhantomData; +use distance_metrics::EuclidMetric; +use distance_metrics::Metric; use instant_distance::Point; use pyo3::conversion::IntoPy; -use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::exceptions::PyValueError; use pyo3::types::{PyList, PyModule, PyString}; use pyo3::{pyclass, pymethods, pymodule}; use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python}; use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; #[pymodule] #[pyo3(name = "instant_distance")] @@ -62,8 +64,8 @@ struct HnswMap { #[derive(Deserialize, Serialize)] enum HnswMapWithMetric { - Euclid(instant_distance::HnswMap), - Cosine(instant_distance::HnswMap), + Euclid(instant_distance::HnswMap, MapValue>), + Cosine(instant_distance::HnswMap, MapValue>), } #[pymethods] @@ -126,9 +128,6 @@ impl HnswMap { } /// An instance of hierarchical navigable small worlds -/// -/// For now, this is specialized to only support 300-element (32-bit) float vectors -/// with a squared Euclidean distance metric. #[pyclass] struct Hnsw { inner: HnswWithMetric, @@ -136,8 +135,8 @@ struct Hnsw { #[derive(Deserialize, Serialize)] enum HnswWithMetric { - Euclid(instant_distance::Hnsw), - Cosine(instant_distance::Hnsw), + Euclid(instant_distance::Hnsw>), + Cosine(instant_distance::Hnsw>), } #[pymethods] @@ -414,73 +413,42 @@ impl Neighbor { } } -#[repr(align(32))] #[derive(Clone, Deserialize, Serialize)] -struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); +struct FloatArray { + array: Vec, + phantom: PhantomData, +} -impl FloatArray { +impl FloatArray { fn try_from_pylist(list: &PyList) -> Result, PyErr> { list.into_iter().map(FloatArray::try_from).collect() } } -impl TryFrom<&PyAny> for FloatArray { - type Error = PyErr; - - fn try_from(value: &PyAny) -> Result { - let mut new = FloatArray([0.0; DIMENSIONS]); - for (i, val) in value.iter()?.enumerate() { - match i >= DIMENSIONS { - true => return Err(PyTypeError::new_err("point array too long")), - false => new.0[i] = val?.extract::()?, - } +impl From> for FloatArray { + fn from(array: Vec) -> Self { + Self { + array, + phantom: PhantomData, } - Ok(new) } } -impl Point for FloatArray { +impl TryFrom<&PyAny> for FloatArray { + type Error = PyErr; + + fn try_from(value: &PyAny) -> Result { + let array: Vec = value + .iter()? + .map(|val| val.and_then(|v| v.extract::())) + .collect::>()?; + Ok(Self::from(array)) + } +} + +impl Point for FloatArray { fn distance(&self, rhs: &Self) -> f32 { - #[cfg(target_arch = "x86_64")] - { - use std::arch::x86_64::{ - _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, - _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, - _mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, - }; - debug_assert_eq!(self.0.len() % 8, 4); - - unsafe { - let mut acc_8x = _mm256_setzero_ps(); - for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { - let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); - let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); - let diff = _mm256_sub_ps(lh_8x, rh_8x); - acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); - } - - let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half - let right = _mm256_castps256_ps128(acc_8x); // lower half - acc_4x = _mm_add_ps(acc_4x, right); // sum halves - - let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr()); - let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); - let diff = _mm_sub_ps(lh_4x, rh_4x); - acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); - - let lower = _mm_movehl_ps(acc_4x, acc_4x); - acc_4x = _mm_add_ps(acc_4x, lower); - let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); - acc_4x = _mm_add_ss(acc_4x, upper); - _mm_cvtss_f32(acc_4x) - } - } - #[cfg(not(target_arch = "x86_64"))] - self.0 - .iter() - .zip(rhs.0.iter()) - .map(|(&a, &b)| (a - b).powi(2)) - .sum::() + M::distance(&self.array, &rhs.array) } } @@ -504,5 +472,3 @@ impl IntoPy> for &'_ MapValue { } } } - -const DIMENSIONS: usize = 300;