Improve parallelization with fine-grained locks
This commit is contained in:
parent
f388fd0a46
commit
58bb0f315a
|
@ -9,6 +9,7 @@ edition = "2018"
|
||||||
indicatif = { version = "0.15", optional = true }
|
indicatif = { version = "0.15", optional = true }
|
||||||
num_cpus = "1.13"
|
num_cpus = "1.13"
|
||||||
ordered-float = "2.0"
|
ordered-float = "2.0"
|
||||||
|
parking_lot = "0.11"
|
||||||
rand = { version = "0.8", features = ["small_rng"] }
|
rand = { version = "0.8", features = ["small_rng"] }
|
||||||
rayon = "1.5"
|
rayon = "1.5"
|
||||||
serde = { version = "1.0.118", features = ["derive"], optional = true }
|
serde = { version = "1.0.118", features = ["derive"], optional = true }
|
||||||
|
|
133
src/lib.rs
133
src/lib.rs
|
@ -1,13 +1,16 @@
|
||||||
use std::cmp::{max, min, Ordering, Reverse};
|
use std::cmp::{max, Ordering, Reverse};
|
||||||
use std::collections::BinaryHeap;
|
use std::collections::BinaryHeap;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
#[cfg(feature = "indicatif")]
|
||||||
|
use std::sync::atomic::{self, AtomicUsize};
|
||||||
|
|
||||||
#[cfg(feature = "indicatif")]
|
#[cfg(feature = "indicatif")]
|
||||||
use indicatif::ProgressBar;
|
use indicatif::ProgressBar;
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
|
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
@ -125,6 +128,7 @@ where
|
||||||
let ef_search = builder.ef_search.unwrap_or(100);
|
let ef_search = builder.ef_search.unwrap_or(100);
|
||||||
let ef_construction = builder.ef_construction.unwrap_or(100);
|
let ef_construction = builder.ef_construction.unwrap_or(100);
|
||||||
let ml = builder.ml.unwrap_or_else(|| (M as f32).ln());
|
let ml = builder.ml.unwrap_or_else(|| (M as f32).ln());
|
||||||
|
let heuristic = builder.heuristic;
|
||||||
let mut rng = match builder.seed {
|
let mut rng = match builder.seed {
|
||||||
Some(seed) => SmallRng::seed_from_u64(seed),
|
Some(seed) => SmallRng::seed_from_u64(seed),
|
||||||
None => SmallRng::from_entropy(),
|
None => SmallRng::from_entropy(),
|
||||||
|
@ -136,6 +140,7 @@ where
|
||||||
if let Some(bar) = &progress {
|
if let Some(bar) = &progress {
|
||||||
bar.set_draw_delta(1_000);
|
bar.set_draw_delta(1_000);
|
||||||
bar.set_length(points.len() as u64);
|
bar.set_length(points.len() as u64);
|
||||||
|
bar.set_message("Build index (preparation)");
|
||||||
}
|
}
|
||||||
|
|
||||||
if points.is_empty() {
|
if points.is_empty() {
|
||||||
|
@ -220,42 +225,27 @@ where
|
||||||
// Insert the first point so that we have an enter point to start searches with.
|
// Insert the first point so that we have an enter point to start searches with.
|
||||||
|
|
||||||
let mut layers = vec![vec![]; top.0];
|
let mut layers = vec![vec![]; top.0];
|
||||||
let mut zero = Vec::with_capacity(points.len());
|
let zero = points
|
||||||
zero.push(ZeroNode::default());
|
.iter()
|
||||||
|
.map(|_| RwLock::new(ZeroNode::default()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let mut insertion = Search {
|
let pool = SearchPool::new(points.len());
|
||||||
ef: ef_construction,
|
#[cfg(feature = "indicatif")]
|
||||||
visited: Visited::with_capacity(points.len()),
|
let done = AtomicUsize::new(0);
|
||||||
..Default::default()
|
for (layer, range) in ranges {
|
||||||
};
|
|
||||||
|
|
||||||
let mut pool = SearchPool {
|
|
||||||
pool: Vec::new(),
|
|
||||||
len: points.len(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut batch = Vec::new();
|
|
||||||
let mut done = Vec::new();
|
|
||||||
let max_batch_len = num_cpus::get() * 4;
|
|
||||||
for (layer, mut range) in ranges {
|
|
||||||
let num = if layer.0 > 0 { M } else { M * 2 };
|
let num = if layer.0 > 0 { M } else { M * 2 };
|
||||||
#[cfg(feature = "indicatif")]
|
#[cfg(feature = "indicatif")]
|
||||||
if let Some(bar) = &progress {
|
if let Some(bar) = &progress {
|
||||||
bar.set_message(&format!("Building index (layer {})", layer.0));
|
bar.set_message(&format!("Building index (layer {})", layer.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
while range.start < range.end {
|
nodes[range].into_par_iter().for_each(|(_, pid)| {
|
||||||
let len = min(range.len(), max_batch_len);
|
let (mut search, mut insertion) = pool.pop();
|
||||||
batch.clear();
|
let point = &points.as_slice()[*pid];
|
||||||
batch.extend(
|
search.reset();
|
||||||
nodes[range.start..(range.start + len)]
|
|
||||||
.iter()
|
|
||||||
.map(|&(_, pid)| (pid, pool.pop())),
|
|
||||||
);
|
|
||||||
|
|
||||||
batch.par_iter_mut().for_each(|(pid, search)| {
|
|
||||||
let point = &points[*pid];
|
|
||||||
search.push(PointId(0), point, &points);
|
search.push(PointId(0), point, &points);
|
||||||
|
|
||||||
for cur in top.descend() {
|
for cur in top.descend() {
|
||||||
search.ef = if cur <= layer { ef_construction } else { 1 };
|
search.ef = if cur <= layer { ef_construction } else { 1 };
|
||||||
match cur > layer {
|
match cur > layer {
|
||||||
|
@ -269,39 +259,33 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
done.clear();
|
|
||||||
for (pid, mut search) in batch.drain(..) {
|
|
||||||
for added in done.iter().copied() {
|
|
||||||
search.push(added, &points[pid], &points);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
insertion.ef = ef_construction;
|
||||||
insert(
|
insert(
|
||||||
pid,
|
*pid,
|
||||||
&mut insertion,
|
&mut insertion,
|
||||||
&mut search,
|
&mut search,
|
||||||
&mut zero,
|
&zero,
|
||||||
&points,
|
&points,
|
||||||
&builder.heuristic,
|
&heuristic,
|
||||||
);
|
);
|
||||||
done.push(pid);
|
|
||||||
pool.push(search);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "indicatif")]
|
#[cfg(feature = "indicatif")]
|
||||||
if let Some(bar) = &progress {
|
if let Some(bar) = &progress {
|
||||||
bar.inc(done.len() as u64);
|
let value = done.fetch_add(1, atomic::Ordering::Relaxed);
|
||||||
|
if value % 1000 == 0 {
|
||||||
|
bar.set_position(value as u64);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
range.start += len;
|
pool.push((search, insertion));
|
||||||
}
|
});
|
||||||
|
|
||||||
// For layers above the zero layer, make a copy of the current state of the zero layer
|
// For layers above the zero layer, make a copy of the current state of the zero layer
|
||||||
// with `nearest` truncated to `M` elements.
|
// with `nearest` truncated to `M` elements.
|
||||||
if layer.0 > 0 {
|
if layer.0 > 0 {
|
||||||
let mut upper = Vec::with_capacity(zero.len());
|
let mut upper = Vec::with_capacity(zero.len());
|
||||||
upper.extend(zero.iter().map(UpperNode::from_zero));
|
upper.extend(zero.iter().map(|zero| UpperNode::from_zero(&zero.read())));
|
||||||
layers[layer.0 - 1] = upper;
|
layers[layer.0 - 1] = upper;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -314,7 +298,7 @@ where
|
||||||
(
|
(
|
||||||
Self {
|
Self {
|
||||||
ef_search,
|
ef_search,
|
||||||
zero,
|
zero: zero.into_iter().map(|node| node.into_inner()).collect(),
|
||||||
points,
|
points,
|
||||||
layers,
|
layers,
|
||||||
},
|
},
|
||||||
|
@ -382,14 +366,14 @@ fn insert<P: Point>(
|
||||||
new: PointId,
|
new: PointId,
|
||||||
insertion: &mut Search,
|
insertion: &mut Search,
|
||||||
search: &mut Search,
|
search: &mut Search,
|
||||||
layer: &mut Vec<ZeroNode>,
|
layer: &[RwLock<ZeroNode>],
|
||||||
points: &[P],
|
points: &[P],
|
||||||
heuristic: &Option<Heuristic>,
|
heuristic: &Option<Heuristic>,
|
||||||
) {
|
) {
|
||||||
layer.push(ZeroNode::default());
|
let mut node = layer[new].write();
|
||||||
let found = match heuristic {
|
let found = match heuristic {
|
||||||
None => search.select_simple(M * 2),
|
None => search.select_simple(M * 2),
|
||||||
Some(heuristic) => search.select_heuristic(&points[new], &layer, points, *heuristic),
|
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Just make sure the candidates are all unique
|
// Just make sure the candidates are all unique
|
||||||
|
@ -404,19 +388,22 @@ fn insert<P: Point>(
|
||||||
if let Some(heuristic) = heuristic {
|
if let Some(heuristic) = heuristic {
|
||||||
let found = insertion.add_neighbor_heuristic(
|
let found = insertion.add_neighbor_heuristic(
|
||||||
new,
|
new,
|
||||||
layer.as_slice().nearest_iter(pid),
|
layer.nearest_iter(pid),
|
||||||
layer,
|
layer,
|
||||||
&points[pid],
|
&points[pid],
|
||||||
points,
|
points,
|
||||||
*heuristic,
|
*heuristic,
|
||||||
);
|
);
|
||||||
|
|
||||||
layer[pid].rewrite(found.iter().map(|candidate| candidate.pid));
|
layer[pid]
|
||||||
layer[new].set(i, pid);
|
.write()
|
||||||
|
.rewrite(found.iter().map(|candidate| candidate.pid));
|
||||||
|
node.set(i, pid);
|
||||||
} else {
|
} else {
|
||||||
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
// Find the correct index to insert at to keep the neighbor's neighbors sorted
|
||||||
let old = &points[pid];
|
let old = &points[pid];
|
||||||
let idx = layer[pid]
|
let idx = layer[pid]
|
||||||
|
.read()
|
||||||
.binary_search_by(|third| {
|
.binary_search_by(|third| {
|
||||||
// `third` here is one of the neighbors of the new node's neighbor.
|
// `third` here is one of the neighbors of the new node's neighbor.
|
||||||
let third = match third {
|
let third = match third {
|
||||||
|
@ -429,30 +416,34 @@ fn insert<P: Point>(
|
||||||
})
|
})
|
||||||
.unwrap_or_else(|e| e);
|
.unwrap_or_else(|e| e);
|
||||||
|
|
||||||
layer[pid].insert(idx, new);
|
layer[pid].write().insert(idx, new);
|
||||||
layer[new].set(i, pid);
|
node.set(i, pid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SearchPool {
|
struct SearchPool {
|
||||||
pool: Vec<Search>,
|
pool: Mutex<Vec<(Search, Search)>>,
|
||||||
len: usize,
|
len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SearchPool {
|
impl SearchPool {
|
||||||
fn pop(&mut self) -> Search {
|
fn new(len: usize) -> Self {
|
||||||
match self.pool.pop() {
|
Self {
|
||||||
Some(mut search) => {
|
pool: Mutex::new(Vec::new()),
|
||||||
search.reset();
|
len,
|
||||||
search
|
|
||||||
}
|
|
||||||
None => Search::new(self.len),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn push(&mut self, search: Search) {
|
fn pop(&self) -> (Search, Search) {
|
||||||
self.pool.push(search);
|
match self.pool.lock().pop() {
|
||||||
|
Some(res) => res,
|
||||||
|
None => (Search::new(self.len), Search::new(self.len)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push(&self, item: (Search, Search)) {
|
||||||
|
self.pool.lock().push(item);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,11 +507,11 @@ impl Search {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_neighbor_heuristic<P: Point>(
|
fn add_neighbor_heuristic<L: Layer, P: Point>(
|
||||||
&mut self,
|
&mut self,
|
||||||
new: PointId,
|
new: PointId,
|
||||||
current: impl Iterator<Item = PointId>,
|
current: impl Iterator<Item = PointId>,
|
||||||
layer: &[ZeroNode],
|
layer: L,
|
||||||
point: &P,
|
point: &P,
|
||||||
points: &[P],
|
points: &[P],
|
||||||
params: Heuristic,
|
params: Heuristic,
|
||||||
|
@ -530,16 +521,16 @@ impl Search {
|
||||||
for pid in current {
|
for pid in current {
|
||||||
self.push(pid, point, points);
|
self.push(pid, point, points);
|
||||||
}
|
}
|
||||||
self.select_heuristic(point, &layer, points, params)
|
self.select_heuristic(point, layer, points, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Heuristically sort and truncate neighbors in `self.nearest`
|
/// Heuristically sort and truncate neighbors in `self.nearest`
|
||||||
///
|
///
|
||||||
/// Invariant: `self.nearest` must be in sorted (nearest first) order.
|
/// Invariant: `self.nearest` must be in sorted (nearest first) order.
|
||||||
fn select_heuristic<P: Point>(
|
fn select_heuristic<L: Layer, P: Point>(
|
||||||
&mut self,
|
&mut self,
|
||||||
point: &P,
|
point: &P,
|
||||||
layer: &[ZeroNode],
|
layer: L,
|
||||||
points: &[P],
|
points: &[P],
|
||||||
params: Heuristic,
|
params: Heuristic,
|
||||||
) -> &[Candidate] {
|
) -> &[Candidate] {
|
||||||
|
|
86
src/types.rs
86
src/types.rs
|
@ -1,7 +1,8 @@
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
use std::ops::{Deref, Index, IndexMut};
|
use std::ops::{Deref, Index};
|
||||||
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
|
use parking_lot::{MappedRwLockReadGuard, RwLock, RwLockReadGuard};
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
|
@ -69,11 +70,11 @@ impl UpperNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for &[UpperNode] {
|
impl<'a> Layer for &'a [UpperNode] {
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
type Slice = &'a [PointId];
|
||||||
NearestIter {
|
|
||||||
nearest: &self[pid.0 as usize].0,
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
|
||||||
}
|
NearestIter::new(&self[pid.0 as usize].0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,32 +129,63 @@ impl Deref for ZeroNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer for &[ZeroNode] {
|
impl<'a> Layer for &'a [ZeroNode] {
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_> {
|
type Slice = &'a [PointId];
|
||||||
NearestIter {
|
|
||||||
nearest: &self[pid.0 as usize].0,
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
|
||||||
|
NearestIter::new(&self[pid.0 as usize])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Layer for &'a [RwLock<ZeroNode>] {
|
||||||
|
type Slice = MappedRwLockReadGuard<'a, [PointId]>;
|
||||||
|
|
||||||
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice> {
|
||||||
|
NearestIter::new(RwLockReadGuard::map(
|
||||||
|
self[pid.0 as usize].read(),
|
||||||
|
Deref::deref,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait Layer {
|
pub(crate) trait Layer {
|
||||||
fn nearest_iter(&self, pid: PointId) -> NearestIter<'_>;
|
type Slice: Deref<Target = [PointId]>;
|
||||||
|
fn nearest_iter(&self, pid: PointId) -> NearestIter<Self::Slice>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct NearestIter<'a> {
|
pub(crate) struct NearestIter<T> {
|
||||||
nearest: &'a [PointId],
|
node: T,
|
||||||
|
cur: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Iterator for NearestIter<'a> {
|
impl<T> NearestIter<T>
|
||||||
|
where
|
||||||
|
T: Deref<Target = [PointId]>,
|
||||||
|
{
|
||||||
|
fn new(node: T) -> Self {
|
||||||
|
Self { node, cur: 0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Iterator for NearestIter<T>
|
||||||
|
where
|
||||||
|
T: Deref<Target = [PointId]>,
|
||||||
|
{
|
||||||
type Item = PointId;
|
type Item = PointId;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let (&first, rest) = self.nearest.split_first()?;
|
if self.cur >= self.node.len() {
|
||||||
if !first.is_valid() {
|
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
self.nearest = rest;
|
|
||||||
Some(first)
|
let item = self.node[self.cur];
|
||||||
|
if !item.is_valid() {
|
||||||
|
self.cur = self.node.len();
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.cur += 1;
|
||||||
|
Some(item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -230,14 +262,6 @@ impl<P> Index<PointId> for Hnsw<P> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P: Point> Index<PointId> for Vec<P> {
|
|
||||||
type Output = P;
|
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
|
||||||
&self[index.0 as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<P: Point> Index<PointId> for [P] {
|
impl<P: Point> Index<PointId> for [P] {
|
||||||
type Output = P;
|
type Output = P;
|
||||||
|
|
||||||
|
@ -246,18 +270,12 @@ impl<P: Point> Index<PointId> for [P] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Index<PointId> for Vec<ZeroNode> {
|
impl Index<PointId> for [RwLock<ZeroNode>] {
|
||||||
type Output = ZeroNode;
|
type Output = RwLock<ZeroNode>;
|
||||||
|
|
||||||
fn index(&self, index: PointId) -> &Self::Output {
|
fn index(&self, index: PointId) -> &Self::Output {
|
||||||
&self[index.0 as usize]
|
&self[index.0 as usize]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IndexMut<PointId> for Vec<ZeroNode> {
|
|
||||||
fn index_mut(&mut self, index: PointId) -> &mut Self::Output {
|
|
||||||
&mut self[index.0 as usize]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) const INVALID: PointId = PointId(u32::MAX);
|
pub(crate) const INVALID: PointId = PointId(u32::MAX);
|
||||||
|
|
Loading…
Reference in New Issue