Add Cosine distance metric

This commit is contained in:
Kuba Jaroszewski 2023-02-16 21:55:50 +01:00
parent ac4257495e
commit 9d8c83fc38
7 changed files with 470 additions and 8 deletions

View File

@ -8,6 +8,7 @@ test-python: instant-distance-py/test/instant_distance.so
bench-python: instant-distance-py/test/instant_distance.so bench-python: instant-distance-py/test/instant_distance.so
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)' PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)'
clean: clean:
cargo clean cargo clean

View File

@ -1,10 +1,18 @@
use bencher::{benchmark_group, benchmark_main, Bencher}; use bencher::{benchmark_group, benchmark_main, Bencher};
use distance_metrics::{EuclidMetric, Metric}; use distance_metrics::{
Metric, {CosineMetric, EuclidMetric},
};
use rand::{rngs::StdRng, Rng, SeedableRng}; use rand::{rngs::StdRng, Rng, SeedableRng};
benchmark_main!(benches); benchmark_main!(benches);
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,); benchmark_group!(
benches,
legacy,
non_simd,
metric::<EuclidMetric>,
metric::<CosineMetric>
);
fn legacy(bench: &mut Bencher) { fn legacy(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED); let mut rng = StdRng::seed_from_u64(SEED);
@ -24,8 +32,10 @@ fn non_simd(bench: &mut Bencher) {
fn metric<M: Metric>(bench: &mut Bencher) { fn metric<M: Metric>(bench: &mut Bencher) {
let mut rng = StdRng::seed_from_u64(SEED); let mut rng = StdRng::seed_from_u64(SEED);
let point_a = [rng.gen(); 300]; let mut point_a = [rng.gen(); 300];
let point_b = [rng.gen(); 300]; let mut point_b = [rng.gen(); 300];
M::preprocess(&mut point_a);
M::preprocess(&mut point_b);
bench.iter(|| M::distance(&point_a, &point_b)) bench.iter(|| M::distance(&point_a, &point_b))
} }

View File

@ -11,6 +11,9 @@ pub mod simd_neon;
pub trait Metric { pub trait Metric {
/// Greater the value - more distant the vectors /// Greater the value - more distant the vectors
fn distance(v1: &[f32], v2: &[f32]) -> f32; fn distance(v1: &[f32], v2: &[f32]) -> f32;
/// Necessary vector transformations performed before adding it to the collection (like normalization)
fn preprocess(vector: &mut [f32]);
} }
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
@ -54,6 +57,72 @@ impl Metric for EuclidMetric {
euclid_distance(v1, v2) euclid_distance(v1, v2)
} }
fn preprocess(_vector: &mut [f32]) {
// no-op
}
}
#[derive(Clone, Copy)]
pub struct CosineMetric {}
impl Metric for CosineMetric {
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx")
&& is_x86_feature_detected!("fma")
&& v1.len() >= MIN_DIM_SIZE_AVX
{
return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) };
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) };
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) };
}
}
1.0 - dot_similarity(v1, v2)
}
fn preprocess(vector: &mut [f32]) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx")
&& is_x86_feature_detected!("fma")
&& vector.len() >= MIN_DIM_SIZE_AVX
{
return unsafe { simd_avx::cosine_preprocess_avx(vector) };
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { simd_sse::cosine_preprocess_sse(vector) };
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD
{
return unsafe { simd_neon::cosine_preprocess_neon(vector) };
}
}
cosine_preprocess(vector);
}
} }
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 { pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
@ -66,6 +135,21 @@ pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
s.abs().sqrt() s.abs().sqrt()
} }
pub fn cosine_preprocess(vector: &mut [f32]) {
let mut length: f32 = vector.iter().map(|x| x * x).sum();
if length < f32::EPSILON {
return;
}
length = length.sqrt();
for x in vector.iter_mut() {
*x /= length;
}
}
pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 {
v1.iter().zip(v2).map(|(a, b)| a * b).sum()
}
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 { pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
{ {

View File

@ -52,3 +52,135 @@ pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
} }
result.abs().sqrt() result.abs().sqrt()
} }
#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
pub(crate) unsafe fn cosine_preprocess_avx(vector: &mut [f32]) {
let n = vector.len();
let m = n - (n % 32);
let mut ptr: *const f32 = vector.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
let m256_1 = _mm256_loadu_ps(ptr);
sum256_1 = _mm256_fmadd_ps(m256_1, m256_1, sum256_1);
let m256_2 = _mm256_loadu_ps(ptr.add(8));
sum256_2 = _mm256_fmadd_ps(m256_2, m256_2, sum256_2);
let m256_3 = _mm256_loadu_ps(ptr.add(16));
sum256_3 = _mm256_fmadd_ps(m256_3, m256_3, sum256_3);
let m256_4 = _mm256_loadu_ps(ptr.add(24));
sum256_4 = _mm256_fmadd_ps(m256_4, m256_4, sum256_4);
ptr = ptr.add(32);
i += 32;
}
let mut length = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
length += (*ptr.add(i)).powi(2);
}
if length < f32::EPSILON {
return;
}
length = length.sqrt();
for x in vector.iter_mut() {
*x /= length;
}
}
#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
pub(crate) unsafe fn dot_similarity_avx(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 32);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum256_1: __m256 = _mm256_setzero_ps();
let mut sum256_2: __m256 = _mm256_setzero_ps();
let mut sum256_3: __m256 = _mm256_setzero_ps();
let mut sum256_4: __m256 = _mm256_setzero_ps();
let mut i: usize = 0;
while i < m {
sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1);
sum256_2 = _mm256_fmadd_ps(
_mm256_loadu_ps(ptr1.add(8)),
_mm256_loadu_ps(ptr2.add(8)),
sum256_2,
);
sum256_3 = _mm256_fmadd_ps(
_mm256_loadu_ps(ptr1.add(16)),
_mm256_loadu_ps(ptr2.add(16)),
sum256_3,
);
sum256_4 = _mm256_fmadd_ps(
_mm256_loadu_ps(ptr1.add(24)),
_mm256_loadu_ps(ptr2.add(24)),
sum256_4,
);
ptr1 = ptr1.add(32);
ptr2 = ptr2.add(32);
i += 32;
}
let mut result = hsum256_ps_avx(sum256_1)
+ hsum256_ps_avx(sum256_2)
+ hsum256_ps_avx(sum256_3)
+ hsum256_ps_avx(sum256_4);
for i in 0..n - m {
result += (*ptr1.add(i)) * (*ptr2.add(i));
}
result
}
#[cfg(test)]
mod tests {
#[test]
fn test_spaces_avx() {
use super::*;
use crate::*;
if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
let v1: Vec<f32> = vec![
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
26., 27., 28., 29., 30., 31.,
];
let v2: Vec<f32> = vec![
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
56., 57., 58., 59., 60., 61.,
];
let euclid_simd = unsafe { euclid_distance_avx(&v1, &v2) };
let euclid = euclid_distance(&v1, &v2);
assert_eq!(euclid_simd, euclid);
let dot_simd = unsafe { dot_similarity_avx(&v1, &v2) };
let dot = dot_similarity(&v1, &v2);
assert_eq!(dot_simd, dot);
let mut v1 = v1;
let mut v1_copy = v1.clone();
unsafe { cosine_preprocess_avx(&mut v1) };
cosine_preprocess(&mut v1_copy);
assert_eq!(v1, v1_copy);
} else {
println!("avx test skipped");
}
}
}

