Add load() and dump() methods
This commit is contained in:
parent
bdcf58168a
commit
90cea0f4c0
|
@ -10,7 +10,8 @@ name = "instant_distance"
|
||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
instant-distance = { version = "0.2", path = "../instant-distance" }
|
bincode = "1.3.1"
|
||||||
|
instant-distance = { version = "0.2", path = "../instant-distance", features = ["with-serde"] }
|
||||||
pyo3 = { version = "0.13.2", features = ["extension-module"] }
|
pyo3 = { version = "0.13.2", features = ["extension-module"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde-big-array = "0.3.2"
|
serde-big-array = "0.3.2"
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
#![allow(clippy::from_iter_instead_of_collect)]
|
#![allow(clippy::from_iter_instead_of_collect)]
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{BufReader, BufWriter};
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
|
|
||||||
use instant_distance::Point;
|
use instant_distance::Point;
|
||||||
use pyo3::exceptions::PyTypeError;
|
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||||
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
|
||||||
use pyo3::types::{PyList, PyModule};
|
use pyo3::types::{PyList, PyModule};
|
||||||
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
use pyo3::{PyAny, PyErr, PyIterProtocol, PyObjectProtocol, PyRef, PyRefMut, PyResult, Python};
|
||||||
|
@ -39,6 +41,22 @@ impl Hnsw {
|
||||||
Ok((Self { inner }, ids))
|
Ok((Self { inner }, ids))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[staticmethod]
|
||||||
|
fn load(fname: &str) -> PyResult<Self> {
|
||||||
|
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
|
||||||
|
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||||
|
)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
||||||
|
Ok(Self { inner: hnsw })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dump(&self, fname: &str) -> PyResult<()> {
|
||||||
|
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
|
||||||
|
bincode::serialize_into(f, &self.inner)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
||||||
let point = FloatArray::try_from(point)?;
|
let point = FloatArray::try_from(point)?;
|
||||||
let _ = self.inner.search(&point, &mut search.inner);
|
let _ = self.inner.search(&point, &mut search.inner);
|
||||||
|
|
Loading…
Reference in New Issue