Add Cosine distance metric
This commit is contained in:
parent
ac4257495e
commit
9d8c83fc38
1
Makefile
1
Makefile
|
@ -8,6 +8,7 @@ test-python: instant-distance-py/test/instant_distance.so
|
||||||
|
|
||||||
bench-python: instant-distance-py/test/instant_distance.so
|
bench-python: instant-distance-py/test/instant_distance.so
|
||||||
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
|
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
|
||||||
|
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config(); config.distance_metric = instant_distance.DistanceMetric.Cosine' 'instant_distance.Hnsw.build(points, config)'
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
cargo clean
|
cargo clean
|
||||||
|
|
|
@ -1,10 +1,18 @@
|
||||||
use bencher::{benchmark_group, benchmark_main, Bencher};
|
use bencher::{benchmark_group, benchmark_main, Bencher};
|
||||||
|
|
||||||
use distance_metrics::{EuclidMetric, Metric};
|
use distance_metrics::{
|
||||||
|
Metric, {CosineMetric, EuclidMetric},
|
||||||
|
};
|
||||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||||
|
|
||||||
benchmark_main!(benches);
|
benchmark_main!(benches);
|
||||||
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,);
|
benchmark_group!(
|
||||||
|
benches,
|
||||||
|
legacy,
|
||||||
|
non_simd,
|
||||||
|
metric::<EuclidMetric>,
|
||||||
|
metric::<CosineMetric>
|
||||||
|
);
|
||||||
|
|
||||||
fn legacy(bench: &mut Bencher) {
|
fn legacy(bench: &mut Bencher) {
|
||||||
let mut rng = StdRng::seed_from_u64(SEED);
|
let mut rng = StdRng::seed_from_u64(SEED);
|
||||||
|
@ -24,8 +32,10 @@ fn non_simd(bench: &mut Bencher) {
|
||||||
|
|
||||||
fn metric<M: Metric>(bench: &mut Bencher) {
|
fn metric<M: Metric>(bench: &mut Bencher) {
|
||||||
let mut rng = StdRng::seed_from_u64(SEED);
|
let mut rng = StdRng::seed_from_u64(SEED);
|
||||||
let point_a = [rng.gen(); 300];
|
let mut point_a = [rng.gen(); 300];
|
||||||
let point_b = [rng.gen(); 300];
|
let mut point_b = [rng.gen(); 300];
|
||||||
|
M::preprocess(&mut point_a);
|
||||||
|
M::preprocess(&mut point_b);
|
||||||
|
|
||||||
bench.iter(|| M::distance(&point_a, &point_b))
|
bench.iter(|| M::distance(&point_a, &point_b))
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,9 @@ pub mod simd_neon;
|
||||||
pub trait Metric {
|
pub trait Metric {
|
||||||
/// Greater the value - more distant the vectors
|
/// Greater the value - more distant the vectors
|
||||||
fn distance(v1: &[f32], v2: &[f32]) -> f32;
|
fn distance(v1: &[f32], v2: &[f32]) -> f32;
|
||||||
|
|
||||||
|
/// Necessary vector transformations performed before adding it to the collection (like normalization)
|
||||||
|
fn preprocess(vector: &mut [f32]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
@ -54,6 +57,72 @@ impl Metric for EuclidMetric {
|
||||||
|
|
||||||
euclid_distance(v1, v2)
|
euclid_distance(v1, v2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn preprocess(_vector: &mut [f32]) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub struct CosineMetric {}
|
||||||
|
|
||||||
|
impl Metric for CosineMetric {
|
||||||
|
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
{
|
||||||
|
if is_x86_feature_detected!("avx")
|
||||||
|
&& is_x86_feature_detected!("fma")
|
||||||
|
&& v1.len() >= MIN_DIM_SIZE_AVX
|
||||||
|
{
|
||||||
|
return 1.0 - unsafe { simd_avx::dot_similarity_avx(v1, v2) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
|
{
|
||||||
|
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
|
||||||
|
return 1.0 - unsafe { simd_sse::dot_similarity_sse(v1, v2) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
|
||||||
|
{
|
||||||
|
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
|
||||||
|
return 1.0 - unsafe { simd_neon::dot_similarity_neon(v1, v2) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
1.0 - dot_similarity(v1, v2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn preprocess(vector: &mut [f32]) {
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
{
|
||||||
|
if is_x86_feature_detected!("avx")
|
||||||
|
&& is_x86_feature_detected!("fma")
|
||||||
|
&& vector.len() >= MIN_DIM_SIZE_AVX
|
||||||
|
{
|
||||||
|
return unsafe { simd_avx::cosine_preprocess_avx(vector) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||||
|
{
|
||||||
|
if is_x86_feature_detected!("sse") && vector.len() >= MIN_DIM_SIZE_SIMD {
|
||||||
|
return unsafe { simd_sse::cosine_preprocess_sse(vector) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
|
||||||
|
{
|
||||||
|
if std::arch::is_aarch64_feature_detected!("neon") && vector.len() >= MIN_DIM_SIZE_SIMD
|
||||||
|
{
|
||||||
|
return unsafe { simd_neon::cosine_preprocess_neon(vector) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cosine_preprocess(vector);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
|
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
|
@ -66,6 +135,21 @@ pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
s.abs().sqrt()
|
s.abs().sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cosine_preprocess(vector: &mut [f32]) {
|
||||||
|
let mut length: f32 = vector.iter().map(|x| x * x).sum();
|
||||||
|
if length < f32::EPSILON {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
length = length.sqrt();
|
||||||
|
for x in vector.iter_mut() {
|
||||||
|
*x /= length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dot_similarity(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
|
v1.iter().zip(v2).map(|(a, b)| a * b).sum()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
|
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
|
||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
{
|
{
|
||||||
|
|
|
@ -52,3 +52,135 @@ pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
}
|
}
|
||||||
result.abs().sqrt()
|
result.abs().sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[target_feature(enable = "avx")]
|
||||||
|
#[target_feature(enable = "fma")]
|
||||||
|
pub(crate) unsafe fn cosine_preprocess_avx(vector: &mut [f32]) {
|
||||||
|
let n = vector.len();
|
||||||
|
let m = n - (n % 32);
|
||||||
|
let mut ptr: *const f32 = vector.as_ptr();
|
||||||
|
let mut sum256_1: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut sum256_2: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut sum256_3: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut sum256_4: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut i: usize = 0;
|
||||||
|
while i < m {
|
||||||
|
let m256_1 = _mm256_loadu_ps(ptr);
|
||||||
|
sum256_1 = _mm256_fmadd_ps(m256_1, m256_1, sum256_1);
|
||||||
|
|
||||||
|
let m256_2 = _mm256_loadu_ps(ptr.add(8));
|
||||||
|
sum256_2 = _mm256_fmadd_ps(m256_2, m256_2, sum256_2);
|
||||||
|
|
||||||
|
let m256_3 = _mm256_loadu_ps(ptr.add(16));
|
||||||
|
sum256_3 = _mm256_fmadd_ps(m256_3, m256_3, sum256_3);
|
||||||
|
|
||||||
|
let m256_4 = _mm256_loadu_ps(ptr.add(24));
|
||||||
|
sum256_4 = _mm256_fmadd_ps(m256_4, m256_4, sum256_4);
|
||||||
|
|
||||||
|
ptr = ptr.add(32);
|
||||||
|
i += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut length = hsum256_ps_avx(sum256_1)
|
||||||
|
+ hsum256_ps_avx(sum256_2)
|
||||||
|
+ hsum256_ps_avx(sum256_3)
|
||||||
|
+ hsum256_ps_avx(sum256_4);
|
||||||
|
for i in 0..n - m {
|
||||||
|
length += (*ptr.add(i)).powi(2);
|
||||||
|
}
|
||||||
|
if length < f32::EPSILON {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
length = length.sqrt();
|
||||||
|
for x in vector.iter_mut() {
|
||||||
|
*x /= length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[target_feature(enable = "avx")]
|
||||||
|
#[target_feature(enable = "fma")]
|
||||||
|
pub(crate) unsafe fn dot_similarity_avx(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
|
let n = v1.len();
|
||||||
|
let m = n - (n % 32);
|
||||||
|
let mut ptr1: *const f32 = v1.as_ptr();
|
||||||
|
let mut ptr2: *const f32 = v2.as_ptr();
|
||||||
|
let mut sum256_1: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut sum256_2: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut sum256_3: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut sum256_4: __m256 = _mm256_setzero_ps();
|
||||||
|
let mut i: usize = 0;
|
||||||
|
while i < m {
|
||||||
|
sum256_1 = _mm256_fmadd_ps(_mm256_loadu_ps(ptr1), _mm256_loadu_ps(ptr2), sum256_1);
|
||||||
|
sum256_2 = _mm256_fmadd_ps(
|
||||||
|
_mm256_loadu_ps(ptr1.add(8)),
|
||||||
|
_mm256_loadu_ps(ptr2.add(8)),
|
||||||
|
sum256_2,
|
||||||
|
);
|
||||||
|
sum256_3 = _mm256_fmadd_ps(
|
||||||
|
_mm256_loadu_ps(ptr1.add(16)),
|
||||||
|
_mm256_loadu_ps(ptr2.add(16)),
|
||||||
|
sum256_3,
|
||||||
|
);
|
||||||
|
sum256_4 = _mm256_fmadd_ps(
|
||||||
|
_mm256_loadu_ps(ptr1.add(24)),
|
||||||
|
_mm256_loadu_ps(ptr2.add(24)),
|
||||||
|
sum256_4,
|
||||||
|
);
|
||||||
|
|
||||||
|
ptr1 = ptr1.add(32);
|
||||||
|
ptr2 = ptr2.add(32);
|
||||||
|
i += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = hsum256_ps_avx(sum256_1)
|
||||||
|
+ hsum256_ps_avx(sum256_2)
|
||||||
|
+ hsum256_ps_avx(sum256_3)
|
||||||
|
+ hsum256_ps_avx(sum256_4);
|
||||||
|
|
||||||
|
for i in 0..n - m {
|
||||||
|
result += (*ptr1.add(i)) * (*ptr2.add(i));
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[test]
|
||||||
|
fn test_spaces_avx() {
|
||||||
|
use super::*;
|
||||||
|
use crate::*;
|
||||||
|
|
||||||
|
if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
|
||||||
|
let v1: Vec<f32> = vec![
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
26., 27., 28., 29., 30., 31.,
|
||||||
|
];
|
||||||
|
let v2: Vec<f32> = vec![
|
||||||
|
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
56., 57., 58., 59., 60., 61.,
|
||||||
|
];
|
||||||
|
|
||||||
|
let euclid_simd = unsafe { euclid_distance_avx(&v1, &v2) };
|
||||||
|
let euclid = euclid_distance(&v1, &v2);
|
||||||
|
assert_eq!(euclid_simd, euclid);
|
||||||
|
|
||||||
|
let dot_simd = unsafe { dot_similarity_avx(&v1, &v2) };
|
||||||
|
let dot = dot_similarity(&v1, &v2);
|
||||||
|
assert_eq!(dot_simd, dot);
|
||||||
|
|
||||||
|
let mut v1 = v1;
|
||||||
|
let mut v1_copy = v1.clone();
|
||||||
|
unsafe { cosine_preprocess_avx(&mut v1) };
|
||||||
|
cosine_preprocess(&mut v1_copy);
|
||||||
|
assert_eq!(v1, v1_copy);
|
||||||
|
} else {
|
||||||
|
println!("avx test skipped");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -36,3 +36,108 @@ pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
}
|
}
|
||||||
result.abs().sqrt()
|
result.abs().sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(target_feature = "neon")]
|
||||||
|
pub(crate) unsafe fn cosine_preprocess_neon(vector: &mut [f32]) {
|
||||||
|
let n = vector.len();
|
||||||
|
let m = n - (n % 16);
|
||||||
|
let mut ptr: *const f32 = vector.as_ptr();
|
||||||
|
let mut sum1 = vdupq_n_f32(0.);
|
||||||
|
let mut sum2 = vdupq_n_f32(0.);
|
||||||
|
let mut sum3 = vdupq_n_f32(0.);
|
||||||
|
let mut sum4 = vdupq_n_f32(0.);
|
||||||
|
|
||||||
|
let mut i: usize = 0;
|
||||||
|
while i < m {
|
||||||
|
let d1 = vld1q_f32(ptr);
|
||||||
|
sum1 = vfmaq_f32(sum1, d1, d1);
|
||||||
|
|
||||||
|
let d2 = vld1q_f32(ptr.add(4));
|
||||||
|
sum2 = vfmaq_f32(sum2, d2, d2);
|
||||||
|
|
||||||
|
let d3 = vld1q_f32(ptr.add(8));
|
||||||
|
sum3 = vfmaq_f32(sum3, d3, d3);
|
||||||
|
|
||||||
|
let d4 = vld1q_f32(ptr.add(12));
|
||||||
|
sum4 = vfmaq_f32(sum4, d4, d4);
|
||||||
|
|
||||||
|
ptr = ptr.add(16);
|
||||||
|
i += 16;
|
||||||
|
}
|
||||||
|
let mut length = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
|
||||||
|
for v in vector.iter().take(n).skip(m) {
|
||||||
|
length += v.powi(2);
|
||||||
|
}
|
||||||
|
if length < f32::EPSILON {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let length = length.sqrt();
|
||||||
|
for x in vector.iter_mut() {
|
||||||
|
*x /= length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_feature = "neon")]
|
||||||
|
pub(crate) unsafe fn dot_similarity_neon(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
|
let n = v1.len();
|
||||||
|
let m = n - (n % 16);
|
||||||
|
let mut ptr1: *const f32 = v1.as_ptr();
|
||||||
|
let mut ptr2: *const f32 = v2.as_ptr();
|
||||||
|
let mut sum1 = vdupq_n_f32(0.);
|
||||||
|
let mut sum2 = vdupq_n_f32(0.);
|
||||||
|
let mut sum3 = vdupq_n_f32(0.);
|
||||||
|
let mut sum4 = vdupq_n_f32(0.);
|
||||||
|
|
||||||
|
let mut i: usize = 0;
|
||||||
|
while i < m {
|
||||||
|
sum1 = vfmaq_f32(sum1, vld1q_f32(ptr1), vld1q_f32(ptr2));
|
||||||
|
sum2 = vfmaq_f32(sum2, vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
|
||||||
|
sum3 = vfmaq_f32(sum3, vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
|
||||||
|
sum4 = vfmaq_f32(sum4, vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
|
||||||
|
ptr1 = ptr1.add(16);
|
||||||
|
ptr2 = ptr2.add(16);
|
||||||
|
i += 16;
|
||||||
|
}
|
||||||
|
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
|
||||||
|
for i in 0..n - m {
|
||||||
|
result += (*ptr1.add(i)) * (*ptr2.add(i));
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[cfg(target_feature = "neon")]
|
||||||
|
#[test]
|
||||||
|
fn test_spaces_neon() {
|
||||||
|
use super::*;
|
||||||
|
use crate::*;
|
||||||
|
|
||||||
|
if std::arch::is_aarch64_feature_detected!("neon") {
|
||||||
|
let v1: Vec<f32> = vec![
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
26., 27., 28., 29., 30., 31.,
|
||||||
|
];
|
||||||
|
let v2: Vec<f32> = vec![
|
||||||
|
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
|
||||||
|
56., 57., 58., 59., 60., 61.,
|
||||||
|
];
|
||||||
|
|
||||||
|
let euclid_simd = unsafe { euclid_distance_neon(&v1, &v2) };
|
||||||
|
let euclid = euclid_distance(&v1, &v2);
|
||||||
|
assert_eq!(euclid_simd, euclid);
|
||||||
|
|
||||||
|
let dot_simd = unsafe { dot_similarity_neon(&v1, &v2) };
|
||||||
|
let dot = dot_similarity(&v1, &v2);
|
||||||
|
assert_eq!(dot_simd, dot);
|
||||||
|
|
||||||
|
let mut v1 = v1;
|
||||||
|
let mut v1_copy = v1.clone();
|
||||||
|
unsafe { cosine_preprocess_neon(&mut v1) };
|
||||||
|
cosine_preprocess(&mut v1_copy);
|
||||||
|
assert_eq!(v1, v1_copy);
|
||||||
|
} else {
|
||||||
|
println!("neon test skipped");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -48,3 +48,132 @@ pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
}
|
}
|
||||||
result.abs().sqrt()
|
result.abs().sqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[target_feature(enable = "sse")]
|
||||||
|
pub(crate) unsafe fn cosine_preprocess_sse(vector: &mut [f32]) {
|
||||||
|
let n = vector.len();
|
||||||
|
let m = n - (n % 16);
|
||||||
|
let mut ptr: *const f32 = vector.as_ptr();
|
||||||
|
let mut sum128_1: __m128 = _mm_setzero_ps();
|
||||||
|
let mut sum128_2: __m128 = _mm_setzero_ps();
|
||||||
|
let mut sum128_3: __m128 = _mm_setzero_ps();
|
||||||
|
let mut sum128_4: __m128 = _mm_setzero_ps();
|
||||||
|
|
||||||
|
let mut i: usize = 0;
|
||||||
|
while i < m {
|
||||||
|
let m128_1 = _mm_loadu_ps(ptr);
|
||||||
|
sum128_1 = _mm_add_ps(_mm_mul_ps(m128_1, m128_1), sum128_1);
|
||||||
|
|
||||||
|
let m128_2 = _mm_loadu_ps(ptr.add(4));
|
||||||
|
sum128_2 = _mm_add_ps(_mm_mul_ps(m128_2, m128_2), sum128_2);
|
||||||
|
|
||||||
|
let m128_3 = _mm_loadu_ps(ptr.add(8));
|
||||||
|
sum128_3 = _mm_add_ps(_mm_mul_ps(m128_3, m128_3), sum128_3);
|
||||||
|
|
||||||
|
let m128_4 = _mm_loadu_ps(ptr.add(12));
|
||||||
|
sum128_4 = _mm_add_ps(_mm_mul_ps(m128_4, m128_4), sum128_4);
|
||||||
|
|
||||||
|
ptr = ptr.add(16);
|
||||||
|
i += 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut length = hsum128_ps_sse(sum128_1)
|
||||||
|
+ hsum128_ps_sse(sum128_2)
|
||||||
|
+ hsum128_ps_sse(sum128_3)
|
||||||
|
+ hsum128_ps_sse(sum128_4);
|
||||||
|
for i in 0..n - m {
|
||||||
|
length += (*ptr.add(i)).powi(2);
|
||||||
|
}
|
||||||
|
if length < f32::EPSILON {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
length = length.sqrt();
|
||||||
|
for x in vector.iter_mut() {
|
||||||
|
*x /= length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[target_feature(enable = "sse")]
|
||||||
|
pub(crate) unsafe fn dot_similarity_sse(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
|
let n = v1.len();
|
||||||
|
let m = n - (n % 16);
|
||||||
|
let mut ptr1: *const f32 = v1.as_ptr();
|
||||||
|
let mut ptr2: *const f32 = v2.as_ptr();
|
||||||
|
let mut sum128_1: __m128 = _mm_setzero_ps();
|
||||||
|
let mut sum128_2: __m128 = _mm_setzero_ps();
|
||||||
|
let mut sum128_3: __m128 = _mm_setzero_ps();
|
||||||
|
let mut sum128_4: __m128 = _mm_setzero_ps();
|
||||||
|
|
||||||
|
let mut i: usize = 0;
|
||||||
|
while i < m {
|
||||||
|
sum128_1 = _mm_add_ps(_mm_mul_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2)), sum128_1);
|
||||||
|
|
||||||
|
sum128_2 = _mm_add_ps(
|
||||||
|
_mm_mul_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4))),
|
||||||
|
sum128_2,
|
||||||
|
);
|
||||||
|
|
||||||
|
sum128_3 = _mm_add_ps(
|
||||||
|
_mm_mul_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8))),
|
||||||
|
sum128_3,
|
||||||
|
);
|
||||||
|
|
||||||
|
sum128_4 = _mm_add_ps(
|
||||||
|
_mm_mul_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12))),
|
||||||
|
sum128_4,
|
||||||
|
);
|
||||||
|
|
||||||
|
ptr1 = ptr1.add(16);
|
||||||
|
ptr2 = ptr2.add(16);
|
||||||
|
i += 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = hsum128_ps_sse(sum128_1)
|
||||||
|
+ hsum128_ps_sse(sum128_2)
|
||||||
|
+ hsum128_ps_sse(sum128_3)
|
||||||
|
+ hsum128_ps_sse(sum128_4);
|
||||||
|
for i in 0..n - m {
|
||||||
|
result += (*ptr1.add(i)) * (*ptr2.add(i));
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#[test]
|
||||||
|
fn test_spaces_sse() {
|
||||||
|
use super::*;
|
||||||
|
use crate::*;
|
||||||
|
|
||||||
|
if is_x86_feature_detected!("sse") {
|
||||||
|
let v1: Vec<f32> = vec![
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
26., 27., 28., 29., 30., 31.,
|
||||||
|
];
|
||||||
|
let v2: Vec<f32> = vec![
|
||||||
|
40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
|
||||||
|
56., 57., 58., 59., 60., 61.,
|
||||||
|
];
|
||||||
|
|
||||||
|
let euclid_simd = unsafe { euclid_distance_sse(&v1, &v2) };
|
||||||
|
let euclid = euclid_distance(&v1, &v2);
|
||||||
|
assert_eq!(euclid_simd, euclid);
|
||||||
|
|
||||||
|
let dot_simd = unsafe { dot_similarity_sse(&v1, &v2) };
|
||||||
|
let dot = dot_similarity(&v1, &v2);
|
||||||
|
assert_eq!(dot_simd, dot);
|
||||||
|
|
||||||
|
let mut v1 = v1;
|
||||||
|
let mut v1_copy = v1.clone();
|
||||||
|
unsafe { cosine_preprocess_sse(&mut v1) };
|
||||||
|
cosine_preprocess(&mut v1_copy);
|
||||||
|
assert_eq!(v1, v1_copy);
|
||||||
|
} else {
|
||||||
|
println!("sse test skipped");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -7,8 +7,8 @@ use std::io::{BufReader, BufWriter};
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use distance_metrics::EuclidMetric;
|
|
||||||
use distance_metrics::Metric;
|
use distance_metrics::Metric;
|
||||||
|
use distance_metrics::{CosineMetric, EuclidMetric};
|
||||||
use instant_distance::Point;
|
use instant_distance::Point;
|
||||||
use pyo3::conversion::IntoPy;
|
use pyo3::conversion::IntoPy;
|
||||||
use pyo3::exceptions::PyValueError;
|
use pyo3::exceptions::PyValueError;
|
||||||
|
@ -65,7 +65,7 @@ struct HnswMap {
|
||||||
#[derive(Deserialize, Serialize)]
|
#[derive(Deserialize, Serialize)]
|
||||||
enum HnswMapWithMetric {
|
enum HnswMapWithMetric {
|
||||||
Euclid(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
|
Euclid(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
|
||||||
Cosine(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
|
Cosine(instant_distance::HnswMap<FloatArray<CosineMetric>, MapValue>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
|
@ -136,7 +136,7 @@ struct Hnsw {
|
||||||
#[derive(Deserialize, Serialize)]
|
#[derive(Deserialize, Serialize)]
|
||||||
enum HnswWithMetric {
|
enum HnswWithMetric {
|
||||||
Euclid(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
|
Euclid(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
|
||||||
Cosine(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
|
Cosine(instant_distance::Hnsw<FloatArray<CosineMetric>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
|
@ -426,7 +426,8 @@ impl<M: Metric> FloatArray<M> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: Metric> From<Vec<f32>> for FloatArray<M> {
|
impl<M: Metric> From<Vec<f32>> for FloatArray<M> {
|
||||||
fn from(array: Vec<f32>) -> Self {
|
fn from(mut array: Vec<f32>) -> Self {
|
||||||
|
M::preprocess(&mut array);
|
||||||
Self {
|
Self {
|
||||||
array,
|
array,
|
||||||
phantom: PhantomData,
|
phantom: PhantomData,
|
||||||
|
|
Loading…
Reference in New Issue