Refactor to prepare for multiple distance metrics

This commit is contained in:
Kuba Jaroszewski 2023-02-16 21:37:35 +01:00
parent f1cb9ee234
commit bca31ad33f
3 changed files with 127 additions and 43 deletions

View File

@ -1,8 +1,14 @@
test-python:
cargo build --release
cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so
instant-distance-py/test/instant_distance.so: instant-distance-py/src/lib.rs
RUSTFLAGS="-C target-cpu=native" cargo build --release
([ -f target/release/libinstant_distance.dylib ] && cp target/release/libinstant_distance.dylib instant-distance-py/test/instant_distance.so) || \
([ -f target/release/libinstant_distance.so ] && cp target/release/libinstant_distance.so instant-distance-py/test/instant_distance.so)
test-python: instant-distance-py/test/instant_distance.so
PYTHONPATH=instant-distance-py/test/ python3 -m test
bench-python: instant-distance-py/test/instant_distance.so
PYTHONPATH=instant-distance-py/test/ python3 -m timeit -n 10 -s 'import random, instant_distance; points = [[random.random() for _ in range(300)] for _ in range(1024)]; config = instant_distance.Config()' 'instant_distance.Hnsw.build(points, config)'
clean:
cargo clean
rm -f instant-distance-py/test/instant_distance.so

View File

