From b1bd3525a101975686443fdd69c2a85dc76ca7e3 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 17 May 2021 16:52:51 +0200 Subject: [PATCH] Add a high-level HnswMap type (fixes #7) --- instant-distance-py/src/lib.rs | 2 +- instant-distance/benches/all.rs | 2 +- instant-distance/src/lib.rs | 43 ++++++++++++++++++++++++++++++++- instant-distance/tests/all.rs | 21 +++++++++++++++- 4 files changed, 64 insertions(+), 4 deletions(-) diff --git a/instant-distance-py/src/lib.rs b/instant-distance-py/src/lib.rs index 4a9355e..94d24b2 100644 --- a/instant-distance-py/src/lib.rs +++ b/instant-distance-py/src/lib.rs @@ -41,7 +41,7 @@ impl Hnsw { .map(FloatArray::try_from) .collect::, PyErr>>()?; - let (inner, ids) = instant_distance::Builder::from(config).build(points); + let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points); let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner())); Ok((Self { inner }, ids)) } diff --git a/instant-distance/benches/all.rs b/instant-distance/benches/all.rs index edc21a4..60b422b 100644 --- a/instant-distance/benches/all.rs +++ b/instant-distance/benches/all.rs @@ -15,7 +15,7 @@ fn build_heuristic(bench: &mut Bencher) { .map(|_| Point(rng.gen(), rng.gen())) .collect::>(); - bench.iter(|| Builder::default().seed(seed).build(points.clone())) + bench.iter(|| Builder::default().seed(seed).build_hnsw(points.clone())) } /* diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index caf54cc..c1f39f5 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -74,8 +74,13 @@ impl Builder { self } + /// Build an `HnswMap` with the given sets of points and values + pub fn build(self, points: Vec

, values: Vec) -> HnswMap { + HnswMap::new(points, values, self) + } + /// Build the `Hnsw` with the given set of points - pub fn build(self, points: Vec

) -> (Hnsw

, Vec) { + pub fn build_hnsw(self, points: Vec

) -> (Hnsw

, Vec) { Hnsw::new(points, self) } @@ -122,6 +127,42 @@ impl Default for Heuristic { } } +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub struct HnswMap { + hnsw: Hnsw

, + values: Vec, +} + +impl HnswMap +where + P: Point, + V: Clone, +{ + fn new(points: Vec

, values: Vec, builder: Builder) -> Self { + let (hnsw, ids) = Hnsw::new(points, builder); + + let mut sorted = ids.into_iter().enumerate().collect::>(); + sorted.sort_unstable_by(|a, b| a.1.cmp(&b.1)); + let new = sorted + .into_iter() + .map(|(src, _)| values[src].clone()) + .collect(); + + Self { hnsw, values: new } + } + + pub fn search<'a>( + &'a self, + point: &P, + search: &'a mut Search, + ) -> impl Iterator + ExactSizeIterator + 'a { + self.hnsw.search(point, search).map(move |candidate| { + let value = &self.values[candidate.pid.0 as usize]; + (candidate.distance.into(), value) + }) + } +} + #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] pub struct Hnsw

{ ef_search: usize, diff --git a/instant-distance/tests/all.rs b/instant-distance/tests/all.rs index 2f0b49e..92a9ead 100644 --- a/instant-distance/tests/all.rs +++ b/instant-distance/tests/all.rs @@ -6,6 +6,25 @@ use rand::{Rng, SeedableRng}; use instant_distance::{Builder, Point as _, Search}; +#[test] +fn map() { + let points = (0..16) + .into_iter() + .map(|i| Point(i as f32, i as f32)) + .collect::>(); + let values = vec![ + "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", + "eleven", "twelve", "thirteen", "fourteen", "fifteen", + ]; + + let seed = ThreadRng::default().gen::(); + println!("map (seed = {})", seed); + let map = Builder::default().seed(seed).build(points, values); + let mut search = Search::default(); + + let _ = map.search(&Point(2.0, 2.0), &mut search); +} + #[test] fn random_heuristic() { let (seed, recall) = randomized(Builder::default()); @@ -38,7 +57,7 @@ fn randomized(builder: Builder) -> (u64, usize) { } } - let (hnsw, pids) = builder.seed(seed).build(points); + let (hnsw, pids) = builder.seed(seed).build_hnsw(points); let mut search = Search::default(); let results = hnsw.search(&query, &mut search); assert!(results.len() >= 100);