From 20ca8b0f3a3cb143b4321c66ad5ef5f8476a602b Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Tue, 18 May 2021 10:21:52 +0200 Subject: [PATCH] Explicitly group state for HNSW construction --- instant-distance/src/lib.rs | 248 +++++++++++++++++++----------------- 1 file changed, 134 insertions(+), 114 deletions(-) diff --git a/instant-distance/src/lib.rs b/instant-distance/src/lib.rs index c1f39f5..9f5a252 100644 --- a/instant-distance/src/lib.rs +++ b/instant-distance/src/lib.rs @@ -280,74 +280,43 @@ where .map(|_| RwLock::new(ZeroNode::default())) .collect::>(); - let pool = SearchPool::new(points.len()); - #[cfg(feature = "indicatif")] - let done = AtomicUsize::new(0); - for (layer, range) in ranges { - let num = if layer.is_zero() { M * 2 } else { M }; + let state = Construction { + zero: zero.as_slice(), + pool: SearchPool::new(points.len()), + top, + points: &points, + heuristic, + ef_construction, #[cfg(feature = "indicatif")] - if let Some(bar) = &progress { + progress, + #[cfg(feature = "indicatif")] + done: AtomicUsize::new(0), + }; + + for (layer, range) in ranges { + #[cfg(feature = "indicatif")] + if let Some(bar) = &state.progress { bar.set_message(format!("Building index (layer {})", layer.0)); } let end = range.end; nodes[range].into_par_iter().for_each(|(_, pid)| { - let node = zero.as_slice()[*pid].write(); - let (mut search, mut insertion) = pool.pop(); - let point = &points.as_slice()[*pid]; - search.reset(); - search.push(PointId(0), point, &points); - - for cur in top.descend() { - search.ef = if cur <= layer { ef_construction } else { 1 }; - match cur > layer { - true => { - search.search(point, layers[cur.0 - 1].as_slice(), &points, num); - search.cull(); - } - false => { - search.search(point, zero.as_slice(), &points, num); - break; - } - } - } - - insertion.ef = ef_construction; - insert( - *pid, - node, - &mut insertion, - &mut search, - &zero, - &points, - &heuristic, - ); - - #[cfg(feature = "indicatif")] - if let Some(bar) = &progress { - let value = done.fetch_add(1, atomic::Ordering::Relaxed); - if value % 1000 == 0 { - bar.set_position(value as u64); - } - } - - pool.push((search, insertion)); + let node = state.zero[*pid].write(); + state.insert(*pid, node, layer, &layers); }); // For layers above the zero layer, make a copy of the current state of the zero layer // with `nearest` truncated to `M` elements. if !layer.is_zero() { - let mut upper = Vec::new(); - (&zero[..end]) + (&state.zero[..end]) .into_par_iter() .map(|zero| UpperNode::from_zero(&zero.read())) - .collect_into_vec(&mut upper); - layers[layer.0 - 1] = upper; + .collect_into_vec(&mut layers[layer.0 - 1]); } } #[cfg(feature = "indicatif")] - if let Some(bar) = progress { + if let Some(bar) = &state.progress { bar.finish(); } @@ -408,76 +377,127 @@ where } } -/// Insert new node in the zero layer -/// -/// * `new`: the `PointId` for the new node -/// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection) -/// * `search`: the result for searching potential neighbors for the new node -/// * `layer` contains all the nodes at the current layer -/// * `points` is a slice of all the points in the index -/// -/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors -/// for the new node's neighbors if necessary before appending the new node to the layer. -fn insert( - new: PointId, - mut node: parking_lot::RwLockWriteGuard, - insertion: &mut Search, - search: &mut Search, - layer: &[RwLock], - points: &[P], - heuristic: &Option, -) { - let found = match heuristic { - None => { - let candidates = search.select_simple(); - &candidates[..Ord::min(candidates.len(), M * 2)] +struct Construction<'a, P: Point> { + zero: &'a [RwLock], + pool: SearchPool, + top: LayerId, + points: &'a [P], + heuristic: Option, + ef_construction: usize, + #[cfg(feature = "indicatif")] + progress: Option, + #[cfg(feature = "indicatif")] + done: AtomicUsize, +} + +impl<'a, P: Point> Construction<'a, P> { + /// Insert new node in the zero layer + /// + /// * `new`: the `PointId` for the new node + /// * `insertion`: a `Search` for shrinking a neighbor set (only used with heuristic neighbor selection) + /// * `search`: the result for searching potential neighbors for the new node + /// * `layer` contains all the nodes at the current layer + /// * `points` is a slice of all the points in the index + /// + /// Creates the new node, initializing its `nearest` array and updates the nearest neighbors + /// for the new node's neighbors if necessary before appending the new node to the layer. + fn insert( + &self, + new: PointId, + mut node: parking_lot::RwLockWriteGuard, + layer: LayerId, + layers: &[Vec], + ) { + let (mut search, mut insertion) = self.pool.pop(); + insertion.ef = self.ef_construction; + + let point = &self.points[new]; + search.reset(); + search.push(PointId(0), point, &self.points); + let num = if layer.is_zero() { M * 2 } else { M }; + + for cur in self.top.descend() { + search.ef = if cur <= layer { + self.ef_construction + } else { + 1 + }; + match cur > layer { + true => { + search.search(point, layers[cur.0 - 1].as_slice(), &self.points, num); + search.cull(); + } + false => { + search.search(point, self.zero, &self.points, num); + break; + } + } } - Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic), - }; - // Just make sure the candidates are all unique - debug_assert_eq!( - found.len(), - found.iter().map(|c| c.pid).collect::>().len() - ); + let found = match self.heuristic { + None => { + let candidates = search.select_simple(); + &candidates[..Ord::min(candidates.len(), M * 2)] + } + Some(heuristic) => { + search.select_heuristic(&self.points[new], self.zero, self.points, heuristic) + } + }; - for (i, candidate) in found.iter().enumerate() { - // `candidate` here is the new node's neighbor - let &Candidate { distance, pid } = candidate; - if let Some(heuristic) = heuristic { - let found = insertion.add_neighbor_heuristic( - new, - layer.nearest_iter(pid), - layer, - &points[pid], - points, - *heuristic, - ); + // Just make sure the candidates are all unique + debug_assert_eq!( + found.len(), + found.iter().map(|c| c.pid).collect::>().len() + ); - layer[pid] - .write() - .rewrite(found.iter().map(|candidate| candidate.pid)); - node.set(i, pid); - } else { - // Find the correct index to insert at to keep the neighbor's neighbors sorted - let old = &points[pid]; - let idx = layer[pid] - .read() - .binary_search_by(|third| { - // `third` here is one of the neighbors of the new node's neighbor. - let third = match third { - pid if pid.is_valid() => *pid, - // if `third` is `None`, our new `node` is always "closer" - _ => return Ordering::Greater, - }; + for (i, candidate) in found.iter().enumerate() { + // `candidate` here is the new node's neighbor + let &Candidate { distance, pid } = candidate; + if let Some(heuristic) = self.heuristic { + let found = insertion.add_neighbor_heuristic( + new, + self.zero.nearest_iter(pid), + self.zero, + &self.points[pid], + self.points, + heuristic, + ); - distance.cmp(&old.distance(&points[third]).into()) - }) - .unwrap_or_else(|e| e); + self.zero[pid] + .write() + .rewrite(found.iter().map(|candidate| candidate.pid)); + node.set(i, pid); + } else { + // Find the correct index to insert at to keep the neighbor's neighbors sorted + let old = &self.points[pid]; + let idx = self.zero[pid] + .read() + .binary_search_by(|third| { + // `third` here is one of the neighbors of the new node's neighbor. + let third = match third { + pid if pid.is_valid() => *pid, + // if `third` is `None`, our new `node` is always "closer" + _ => return Ordering::Greater, + }; - layer[pid].write().insert(idx, new); - node.set(i, pid); + distance.cmp(&old.distance(&self.points[third]).into()) + }) + .unwrap_or_else(|e| e); + + self.zero[pid].write().insert(idx, new); + node.set(i, pid); + } } + + #[cfg(feature = "indicatif")] + if let Some(bar) = &self.progress { + let value = self.done.fetch_add(1, atomic::Ordering::Relaxed); + if value % 1000 == 0 { + bar.set_position(value as u64); + } + } + + self.pool.push((search, insertion)); } }