Expose result item getters directly from Hnsw/HnswMap

This commit is contained in:
Dirkjan Ochtman 2021-05-21 16:48:34 +02:00
parent 07d2f1aedc
commit e9f77d8714
4 changed files with 91 additions and 57 deletions

View File

@ -8,13 +8,13 @@ use instant_distance::Point;
use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule}; use pyo3::types::{PyList, PyModule};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; use pyo3::{Py, PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_big_array::big_array; use serde_big_array::big_array;
#[pymodule] #[pymodule]
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Candidate>()?; m.add_class::<Neighbor>()?;
m.add_class::<Heuristic>()?; m.add_class::<Heuristic>()?;
m.add_class::<Config>()?; m.add_class::<Config>()?;
m.add_class::<Search>()?; m.add_class::<Search>()?;
@ -30,10 +30,6 @@ struct HnswMap {
#[pymethods] #[pymethods]
impl HnswMap { impl HnswMap {
#[getter]
fn values(&self) -> PyResult<Vec<String>> {
Ok(self.inner.values.clone())
}
/// Build the index /// Build the index
#[staticmethod] #[staticmethod]
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> { fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
@ -72,10 +68,10 @@ impl HnswMap {
/// to the `ef_search` parameter set in the index's `config`. /// to the `ef_search` parameter set in the index's `config`.
/// ///
/// For best performance, reusing `Search` objects is recommended. /// For best performance, reusing `Search` objects is recommended.
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
let point = FloatArray::try_from(point)?; let point = FloatArray::try_from(point)?;
let _ = self.inner.search(&point, &mut search.inner); let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
search.cur = Some(0); search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
Ok(()) Ok(())
} }
} }
@ -129,10 +125,10 @@ impl Hnsw {
/// to the `ef_search` parameter set in the index's `config`. /// to the `ef_search` parameter set in the index's `config`.
/// ///
/// For best performance, reusing `Search` objects is recommended. /// For best performance, reusing `Search` objects is recommended.
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
let point = FloatArray::try_from(point)?; let point = FloatArray::try_from(point)?;
let _ = self.inner.search(&point, &mut search.inner); let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
search.cur = Some(0); search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
Ok(()) Ok(())
} }
} }
@ -141,7 +137,7 @@ impl Hnsw {
#[pyclass] #[pyclass]
struct Search { struct Search {
inner: instant_distance::Search, inner: instant_distance::Search,
cur: Option<usize>, cur: Option<(HnswType, usize)>,
} }
#[pymethods] #[pymethods]
@ -163,28 +159,44 @@ impl PyIterProtocol for Search {
} }
/// Return the next closest point /// Return the next closest point
fn __next__(mut slf: PyRefMut<Self>) -> Option<Candidate> { fn __next__(mut slf: PyRefMut<Self>) -> Option<Neighbor> {
let idx = match &slf.cur { let (index, idx) = match slf.cur.take() {
Some(idx) => *idx, Some(x) => x,
None => return None, None => return None,
}; };
let candidate = match slf.inner.get(idx) { let py = slf.py();
Some(c) => c, let neighbor = match &index {
None => { HnswType::Hnsw(hnsw) => {
slf.cur = None; let hnsw = hnsw.as_ref(py).borrow();
return None; let item = hnsw.inner.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: None,
})
}
HnswType::Map(map) => {
let map = map.as_ref(py).borrow();
let item = map.inner.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: Some(item.value.to_owned()),
})
} }
}; };
slf.cur = Some(idx + 1); slf.cur = neighbor.as_ref().map(|_| (index, idx + 1));
Some(Candidate { neighbor
pid: candidate.pid.into_inner(),
distance: candidate.distance(),
})
} }
} }
enum HnswType {
Hnsw(Py<Hnsw>),
Map(Py<HnswMap>),
}
#[pyclass] #[pyclass]
#[derive(Copy, Clone, Default)] #[derive(Copy, Clone, Default)]
struct Config { struct Config {
@ -296,24 +308,33 @@ impl From<Heuristic> for instant_distance::Heuristic {
} }
} }
/// Search buffer and result set /// Item found by the nearest neighbor search
#[pyclass] #[pyclass]
struct Candidate { struct Neighbor {
/// Identifier for the neighboring point
#[pyo3(get)]
pid: u32,
/// Distance to the neighboring point /// Distance to the neighboring point
#[pyo3(get)] #[pyo3(get)]
distance: f32, distance: f32,
/// Identifier for the neighboring point
#[pyo3(get)]
pid: u32,
/// Value for the neighboring point (only set for `HnswMap` results)
#[pyo3(get)]
value: Option<String>,
} }
#[pyproto] #[pyproto]
impl PyObjectProtocol for Candidate { impl PyObjectProtocol for Neighbor {
fn __repr__(&self) -> PyResult<String> { fn __repr__(&self) -> PyResult<String> {
Ok(format!( match &self.value {
"instant_distance.Candidate(pid={}, distance={})", Some(s) => Ok(format!(
self.pid, self.distance "instant_distance.Neighbor(distance={}, pid={}, value={})",
)) self.distance, self.pid, s,
)),
None => Ok(format!(
"instant_distance.Item(distance={}, pid={})",
self.distance, self.pid,
)),
}
} }
} }

