diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index 5f74d90..eb976c2 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -8,7 +8,7 @@ use instant_distance::Point; use pyo3::conversion::IntoPy; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; -use pyo3::types::{PyList, PyModule}; +use pyo3::types::{PyList, PyModule, PyString}; use pyo3::{ Py, PyAny, PyErr, PyIterProtocol, PyObject, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python, }; @@ -28,19 +28,24 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { #[pyclass] struct HnswMap { - inner: instant_distance::HnswMap, + inner: instant_distance::HnswMap, } #[pymethods] impl HnswMap { /// Build the index #[staticmethod] - fn build(points: &PyList, values: Vec, config: &Config) -> PyResult { + fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult { let points = points .into_iter() .map(FloatArray::try_from) .collect::, PyErr>>()?; + let values = values + .into_iter() + .map(MapValue::try_from) + .collect::, PyErr>>()?; + let hsnw_map = instant_distance::Builder::from(config).build(points, values); Ok(Self { inner: hsnw_map }) } @@ -49,7 +54,7 @@ impl HnswMap { #[staticmethod] fn load(fname: &str) -> PyResult { let hnsw_map = - bincode::deserialize_from::<_, instant_distance::HnswMap>( + bincode::deserialize_from::<_, instant_distance::HnswMap>( BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), ) .map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?; @@ -400,4 +405,25 @@ impl Point for FloatArray { } } +#[derive(Clone, Deserialize, Serialize)] +enum MapValue { + String(String), +} + +impl TryFrom<&PyAny> for MapValue { + type Error = PyErr; + + fn try_from(value: &PyAny) -> Result { + Ok(MapValue::String(value.extract::()?)) + } +} + +impl IntoPy> for &'_ MapValue { + fn into_py(self, py: Python<'_>) -> Py { + match self { + MapValue::String(s) => PyString::new(py, s).into(), + } + } +} + const DIMENSIONS: usize = 300;