Allow access to the candidate distance
This commit is contained in:
parent
e9d0d99eb4
commit
89d066eb6b
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "instant-distance-py"
|
||||
version = "0.1.1"
|
||||
version = "0.2.0"
|
||||
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
||||
edition = "2018"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
@ -11,7 +11,7 @@ crate-type = ["cdylib"]
|
|||
|
||||
[dependencies]
|
||||
bincode = "1.3.1"
|
||||
instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] }
|
||||
instant-distance = { version = "0.3", path = "../instant-distance", features = ["with-serde"] }
|
||||
pyo3 = { version = "0.13.2", features = ["extension-module"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde-big-array = "0.3.2"
|
||||
|
|
|
@ -8,12 +8,13 @@ use instant_distance::Point;
|
|||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||
use pyo3::types::{PyList, PyModule};
|
||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_big_array::big_array;
|
||||
|
||||
#[pymodule]
|
||||
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<Candidate>()?;
|
||||
m.add_class::<Heuristic>()?;
|
||||
m.add_class::<Config>()?;
|
||||
m.add_class::<Search>()?;
|
||||
|
@ -104,14 +105,14 @@ impl PyIterProtocol for Search {
|
|||
}
|
||||
|
||||
/// Return the next closest point
|
||||
fn __next__(mut slf: PyRefMut<Self>) -> Option<u32> {
|
||||
fn __next__(mut slf: PyRefMut<Self>) -> Option<Candidate> {
|
||||
let idx = match &slf.cur {
|
||||
Some(idx) => *idx,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let pid = match slf.inner.get(idx) {
|
||||
Some(pid) => pid,
|
||||
let candidate = match slf.inner.get(idx) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
slf.cur = None;
|
||||
return None;
|
||||
|
@ -119,7 +120,10 @@ impl PyIterProtocol for Search {
|
|||
};
|
||||
|
||||
slf.cur = Some(idx + 1);
|
||||
Some(pid.into_inner())
|
||||
Some(Candidate {
|
||||
pid: candidate.pid.into_inner(),
|
||||
distance: candidate.distance(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -234,6 +238,24 @@ impl From<Heuristic> for instant_distance::Heuristic {
|
|||
}
|
||||
}
|
||||
|
||||
/// Search buffer and result set
|
||||
#[pyclass]
|
||||
struct Candidate {
|
||||
/// Identifier for the neighboring point
|
||||
#[pyo3(get)]
|
||||
pid: u32,
|
||||
/// Distance to the neighboring point
|
||||
#[pyo3(get)]
|
||||
distance: f32,
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for Candidate {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!("instant_distance.Candidate(pid={}, distance={})", self.pid, self.distance))
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(align(32))]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
||||
|
|
|
@ -7,8 +7,8 @@ def main():
|
|||
p = [random.random() for _ in range(300)]
|
||||
search = instant_distance.Search()
|
||||
hnsw.search(p, search)
|
||||
for pid in search:
|
||||
print(pid)
|
||||
for candidate in search:
|
||||
print(candidate)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "instant-distance"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
license = "MIT OR Apache-2.0"
|
||||
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
||||
edition = "2018"
|
||||
|
|
|
@ -329,7 +329,7 @@ where
|
|||
&self,
|
||||
point: &P,
|
||||
search: &'a mut Search,
|
||||
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a {
|
||||
) -> impl Iterator<Item = Candidate> + ExactSizeIterator + 'a {
|
||||
search.reset();
|
||||
if self.points.is_empty() {
|
||||
return search.iter();
|
||||
|
@ -670,13 +670,13 @@ impl Search {
|
|||
&self.nearest
|
||||
}
|
||||
|
||||
fn iter(&self) -> impl Iterator<Item = PointId> + ExactSizeIterator + '_ {
|
||||
self.nearest.iter().map(|candidate| candidate.pid)
|
||||
fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ {
|
||||
self.nearest.iter().copied()
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn get(&self, i: usize) -> Option<PointId> {
|
||||
self.nearest.get(i).map(|candidate| candidate.pid)
|
||||
pub fn get(&self, i: usize) -> Option<Candidate> {
|
||||
self.nearest.get(i).copied()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -228,10 +228,19 @@ impl Iterator for DescendingLayerIter {
|
|||
}
|
||||
}
|
||||
|
||||
/// A potential nearest neighbor
|
||||
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||
pub(crate) struct Candidate {
|
||||
pub struct Candidate {
|
||||
pub(crate) distance: OrderedFloat<f32>,
|
||||
pub(crate) pid: PointId,
|
||||
/// The identifier for the neighboring point
|
||||
pub pid: PointId,
|
||||
}
|
||||
|
||||
impl Candidate {
|
||||
/// Distance to the neighboring point
|
||||
pub fn distance(&self) -> f32 {
|
||||
*self.distance
|
||||
}
|
||||
}
|
||||
|
||||
/// References a `Point` in the `Hnsw`
|
||||
|
|
|
@ -49,7 +49,10 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
|||
.iter()
|
||||
.map(|(_, i)| pids[*i])
|
||||
.collect::<HashSet<_>>();
|
||||
let found = results.take(100).collect::<HashSet<_>>();
|
||||
let found = results
|
||||
.take(100)
|
||||
.map(|candidate| candidate.pid)
|
||||
.collect::<HashSet<_>>();
|
||||
(seed, forced.intersection(&found).count())
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue