ayymen's picture
Switch to int8
63e0f06 verified
import gradio as gr
from transformers import NllbTokenizer
import ctranslate2
from huggingface_hub import snapshot_download
MODEL_ID = "Tamazight-NLP/NLLB-200-600M-Tamazight-All-Data-3-epoch-ct2-int8"
snapshot_download(MODEL_ID, local_dir=MODEL_ID.split('/')[1])
translator = ctranslate2.Translator(MODEL_ID.split('/')[1])
nllb_tokenizer = NllbTokenizer.from_pretrained(MODEL_ID.split('/')[1])
NLLB_LANG_MAPPING = {
"English": "eng_Latn",
"Standard Moroccan Tamazight": "tzm_Tfng",
"Tachelhit/Central Atlas Tamazight": "taq_Tfng",
"Tachelhit/Central Atlas Tamazight (Latin)": "taq_Latn",
"Tarifit": "kab_Tfng",
"Tarifit (Latin)": "kab_Latn",
"Moroccan Darija": "ary_Arab",
"Modern Standard Arabic": "arb_Arab",
"Catalan": "cat_Latn",
"Spanish": "spa_Latn",
"French": "fra_Latn",
"German": "deu_Latn",
"Dutch": "nld_Latn",
"Russian": "rus_Cyrl",
"Italian": "ita_Latn",
"Turkish": "tur_Latn",
"Esperanto": "epo_Latn"
}
def translate(text, source_lang="English", target_lang="Tachelhit/Central Atlas Tamazight",
max_length=237, num_beams=4, repetition_penalty=1.0):
"""
Translate multi-line text while preserving line breaks.
Each line is translated independently.
"""
translations = []
for line in text.split("\n"):
if line.strip() == "":
translations.append("") # preserve empty lines
else:
nllb_tokenizer.src_lang = NLLB_LANG_MAPPING[source_lang]
source = nllb_tokenizer.convert_ids_to_tokens(nllb_tokenizer.encode(line))
target_prefix = [NLLB_LANG_MAPPING[target_lang]]
results = translator.translate_batch(
[source],
target_prefix=[target_prefix],
max_decoding_length=max_length,
beam_size=num_beams,
repetition_penalty=repetition_penalty,
)
target = results[0].hypotheses[0][1:]
translation = nllb_tokenizer.decode(nllb_tokenizer.convert_tokens_to_ids(target), skip_special_tokens=True)
translations.append(translation)
return "\n".join(translations)
gradio_ui= gr.Interface(
fn=translate,
title="NLLB Tamazight Translation Demo",
inputs= [
gr.components.Textbox(label="Text", lines=4, placeholder="ⵙⵙⴽⵛⵎ ⴰⴹⵕⵉⵚ...\nEnter text to translate..."),
gr.components.Dropdown(label="Source Language", choices=list(NLLB_LANG_MAPPING.keys()), value="English"),
gr.components.Dropdown(label="Target Language", choices=list(NLLB_LANG_MAPPING.keys()), value="Standard Moroccan Tamazight"),
gr.components.Slider(1, 400, value=237, step=10, label="Max Length (in tokens). Increase in case the output looks truncated."),
gr.components.Slider(1, 25, value=4, step=1, label="Number of beams. Higher values might improve translation accuracy at the cost of speed."),
gr.components.Slider(1, 10, value=1.0, step=0.1, label="Repetition penalty."),
],
outputs=gr.components.Textbox(label="Translated text", lines=4)
)
gradio_ui.launch()