Medical-VQA / src /utils /translator.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
import torch
import json
import os
from src.utils.text_utils import postprocess_answer
class MedicalTranslator:
"""
Dịch thuật y tế với cơ chế Lazy Loading + Independent Fallback.
- Vi→En: MarianMT (Helsinki-NLP) trên CPU
- En→Vi: MedCrab-1.5B (4-bit) trên GPU phụ (nếu có)
Mỗi model load độc lập — nếu 1 cái fail, cái kia vẫn hoạt động.
"""
def __init__(self, device="cpu", dict_path="data/medical_dict.json"):
self.device_str = device # "cuda" hoặc "cpu"
# Chọn GPU: nếu Dual GPU → dùng cuda:1, nếu Single → dùng cuda:0
if torch.cuda.is_available() and device == "cuda":
if torch.cuda.device_count() > 1:
self.gpu_device = torch.device("cuda:1")
print(f"[INFO] Dual-GPU detected → Translator on {self.gpu_device}")
else:
self.gpu_device = torch.device("cuda:0")
else:
self.gpu_device = torch.device("cpu")
# State flags
self._load_attempted = False
self._vi2en_ready = False
self._en2vi_ready = False
# Models (lazy)
self._vi2en_model = None
self._vi2en_tokenizer = None
self._en2vi_model = None
self._en2vi_tokenizer = None
# Medical dictionary
self.med_dict = {}
if os.path.exists(dict_path):
try:
with open(dict_path, 'r', encoding='utf-8') as f:
self.med_dict = json.load(f)
except:
pass
def _lazy_load(self):
"""Nạp models. Chỉ gọi 1 lần duy nhất."""
if self._load_attempted:
return
self._load_attempted = True
print("[INFO] Đang nạp Translation Models (Lazy Load)...")
# ── 1. Helsinki-NLP Vi→En (Chạy trên CPU, nhẹ ~300MB) ──
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
vi2en_id = "Helsinki-NLP/opus-mt-vi-en"
self._vi2en_tokenizer = AutoTokenizer.from_pretrained(vi2en_id)
self._vi2en_model = AutoModelForSeq2SeqLM.from_pretrained(vi2en_id).to("cpu")
self._vi2en_model.eval()
self._vi2en_ready = True
print("[INFO] ✅ Helsinki-NLP (Vi→En) đã sẵn sàng trên CPU")
except Exception as e:
print(f"[WARNING] ❌ Helsinki-NLP load thất bại: {e}")
# ── 2. MedCrab En→Vi (4-bit trên GPU) ──
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
medcrab_id = "pnnbao-ump/MedCrab-1.5B"
self._en2vi_tokenizer = AutoTokenizer.from_pretrained(medcrab_id)
d_map = {"": self.gpu_device} if self.gpu_device.type == "cuda" else None
self._en2vi_model = AutoModelForCausalLM.from_pretrained(
medcrab_id,
quantization_config=bnb_config,
device_map=d_map,
low_cpu_mem_usage=True
)
self._en2vi_model.eval()
self._en2vi_ready = True
print(f"[INFO] ✅ MedCrab-1.5B (En→Vi) đã sẵn sàng trên {self.gpu_device}")
except Exception as e:
print(f"[WARNING] ❌ MedCrab load thất bại: {e}")
# ── Vi → En ──
def translate_vi2en(self, text):
"""Dịch câu hỏi Tiếng Việt sang Tiếng Anh."""
if not text:
return text
self._lazy_load()
if not self._vi2en_ready:
# Fallback: trả về nguyên văn (LLaVA vẫn hiểu được một phần)
return text
try:
texts = text if isinstance(text, list) else [text]
results = []
for t in texts:
inputs = self._vi2en_tokenizer(t, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
output_ids = self._vi2en_model.generate(**inputs, max_new_tokens=128)
translated = self._vi2en_tokenizer.decode(output_ids[0], skip_special_tokens=True)
results.append(translated)
return results if isinstance(text, list) else results[0]
except Exception as e:
print(f"[WARNING] Vi→En error: {e}")
return text
# ── En → Vi ──
def translate_en2vi(self, text):
"""Dịch kết quả từ LLaVA-Med sang Tiếng Việt."""
if not text:
return text
# 1. Ánh xạ trực tiếp nhãn nhị phân (nhanh + chính xác 100%)
if isinstance(text, str):
t = text.lower().strip().rstrip(".").rstrip(",").strip()
# Xử lý các câu trả lời dài bắt đầu bằng Yes/No của LLaVA (vd: "No, the image does not...")
if t.startswith("yes"):
return "có"
if t.startswith("no"):
return "không"
# Exact match trước
direct_map = {
"true": "có", "false": "không",
"correct": "có", "incorrect": "không",
"present": "có", "absent": "không",
"normal": "bình thường", "abnormal": "bất thường",
}
if t in direct_map:
return direct_map[t]
# 2. Dịch bằng MedCrab
self._lazy_load()
if not self._en2vi_ready:
if isinstance(text, list):
return text
return text
if isinstance(text, list):
return [self._medcrab_translate(t) for t in text]
return self._medcrab_translate(text)
def _medcrab_translate(self, text):
"""Dịch 1 câu En→Vi bằng MedCrab với ràng buộc ngắn gọn."""
# Kiểm tra ánh xạ trực tiếp trước
t = text.lower().strip().rstrip(".").rstrip(",").strip()
direct_map = {
"yes": "có", "no": "không",
"normal": "bình thường", "abnormal": "bất thường",
}
if t in direct_map:
return direct_map[t]
try:
prompt = f"English: {text}\nVietnamese (trả lời ngắn gọn):"
inputs = self._en2vi_tokenizer(prompt, return_tensors="pt").to(self.gpu_device)
with torch.no_grad():
outputs = self._en2vi_model.generate(
**inputs,
max_new_tokens=30,
repetition_penalty=1.2,
temperature=0.1,
do_sample=False,
pad_token_id=self._en2vi_tokenizer.eos_token_id
)
full_text = self._en2vi_tokenizer.decode(outputs[0], skip_special_tokens=True)
translated = full_text.split("Vietnamese (trả lời ngắn gọn):")[-1].strip()
return translated
except Exception as e:
print(f"[WARNING] En→Vi error: {e}")
return text