implements insert on Hnsw and HnswMap and test
This commit is contained in:
parent
cc8fcc0f5c
commit
61e6706448
|
@ -170,6 +170,12 @@ where
|
|||
pub fn get(&self, i: usize, search: &Search) -> Option<MapItem<'_, P, V>> {
|
||||
Some(MapItem::from(self.hnsw.get(i, search)?, self))
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, point: P, value: V) -> Result<PointId, Box<dyn std::error::Error>> {
|
||||
let point_id = self.hnsw.insert(point, 100, Some(Heuristic::default()));
|
||||
self.values.push(value);
|
||||
Ok(point_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MapItem<'a, P, V> {
|
||||
|
@ -394,6 +400,55 @@ where
|
|||
pub fn get(&self, i: usize, search: &Search) -> Option<Item<'_, P>> {
|
||||
Some(Item::new(search.nearest.get(i).copied()?, self))
|
||||
}
|
||||
|
||||
pub fn insert(
|
||||
&mut self,
|
||||
point: P,
|
||||
ef_construction: usize,
|
||||
heuristic: Option<Heuristic>,
|
||||
) -> PointId {
|
||||
let new_pid = self.points.len();
|
||||
let new_point_id = PointId(new_pid as u32);
|
||||
|
||||
self.points.push(point);
|
||||
self.zero.push(ZeroNode::default());
|
||||
|
||||
let zeros = self
|
||||
.zero
|
||||
.iter()
|
||||
.map(|z| RwLock::new(z.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let top = if self.layers.is_empty() {
|
||||
LayerId(0)
|
||||
} else {
|
||||
LayerId(self.layers.len())
|
||||
};
|
||||
|
||||
let construction = Construction {
|
||||
zero: zeros.as_slice(),
|
||||
pool: SearchPool::new(self.points.len()),
|
||||
top,
|
||||
points: self.points.as_slice(),
|
||||
heuristic,
|
||||
ef_construction,
|
||||
#[cfg(feature = "indicatif")]
|
||||
progress: None,
|
||||
#[cfg(feature = "indicatif")]
|
||||
done: AtomicUsize::new(0),
|
||||
};
|
||||
|
||||
let new_layer = construction.top;
|
||||
construction.insert(new_point_id, new_layer, &self.layers);
|
||||
|
||||
self.zero = construction
|
||||
.zero
|
||||
.iter()
|
||||
.map(|node| node.read().clone())
|
||||
.collect();
|
||||
|
||||
new_point_id
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Item<'a, P> {
|
||||
|
|
|
@ -92,7 +92,89 @@ struct Point(f32, f32);
|
|||
|
||||
impl instant_distance::Point for Point {
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
// Euclidean distance metric
|
||||
// Euclidean distance metricØ
|
||||
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::float_cmp, clippy::approx_constant)]
|
||||
fn incremental_insert() {
|
||||
let points = (0..4)
|
||||
.map(|i| Point(i as f32, i as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let values = vec!["zero", "one", "two", "three"];
|
||||
let seed = ThreadRng::default().gen::<u64>();
|
||||
let builder = Builder::default().seed(seed);
|
||||
|
||||
let mut map = builder.build(points, values);
|
||||
|
||||
map.insert(Point(4.0, 4.0), "four").expect("Should insert");
|
||||
|
||||
let mut search = Search::default();
|
||||
|
||||
for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() {
|
||||
match i {
|
||||
0 => {
|
||||
assert_eq!(item.distance, 0.0);
|
||||
assert_eq!(item.value, &"four");
|
||||
}
|
||||
1 => {
|
||||
assert_eq!(item.distance, 1.4142135);
|
||||
assert!(item.value == &"three");
|
||||
}
|
||||
2 => {
|
||||
assert_eq!(item.distance, 2.828427);
|
||||
assert!(item.value == &"two");
|
||||
}
|
||||
3 => {
|
||||
assert_eq!(item.distance, 4.2426405);
|
||||
assert!(item.value == &"one");
|
||||
}
|
||||
4 => {
|
||||
assert_eq!(item.distance, 5.656854);
|
||||
assert!(item.value == &"zero");
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
// Note
|
||||
// This has the same expected results as incremental_insert but builds
|
||||
// the whole map in one go. Only here for comparison.
|
||||
{
|
||||
let points = (0..5)
|
||||
.map(|i| Point(i as f32, i as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let values = vec!["zero", "one", "two", "three", "four"];
|
||||
let seed = ThreadRng::default().gen::<u64>();
|
||||
let builder = Builder::default().seed(seed);
|
||||
let map = builder.build(points, values);
|
||||
let mut search = Search::default();
|
||||
for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() {
|
||||
match i {
|
||||
0 => {
|
||||
assert_eq!(item.distance, 0.0);
|
||||
assert_eq!(item.value, &"four");
|
||||
}
|
||||
1 => {
|
||||
assert_eq!(item.distance, 1.4142135);
|
||||
assert!(item.value == &"three");
|
||||
}
|
||||
2 => {
|
||||
assert_eq!(item.distance, 2.828427);
|
||||
assert!(item.value == &"two");
|
||||
}
|
||||
3 => {
|
||||
assert_eq!(item.distance, 4.2426405);
|
||||
assert!(item.value == &"one");
|
||||
}
|
||||
4 => {
|
||||
assert_eq!(item.distance, 5.656854);
|
||||
assert!(item.value == &"zero");
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue