From cdc1242ba918f35b762abae0f82e72208d93482e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Tue, 19 Jan 2021 11:00:29 +0100 Subject: [PATCH] Test both heuristic and simple selection --- src/lib.rs | 2 +- tests/all.rs | 31 +++++++++++++++++-------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ed4f414..359ce52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -840,4 +840,4 @@ impl IndexMut for Vec { /// The parameter `M` from the paper /// /// This should become a generic argument to `Hnsw` when possible. -const M: usize = 12; +const M: usize = 6; diff --git a/tests/all.rs b/tests/all.rs index be2a60a..9bdefc3 100644 --- a/tests/all.rs +++ b/tests/all.rs @@ -4,7 +4,7 @@ use ordered_float::OrderedFloat; use rand::rngs::{StdRng, ThreadRng}; use rand::{Rng, SeedableRng}; -use instant_distance::{Hnsw, Point as _, PointId, Search}; +use instant_distance::{Builder, Heuristic, Hnsw, Point as _, PointId, Search}; #[test] fn basic() { @@ -24,10 +24,21 @@ fn basic() { } #[test] -fn randomized() { - let seed = ThreadRng::default().gen::(); - println!("seed {}", seed); +fn random_heuristic() { + let (seed, recall) = randomized(Builder::default().select_heuristic(Heuristic::default())); + println!("heuristic (seed = {}) recall = {}", seed, recall); + assert!(recall > 98, "expected at least 98, got {}", recall); +} +#[test] +fn random_simple() { + let (seed, recall) = randomized(Builder::default()); + println!("simple (seed = {}) recall = {}", seed, recall); + assert!(recall > 90, "expected at least 90, got {}", recall); +} + +fn randomized(builder: Builder) -> (u64, usize) { + let seed = ThreadRng::default().gen::(); let mut rng = StdRng::seed_from_u64(seed); let points = (0..1024) .into_iter() @@ -35,8 +46,6 @@ fn randomized() { .collect::>(); let query = Point(rng.gen(), rng.gen()); - println!("query: {:?}", query); - let mut nearest = Vec::with_capacity(256); for (i, p) in points.iter().enumerate() { nearest.push((OrderedFloat::from(query.distance(p)), i)); @@ -46,10 +55,7 @@ fn randomized() { } } - let (hnsw, pids) = Hnsw::::builder() - .seed(seed) - .select_heuristic(Default::default()) - .build(&points); + let (hnsw, pids) = builder.seed(seed).build(&points); let mut search = Search::default(); let mut results = vec![PointId::default(); 100]; let found = hnsw.search(&query, &mut results, &mut search); @@ -62,10 +68,7 @@ fn randomized() { .map(|(_, i)| pids[*i]) .collect::>(); let found = results.into_iter().take(found).collect::>(); - - let recall = forced.intersection(&found).count(); - println!("{} matched", recall); - assert!(recall > 90); + (seed, forced.intersection(&found).count()) } #[derive(Clone, Copy, Debug)]