diff --git a/Cargo.toml b/Cargo.toml index fd9e608..948faaf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2018" [dependencies] ahash = "0.6.1" +indicatif = { version = "0.15", optional = true } ordered-float = "2.0" rand = { version = "0.7.3", features = ["small_rng"] } rayon = "1.5" diff --git a/src/lib.rs b/src/lib.rs index 5bf6b56..d052517 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,8 @@ use std::hash::Hash; use std::ops::Index; use ahash::AHashSet as HashSet; +#[cfg(feature = "indicatif")] +use indicatif::ProgressBar; use ordered_float::OrderedFloat; use rand::rngs::SmallRng; use rand::{RngCore, SeedableRng}; @@ -28,6 +30,8 @@ where fn new(points: &[P], builder: Builder) -> (Self, Vec) { let ef_search = builder.ef_search.unwrap_or(100); let ef_construction = builder.ef_construction.unwrap_or(100); + #[cfg(feature = "indicatif")] + let progress = builder.progress; if points.is_empty() { return ( @@ -99,6 +103,13 @@ where for (layer, range) in ranges { let num = if layer.0 > 0 { M } else { M * 2 }; for &(_, pid) in &nodes[range] { + #[cfg(feature = "indicatif")] + if pid.0 % 10_000 == 0 { + if let Some(bar) = &progress { + bar.set_position(pid.0 as u64); + } + } + search.reset(); let point = &points[pid]; search.push(PointId(0), &points[pid], &points); @@ -128,6 +139,11 @@ where } } + #[cfg(feature = "indicatif")] + if let Some(bar) = progress { + bar.finish(); + } + ( Self { ef_search, @@ -258,6 +274,8 @@ impl Default for Search { pub struct Builder { ef_search: Option, ef_construction: Option, + #[cfg(feature = "indicatif")] + progress: Option, } impl Builder { @@ -274,6 +292,12 @@ impl Builder { self } + #[cfg(feature = "indicatif")] + pub fn progress(mut self, bar: ProgressBar) -> Self { + self.progress = Some(bar); + self + } + pub fn build(self, points: &[P]) -> (Hnsw

, Vec) { Hnsw::new(points, self) }