py: store a generic type in HnswMap

This commit is contained in:
Dirkjan Ochtman 2021-05-21 17:31:47 +02:00
parent 4cacb38875
commit f6a5103abb
1 changed files with 30 additions and 4 deletions

View File

@ -8,7 +8,7 @@ use instant_distance::Point;
use pyo3::conversion::IntoPy;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule};
use pyo3::types::{PyList, PyModule, PyString};
use pyo3::{
Py, PyAny, PyErr, PyIterProtocol, PyObject, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python,
};
@ -28,19 +28,24 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
#[pyclass]
struct HnswMap {
inner: instant_distance::HnswMap<FloatArray, String>,
inner: instant_distance::HnswMap<FloatArray, MapValue>,
}
#[pymethods]
impl HnswMap {
/// Build the index
#[staticmethod]
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
let points = points
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let values = values
.into_iter()
.map(MapValue::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
Ok(Self { inner: hsnw_map })
}
@ -49,7 +54,7 @@ impl HnswMap {
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw_map =
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, String>>(
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
@ -400,4 +405,25 @@ impl Point for FloatArray {
}
}
#[derive(Clone, Deserialize, Serialize)]
enum MapValue {
String(String),
}
impl TryFrom<&PyAny> for MapValue {
type Error = PyErr;
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
Ok(MapValue::String(value.extract::<String>()?))
}
}
impl IntoPy<Py<PyAny>> for &'_ MapValue {
fn into_py(self, py: Python<'_>) -> Py<PyAny> {
match self {
MapValue::String(s) => PyString::new(py, s).into(),
}
}
}
const DIMENSIONS: usize = 300;