diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index dda381a..8c21b4a 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -8,13 +8,13 @@ use instant_distance::Point; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::types::{PyList, PyModule}; -use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; +use pyo3::{Py, PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; use serde::{Deserialize, Serialize}; use serde_big_array::big_array; #[pymodule] fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -30,10 +30,6 @@ struct HnswMap { #[pymethods] impl HnswMap { - #[getter] - fn values(&self) -> PyResult> { - Ok(self.inner.values.clone()) - } /// Build the index #[staticmethod] fn build(points: &PyList, values: Vec, config: &Config) -> PyResult { @@ -72,10 +68,10 @@ impl HnswMap { /// to the `ef_search` parameter set in the index's `config`. /// /// For best performance, reusing `Search` objects is recommended. - fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { + fn search(slf: Py, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { let point = FloatArray::try_from(point)?; - let _ = self.inner.search(&point, &mut search.inner); - search.cur = Some(0); + let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); + search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0)); Ok(()) } } @@ -129,10 +125,10 @@ impl Hnsw { /// to the `ef_search` parameter set in the index's `config`. /// /// For best performance, reusing `Search` objects is recommended. - fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { + fn search(slf: Py, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> { let point = FloatArray::try_from(point)?; - let _ = self.inner.search(&point, &mut search.inner); - search.cur = Some(0); + let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner); + search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0)); Ok(()) } } @@ -141,7 +137,7 @@ impl Hnsw { #[pyclass] struct Search { inner: instant_distance::Search, - cur: Option, + cur: Option<(HnswType, usize)>, } #[pymethods] @@ -163,28 +159,44 @@ impl PyIterProtocol for Search { } /// Return the next closest point - fn __next__(mut slf: PyRefMut) -> Option { - let idx = match &slf.cur { - Some(idx) => *idx, + fn __next__(mut slf: PyRefMut) -> Option { + let (index, idx) = match slf.cur.take() { + Some(x) => x, None => return None, }; - let candidate = match slf.inner.get(idx) { - Some(c) => c, - None => { - slf.cur = None; - return None; + let py = slf.py(); + let neighbor = match &index { + HnswType::Hnsw(hnsw) => { + let hnsw = hnsw.as_ref(py).borrow(); + let item = hnsw.inner.get(idx, &slf.inner); + item.map(|item| Neighbor { + distance: item.distance, + pid: item.pid.into_inner(), + value: None, + }) + } + HnswType::Map(map) => { + let map = map.as_ref(py).borrow(); + let item = map.inner.get(idx, &slf.inner); + item.map(|item| Neighbor { + distance: item.distance, + pid: item.pid.into_inner(), + value: Some(item.value.to_owned()), + }) } }; - slf.cur = Some(idx + 1); - Some(Candidate { - pid: candidate.pid.into_inner(), - distance: candidate.distance(), - }) + slf.cur = neighbor.as_ref().map(|_| (index, idx + 1)); + neighbor } } +enum HnswType { + Hnsw(Py), + Map(Py), +} + #[pyclass] #[derive(Copy, Clone, Default)] struct Config { @@ -296,24 +308,33 @@ impl From for instant_distance::Heuristic { } } -/// Search buffer and result set +/// Item found by the nearest neighbor search #[pyclass] -struct Candidate { - /// Identifier for the neighboring point - #[pyo3(get)] - pid: u32, +struct Neighbor { /// Distance to the neighboring point #[pyo3(get)] distance: f32, + /// Identifier for the neighboring point + #[pyo3(get)] + pid: u32, + /// Value for the neighboring point (only set for `HnswMap` results) + #[pyo3(get)] + value: Option, } #[pyproto] -impl PyObjectProtocol for Candidate { +impl PyObjectProtocol for Neighbor { fn __repr__(&self) -> PyResult { - Ok(format!( - "instant_distance.Candidate(pid={}, distance={})", - self.pid, self.distance - )) + match &self.value { + Some(s) => Ok(format!( + "instant_distance.Neighbor(distance={}, pid={}, value={})", + self.distance, self.pid, s, + )), + None => Ok(format!( + "instant_distance.Item(distance={}, pid={})", + self.distance, self.pid, + )), + } } } diff --git a/instant-distance-py/test/test.py b/instant-distance-py/test/test.py index 931f697..3ec2e23 100644 --- a/instant-distance-py/test/test.py +++ b/instant-distance-py/test/test.py @@ -24,12 +24,11 @@ def test_hsnw_map(): search = instant_distance.Search() hnsw_map.search(embeddings[the_chosen_one], search) + first = next(search) - closest_pid = list(search)[0].pid - approx_nearest = hnsw_map.values[closest_pid] + approx_nearest = first.value actual_word = values[the_chosen_one] - print("pid:\t\t", closest_pid) print("approx word:\t", approx_nearest) print("actual word:\t", actual_word) diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index 98c8adf..82cc867 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -156,18 +156,20 @@ where point: &P, search: &'a mut Search, ) -> impl Iterator> + ExactSizeIterator + 'a { - self.hnsw.search(point, search).map(move |item| MapItem { - distance: item.distance, - pid: item.pid, - point: item.point, - value: &self.values[item.pid.0 as usize], - }) + self.hnsw + .search(point, search) + .map(move |item| MapItem::from(item, self)) } /// Iterate over the keys and values in this index pub fn iter(&self) -> impl Iterator { self.hnsw.iter() } + + #[doc(hidden)] + pub fn get(&self, i: usize, search: &Search) -> Option> { + Some(MapItem::from(self.hnsw.get(i, search)?, self)) + } } pub struct MapItem<'a, P, V> { @@ -177,6 +179,17 @@ pub struct MapItem<'a, P, V> { pub value: &'a V, } +impl<'a, P, V> MapItem<'a, P, V> { + fn from(item: Item<'a, P>, map: &'a HnswMap) -> Self { + MapItem { + distance: item.distance, + pid: item.pid, + point: item.point, + value: &map.values[item.pid.0 as usize], + } + } +} + #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Hnsw

{ ef_search: usize, @@ -377,6 +390,11 @@ where .enumerate() .map(|(i, p)| (PointId(i as u32), p)) } + + #[doc(hidden)] + pub fn get(&self, i: usize, search: &Search) -> Option> { + Some(Item::new(search.nearest.get(i).copied()?, self)) + } } pub struct Item<'a, P> { @@ -746,11 +764,6 @@ impl Search { fn iter(&self) -> impl Iterator + ExactSizeIterator + '_ { self.nearest.iter().copied() } - - #[doc(hidden)] - pub fn get(&self, i: usize) -> Option { - self.nearest.get(i).copied() - } } impl Default for Search { diff --git a/instant-distance/src/types.rs b/instant-distance/src/types.rs index b9258a8..f2845f1 100644 --- a/instant-distance/src/types.rs +++ b/instant-distance/src/types.rs @@ -236,13 +236,6 @@ pub struct Candidate { pub pid: PointId, } -impl Candidate { - /// Distance to the neighboring point - pub fn distance(&self) -> f32 { - *self.distance - } -} - /// References a `Point` in the `Hnsw` /// /// This can be used to index into the `Hnsw` to refer to the `Point` data. @@ -262,6 +255,14 @@ impl PointId { } } +#[doc(hidden)] +// Not part of the public API; only for use in bindings +impl From for PointId { + fn from(id: u32) -> Self { + PointId(id) + } +} + impl Default for PointId { fn default() -> Self { INVALID