diff --git a/src/lib.rs b/src/lib.rs index 5422284..01371af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,6 @@ use serde::{Deserialize, Serialize}; use serde_big_array::big_array; /// Parameters for building the `Hnsw` -#[derive(Default)] pub struct Builder { ef_search: Option, ef_construction: Option, @@ -43,8 +42,8 @@ impl Builder { self } - pub fn select_heuristic(mut self, params: Heuristic) -> Self { - self.heuristic = Some(params); + pub fn select_heuristic(mut self, params: Option) -> Self { + self.heuristic = params; self } @@ -77,6 +76,20 @@ impl Builder { } } +impl Default for Builder { + fn default() -> Self { + Self { + ef_search: None, + ef_construction: None, + heuristic: Some(Heuristic::default()), + ml: None, + seed: None, + #[cfg(feature = "indicatif")] + progress: None, + } + } +} + #[derive(Copy, Clone, Debug)] pub struct Heuristic { pub extend_candidates: bool, diff --git a/tests/all.rs b/tests/all.rs index 9bdefc3..6e04118 100644 --- a/tests/all.rs +++ b/tests/all.rs @@ -25,14 +25,14 @@ fn basic() { #[test] fn random_heuristic() { - let (seed, recall) = randomized(Builder::default().select_heuristic(Heuristic::default())); + let (seed, recall) = randomized(Builder::default()); println!("heuristic (seed = {}) recall = {}", seed, recall); assert!(recall > 98, "expected at least 98, got {}", recall); } #[test] fn random_simple() { - let (seed, recall) = randomized(Builder::default()); + let (seed, recall) = randomized(Builder::default().select_heuristic(None)); println!("simple (seed = {}) recall = {}", seed, recall); assert!(recall > 90, "expected at least 90, got {}", recall); }