diff --git a/instant-distance-py/Cargo.toml b/instant-distance-py/Cargo.toml index a038c60..fc67abd 100644 --- a/instant-distance-py/Cargo.toml +++ b/instant-distance-py/Cargo.toml @@ -10,7 +10,8 @@ name = "instant_distance" crate-type = ["cdylib"] [dependencies] -instant-distance = { version = "0.2", path = "../instant-distance" } +bincode = "1.3.1" +instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] } pyo3 = { version = "0.13.2", features = ["extension-module"] } serde = { version = "1", features = ["derive"] } serde-big-array = "0.3.2" diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index d417640..3466a06 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -1,9 +1,11 @@ #![allow(clippy::from_iter_instead_of_collect)] use std::convert::TryFrom; +use std::fs::File; +use std::io::{BufReader, BufWriter}; use std::iter::FromIterator; use instant_distance::Point; -use pyo3::exceptions::PyTypeError; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::types::{PyList, PyModule}; use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; @@ -39,6 +41,22 @@ impl Hnsw { Ok((Self { inner }, ids)) } + #[staticmethod] + fn load(fname: &str) -> PyResult { + let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw>( + BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), + ) + .map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?; + Ok(Self { inner: hnsw }) + } + + fn dump(&self, fname: &str) -> PyResult<()> { + let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?); + bincode::serialize_into(f, &self.inner) + .map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?; + Ok(()) + } + fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { let point = FloatArray::try_from(point)?; let _ = self.inner.search(&point, &mut search.inner);