View File

@ -36,3 +36,108 @@ pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
} }
result.abs().sqrt() result.abs().sqrt()
} }
#[cfg(target_feature = "neon")]
pub(crate) unsafe fn cosine_preprocess_neon(vector: &mut [f32]) {
let n = vector.len();
let m = n - (n % 16);
let mut ptr: *const f32 = vector.as_ptr();
let mut sum1 = vdupq_n_f32(0.);
let mut sum2 = vdupq_n_f32(0.);
let mut sum3 = vdupq_n_f32(0.);
let mut sum4 = vdupq_n_f32(0.);
let mut i: usize = 0;
while i < m {
let d1 = vld1q_f32(ptr);
sum1 = vfmaq_f32(sum1, d1, d1);
let d2 = vld1q_f32(ptr.add(4));
sum2 = vfmaq_f32(sum2, d2, d2);
let d3 = vld1q_f32(ptr.add(8));
sum3 = vfmaq_f32(sum3, d3, d3);
let d4 = vld1q_f32(ptr.add(12));
sum4 = vfmaq_f32(sum4, d4, d4);
ptr = ptr.add(16);
i += 16;
}
let mut length = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
for v in vector.iter().take(n).skip(m) {
length += v.powi(2);
}
if length < f32::EPSILON {
return;
}
let length = length.sqrt();
for x in vector.iter_mut() {
*x /= length;
}
}
#[cfg(target_feature = "neon")]
pub(crate) unsafe fn dot_similarity_neon(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum1 = vdupq_n_f32(0.);
let mut sum2 = vdupq_n_f32(0.);
let mut sum3 = vdupq_n_f32(0.);
let mut sum4 = vdupq_n_f32(0.);
let mut i: usize = 0;
while i < m {
sum1 = vfmaq_f32(sum1, vld1q_f32(ptr1), vld1q_f32(ptr2));
sum2 = vfmaq_f32(sum2, vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
sum3 = vfmaq_f32(sum3, vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
sum4 = vfmaq_f32(sum4, vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
for i in 0..n - m {
result += (*ptr1.add(i)) * (*ptr2.add(i));
}
result
}
#[cfg(test)]
mod tests {
#[cfg(target_feature = "neon")]
#[test]
fn test_spaces_neon() {
use super::*;
use crate::*;
if std::arch::is_aarch64_feature_detected!("neon") {
let v1: Vec<f32> = vec![
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
26., 27., 28., 29., 30., 31.,
];
let v2: Vec<f32> = vec![
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
56., 57., 58., 59., 60., 61.,
];
let euclid_simd = unsafe { euclid_distance_neon(&v1, &v2) };
let euclid = euclid_distance(&v1, &v2);
assert_eq!(euclid_simd, euclid);
let dot_simd = unsafe { dot_similarity_neon(&v1, &v2) };
let dot = dot_similarity(&v1, &v2);
assert_eq!(dot_simd, dot);
let mut v1 = v1;
let mut v1_copy = v1.clone();
unsafe { cosine_preprocess_neon(&mut v1) };
cosine_preprocess(&mut v1_copy);
assert_eq!(v1, v1_copy);
} else {
println!("neon test skipped");
}
}
}

View File

@ -48,3 +48,132 @@ pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 {
} }
result.abs().sqrt() result.abs().sqrt()
} }
#[target_feature(enable = "sse")]
pub(crate) unsafe fn cosine_preprocess_sse(vector: &mut [f32]) {
let n = vector.len();
let m = n - (n % 16);
let mut ptr: *const f32 = vector.as_ptr();
let mut sum128_1: __m128 = _mm_setzero_ps();
let mut sum128_2: __m128 = _mm_setzero_ps();
let mut sum128_3: __m128 = _mm_setzero_ps();
let mut sum128_4: __m128 = _mm_setzero_ps();
let mut i: usize = 0;
while i < m {
let m128_1 = _mm_loadu_ps(ptr);
sum128_1 = _mm_add_ps(_mm_mul_ps(m128_1, m128_1), sum128_1);
let m128_2 = _mm_loadu_ps(ptr.add(4));
sum128_2 = _mm_add_ps(_mm_mul_ps(m128_2, m128_2), sum128_2);
let m128_3 = _mm_loadu_ps(ptr.add(8));
sum128_3 = _mm_add_ps(_mm_mul_ps(m128_3, m128_3), sum128_3);
let m128_4 = _mm_loadu_ps(ptr.add(12));
sum128_4 = _mm_add_ps(_mm_mul_ps(m128_4, m128_4), sum128_4);
ptr = ptr.add(16);
i += 16;
}
let mut length = hsum128_ps_sse(sum128_1)
+ hsum128_ps_sse(sum128_2)
+ hsum128_ps_sse(sum128_3)
+ hsum128_ps_sse(sum128_4);
for i in 0..n - m {
length += (*ptr.add(i)).powi(2);
}
if length < f32::EPSILON {
return;
}
length = length.sqrt();
for x in vector.iter_mut() {
*x /= length;
}
}
#[target_feature(enable = "sse")]
pub(crate) unsafe fn dot_similarity_sse(v1: &[f32], v2: &[f32]) -> f32 {
let n = v1.len();
let m = n - (n % 16);
let mut ptr1: *const f32 = v1.as_ptr();
let mut ptr2: *const f32 = v2.as_ptr();
let mut sum128_1: __m128 = _mm_setzero_ps();
let mut sum128_2: __m128 = _mm_setzero_ps();
let mut sum128_3: __m128 = _mm_setzero_ps();
let mut sum128_4: __m128 = _mm_setzero_ps();
let mut i: usize = 0;
while i < m {
sum128_1 = _mm_add_ps(_mm_mul_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)), sum128_1);
sum128_2 = _mm_add_ps(
_mm_mul_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))),
sum128_2,
);
sum128_3 = _mm_add_ps(
_mm_mul_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))),
sum128_3,
);
sum128_4 = _mm_add_ps(
_mm_mul_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))),
sum128_4,
);
ptr1 = ptr1.add(16);
ptr2 = ptr2.add(16);
i += 16;
}
let mut result = hsum128_ps_sse(sum128_1)
+ hsum128_ps_sse(sum128_2)
+ hsum128_ps_sse(sum128_3)
+ hsum128_ps_sse(sum128_4);
for i in 0..n - m {
result += (*ptr1.add(i)) * (*ptr2.add(i));
}
result
}
#[cfg(test)]
mod tests {
#[test]
fn test_spaces_sse() {
use super::*;
use crate::*;
if is_x86_feature_detected!("sse") {
let v1: Vec<f32> = vec![
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
26., 27., 28., 29., 30., 31.,
];
let v2: Vec<f32> = vec![
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
56., 57., 58., 59., 60., 61.,
];
let euclid_simd = unsafe { euclid_distance_sse(&v1, &v2) };
let euclid = euclid_distance(&v1, &v2);
assert_eq!(euclid_simd, euclid);
let dot_simd = unsafe { dot_similarity_sse(&v1, &v2) };
let dot = dot_similarity(&v1, &v2);
assert_eq!(dot_simd, dot);
let mut v1 = v1;
let mut v1_copy = v1.clone();
unsafe { cosine_preprocess_sse(&mut v1) };
cosine_preprocess(&mut v1_copy);
assert_eq!(v1, v1_copy);
} else {
println!("sse test skipped");
}
}
}

