instant-distance/instant-distance-py/examples/translations/translate.py

154 lines
5.6 KiB
Python
Raw Normal View History

2021-05-29 16:04:36 +00:00
import asyncio
import json
import os
from string import digits, punctuation, whitespace
2021-05-29 16:04:36 +00:00
import sys
import aiohttp
import instant_distance
from progress.bar import IncrementalBar
MAX_LINES = 100_000
LANGS = ("en", "fr", "it")
LANG_REPLACE = "$$lang"
DL_TEMPLATE = f"https://dl.fbaipublicfiles.com/fasttext/vectors-aligned/wiki.{LANG_REPLACE}.align.vec"
BUILT_IDX_PATH = f"./data/{'_'.join(LANGS)}.idx"
WORD_MAP_PATH = f"./data/{'_'.join(LANGS)}.json"
async def download_build_index():
"""
This function downloads pre-trained word vectors trained on Wikipedia using fastText:
https://fasttext.cc/docs/en/aligned-vectors.html
The content is streamed and we take only the first 100,000 lines and drop the long tail
of less common words. We intercept each line and use this information to build
an instant-distance index file. We maintain a mapping of english word to embeddings
in order to convert the translation input to an embedding.
"""
points = []
values = []
word_map = {}
print("Downloading vector files and building indexes...")
async with aiohttp.ClientSession() as session:
for lang in LANGS:
# Construct a url for each language
url = DL_TEMPLATE.replace(LANG_REPLACE, lang)
# Ensure the directory and files exist
os.makedirs(os.path.dirname(BUILT_IDX_PATH), exist_ok=True)
lineno = 0
with IncrementalBar(
f"Downloading {url.split('/')[-1]}", max=MAX_LINES
) as bar:
async with session.get(url) as resp:
while True:
lineno += 1
line = await resp.content.readline()
if not line:
# EOF
break
# We just use the top 100k embeddings to
# save on space and time
if lineno > MAX_LINES:
break
2021-05-29 16:04:36 +00:00
linestr = line.decode("utf-8")
tokens = linestr.split(" ")
# The first token is the word and the rest
# are the embedding
value = tokens[0]
embedding = [float(p) for p in tokens[1:]]
# We only go from english to the other two languages
if lang == "en":
word_map[value] = embedding
else:
# Don't index words that exist in english or
# that include non-letter characters to improve
# the quality of the results.
if (
value in word_map
or any(p in value for p in punctuation)
or any(p in value for p in whitespace)
or any(p in value for p in digits)
):
bar.next()
2021-05-31 17:09:22 +00:00
continue
2021-05-29 16:04:36 +00:00
# We track values here to build the instant-distance index
# Every value is prepended with 2 character language code.
# This allows us to determine language output later.
values.append(lang + value)
points.append(embedding)
bar.next()
# Build the instant-distance index and dump it out to a file with .idx suffix
print("Building index... (this will take a while)")
hnsw = instant_distance.HnswMap.build(points, values, instant_distance.Config())
hnsw.dump(BUILT_IDX_PATH)
# Store the mapping from string to embedding in a .json file
with open(WORD_MAP_PATH, "w") as f:
json.dump(word_map, f)
async def translate(word):
"""
This function relies on the index built in the `download_build_index` function.
If the data does not yet exist, it will download and build the index.
The input is expected to be english. A word is first mapped onto an embedding
from the mapping stored as json. Then we use instant-distance to find the approximate
nearest neighbors to that point (embedding) in order to translate to other languages.
"""
data_exists = os.path.isfile(BUILT_IDX_PATH) and os.path.isfile(WORD_MAP_PATH)
if not data_exists:
print("Instant Distance index not present. Building...")
await download_build_index()
print("Loading indexes from filesystem...")
with open(WORD_MAP_PATH, "r") as f:
word_map = json.load(f)
# Get an embedding for the given word
embedding = word_map.get(word)
if not embedding:
print(f"Word not recognized: {word}")
exit(1)
hnsw = instant_distance.HnswMap.load(BUILT_IDX_PATH)
search = instant_distance.Search()
hnsw.search(embedding, search)
# Print the results
for result in list(search)[:10]:
# We know that the first two characters of the value is the language code
# from when we built the index.
print(f"Language: {result.value[:2]}, Translation: {result.value[2:]}")
async def main():
args = sys.argv[1:]
try:
if args[0] == "build":
await download_build_index()
exit(0)
elif args[0] == "translate":
await translate(args[1])
exit(0)
except IndexError:
pass
print(f"usage:\t{sys.argv[0]} prepare\n\t{sys.argv[0]} translate <english word>")
exit(1)
loop = asyncio.get_event_loop()
loop.run_until_complete(main())