py: use integers as point identifiers
This commit is contained in:
parent
90cea0f4c0
commit
c9f62ad3fb
|
@ -8,13 +8,12 @@ use instant_distance::Point;
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||||
use pyo3::types::{PyList, PyModule};
|
use pyo3::types::{PyList, PyModule};
|
||||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
use pyo3::{PyAny, PyErr, PyIterProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_big_array::big_array;
|
use serde_big_array::big_array;
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<PointId>()?;
|
|
||||||
m.add_class::<Heuristic>()?;
|
m.add_class::<Heuristic>()?;
|
||||||
m.add_class::<Config>()?;
|
m.add_class::<Config>()?;
|
||||||
m.add_class::<Search>()?;
|
m.add_class::<Search>()?;
|
||||||
|
@ -30,14 +29,14 @@ struct Hnsw {
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl Hnsw {
|
impl Hnsw {
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<PointId>)> {
|
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
|
||||||
let points = input
|
let points = input
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(FloatArray::try_from)
|
.map(FloatArray::try_from)
|
||||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
let (inner, ids) = instant_distance::Builder::from(config).build(&points);
|
let (inner, ids) = instant_distance::Builder::from(config).build(&points);
|
||||||
let ids = Vec::from_iter(ids.into_iter().map(PointId::from));
|
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
||||||
Ok((Self { inner }, ids))
|
Ok((Self { inner }, ids))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +87,7 @@ impl PyIterProtocol for Search {
|
||||||
slf
|
slf
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __next__(mut slf: PyRefMut<Self>) -> Option<PointId> {
|
fn __next__(mut slf: PyRefMut<Self>) -> Option<u32> {
|
||||||
let idx = match &slf.cur {
|
let idx = match &slf.cur {
|
||||||
Some(idx) => *idx,
|
Some(idx) => *idx,
|
||||||
None => return None,
|
None => return None,
|
||||||
|
@ -103,7 +102,7 @@ impl PyIterProtocol for Search {
|
||||||
};
|
};
|
||||||
|
|
||||||
slf.cur = Some(idx + 1);
|
slf.cur = Some(idx + 1);
|
||||||
Some(PointId::from(pid))
|
Some(pid.into_inner())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,28 +203,6 @@ impl From<Heuristic> for instant_distance::Heuristic {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyclass]
|
|
||||||
struct PointId {
|
|
||||||
inner: instant_distance::PointId,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<instant_distance::PointId> for PointId {
|
|
||||||
fn from(inner: instant_distance::PointId) -> Self {
|
|
||||||
Self { inner }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyproto]
|
|
||||||
impl<'p> PyObjectProtocol<'p> for PointId {
|
|
||||||
fn __repr__(&self) -> PyResult<String> {
|
|
||||||
Ok(format!("{:?}", self.inner))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __hash__(&'p self) -> PyResult<u32> {
|
|
||||||
Ok(self.inner.into_inner())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[repr(align(32))]
|
#[repr(align(32))]
|
||||||
#[derive(Clone, Deserialize, Serialize)]
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
||||||
|
|
Loading…
Reference in New Issue