import torch import json import os from src.utils.text_utils import postprocess_answer class MedicalTranslator: """ Dịch thuật y tế với cơ chế Lazy Loading + Independent Fallback. - Vi→En: MarianMT (Helsinki-NLP) trên CPU - En→Vi: MedCrab-1.5B (4-bit) trên GPU phụ (nếu có) Mỗi model load độc lập — nếu 1 cái fail, cái kia vẫn hoạt động. """ def __init__(self, device="cpu", dict_path="data/medical_dict.json"): self.device_str = device # "cuda" hoặc "cpu" # Chọn GPU: nếu Dual GPU → dùng cuda:1, nếu Single → dùng cuda:0 if torch.cuda.is_available() and device == "cuda": if torch.cuda.device_count() > 1: self.gpu_device = torch.device("cuda:1") print(f"[INFO] Dual-GPU detected → Translator on {self.gpu_device}") else: self.gpu_device = torch.device("cuda:0") else: self.gpu_device = torch.device("cpu") # State flags self._load_attempted = False self._vi2en_ready = False self._en2vi_ready = False # Models (lazy) self._vi2en_model = None self._vi2en_tokenizer = None self._en2vi_model = None self._en2vi_tokenizer = None # Medical dictionary self.med_dict = {} if os.path.exists(dict_path): try: with open(dict_path, 'r', encoding='utf-8') as f: self.med_dict = json.load(f) except: pass def _lazy_load(self): """Nạp models. Chỉ gọi 1 lần duy nhất.""" if self._load_attempted: return self._load_attempted = True print("[INFO] Đang nạp Translation Models (Lazy Load)...") # ── 1. Helsinki-NLP Vi→En (Chạy trên CPU, nhẹ ~300MB) ── try: from transformers import AutoTokenizer, AutoModelForSeq2SeqLM vi2en_id = "Helsinki-NLP/opus-mt-vi-en" self._vi2en_tokenizer = AutoTokenizer.from_pretrained(vi2en_id) self._vi2en_model = AutoModelForSeq2SeqLM.from_pretrained(vi2en_id).to("cpu") self._vi2en_model.eval() self._vi2en_ready = True print("[INFO] ✅ Helsinki-NLP (Vi→En) đã sẵn sàng trên CPU") except Exception as e: print(f"[WARNING] ❌ Helsinki-NLP load thất bại: {e}") # ── 2. MedCrab En→Vi (4-bit trên GPU) ── try: from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) medcrab_id = "pnnbao-ump/MedCrab-1.5B" self._en2vi_tokenizer = AutoTokenizer.from_pretrained(medcrab_id) d_map = {"": self.gpu_device} if self.gpu_device.type == "cuda" else None self._en2vi_model = AutoModelForCausalLM.from_pretrained( medcrab_id, quantization_config=bnb_config, device_map=d_map, low_cpu_mem_usage=True ) self._en2vi_model.eval() self._en2vi_ready = True print(f"[INFO] ✅ MedCrab-1.5B (En→Vi) đã sẵn sàng trên {self.gpu_device}") except Exception as e: print(f"[WARNING] ❌ MedCrab load thất bại: {e}") # ── Vi → En ── def translate_vi2en(self, text): """Dịch câu hỏi Tiếng Việt sang Tiếng Anh.""" if not text: return text self._lazy_load() if not self._vi2en_ready: # Fallback: trả về nguyên văn (LLaVA vẫn hiểu được một phần) return text try: texts = text if isinstance(text, list) else [text] results = [] for t in texts: inputs = self._vi2en_tokenizer(t, return_tensors="pt", padding=True, truncation=True, max_length=128) with torch.no_grad(): output_ids = self._vi2en_model.generate(**inputs, max_new_tokens=128) translated = self._vi2en_tokenizer.decode(output_ids[0], skip_special_tokens=True) results.append(translated) return results if isinstance(text, list) else results[0] except Exception as e: print(f"[WARNING] Vi→En error: {e}") return text # ── En → Vi ── def translate_en2vi(self, text): """Dịch kết quả từ LLaVA-Med sang Tiếng Việt.""" if not text: return text # 1. Ánh xạ trực tiếp nhãn nhị phân (nhanh + chính xác 100%) if isinstance(text, str): t = text.lower().strip().rstrip(".").rstrip(",").strip() # Xử lý các câu trả lời dài bắt đầu bằng Yes/No của LLaVA (vd: "No, the image does not...") if t.startswith("yes"): return "có" if t.startswith("no"): return "không" # Exact match trước direct_map = { "true": "có", "false": "không", "correct": "có", "incorrect": "không", "present": "có", "absent": "không", "normal": "bình thường", "abnormal": "bất thường", } if t in direct_map: return direct_map[t] # 2. Dịch bằng MedCrab self._lazy_load() if not self._en2vi_ready: if isinstance(text, list): return text return text if isinstance(text, list): return [self._medcrab_translate(t) for t in text] return self._medcrab_translate(text) def _medcrab_translate(self, text): """Dịch 1 câu En→Vi bằng MedCrab với ràng buộc ngắn gọn.""" # Kiểm tra ánh xạ trực tiếp trước t = text.lower().strip().rstrip(".").rstrip(",").strip() direct_map = { "yes": "có", "no": "không", "normal": "bình thường", "abnormal": "bất thường", } if t in direct_map: return direct_map[t] try: prompt = f"English: {text}\nVietnamese (trả lời ngắn gọn):" inputs = self._en2vi_tokenizer(prompt, return_tensors="pt").to(self.gpu_device) with torch.no_grad(): outputs = self._en2vi_model.generate( **inputs, max_new_tokens=30, repetition_penalty=1.2, temperature=0.1, do_sample=False, pad_token_id=self._en2vi_tokenizer.eos_token_id ) full_text = self._en2vi_tokenizer.decode(outputs[0], skip_special_tokens=True) translated = full_text.split("Vietnamese (trả lời ngắn gọn):")[-1].strip() return translated except Exception as e: print(f"[WARNING] En→Vi error: {e}") return text