File size: 7,457 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
import json
import os

from src.utils.text_utils import postprocess_answer

class MedicalTranslator:
    """
    Dịch thuật y tế với cơ chế Lazy Loading + Independent Fallback.
    - Vi→En: MarianMT (Helsinki-NLP) trên CPU
    - En→Vi: MedCrab-1.5B (4-bit) trên GPU phụ (nếu có)
    Mỗi model load độc lập — nếu 1 cái fail, cái kia vẫn hoạt động.
    """
    def __init__(self, device="cpu", dict_path="data/medical_dict.json"):
        self.device_str = device  # "cuda" hoặc "cpu"
        
        # Chọn GPU: nếu Dual GPU → dùng cuda:1, nếu Single → dùng cuda:0
        if torch.cuda.is_available() and device == "cuda":
            if torch.cuda.device_count() > 1:
                self.gpu_device = torch.device("cuda:1")
                print(f"[INFO] Dual-GPU detected → Translator on {self.gpu_device}")
            else:
                self.gpu_device = torch.device("cuda:0")
        else:
            self.gpu_device = torch.device("cpu")
        
        # State flags
        self._load_attempted = False
        self._vi2en_ready = False
        self._en2vi_ready = False
        
        # Models (lazy)
        self._vi2en_model = None
        self._vi2en_tokenizer = None
        self._en2vi_model = None
        self._en2vi_tokenizer = None
        
        # Medical dictionary
        self.med_dict = {}
        if os.path.exists(dict_path):
            try:
                with open(dict_path, 'r', encoding='utf-8') as f:
                    self.med_dict = json.load(f)
            except:
                pass

    def _lazy_load(self):
        """Nạp models. Chỉ gọi 1 lần duy nhất."""
        if self._load_attempted:
            return
        self._load_attempted = True
        print("[INFO] Đang nạp Translation Models (Lazy Load)...")
        
        # ── 1. Helsinki-NLP Vi→En (Chạy trên CPU, nhẹ ~300MB) ──
        try:
            from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
            vi2en_id = "Helsinki-NLP/opus-mt-vi-en"
            self._vi2en_tokenizer = AutoTokenizer.from_pretrained(vi2en_id)
            self._vi2en_model = AutoModelForSeq2SeqLM.from_pretrained(vi2en_id).to("cpu")
            self._vi2en_model.eval()
            self._vi2en_ready = True
            print("[INFO] ✅ Helsinki-NLP (Vi→En) đã sẵn sàng trên CPU")
        except Exception as e:
            print(f"[WARNING] ❌ Helsinki-NLP load thất bại: {e}")
        
        # ── 2. MedCrab En→Vi (4-bit trên GPU) ──
        try:
            from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16
            )
            medcrab_id = "pnnbao-ump/MedCrab-1.5B"
            self._en2vi_tokenizer = AutoTokenizer.from_pretrained(medcrab_id)
            
            d_map = {"": self.gpu_device} if self.gpu_device.type == "cuda" else None
            self._en2vi_model = AutoModelForCausalLM.from_pretrained(
                medcrab_id,
                quantization_config=bnb_config,
                device_map=d_map,
                low_cpu_mem_usage=True
            )
            self._en2vi_model.eval()
            self._en2vi_ready = True
            print(f"[INFO] ✅ MedCrab-1.5B (En→Vi) đã sẵn sàng trên {self.gpu_device}")
        except Exception as e:
            print(f"[WARNING] ❌ MedCrab load thất bại: {e}")

    # ── Vi → En ──
    def translate_vi2en(self, text):
        """Dịch câu hỏi Tiếng Việt sang Tiếng Anh."""
        if not text:
            return text
        self._lazy_load()
        
        if not self._vi2en_ready:
            # Fallback: trả về nguyên văn (LLaVA vẫn hiểu được một phần)
            return text
        
        try:
            texts = text if isinstance(text, list) else [text]
            results = []
            for t in texts:
                inputs = self._vi2en_tokenizer(t, return_tensors="pt", padding=True, truncation=True, max_length=128)
                with torch.no_grad():
                    output_ids = self._vi2en_model.generate(**inputs, max_new_tokens=128)
                translated = self._vi2en_tokenizer.decode(output_ids[0], skip_special_tokens=True)
                results.append(translated)
            return results if isinstance(text, list) else results[0]
        except Exception as e:
            print(f"[WARNING] Vi→En error: {e}")
            return text

    # ── En → Vi ──
    def translate_en2vi(self, text):
        """Dịch kết quả từ LLaVA-Med sang Tiếng Việt."""
        if not text:
            return text
        
        # 1. Ánh xạ trực tiếp nhãn nhị phân (nhanh + chính xác 100%)
        if isinstance(text, str):
            t = text.lower().strip().rstrip(".").rstrip(",").strip()
            
            # Xử lý các câu trả lời dài bắt đầu bằng Yes/No của LLaVA (vd: "No, the image does not...")
            if t.startswith("yes"):
                return "có"
            if t.startswith("no"):
                return "không"
                
            # Exact match trước
            direct_map = {
                "true": "có", "false": "không",
                "correct": "có", "incorrect": "không",
                "present": "có", "absent": "không",
                "normal": "bình thường", "abnormal": "bất thường",
            }
            if t in direct_map:
                return direct_map[t]
        
        # 2. Dịch bằng MedCrab
        self._lazy_load()
        if not self._en2vi_ready:
            if isinstance(text, list):
                return text
            return text
        
        if isinstance(text, list):
            return [self._medcrab_translate(t) for t in text]
        return self._medcrab_translate(text)

    def _medcrab_translate(self, text):
        """Dịch 1 câu En→Vi bằng MedCrab với ràng buộc ngắn gọn."""
        # Kiểm tra ánh xạ trực tiếp trước
        t = text.lower().strip().rstrip(".").rstrip(",").strip()
        direct_map = {
            "yes": "có", "no": "không",
            "normal": "bình thường", "abnormal": "bất thường",
        }
        if t in direct_map:
            return direct_map[t]
        
        try:
            prompt = f"English: {text}\nVietnamese (trả lời ngắn gọn):"
            inputs = self._en2vi_tokenizer(prompt, return_tensors="pt").to(self.gpu_device)
            
            with torch.no_grad():
                outputs = self._en2vi_model.generate(
                    **inputs,
                    max_new_tokens=30,
                    repetition_penalty=1.2,
                    temperature=0.1,
                    do_sample=False,
                    pad_token_id=self._en2vi_tokenizer.eos_token_id
                )
            
            full_text = self._en2vi_tokenizer.decode(outputs[0], skip_special_tokens=True)
            translated = full_text.split("Vietnamese (trả lời ngắn gọn):")[-1].strip()
            return translated
        except Exception as e:
            print(f"[WARNING] En→Vi error: {e}")
            return text