Simplify search API
This commit is contained in:
parent
c6087de542
commit
3a41d916c3
23
src/lib.rs
23
src/lib.rs
|
@ -315,9 +315,17 @@ where
|
||||||
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
||||||
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
||||||
/// in the return value.
|
/// in the return value.
|
||||||
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
|
pub fn search<'a>(
|
||||||
|
&self,
|
||||||
|
point: &P,
|
||||||
|
search: &'a mut Search,
|
||||||
|
) -> impl Iterator<Item = PointId> + ExactSizeIterator + 'a {
|
||||||
|
fn map(candidate: &Candidate) -> PointId {
|
||||||
|
candidate.pid
|
||||||
|
}
|
||||||
|
|
||||||
if self.points.is_empty() {
|
if self.points.is_empty() {
|
||||||
return 0;
|
return (&[] as &[Candidate]).iter().map(map);
|
||||||
}
|
}
|
||||||
|
|
||||||
search.visited.reserve_capacity(self.points.len());
|
search.visited.reserve_capacity(self.points.len());
|
||||||
|
@ -340,11 +348,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let nearest = &search.select_simple()[..out.len()];
|
search.select_simple().iter().map(map)
|
||||||
for (i, candidate) in nearest.iter().enumerate() {
|
|
||||||
out[i] = candidate.pid;
|
|
||||||
}
|
|
||||||
nearest.len()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Iterate over the keys and values in this index
|
/// Iterate over the keys and values in this index
|
||||||
|
@ -376,7 +380,10 @@ fn insert<P: Point>(
|
||||||
heuristic: &Option<Heuristic>,
|
heuristic: &Option<Heuristic>,
|
||||||
) {
|
) {
|
||||||
let found = match heuristic {
|
let found = match heuristic {
|
||||||
None => &search.select_simple()[..M * 2],
|
None => {
|
||||||
|
let candidates = search.select_simple();
|
||||||
|
&candidates[..Ord::min(candidates.len(), M * 2)]
|
||||||
|
}
|
||||||
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
|
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ use ordered_float::OrderedFloat;
|
||||||
use rand::rngs::{StdRng, ThreadRng};
|
use rand::rngs::{StdRng, ThreadRng};
|
||||||
use rand::{Rng, SeedableRng};
|
use rand::{Rng, SeedableRng};
|
||||||
|
|
||||||
use instant_distance::{Builder, Point as _, PointId, Search};
|
use instant_distance::{Builder, Point as _, Search};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn random_heuristic() {
|
fn random_heuristic() {
|
||||||
|
@ -40,9 +40,8 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||||
|
|
||||||
let (hnsw, pids) = builder.seed(seed).build(&points);
|
let (hnsw, pids) = builder.seed(seed).build(&points);
|
||||||
let mut search = Search::default();
|
let mut search = Search::default();
|
||||||
let mut results = vec![PointId::default(); 100];
|
let results = hnsw.search(&query, &mut search);
|
||||||
let found = hnsw.search(&query, &mut results, &mut search);
|
assert!(results.len() >= 100);
|
||||||
assert_eq!(found, 100);
|
|
||||||
|
|
||||||
nearest.sort_unstable();
|
nearest.sort_unstable();
|
||||||
nearest.truncate(100);
|
nearest.truncate(100);
|
||||||
|
@ -50,7 +49,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, i)| pids[*i])
|
.map(|(_, i)| pids[*i])
|
||||||
.collect::<HashSet<_>>();
|
.collect::<HashSet<_>>();
|
||||||
let found = results.into_iter().take(found).collect::<HashSet<_>>();
|
let found = results.take(100).collect::<HashSet<_>>();
|
||||||
(seed, forced.intersection(&found).count())
|
(seed, forced.intersection(&found).count())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue