First cut of python bindings update
This commit is contained in:
parent
99a473c298
commit
1d27340883
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"python.formatting.provider": "black"
|
||||
}
|
|
@ -19,6 +19,7 @@ fn instant_distance(_: Python, m: &PyModule) -> PyResult<()> {
|
|||
m.add_class::<Config>()?;
|
||||
m.add_class::<Search>()?;
|
||||
m.add_class::<Hnsw>()?;
|
||||
m.add_class::<HnswMap>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -79,6 +80,63 @@ impl Hnsw {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct HnswMap {
|
||||
inner: instant_distance::HnswMap<FloatArray, String>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl HnswMap {
|
||||
#[getter]
|
||||
fn values(&self) -> PyResult<Vec<String>> {
|
||||
Ok(self.inner.values.clone())
|
||||
}
|
||||
/// Build the index
|
||||
#[staticmethod]
|
||||
fn build(points: &PyList, values: Vec<String>, config: &Config) -> PyResult<Self> {
|
||||
let points = points
|
||||
.into_iter()
|
||||
.map(FloatArray::try_from)
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
let hsnw_map = instant_distance::Builder::from(config).build(points, values);
|
||||
Ok(Self { inner: hsnw_map })
|
||||
}
|
||||
|
||||
/// Load an index from the given file name
|
||||
#[staticmethod]
|
||||
fn load(fname: &str) -> PyResult<Self> {
|
||||
let hnsw_map =
|
||||
bincode::deserialize_from::<_, instant_distance::HnswMap<FloatArray, String>>(
|
||||
BufReader::with_capacity(32 * 1024 * 1024, File::open(fname)?),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(format!("deserialization error: {:?}", e)))?;
|
||||
Ok(Self { inner: hnsw_map })
|
||||
}
|
||||
|
||||
/// Dump the index to the given file name
|
||||
fn dump(&self, fname: &str) -> PyResult<()> {
|
||||
let f = BufWriter::with_capacity(32 * 1024 * 1024, File::create(fname)?);
|
||||
bincode::serialize_into(f, &self.inner)
|
||||
.map_err(|e| PyValueError::new_err(format!("serialization error: {:?}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search the index for points neighboring the given point
|
||||
///
|
||||
/// The `search` object contains buffers used for searching. When the search completes,
|
||||
/// iterate over the `Search` to get the results. The number of results should be equal
|
||||
/// to the `ef_search` parameter set in the index's `config`.
|
||||
///
|
||||
/// For best performance, reusing `Search` objects is recommended.
|
||||
fn search(&self, point: &PyAny, search: &mut Search) -> PyResult<()> {
|
||||
let point = FloatArray::try_from(point)?;
|
||||
let _ = self.inner.search(&point, &mut search.inner);
|
||||
search.cur = Some(0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Search buffer and result set
|
||||
#[pyclass]
|
||||
struct Search {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import instant_distance, random
|
||||
|
||||
def main():
|
||||
|
||||
def test_hsnw():
|
||||
points = [[random.random() for _ in range(300)] for _ in range(1024)]
|
||||
config = instant_distance.Config()
|
||||
(hnsw, ids) = instant_distance.Hnsw.build(points, config)
|
||||
|
@ -10,5 +11,31 @@ def main():
|
|||
for candidate in search:
|
||||
print(candidate)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
def test_hsnw_map():
|
||||
the_chosen_one = 123
|
||||
|
||||
embeddings = [[random.random() for _ in range(300)] for _ in range(1024)]
|
||||
with open("/usr/share/dict/words", "r") as f: # *nix only
|
||||
values = f.read().splitlines()[1024:]
|
||||
|
||||
config = instant_distance.Config()
|
||||
hnsw_map = instant_distance.HnswMap.build(embeddings, values, config)
|
||||
|
||||
search = instant_distance.Search()
|
||||
hnsw_map.search(embeddings[the_chosen_one], search)
|
||||
|
||||
closest_pid = list(search)[0].pid
|
||||
approx_nearest = hnsw_map.values[closest_pid]
|
||||
actual_word = values[the_chosen_one]
|
||||
|
||||
print("pid:\t\t", closest_pid)
|
||||
print("approx word:\t", approx_nearest)
|
||||
print("actual word:\t", actual_word)
|
||||
|
||||
assert approx_nearest == actual_word
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hsnw()
|
||||
test_hsnw_map()
|
||||
|
|
|
@ -130,7 +130,7 @@ impl Default for Heuristic {
|
|||
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
|
||||
pub struct HnswMap<P, V> {
|
||||
hnsw: Hnsw<P>,
|
||||
values: Vec<V>,
|
||||
pub values: Vec<V>,
|
||||
}
|
||||
|
||||
impl<P, V> HnswMap<P, V>
|
||||
|
|
Loading…
Reference in New Issue