diff --git a/Cargo.toml b/Cargo.toml index 441d9bd..200b7a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" indicatif = { version = "0.15", optional = true } num_cpus = "1.13" ordered-float = "2.0" +parking_lot = "0.11" rand = { version = "0.8", features = ["small_rng"] } rayon = "1.5" serde = { version = "1.0.118", features = ["derive"], optional = true } diff --git a/src/lib.rs b/src/lib.rs index 357ca07..af7ed2d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,16 @@ -use std::cmp::{max, min, Ordering, Reverse}; +use std::cmp::{max, Ordering, Reverse}; use std::collections::BinaryHeap; use std::collections::HashSet; +#[cfg(feature = "indicatif")] +use std::sync::atomic::{self, AtomicUsize}; #[cfg(feature = "indicatif")] use indicatif::ProgressBar; use ordered_float::OrderedFloat; +use parking_lot::{Mutex, RwLock}; use rand::rngs::SmallRng; use rand::SeedableRng; -use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -125,6 +128,7 @@ where let ef_search = builder.ef_search.unwrap_or(100); let ef_construction = builder.ef_construction.unwrap_or(100); let ml = builder.ml.unwrap_or_else(|| (M as f32).ln()); + let heuristic = builder.heuristic; let mut rng = match builder.seed { Some(seed) => SmallRng::seed_from_u64(seed), None => SmallRng::from_entropy(), @@ -136,6 +140,7 @@ where if let Some(bar) = &progress { bar.set_draw_delta(1_000); bar.set_length(points.len() as u64); + bar.set_message("Build index (preparation)"); } if points.is_empty() { @@ -220,88 +225,67 @@ where // Insert the first point so that we have an enter point to start searches with. let mut layers = vec![vec![]; top.0]; - let mut zero = Vec::with_capacity(points.len()); - zero.push(ZeroNode::default()); + let zero = points + .iter() + .map(|_| RwLock::new(ZeroNode::default())) + .collect::>(); - let mut insertion = Search { - ef: ef_construction, - visited: Visited::with_capacity(points.len()), - ..Default::default() - }; - - let mut pool = SearchPool { - pool: Vec::new(), - len: points.len(), - }; - - let mut batch = Vec::new(); - let mut done = Vec::new(); - let max_batch_len = num_cpus::get() * 4; - for (layer, mut range) in ranges { + let pool = SearchPool::new(points.len()); + #[cfg(feature = "indicatif")] + let done = AtomicUsize::new(0); + for (layer, range) in ranges { let num = if layer.0 > 0 { M } else { M * 2 }; #[cfg(feature = "indicatif")] if let Some(bar) = &progress { bar.set_message(&format!("Building index (layer {})", layer.0)); } - while range.start < range.end { - let len = min(range.len(), max_batch_len); - batch.clear(); - batch.extend( - nodes[range.start..(range.start + len)] - .iter() - .map(|&(_, pid)| (pid, pool.pop())), - ); + nodes[range].into_par_iter().for_each(|(_, pid)| { + let (mut search, mut insertion) = pool.pop(); + let point = &points.as_slice()[*pid]; + search.reset(); + search.push(PointId(0), point, &points); - batch.par_iter_mut().for_each(|(pid, search)| { - let point = &points[*pid]; - search.push(PointId(0), point, &points); - for cur in top.descend() { - search.ef = if cur <= layer { ef_construction } else { 1 }; - match cur > layer { - true => { - search.search(point, layers[cur.0 - 1].as_slice(), &points, num); - search.cull(); - } - false => { - search.search(point, zero.as_slice(), &points, num); - break; - } + for cur in top.descend() { + search.ef = if cur <= layer { ef_construction } else { 1 }; + match cur > layer { + true => { + search.search(point, layers[cur.0 - 1].as_slice(), &points, num); + search.cull(); + } + false => { + search.search(point, zero.as_slice(), &points, num); + break; } } - }); - - done.clear(); - for (pid, mut search) in batch.drain(..) { - for added in done.iter().copied() { - search.push(added, &points[pid], &points); - } - - insert( - pid, - &mut insertion, - &mut search, - &mut zero, - &points, - &builder.heuristic, - ); - done.push(pid); - pool.push(search); } + insertion.ef = ef_construction; + insert( + *pid, + &mut insertion, + &mut search, + &zero, + &points, + &heuristic, + ); + #[cfg(feature = "indicatif")] if let Some(bar) = &progress { - bar.inc(done.len() as u64); + let value = done.fetch_add(1, atomic::Ordering::Relaxed); + if value % 1000 == 0 { + bar.set_position(value as u64); + } } - range.start += len; - } + pool.push((search, insertion)); + }); // For layers above the zero layer, make a copy of the current state of the zero layer // with `nearest` truncated to `M` elements. if layer.0 > 0 { let mut upper = Vec::with_capacity(zero.len()); - upper.extend(zero.iter().map(UpperNode::from_zero)); + upper.extend(zero.iter().map(|zero| UpperNode::from_zero(&zero.read()))); layers[layer.0 - 1] = upper; } } @@ -314,7 +298,7 @@ where ( Self { ef_search, - zero, + zero: zero.into_iter().map(|node| node.into_inner()).collect(), points, layers, }, @@ -382,14 +366,14 @@ fn insert( new: PointId, insertion: &mut Search, search: &mut Search, - layer: &mut Vec, + layer: &[RwLock], points: &[P], heuristic: &Option, ) { - layer.push(ZeroNode::default()); + let mut node = layer[new].write(); let found = match heuristic { None => search.select_simple(M * 2), - Some(heuristic) => search.select_heuristic(&points[new], &layer, points, *heuristic), + Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic), }; // Just make sure the candidates are all unique @@ -404,19 +388,22 @@ fn insert( if let Some(heuristic) = heuristic { let found = insertion.add_neighbor_heuristic( new, - layer.as_slice().nearest_iter(pid), + layer.nearest_iter(pid), layer, &points[pid], points, *heuristic, ); - layer[pid].rewrite(found.iter().map(|candidate| candidate.pid)); - layer[new].set(i, pid); + layer[pid] + .write() + .rewrite(found.iter().map(|candidate| candidate.pid)); + node.set(i, pid); } else { // Find the correct index to insert at to keep the neighbor's neighbors sorted let old = &points[pid]; let idx = layer[pid] + .read() .binary_search_by(|third| { // `third` here is one of the neighbors of the new node's neighbor. let third = match third { @@ -429,30 +416,34 @@ fn insert( }) .unwrap_or_else(|e| e); - layer[pid].insert(idx, new); - layer[new].set(i, pid); + layer[pid].write().insert(idx, new); + node.set(i, pid); } } } struct SearchPool { - pool: Vec, + pool: Mutex>, len: usize, } impl SearchPool { - fn pop(&mut self) -> Search { - match self.pool.pop() { - Some(mut search) => { - search.reset(); - search - } - None => Search::new(self.len), + fn new(len: usize) -> Self { + Self { + pool: Mutex::new(Vec::new()), + len, } } - fn push(&mut self, search: Search) { - self.pool.push(search); + fn pop(&self) -> (Search, Search) { + match self.pool.lock().pop() { + Some(res) => res, + None => (Search::new(self.len), Search::new(self.len)), + } + } + + fn push(&self, item: (Search, Search)) { + self.pool.lock().push(item); } } @@ -516,11 +507,11 @@ impl Search { } } - fn add_neighbor_heuristic( + fn add_neighbor_heuristic( &mut self, new: PointId, current: impl Iterator, - layer: &[ZeroNode], + layer: L, point: &P, points: &[P], params: Heuristic, @@ -530,16 +521,16 @@ impl Search { for pid in current { self.push(pid, point, points); } - self.select_heuristic(point, &layer, points, params) + self.select_heuristic(point, layer, points, params) } /// Heuristically sort and truncate neighbors in `self.nearest` /// /// Invariant: `self.nearest` must be in sorted (nearest first) order. - fn select_heuristic( + fn select_heuristic( &mut self, point: &P, - layer: &[ZeroNode], + layer: L, points: &[P], params: Heuristic, ) -> &[Candidate] { diff --git a/src/types.rs b/src/types.rs index 0c3af9f..4431342 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,8 @@ use std::hash::Hash; -use std::ops::{Deref, Index, IndexMut}; +use std::ops::{Deref, Index}; use ordered_float::OrderedFloat; +use parking_lot::{MappedRwLockReadGuard, RwLock, RwLockReadGuard}; use rand::rngs::SmallRng; use rand::Rng; #[cfg(feature = "serde")] @@ -69,11 +70,11 @@ impl UpperNode { } } -impl Layer for &[UpperNode] { - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { - NearestIter { - nearest: &self[pid.0 as usize].0, - } +impl<'a> Layer for &'a [UpperNode] { + type Slice = &'a [PointId]; + + fn nearest_iter(&self, pid: PointId) -> NearestIter { + NearestIter::new(&self[pid.0 as usize].0) } } @@ -128,32 +129,63 @@ impl Deref for ZeroNode { } } -impl Layer for &[ZeroNode] { - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { - NearestIter { - nearest: &self[pid.0 as usize].0, - } +impl<'a> Layer for &'a [ZeroNode] { + type Slice = &'a [PointId]; + + fn nearest_iter(&self, pid: PointId) -> NearestIter { + NearestIter::new(&self[pid.0 as usize]) + } +} + +impl<'a> Layer for &'a [RwLock] { + type Slice = MappedRwLockReadGuard<'a, [PointId]>; + + fn nearest_iter(&self, pid: PointId) -> NearestIter { + NearestIter::new(RwLockReadGuard::map( + self[pid.0 as usize].read(), + Deref::deref, + )) } } pub(crate) trait Layer { - fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>; + type Slice: Deref; + fn nearest_iter(&self, pid: PointId) -> NearestIter; } -pub(crate) struct NearestIter<'a> { - nearest: &'a [PointId], +pub(crate) struct NearestIter { + node: T, + cur: usize, } -impl<'a> Iterator for NearestIter<'a> { +impl NearestIter +where + T: Deref, +{ + fn new(node: T) -> Self { + Self { node, cur: 0 } + } +} + +impl Iterator for NearestIter +where + T: Deref, +{ type Item = PointId; fn next(&mut self) -> Option { - let (&first, rest) = self.nearest.split_first()?; - if !first.is_valid() { + if self.cur >= self.node.len() { return None; } - self.nearest = rest; - Some(first) + + let item = self.node[self.cur]; + if !item.is_valid() { + self.cur = self.node.len(); + return None; + } + + self.cur += 1; + Some(item) } } @@ -230,14 +262,6 @@ impl

Index for Hnsw

{ } } -impl Index for Vec

{ - type Output = P; - - fn index(&self, index: PointId) -> &Self::Output { - &self[index.0 as usize] - } -} - impl Index for [P] { type Output = P; @@ -246,18 +270,12 @@ impl Index for [P] { } } -impl Index for Vec { - type Output = ZeroNode; +impl Index for [RwLock] { + type Output = RwLock; fn index(&self, index: PointId) -> &Self::Output { &self[index.0 as usize] } } -impl IndexMut for Vec { - fn index_mut(&mut self, index: PointId) -> &mut Self::Output { - &mut self[index.0 as usize] - } -} - pub(crate) const INVALID: PointId = PointId(u32::MAX);