Expose result item getters directly from Hnsw/HnswMap
This commit is contained in:
parent
07d2f1aedc
commit
e9f77d8714
|
@ -8,13 +8,13 @@ use instant_distance::Point;
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||||
use pyo3::types::{PyList, PyModule};
|
use pyo3::types::{PyList, PyModule};
|
||||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
use pyo3::{Py, PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_big_array::big_array;
|
use serde_big_array::big_array;
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<Candidate>()?;
|
m.add_class::<Neighbor>()?;
|
||||||
m.add_class::<Heuristic>()?;
|
m.add_class::<Heuristic>()?;
|
||||||
m.add_class::<Config>()?;
|
m.add_class::<Config>()?;
|
||||||
m.add_class::<Search>()?;
|
m.add_class::<Search>()?;
|
||||||
|
@ -30,10 +30,6 @@ struct HnswMap {
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl HnswMap {
|
impl HnswMap {
|
||||||
#[getter]
|
|
||||||
fn values(&self) -> PyResult<Vec<String>> {
|
|
||||||
Ok(self.inner.values.clone())
|
|
||||||
}
|
|
||||||
/// Build the index
|
/// Build the index
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
|
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
|
||||||
|
@ -72,10 +68,10 @@ impl HnswMap {
|
||||||
/// to the `ef_search` parameter set in the index's `config`.
|
/// to the `ef_search` parameter set in the index's `config`.
|
||||||
///
|
///
|
||||||
/// For best performance, reusing `Search` objects is recommended.
|
/// For best performance, reusing `Search` objects is recommended.
|
||||||
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
|
||||||
let point = FloatArray::try_from(point)?;
|
let point = FloatArray::try_from(point)?;
|
||||||
let _ = self.inner.search(&point, &mut search.inner);
|
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
|
||||||
search.cur = Some(0);
|
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -129,10 +125,10 @@ impl Hnsw {
|
||||||
/// to the `ef_search` parameter set in the index's `config`.
|
/// to the `ef_search` parameter set in the index's `config`.
|
||||||
///
|
///
|
||||||
/// For best performance, reusing `Search` objects is recommended.
|
/// For best performance, reusing `Search` objects is recommended.
|
||||||
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
|
||||||
let point = FloatArray::try_from(point)?;
|
let point = FloatArray::try_from(point)?;
|
||||||
let _ = self.inner.search(&point, &mut search.inner);
|
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
|
||||||
search.cur = Some(0);
|
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -141,7 +137,7 @@ impl Hnsw {
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
struct Search {
|
struct Search {
|
||||||
inner: instant_distance::Search,
|
inner: instant_distance::Search,
|
||||||
cur: Option<usize>,
|
cur: Option<(HnswType, usize)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
|
@ -163,28 +159,44 @@ impl PyIterProtocol for Search {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the next closest point
|
/// Return the next closest point
|
||||||
fn __next__(mut slf: PyRefMut<Self>) -> Option<Candidate> {
|
fn __next__(mut slf: PyRefMut<Self>) -> Option<Neighbor> {
|
||||||
let idx = match &slf.cur {
|
let (index, idx) = match slf.cur.take() {
|
||||||
Some(idx) => *idx,
|
Some(x) => x,
|
||||||
None => return None,
|
None => return None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let candidate = match slf.inner.get(idx) {
|
let py = slf.py();
|
||||||
Some(c) => c,
|
let neighbor = match &index {
|
||||||
None => {
|
HnswType::Hnsw(hnsw) => {
|
||||||
slf.cur = None;
|
let hnsw = hnsw.as_ref(py).borrow();
|
||||||
return None;
|
let item = hnsw.inner.get(idx, &slf.inner);
|
||||||
|
item.map(|item| Neighbor {
|
||||||
|
distance: item.distance,
|
||||||
|
pid: item.pid.into_inner(),
|
||||||
|
value: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
HnswType::Map(map) => {
|
||||||
|
let map = map.as_ref(py).borrow();
|
||||||
|
let item = map.inner.get(idx, &slf.inner);
|
||||||
|
item.map(|item| Neighbor {
|
||||||
|
distance: item.distance,
|
||||||
|
pid: item.pid.into_inner(),
|
||||||
|
value: Some(item.value.to_owned()),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
slf.cur = Some(idx + 1);
|
slf.cur = neighbor.as_ref().map(|_| (index, idx + 1));
|
||||||
Some(Candidate {
|
neighbor
|
||||||
pid: candidate.pid.into_inner(),
|
|
||||||
distance: candidate.distance(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum HnswType {
|
||||||
|
Hnsw(Py<Hnsw>),
|
||||||
|
Map(Py<HnswMap>),
|
||||||
|
}
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
#[derive(Copy, Clone, Default)]
|
#[derive(Copy, Clone, Default)]
|
||||||
struct Config {
|
struct Config {
|
||||||
|
@ -296,24 +308,33 @@ impl From<Heuristic> for instant_distance::Heuristic {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Search buffer and result set
|
/// Item found by the nearest neighbor search
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
struct Candidate {
|
struct Neighbor {
|
||||||
/// Identifier for the neighboring point
|
|
||||||
#[pyo3(get)]
|
|
||||||
pid: u32,
|
|
||||||
/// Distance to the neighboring point
|
/// Distance to the neighboring point
|
||||||
#[pyo3(get)]
|
#[pyo3(get)]
|
||||||
distance: f32,
|
distance: f32,
|
||||||
|
/// Identifier for the neighboring point
|
||||||
|
#[pyo3(get)]
|
||||||
|
pid: u32,
|
||||||
|
/// Value for the neighboring point (only set for `HnswMap` results)
|
||||||
|
#[pyo3(get)]
|
||||||
|
value: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyproto]
|
#[pyproto]
|
||||||
impl PyObjectProtocol for Candidate {
|
impl PyObjectProtocol for Neighbor {
|
||||||
fn __repr__(&self) -> PyResult<String> {
|
fn __repr__(&self) -> PyResult<String> {
|
||||||
Ok(format!(
|
match &self.value {
|
||||||
"instant_distance.Candidate(pid={}, distance={})",
|
Some(s) => Ok(format!(
|
||||||
self.pid, self.distance
|
"instant_distance.Neighbor(distance={}, pid={}, value={})",
|
||||||
))
|
self.distance, self.pid, s,
|
||||||
|
)),
|
||||||
|
None => Ok(format!(
|
||||||
|
"instant_distance.Item(distance={}, pid={})",
|
||||||
|
self.distance, self.pid,
|
||||||
|
)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,12 +24,11 @@ def test_hsnw_map():
|
||||||
|
|
||||||
search = instant_distance.Search()
|
search = instant_distance.Search()
|
||||||
hnsw_map.search(embeddings[the_chosen_one], search)
|
hnsw_map.search(embeddings[the_chosen_one], search)
|
||||||
|
first = next(search)
|
||||||
|
|
||||||
closest_pid = list(search)[0].pid
|
approx_nearest = first.value
|
||||||
approx_nearest = hnsw_map.values[closest_pid]
|
|
||||||
actual_word = values[the_chosen_one]
|
actual_word = values[the_chosen_one]
|
||||||
|
|
||||||
print("pid:\t\t", closest_pid)
|
|
||||||
print("approx word:\t", approx_nearest)
|
print("approx word:\t", approx_nearest)
|
||||||
print("actual word:\t", actual_word)
|
print("actual word:\t", actual_word)
|
||||||
|
|
||||||
|
|
|
@ -156,18 +156,20 @@ where
|
||||||
point: &P,
|
point: &P,
|
||||||
search: &'a mut Search,
|
search: &'a mut Search,
|
||||||
) -> impl Iterator<Item = MapItem<'a, P, V>> + ExactSizeIterator + 'a {
|
) -> impl Iterator<Item = MapItem<'a, P, V>> + ExactSizeIterator + 'a {
|
||||||
self.hnsw.search(point, search).map(move |item| MapItem {
|
self.hnsw
|
||||||
distance: item.distance,
|
.search(point, search)
|
||||||
pid: item.pid,
|
.map(move |item| MapItem::from(item, self))
|
||||||
point: item.point,
|
|
||||||
value: &self.values[item.pid.0 as usize],
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Iterate over the keys and values in this index
|
/// Iterate over the keys and values in this index
|
||||||
pub fn iter(&self) -> impl Iterator<Item = (PointId, &P)> {
|
pub fn iter(&self) -> impl Iterator<Item = (PointId, &P)> {
|
||||||
self.hnsw.iter()
|
self.hnsw.iter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub fn get(&self, i: usize, search: &Search) -> Option<MapItem<'_, P, V>> {
|
||||||
|
Some(MapItem::from(self.hnsw.get(i, search)?, self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MapItem<'a, P, V> {
|
pub struct MapItem<'a, P, V> {
|
||||||
|
@ -177,6 +179,17 @@ pub struct MapItem<'a, P, V> {
|
||||||
pub value: &'a V,
|
pub value: &'a V,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a, P, V> MapItem<'a, P, V> {
|
||||||
|
fn from(item: Item<'a, P>, map: &'a HnswMap<P, V>) -> Self {
|
||||||
|
MapItem {
|
||||||
|
distance: item.distance,
|
||||||
|
pid: item.pid,
|
||||||
|
point: item.point,
|
||||||
|
value: &map.values[item.pid.0 as usize],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||||
pub struct Hnsw<P> {
|
pub struct Hnsw<P> {
|
||||||
ef_search: usize,
|
ef_search: usize,
|
||||||
|
@ -377,6 +390,11 @@ where
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, p)| (PointId(i as u32), p))
|
.map(|(i, p)| (PointId(i as u32), p))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[doc(hidden)]
|
||||||
|
pub fn get(&self, i: usize, search: &Search) -> Option<Item<'_, P>> {
|
||||||
|
Some(Item::new(search.nearest.get(i).copied()?, self))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Item<'a, P> {
|
pub struct Item<'a, P> {
|
||||||
|
@ -746,11 +764,6 @@ impl Search {
|
||||||
fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ {
|
fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ {
|
||||||
self.nearest.iter().copied()
|
self.nearest.iter().copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)]
|
|
||||||
pub fn get(&self, i: usize) -> Option<Candidate> {
|
|
||||||
self.nearest.get(i).copied()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Search {
|
impl Default for Search {
|
||||||
|
|
|
@ -236,13 +236,6 @@ pub struct Candidate {
|
||||||
pub pid: PointId,
|
pub pid: PointId,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Candidate {
|
|
||||||
/// Distance to the neighboring point
|
|
||||||
pub fn distance(&self) -> f32 {
|
|
||||||
*self.distance
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// References a `Point` in the `Hnsw`
|
/// References a `Point` in the `Hnsw`
|
||||||
///
|
///
|
||||||
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
/// This can be used to index into the `Hnsw` to refer to the `Point` data.
|
||||||
|
@ -262,6 +255,14 @@ impl PointId {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[doc(hidden)]
|
||||||
|
// Not part of the public API; only for use in bindings
|
||||||
|
impl From<u32> for PointId {
|
||||||
|
fn from(id: u32) -> Self {
|
||||||
|
PointId(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for PointId {
|
impl Default for PointId {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
INVALID
|
INVALID
|
||||||
|
|
Loading…
Reference in New Issue