File size: 4,058 Bytes
ca51260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gc
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

current_model_id = None
tokenizer = None
model = None

LANGUAGES = [
    "Arabic", "Bengali", "Burmese", "Cantonese", "Chinese", "Czech", 
    "Dutch", "English", "Filipino", "French", "German", "Gujarati", 
    "Hebrew", "Hindi", "Indonesian", "Italian", "Japanese", "Kazakh", 
    "Khmer", "Korean", "Malay", "Marathi", "Mongolian", "Persian", 
    "Polish", "Portuguese", "Russian", "Spanish", "Tamil", "Telugu", 
    "Thai", "Tibetan", "Traditional Chinese", "Turkish", "Ukrainian", 
    "Urdu", "Uyghur", "Vietnamese"
]

MODELS = [
    "tencent/Hy-MT2-30B-A3B",
    "tencent/Hy-MT2-7B",
    "tencent/Hy-MT2-1.8B"
]

def load_model(model_id):
    global current_model_id, tokenizer, model
    
    if current_model_id == model_id:
        return
        
    print(f"Switching model from {current_model_id} to {model_id}...")
    
    if model is not None:
        del model
        del tokenizer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    print(f"Loading tokenizer for {model_id}...")
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    
    print(f"Loading model {model_id}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    current_model_id = model_id
    print("Model loaded successfully.")

@spaces.GPU
def translate(source_text, target_lang, selected_model):
    global current_model_id, tokenizer, model
    
    if not source_text.strip():
        return ""
        
    try:
        load_model(selected_model)
        
        prompt = f"Translate the following text into {target_lang}. Note that you should only output the translated result without any additional explanation:\n\n{source_text}"
        messages = [{"role": "user", "content": prompt}]
        
        inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
        
        if "30B" in selected_model:
            gen_kwargs = {
                "temperature": 0.7,
                "top_p": 1.0,
                "repetition_penalty": 1.0,
            }
        else:
            gen_kwargs = {
                "temperature": 0.7,
                "top_p": 0.6,
                "top_k": 20,
                "repetition_penalty": 1.05,
            }
            
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=4096,
                **gen_kwargs
            )
            
        response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
        return response
        
    except Exception as e:
        return f"Error during generation: {str(e)}\n\n(Note: Zero GPU environments may timeout or run out of memory when loading large models dynamically.)"

with gr.Blocks(title="Hy-MT2 Translator") as demo:
    gr.Markdown("# Hy-MT2 Translator")
    gr.Markdown("https://huggingface.co/collections/tencent/hy-mt2")
    
    with gr.Row():
        with gr.Column():
            source_text = gr.Textbox(label="Source Text", lines=8, placeholder="Enter text to translate...")
            
            with gr.Row():
                target_lang = gr.Dropdown(choices=LANGUAGES, value="English", label="Target Language")
                model_selector = gr.Dropdown(choices=MODELS, value="tencent/Hy-MT2-1.8B", label="Model")
                
            translate_btn = gr.Button("Translate", variant="primary")
            
        with gr.Column():
            output_text = gr.Textbox(label="Translated Text", lines=12, interactive=False)
            
    translate_btn.click(
        fn=translate,
        inputs=[source_text, target_lang, model_selector],
        outputs=output_text
    )

if __name__ == "__main__":
    demo.launch()