diff --git a/instant-segment-py/Cargo.toml b/instant-segment-py/Cargo.toml index 6c5289f..ad33261 100644 --- a/instant-segment-py/Cargo.toml +++ b/instant-segment-py/Cargo.toml @@ -12,7 +12,8 @@ crate-type = ["cdylib"] [dependencies] ahash = "0.7.2" -instant-segment = { version = "0.7", path = "../instant-segment" } +bincode = "1.3.2" +instant-segment = { version = "0.7", path = "../instant-segment", features = ["with-serde"] } pyo3 = { version = "0.13.2", features = ["extension-module"] } smartstring = "0.2.6" diff --git a/instant-segment-py/src/lib.rs b/instant-segment-py/src/lib.rs index 2043fe5..8a3b4a9 100644 --- a/instant-segment-py/src/lib.rs +++ b/instant-segment-py/src/lib.rs @@ -1,3 +1,6 @@ +use std::fs::File; +use std::io::{BufReader, BufWriter}; + use pyo3::exceptions::PyValueError; use pyo3::proc_macro::{pyclass, pymethods, pymodule, pyproto}; use pyo3::types::{PyIterator, PyModule}; @@ -47,6 +50,24 @@ impl Segmenter { }) } + /// Load a segmenter from the given file name + #[staticmethod] + fn load(fname: &str) -> PyResult { + let hnsw = bincode::deserialize_from::<_, instant_segment::Segmenter>( + BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), + ) + .map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?; + Ok(Self { inner: hnsw }) + } + + /// Dump the segmenter to the given file name + fn dump(&self, fname: &str) -> PyResult<()> { + let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?); + bincode::serialize_into(f, &self.inner) + .map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?; + Ok(()) + } + fn segment(&self, s: &str, search: &mut Search) -> PyResult<()> { match self.inner.segment(s, &mut search.inner) { Ok(_) => {