From 9b7e3b1486fe34e1c18fc9a3324382220c9f14dc Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Thu, 18 Mar 2021 16:36:15 +0100 Subject: [PATCH] py: use integers as point identifiers --- instant-distance-py/src/lib.rs | 33 +++++---------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index 3466a06..2adb4a9 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -8,13 +8,12 @@ use instant_distance::Point; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::types::{PyList, PyModule}; -use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; +use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python}; use serde::{Deserialize, Serialize}; use serde_big_array::big_array; #[pymodule] fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -30,14 +29,14 @@ struct Hnsw { #[pymethods] impl Hnsw { #[staticmethod] - fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec)> { + fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec)> { let points = input .into_iter() .map(FloatArray::try_from) .collect::, PyErr>>()?; let (inner, ids) = instant_distance::Builder::from(config).build(&points); - let ids = Vec::from_iter(ids.into_iter().map(PointId::from)); + let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); Ok((Self { inner }, ids)) } @@ -88,7 +87,7 @@ impl PyIterProtocol for Search { slf } - fn __next__(mut slf: PyRefMut) -> Option { + fn __next__(mut slf: PyRefMut) -> Option { let idx = match &slf.cur { Some(idx) => *idx, None => return None, @@ -103,7 +102,7 @@ impl PyIterProtocol for Search { }; slf.cur = Some(idx + 1); - Some(PointId::from(pid)) + Some(pid.into_inner()) } } @@ -204,28 +203,6 @@ impl From for instant_distance::Heuristic { } } -#[pyclass] -struct PointId { - inner: instant_distance::PointId, -} - -impl From for PointId { - fn from(inner: instant_distance::PointId) -> Self { - Self { inner } - } -} - -#[pyproto] -impl<'p> PyObjectProtocol<'p> for PointId { - fn __repr__(&self) -> PyResult { - Ok(format!("{:?}", self.inner)) - } - - fn __hash__(&'p self) -> PyResult { - Ok(self.inner.into_inner()) - } -} - #[repr(align(32))] #[derive(Clone, Deserialize, Serialize)] struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);