Hy-MT2 / app.py
noxwano's picture
Upload 3 files
ca51260 verified
import gc
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
current_model_id = None
tokenizer = None
model = None
LANGUAGES = [
"Arabic", "Bengali", "Burmese", "Cantonese", "Chinese", "Czech",
"Dutch", "English", "Filipino", "French", "German", "Gujarati",
"Hebrew", "Hindi", "Indonesian", "Italian", "Japanese", "Kazakh",
"Khmer", "Korean", "Malay", "Marathi", "Mongolian", "Persian",
"Polish", "Portuguese", "Russian", "Spanish", "Tamil", "Telugu",
"Thai", "Tibetan", "Traditional Chinese", "Turkish", "Ukrainian",
"Urdu", "Uyghur", "Vietnamese"
]
MODELS = [
"tencent/Hy-MT2-30B-A3B",
"tencent/Hy-MT2-7B",
"tencent/Hy-MT2-1.8B"
]
def load_model(model_id):
global current_model_id, tokenizer, model
if current_model_id == model_id:
return
print(f"Switching model from {current_model_id} to {model_id}...")
if model is not None:
del model
del tokenizer
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Loading tokenizer for {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print(f"Loading model {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
current_model_id = model_id
print("Model loaded successfully.")
@spaces.GPU
def translate(source_text, target_lang, selected_model):
global current_model_id, tokenizer, model
if not source_text.strip():
return ""
try:
load_model(selected_model)
prompt = f"Translate the following text into {target_lang}. Note that you should only output the translated result without any additional explanation:\n\n{source_text}"
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
if "30B" in selected_model:
gen_kwargs = {
"temperature": 0.7,
"top_p": 1.0,
"repetition_penalty": 1.0,
}
else:
gen_kwargs = {
"temperature": 0.7,
"top_p": 0.6,
"top_k": 20,
"repetition_penalty": 1.05,
}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=4096,
**gen_kwargs
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
return response
except Exception as e:
return f"Error during generation: {str(e)}\n\n(Note: Zero GPU environments may timeout or run out of memory when loading large models dynamically.)"
with gr.Blocks(title="Hy-MT2 Translator") as demo:
gr.Markdown("# Hy-MT2 Translator")
gr.Markdown("https://huggingface.co/collections/tencent/hy-mt2")
with gr.Row():
with gr.Column():
source_text = gr.Textbox(label="Source Text", lines=8, placeholder="Enter text to translate...")
with gr.Row():
target_lang = gr.Dropdown(choices=LANGUAGES, value="English", label="Target Language")
model_selector = gr.Dropdown(choices=MODELS, value="tencent/Hy-MT2-1.8B", label="Model")
translate_btn = gr.Button("Translate", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Translated Text", lines=12, interactive=False)
translate_btn.click(
fn=translate,
inputs=[source_text, target_lang, model_selector],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()