Simplify search API

This commit is contained in:
Dirkjan Ochtman 2021-03-04 22:09:19 +01:00
parent c6087de542
commit 3a41d916c3
2 changed files with 19 additions and 13 deletions

View File

@ -315,9 +315,17 @@ where
/// The results are returned in the `out` parameter; the number of neighbors to search for /// The results are returned in the `out` parameter; the number of neighbors to search for
/// is limited by the size of the `out` parameter, and the number of results found is returned /// is limited by the size of the `out` parameter, and the number of results found is returned
/// in the return value. /// in the return value.
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize { pub fn search<'a>(
&self,
point: &P,
search: &'a mut Search,
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a {
fn map(candidate: &Candidate) -> PointId {
candidate.pid
}
if self.points.is_empty() { if self.points.is_empty() {
return 0; return (&[] as &[Candidate]).iter().map(map);
} }
search.visited.reserve_capacity(self.points.len()); search.visited.reserve_capacity(self.points.len());
@ -340,11 +348,7 @@ where
} }
} }
let nearest = &search.select_simple()[..out.len()]; search.select_simple().iter().map(map)
for (i, candidate) in nearest.iter().enumerate() {
out[i] = candidate.pid;
}
nearest.len()
} }
/// Iterate over the keys and values in this index /// Iterate over the keys and values in this index
@ -376,7 +380,10 @@ fn insert<P: Point>(
heuristic: &Option<Heuristic>, heuristic: &Option<Heuristic>,
) { ) {
let found = match heuristic { let found = match heuristic {
None => &search.select_simple()[..M * 2], None => {
let candidates = search.select_simple();
&candidates[..Ord::min(candidates.len(), M * 2)]
}
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic), Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
}; };

View File

@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
use rand::rngs::{StdRng, ThreadRng}; use rand::rngs::{StdRng, ThreadRng};
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use instant_distance::{Builder, Point as _, PointId, Search}; use instant_distance::{Builder, Point as _, Search};
#[test] #[test]
fn random_heuristic() { fn random_heuristic() {
@ -40,9 +40,8 @@ fn randomized(builder: Builder) -> (u64, usize) {
let (hnsw, pids) = builder.seed(seed).build(&points); let (hnsw, pids) = builder.seed(seed).build(&points);
let mut search = Search::default(); let mut search = Search::default();
let mut results = vec![PointId::default(); 100]; let results = hnsw.search(&query, &mut search);
let found = hnsw.search(&query, &mut results, &mut search); assert!(results.len() >= 100);
assert_eq!(found, 100);
nearest.sort_unstable(); nearest.sort_unstable();
nearest.truncate(100); nearest.truncate(100);
@ -50,7 +49,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
.iter() .iter()
.map(|(_, i)| pids[*i]) .map(|(_, i)| pids[*i])
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let found = results.into_iter().take(found).collect::<HashSet<_>>(); let found = results.take(100).collect::<HashSet<_>>();
(seed, forced.intersection(&found).count()) (seed, forced.intersection(&found).count())
} }