Test recall in randomized test
This commit is contained in:
parent
6def318423
commit
df24ddc186
39
tests/all.rs
39
tests/all.rs
|
@ -1,3 +1,5 @@
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use rand::rngs::{StdRng, ThreadRng};
|
use rand::rngs::{StdRng, ThreadRng};
|
||||||
use rand::{Rng, SeedableRng};
|
use rand::{Rng, SeedableRng};
|
||||||
|
@ -35,35 +37,32 @@ fn randomized() {
|
||||||
let query = Point(rng.gen(), rng.gen());
|
let query = Point(rng.gen(), rng.gen());
|
||||||
println!("query: {:?}", query);
|
println!("query: {:?}", query);
|
||||||
|
|
||||||
|
let mut nearest = Vec::with_capacity(256);
|
||||||
for (i, p) in points.iter().enumerate() {
|
for (i, p) in points.iter().enumerate() {
|
||||||
println!("{:2} {:?} ({})", i, p, query.distance(p));
|
nearest.push((OrderedFloat::from(query.distance(p)), i));
|
||||||
|
if nearest.len() >= 200 {
|
||||||
|
nearest.sort_unstable();
|
||||||
|
nearest.truncate(100);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (hnsw, pids) = Hnsw::<Point>::builder().seed(seed).build(&points);
|
let (hnsw, pids) = Hnsw::<Point>::builder().seed(seed).build(&points);
|
||||||
let mut search = Search::default();
|
let mut search = Search::default();
|
||||||
let mut results = vec![PointId::default()];
|
let mut results = vec![PointId::default(); 100];
|
||||||
let found = hnsw.search(&query, &mut results, &mut search);
|
let found = hnsw.search(&query, &mut results, &mut search);
|
||||||
assert_eq!(found, 1);
|
assert_eq!(found, 100);
|
||||||
|
|
||||||
let nearest = points
|
nearest.sort_unstable();
|
||||||
|
nearest.truncate(100);
|
||||||
|
let forced = nearest
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.map(|(_, i)| pids[*i])
|
||||||
.map(|(i, other)| (OrderedFloat::from(query.distance(other)), i))
|
.collect::<HashSet<_>>();
|
||||||
.min()
|
let found = results.into_iter().take(found).collect::<HashSet<_>>();
|
||||||
.unwrap();
|
|
||||||
println!(
|
|
||||||
"nearest (brute force): {:?} @ {:9.7}",
|
|
||||||
pids[nearest.1],
|
|
||||||
nearest.0.into_inner()
|
|
||||||
);
|
|
||||||
|
|
||||||
let index = pids.iter().position(|p| p == &results[0]).unwrap();
|
let recall = forced.intersection(&found).count();
|
||||||
println!(
|
println!("{} matched", recall);
|
||||||
"nearest (hnsw): {:?} @ {:9.7}",
|
assert!(recall > 95);
|
||||||
results[0],
|
|
||||||
query.distance(&points[index])
|
|
||||||
);
|
|
||||||
assert_eq!(pids[nearest.1], results[0]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
|
Loading…
Reference in New Issue