py: reverse order of HnswMap and Hnsw
This commit is contained in:
parent
3d9e0a4b3d
commit
07d2f1aedc
|
@ -23,63 +23,6 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An instance of hierarchical navigable small worlds
|
|
||||||
///
|
|
||||||
/// For now, this is specialized to only support 300-element (32-bit) float vectors
|
|
||||||
/// with a squared Euclidean distance metric.
|
|
||||||
#[pyclass]
|
|
||||||
struct Hnsw {
|
|
||||||
inner: instant_distance::Hnsw<FloatArray>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl Hnsw {
|
|
||||||
/// Build the index
|
|
||||||
#[staticmethod]
|
|
||||||
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
|
|
||||||
let points = input
|
|
||||||
.into_iter()
|
|
||||||
.map(FloatArray::try_from)
|
|
||||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
|
||||||
|
|
||||||
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
|
|
||||||
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
|
||||||
Ok((Self { inner }, ids))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Load an index from the given file name
|
|
||||||
#[staticmethod]
|
|
||||||
fn load(fname: &str) -> PyResult<Self> {
|
|
||||||
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
|
|
||||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
|
||||||
)
|
|
||||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
|
||||||
Ok(Self { inner: hnsw })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Dump the index to the given file name
|
|
||||||
fn dump(&self, fname: &str) -> PyResult<()> {
|
|
||||||
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
|
|
||||||
bincode::serialize_into(f, &self.inner)
|
|
||||||
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Search the index for points neighboring the given point
|
|
||||||
///
|
|
||||||
/// The `search` object contains buffers used for searching. When the search completes,
|
|
||||||
/// iterate over the `Search` to get the results. The number of results should be equal
|
|
||||||
/// to the `ef_search` parameter set in the index's `config`.
|
|
||||||
///
|
|
||||||
/// For best performance, reusing `Search` objects is recommended.
|
|
||||||
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
|
||||||
let point = FloatArray::try_from(point)?;
|
|
||||||
let _ = self.inner.search(&point, &mut search.inner);
|
|
||||||
search.cur = Some(0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
struct HnswMap {
|
struct HnswMap {
|
||||||
inner: instant_distance::HnswMap<FloatArray, String>,
|
inner: instant_distance::HnswMap<FloatArray, String>,
|
||||||
|
@ -137,6 +80,63 @@ impl HnswMap {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An instance of hierarchical navigable small worlds
|
||||||
|
///
|
||||||
|
/// For now, this is specialized to only support 300-element (32-bit) float vectors
|
||||||
|
/// with a squared Euclidean distance metric.
|
||||||
|
#[pyclass]
|
||||||
|
struct Hnsw {
|
||||||
|
inner: instant_distance::Hnsw<FloatArray>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl Hnsw {
|
||||||
|
/// Build the index
|
||||||
|
#[staticmethod]
|
||||||
|
fn build(input: &PyList, config: &Config) -> PyResult<(Self, Vec<u32>)> {
|
||||||
|
let points = input
|
||||||
|
.into_iter()
|
||||||
|
.map(FloatArray::try_from)
|
||||||
|
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||||
|
|
||||||
|
let (inner, ids) = instant_distance::Builder::from(config).build_hnsw(points);
|
||||||
|
let ids = Vec::from_iter(ids.into_iter().map(|pid| pid.into_inner()));
|
||||||
|
Ok((Self { inner }, ids))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load an index from the given file name
|
||||||
|
#[staticmethod]
|
||||||
|
fn load(fname: &str) -> PyResult<Self> {
|
||||||
|
let hnsw = bincode::deserialize_from::<_, instant_distance::Hnsw<FloatArray>>(
|
||||||
|
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||||
|
)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
||||||
|
Ok(Self { inner: hnsw })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dump the index to the given file name
|
||||||
|
fn dump(&self, fname: &str) -> PyResult<()> {
|
||||||
|
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
|
||||||
|
bincode::serialize_into(f, &self.inner)
|
||||||
|
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search the index for points neighboring the given point
|
||||||
|
///
|
||||||
|
/// The `search` object contains buffers used for searching. When the search completes,
|
||||||
|
/// iterate over the `Search` to get the results. The number of results should be equal
|
||||||
|
/// to the `ef_search` parameter set in the index's `config`.
|
||||||
|
///
|
||||||
|
/// For best performance, reusing `Search` objects is recommended.
|
||||||
|
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
||||||
|
let point = FloatArray::try_from(point)?;
|
||||||
|
let _ = self.inner.search(&point, &mut search.inner);
|
||||||
|
search.cur = Some(0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Search buffer and result set
|
/// Search buffer and result set
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
struct Search {
|
struct Search {
|
||||||
|
|
Loading…
Reference in New Issue