Refactor to prepare for multiple distance metrics
This commit is contained in:
parent
f1cb9ee234
commit
bca31ad33f
12
Makefile
12
Makefile
|
@ -1,8 +1,14 @@
|
|||
test-python:
|
||||
cargo build --release
|
||||
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so
|
||||
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs
|
||||
RUSTFLAGS="-C target-cpu=native" cargo build --release
|
||||
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \
|
||||
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so)
|
||||
|
||||
test-python: instant-distance-py/test/instant_distance.so
|
||||
PYTHONPATH=instant-distance-py/test/ python3 -m test
|
||||
|
||||
bench-python: instant-distance-py/test/instant_distance.so
|
||||
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
|
||||
|
||||
clean:
|
||||
cargo clean
|
||||
rm -f instant-distance-py/test/instant_distance.so
|
||||
|
|
|
@ -24,12 +24,46 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
|
|||
m.add_class::<Search>()?;
|
||||
m.add_class::<Hnsw>()?;
|
||||
m.add_class::<HnswMap>()?;
|
||||
m.add_class::<DistanceMetric>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Copy, Clone)]
|
||||
enum DistanceMetric {
|
||||
Euclid,
|
||||
Cosine,
|
||||
}
|
||||
|
||||
impl Default for DistanceMetric {
|
||||
fn default() -> Self {
|
||||
Self::Euclid
|
||||
}
|
||||
}
|
||||
|
||||
// Helper macro for dispatching to inner implementation
|
||||
macro_rules! impl_for_each_hnsw_with_metric {
|
||||
($type:ident, $instance:expr, $inner:ident, $($tokens:tt)+) => {
|
||||
match $instance {
|
||||
$type::Euclid($inner) => {
|
||||
$($tokens)+
|
||||
}
|
||||
$type::Cosine($inner) => {
|
||||
$($tokens)+
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct HnswMap {
|
||||
inner: instant_distance::HnswMap<FloatArray, MapValue>,
|
||||
inner: HnswMapWithMetric,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
enum HnswMapWithMetric {
|
||||
Euclid(instant_distance::HnswMap<FloatArray, MapValue>),
|
||||
Cosine(instant_distance::HnswMap<FloatArray, MapValue>),
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -37,28 +71,32 @@ impl HnswMap {
|
|||
/// Build the index
|
||||
#[staticmethod]
|
||||
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
|
||||
let points = points
|
||||
.into_iter()
|
||||
.map(FloatArray::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let values = values
|
||||
.into_iter()
|
||||
.map(MapValue::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
|
||||
Ok(Self { inner: hsnw_map })
|
||||
let builder = instant_distance::Builder::from(config);
|
||||
let inner = match config.distance_metric {
|
||||
DistanceMetric::Euclid => {
|
||||
let points = FloatArray::try_from_pylist(points)?;
|
||||
HnswMapWithMetric::Euclid(builder.build(points, values))
|
||||
}
|
||||
DistanceMetric::Cosine => {
|
||||
let points = FloatArray::try_from_pylist(points)?;
|
||||
HnswMapWithMetric::Cosine(builder.build(points, values))
|
||||
}
|
||||
};
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Load an index from the given file name
|
||||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw_map =
|
||||
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
|
||||
let hnsw_map = bincode::deserialize_from::<_, HnswMapWithMetric>(BufReader::with_capacity(
|
||||
32 * 1024 * 1024,
|
||||
File::open(fname)?,
|
||||
))
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
|
||||
Ok(Self { inner: hnsw_map })
|
||||
}
|
||||
|
||||
|
@ -78,8 +116,10 @@ impl HnswMap {
|
|||
///
|
||||
/// For best performance, reusing `Search` objects is recommended.
|
||||
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
|
||||
let point = FloatArray::try_from(point)?;
|
||||
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
|
||||
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
|
||||
let point = FloatArray::try_from(point)?;
|
||||
let _ = hnsw.search(&point, &mut search.inner);
|
||||
});
|
||||
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
|
||||
Ok(())
|
||||
}
|
||||
|
@ -91,7 +131,13 @@ impl HnswMap {
|
|||
/// with a squared Euclidean distance metric.
|
||||
#[pyclass]
|
||||
struct Hnsw {
|
||||
inner: instant_distance::Hnsw<FloatArray>,
|
||||
inner: HnswWithMetric,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
enum HnswWithMetric {
|
||||
Euclid(instant_distance::Hnsw<FloatArray>),
|
||||
Cosine(instant_distance::Hnsw<FloatArray>),
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -99,12 +145,19 @@ impl Hnsw {
|
|||
/// Build the index
|
||||
#[staticmethod]
|
||||
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
|
||||
let points = input
|
||||
.into_iter()
|
||||
.map(FloatArray::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
|
||||
let builder = instant_distance::Builder::from(config);
|
||||
let (inner, ids) = match config.distance_metric {
|
||||
DistanceMetric::Euclid => {
|
||||
let points = FloatArray::try_from_pylist(input)?;
|
||||
let (hnsw, ids) = builder.build_hnsw(points);
|
||||
(HnswWithMetric::Euclid(hnsw), ids)
|
||||
}
|
||||
DistanceMetric::Cosine => {
|
||||
let points = FloatArray::try_from_pylist(input)?;
|
||||
let (hnsw, ids) = builder.build_hnsw(points);
|
||||
(HnswWithMetric::Cosine(hnsw), ids)
|
||||
}
|
||||
};
|
||||
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
||||
Ok((Self { inner }, ids))
|
||||
}
|
||||
|
@ -112,9 +165,10 @@ impl Hnsw {
|
|||
/// Load an index from the given file name
|
||||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
let hnsw = bincode::deserialize_from::<_, HnswWithMetric>(BufReader::with_capacity(
|
||||
32 * 1024 * 1024,
|
||||
File::open(fname)?,
|
||||
))
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
|
||||
Ok(Self { inner: hnsw })
|
||||
}
|
||||
|
@ -135,8 +189,10 @@ impl Hnsw {
|
|||
///
|
||||
/// For best performance, reusing `Search` objects is recommended.
|
||||
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
|
||||
let point = FloatArray::try_from(point)?;
|
||||
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
|
||||
impl_for_each_hnsw_with_metric!(HnswWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
|
||||
let point = FloatArray::try_from(point)?;
|
||||
let _ = hnsw.search(&point, &mut search.inner);
|
||||
});
|
||||
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
|
||||
Ok(())
|
||||
}
|
||||
|
@ -175,20 +231,24 @@ impl Search {
|
|||
let neighbor = match &index {
|
||||
HnswType::Hnsw(hnsw) => {
|
||||
let hnsw = hnsw.as_ref(py).borrow();
|
||||
let item = hnsw.inner.get(idx, &slf.inner);
|
||||
item.map(|item| Neighbor {
|
||||
distance: item.distance,
|
||||
pid: item.pid.into_inner(),
|
||||
value: py.None(),
|
||||
impl_for_each_hnsw_with_metric!(HnswWithMetric, &hnsw.inner, hnsw, {
|
||||
let item = hnsw.get(idx, &slf.inner);
|
||||
item.map(|item| Neighbor {
|
||||
distance: item.distance,
|
||||
pid: item.pid.into_inner(),
|
||||
value: py.None(),
|
||||
})
|
||||
})
|
||||
}
|
||||
HnswType::Map(map) => {
|
||||
let map = map.as_ref(py).borrow();
|
||||
let item = map.inner.get(idx, &slf.inner);
|
||||
item.map(|item| Neighbor {
|
||||
distance: item.distance,
|
||||
pid: item.pid.into_inner(),
|
||||
value: item.value.into_py(py),
|
||||
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &map.inner, map, {
|
||||
let item = map.get(idx, &slf.inner);
|
||||
item.map(|item| Neighbor {
|
||||
distance: item.distance,
|
||||
pid: item.pid.into_inner(),
|
||||
value: item.value.into_py(py),
|
||||
})
|
||||
})
|
||||
}
|
||||
};
|
||||
|
@ -226,6 +286,11 @@ struct Config {
|
|||
/// in order to get better results on clustered data points.
|
||||
#[pyo3(get, set)]
|
||||
heuristic: Option<Heuristic>,
|
||||
/// Distance metric to use
|
||||
///
|
||||
/// Defaults to Euclidean distance
|
||||
#[pyo3(get, set)]
|
||||
distance_metric: DistanceMetric,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
|
@ -235,12 +300,14 @@ impl Config {
|
|||
let builder = instant_distance::Builder::default();
|
||||
let (ef_search, ef_construction, ml, seed) = builder.into_parts();
|
||||
let heuristic = Some(Heuristic::default());
|
||||
let distance_metric = DistanceMetric::default();
|
||||
Self {
|
||||
ef_search,
|
||||
ef_construction,
|
||||
ml,
|
||||
seed,
|
||||
heuristic,
|
||||
distance_metric,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -253,6 +320,7 @@ impl From<&Config> for instant_distance::Builder {
|
|||
ml,
|
||||
seed,
|
||||
heuristic,
|
||||
distance_metric: _,
|
||||
} = *py;
|
||||
Self::default()
|
||||
.ef_search(ef_search)
|
||||
|
@ -350,6 +418,12 @@ impl Neighbor {
|
|||
#[derive(Clone, Deserialize, Serialize)]
|
||||
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
|
||||
|
||||
impl FloatArray {
|
||||
fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> {
|
||||
list.into_iter().map(FloatArray::try_from).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&PyAny> for FloatArray {
|
||||
type Error = PyErr;
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import instant_distance, random
|
||||
|
||||
|
||||
def test_hsnw():
|
||||
def test_hsnw(distance_metric=instant_distance.DistanceMetric.Euclid):
|
||||
points = [[random.random() for _ in range(300)] for _ in range(1024)]
|
||||
config = instant_distance.Config()
|
||||
config.distance_metric = distance_metric
|
||||
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
|
||||
p = [random.random() for _ in range(300)]
|
||||
search = instant_distance.Search()
|
||||
|
@ -12,7 +13,7 @@ def test_hsnw():
|
|||
print(candidate)
|
||||
|
||||
|
||||
def test_hsnw_map():
|
||||
def test_hsnw_map(distance_metric=instant_distance.DistanceMetric.Euclid):
|
||||
the_chosen_one = 123
|
||||
|
||||
embeddings = [[random.random() for _ in range(300)] for _ in range(1024)]
|
||||
|
@ -20,6 +21,7 @@ def test_hsnw_map():
|
|||
values = f.read().splitlines()[1024:]
|
||||
|
||||
config = instant_distance.Config()
|
||||
config.distance_metric = distance_metric
|
||||
hnsw_map = instant_distance.HnswMap.build(embeddings, values, config)
|
||||
|
||||
search = instant_distance.Search()
|
||||
|
@ -38,3 +40,5 @@ def test_hsnw_map():
|
|||
if __name__ == "__main__":
|
||||
test_hsnw()
|
||||
test_hsnw_map()
|
||||
test_hsnw(instant_distance.DistanceMetric.Cosine)
|
||||
test_hsnw_map(instant_distance.DistanceMetric.Cosine)
|
||||
|
|
Loading…
Reference in New Issue