py: use integers as point identifiers

This commit is contained in:
Dirkjan Ochtman 2021-03-18 16:36:15 +01:00
parent 90cea0f4c0
commit c9f62ad3fb
1 changed files with 5 additions and 28 deletions

View File

@ -8,13 +8,12 @@ use instant_distance::Point;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python};
use serde::{Deserialize, Serialize};
use serde_big_array::big_array;
#[pymodule]
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PointId>()?;
m.add_class::<Heuristic>()?;
m.add_class::<Config>()?;
m.add_class::<Search>()?;
@ -30,14 +29,14 @@ struct Hnsw {
#[pymethods]
impl Hnsw {
#[staticmethod]
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<PointId>)> {
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
let points = input
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let (inner, ids) = instant_distance::Builder::from(config).build(&points);
let ids = Vec::from_iter(ids.into_iter().map(PointId::from));
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
Ok((Self { inner }, ids))
}
@ -88,7 +87,7 @@ impl PyIterProtocol for Search {
slf
}
fn __next__(mut slf: PyRefMut<Self>) -> Option<PointId> {
fn __next__(mut slf: PyRefMut<Self>) -> Option<u32> {
let idx = match &slf.cur {
Some(idx) => *idx,
None => return None,
@ -103,7 +102,7 @@ impl PyIterProtocol for Search {
};
slf.cur = Some(idx + 1);
Some(PointId::from(pid))
Some(pid.into_inner())
}
}
@ -204,28 +203,6 @@ impl From<Heuristic> for instant_distance::Heuristic {
}
}
#[pyclass]
struct PointId {
inner: instant_distance::PointId,
}
impl From<instant_distance::PointId> for PointId {
fn from(inner: instant_distance::PointId) -> Self {
Self { inner }
}
}
#[pyproto]
impl<'p> PyObjectProtocol<'p> for PointId {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("{:?}", self.inner))
}
fn __hash__(&'p self) -> PyResult<u32> {
Ok(self.inner.into_inner())
}
}
#[repr(align(32))]
#[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);