Replace Euclid metric implementation
This commit is contained in:
parent
bca31ad33f
commit
ac4257495e
|
@ -1,5 +1,5 @@
|
|||
[workspace]
|
||||
members = ["instant-distance", "instant-distance-py"]
|
||||
members = ["distance-metrics", "instant-distance", "instant-distance-py"]
|
||||
|
||||
[profile.bench]
|
||||
debug = true
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
[package]
|
||||
name = "distance-metrics"
|
||||
version = "0.6.0"
|
||||
license = "MIT OR Apache-2.0"
|
||||
edition = "2021"
|
||||
rust-version = "1.58"
|
||||
homepage = "https://github.com/InstantDomain/instant-distance"
|
||||
repository = "https://github.com/InstantDomain/instant-distance"
|
||||
documentation = "https://docs.rs/instant-distance"
|
||||
workspace = ".."
|
||||
readme = "../README.md"
|
||||
|
||||
[dependencies]
|
||||
|
||||
[dev-dependencies]
|
||||
bencher = "0.1.5"
|
||||
rand = { version = "0.8", features = ["small_rng"] }
|
||||
|
||||
[[bench]]
|
||||
name = "all"
|
||||
harness = false
|
|
@ -0,0 +1,33 @@
|
|||
use bencher::{benchmark_group, benchmark_main, Bencher};
|
||||
|
||||
use distance_metrics::{EuclidMetric, Metric};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
benchmark_main!(benches);
|
||||
benchmark_group!(benches, legacy, non_simd, metric::<EuclidMetric>,);
|
||||
|
||||
fn legacy(bench: &mut Bencher) {
|
||||
let mut rng = StdRng::seed_from_u64(SEED);
|
||||
let point_a = distance_metrics::FloatArray([rng.gen(); 300]);
|
||||
let point_b = distance_metrics::FloatArray([rng.gen(); 300]);
|
||||
|
||||
bench.iter(|| distance_metrics::legacy_distance(&point_a, &point_b))
|
||||
}
|
||||
|
||||
fn non_simd(bench: &mut Bencher) {
|
||||
let mut rng = StdRng::seed_from_u64(SEED);
|
||||
let point_a = [rng.gen(); 300];
|
||||
let point_b = [rng.gen(); 300];
|
||||
|
||||
bench.iter(|| distance_metrics::euclid_distance(&point_a, &point_b))
|
||||
}
|
||||
|
||||
fn metric<M: Metric>(bench: &mut Bencher) {
|
||||
let mut rng = StdRng::seed_from_u64(SEED);
|
||||
let point_a = [rng.gen(); 300];
|
||||
let point_b = [rng.gen(); 300];
|
||||
|
||||
bench.iter(|| M::distance(&point_a, &point_b))
|
||||
}
|
||||
|
||||
const SEED: u64 = 123456789;
|
|
@ -0,0 +1,115 @@
|
|||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub mod simd_sse;
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub mod simd_avx;
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
pub mod simd_neon;
|
||||
|
||||
/// Defines how to compare vectors
|
||||
pub trait Metric {
|
||||
/// Greater the value - more distant the vectors
|
||||
fn distance(v1: &[f32], v2: &[f32]) -> f32;
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
const MIN_DIM_SIZE_AVX: usize = 32;
|
||||
|
||||
#[cfg(any(
|
||||
target_arch = "x86",
|
||||
target_arch = "x86_64",
|
||||
all(target_arch = "aarch64", target_feature = "neon")
|
||||
))]
|
||||
const MIN_DIM_SIZE_SIMD: usize = 16;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct EuclidMetric {}
|
||||
|
||||
impl Metric for EuclidMetric {
|
||||
fn distance(v1: &[f32], v2: &[f32]) -> f32 {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if is_x86_feature_detected!("avx")
|
||||
&& is_x86_feature_detected!("fma")
|
||||
&& v1.len() >= MIN_DIM_SIZE_AVX
|
||||
{
|
||||
return unsafe { simd_avx::euclid_distance_avx(v1, v2) };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
if is_x86_feature_detected!("sse") && v1.len() >= MIN_DIM_SIZE_SIMD {
|
||||
return unsafe { simd_sse::euclid_distance_sse(v1, v2) };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
|
||||
{
|
||||
if std::arch::is_aarch64_feature_detected!("neon") && v1.len() >= MIN_DIM_SIZE_SIMD {
|
||||
return unsafe { simple_neon::euclid_distance_neon(v1, v2) };
|
||||
}
|
||||
}
|
||||
|
||||
euclid_distance(v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn euclid_distance(v1: &[f32], v2: &[f32]) -> f32 {
|
||||
let s: f32 = v1
|
||||
.iter()
|
||||
.copied()
|
||||
.zip(v2.iter().copied())
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum();
|
||||
s.abs().sqrt()
|
||||
}
|
||||
|
||||
pub fn legacy_distance(lhs: &FloatArray, rhs: &FloatArray) -> f32 {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
use std::arch::x86_64::{
|
||||
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
|
||||
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps,
|
||||
_mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
|
||||
};
|
||||
debug_assert_eq!(lhs.0.len() % 8, 4);
|
||||
|
||||
unsafe {
|
||||
let mut acc_8x = _mm256_setzero_ps();
|
||||
for (lh_slice, rh_slice) in lhs.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
|
||||
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
|
||||
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
|
||||
let diff = _mm256_sub_ps(lh_8x, rh_8x);
|
||||
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
|
||||
}
|
||||
|
||||
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
|
||||
let right = _mm256_castps256_ps128(acc_8x); // lower half
|
||||
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
|
||||
|
||||
let lh_4x = _mm_load_ps(lhs.0[DIMENSIONS - 4..].as_ptr());
|
||||
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
|
||||
let diff = _mm_sub_ps(lh_4x, rh_4x);
|
||||
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
|
||||
|
||||
let lower = _mm_movehl_ps(acc_4x, acc_4x);
|
||||
acc_4x = _mm_add_ps(acc_4x, lower);
|
||||
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
|
||||
acc_4x = _mm_add_ss(acc_4x, upper);
|
||||
_mm_cvtss_f32(acc_4x)
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
lhs.0
|
||||
.iter()
|
||||
.zip(rhs.0.iter())
|
||||
.map(|(&a, &b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
}
|
||||
|
||||
#[repr(align(32))]
|
||||
pub struct FloatArray(pub [f32; DIMENSIONS]);
|
||||
|
||||
const DIMENSIONS: usize = 300;
|
|
@ -0,0 +1,54 @@
|
|||
use std::arch::x86_64::*;
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
#[target_feature(enable = "fma")]
|
||||
unsafe fn hsum256_ps_avx(x: __m256) -> f32 {
|
||||
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
|
||||
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
|
||||
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
|
||||
_mm_cvtss_f32(x32)
|
||||
}
|
||||
|
||||
#[target_feature(enable = "avx")]
|
||||
#[target_feature(enable = "fma")]
|
||||
pub(crate) unsafe fn euclid_distance_avx(v1: &[f32], v2: &[f32]) -> f32 {
|
||||
let n = v1.len();
|
||||
let m = n - (n % 32);
|
||||
let mut ptr1: *const f32 = v1.as_ptr();
|
||||
let mut ptr2: *const f32 = v2.as_ptr();
|
||||
let mut sum256_1: __m256 = _mm256_setzero_ps();
|
||||
let mut sum256_2: __m256 = _mm256_setzero_ps();
|
||||
let mut sum256_3: __m256 = _mm256_setzero_ps();
|
||||
let mut sum256_4: __m256 = _mm256_setzero_ps();
|
||||
let mut i: usize = 0;
|
||||
while i < m {
|
||||
let sub256_1: __m256 =
|
||||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(0)), _mm256_loadu_ps(ptr2.add(0)));
|
||||
sum256_1 = _mm256_fmadd_ps(sub256_1, sub256_1, sum256_1);
|
||||
|
||||
let sub256_2: __m256 =
|
||||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(8)), _mm256_loadu_ps(ptr2.add(8)));
|
||||
sum256_2 = _mm256_fmadd_ps(sub256_2, sub256_2, sum256_2);
|
||||
|
||||
let sub256_3: __m256 =
|
||||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(16)), _mm256_loadu_ps(ptr2.add(16)));
|
||||
sum256_3 = _mm256_fmadd_ps(sub256_3, sub256_3, sum256_3);
|
||||
|
||||
let sub256_4: __m256 =
|
||||
_mm256_sub_ps(_mm256_loadu_ps(ptr1.add(24)), _mm256_loadu_ps(ptr2.add(24)));
|
||||
sum256_4 = _mm256_fmadd_ps(sub256_4, sub256_4, sum256_4);
|
||||
|
||||
ptr1 = ptr1.add(32);
|
||||
ptr2 = ptr2.add(32);
|
||||
i += 32;
|
||||
}
|
||||
|
||||
let mut result = hsum256_ps_avx(sum256_1)
|
||||
+ hsum256_ps_avx(sum256_2)
|
||||
+ hsum256_ps_avx(sum256_3)
|
||||
+ hsum256_ps_avx(sum256_4);
|
||||
for i in 0..n - m {
|
||||
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
|
||||
}
|
||||
result.abs().sqrt()
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
#[cfg(target_feature = "neon")]
|
||||
use std::arch::aarch64::*;
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub(crate) unsafe fn euclid_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
|
||||
let n = v1.len();
|
||||
let m = n - (n % 16);
|
||||
let mut ptr1: *const f32 = v1.as_ptr();
|
||||
let mut ptr2: *const f32 = v2.as_ptr();
|
||||
let mut sum1 = vdupq_n_f32(0.);
|
||||
let mut sum2 = vdupq_n_f32(0.);
|
||||
let mut sum3 = vdupq_n_f32(0.);
|
||||
let mut sum4 = vdupq_n_f32(0.);
|
||||
|
||||
let mut i: usize = 0;
|
||||
while i < m {
|
||||
let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2));
|
||||
sum1 = vfmaq_f32(sum1, sub1, sub1);
|
||||
|
||||
let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
|
||||
sum2 = vfmaq_f32(sum2, sub2, sub2);
|
||||
|
||||
let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
|
||||
sum3 = vfmaq_f32(sum3, sub3, sub3);
|
||||
|
||||
let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
|
||||
sum4 = vfmaq_f32(sum4, sub4, sub4);
|
||||
|
||||
ptr1 = ptr1.add(16);
|
||||
ptr2 = ptr2.add(16);
|
||||
i += 16;
|
||||
}
|
||||
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
|
||||
for i in 0..n - m {
|
||||
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
|
||||
}
|
||||
result.abs().sqrt()
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
#[cfg(target_arch = "x86")]
|
||||
use std::arch::x86::*;
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
#[target_feature(enable = "sse")]
|
||||
unsafe fn hsum128_ps_sse(x: __m128) -> f32 {
|
||||
let x64: __m128 = _mm_add_ps(x, _mm_movehl_ps(x, x));
|
||||
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
|
||||
_mm_cvtss_f32(x32)
|
||||
}
|
||||
|
||||
#[target_feature(enable = "sse")]
|
||||
pub(crate) unsafe fn euclid_distance_sse(v1: &[f32], v2: &[f32]) -> f32 {
|
||||
let n = v1.len();
|
||||
let m = n - (n % 16);
|
||||
let mut ptr1: *const f32 = v1.as_ptr();
|
||||
let mut ptr2: *const f32 = v2.as_ptr();
|
||||
let mut sum128_1: __m128 = _mm_setzero_ps();
|
||||
let mut sum128_2: __m128 = _mm_setzero_ps();
|
||||
let mut sum128_3: __m128 = _mm_setzero_ps();
|
||||
let mut sum128_4: __m128 = _mm_setzero_ps();
|
||||
let mut i: usize = 0;
|
||||
while i < m {
|
||||
let sub128_1 = _mm_sub_ps(_mm_loadu_ps(ptr1), _mm_loadu_ps(ptr2));
|
||||
sum128_1 = _mm_add_ps(_mm_mul_ps(sub128_1, sub128_1), sum128_1);
|
||||
|
||||
let sub128_2 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(4)), _mm_loadu_ps(ptr2.add(4)));
|
||||
sum128_2 = _mm_add_ps(_mm_mul_ps(sub128_2, sub128_2), sum128_2);
|
||||
|
||||
let sub128_3 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(8)), _mm_loadu_ps(ptr2.add(8)));
|
||||
sum128_3 = _mm_add_ps(_mm_mul_ps(sub128_3, sub128_3), sum128_3);
|
||||
|
||||
let sub128_4 = _mm_sub_ps(_mm_loadu_ps(ptr1.add(12)), _mm_loadu_ps(ptr2.add(12)));
|
||||
sum128_4 = _mm_add_ps(_mm_mul_ps(sub128_4, sub128_4), sum128_4);
|
||||
|
||||
ptr1 = ptr1.add(16);
|
||||
ptr2 = ptr2.add(16);
|
||||
i += 16;
|
||||
}
|
||||
|
||||
let mut result = hsum128_ps_sse(sum128_1)
|
||||
+ hsum128_ps_sse(sum128_2)
|
||||
+ hsum128_ps_sse(sum128_3)
|
||||
+ hsum128_ps_sse(sum128_4);
|
||||
for i in 0..n - m {
|
||||
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
|
||||
}
|
||||
result.abs().sqrt()
|
||||
}
|
|
@ -16,7 +16,7 @@ crate-type = ["cdylib"]
|
|||
|
||||
[dependencies]
|
||||
bincode = "1.3.1"
|
||||
distance-metrics = { version = "0.6", path = "../distance-metrics" }
|
||||
instant-distance = { version = "0.6", path = "../instant-distance", features = ["with-serde"] }
|
||||
pyo3 = { version = "0.18.0", features = ["extension-module"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde-big-array = "0.4.1"
|
||||
|
|
|
@ -5,15 +5,17 @@ use std::convert::TryFrom;
|
|||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::iter::FromIterator;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use distance_metrics::EuclidMetric;
|
||||
use distance_metrics::Metric;
|
||||
use instant_distance::Point;
|
||||
use pyo3::conversion::IntoPy;
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::types::{PyList, PyModule, PyString};
|
||||
use pyo3::{pyclass, pymethods, pymodule};
|
||||
use pyo3::{Py, PyAny, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_big_array::BigArray;
|
||||
|
||||
#[pymodule]
|
||||
#[pyo3(name = "instant_distance")]
|
||||
|
@ -62,8 +64,8 @@ struct HnswMap {
|
|||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
enum HnswMapWithMetric {
|
||||
Euclid(instant_distance::HnswMap<FloatArray, MapValue>),
|
||||
Cosine(instant_distance::HnswMap<FloatArray, MapValue>),
|
||||
Euclid(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
|
||||
Cosine(instant_distance::HnswMap<FloatArray<EuclidMetric>, MapValue>),
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -126,9 +128,6 @@ impl HnswMap {
|
|||
}
|
||||
|
||||
/// An instance of hierarchical navigable small worlds
|
||||
///
|
||||
/// For now, this is specialized to only support 300-element (32-bit) float vectors
|
||||
/// with a squared Euclidean distance metric.
|
||||
#[pyclass]
|
||||
struct Hnsw {
|
||||
inner: HnswWithMetric,
|
||||
|
@ -136,8 +135,8 @@ struct Hnsw {
|
|||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
enum HnswWithMetric {
|
||||
Euclid(instant_distance::Hnsw<FloatArray>),
|
||||
Cosine(instant_distance::Hnsw<FloatArray>),
|
||||
Euclid(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
|
||||
Cosine(instant_distance::Hnsw<FloatArray<EuclidMetric>>),
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -414,73 +413,42 @@ impl Neighbor {
|
|||
}
|
||||
}
|
||||
|
||||
#[repr(align(32))]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
||||
struct FloatArray<M> {
|
||||
array: Vec<f32>,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl FloatArray {
|
||||
impl<M: Metric> FloatArray<M> {
|
||||
fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> {
|
||||
list.into_iter().map(FloatArray::try_from).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&PyAny> for FloatArray {
|
||||
type Error = PyErr;
|
||||
|
||||
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
|
||||
let mut new = FloatArray([0.0; DIMENSIONS]);
|
||||
for (i, val) in value.iter()?.enumerate() {
|
||||
match i >= DIMENSIONS {
|
||||
true => return Err(PyTypeError::new_err("point array too long")),
|
||||
false => new.0[i] = val?.extract::<f32>()?,
|
||||
impl<M: Metric> From<Vec<f32>> for FloatArray<M> {
|
||||
fn from(array: Vec<f32>) -> Self {
|
||||
Self {
|
||||
array,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
Ok(new)
|
||||
}
|
||||
}
|
||||
|
||||
impl Point for FloatArray {
|
||||
impl<M: Metric> TryFrom<&PyAny> for FloatArray<M> {
|
||||
type Error = PyErr;
|
||||
|
||||
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
|
||||
let array: Vec<f32> = value
|
||||
.iter()?
|
||||
.map(|val| val.and_then(|v| v.extract::<f32>()))
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(Self::from(array))
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Metric + Clone + Sync> Point for FloatArray<M> {
|
||||
fn distance(&self, rhs: &Self) -> f32 {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
use std::arch::x86_64::{
|
||||
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps,
|
||||
_mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
|
||||
_mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps,
|
||||
};
|
||||
debug_assert_eq!(self.0.len() % 8, 4);
|
||||
|
||||
unsafe {
|
||||
let mut acc_8x = _mm256_setzero_ps();
|
||||
for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) {
|
||||
let lh_8x = _mm256_load_ps(lh_slice.as_ptr());
|
||||
let rh_8x = _mm256_load_ps(rh_slice.as_ptr());
|
||||
let diff = _mm256_sub_ps(lh_8x, rh_8x);
|
||||
acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x);
|
||||
}
|
||||
|
||||
let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half
|
||||
let right = _mm256_castps256_ps128(acc_8x); // lower half
|
||||
acc_4x = _mm_add_ps(acc_4x, right); // sum halves
|
||||
|
||||
let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr());
|
||||
let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr());
|
||||
let diff = _mm_sub_ps(lh_4x, rh_4x);
|
||||
acc_4x = _mm_fmadd_ps(diff, diff, acc_4x);
|
||||
|
||||
let lower = _mm_movehl_ps(acc_4x, acc_4x);
|
||||
acc_4x = _mm_add_ps(acc_4x, lower);
|
||||
let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1);
|
||||
acc_4x = _mm_add_ss(acc_4x, upper);
|
||||
_mm_cvtss_f32(acc_4x)
|
||||
}
|
||||
}
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
self.0
|
||||
.iter()
|
||||
.zip(rhs.0.iter())
|
||||
.map(|(&a, &b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
M::distance(&self.array, &rhs.array)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -504,5 +472,3 @@ impl IntoPy<Py<PyAny>> for &'_ MapValue {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DIMENSIONS: usize = 300;
|
||||
|
|
Loading…
Reference in New Issue