Make search() return more information

This commit is contained in:
Dirkjan Ochtman 2021-05-18 13:38:20 +02:00
parent 20ca8b0f3a
commit 4b43a11e26
2 changed files with 30 additions and 11 deletions

View File

@ -155,10 +155,12 @@ where
&'a self,
point: &P,
search: &'a mut Search,
) -> impl Iterator<Item = (f32, &'a V)> + ExactSizeIterator + 'a {
self.hnsw.search(point, search).map(move |candidate| {
let value = &self.values[candidate.pid.0 as usize];
(candidate.distance.into(), value)
) -> impl Iterator<Item = (f32, &'a P, &'a V)> + ExactSizeIterator + 'a {
self.hnsw
.search(point, search)
.map(move |(distance, pid, point)| {
let value = &self.values[pid.0 as usize];
(distance, point, value)
})
}
}
@ -336,14 +338,15 @@ where
/// The results are returned in the `out` parameter; the number of neighbors to search for
/// is limited by the size of the `out` parameter, and the number of results found is returned
/// in the return value.
pub fn search<'a>(
&self,
pub fn search<'a, 'b: 'a>(
&'b self,
point: &P,
search: &'a mut Search,
) -> impl Iterator<Item = Candidate> + ExactSizeIterator + 'a {
) -> impl Iterator<Item = Item<'b, P>> + ExactSizeIterator + 'a {
search.reset();
let map = move |candidate| Item::new(candidate, self);
if self.points.is_empty() {
return search.iter();
return search.iter().map(map);
}
search.visited.reserve_capacity(self.points.len());
@ -365,7 +368,7 @@ where
}
}
search.iter()
search.iter().map(map)
}
/// Iterate over the keys and values in this index
@ -377,6 +380,22 @@ where
}
}
pub struct Item<'a, P> {
pub distance: f32,
pub pid: PointId,
pub point: &'a P,
}
impl<'a, P> Item<'a, P> {
fn new(candidate: Candidate, hnsw: &'a Hnsw<P>) -> Self {
Self {
distance: candidate.distance.into_inner(),
pid: candidate.pid,
point: &hnsw[candidate.pid],
}
}
}
struct Construction<'a, P: Point> {
zero: &'a [RwLock<ZeroNode>],
pool: SearchPool,

View File

@ -70,7 +70,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
.collect::<HashSet<_>>();
let found = results
.take(100)
.map(|candidate| candidate.pid)
.map(|item| item.pid)
.collect::<HashSet<_>>();
(seed, forced.intersection(&found).count())
}