Make search() return more information

This commit is contained in:
Dirkjan Ochtman 2021-05-18 13:38:20 +02:00
parent 20ca8b0f3a
commit 4b43a11e26
2 changed files with 30 additions and 11 deletions

View File

@ -155,11 +155,13 @@ where
&'a self, &'a self,
point: &P, point: &P,
search: &'a mut Search, search: &'a mut Search,
) -> impl Iterator<Item = (f32, &'a V)> + ExactSizeIterator + 'a { ) -> impl Iterator<Item = (f32, &'a P, &'a V)> + ExactSizeIterator + 'a {
self.hnsw.search(point, search).map(move |candidate| { self.hnsw
let value = &self.values[candidate.pid.0 as usize]; .search(point, search)
(candidate.distance.into(), value) .map(move |(distance, pid, point)| {
}) let value = &self.values[pid.0 as usize];
(distance, point, value)
})
} }
} }
@ -336,14 +338,15 @@ where
/// The results are returned in the `out` parameter; the number of neighbors to search for /// The results are returned in the `out` parameter; the number of neighbors to search for
/// is limited by the size of the `out` parameter, and the number of results found is returned /// is limited by the size of the `out` parameter, and the number of results found is returned
/// in the return value. /// in the return value.
pub fn search<'a>( pub fn search<'a, 'b: 'a>(
&self, &'b self,
point: &P, point: &P,
search: &'a mut Search, search: &'a mut Search,
) -> impl Iterator<Item = Candidate> + ExactSizeIterator + 'a { ) -> impl Iterator<Item = Item<'b, P>> + ExactSizeIterator + 'a {
search.reset(); search.reset();
let map = move |candidate| Item::new(candidate, self);
if self.points.is_empty() { if self.points.is_empty() {
return search.iter(); return search.iter().map(map);
} }
search.visited.reserve_capacity(self.points.len()); search.visited.reserve_capacity(self.points.len());
@ -365,7 +368,7 @@ where
} }
} }
search.iter() search.iter().map(map)
} }
/// Iterate over the keys and values in this index /// Iterate over the keys and values in this index
@ -377,6 +380,22 @@ where
} }
} }
pub struct Item<'a, P> {
pub distance: f32,
pub pid: PointId,
pub point: &'a P,
}
impl<'a, P> Item<'a, P> {
fn new(candidate: Candidate, hnsw: &'a Hnsw<P>) -> Self {
Self {
distance: candidate.distance.into_inner(),
pid: candidate.pid,
point: &hnsw[candidate.pid],
}
}
}
struct Construction<'a, P: Point> { struct Construction<'a, P: Point> {
zero: &'a [RwLock<ZeroNode>], zero: &'a [RwLock<ZeroNode>],
pool: SearchPool, pool: SearchPool,

View File

@ -70,7 +70,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
let found = results let found = results
.take(100) .take(100)
.map(|candidate| candidate.pid) .map(|item| item.pid)
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
(seed, forced.intersection(&found).count()) (seed, forced.intersection(&found).count())
} }