Use u32 for PointIds to save memory

This commit is contained in:
Dirkjan Ochtman 2020-12-14 16:37:15 +01:00
parent db925603d0
commit 0e0ffe201e
1 changed files with 42 additions and 32 deletions

View File

@ -38,6 +38,7 @@ where
// construction. This allows us to copy higher layers to lower layers as construction // construction. This allows us to copy higher layers to lower layers as construction
// progresses, while preserving randomness in each point's layer and insertion order. // progresses, while preserving randomness in each point's layer and insertion order.
assert!(points.len() < u32::MAX as usize);
let mut rng = SmallRng::from_entropy(); let mut rng = SmallRng::from_entropy();
let mut nodes = (0..points.len()) let mut nodes = (0..points.len())
.map(|i| (LayerId::random(&mut rng), i)) .map(|i| (LayerId::random(&mut rng), i))
@ -51,7 +52,7 @@ where
let mut new_nodes = Vec::with_capacity(points.len()); let mut new_nodes = Vec::with_capacity(points.len());
let mut out = vec![PointId::invalid(); points.len()]; let mut out = vec![PointId::invalid(); points.len()];
for (i, &(layer, idx)) in nodes.iter().enumerate() { for (i, &(layer, idx)) in nodes.iter().enumerate() {
let pid = PointId(i); let pid = PointId(i as u32);
new_points.push(points[idx].clone()); new_points.push(points[idx].clone());
new_nodes.push((layer, pid)); new_nodes.push((layer, pid));
out[idx] = pid; out[idx] = pid;
@ -136,8 +137,6 @@ where
/// The results are returned in the `out` parameter; the number of neighbors to search for /// The results are returned in the `out` parameter; the number of neighbors to search for
/// is limited by the size of the `out` parameter, and the number of results found is returned /// is limited by the size of the `out` parameter, and the number of results found is returned
/// in the return value. /// in the return value.
///
/// `PointId` values can be initialized with `PointId::invalid()`.
pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize { pub fn search(&self, point: &P, out: &mut [PointId], search: &mut Search) -> usize {
if self.points.is_empty() { if self.points.is_empty() {
return 0; return 0;
@ -249,7 +248,7 @@ impl<P> Index<PointId> for Hnsw<P> {
type Output = P; type Output = P;
fn index(&self, index: PointId) -> &Self::Output { fn index(&self, index: PointId) -> &Self::Output {
&self.points[index.0] &self.points[index.0 as usize]
} }
} }
@ -257,7 +256,7 @@ impl<P: Point> Index<PointId> for Vec<P> {
type Output = P; type Output = P;
fn index(&self, index: PointId) -> &Self::Output { fn index(&self, index: PointId) -> &Self::Output {
&self[index.0] &self[index.0 as usize]
} }
} }
@ -265,7 +264,7 @@ impl<P: Point> Index<PointId> for [P] {
type Output = P; type Output = P;
fn index(&self, index: PointId) -> &Self::Output { fn index(&self, index: PointId) -> &Self::Output {
&self[index.0] &self[index.0 as usize]
} }
} }
@ -329,7 +328,7 @@ trait Layer {
} }
} }
for pid in self.nodes()[candidate.pid.0].nearest_iter().take(num) { for pid in self.nodes()[candidate.pid.0 as usize].nearest_iter().take(num) {
search.push(pid, point, points); search.push(pid, point, points);
} }
} }
@ -356,22 +355,22 @@ trait Layer {
for (i, candidate) in found.iter().take(Self::LINKS).enumerate() { for (i, candidate) in found.iter().take(Self::LINKS).enumerate() {
// `candidate` here is the new node's neighbor // `candidate` here is the new node's neighbor
let &Candidate { distance, pid } = candidate; let &Candidate { distance, pid } = candidate;
new_nearest[i] = Some(pid); // Update the new node's `nearest` new_nearest[i] = pid; // Update the new node's `nearest`
let old = &points[pid]; let old = &points[pid];
let nearest = self.nodes()[pid.0].nearest(); let nearest = self.nodes()[pid.0 as usize].nearest();
// Find the correct index to insert at to keep the neighbor's neighbors sorted // Find the correct index to insert at to keep the neighbor's neighbors sorted
let idx = nearest let idx = nearest
.binary_search_by(|third| { .binary_search_by(|third| {
// `third` here is one of the neighbors of the new node's neighbor. // `third` here is one of the neighbors of the new node's neighbor.
let third = match third { let third = match third {
Some(nid) => *nid, pid if pid.is_valid() => *pid,
// if `third` is `None`, our new `node` is always "closer" // if `third` is `None`, our new `node` is always "closer"
None => return Ordering::Greater, _ => return Ordering::Greater,
}; };
let third_distance = OrderedFloat::from(old.distance(&points[third.0])); let third_distance = OrderedFloat::from(old.distance(&points[third.0 as usize]));
distance.cmp(&third_distance) distance.cmp(&third_distance)
}) })
.unwrap_or_else(|e| e); .unwrap_or_else(|e| e);
@ -382,15 +381,15 @@ trait Layer {
continue; continue;
} }
let nearest = self.nodes_mut()[pid.0].nearest_mut(); let nearest = self.nodes_mut()[pid.0 as usize].nearest_mut();
if nearest[idx].is_none() { if !nearest[idx].is_valid() {
nearest[idx] = Some(new); nearest[idx] = new;
continue; continue;
} }
let end = Self::LINKS - 1; let end = Self::LINKS - 1;
nearest.copy_within(idx..end, idx + 1); nearest.copy_within(idx..end, idx + 1);
nearest[idx] = Some(new); nearest[idx] = new;
} }
self.push(node); self.push(node);
@ -409,15 +408,15 @@ struct UpperNode {
/// The nearest neighbors on this layer /// The nearest neighbors on this layer
/// ///
/// This is always kept in sorted order (near to far). /// This is always kept in sorted order (near to far).
nearest: [Option<PointId>; M], nearest: [PointId; M],
} }
impl Node for UpperNode { impl Node for UpperNode {
fn nearest(&self) -> &[Option<PointId>] { fn nearest(&self) -> &[PointId] {
&self.nearest &self.nearest
} }
fn nearest_mut(&mut self) -> &mut [Option<PointId>] { fn nearest_mut(&mut self) -> &mut [PointId] {
&mut self.nearest &mut self.nearest
} }
@ -434,15 +433,15 @@ struct ZeroNode {
/// The nearest neighbors on this layer /// The nearest neighbors on this layer
/// ///
/// This is always kept in sorted order (near to far). /// This is always kept in sorted order (near to far).
nearest: [Option<PointId>; M * 2], nearest: [PointId; M * 2],
} }
impl Node for ZeroNode { impl Node for ZeroNode {
fn nearest(&self) -> &[Option<PointId>] { fn nearest(&self) -> &[PointId] {
&self.nearest &self.nearest
} }
fn nearest_mut(&mut self) -> &mut [Option<PointId>] { fn nearest_mut(&mut self) -> &mut [PointId] {
&mut self.nearest &mut self.nearest
} }
@ -454,13 +453,13 @@ impl Node for ZeroNode {
} }
trait Node: Default { trait Node: Default {
fn nearest(&self) -> &[Option<PointId>]; fn nearest(&self) -> &[PointId];
fn nearest_mut(&mut self) -> &mut [Option<PointId>]; fn nearest_mut(&mut self) -> &mut [PointId];
fn nearest_iter(&self) -> NearestIter<'_>; fn nearest_iter(&self) -> NearestIter<'_>;
} }
struct NearestIter<'a> { struct NearestIter<'a> {
nearest: &'a [Option<PointId>], nearest: &'a [PointId],
} }
impl<'a> Iterator for NearestIter<'a> { impl<'a> Iterator for NearestIter<'a> {
@ -468,11 +467,11 @@ impl<'a> Iterator for NearestIter<'a> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let (&first, rest) = self.nearest.split_first()?; let (&first, rest) = self.nearest.split_first()?;
self.nearest = rest; if !first.is_valid() {
if first.is_none() { return None;
self.nearest = &[];
} }
first self.nearest = rest;
Some(first)
} }
} }
@ -530,11 +529,22 @@ struct Candidate {
/// This can be used to index into the `Hnsw` to refer to the `Point` data. /// This can be used to index into the `Hnsw` to refer to the `Point` data.
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct PointId(usize); pub struct PointId(u32);
impl PointId { impl PointId {
pub fn invalid() -> Self { fn invalid() -> Self {
PointId(usize::MAX) PointId(u32::MAX)
}
/// Whether this value represents a valid point
pub fn is_valid(self) -> bool {
self.0 != u32::MAX
}
}
impl Default for PointId {
fn default() -> Self {
PointId::invalid()
} }
} }