First cut of python bindings update

This commit is contained in:
Nicholas Rempel 2021-05-18 14:08:19 -07:00 committed by Dirkjan Ochtman
parent 99a473c298
commit 3d9e0a4b3d
4 changed files with 92 additions and 4 deletions

3
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"python.formatting.provider": "black"
}

View File

@ -19,6 +19,7 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Config>()?;
m.add_class::<Search>()?;
m.add_class::<Hnsw>()?;
m.add_class::<HnswMap>()?;
Ok(())
}
@ -79,6 +80,63 @@ impl Hnsw {
}
}
#[pyclass]
struct HnswMap {
inner: instant_distance::HnswMap<FloatArray, String>,
}
#[pymethods]
impl HnswMap {
#[getter]
fn values(&self) -> PyResult<Vec<String>> {
Ok(self.inner.values.clone())
}
/// Build the index
#[staticmethod]
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
let points = points
.into_iter()
.map(FloatArray::try_from)
.collect::<Result<Vec<_>, PyErr>>()?;
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
Ok(Self { inner: hsnw_map })
}
/// Load an index from the given file name
#[staticmethod]
fn load(fname: &str) -> PyResult<Self> {
let hnsw_map =
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, String>>(
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
)
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
Ok(Self { inner: hnsw_map })
}
/// Dump the index to the given file name
fn dump(&self, fname: &str) -> PyResult<()> {
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
bincode::serialize_into(f, &self.inner)
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
Ok(())
}
/// Search the index for points neighboring the given point
///
/// The `search` object contains buffers used for searching. When the search completes,
/// iterate over the `Search` to get the results. The number of results should be equal
/// to the `ef_search` parameter set in the index's `config`.
///
/// For best performance, reusing `Search` objects is recommended.
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
let point = FloatArray::try_from(point)?;
let _ = self.inner.search(&point, &mut search.inner);
search.cur = Some(0);
Ok(())
}
}
/// Search buffer and result set
#[pyclass]
struct Search {

View File

@ -1,6 +1,7 @@
import instant_distance, random
def main():
def test_hsnw():
points = [[random.random() for _ in range(300)] for _ in range(1024)]
config = instant_distance.Config()
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
@ -10,5 +11,31 @@ def main():
for candidate in search:
print(candidate)
if __name__ == '__main__':
main()
def test_hsnw_map():
the_chosen_one = 123
embeddings = [[random.random() for _ in range(300)] for _ in range(1024)]
with open("/usr/share/dict/words", "r") as f: # *nix only
values = f.read().splitlines()[1024:]
config = instant_distance.Config()
hnsw_map = instant_distance.HnswMap.build(embeddings, values, config)
search = instant_distance.Search()
hnsw_map.search(embeddings[the_chosen_one], search)
closest_pid = list(search)[0].pid
approx_nearest = hnsw_map.values[closest_pid]
actual_word = values[the_chosen_one]
print("pid:\t\t", closest_pid)
print("approx word:\t", approx_nearest)
print("actual word:\t", actual_word)
assert approx_nearest == actual_word
if __name__ == "__main__":
test_hsnw()
test_hsnw_map()

View File

@ -130,7 +130,7 @@ impl Default for Heuristic {
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct HnswMap<P, V> {
hnsw: Hnsw<P>,
values: Vec<V>,
pub values: Vec<V>,
}
impl<P, V> HnswMap<P, V>