diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index ce0f89d..dda381a 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -23,63 +23,6 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { Ok(()) } -/// An instance of hierarchical navigable small worlds -/// -/// For now, this is specialized to only support 300-element (32-bit) float vectors -/// with a squared Euclidean distance metric. -#[pyclass] -struct Hnsw { - inner: instant_distance::Hnsw, -} - -#[pymethods] -impl Hnsw { - /// Build the index - #[staticmethod] - fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec)> { - let points = input - .into_iter() - .map(FloatArray::try_from) - .collect::, PyErr>>()?; - - let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points); - let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); - Ok((Self { inner }, ids)) - } - - /// Load an index from the given file name - #[staticmethod] - fn load(fname: &str) -> PyResult { - let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw>( - BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), - ) - .map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?; - Ok(Self { inner: hnsw }) - } - - /// Dump the index to the given file name - fn dump(&self, fname: &str) -> PyResult<()> { - let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?); - bincode::serialize_into(f, &self.inner) - .map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?; - Ok(()) - } - - /// Search the index for points neighboring the given point - /// - /// The `search` object contains buffers used for searching. When the search completes, - /// iterate over the `Search` to get the results. The number of results should be equal - /// to the `ef_search` parameter set in the index's `config`. - /// - /// For best performance, reusing `Search` objects is recommended. - fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { - let point = FloatArray::try_from(point)?; - let _ = self.inner.search(&point, &mut search.inner); - search.cur = Some(0); - Ok(()) - } -} - #[pyclass] struct HnswMap { inner: instant_distance::HnswMap, @@ -137,6 +80,63 @@ impl HnswMap { } } +/// An instance of hierarchical navigable small worlds +/// +/// For now, this is specialized to only support 300-element (32-bit) float vectors +/// with a squared Euclidean distance metric. +#[pyclass] +struct Hnsw { + inner: instant_distance::Hnsw, +} + +#[pymethods] +impl Hnsw { + /// Build the index + #[staticmethod] + fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec)> { + let points = input + .into_iter() + .map(FloatArray::try_from) + .collect::, PyErr>>()?; + + let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points); + let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); + Ok((Self { inner }, ids)) + } + + /// Load an index from the given file name + #[staticmethod] + fn load(fname: &str) -> PyResult { + let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw>( + BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), + ) + .map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?; + Ok(Self { inner: hnsw }) + } + + /// Dump the index to the given file name + fn dump(&self, fname: &str) -> PyResult<()> { + let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?); + bincode::serialize_into(f, &self.inner) + .map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?; + Ok(()) + } + + /// Search the index for points neighboring the given point + /// + /// The `search` object contains buffers used for searching. When the search completes, + /// iterate over the `Search` to get the results. The number of results should be equal + /// to the `ef_search` parameter set in the index's `config`. + /// + /// For best performance, reusing `Search` objects is recommended. + fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { + let point = FloatArray::try_from(point)?; + let _ = self.inner.search(&point, &mut search.inner); + search.cur = Some(0); + Ok(()) + } +} + /// Search buffer and result set #[pyclass] struct Search {