Improve parallelization with fine-grained locks

This commit is contained in:
Dirkjan Ochtman 2021-01-21 13:45:39 +01:00
parent f388fd0a46
commit 58bb0f315a
3 changed files with 132 additions and 122 deletions

View File

@ -9,6 +9,7 @@ edition = "2018"
indicatif = { version = "0.15", optional = true } indicatif = { version = "0.15", optional = true }
num_cpus = "1.13" num_cpus = "1.13"
ordered-float = "2.0" ordered-float = "2.0"
parking_lot = "0.11"
rand = { version = "0.8", features = ["small_rng"] } rand = { version = "0.8", features = ["small_rng"] }
rayon = "1.5" rayon = "1.5"
serde = { version = "1.0.118", features = ["derive"], optional = true } serde = { version = "1.0.118", features = ["derive"], optional = true }

View File

@ -1,13 +1,16 @@
use std::cmp::{max, min, Ordering, Reverse}; use std::cmp::{max, Ordering, Reverse};
use std::collections::BinaryHeap; use std::collections::BinaryHeap;
use std::collections::HashSet; use std::collections::HashSet;
#[cfg(feature = "indicatif")]
use std::sync::atomic::{self, AtomicUsize};
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
use indicatif::ProgressBar; use indicatif::ProgressBar;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use parking_lot::{Mutex, RwLock};
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use rand::SeedableRng; use rand::SeedableRng;
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -125,6 +128,7 @@ where
let ef_search = builder.ef_search.unwrap_or(100); let ef_search = builder.ef_search.unwrap_or(100);
let ef_construction = builder.ef_construction.unwrap_or(100); let ef_construction = builder.ef_construction.unwrap_or(100);
let ml = builder.ml.unwrap_or_else(|| (M as f32).ln()); let ml = builder.ml.unwrap_or_else(|| (M as f32).ln());
let heuristic = builder.heuristic;
let mut rng = match builder.seed { let mut rng = match builder.seed {
Some(seed) => SmallRng::seed_from_u64(seed), Some(seed) => SmallRng::seed_from_u64(seed),
None => SmallRng::from_entropy(), None => SmallRng::from_entropy(),
@ -136,6 +140,7 @@ where
if let Some(bar) = &progress { if let Some(bar) = &progress {
bar.set_draw_delta(1_000); bar.set_draw_delta(1_000);
bar.set_length(points.len() as u64); bar.set_length(points.len() as u64);
bar.set_message("Build index (preparation)");
} }
if points.is_empty() { if points.is_empty() {
@ -220,88 +225,67 @@ where
// Insert the first point so that we have an enter point to start searches with. // Insert the first point so that we have an enter point to start searches with.
let mut layers = vec![vec![]; top.0]; let mut layers = vec![vec![]; top.0];
let mut zero = Vec::with_capacity(points.len()); let zero = points
zero.push(ZeroNode::default()); .iter()
.map(|_| RwLock::new(ZeroNode::default()))
.collect::<Vec<_>>();
let mut insertion = Search { let pool = SearchPool::new(points.len());
ef: ef_construction, #[cfg(feature = "indicatif")]
visited: Visited::with_capacity(points.len()), let done = AtomicUsize::new(0);
..Default::default() for (layer, range) in ranges {
};
let mut pool = SearchPool {
pool: Vec::new(),
len: points.len(),
};
let mut batch = Vec::new();
let mut done = Vec::new();
let max_batch_len = num_cpus::get() * 4;
for (layer, mut range) in ranges {
let num = if layer.0 > 0 { M } else { M * 2 }; let num = if layer.0 > 0 { M } else { M * 2 };
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
if let Some(bar) = &progress { if let Some(bar) = &progress {
bar.set_message(&format!("Building index (layer {})", layer.0)); bar.set_message(&format!("Building index (layer {})", layer.0));
} }
while range.start < range.end { nodes[range].into_par_iter().for_each(|(_, pid)| {
let len = min(range.len(), max_batch_len); let (mut search, mut insertion) = pool.pop();
batch.clear(); let point = &points.as_slice()[*pid];
batch.extend( search.reset();
nodes[range.start..(range.start + len)] search.push(PointId(0), point, &points);
.iter()
.map(|&(_, pid)| (pid, pool.pop())),
);
batch.par_iter_mut().for_each(|(pid, search)| { for cur in top.descend() {
let point = &points[*pid]; search.ef = if cur <= layer { ef_construction } else { 1 };
search.push(PointId(0), point, &points); match cur > layer {
for cur in top.descend() { true => {
search.ef = if cur <= layer { ef_construction } else { 1 }; search.search(point, layers[cur.0 - 1].as_slice(), &points, num);
match cur > layer { search.cull();
true => { }
search.search(point, layers[cur.0 - 1].as_slice(), &points, num); false => {
search.cull(); search.search(point, zero.as_slice(), &points, num);
} break;
false => {
search.search(point, zero.as_slice(), &points, num);
break;
}
} }
} }
});
done.clear();
for (pid, mut search) in batch.drain(..) {
for added in done.iter().copied() {
search.push(added, &points[pid], &points);
}
insert(
pid,
&mut insertion,
&mut search,
&mut zero,
&points,
&builder.heuristic,
);
done.push(pid);
pool.push(search);
} }
insertion.ef = ef_construction;
insert(
*pid,
&mut insertion,
&mut search,
&zero,
&points,
&heuristic,
);
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
if let Some(bar) = &progress { if let Some(bar) = &progress {
bar.inc(done.len() as u64); let value = done.fetch_add(1, atomic::Ordering::Relaxed);
if value % 1000 == 0 {
bar.set_position(value as u64);
}
} }
range.start += len; pool.push((search, insertion));
} });
// For layers above the zero layer, make a copy of the current state of the zero layer // For layers above the zero layer, make a copy of the current state of the zero layer
// with `nearest` truncated to `M` elements. // with `nearest` truncated to `M` elements.
if layer.0 > 0 { if layer.0 > 0 {
let mut upper = Vec::with_capacity(zero.len()); let mut upper = Vec::with_capacity(zero.len());
upper.extend(zero.iter().map(UpperNode::from_zero)); upper.extend(zero.iter().map(|zero| UpperNode::from_zero(&zero.read())));
layers[layer.0 - 1] = upper; layers[layer.0 - 1] = upper;
} }
} }
@ -314,7 +298,7 @@ where
( (
Self { Self {
ef_search, ef_search,
zero, zero: zero.into_iter().map(|node| node.into_inner()).collect(),
points, points,
layers, layers,
}, },
@ -382,14 +366,14 @@ fn insert<P: Point>(
new: PointId, new: PointId,
insertion: &mut Search, insertion: &mut Search,
search: &mut Search, search: &mut Search,
layer: &mut Vec<ZeroNode>, layer: &[RwLock<ZeroNode>],
points: &[P], points: &[P],
heuristic: &Option<Heuristic>, heuristic: &Option<Heuristic>,
) { ) {
layer.push(ZeroNode::default()); let mut node = layer[new].write();
let found = match heuristic { let found = match heuristic {
None => search.select_simple(M * 2), None => search.select_simple(M * 2),
Some(heuristic) => search.select_heuristic(&points[new], &layer, points, *heuristic), Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
}; };
// Just make sure the candidates are all unique // Just make sure the candidates are all unique
@ -404,19 +388,22 @@ fn insert<P: Point>(
if let Some(heuristic) = heuristic { if let Some(heuristic) = heuristic {
let found = insertion.add_neighbor_heuristic( let found = insertion.add_neighbor_heuristic(
new, new,
layer.as_slice().nearest_iter(pid), layer.nearest_iter(pid),
layer, layer,
&points[pid], &points[pid],
points, points,
*heuristic, *heuristic,
); );
layer[pid].rewrite(found.iter().map(|candidate| candidate.pid)); layer[pid]
layer[new].set(i, pid); .write()
.rewrite(found.iter().map(|candidate| candidate.pid));
node.set(i, pid);
} else { } else {
// Find the correct index to insert at to keep the neighbor's neighbors sorted // Find the correct index to insert at to keep the neighbor's neighbors sorted
let old = &points[pid]; let old = &points[pid];
let idx = layer[pid] let idx = layer[pid]
.read()
.binary_search_by(|third| { .binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor. // `third` here is one of the neighbors of the new node's neighbor.
let third = match third { let third = match third {
@ -429,30 +416,34 @@ fn insert<P: Point>(
}) })
.unwrap_or_else(|e| e); .unwrap_or_else(|e| e);
layer[pid].insert(idx, new); layer[pid].write().insert(idx, new);
layer[new].set(i, pid); node.set(i, pid);
} }
} }
} }
struct SearchPool { struct SearchPool {
pool: Vec<Search>, pool: Mutex<Vec<(Search, Search)>>,
len: usize, len: usize,
} }
impl SearchPool { impl SearchPool {
fn pop(&mut self) -> Search { fn new(len: usize) -> Self {
match self.pool.pop() { Self {
Some(mut search) => { pool: Mutex::new(Vec::new()),
search.reset(); len,
search
}
None => Search::new(self.len),
} }
} }
fn push(&mut self, search: Search) { fn pop(&self) -> (Search, Search) {
self.pool.push(search); match self.pool.lock().pop() {
Some(res) => res,
None => (Search::new(self.len), Search::new(self.len)),
}
}
fn push(&self, item: (Search, Search)) {
self.pool.lock().push(item);
} }
} }
@ -516,11 +507,11 @@ impl Search {
} }
} }
fn add_neighbor_heuristic<P: Point>( fn add_neighbor_heuristic<L: Layer, P: Point>(
&mut self, &mut self,
new: PointId, new: PointId,
current: impl Iterator<Item = PointId>, current: impl Iterator<Item = PointId>,
layer: &[ZeroNode], layer: L,
point: &P, point: &P,
points: &[P], points: &[P],
params: Heuristic, params: Heuristic,
@ -530,16 +521,16 @@ impl Search {
for pid in current { for pid in current {
self.push(pid, point, points); self.push(pid, point, points);
} }
self.select_heuristic(point, &layer, points, params) self.select_heuristic(point, layer, points, params)
} }
/// Heuristically sort and truncate neighbors in `self.nearest` /// Heuristically sort and truncate neighbors in `self.nearest`
/// ///
/// Invariant: `self.nearest` must be in sorted (nearest first) order. /// Invariant: `self.nearest` must be in sorted (nearest first) order.
fn select_heuristic<P: Point>( fn select_heuristic<L: Layer, P: Point>(
&mut self, &mut self,
point: &P, point: &P,
layer: &[ZeroNode], layer: L,
points: &[P], points: &[P],
params: Heuristic, params: Heuristic,
) -> &[Candidate] { ) -> &[Candidate] {

View File

@ -1,7 +1,8 @@
use std::hash::Hash; use std::hash::Hash;
use std::ops::{Deref, Index, IndexMut}; use std::ops::{Deref, Index};
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use parking_lot::{MappedRwLockReadGuard, RwLock, RwLockReadGuard};
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use rand::Rng; use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
@ -69,11 +70,11 @@ impl UpperNode {
} }
} }
impl Layer for &[UpperNode] { impl<'a> Layer for &'a [UpperNode] {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { type Slice = &'a [PointId];
NearestIter {
nearest: &self[pid.0 as usize].0, fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
} NearestIter::new(&self[pid.0 as usize].0)
} }
} }
@ -128,32 +129,63 @@ impl Deref for ZeroNode {
} }
} }
impl Layer for &[ZeroNode] { impl<'a> Layer for &'a [ZeroNode] {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> { type Slice = &'a [PointId];
NearestIter {
nearest: &self[pid.0 as usize].0, fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
} NearestIter::new(&self[pid.0 as usize])
}
}
impl<'a> Layer for &'a [RwLock<ZeroNode>] {
type Slice = MappedRwLockReadGuard<'a, [PointId]>;
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
NearestIter::new(RwLockReadGuard::map(
self[pid.0 as usize].read(),
Deref::deref,
))
} }
} }
pub(crate) trait Layer { pub(crate) trait Layer {
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>; type Slice: Deref<Target = [PointId]>;
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice>;
} }
pub(crate) struct NearestIter<'a> { pub(crate) struct NearestIter<T> {
nearest: &'a [PointId], node: T,
cur: usize,
} }
impl<'a> Iterator for NearestIter<'a> { impl<T> NearestIter<T>
where
T: Deref<Target = [PointId]>,
{
fn new(node: T) -> Self {
Self { node, cur: 0 }
}
}
impl<T> Iterator for NearestIter<T>
where
T: Deref<Target = [PointId]>,
{
type Item = PointId; type Item = PointId;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let (&first, rest) = self.nearest.split_first()?; if self.cur >= self.node.len() {
if !first.is_valid() {
return None; return None;
} }
self.nearest = rest;
Some(first) let item = self.node[self.cur];
if !item.is_valid() {
self.cur = self.node.len();
return None;
}
self.cur += 1;
Some(item)
} }
} }
@ -230,14 +262,6 @@ impl<P> Index<PointId> for Hnsw<P> {
} }
} }
impl<P: Point> Index<PointId> for Vec<P> {
type Output = P;
fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize]
}
}
impl<P: Point> Index<PointId> for [P] { impl<P: Point> Index<PointId> for [P] {
type Output = P; type Output = P;
@ -246,18 +270,12 @@ impl<P: Point> Index<PointId> for [P] {
} }
} }
impl Index<PointId> for Vec<ZeroNode> { impl Index<PointId> for [RwLock<ZeroNode>] {
type Output = ZeroNode; type Output = RwLock<ZeroNode>;
fn index(&self, index: PointId) -> &Self::Output { fn index(&self, index: PointId) -> &Self::Output {
&self[index.0 as usize] &self[index.0 as usize]
} }
} }
impl IndexMut<PointId> for Vec<ZeroNode> {
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
&mut self[index.0 as usize]
}
}
pub(crate) const INVALID: PointId = PointId(u32::MAX); pub(crate) const INVALID: PointId = PointId(u32::MAX);