Replace Euclid metric implementation

This commit is contained in:
Kuba Jaroszewski 2023-02-16 21:50:13 +01:00
parent bca31ad33f
commit ac4257495e
9 changed files with 345 additions and 68 deletions

View File

@ -1,5 +1,5 @@
[workspace] [workspace]
members = ["instant-distance", "instant-distance-py"] members = ["distance-metrics", "instant-distance", "instant-distance-py"]
[profile.bench] [profile.bench]
debug = true debug = true

View File

@ -0,0 +1,21 @@
[package]
name = "distance-metrics"
version = "0.6.0"
license = "MIT OR Apache-2.0"
edition = "2021"
rust-version = "1.58"
homepage = "https://github.com/InstantDomain/instant-distance"
repository = "https://github.com/InstantDomain/instant-distance"
documentation = "https://docs.rs/instant-distance"
workspace = ".."
readme = "../README.md"
[dependencies]
[dev-dependencies]
bencher = "0.1.5"
rand = { version = "0.8", features = ["small_rng"] }
[[bench]]
name = "all"
harness = false

View File

@ -0,0 +1,33 @@
use bencher::{benchmark_group, benchmark_main, Bencher};
use distance_metrics::{EuclidMetric, Metric};
use rand::{rngs::StdRng, Rng, SeedableRng};
benchmark_main!(benches);
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,);
fn legacy(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED);
let point_a = distance_metrics::FloatArray([rng.gen(); 300]);
let point_b = distance_metrics::FloatArray([rng.gen(); 300]);
bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b))
}
fn non_simd(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED);
let point_a = [rng.gen(); 300];
let point_b = [rng.gen(); 300];
bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b))
}
fn metric<M: Metric>(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED);
let point_a = [rng.gen(); 300];
let point_b = [rng.gen(); 300];
bench.iter(|| M::distance(&point_a, &point_b))
}
const SEED: u64 = 123456789;

115
distance-metrics/src/lib.rs Normal file
View File

@ -0,0 +1,115 @@
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod simd_sse;
#[cfg(target_arch = "x86_64")]
pub mod simd_avx;
#[cfg(target_arch = "aarch64")]
pub mod simd_neon;
/// Defines how to compare vectors
pub trait Metric {
/// Greater the value - more distant the vectors
fn distance(v1: &[f32], v2: &[f32]) -> f32;
}
#[cfg(target_arch = "x86_64")]
const MIN_DIM_SIZE_AVX: usize = 32;
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "aarch64", target_feature = "neon")
))]
const MIN_DIM_SIZE_SIMD: usize = 16;
#[derive(Clone, Copy)]
pub struct EuclidMetric {}
impl Metric for EuclidMetric {
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx")
&& is_x86_feature_detected!("fma")
&& v1.len() >= MIN_DIM_SIZE_AVX
{
return unsafe { simd_avx::euclid_distance_avx(v1, v2) };
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { simd_sse::euclid_distance_sse(v1, v2) };
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { simple_neon::euclid_distance_neon(v1, v2) };
}
}
euclid_distance(v1, v2)
}
}
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
let s: f32 = v1
.iter()
.copied()
.zip(v2.iter().copied())
.map(|(a, b)| (a - b).powi(2))
.sum();
s.abs().sqrt()
}
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::{
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps,
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
};
debug_assert_eq!(lhs.0.len() % 8, 4);
unsafe {
let mut acc_8x = _mm256_setzero_ps();
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
let diff = _mm256_sub_ps(lh_8x, rh_8x);
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
}
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
let right = _mm256_castps256_ps128(acc_8x); // lower half
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
let diff = _mm_sub_ps(lh_4x, rh_4x);
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
let lower = _mm_movehl_ps(acc_4x, acc_4x);
acc_4x = _mm_add_ps(acc_4x, lower);
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
acc_4x = _mm_add_ss(acc_4x, upper);
_mm_cvtss_f32(acc_4x)
}
}
#[cfg(not(target_arch = "x86_64"))]
lhs.0
.iter()
.zip(rhs.0.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
}
#[repr(align(32))]
pub struct FloatArray(pub [f32; DIMENSIONS]);
const DIMENSIONS: usize = 300;

View File

@ -0,0 +1,54 @@
use std::arch::x86_64::*;
#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
unsafe fn hsum256_ps_avx(x: __m256) -> f32 {
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
_mm_cvtss_f32(x32)
}
#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 32);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub256_1: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0)));
sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1);
let sub256_2: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8)));
sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2);
let sub256_3: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16)));
sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3);
let sub256_4: __m256 =
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24)));
sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4);
ptr1 = ptr1.add(32);
ptr2 = ptr2.add(32);
i += 32;
}
let mut result = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
}
result.abs().sqrt()
}

View File

@ -0,0 +1,38 @@
#[cfg(target_feature = "neon")]
use std::arch::aarch64::*;
#[cfg(target_feature = "neon")]
pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum1 = vdupq_n_f32(0.);
let mut sum2 = vdupq_n_f32(0.);
let mut sum3 = vdupq_n_f32(0.);
let mut sum4 = vdupq_n_f32(0.);
let mut i: usize = 0;
while i < m {
let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2));
sum1 = vfmaq_f32(sum1, sub1, sub1);
let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
sum2 = vfmaq_f32(sum2, sub2, sub2);
let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
sum3 = vfmaq_f32(sum3, sub3, sub3);
let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
sum4 = vfmaq_f32(sum4, sub4, sub4);
ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
}
result.abs().sqrt()
}

View File