View File

@ -7,8 +7,8 @@ use std::io::{BufReader, BufWriter};
use std::iter::FromIterator; use std::iter::FromIterator;
use std::marker::PhantomData; use std::marker::PhantomData;
use distance_metrics::EuclidMetric;
use distance_metrics::Metric; use distance_metrics::Metric;
use distance_metrics::{CosineMetric, EuclidMetric};
use instant_distance::Point; use instant_distance::Point;
use pyo3::conversion::IntoPy; use pyo3::conversion::IntoPy;
use pyo3::exceptions::PyValueError; use pyo3::exceptions::PyValueError;
@ -65,7 +65,7 @@ struct HnswMap {
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
enum HnswMapWithMetric { enum HnswMapWithMetric {
Euclid(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>), Euclid(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
Cosine(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>), Cosine(instant_distance::HnswMap<FloatArray<CosineMetric>, MapValue>),
} }
#[pymethods] #[pymethods]
@ -136,7 +136,7 @@ struct Hnsw {
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
enum HnswWithMetric { enum HnswWithMetric {
Euclid(instant_distance::Hnsw<FloatArray<EuclidMetric>>), Euclid(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
Cosine(instant_distance::Hnsw<FloatArray<EuclidMetric>>), Cosine(instant_distance::Hnsw<FloatArray<CosineMetric>>),
} }
#[pymethods] #[pymethods]
@ -426,7 +426,8 @@ impl<M: Metric> FloatArray<M> {
} }
impl<M: Metric> From<Vec<f32>> for FloatArray<M> { impl<M: Metric> From<Vec<f32>> for FloatArray<M> {
fn from(array: Vec<f32>) -> Self { fn from(mut array: Vec<f32>) -> Self {
M::preprocess(&mut array);
Self { Self {
array, array,
phantom: PhantomData, phantom: PhantomData,