diff --git a/tests/all.rs b/tests/all.rs index 754e4db..2098a84 100644 --- a/tests/all.rs +++ b/tests/all.rs @@ -1,4 +1,8 @@ -use instant_distance::{Hnsw, PointId, Search}; +use ordered_float::OrderedFloat; +use rand::rngs::{StdRng, ThreadRng}; +use rand::{Rng, SeedableRng}; + +use instant_distance::{Hnsw, Point as _, PointId, Search}; #[test] fn basic() { @@ -17,6 +21,51 @@ fn basic() { assert_eq!(&results, &[pids[0]]); } +#[test] +fn randomized() { + let seed = ThreadRng::default().gen::(); + println!("seed {}", seed); + + let mut rng = StdRng::seed_from_u64(seed); + let points = (0..1024) + .into_iter() + .map(|_| Point(rng.gen(), rng.gen())) + .collect::>(); + + let query = Point(rng.gen(), rng.gen()); + println!("query: {:?}", query); + + for (i, p) in points.iter().enumerate() { + println!("{:2} {:?} ({})", i, p, query.distance(p)); + } + + let (hnsw, pids) = Hnsw::::builder().seed(seed).build(&points); + let mut search = Search::default(); + let mut results = vec![PointId::default()]; + let found = hnsw.search(&query, &mut results, &mut search); + assert_eq!(found, 1); + + let nearest = points + .iter() + .enumerate() + .map(|(i, other)| (OrderedFloat::from(query.distance(other)), i)) + .min() + .unwrap(); + println!( + "nearest (brute force): {:?} @ {:9.7}", + pids[nearest.1], + nearest.0.into_inner() + ); + + let index = pids.iter().position(|p| p == &results[0]).unwrap(); + println!( + "nearest (hnsw): {:?} @ {:9.7}", + results[0], + query.distance(&points[index]) + ); + assert_eq!(pids[nearest.1], results[0]); +} + #[derive(Clone, Copy, Debug)] struct Point(f32, f32);