@ -0,0 +1,50 @@
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[target_feature(enable = "sse")]
unsafe fn hsum128_ps_sse(x: __m128) -> f32 {
let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x));
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
_mm_cvtss_f32(x32)
}
#[target_feature(enable = "sse")]
pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum128_1: __m128 = _mm_setzero_ps();
let mut sum128_2: __m128 = _mm_setzero_ps();
let mut sum128_3: __m128 = _mm_setzero_ps();
let mut sum128_4: __m128 = _mm_setzero_ps();
let mut i: usize = 0;
while i < m {
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2));
sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1);
let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4)));
sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2);
let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8)));
sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3);
let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12)));
sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4);
ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}
let mut result = hsum128_ps_sse(sum128_1)
+ hsum128_ps_sse(sum128_2)
+ hsum128_ps_sse(sum128_3)
+ hsum128_ps_sse(sum128_4);
for i in 0..n - m {
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
}
result.abs().sqrt()
}

View File

@ -16,7 +16,7 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
bincode = "1.3.1" bincode = "1.3.1"
distance-metrics = { version = "0.6", path = "../distance-metrics" }
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] } instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
pyo3 = { version = "0.18.0", features = ["extension-module"] } pyo3 = { version = "0.18.0", features = ["extension-module"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde-big-array = "0.4.1"

View File

@ -5,15 +5,17 @@ use std::convert::TryFrom;
use std::fs::File; use std::fs::File;
use std::io::{BufReader, BufWriter}; use std::io::{BufReader, BufWriter};
use std::iter::FromIterator; use std::iter::FromIterator;
use std::marker::PhantomData;
use distance_metrics::EuclidMetric;
use distance_metrics::Metric;
use instant_distance::Point; use instant_distance::Point;
use pyo3::conversion::IntoPy; use pyo3::conversion::IntoPy;
use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::exceptions::PyValueError;
use pyo3::types::{PyList, PyModule, PyString}; use pyo3::types::{PyList, PyModule, PyString};
use pyo3::{pyclass, pymethods, pymodule}; use pyo3::{pyclass, pymethods, pymodule};
use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python}; use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_big_array::BigArray;
#[pymodule] #[pymodule]
#[pyo3(name = "instant_distance")] #[pyo3(name = "instant_distance")]
@ -62,8 +64,8 @@ struct HnswMap {
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
enum HnswMapWithMetric { enum HnswMapWithMetric {
Euclid(instant_distance::HnswMap<FloatArray, MapValue>), Euclid(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
Cosine(instant_distance::HnswMap<FloatArray, MapValue>), Cosine(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
} }
#[pymethods] #[pymethods]
@ -126,9 +128,6 @@ impl HnswMap {
} }
/// An instance of hierarchical navigable small worlds /// An instance of hierarchical navigable small worlds
///
/// For now, this is specialized to only support 300-element (32-bit) float vectors
/// with a squared Euclidean distance metric.
#[pyclass] #[pyclass]
struct Hnsw { struct Hnsw {
inner: HnswWithMetric, inner: HnswWithMetric,
@ -136,8 +135,8 @@ struct Hnsw {
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
enum HnswWithMetric { enum HnswWithMetric {
Euclid(instant_distance::Hnsw<FloatArray>), Euclid(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
Cosine(instant_distance::Hnsw<FloatArray>), Cosine(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
} }
#[pymethods] #[pymethods]
@ -414,73 +413,42 @@ impl Neighbor {
} }
} }
#[repr(align(32))]
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); struct FloatArray<M> {
array: Vec<f32>,
phantom: PhantomData<M>,
}
impl FloatArray { impl<M: Metric> FloatArray<M> {
fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> { fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> {
list.into_iter().map(FloatArray::try_from).collect() list.into_iter().map(FloatArray::try_from).collect()
} }
} }
impl TryFrom<&PyAny> for FloatArray { impl<M: Metric> From<Vec<f32>> for FloatArray<M> {
type Error = PyErr; fn from(array: Vec<f32>) -> Self {
Self {
fn try_from(value: &PyAny) -> Result<Self, Self::Error> { array,
let mut new = FloatArray([0.0; DIMENSIONS]); phantom: PhantomData,
for (i, val) in value.iter()?.enumerate() {
match i >= DIMENSIONS {
true => return Err(PyTypeError::new_err("point array too long")),
false => new.0[i] = val?.extract::<f32>()?,
}
} }
Ok(new)
} }
} }
impl Point for FloatArray { impl<M: Metric> TryFrom<&PyAny> for FloatArray<M> {
type Error = PyErr;
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
let array: Vec<f32> = value
.iter()?
.map(|val| val.and_then(|v| v.extract::<f32>()))
.collect::<Result<_, _>>()?;
Ok(Self::from(array))
}
}
impl<M: Metric + Clone + Sync> Point for FloatArray<M> {
fn distance(&self, rhs: &Self) -> f32 { fn distance(&self, rhs: &Self) -> f32 {
#[cfg(target_arch = "x86_64")] M::distance(&self.array, &rhs.array)
{
use std::arch::x86_64::{
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
};
debug_assert_eq!(self.0.len() % 8, 4);
unsafe {
let mut acc_8x = _mm256_setzero_ps();
for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
let diff = _mm256_sub_ps(lh_8x, rh_8x);
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
}
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
let right = _mm256_castps256_ps128(acc_8x); // lower half
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr());
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
let diff = _mm_sub_ps(lh_4x, rh_4x);
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
let lower = _mm_movehl_ps(acc_4x, acc_4x);
acc_4x = _mm_add_ps(acc_4x, lower);
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
acc_4x = _mm_add_ss(acc_4x, upper);
_mm_cvtss_f32(acc_4x)
}
}
#[cfg(not(target_arch = "x86_64"))]
self.0
.iter()
.zip(rhs.0.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
} }
} }
@ -504,5 +472,3 @@ impl IntoPy<Py<PyAny>> for &'_ MapValue {
} }
} }
} }
const DIMENSIONS: usize = 300;