Allow access to the candidate distance
This commit is contained in:
parent
e9d0d99eb4
commit
89d066eb6b
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "instant-distance-py"
|
name = "instant-distance-py"
|
||||||
version = "0.1.1"
|
version = "0.2.0"
|
||||||
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
@ -11,7 +11,7 @@ crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bincode = "1.3.1"
|
bincode = "1.3.1"
|
||||||
instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] }
|
instant-distance = { version = "0.3", path = "../instant-distance", features = ["with-serde"] }
|
||||||
pyo3 = { version = "0.13.2", features = ["extension-module"] }
|
pyo3 = { version = "0.13.2", features = ["extension-module"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde-big-array = "0.3.2"
|
serde-big-array = "0.3.2"
|
||||||
|
|
|
@ -8,12 +8,13 @@ use instant_distance::Point;
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||||
use pyo3::types::{PyList, PyModule};
|
use pyo3::types::{PyList, PyModule};
|
||||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python};
|
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_big_array::big_array;
|
use serde_big_array::big_array;
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_class::<Candidate>()?;
|
||||||
m.add_class::<Heuristic>()?;
|
m.add_class::<Heuristic>()?;
|
||||||
m.add_class::<Config>()?;
|
m.add_class::<Config>()?;
|
||||||
m.add_class::<Search>()?;
|
m.add_class::<Search>()?;
|
||||||
|
@ -104,14 +105,14 @@ impl PyIterProtocol for Search {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the next closest point
|
/// Return the next closest point
|
||||||
fn __next__(mut slf: PyRefMut<Self>) -> Option<u32> {
|
fn __next__(mut slf: PyRefMut<Self>) -> Option<Candidate> {
|
||||||
let idx = match &slf.cur {
|
let idx = match &slf.cur {
|
||||||
Some(idx) => *idx,
|
Some(idx) => *idx,
|
||||||
None => return None,
|
None => return None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let pid = match slf.inner.get(idx) {
|
let candidate = match slf.inner.get(idx) {
|
||||||
Some(pid) => pid,
|
Some(c) => c,
|
||||||
None => {
|
None => {
|
||||||
slf.cur = None;
|
slf.cur = None;
|
||||||
return None;
|
return None;
|
||||||
|
@ -119,7 +120,10 @@ impl PyIterProtocol for Search {
|
||||||
};
|
};
|
||||||
|
|
||||||
slf.cur = Some(idx + 1);
|
slf.cur = Some(idx + 1);
|
||||||
Some(pid.into_inner())
|
Some(Candidate {
|
||||||
|
pid: candidate.pid.into_inner(),
|
||||||
|
distance: candidate.distance(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,6 +238,24 @@ impl From<Heuristic> for instant_distance::Heuristic {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Search buffer and result set
|
||||||
|
#[pyclass]
|
||||||
|
struct Candidate {
|
||||||
|
/// Identifier for the neighboring point
|
||||||
|
#[pyo3(get)]
|
||||||
|
pid: u32,
|
||||||
|
/// Distance to the neighboring point
|
||||||
|
#[pyo3(get)]
|
||||||
|
distance: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyproto]
|
||||||
|
impl PyObjectProtocol for Candidate {
|
||||||
|
fn __repr__(&self) -> PyResult<String> {
|
||||||
|
Ok(format!("instant_distance.Candidate(pid={}, distance={})", self.pid, self.distance))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[repr(align(32))]
|
#[repr(align(32))]
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
||||||
|
|
|
@ -7,8 +7,8 @@ def main():
|
||||||
p = [random.random() for _ in range(300)]
|
p = [random.random() for _ in range(300)]
|
||||||
search = instant_distance.Search()
|
search = instant_distance.Search()
|
||||||
hnsw.search(p, search)
|
hnsw.search(p, search)
|
||||||
for pid in search:
|
for candidate in search:
|
||||||
print(pid)
|
print(candidate)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "instant-distance"
|
name = "instant-distance"
|
||||||
version = "0.2.0"
|
version = "0.3.0"
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
authors = ["Dirkjan Ochtman <dirkjan@ochtman.nl>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
|
@ -329,7 +329,7 @@ where
|
||||||
&self,
|
&self,
|
||||||
point: &P,
|
point: &P,
|
||||||
search: &'a mut Search,
|
search: &'a mut Search,
|
||||||
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a {
|
) -> impl Iterator<Item = Candidate> + ExactSizeIterator + 'a {
|
||||||
search.reset();
|
search.reset();
|
||||||
if self.points.is_empty() {
|
if self.points.is_empty() {
|
||||||
return search.iter();
|
return search.iter();
|
||||||
|
@ -670,13 +670,13 @@ impl Search {
|
||||||
&self.nearest
|
&self.nearest
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iter(&self) -> impl Iterator<Item = PointId> + ExactSizeIterator + '_ {
|
fn iter(&self) -> impl Iterator<Item = Candidate> + ExactSizeIterator + '_ {
|
||||||
self.nearest.iter().map(|candidate| candidate.pid)
|
self.nearest.iter().copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub fn get(&self, i: usize) -> Option<PointId> {
|
pub fn get(&self, i: usize) -> Option<Candidate> {
|
||||||
self.nearest.get(i).map(|candidate| candidate.pid)
|
self.nearest.get(i).copied()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -228,10 +228,19 @@ impl Iterator for DescendingLayerIter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A potential nearest neighbor
|
||||||
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||||
pub(crate) struct Candidate {
|
pub struct Candidate {
|
||||||
pub(crate) distance: OrderedFloat<f32>,
|
pub(crate) distance: OrderedFloat<f32>,
|
||||||
pub(crate) pid: PointId,
|
/// The identifier for the neighboring point
|
||||||
|
pub pid: PointId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Candidate {
|
||||||
|
/// Distance to the neighboring point
|
||||||
|
pub fn distance(&self) -> f32 {
|
||||||
|
*self.distance
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// References a `Point` in the `Hnsw`
|
/// References a `Point` in the `Hnsw`
|
||||||
|
|
|
@ -49,7 +49,10 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, i)| pids[*i])
|
.map(|(_, i)| pids[*i])
|
||||||
.collect::<HashSet<_>>();
|
.collect::<HashSet<_>>();
|
||||||
let found = results.take(100).collect::<HashSet<_>>();
|
let found = results
|
||||||
|
.take(100)
|
||||||
|
.map(|candidate| candidate.pid)
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
(seed, forced.intersection(&found).count())
|
(seed, forced.intersection(&found).count())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue