diff --git a/instant-distance-py/Cargo.toml b/instant-distance-py/Cargo.toml index 8eac673..b4112ec 100644 --- a/instant-distance-py/Cargo.toml +++ b/instant-distance-py/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "instant-distance-py" -version = "0.1.1" +version = "0.2.0" authors = ["Dirkjan Ochtman "] edition = "2018" license = "MIT OR Apache-2.0" @@ -11,7 +11,7 @@ crate-type = ["cdylib"] [dependencies] bincode = "1.3.1" -instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] } +instant-distance = { version = "0.3", path = "../instant-distance", features = ["with-serde"] } pyo3 = { version = "0.13.2", features = ["extension-module"] } serde = { version = "1", features = ["derive"] } serde-big-array = "0.3.2" diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index 428b156..9d561a1 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -8,12 +8,13 @@ use instant_distance::Point; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::types::{PyList, PyModule}; -use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python}; +use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; use serde::{Deserialize, Serialize}; use serde_big_array::big_array; #[pymodule] fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -104,14 +105,14 @@ impl PyIterProtocol for Search { } /// Return the next closest point - fn __next__(mut slf: PyRefMut) -> Option { + fn __next__(mut slf: PyRefMut) -> Option { let idx = match &slf.cur { Some(idx) => *idx, None => return None, }; - let pid = match slf.inner.get(idx) { - Some(pid) => pid, + let candidate = match slf.inner.get(idx) { + Some(c) => c, None => { slf.cur = None; return None; @@ -119,7 +120,10 @@ impl PyIterProtocol for Search { }; slf.cur = Some(idx + 1); - Some(pid.into_inner()) + Some(Candidate { + pid: candidate.pid.into_inner(), + distance: candidate.distance(), + }) } } @@ -234,6 +238,24 @@ impl From for instant_distance::Heuristic { } } +/// Search buffer and result set +#[pyclass] +struct Candidate { + /// Identifier for the neighboring point + #[pyo3(get)] + pid: u32, + /// Distance to the neighboring point + #[pyo3(get)] + distance: f32, +} + +#[pyproto] +impl PyObjectProtocol for Candidate { + fn __repr__(&self) -> PyResult { + Ok(format!("instant_distance.Candidate(pid={}, distance={})", self.pid, self.distance)) + } +} + #[repr(align(32))] #[derive(Clone, Deserialize, Serialize)] struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); diff --git a/instant-distance-py/test/test.py b/instant-distance-py/test/test.py index f93908f..28ee613 100644 --- a/instant-distance-py/test/test.py +++ b/instant-distance-py/test/test.py @@ -7,8 +7,8 @@ def main(): p = [random.random() for _ in range(300)] search = instant_distance.Search() hnsw.search(p, search) - for pid in search: - print(pid) + for candidate in search: + print(candidate) if __name__ == '__main__': main() diff --git a/instant-distance/Cargo.toml b/instant-distance/Cargo.toml index 0720ac0..0afa760 100644 --- a/instant-distance/Cargo.toml +++ b/instant-distance/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "instant-distance" -version = "0.2.0" +version = "0.3.0" license = "MIT OR Apache-2.0" authors = ["Dirkjan Ochtman "] edition = "2018" diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index 948ac70..e66746a 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -329,7 +329,7 @@ where &self, point: &P, search: &'a mut Search, - ) -> impl Iterator + ExactSizeIterator + 'a { + ) -> impl Iterator + ExactSizeIterator + 'a { search.reset(); if self.points.is_empty() { return search.iter(); @@ -670,13 +670,13 @@ impl Search { &self.nearest } - fn iter(&self) -> impl Iterator + ExactSizeIterator + '_ { - self.nearest.iter().map(|candidate| candidate.pid) + fn iter(&self) -> impl Iterator + ExactSizeIterator + '_ { + self.nearest.iter().copied() } #[doc(hidden)] - pub fn get(&self, i: usize) -> Option { - self.nearest.get(i).map(|candidate| candidate.pid) + pub fn get(&self, i: usize) -> Option { + self.nearest.get(i).copied() } } diff --git a/instant-distance/src/types.rs b/instant-distance/src/types.rs index 4163112..b9258a8 100644 --- a/instant-distance/src/types.rs +++ b/instant-distance/src/types.rs @@ -228,10 +228,19 @@ impl Iterator for DescendingLayerIter { } } +/// A potential nearest neighbor #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] -pub(crate) struct Candidate { +pub struct Candidate { pub(crate) distance: OrderedFloat, - pub(crate) pid: PointId, + /// The identifier for the neighboring point + pub pid: PointId, +} + +impl Candidate { + /// Distance to the neighboring point + pub fn distance(&self) -> f32 { + *self.distance + } } /// References a `Point` in the `Hnsw` diff --git a/instant-distance/tests/all.rs b/instant-distance/tests/all.rs index d054c63..bfb04f9 100644 --- a/instant-distance/tests/all.rs +++ b/instant-distance/tests/all.rs @@ -49,7 +49,10 @@ fn randomized(builder: Builder) -> (u64, usize) { .iter() .map(|(_, i)| pids[*i]) .collect::>(); - let found = results.take(100).collect::>(); + let found = results + .take(100) + .map(|candidate| candidate.pid) + .collect::>(); (seed, forced.intersection(&found).count()) }