Simplify search API
This commit is contained in:
parent
c6087de542
commit
3a41d916c3
23
src/lib.rs
23
src/lib.rs
|
@ -315,9 +315,17 @@ where
|
|||
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
||||
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
||||
/// in the return value.
|
||||
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
|
||||
pub fn search<'a>(
|
||||
&self,
|
||||
point: &P,
|
||||
search: &'a mut Search,
|
||||
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a {
|
||||
fn map(candidate: &Candidate) -> PointId {
|
||||
candidate.pid
|
||||
}
|
||||
|
||||
if self.points.is_empty() {
|
||||
return 0;
|
||||
return (&[] as &[Candidate]).iter().map(map);
|
||||
}
|
||||
|
||||
search.visited.reserve_capacity(self.points.len());
|
||||
|
@ -340,11 +348,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
let nearest = &search.select_simple()[..out.len()];
|
||||
for (i, candidate) in nearest.iter().enumerate() {
|
||||
out[i] = candidate.pid;
|
||||
}
|
||||
nearest.len()
|
||||
search.select_simple().iter().map(map)
|
||||
}
|
||||
|
||||
/// Iterate over the keys and values in this index
|
||||
|
@ -376,7 +380,10 @@ fn insert<P: Point>(
|
|||
heuristic: &Option<Heuristic>,
|
||||
) {
|
||||
let found = match heuristic {
|
||||
None => &search.select_simple()[..M * 2],
|
||||
None => {
|
||||
let candidates = search.select_simple();
|
||||
&candidates[..Ord::min(candidates.len(), M * 2)]
|
||||
}
|
||||
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
|
||||
};
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
|
|||
use rand::rngs::{StdRng, ThreadRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use instant_distance::{Builder, Point as _, PointId, Search};
|
||||
use instant_distance::{Builder, Point as _, Search};
|
||||
|
||||
#[test]
|
||||
fn random_heuristic() {
|
||||
|
@ -40,9 +40,8 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
|||
|
||||
let (hnsw, pids) = builder.seed(seed).build(&points);
|
||||
let mut search = Search::default();
|
||||
let mut results = vec![PointId::default(); 100];
|
||||
let found = hnsw.search(&query, &mut results, &mut search);
|
||||
assert_eq!(found, 100);
|
||||
let results = hnsw.search(&query, &mut search);
|
||||
assert!(results.len() >= 100);
|
||||
|
||||
nearest.sort_unstable();
|
||||
nearest.truncate(100);
|
||||
|
@ -50,7 +49,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
|||
.iter()
|
||||
.map(|(_, i)| pids[*i])
|
||||
.collect::<HashSet<_>>();
|
||||
let found = results.into_iter().take(found).collect::<HashSet<_>>();
|
||||
let found = results.take(100).collect::<HashSet<_>>();
|
||||
(seed, forced.intersection(&found).count())
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue