diff --git a/Cargo.toml b/Cargo.toml index a2a9a78..8f7f2fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ edition = "2018" [dependencies] ahash = "0.6.1" indicatif = { version = "0.15", optional = true } +num_cpus = "1.13" ordered-float = "2.0" rand = { version = "0.7.3", features = ["small_rng"] } rayon = "1.5" diff --git a/src/lib.rs b/src/lib.rs index a1438e5..0db2ca4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ use indicatif::ProgressBar; use ordered_float::OrderedFloat; use rand::rngs::SmallRng; use rand::{RngCore, SeedableRng}; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -32,6 +33,11 @@ where let ef_construction = builder.ef_construction.unwrap_or(100); #[cfg(feature = "indicatif")] let progress = builder.progress; + #[cfg(feature = "indicatif")] + if let Some(bar) = &progress { + bar.set_draw_delta(1_000); + bar.set_length(points.len() as u64); + } if points.is_empty() { return ( @@ -99,31 +105,51 @@ where let mut zero = Vec::with_capacity(points.len()); zero.push(ZeroNode::default()); - let mut search = Search::default(); - for (layer, range) in ranges { + let mut pool = SearchPool::default(); + let mut batch = Vec::new(); + let mut done = Vec::new(); + let max_batch_len = num_cpus::get(); + for (layer, mut range) in ranges { let num = if layer.0 > 0 { M } else { M * 2 }; - for &(_, pid) in &nodes[range] { + + while range.start < range.end { + let len = min(range.len(), max_batch_len); + batch.clear(); + batch.extend( + nodes[range.start..(range.start + len)] + .iter() + .map(|&(_, pid)| (pid, pool.pop())), + ); + + batch.par_iter_mut().for_each(|(pid, search)| { + let point = &points[*pid]; + search.push(PointId(0), point, &points); + for cur in top.descend() { + search.num = if cur <= layer { ef_construction } else { 1 }; + zero.search(point, search, &points, num); + match cur > layer { + true => search.cull(), + false => break, + } + } + }); + + done.clear(); + for (pid, mut search) in batch.drain(..) { + for added in done.iter().copied() { + search.push(added, &points[pid], &points); + } + zero.insert_node(pid, &search.nearest, &points); + done.push(pid); + pool.push(search); + } + #[cfg(feature = "indicatif")] - if pid.0 % 10_000 == 0 { - if let Some(bar) = &progress { - bar.set_position(pid.0 as u64); - } + if let Some(bar) = &progress { + bar.inc(done.len() as u64); } - search.reset(); - let point = &points[pid]; - search.push(PointId(0), &points[pid], &points); - - for cur in top.descend() { - search.num = if cur <= layer { ef_construction } else { 1 }; - zero.search(point, &mut search, &points, num); - match cur > layer { - true => search.cull(), - false => break, - } - } - - zero.insert_node(pid, &search.nearest, &points); + range.start += len; } // For layers above the zero layer, make a copy of the current state of the zero layer @@ -200,6 +226,27 @@ where } } +#[derive(Default)] +struct SearchPool { + pool: Vec, +} + +impl SearchPool { + fn pop(&mut self) -> Search { + match self.pool.pop() { + Some(mut search) => { + search.reset(); + search + } + None => Search::default(), + } + } + + fn push(&mut self, search: Search) { + self.pool.push(search); + } +} + /// Keeps mutable state for searching a point's nearest neighbors /// /// In particular, this contains most of the state used in algorithm 2. The structure is @@ -555,7 +602,7 @@ impl Iterator for DescendingLayerIter { } } -pub trait Point: Clone { +pub trait Point: Clone + Sync { fn distance(&self, other: &Self) -> f32; }