diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index 9f5a252..fa82d43 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -155,11 +155,13 @@ where &'a self, point: &P, search: &'a mut Search, - ) -> impl Iterator + ExactSizeIterator + 'a { - self.hnsw.search(point, search).map(move |candidate| { - let value = &self.values[candidate.pid.0 as usize]; - (candidate.distance.into(), value) - }) + ) -> impl Iterator + ExactSizeIterator + 'a { + self.hnsw + .search(point, search) + .map(move |(distance, pid, point)| { + let value = &self.values[pid.0 as usize]; + (distance, point, value) + }) } } @@ -336,14 +338,15 @@ where /// The results are returned in the `out` parameter; the number of neighbors to search for /// is limited by the size of the `out` parameter, and the number of results found is returned /// in the return value. - pub fn search<'a>( - &self, + pub fn search<'a, 'b: 'a>( + &'b self, point: &P, search: &'a mut Search, - ) -> impl Iterator + ExactSizeIterator + 'a { + ) -> impl Iterator> + ExactSizeIterator + 'a { search.reset(); + let map = move |candidate| Item::new(candidate, self); if self.points.is_empty() { - return search.iter(); + return search.iter().map(map); } search.visited.reserve_capacity(self.points.len()); @@ -365,7 +368,7 @@ where } } - search.iter() + search.iter().map(map) } /// Iterate over the keys and values in this index @@ -377,6 +380,22 @@ where } } +pub struct Item<'a, P> { + pub distance: f32, + pub pid: PointId, + pub point: &'a P, +} + +impl<'a, P> Item<'a, P> { + fn new(candidate: Candidate, hnsw: &'a Hnsw

) -> Self { + Self { + distance: candidate.distance.into_inner(), + pid: candidate.pid, + point: &hnsw[candidate.pid], + } + } +} + struct Construction<'a, P: Point> { zero: &'a [RwLock], pool: SearchPool, diff --git a/instant-distance/tests/all.rs b/instant-distance/tests/all.rs index 92a9ead..3c25964 100644 --- a/instant-distance/tests/all.rs +++ b/instant-distance/tests/all.rs @@ -70,7 +70,7 @@ fn randomized(builder: Builder) -> (u64, usize) { .collect::>(); let found = results .take(100) - .map(|candidate| candidate.pid) + .map(|item| item.pid) .collect::>(); (seed, forced.intersection(&found).count()) }