|
import gradio as gr |
|
import yaml |
|
import torch |
|
from mmtafrica import load_params, translate |
|
from huggingface_hub import hf_hub_download |
|
|
|
language_map = {'English':'en','Swahili':'sw','Fon':'fon','Igbo':'ig', |
|
'Kinyarwanda':'rw','Xhosa':'xh','Yoruba':'yo','French':'fr'} |
|
|
|
available_languages = list(language_map.keys()) |
|
|
|
|
|
checkpoint = hf_hub_download(repo_id="chrisjay/mmtafrica", filename="mmt_translation.pt") |
|
device = 'gpu' if torch.cuda.is_available() else 'cpu' |
|
params = load_params({'checkpoint':checkpoint,'device':device}) |
|
|
|
|
|
def get_translation(source_language,target_language,source_sentence=None): |
|
''' |
|
This takes a sentence and gets the translation. |
|
''' |
|
|
|
source_language_ = language_map[source_language] |
|
target_language_ = language_map[target_language] |
|
|
|
|
|
try: |
|
pred = translate(params,source_sentence,source_lang=source_language_,target_lang=target_language_) |
|
if pred=='': |
|
return f"Could not find translation" |
|
else: |
|
return pred |
|
except Exception as error: |
|
return f"Issue with translation: \n {error}" |
|
|
|
title = "MMTAfrica: Multilingual Machine Translation" |
|
description = "Enjoy our MMT model that allows you to translate among 6 African languages, English and French!" |
|
|
|
iface = gr.Interface(fn=get_translation, |
|
inputs=[gr.inputs.Dropdown(choices = available_languages,default='Igbo'), |
|
gr.inputs.Dropdown(choices = available_languages,default='Fon'), |
|
gr.inputs.Textbox(label="Input")], |
|
outputs=gr.outputs.Textbox(type="auto", label='Translation'), |
|
title=title, |
|
description=description, |
|
enable_queue=True, |
|
theme='huggingface') |
|
iface.launch() |