py: use integers as point identifiers
This commit is contained in:
parent
90cea0f4c0
commit
c9f62ad3fb
|
@ -8,13 +8,12 @@ use instant_distance::Point;
|
|||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||
use pyo3::types::{PyList, PyModule};
|
||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_big_array::big_array;
|
||||
|
||||
#[pymodule]
|
||||
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PointId>()?;
|
||||
m.add_class::<Heuristic>()?;
|
||||
m.add_class::<Config>()?;
|
||||
m.add_class::<Search>()?;
|
||||
|
@ -30,14 +29,14 @@ struct Hnsw {
|
|||
#[pymethods]
|
||||
impl Hnsw {
|
||||
#[staticmethod]
|
||||
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<PointId>)> {
|
||||
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
|
||||
let points = input
|
||||
.into_iter()
|
||||
.map(FloatArray::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let (inner, ids) = instant_distance::Builder::from(config).build(&points);
|
||||
let ids = Vec::from_iter(ids.into_iter().map(PointId::from));
|
||||
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
||||
Ok((Self { inner }, ids))
|
||||
}
|
||||
|
||||
|
@ -88,7 +87,7 @@ impl PyIterProtocol for Search {
|
|||
slf
|
||||
}
|
||||
|
||||
fn __next__(mut slf: PyRefMut<Self>) -> Option<PointId> {
|
||||
fn __next__(mut slf: PyRefMut<Self>) -> Option<u32> {
|
||||
let idx = match &slf.cur {
|
||||
Some(idx) => *idx,
|
||||
None => return None,
|
||||
|
@ -103,7 +102,7 @@ impl PyIterProtocol for Search {
|
|||
};
|
||||
|
||||
slf.cur = Some(idx + 1);
|
||||
Some(PointId::from(pid))
|
||||
Some(pid.into_inner())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,28 +203,6 @@ impl From<Heuristic> for instant_distance::Heuristic {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct PointId {
|
||||
inner: instant_distance::PointId,
|
||||
}
|
||||
|
||||
impl From<instant_distance::PointId> for PointId {
|
||||
fn from(inner: instant_distance::PointId) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl<'p> PyObjectProtocol<'p> for PointId {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!("{:?}", self.inner))
|
||||
}
|
||||
|
||||
fn __hash__(&'p self) -> PyResult<u32> {
|
||||
Ok(self.inner.into_inner())
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(align(32))]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
||||
|
|
Loading…
Reference in New Issue