py: store a generic type in HnswMap
This commit is contained in:
parent
0dd6a4ece8
commit
ed9a488a27
|
@ -8,7 +8,7 @@ use instant_distance::Point;
|
|||
use pyo3::conversion::IntoPy;
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||
use pyo3::types::{PyList, PyModule};
|
||||
use pyo3::types::{PyList, PyModule, PyString};
|
||||
use pyo3::{
|
||||
Py, PyAny, PyErr, PyIterProtocol, PyObject, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python,
|
||||
};
|
||||
|
@ -28,19 +28,24 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
|||
|
||||
#[pyclass]
|
||||
struct HnswMap {
|
||||
inner: instant_distance::HnswMap<FloatArray, String>,
|
||||
inner: instant_distance::HnswMap<FloatArray, MapValue>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl HnswMap {
|
||||
/// Build the index
|
||||
#[staticmethod]
|
||||
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
|
||||
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
|
||||
let points = points
|
||||
.into_iter()
|
||||
.map(FloatArray::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let values = values
|
||||
.into_iter()
|
||||
.map(MapValue::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
|
||||
Ok(Self { inner: hsnw_map })
|
||||
}
|
||||
|
@ -49,7 +54,7 @@ impl HnswMap {
|
|||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw_map =
|
||||
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, String>>(
|
||||
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
||||
|
@ -400,4 +405,25 @@ impl Point for FloatArray {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
enum MapValue {
|
||||
String(String),
|
||||
}
|
||||
|
||||
impl TryFrom<&PyAny> for MapValue {
|
||||
type Error = PyErr;
|
||||
|
||||
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
|
||||
Ok(MapValue::String(value.extract::<String>()?))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPy<Py<PyAny>> for &'_ MapValue {
|
||||
fn into_py(self, py: Python<'_>) -> Py<PyAny> {
|
||||
match self {
|
||||
MapValue::String(s) => PyString::new(py, s).into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DIMENSIONS: usize = 300;
|
||||
|
|
Loading…
Reference in New Issue