diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index 54d58be..a2e5371 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -1,6 +1,7 @@ use std::cmp::{max, Ordering, Reverse}; use std::collections::BinaryHeap; use std::collections::HashSet; +use std::ops::Range; #[cfg(feature = "indicatif")] use std::sync::atomic::{self, AtomicUsize}; @@ -280,75 +281,19 @@ where .map(|_| RwLock::new(ZeroNode::default())) .collect::>(); - let pool = SearchPool::new(points.len()); - #[cfg(feature = "indicatif")] - let done = AtomicUsize::new(0); - for (layer, range) in ranges { - let num = if layer.is_zero() { M * 2 } else { M }; + Construction { + zero: zero.as_slice(), + pool: SearchPool::new(points.len()), + top, + points: &points, + heuristic, + ef_construction, #[cfg(feature = "indicatif")] - if let Some(bar) = &progress { - bar.set_message(format!("Building index (layer {})", layer.0)); - } - - let end = range.end; - nodes[range].into_iter().for_each(|(_, pid)| { - let node = zero.as_slice()[*pid].write(); - let (mut search, mut insertion) = pool.pop(); - let point = &points.as_slice()[*pid]; - search.reset(); - search.push(PointId(0), point, &points); - - for cur in top.descend() { - search.ef = if cur <= layer { ef_construction } else { 1 }; - match cur > layer { - true => { - search.search(point, layers[cur.0 - 1].as_slice(), &points, num); - search.cull(); - } - false => { - search.search(point, zero.as_slice(), &points, num); - break; - } - } - } - - insertion.ef = ef_construction; - insert( - *pid, - node, - &mut insertion, - &mut search, - &zero, - &points, - &heuristic, - ); - - #[cfg(feature = "indicatif")] - if let Some(bar) = &progress { - let value = done.fetch_add(1, atomic::Ordering::Relaxed); - if value % 1000 == 0 { - bar.set_position(value as u64); - } - } - - pool.push((search, insertion)); - }); - - // For layers above the zero layer, make a copy of the current state of the zero layer - // with `nearest` truncated to `M` elements. - if !layer.is_zero() { - let upper = (&zero[..end]) - .into_iter() - .map(|zero| UpperNode::from_zero(&zero.read())) - .collect(); - layers[layer.0 - 1] = upper; - } - } - - #[cfg(feature = "indicatif")] - if let Some(bar) = progress { - bar.finish(); + progress, + #[cfg(feature = "indicatif")] + done: AtomicUsize::new(0), } + .build(nodes, ranges.into_iter(), &mut layers); ( Self { @@ -407,77 +352,164 @@ where } } -/// Insert new node in the zero layer -/// -/// * `new`: the `PointId` for the new node -/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection) -/// * `search`: the result for searching potential neighbors for the new node -/// * `layer` contains all the nodes at the current layer -/// * `points` is a slice of all the points in the index -/// -/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors -/// for the new node's neighbors if necessary before appending the new node to the layer. -fn insert( - new: PointId, - mut node: parking_lot::RwLockWriteGuard, - insertion: &mut Search, - search: &mut Search, - layer: &[RwLock], - points: &[P], - heuristic: &Option, -) { - let found = match heuristic { - None => { - let candidates = search.select_simple(); - &candidates[..Ord::min(candidates.len(), M * 2)] +struct Construction<'a, P: Point> { + zero: &'a [RwLock], + pool: SearchPool, + top: LayerId, + points: &'a [P], + heuristic: Option, + ef_construction: usize, + #[cfg(feature = "indicatif")] + progress: Option, + #[cfg(feature = "indicatif")] + done: AtomicUsize, +} + +impl<'a, P: Point> Construction<'a, P> { + fn build( + &self, + nodes: Vec<(LayerId, PointId)>, + ranges: impl Iterator)>, + layers: &mut [Vec], + ) { + for (layer, range) in ranges { + #[cfg(feature = "indicatif")] + if let Some(bar) = &self.progress { + bar.set_message(format!("Building index (layer {})", layer.0)); + } + + let end = range.end; + nodes[range].into_iter().for_each(|(_, pid)| { + let node = self.zero[*pid].write(); + self.insert(*pid, node, layer, &layers); + }); + + // For layers above the zero layer, make a copy of the current state of the zero layer + // with `nearest` truncated to `M` elements. + if !layer.is_zero() { + let upper = (&self.zero[..end]) + .into_iter() + .map(|zero| UpperNode::from_zero(&zero.read())) + .collect(); + layers[layer.0 - 1] = upper; + } } - Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic), - }; - // Just make sure the candidates are all unique - debug_assert_eq!( - found.len(), - found.iter().map(|c| c.pid).collect::>().len() - ); - - for (i, candidate) in found.iter().enumerate() { - // `candidate` here is the new node's neighbor - let &Candidate { distance, pid } = candidate; - if let Some(heuristic) = heuristic { - let found = insertion.add_neighbor_heuristic( - new, - layer.nearest_iter(pid), - layer, - &points[pid], - points, - *heuristic, - ); - - layer[pid] - .write() - .rewrite(found.iter().map(|candidate| candidate.pid)); - node.set(i, pid); - } else { - // Find the correct index to insert at to keep the neighbor's neighbors sorted - let old = &points[pid]; - let idx = layer[pid] - .read() - .binary_search_by(|third| { - // `third` here is one of the neighbors of the new node's neighbor. - let third = match third { - pid if pid.is_valid() => *pid, - // if `third` is `None`, our new `node` is always "closer" - _ => return Ordering::Greater, - }; - - distance.cmp(&old.distance(&points[third]).into()) - }) - .unwrap_or_else(|e| e); - - layer[pid].write().insert(idx, new); - node.set(i, pid); + #[cfg(feature = "indicatif")] + if let Some(bar) = &self.progress { + bar.finish(); } } + + /// Insert new node in the zero layer + /// + /// * `new`: the `PointId` for the new node + /// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection) + /// * `search`: the result for searching potential neighbors for the new node + /// * `layer` contains all the nodes at the current layer + /// * `points` is a slice of all the points in the index + /// + /// Creates the new node, initializing its `nearest` array and updates the nearest neighbors + /// for the new node's neighbors if necessary before appending the new node to the layer. + #[inline(always)] + fn insert( + &self, + new: PointId, + mut node: parking_lot::RwLockWriteGuard, + layer: LayerId, + layers: &[Vec], + ) { + let (mut search, mut insertion) = self.pool.pop(); + insertion.ef = self.ef_construction; + + let point = &self.points[new]; + search.reset(); + search.push(PointId(0), point, &self.points); + let num = if layer.is_zero() { M * 2 } else { M }; + + for cur in self.top.descend() { + search.ef = if cur <= layer { + self.ef_construction + } else { + 1 + }; + match cur > layer { + true => { + search.search(point, layers[cur.0 - 1].as_slice(), &self.points, num); + search.cull(); + } + false => { + search.search(point, self.zero, &self.points, num); + break; + } + } + } + + let found = match self.heuristic { + None => { + let candidates = search.select_simple(); + &candidates[..Ord::min(candidates.len(), M * 2)] + } + Some(heuristic) => { + search.select_heuristic(&self.points[new], self.zero, self.points, heuristic) + } + }; + + // Just make sure the candidates are all unique + debug_assert_eq!( + found.len(), + found.iter().map(|c| c.pid).collect::>().len() + ); + + for (i, candidate) in found.iter().enumerate() { + // `candidate` here is the new node's neighbor + let &Candidate { distance, pid } = candidate; + if let Some(heuristic) = self.heuristic { + let found = insertion.add_neighbor_heuristic( + new, + self.zero.nearest_iter(pid), + self.zero, + &self.points[pid], + self.points, + heuristic, + ); + + self.zero[pid] + .write() + .rewrite(found.iter().map(|candidate| candidate.pid)); + node.set(i, pid); + } else { + // Find the correct index to insert at to keep the neighbor's neighbors sorted + let old = &self.points[pid]; + let idx = self.zero[pid] + .read() + .binary_search_by(|third| { + // `third` here is one of the neighbors of the new node's neighbor. + let third = match third { + pid if pid.is_valid() => *pid, + // if `third` is `None`, our new `node` is always "closer" + _ => return Ordering::Greater, + }; + + distance.cmp(&old.distance(&self.points[third]).into()) + }) + .unwrap_or_else(|e| e); + + self.zero[pid].write().insert(idx, new); + node.set(i, pid); + } + } + + #[cfg(feature = "indicatif")] + if let Some(bar) = &self.progress { + let value = self.done.fetch_add(1, atomic::Ordering::Relaxed); + if value % 1000 == 0 { + bar.set_position(value as u64); + } + } + + self.pool.push((search, insertion)); + } } struct SearchPool {