@ -24,12 +24,46 @@ fn instant_distance_py(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Search>()?;
m.add_class::<Hnsw>()?;
m.add_class::<HnswMap>()?;
m.add_class::<DistanceMetric>()?;
Ok(())
}
#[pyclass]
#[derive(Copy, Clone)]
enum DistanceMetric {
Euclid,
Cosine,
}
impl Default for DistanceMetric {
fn default() -> Self {
Self::Euclid
}
}
// Helper macro for dispatching to inner implementation
macro_rules! impl_for_each_hnsw_with_metric {
($type:ident, $instance:expr, $inner:ident, $($tokens:tt)+) => {
match $instance {
$type::Euclid($inner) => {
$($tokens)+
}
$type::Cosine($inner) => {
$($tokens)+
}
}
};
}
#[pyclass]
struct HnswMap {
inner: instant_distance::HnswMap<FloatArray, MapValue>,
inner: HnswMapWithMetric,
}
#[derive(Deserialize, Serialize)]
enum HnswMapWithMetric {
Euclid(instant_distance::HnswMap<FloatArray, MapValue>),
Cosine(instant_distance::HnswMap<FloatArray, MapValue>),
}
#[pymethods]
@ -37,28 +71,32 @@ impl HnswMap {
/// Build the index
#[staticmethod]
fn build(points: &PyList, values: &PyList, config: &Config) -> PyResult<Self> {
let points = points
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let values = values
.into_iter()
.map(MapValue::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
Ok(Self { inner: hsnw_map })
let builder = instant_distance::Builder::from(config);
let inner = match config.distance_metric {
DistanceMetric::Euclid => {
let points = FloatArray::try_from_pylist(points)?;
HnswMapWithMetric::Euclid(builder.build(points, values))
}
DistanceMetric::Cosine => {
let points = FloatArray::try_from_pylist(points)?;
HnswMapWithMetric::Cosine(builder.build(points, values))
}
};
Ok(Self { inner })
}
/// Load an index from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw_map =
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, MapValue>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
let hnsw_map = bincode::deserialize_from::<_, HnswMapWithMetric>(BufReader::with_capacity(
32 * 1024 * 1024,
File::open(fname)?,
))
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
Ok(Self { inner: hnsw_map })
}
@ -78,8 +116,10 @@ impl HnswMap {
///
/// For best performance, reusing `Search` objects is recommended.
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
let point = FloatArray::try_from(point)?;
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
let point = FloatArray::try_from(point)?;
let _ = hnsw.search(&point, &mut search.inner);
});
search.cur = Some((HnswType::Map(slf.clone_ref(py)), 0));
Ok(())
}
@ -91,7 +131,13 @@ impl HnswMap {
/// with a squared Euclidean distance metric.
#[pyclass]
struct Hnsw {
inner: instant_distance::Hnsw<FloatArray>,
inner: HnswWithMetric,
}
#[derive(Deserialize, Serialize)]
enum HnswWithMetric {
Euclid(instant_distance::Hnsw<FloatArray>),
Cosine(instant_distance::Hnsw<FloatArray>),
}
#[pymethods]
@ -99,12 +145,19 @@ impl Hnsw {
/// Build the index
#[staticmethod]
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
let points = input
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
let builder = instant_distance::Builder::from(config);
let (inner, ids) = match config.distance_metric {
DistanceMetric::Euclid => {
let points = FloatArray::try_from_pylist(input)?;
let (hnsw, ids) = builder.build_hnsw(points);
(HnswWithMetric::Euclid(hnsw), ids)
}
DistanceMetric::Cosine => {
let points = FloatArray::try_from_pylist(input)?;
let (hnsw, ids) = builder.build_hnsw(points);
(HnswWithMetric::Cosine(hnsw), ids)
}
};
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
Ok((Self { inner }, ids))
}
@ -112,9 +165,10 @@ impl Hnsw {
/// Load an index from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
let hnsw = bincode::deserialize_from::<_, HnswWithMetric>(BufReader::with_capacity(
32 * 1024 * 1024,
File::open(fname)?,
))
.map_err(|e| PyValueError::new_err(format!("deserialization error: {e:?}")))?;
Ok(Self { inner: hnsw })
}
@ -135,8 +189,10 @@ impl Hnsw {
///
/// For best performance, reusing `Search` objects is recommended.
fn search(slf: Py<Self>, point: &PyAny, search: &mut Search, py: Python<'_>) -> PyResult<()> {
let point = FloatArray::try_from(point)?;
let _ = slf.try_borrow(py)?.inner.search(&point, &mut search.inner);
impl_for_each_hnsw_with_metric!(HnswWithMetric, &slf.try_borrow(py)?.inner, hnsw, {
let point = FloatArray::try_from(point)?;
let _ = hnsw.search(&point, &mut search.inner);
});
search.cur = Some((HnswType::Hnsw(slf.clone_ref(py)), 0));
Ok(())
}
@ -175,20 +231,24 @@ impl Search {
let neighbor = match &index {
HnswType::Hnsw(hnsw) => {
let hnsw = hnsw.as_ref(py).borrow();
let item = hnsw.inner.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: py.None(),
impl_for_each_hnsw_with_metric!(HnswWithMetric, &hnsw.inner, hnsw, {
let item = hnsw.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: py.None(),
})
})
}
HnswType::Map(map) => {
let map = map.as_ref(py).borrow();
let item = map.inner.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: item.value.into_py(py),
impl_for_each_hnsw_with_metric!(HnswMapWithMetric, &map.inner, map, {
let item = map.get(idx, &slf.inner);
item.map(|item| Neighbor {
distance: item.distance,
pid: item.pid.into_inner(),
value: item.value.into_py(py),
})
})
}
};
@ -226,6 +286,11 @@ struct Config {
/// in order to get better results on clustered data points.
#[pyo3(get, set)]
heuristic: Option<Heuristic>,
/// Distance metric to use
///
/// Defaults to Euclidean distance
#[pyo3(get, set)]
distance_metric: DistanceMetric,
}
#[pymethods]
@ -235,12 +300,14 @@ impl Config {
let builder = instant_distance::Builder::default();
let (ef_search, ef_construction, ml, seed) = builder.into_parts();
let heuristic = Some(Heuristic::default());
let distance_metric = DistanceMetric::default();
Self {
ef_search,
ef_construction,
ml,
seed,
heuristic,
distance_metric,
}
}
}
@ -253,6 +320,7 @@ impl From<&Config> for instant_distance::Builder {
ml,
seed,
heuristic,
distance_metric: _,
} = *py;
Self::default()
.ef_search(ef_search)
@ -350,6 +418,12 @@ impl Neighbor {
#[derive(Clone, Deserialize, Serialize)]
struct FloatArray(#[serde(with = "BigArray")] [f32; DIMENSIONS]);
impl FloatArray {
fn try_from_pylist(list: &PyList) -> Result<Vec<Self>, PyErr> {
list.into_iter().map(FloatArray::try_from).collect()
}
}
impl TryFrom<&PyAny> for FloatArray {
type Error = PyErr;

View File

@ -1,9 +1,10 @@
import instant_distance, random
def test_hsnw():
def test_hsnw(distance_metric=instant_distance.DistanceMetric.Euclid):
points = [[random.random() for _ in range(300)] for _ in range(1024)]
config = instant_distance.Config()
config.distance_metric = distance_metric
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
p = [random.random() for _ in range(300)]
search = instant_distance.Search()
@ -12,7 +13,7 @@ def test_hsnw():
print(candidate)
def test_hsnw_map():
def test_hsnw_map(distance_metric=instant_distance.DistanceMetric.Euclid):
the_chosen_one = 123
embeddings = [[random.random() for _ in range(300)] for _ in range(1024)]
@ -20,6 +21,7 @@ def test_hsnw_map():
values = f.read().splitlines()[1024:]
config = instant_distance.Config()
config.distance_metric = distance_metric
hnsw_map = instant_distance.HnswMap.build(embeddings, values, config)
search = instant_distance.Search()
@ -38,3 +40,5 @@ def test_hsnw_map():
if __name__ == "__main__":
test_hsnw()
test_hsnw_map()
test_hsnw(instant_distance.DistanceMetric.Cosine)
test_hsnw_map(instant_distance.DistanceMetric.Cosine)