Add load() and dump() methods

This commit is contained in:
Dirkjan Ochtman 2021-03-18 13:54:16 +01:00
parent bdcf58168a
commit 90cea0f4c0
2 changed files with 21 additions and 2 deletions

View File

@ -10,7 +10,8 @@ name = "instant_distance"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
instant-distance = { version = "0.2", path = "../instant-distance" } bincode = "1.3.1"
instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] }
pyo3 = { version = "0.13.2", features = ["extension-module"] } pyo3 = { version = "0.13.2", features = ["extension-module"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde-big-array = "0.3.2" serde-big-array = "0.3.2"

View File

@ -1,9 +1,11 @@
#![allow(clippy::from_iter_instead_of_collect)] #![allow(clippy::from_iter_instead_of_collect)]
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::iter::FromIterator; use std::iter::FromIterator;
use instant_distance::Point; use instant_distance::Point;
use pyo3::exceptions::PyTypeError; use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyList, PyModule}; use pyo3::types::{PyList, PyModule};
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python}; use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
@ -39,6 +41,22 @@ impl Hnsw {
Ok((Self { inner }, ids)) Ok((Self { inner }, ids))
} }
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
Ok(Self { inner: hnsw })
}
fn dump(&self, fname: &str) -> PyResult<()> {
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
bincode::serialize_into(f, &self.inner)
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
Ok(())
}
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
let point = FloatArray::try_from(point)?; let point = FloatArray::try_from(point)?;
let _ = self.inner.search(&point, &mut search.inner); let _ = self.inner.search(&point, &mut search.inner);