Allow access to the candidate distance

This commit is contained in:
Dirkjan Ochtman 2021-03-24 15:52:20 +01:00
parent e9d0d99eb4
commit 89d066eb6b
7 changed files with 52 additions and 18 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "instant-distance-py"
version = "0.1.1"
version = "0.2.0"
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
edition = "2018"
license = "MIT OR Apache-2.0"
@ -11,7 +11,7 @@ crate-type = ["cdylib"]
[dependencies]
bincode = "1.3.1"
instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] }
instant-distance = { version = "0.3", path = "../instant-distance", features = ["with-serde"] }
pyo3 = { version = "0.13.2", features = ["extension-module"] }
serde = { version = "1", features = ["derive"] }
serde-big-array = "0.3.2"

View File

@ -8,12 +8,13 @@ use instant_distance::Point;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
use serde::{Deserialize, Serialize};
use serde_big_array::big_array;
#[pymodule]
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Candidate>()?;
m.add_class::<Heuristic>()?;
m.add_class::<Config>()?;
m.add_class::<Search>()?;
@ -104,14 +105,14 @@ impl PyIterProtocol for Search {
}
/// Return the next closest point
fn __next__(mut slf: PyRefMut<Self>) -> Option<u32> {
fn __next__(mut slf: PyRefMut<Self>) -> Option<Candidate> {
let idx = match &slf.cur {
Some(idx) => *idx,
None => return None,
};
let pid = match slf.inner.get(idx) {
Some(pid) => pid,
let candidate = match slf.inner.get(idx) {
Some(c) => c,
None => {
slf.cur = None;
return None;
@ -119,7 +120,10 @@ impl PyIterProtocol for Search {
};
slf.cur = Some(idx + 1);
Some(pid.into_inner())
Some(Candidate {
pid: candidate.pid.into_inner(),
distance: candidate.distance(),
})
}
}
@ -234,6 +238,24 @@ impl From<Heuristic> for instant_distance::Heuristic {
}
}
/// Search buffer and result set
#[pyclass]
struct Candidate {
/// Identifier for the neighboring point
#[pyo3(get)]
pid: u32,
/// Distance to the neighboring point
#[pyo3(get)]
distance: f32,
}
#[pyproto]
impl PyObjectProtocol for Candidate {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("instant_distance.Candidate(pid={}, distance={})", self.pid, self.distance))
}
}
#[repr(align(32))]
#[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);

View File

@ -7,8 +7,8 @@ def main():
p = [random.random() for _ in range(300)]
search = instant_distance.Search()
hnsw.search(p, search)
for pid in search:
print(pid)
for candidate in search:
print(candidate)
if __name__ == '__main__':
main()

View File

@ -1,6 +1,6 @@
[package]
name = "instant-distance"
version = "0.2.0"
version = "0.3.0"
license = "MIT OR Apache-2.0"
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
edition = "2018"

View File

@ -329,7 +329,7 @@ where
&self,
point: &P,
search: &'a mut Search,
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a {
) -> impl Iterator<Item = Candidate> + ExactSizeIterator + 'a {
search.reset();
if self.points.is_empty() {
return search.iter();
@ -670,13 +670,13 @@ impl Search {
&self.nearest
}
fn iter(&self) -> impl Iterator<Item = PointId> + ExactSizeIterator + '_ {
self.nearest.iter().map(|candidate| candidate.pid)
fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ {
self.nearest.iter().copied()
}
#[doc(hidden)]
pub fn get(&self, i: usize) -> Option<PointId> {
self.nearest.get(i).map(|candidate| candidate.pid)
pub fn get(&self, i: usize) -> Option<Candidate> {
self.nearest.get(i).copied()
}
}

View File

@ -228,10 +228,19 @@ impl Iterator for DescendingLayerIter {
}
}
/// A potential nearest neighbor
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub(crate) struct Candidate {
pub struct Candidate {
pub(crate) distance: OrderedFloat<f32>,
pub(crate) pid: PointId,
/// The identifier for the neighboring point
pub pid: PointId,
}
impl Candidate {
/// Distance to the neighboring point
pub fn distance(&self) -> f32 {
*self.distance
}
}
/// References a `Point` in the `Hnsw`

View File

@ -49,7 +49,10 @@ fn randomized(builder: Builder) -> (u64, usize) {
.iter()
.map(|(_, i)| pids[*i])
.collect::<HashSet<_>>();
let found = results.take(100).collect::<HashSet<_>>();
let found = results
.take(100)
.map(|candidate| candidate.pid)
.collect::<HashSet<_>>();
(seed, forced.intersection(&found).count())
}