View File

@ -24,12 +24,11 @@ def test_hsnw_map():
search = instant_distance.Search() search = instant_distance.Search()
hnsw_map.search(embeddings[the_chosen_one], search) hnsw_map.search(embeddings[the_chosen_one], search)
first = next(search)
closest_pid = list(search)[0].pid approx_nearest = first.value
approx_nearest = hnsw_map.values[closest_pid]
actual_word = values[the_chosen_one] actual_word = values[the_chosen_one]
print("pid:\t\t", closest_pid)
print("approx word:\t", approx_nearest) print("approx word:\t", approx_nearest)
print("actual word:\t", actual_word) print("actual word:\t", actual_word)

View File

@ -156,18 +156,20 @@ where
point: &P, point: &P,
search: &'a mut Search, search: &'a mut Search,
) -> impl Iterator<Item = MapItem<'a, P, V>> + ExactSizeIterator + 'a { ) -> impl Iterator<Item = MapItem<'a, P, V>> + ExactSizeIterator + 'a {
self.hnsw.search(point, search).map(move |item| MapItem { self.hnsw
distance: item.distance, .search(point, search)
pid: item.pid, .map(move |item| MapItem::from(item, self))
point: item.point,
value: &self.values[item.pid.0 as usize],
})
} }
/// Iterate over the keys and values in this index /// Iterate over the keys and values in this index
pub fn iter(&self) -> impl Iterator<Item = (PointId, &P)> { pub fn iter(&self) -> impl Iterator<Item = (PointId, &P)> {
self.hnsw.iter() self.hnsw.iter()
} }
#[doc(hidden)]
pub fn get(&self, i: usize, search: &Search) -> Option<MapItem<'_, P, V>> {
Some(MapItem::from(self.hnsw.get(i, search)?, self))
}
} }
pub struct MapItem<'a, P, V> { pub struct MapItem<'a, P, V> {
@ -177,6 +179,17 @@ pub struct MapItem<'a, P, V> {
pub value: &'a V, pub value: &'a V,
} }
impl<'a, P, V> MapItem<'a, P, V> {
fn from(item: Item<'a, P>, map: &'a HnswMap<P, V>) -> Self {
MapItem {
distance: item.distance,
pid: item.pid,
point: item.point,
value: &map.values[item.pid.0 as usize],
}
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Hnsw<P> { pub struct Hnsw<P> {
ef_search: usize, ef_search: usize,
@ -377,6 +390,11 @@ where
.enumerate() .enumerate()
.map(|(i, p)| (PointId(i as u32), p)) .map(|(i, p)| (PointId(i as u32), p))
} }
#[doc(hidden)]
pub fn get(&self, i: usize, search: &Search) -> Option<Item<'_, P>> {
Some(Item::new(search.nearest.get(i).copied()?, self))
}
} }
pub struct Item<'a, P> { pub struct Item<'a, P> {
@ -746,11 +764,6 @@ impl Search {
fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ { fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ {
self.nearest.iter().copied() self.nearest.iter().copied()
} }
#[doc(hidden)]
pub fn get(&self, i: usize) -> Option<Candidate> {
self.nearest.get(i).copied()
}
} }
impl Default for Search { impl Default for Search {

View File

@ -236,13 +236,6 @@ pub struct Candidate {
pub pid: PointId, pub pid: PointId,
} }
impl Candidate {
/// Distance to the neighboring point
pub fn distance(&self) -> f32 {
*self.distance
}
}
/// References a `Point` in the `Hnsw` /// References a `Point` in the `Hnsw`
/// ///
/// This can be used to index into the `Hnsw` to refer to the `Point` data. /// This can be used to index into the `Hnsw` to refer to the `Point` data.
@ -262,6 +255,14 @@ impl PointId {
} }
} }
#[doc(hidden)]
// Not part of the public API; only for use in bindings
impl From<u32> for PointId {
fn from(id: u32) -> Self {
PointId(id)
}
}
impl Default for PointId { impl Default for PointId {
fn default() -> Self { fn default() -> Self {
INVALID INVALID