diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..de288e1 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.formatting.provider": "black" +} \ No newline at end of file diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index 94d24b2..ce0f89d 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -19,6 +19,7 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -79,6 +80,63 @@ impl Hnsw { } } +#[pyclass] +struct HnswMap { + inner: instant_distance::HnswMap, +} + +#[pymethods] +impl HnswMap { + #[getter] + fn values(&self) -> PyResult> { + Ok(self.inner.values.clone()) + } + /// Build the index + #[staticmethod] + fn build(points: &PyList, values: Vec, config: &Config) -> PyResult { + let points = points + .into_iter() + .map(FloatArray::try_from) + .collect::, PyErr>>()?; + + let hsnw_map = instant_distance::Builder::from(config).build(points, values); + Ok(Self { inner: hsnw_map }) + } + + /// Load an index from the given file name + #[staticmethod] + fn load(fname: &str) -> PyResult { + let hnsw_map = + bincode::deserialize_from::<_, instant_distance::HnswMap>( + BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?), + ) + .map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?; + Ok(Self { inner: hnsw_map }) + } + + /// Dump the index to the given file name + fn dump(&self, fname: &str) -> PyResult<()> { + let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?); + bincode::serialize_into(f, &self.inner) + .map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?; + Ok(()) + } + + /// Search the index for points neighboring the given point + /// + /// The `search` object contains buffers used for searching. When the search completes, + /// iterate over the `Search` to get the results. The number of results should be equal + /// to the `ef_search` parameter set in the index's `config`. + /// + /// For best performance, reusing `Search` objects is recommended. + fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> { + let point = FloatArray::try_from(point)?; + let _ = self.inner.search(&point, &mut search.inner); + search.cur = Some(0); + Ok(()) + } +} + /// Search buffer and result set #[pyclass] struct Search { diff --git a/instant-distance-py/test/test.py b/instant-distance-py/test/test.py index 28ee613..931f697 100644 --- a/instant-distance-py/test/test.py +++ b/instant-distance-py/test/test.py @@ -1,6 +1,7 @@ import instant_distance, random -def main(): + +def test_hsnw(): points = [[random.random() for _ in range(300)] for _ in range(1024)] config = instant_distance.Config() (hnsw, ids) = instant_distance.Hnsw.build(points, config) @@ -10,5 +11,31 @@ def main(): for candidate in search: print(candidate) -if __name__ == '__main__': - main() + +def test_hsnw_map(): + the_chosen_one = 123 + + embeddings = [[random.random() for _ in range(300)] for _ in range(1024)] + with open("/usr/share/dict/words", "r") as f: # *nix only + values = f.read().splitlines()[1024:] + + config = instant_distance.Config() + hnsw_map = instant_distance.HnswMap.build(embeddings, values, config) + + search = instant_distance.Search() + hnsw_map.search(embeddings[the_chosen_one], search) + + closest_pid = list(search)[0].pid + approx_nearest = hnsw_map.values[closest_pid] + actual_word = values[the_chosen_one] + + print("pid:\t\t", closest_pid) + print("approx word:\t", approx_nearest) + print("actual word:\t", actual_word) + + assert approx_nearest == actual_word + + +if __name__ == "__main__": + test_hsnw() + test_hsnw_map() diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index 1a65c08..98c8adf 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -130,7 +130,7 @@ impl Default for Heuristic { #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct HnswMap { hnsw: Hnsw

, - values: Vec, + pub values: Vec, } impl HnswMap