implements insert on Hnsw and HnswMap and test
This commit is contained in:
parent
cc8fcc0f5c
commit
61e6706448
|
@ -170,6 +170,12 @@ where
|
||||||
pub fn get(&self, i: usize, search: &Search) -> Option<MapItem<'_, P, V>> {
|
pub fn get(&self, i: usize, search: &Search) -> Option<MapItem<'_, P, V>> {
|
||||||
Some(MapItem::from(self.hnsw.get(i, search)?, self))
|
Some(MapItem::from(self.hnsw.get(i, search)?, self))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn insert(&mut self, point: P, value: V) -> Result<PointId, Box<dyn std::error::Error>> {
|
||||||
|
let point_id = self.hnsw.insert(point, 100, Some(Heuristic::default()));
|
||||||
|
self.values.push(value);
|
||||||
|
Ok(point_id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MapItem<'a, P, V> {
|
pub struct MapItem<'a, P, V> {
|
||||||
|
@ -394,6 +400,55 @@ where
|
||||||
pub fn get(&self, i: usize, search: &Search) -> Option<Item<'_, P>> {
|
pub fn get(&self, i: usize, search: &Search) -> Option<Item<'_, P>> {
|
||||||
Some(Item::new(search.nearest.get(i).copied()?, self))
|
Some(Item::new(search.nearest.get(i).copied()?, self))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn insert(
|
||||||
|
&mut self,
|
||||||
|
point: P,
|
||||||
|
ef_construction: usize,
|
||||||
|
heuristic: Option<Heuristic>,
|
||||||
|
) -> PointId {
|
||||||
|
let new_pid = self.points.len();
|
||||||
|
let new_point_id = PointId(new_pid as u32);
|
||||||
|
|
||||||
|
self.points.push(point);
|
||||||
|
self.zero.push(ZeroNode::default());
|
||||||
|
|
||||||
|
let zeros = self
|
||||||
|
.zero
|
||||||
|
.iter()
|
||||||
|
.map(|z| RwLock::new(z.clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let top = if self.layers.is_empty() {
|
||||||
|
LayerId(0)
|
||||||
|
} else {
|
||||||
|
LayerId(self.layers.len())
|
||||||
|
};
|
||||||
|
|
||||||
|
let construction = Construction {
|
||||||
|
zero: zeros.as_slice(),
|
||||||
|
pool: SearchPool::new(self.points.len()),
|
||||||
|
top,
|
||||||
|
points: self.points.as_slice(),
|
||||||
|
heuristic,
|
||||||
|
ef_construction,
|
||||||
|
#[cfg(feature = "indicatif")]
|
||||||
|
progress: None,
|
||||||
|
#[cfg(feature = "indicatif")]
|
||||||
|
done: AtomicUsize::new(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
let new_layer = construction.top;
|
||||||
|
construction.insert(new_point_id, new_layer, &self.layers);
|
||||||
|
|
||||||
|
self.zero = construction
|
||||||
|
.zero
|
||||||
|
.iter()
|
||||||
|
.map(|node| node.read().clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
new_point_id
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Item<'a, P> {
|
pub struct Item<'a, P> {
|
||||||
|
|
|
@ -92,7 +92,89 @@ struct Point(f32, f32);
|
||||||
|
|
||||||
impl instant_distance::Point for Point {
|
impl instant_distance::Point for Point {
|
||||||
fn distance(&self, other: &Self) -> f32 {
|
fn distance(&self, other: &Self) -> f32 {
|
||||||
// Euclidean distance metric
|
// Euclidean distance metricØ
|
||||||
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
|
((self.0 - other.0).powi(2) + (self.1 - other.1).powi(2)).sqrt()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(clippy::float_cmp, clippy::approx_constant)]
|
||||||
|
fn incremental_insert() {
|
||||||
|
let points = (0..4)
|
||||||
|
.map(|i| Point(i as f32, i as f32))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let values = vec!["zero", "one", "two", "three"];
|
||||||
|
let seed = ThreadRng::default().gen::<u64>();
|
||||||
|
let builder = Builder::default().seed(seed);
|
||||||
|
|
||||||
|
let mut map = builder.build(points, values);
|
||||||
|
|
||||||
|
map.insert(Point(4.0, 4.0), "four").expect("Should insert");
|
||||||
|
|
||||||
|
let mut search = Search::default();
|
||||||
|
|
||||||
|
for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() {
|
||||||
|
match i {
|
||||||
|
0 => {
|
||||||
|
assert_eq!(item.distance, 0.0);
|
||||||
|
assert_eq!(item.value, &"four");
|
||||||
|
}
|
||||||
|
1 => {
|
||||||
|
assert_eq!(item.distance, 1.4142135);
|
||||||
|
assert!(item.value == &"three");
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
assert_eq!(item.distance, 2.828427);
|
||||||
|
assert!(item.value == &"two");
|
||||||
|
}
|
||||||
|
3 => {
|
||||||
|
assert_eq!(item.distance, 4.2426405);
|
||||||
|
assert!(item.value == &"one");
|
||||||
|
}
|
||||||
|
4 => {
|
||||||
|
assert_eq!(item.distance, 5.656854);
|
||||||
|
assert!(item.value == &"zero");
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note
|
||||||
|
// This has the same expected results as incremental_insert but builds
|
||||||
|
// the whole map in one go. Only here for comparison.
|
||||||
|
{
|
||||||
|
let points = (0..5)
|
||||||
|
.map(|i| Point(i as f32, i as f32))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let values = vec!["zero", "one", "two", "three", "four"];
|
||||||
|
let seed = ThreadRng::default().gen::<u64>();
|
||||||
|
let builder = Builder::default().seed(seed);
|
||||||
|
let map = builder.build(points, values);
|
||||||
|
let mut search = Search::default();
|
||||||
|
for (i, item) in map.search(&Point(4.0, 4.0), &mut search).enumerate() {
|
||||||
|
match i {
|
||||||
|
0 => {
|
||||||
|
assert_eq!(item.distance, 0.0);
|
||||||
|
assert_eq!(item.value, &"four");
|
||||||
|
}
|
||||||
|
1 => {
|
||||||
|
assert_eq!(item.distance, 1.4142135);
|
||||||
|
assert!(item.value == &"three");
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
assert_eq!(item.distance, 2.828427);
|
||||||
|
assert!(item.value == &"two");
|
||||||
|
}
|
||||||
|
3 => {
|
||||||
|
assert_eq!(item.distance, 4.2426405);
|
||||||
|
assert!(item.value == &"one");
|
||||||
|
}
|
||||||
|
4 => {
|
||||||
|
assert_eq!(item.distance, 5.656854);
|
||||||
|
assert!(item.value == &"zero");
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue