py: store a generic type in HnswMap

This commit is contained in:
Dirkjan Ochtman 2021-05-21 17:31:47 +02:00
parent 4cacb38875
commit f6a5103abb
1 changed files with 30 additions and 4 deletions

View File

@ -8,7 +8,7 @@ use instant_distance::Point;
use pyo3::conversion::IntoPy; use pyo3::conversion::IntoPy;
use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule}; use pyo3::types::{PyList, PyModule, PyString};
use pyo3::{ use pyo3::{
Py, PyAny, PyErr, PyIterProtocol, PyObject, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python, Py, PyAny, PyErr, PyIterProtocol, PyObject, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python,
}; };
@ -28,19 +28,24 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
#[pyclass] #[pyclass]
struct HnswMap { struct HnswMap {
inner: instant_distance::HnswMap<FloatArray, String>, inner: instant_distance::HnswMap<FloatArray, MapValue>,
} }
#[pymethods] #[pymethods]
impl HnswMap { impl HnswMap {
/// Build the index /// Build the index
#[staticmethod] #[staticmethod]
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> { fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
let points = points let points = points
.into_iter() .into_iter()
.map(FloatArray::try_from) .map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?; .collect::<Result<Vec<_>, PyErr>>()?;
let values = values
.into_iter()
.map(MapValue::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let hsnw_map = instant_distance::Builder::from(config).build(points, values); let hsnw_map = instant_distance::Builder::from(config).build(points, values);
Ok(Self { inner: hsnw_map }) Ok(Self { inner: hsnw_map })
} }
@ -49,7 +54,7 @@ impl HnswMap {
#[staticmethod] #[staticmethod]
fn load(fname: &str) -> PyResult<Self> { fn load(fname: &str) -> PyResult<Self> {
let hnsw_map = let hnsw_map =
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, String>>( bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
) )
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?; .map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
@ -400,4 +405,25 @@ impl Point for FloatArray {
} }
} }
#[derive(Clone, Deserialize, Serialize)]
enum MapValue {
String(String),
}
impl TryFrom<&PyAny> for MapValue {
type Error = PyErr;
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
Ok(MapValue::String(value.extract::<String>()?))
}
}
impl IntoPy<Py<PyAny>> for &'_ MapValue {
fn into_py(self, py: Python<'_>) -> Py<PyAny> {
match self {
MapValue::String(s) => PyString::new(py, s).into(),
}
}
}
const DIMENSIONS: usize = 300; const DIMENSIONS: usize = 300;