From 3a41d916c38afcb4eb509631d6da2834d477c998 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Thu, 4 Mar 2021 22:09:19 +0100 Subject: [PATCH] Simplify search API --- src/lib.rs | 23 +++++++++++++++-------- tests/all.rs | 9 ++++----- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a32cc90..a12e0da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -315,9 +315,17 @@ where /// The results are returned in the `out` parameter; the number of neighbors to search for /// is limited by the size of the `out` parameter, and the number of results found is returned /// in the return value. - pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize { + pub fn search<'a>( + &self, + point: &P, + search: &'a mut Search, + ) -> impl Iterator + ExactSizeIterator + 'a { + fn map(candidate: &Candidate) -> PointId { + candidate.pid + } + if self.points.is_empty() { - return 0; + return (&[] as &[Candidate]).iter().map(map); } search.visited.reserve_capacity(self.points.len()); @@ -340,11 +348,7 @@ where } } - let nearest = &search.select_simple()[..out.len()]; - for (i, candidate) in nearest.iter().enumerate() { - out[i] = candidate.pid; - } - nearest.len() + search.select_simple().iter().map(map) } /// Iterate over the keys and values in this index @@ -376,7 +380,10 @@ fn insert( heuristic: &Option, ) { let found = match heuristic { - None => &search.select_simple()[..M * 2], + None => { + let candidates = search.select_simple(); + &candidates[..Ord::min(candidates.len(), M * 2)] + } Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic), }; diff --git a/tests/all.rs b/tests/all.rs index 3cab9ca..d054c63 100644 --- a/tests/all.rs +++ b/tests/all.rs @@ -4,7 +4,7 @@ use ordered_float::OrderedFloat; use rand::rngs::{StdRng, ThreadRng}; use rand::{Rng, SeedableRng}; -use instant_distance::{Builder, Point as _, PointId, Search}; +use instant_distance::{Builder, Point as _, Search}; #[test] fn random_heuristic() { @@ -40,9 +40,8 @@ fn randomized(builder: Builder) -> (u64, usize) { let (hnsw, pids) = builder.seed(seed).build(&points); let mut search = Search::default(); - let mut results = vec![PointId::default(); 100]; - let found = hnsw.search(&query, &mut results, &mut search); - assert_eq!(found, 100); + let results = hnsw.search(&query, &mut search); + assert!(results.len() >= 100); nearest.sort_unstable(); nearest.truncate(100); @@ -50,7 +49,7 @@ fn randomized(builder: Builder) -> (u64, usize) { .iter() .map(|(_, i)| pids[*i]) .collect::>(); - let found = results.into_iter().take(found).collect::>(); + let found = results.take(100).collect::>(); (seed, forced.intersection(&found).count()) }