py: add load() and dump() methods

This commit is contained in:
Dirkjan Ochtman 2021-03-24 11:40:39 +01:00
parent b6cff0a93c
commit 377a71cec2
2 changed files with 23 additions and 1 deletions

View File

@ -12,7 +12,8 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
ahash = "0.7.2" ahash = "0.7.2"
instant-segment = { version = "0.7", path = "../instant-segment" } bincode = "1.3.2"
instant-segment = { version = "0.7", path = "../instant-segment", features = ["with-serde"] }
pyo3 = { version = "0.13.2", features = ["extension-module"] } pyo3 = { version = "0.13.2", features = ["extension-module"] }
smartstring = "0.2.6" smartstring = "0.2.6"

View File

@ -1,3 +1,6 @@
use std::fs::File;
use std::io::{BufReader, BufWriter};
use pyo3::exceptions::PyValueError; use pyo3::exceptions::PyValueError;
use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto};
use pyo3::types::{PyIterator, PyModule}; use pyo3::types::{PyIterator, PyModule};
@ -47,6 +50,24 @@ impl Segmenter {
}) })
} }
/// Load a segmenter from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw = bincode::deserialize_from::<_, instant_segment::Segmenter>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
Ok(Self { inner: hnsw })
}
/// Dump the segmenter to the given file name
fn dump(&self, fname: &str) -> PyResult<()> {
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
bincode::serialize_into(f, &self.inner)
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
Ok(())
}
fn segment(&self, s: &str, search: &mut Search) -> PyResult<()> { fn segment(&self, s: &str, search: &mut Search) -> PyResult<()> {
match self.inner.segment(s, &mut search.inner) { match self.inner.segment(s, &mut search.inner) {
Ok(_) => { Ok(_) => {