Add more extensive randomized test
This commit is contained in:
parent
c4708ac032
commit
db1d3128ec
51
tests/all.rs
51
tests/all.rs
|
@ -1,4 +1,8 @@
|
|||
use instant_distance::{Hnsw, PointId, Search};
|
||||
use ordered_float::OrderedFloat;
|
||||
use rand::rngs::{StdRng, ThreadRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use instant_distance::{Hnsw, Point as _, PointId, Search};
|
||||
|
||||
#[test]
|
||||
fn basic() {
|
||||
|
@ -17,6 +21,51 @@ fn basic() {
|
|||
assert_eq!(&results, &[pids[0]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn randomized() {
|
||||
let seed = ThreadRng::default().gen::<u64>();
|
||||
println!("seed {}", seed);
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let points = (0..1024)
|
||||
.into_iter()
|
||||
.map(|_| Point(rng.gen(), rng.gen()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let query = Point(rng.gen(), rng.gen());
|
||||
println!("query: {:?}", query);
|
||||
|
||||
for (i, p) in points.iter().enumerate() {
|
||||
println!("{:2} {:?} ({})", i, p, query.distance(p));
|
||||
}
|
||||
|
||||
let (hnsw, pids) = Hnsw::<Point>::builder().seed(seed).build(&points);
|
||||
let mut search = Search::default();
|
||||
let mut results = vec![PointId::default()];
|
||||
let found = hnsw.search(&query, &mut results, &mut search);
|
||||
assert_eq!(found, 1);
|
||||
|
||||
let nearest = points
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, other)| (OrderedFloat::from(query.distance(other)), i))
|
||||
.min()
|
||||
.unwrap();
|
||||
println!(
|
||||
"nearest (brute force): {:?} @ {:9.7}",
|
||||
pids[nearest.1],
|
||||
nearest.0.into_inner()
|
||||
);
|
||||
|
||||
let index = pids.iter().position(|p| p == &results[0]).unwrap();
|
||||
println!(
|
||||
"nearest (hnsw): {:?} @ {:9.7}",
|
||||
results[0],
|
||||
query.distance(&points[index])
|
||||
);
|
||||
assert_eq!(pids[nearest.1], results[0]);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct Point(f32, f32);
|
||||
|
||||
|
|
Loading…
Reference in New Issue