Explicitly group state for HNSW construction

This commit is contained in:
Dirkjan Ochtman 2021-05-18 10:21:52 +02:00
parent c8a9529355
commit 20ca8b0f3a
1 changed files with 134 additions and 114 deletions

View File

@ -280,74 +280,43 @@ where
.map(|_| RwLock::new(ZeroNode::default())) .map(|_| RwLock::new(ZeroNode::default()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let pool = SearchPool::new(points.len()); let state = Construction {
zero: zero.as_slice(),
pool: SearchPool::new(points.len()),
top,
points: &points,
heuristic,
ef_construction,
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
let done = AtomicUsize::new(0); progress,
#[cfg(feature = "indicatif")]
done: AtomicUsize::new(0),
};
for (layer, range) in ranges { for (layer, range) in ranges {
let num = if layer.is_zero() { M * 2 } else { M };
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
if let Some(bar) = &progress { if let Some(bar) = &state.progress {
bar.set_message(format!("Building index (layer {})", layer.0)); bar.set_message(format!("Building index (layer {})", layer.0));
} }
let end = range.end; let end = range.end;
nodes[range].into_par_iter().for_each(|(_, pid)| { nodes[range].into_par_iter().for_each(|(_, pid)| {
let node = zero.as_slice()[*pid].write(); let node = state.zero[*pid].write();
let (mut search, mut insertion) = pool.pop(); state.insert(*pid, node, layer, &layers);
let point = &points.as_slice()[*pid];
search.reset();
search.push(PointId(0), point, &points);
for cur in top.descend() {
search.ef = if cur <= layer { ef_construction } else { 1 };
match cur > layer {
true => {
search.search(point, layers[cur.0 - 1].as_slice(), &points, num);
search.cull();
}
false => {
search.search(point, zero.as_slice(), &points, num);
break;
}
}
}
insertion.ef = ef_construction;
insert(
*pid,
node,
&mut insertion,
&mut search,
&zero,
&points,
&heuristic,
);
#[cfg(feature = "indicatif")]
if let Some(bar) = &progress {
let value = done.fetch_add(1, atomic::Ordering::Relaxed);
if value % 1000 == 0 {
bar.set_position(value as u64);
}
}
pool.push((search, insertion));
}); });
// For layers above the zero layer, make a copy of the current state of the zero layer // For layers above the zero layer, make a copy of the current state of the zero layer
// with `nearest` truncated to `M` elements. // with `nearest` truncated to `M` elements.
if !layer.is_zero() { if !layer.is_zero() {
let mut upper = Vec::new(); (&state.zero[..end])
(&zero[..end])
.into_par_iter() .into_par_iter()
.map(|zero| UpperNode::from_zero(&zero.read())) .map(|zero| UpperNode::from_zero(&zero.read()))
.collect_into_vec(&mut upper); .collect_into_vec(&mut layers[layer.0 - 1]);
layers[layer.0 - 1] = upper;
} }
} }
#[cfg(feature = "indicatif")] #[cfg(feature = "indicatif")]
if let Some(bar) = progress { if let Some(bar) = &state.progress {
bar.finish(); bar.finish();
} }
@ -408,6 +377,20 @@ where
} }
} }
struct Construction<'a, P: Point> {
zero: &'a [RwLock<ZeroNode>],
pool: SearchPool,
top: LayerId,
points: &'a [P],
heuristic: Option<Heuristic>,
ef_construction: usize,
#[cfg(feature = "indicatif")]
progress: Option<ProgressBar>,
#[cfg(feature = "indicatif")]
done: AtomicUsize,
}
impl<'a, P: Point> Construction<'a, P> {
/// Insert new node in the zero layer /// Insert new node in the zero layer
/// ///
/// * `new`: the `PointId` for the new node /// * `new`: the `PointId` for the new node
@ -418,21 +401,47 @@ where
/// ///
/// Creates the new node, initializing its `nearest` array and updates the nearest neighbors /// Creates the new node, initializing its `nearest` array and updates the nearest neighbors
/// for the new node's neighbors if necessary before appending the new node to the layer. /// for the new node's neighbors if necessary before appending the new node to the layer.
fn insert<P: Point>( fn insert(
&self,
new: PointId, new: PointId,
mut node: parking_lot::RwLockWriteGuard<ZeroNode>, mut node: parking_lot::RwLockWriteGuard<ZeroNode>,
insertion: &mut Search, layer: LayerId,
search: &mut Search, layers: &[Vec<UpperNode>],
layer: &[RwLock<ZeroNode>],
points: &[P],
heuristic: &Option<Heuristic>,
) { ) {
let found = match heuristic { let (mut search, mut insertion) = self.pool.pop();
insertion.ef = self.ef_construction;
let point = &self.points[new];
search.reset();
search.push(PointId(0), point, &self.points);
let num = if layer.is_zero() { M * 2 } else { M };
for cur in self.top.descend() {
search.ef = if cur <= layer {
self.ef_construction
} else {
1
};
match cur > layer {
true => {
search.search(point, layers[cur.0 - 1].as_slice(), &self.points, num);
search.cull();
}
false => {
search.search(point, self.zero, &self.points, num);
break;
}
}
}
let found = match self.heuristic {
None => { None => {
let candidates = search.select_simple(); let candidates = search.select_simple();
&candidates[..Ord::min(candidates.len(), M * 2)] &candidates[..Ord::min(candidates.len(), M * 2)]
} }
Some(heuristic) => search.select_heuristic(&points[new], layer, points, *heuristic), Some(heuristic) => {
search.select_heuristic(&self.points[new], self.zero, self.points, heuristic)
}
}; };
// Just make sure the candidates are all unique // Just make sure the candidates are all unique
@ -444,24 +453,24 @@ fn insert<P: Point>(
for (i, candidate) in found.iter().enumerate() { for (i, candidate) in found.iter().enumerate() {
// `candidate` here is the new node's neighbor // `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate; let &Candidate { distance, pid } = candidate;
if let Some(heuristic) = heuristic { if let Some(heuristic) = self.heuristic {
let found = insertion.add_neighbor_heuristic( let found = insertion.add_neighbor_heuristic(
new, new,
layer.nearest_iter(pid), self.zero.nearest_iter(pid),
layer, self.zero,
&points[pid], &self.points[pid],
points, self.points,
*heuristic, heuristic,
); );
layer[pid] self.zero[pid]
.write() .write()
.rewrite(found.iter().map(|candidate| candidate.pid)); .rewrite(found.iter().map(|candidate| candidate.pid));
node.set(i, pid); node.set(i, pid);
} else { } else {
// Find the correct index to insert at to keep the neighbor's neighbors sorted // Find the correct index to insert at to keep the neighbor's neighbors sorted
let old = &points[pid]; let old = &self.points[pid];
let idx = layer[pid] let idx = self.zero[pid]
.read() .read()
.binary_search_by(|third| { .binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor. // `third` here is one of the neighbors of the new node's neighbor.
@ -471,14 +480,25 @@ fn insert<P: Point>(
_ => return Ordering::Greater, _ => return Ordering::Greater,
}; };
distance.cmp(&old.distance(&points[third]).into()) distance.cmp(&old.distance(&self.points[third]).into())
}) })
.unwrap_or_else(|e| e); .unwrap_or_else(|e| e);
layer[pid].write().insert(idx, new); self.zero[pid].write().insert(idx, new);
node.set(i, pid); node.set(i, pid);
} }
} }
#[cfg(feature = "indicatif")]
if let Some(bar) = &self.progress {
let value = self.done.fetch_add(1, atomic::Ordering::Relaxed);
if value % 1000 == 0 {
bar.set_position(value as u64);
}
}
self.pool.push((search, insertion));
}
} }
struct SearchPool { struct SearchPool {