Allow access to the candidate distance

This commit is contained in:
Dirkjan Ochtman 2021-03-24 15:52:20 +01:00
parent e9d0d99eb4
commit 89d066eb6b
7 changed files with 52 additions and 18 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "instant-distance-py" name = "instant-distance-py"
version = "0.1.1" version = "0.2.0"
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"] authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
edition = "2018" edition = "2018"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
@ -11,7 +11,7 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
bincode = "1.3.1" bincode = "1.3.1"
instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] } instant-distance = { version = "0.3", path = "../instant-distance", features = ["with-serde"] }
pyo3 = { version = "0.13.2", features = ["extension-module"] } pyo3 = { version = "0.13.2", features = ["extension-module"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde-big-array = "0.3.2" serde-big-array = "0.3.2"

View File

@ -8,12 +8,13 @@ use instant_distance::Point;
use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule}; use pyo3::types::{PyList, PyModule};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python}; use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_big_array::big_array; use serde_big_array::big_array;
#[pymodule] #[pymodule]
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Candidate>()?;
m.add_class::<Heuristic>()?; m.add_class::<Heuristic>()?;
m.add_class::<Config>()?; m.add_class::<Config>()?;
m.add_class::<Search>()?; m.add_class::<Search>()?;
@ -104,14 +105,14 @@ impl PyIterProtocol for Search {
} }
/// Return the next closest point /// Return the next closest point
fn __next__(mut slf: PyRefMut<Self>) -> Option<u32> { fn __next__(mut slf: PyRefMut<Self>) -> Option<Candidate> {
let idx = match &slf.cur { let idx = match &slf.cur {
Some(idx) => *idx, Some(idx) => *idx,
None => return None, None => return None,
}; };
let pid = match slf.inner.get(idx) { let candidate = match slf.inner.get(idx) {
Some(pid) => pid, Some(c) => c,
None => { None => {
slf.cur = None; slf.cur = None;
return None; return None;
@ -119,7 +120,10 @@ impl PyIterProtocol for Search {
}; };
slf.cur = Some(idx + 1); slf.cur = Some(idx + 1);
Some(pid.into_inner()) Some(Candidate {
pid: candidate.pid.into_inner(),
distance: candidate.distance(),
})
} }
} }
@ -234,6 +238,24 @@ impl From<Heuristic> for instant_distance::Heuristic {
} }
} }
/// Search buffer and result set
#[pyclass]
struct Candidate {
/// Identifier for the neighboring point
#[pyo3(get)]
pid: u32,
/// Distance to the neighboring point
#[pyo3(get)]
distance: f32,
}
#[pyproto]
impl PyObjectProtocol for Candidate {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("instant_distance.Candidate(pid={}, distance={})", self.pid, self.distance))
}
}
#[repr(align(32))] #[repr(align(32))]
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]); struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);

View File

@ -7,8 +7,8 @@ def main():
p = [random.random() for _ in range(300)] p = [random.random() for _ in range(300)]
search = instant_distance.Search() search = instant_distance.Search()
hnsw.search(p, search) hnsw.search(p, search)
for pid in search: for candidate in search:
print(pid) print(candidate)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,6 +1,6 @@
[package] [package]
name = "instant-distance" name = "instant-distance"
version = "0.2.0" version = "0.3.0"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"] authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
edition = "2018" edition = "2018"

View File

@ -329,7 +329,7 @@ where
&self, &self,
point: &P, point: &P,
search: &'a mut Search, search: &'a mut Search,
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a { ) -> impl Iterator<Item = Candidate> + ExactSizeIterator + 'a {
search.reset(); search.reset();
if self.points.is_empty() { if self.points.is_empty() {
return search.iter(); return search.iter();
@ -670,13 +670,13 @@ impl Search {
&self.nearest &self.nearest
} }
fn iter(&self) -> impl Iterator<Item = PointId> + ExactSizeIterator + '_ { fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ {
self.nearest.iter().map(|candidate| candidate.pid) self.nearest.iter().copied()
} }
#[doc(hidden)] #[doc(hidden)]
pub fn get(&self, i: usize) -> Option<PointId> { pub fn get(&self, i: usize) -> Option<Candidate> {
self.nearest.get(i).map(|candidate| candidate.pid) self.nearest.get(i).copied()
} }
} }

View File

@ -228,10 +228,19 @@ impl Iterator for DescendingLayerIter {
} }
} }
/// A potential nearest neighbor
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub(crate) struct Candidate { pub struct Candidate {
pub(crate) distance: OrderedFloat<f32>, pub(crate) distance: OrderedFloat<f32>,
pub(crate) pid: PointId, /// The identifier for the neighboring point
pub pid: PointId,
}
impl Candidate {
/// Distance to the neighboring point
pub fn distance(&self) -> f32 {
*self.distance
}
} }
/// References a `Point` in the `Hnsw` /// References a `Point` in the `Hnsw`

View File

@ -49,7 +49,10 @@ fn randomized(builder: Builder) -> (u64, usize) {
.iter() .iter()
.map(|(_, i)| pids[*i]) .map(|(_, i)| pids[*i])
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let found = results.take(100).collect::<HashSet<_>>(); let found = results
.take(100)
.map(|candidate| candidate.pid)
.collect::<HashSet<_>>();
(seed, forced.intersection(&found).count()) (seed, forced.intersection(&found).count())
} }