diff --git a/.cargo/config.toml b/.cargo/config.toml index 15d5d32..92403a7 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,4 +1,4 @@ -[target.x86_64-apple-darwin] +[build] rustflags = [ "-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup", diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index eb976c2..1a5b034 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -371,37 +371,46 @@ big_array! { BigArray; DIMENSIONS } impl Point for FloatArray { fn distance(&self, rhs: &Self) -> f32 { - use std::arch::x86_64::{ - _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, - _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps, - _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, - }; - debug_assert_eq!(self.0.len() % 8, 4); + #[cfg(target_arch = "x86_64")] + { + use std::arch::x86_64::{ + _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, + _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, + _mm_fmadd_ps, _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, + }; + debug_assert_eq!(self.0.len() % 8, 4); - unsafe { - let mut acc_8x = _mm256_setzero_ps(); - for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { - let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); - let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); - let diff = _mm256_sub_ps(lh_8x, rh_8x); - acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); + unsafe { + let mut acc_8x = _mm256_setzero_ps(); + for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { + let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); + let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); + let diff = _mm256_sub_ps(lh_8x, rh_8x); + acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); + } + + let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half + let right = _mm256_castps256_ps128(acc_8x); // lower half + acc_4x = _mm_add_ps(acc_4x, right); // sum halves + + let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr()); + let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); + let diff = _mm_sub_ps(lh_4x, rh_4x); + acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); + + let lower = _mm_movehl_ps(acc_4x, acc_4x); + acc_4x = _mm_add_ps(acc_4x, lower); + let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); + acc_4x = _mm_add_ss(acc_4x, upper); + _mm_cvtss_f32(acc_4x) } - - let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half - let right = _mm256_castps256_ps128(acc_8x); // lower half - acc_4x = _mm_add_ps(acc_4x, right); // sum halves - - let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr()); - let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); - let diff = _mm_sub_ps(lh_4x, rh_4x); - acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); - - let lower = _mm_movehl_ps(acc_4x, acc_4x); - acc_4x = _mm_add_ps(acc_4x, lower); - let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); - acc_4x = _mm_add_ss(acc_4x, upper); - _mm_cvtss_f32(acc_4x) } + #[cfg(not(target_arch = "x86_64"))] + self.0 + .iter() + .zip(rhs.0.iter()) + .map(|(&a, &b)| (a - b).powi(2)) + .sum::() } }