py: store a generic type in HnswMap
This commit is contained in:
parent
4cacb38875
commit
f6a5103abb
|
@ -8,7 +8,7 @@ use instant_distance::Point;
|
||||||
use pyo3::conversion::IntoPy;
|
use pyo3::conversion::IntoPy;
|
||||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||||
use pyo3::types::{PyList, PyModule};
|
use pyo3::types::{PyList, PyModule, PyString};
|
||||||
use pyo3::{
|
use pyo3::{
|
||||||
Py, PyAny, PyErr, PyIterProtocol, PyObject, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python,
|
Py, PyAny, PyErr, PyIterProtocol, PyObject, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python,
|
||||||
};
|
};
|
||||||
|
@ -28,19 +28,24 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
struct HnswMap {
|
struct HnswMap {
|
||||||
inner: instant_distance::HnswMap<FloatArray, String>,
|
inner: instant_distance::HnswMap<FloatArray, MapValue>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl HnswMap {
|
impl HnswMap {
|
||||||
/// Build the index
|
/// Build the index
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
|
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
|
||||||
let points = points
|
let points = points
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(FloatArray::try_from)
|
.map(FloatArray::try_from)
|
||||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
|
let values = values
|
||||||
|
.into_iter()
|
||||||
|
.map(MapValue::try_from)
|
||||||
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
|
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
|
||||||
Ok(Self { inner: hsnw_map })
|
Ok(Self { inner: hsnw_map })
|
||||||
}
|
}
|
||||||
|
@ -49,7 +54,7 @@ impl HnswMap {
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn load(fname: &str) -> PyResult<Self> {
|
fn load(fname: &str) -> PyResult<Self> {
|
||||||
let hnsw_map =
|
let hnsw_map =
|
||||||
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, String>>(
|
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
|
||||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||||
)
|
)
|
||||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
||||||
|
@ -400,4 +405,25 @@ impl Point for FloatArray {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize)]
|
||||||
|
enum MapValue {
|
||||||
|
String(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&PyAny> for MapValue {
|
||||||
|
type Error = PyErr;
|
||||||
|
|
||||||
|
fn try_from(value: &PyAny) -> Result<Self, Self::Error> {
|
||||||
|
Ok(MapValue::String(value.extract::<String>()?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoPy<Py<PyAny>> for &'_ MapValue {
|
||||||
|
fn into_py(self, py: Python<'_>) -> Py<PyAny> {
|
||||||
|
match self {
|
||||||
|
MapValue::String(s) => PyString::new(py, s).into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const DIMENSIONS: usize = 300;
|
const DIMENSIONS: usize = 300;
|
||||||
|
|
Loading…
Reference in New Issue