diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..15d5d32 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.x86_64-apple-darwin] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] diff --git a/Cargo.toml b/Cargo.toml index c309130..0f7eac6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["instant-distance"] +members = ["instant-distance", "instant-distance-py"] [profile.bench] debug = true diff --git a/instant-distance-py/Cargo.toml b/instant-distance-py/Cargo.toml new file mode 100644 index 0000000..eef7403 --- /dev/null +++ b/instant-distance-py/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "instant-distance-py" +version = "0.1.0" +authors = ["Dirkjan Ochtman "] +edition = "2018" +license = "MIT OR Apache-2.0" + +[lib] +name = "instant_distance" +crate-type = ["cdylib"] + +[dependencies] +instant-distance = { version = "0.2", path = "../instant-distance" } +pyo3 = { version = "0.13.2", features = ["extension-module"] } +serde = { version = "1", features = ["derive"] } +serde-big-array = "0.3.2" diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs new file mode 100644 index 0000000..d417640 --- /dev/null +++ b/instant-distance-py/src/lib.rs @@ -0,0 +1,268 @@ +#![allow(clippy::from_iter_instead_of_collect)] +use std::convert::TryFrom; +use std::iter::FromIterator; + +use instant_distance::Point; +use pyo3::exceptions::PyTypeError; +use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; +use pyo3::types::{PyList, PyModule}; +use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; +use serde::{Deserialize, Serialize}; +use serde_big_array::big_array; + +#[pymodule] +fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + +#[pyclass] +struct Hnsw { + inner: instant_distance::Hnsw, +} + +#[pymethods] +impl Hnsw { + #[staticmethod] + fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec)> { + let points = input + .into_iter() + .map(FloatArray::try_from) + .collect::, PyErr>>()?; + + let (inner, ids) = instant_distance::Builder::from(config).build(&points); + let ids = Vec::from_iter(ids.into_iter().map(PointId::from)); + Ok((Self { inner }, ids)) + } + + fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { + let point = FloatArray::try_from(point)?; + let _ = self.inner.search(&point, &mut search.inner); + search.cur = Some(0); + Ok(()) + } +} + +#[pyclass] +struct Search { + inner: instant_distance::Search, + cur: Option, +} + +#[pymethods] +impl Search { + #[new] + fn new() -> Self { + Self { + inner: instant_distance::Search::default(), + cur: None, + } + } +} + +#[pyproto] +impl PyIterProtocol for Search { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(mut slf: PyRefMut) -> Option { + let idx = match &slf.cur { + Some(idx) => *idx, + None => return None, + }; + + let pid = match slf.inner.get(idx) { + Some(pid) => pid, + None => { + slf.cur = None; + return None; + } + }; + + slf.cur = Some(idx + 1); + Some(PointId::from(pid)) + } +} + +#[pyclass] +#[derive(Copy, Clone, Default)] +struct Config { + #[pyo3(get, set)] + ef_search: usize, + #[pyo3(get, set)] + ef_construction: usize, + #[pyo3(get, set)] + ml: f32, + #[pyo3(get, set)] + seed: u64, + #[pyo3(get, set)] + heuristic: Option, +} + +#[pymethods] +impl Config { + #[new] + fn new() -> Self { + let builder = instant_distance::Builder::default(); + let (ef_search, ef_construction, ml, seed) = builder.into_parts(); + let heuristic = Some(Heuristic::default()); + Self { + ef_search, + ef_construction, + ml, + seed, + heuristic, + } + } +} + +impl From<&Config> for instant_distance::Builder { + fn from(py: &Config) -> Self { + let Config { + ef_search, + ef_construction, + ml, + seed, + heuristic, + } = *py; + Self::default() + .ef_search(ef_search) + .ef_construction(ef_construction) + .ml(ml) + .seed(seed) + .select_heuristic(heuristic.map(|h| h.into())) + } +} + +#[pyclass] +#[derive(Copy, Clone)] +struct Heuristic { + #[pyo3(get, set)] + extend_candidates: bool, + #[pyo3(get, set)] + keep_pruned: bool, +} + +#[pymethods] +impl Heuristic { + #[new] + fn new() -> Self { + let default = instant_distance::Heuristic::default(); + let instant_distance::Heuristic { + extend_candidates, + keep_pruned, + } = default; + Self { + extend_candidates, + keep_pruned, + } + } +} + +impl Default for Heuristic { + fn default() -> Self { + Self { + extend_candidates: false, + keep_pruned: true, + } + } +} + +impl From for instant_distance::Heuristic { + fn from(py: Heuristic) -> Self { + let Heuristic { + extend_candidates, + keep_pruned, + } = py; + Self { + extend_candidates, + keep_pruned, + } + } +} + +#[pyclass] +struct PointId { + inner: instant_distance::PointId, +} + +impl From for PointId { + fn from(inner: instant_distance::PointId) -> Self { + Self { inner } + } +} + +#[pyproto] +impl<'p> PyObjectProtocol<'p> for PointId { + fn __repr__(&self) -> PyResult { + Ok(format!("{:?}", self.inner)) + } + + fn __hash__(&'p self) -> PyResult { + Ok(self.inner.into_inner()) + } +} + +#[repr(align(32))] +#[derive(Clone, Deserialize, Serialize)] +struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); + +impl TryFrom<&PyAny> for FloatArray { + type Error = PyErr; + + fn try_from(value: &PyAny) -> Result { + let mut new = FloatArray([0.0; DIMENSIONS]); + for (i, val) in value.iter()?.enumerate() { + match i >= DIMENSIONS { + true => return Err(PyTypeError::new_err("point array too long")), + false => new.0[i] = val?.extract::()?, + } + } + Ok(new) + } +} + +big_array! { BigArray; DIMENSIONS } + +impl Point for FloatArray { + fn distance(&self, rhs: &Self) -> f32 { + use std::arch::x86_64::{ + _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_load_ps, + _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_fmadd_ps, + _mm_load_ps, _mm_movehl_ps, _mm_shuffle_ps, _mm_sub_ps, + }; + debug_assert_eq!(self.0.len() % 8, 4); + + unsafe { + let mut acc_8x = _mm256_setzero_ps(); + for (lh_slice, rh_slice) in self.0.chunks_exact(8).zip(rhs.0.chunks_exact(8)) { + let lh_8x = _mm256_load_ps(lh_slice.as_ptr()); + let rh_8x = _mm256_load_ps(rh_slice.as_ptr()); + let diff = _mm256_sub_ps(lh_8x, rh_8x); + acc_8x = _mm256_fmadd_ps(diff, diff, acc_8x); + } + + let mut acc_4x = _mm256_extractf128_ps(acc_8x, 1); // upper half + let right = _mm256_castps256_ps128(acc_8x); // lower half + acc_4x = _mm_add_ps(acc_4x, right); // sum halves + + let lh_4x = _mm_load_ps(self.0[DIMENSIONS - 4..].as_ptr()); + let rh_4x = _mm_load_ps(rhs.0[DIMENSIONS - 4..].as_ptr()); + let diff = _mm_sub_ps(lh_4x, rh_4x); + acc_4x = _mm_fmadd_ps(diff, diff, acc_4x); + + let lower = _mm_movehl_ps(acc_4x, acc_4x); + acc_4x = _mm_add_ps(acc_4x, lower); + let upper = _mm_shuffle_ps(acc_4x, acc_4x, 0x1); + acc_4x = _mm_add_ss(acc_4x, upper); + _mm_cvtss_f32(acc_4x) + } + } +} + +const DIMENSIONS: usize = 300;