Make search() return more information
This commit is contained in:
parent
20ca8b0f3a
commit
4b43a11e26
|
@ -155,11 +155,13 @@ where
|
||||||
&'a self,
|
&'a self,
|
||||||
point: &P,
|
point: &P,
|
||||||
search: &'a mut Search,
|
search: &'a mut Search,
|
||||||
) -> impl Iterator<Item = (f32, &'a V)> + ExactSizeIterator + 'a {
|
) -> impl Iterator<Item = (f32, &'a P, &'a V)> + ExactSizeIterator + 'a {
|
||||||
self.hnsw.search(point, search).map(move |candidate| {
|
self.hnsw
|
||||||
let value = &self.values[candidate.pid.0 as usize];
|
.search(point, search)
|
||||||
(candidate.distance.into(), value)
|
.map(move |(distance, pid, point)| {
|
||||||
})
|
let value = &self.values[pid.0 as usize];
|
||||||
|
(distance, point, value)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,14 +338,15 @@ where
|
||||||
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
/// The results are returned in the `out` parameter; the number of neighbors to search for
|
||||||
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
/// is limited by the size of the `out` parameter, and the number of results found is returned
|
||||||
/// in the return value.
|
/// in the return value.
|
||||||
pub fn search<'a>(
|
pub fn search<'a, 'b: 'a>(
|
||||||
&self,
|
&'b self,
|
||||||
point: &P,
|
point: &P,
|
||||||
search: &'a mut Search,
|
search: &'a mut Search,
|
||||||
) -> impl Iterator<Item = Candidate> + ExactSizeIterator + 'a {
|
) -> impl Iterator<Item = Item<'b, P>> + ExactSizeIterator + 'a {
|
||||||
search.reset();
|
search.reset();
|
||||||
|
let map = move |candidate| Item::new(candidate, self);
|
||||||
if self.points.is_empty() {
|
if self.points.is_empty() {
|
||||||
return search.iter();
|
return search.iter().map(map);
|
||||||
}
|
}
|
||||||
|
|
||||||
search.visited.reserve_capacity(self.points.len());
|
search.visited.reserve_capacity(self.points.len());
|
||||||
|
@ -365,7 +368,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
search.iter()
|
search.iter().map(map)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Iterate over the keys and values in this index
|
/// Iterate over the keys and values in this index
|
||||||
|
@ -377,6 +380,22 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct Item<'a, P> {
|
||||||
|
pub distance: f32,
|
||||||
|
pub pid: PointId,
|
||||||
|
pub point: &'a P,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, P> Item<'a, P> {
|
||||||
|
fn new(candidate: Candidate, hnsw: &'a Hnsw<P>) -> Self {
|
||||||
|
Self {
|
||||||
|
distance: candidate.distance.into_inner(),
|
||||||
|
pid: candidate.pid,
|
||||||
|
point: &hnsw[candidate.pid],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Construction<'a, P: Point> {
|
struct Construction<'a, P: Point> {
|
||||||
zero: &'a [RwLock<ZeroNode>],
|
zero: &'a [RwLock<ZeroNode>],
|
||||||
pool: SearchPool,
|
pool: SearchPool,
|
||||||
|
|
|
@ -70,7 +70,7 @@ fn randomized(builder: Builder) -> (u64, usize) {
|
||||||
.collect::<HashSet<_>>();
|
.collect::<HashSet<_>>();
|
||||||
let found = results
|
let found = results
|
||||||
.take(100)
|
.take(100)
|
||||||
.map(|candidate| candidate.pid)
|
.map(|item| item.pid)
|
||||||
.collect::<HashSet<_>>();
|
.collect::<HashSet<_>>();
|
||||||
(seed, forced.intersection(&found).count())
|
(seed, forced.intersection(&found).count())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue