Simplify search API

This commit is contained in:
Dirkjan Ochtman 2021-03-04 22:09:19 +01:00
parent c6087de542
commit 3a41d916c3
2 changed files with 19 additions and 13 deletions

View File

@ -315,9 +315,17 @@ where
/// The results are returned in the `out` parameter; the number of neighbors to search for
/// is limited by the size of the `out` parameter, and the number of results found is returned
/// in the return value.
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
pub fn search<'a>(
&self,
point: &P,
search: &'a mut Search,
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a {
fn map(candidate: &Candidate) -> PointId {
candidate.pid
}
if self.points.is_empty() {
return 0;
return (&[] as &[Candidate]).iter().map(map);
}
search.visited.reserve_capacity(self.points.len());
@ -340,11 +348,7 @@ where
}
}
let nearest = &search.select_simple()[..out.len()];
for (i, candidate) in nearest.iter().enumerate() {
out[i] = candidate.pid;
}
nearest.len()
search.select_simple().iter().map(map)
}
/// Iterate over the keys and values in this index
@ -376,7 +380,10 @@ fn insert<P: Point>(
heuristic: &Option<Heuristic>,
) {
let found = match heuristic {
None => &search.select_simple()[..M * 2],
None => {
let candidates = search.select_simple();
&candidates[..Ord::min(candidates.len(), M * 2)]
}
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
};

View File

@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
use rand::rngs::{StdRng, ThreadRng};
use rand::{Rng, SeedableRng};
use instant_distance::{Builder, Point as _, PointId, Search};
use instant_distance::{Builder, Point as _, Search};
#[test]
fn random_heuristic() {
@ -40,9 +40,8 @@ fn randomized(builder: Builder) -> (u64, usize) {
let (hnsw, pids) = builder.seed(seed).build(&points);
let mut search = Search::default();
let mut results = vec![PointId::default(); 100];
let found = hnsw.search(&query, &mut results, &mut search);
assert_eq!(found, 100);
let results = hnsw.search(&query, &mut search);
assert!(results.len() >= 100);
nearest.sort_unstable();
nearest.truncate(100);
@ -50,7 +49,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
.iter()
.map(|(_, i)| pids[*i])
.collect::<HashSet<_>>();
let found = results.into_iter().take(found).collect::<HashSet<_>>();
let found = results.take(100).collect::<HashSet<_>>();
(seed, forced.intersection(&found).count())
}