Spaces:
Paused
Paused
File size: 7,457 Bytes
d63774a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | import torch
import json
import os
from src.utils.text_utils import postprocess_answer
class MedicalTranslator:
"""
Dịch thuật y tế với cơ chế Lazy Loading + Independent Fallback.
- Vi→En: MarianMT (Helsinki-NLP) trên CPU
- En→Vi: MedCrab-1.5B (4-bit) trên GPU phụ (nếu có)
Mỗi model load độc lập — nếu 1 cái fail, cái kia vẫn hoạt động.
"""
def __init__(self, device="cpu", dict_path="data/medical_dict.json"):
self.device_str = device # "cuda" hoặc "cpu"
# Chọn GPU: nếu Dual GPU → dùng cuda:1, nếu Single → dùng cuda:0
if torch.cuda.is_available() and device == "cuda":
if torch.cuda.device_count() > 1:
self.gpu_device = torch.device("cuda:1")
print(f"[INFO] Dual-GPU detected → Translator on {self.gpu_device}")
else:
self.gpu_device = torch.device("cuda:0")
else:
self.gpu_device = torch.device("cpu")
# State flags
self._load_attempted = False
self._vi2en_ready = False
self._en2vi_ready = False
# Models (lazy)
self._vi2en_model = None
self._vi2en_tokenizer = None
self._en2vi_model = None
self._en2vi_tokenizer = None
# Medical dictionary
self.med_dict = {}
if os.path.exists(dict_path):
try:
with open(dict_path, 'r', encoding='utf-8') as f:
self.med_dict = json.load(f)
except:
pass
def _lazy_load(self):
"""Nạp models. Chỉ gọi 1 lần duy nhất."""
if self._load_attempted:
return
self._load_attempted = True
print("[INFO] Đang nạp Translation Models (Lazy Load)...")
# ── 1. Helsinki-NLP Vi→En (Chạy trên CPU, nhẹ ~300MB) ──
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
vi2en_id = "Helsinki-NLP/opus-mt-vi-en"
self._vi2en_tokenizer = AutoTokenizer.from_pretrained(vi2en_id)
self._vi2en_model = AutoModelForSeq2SeqLM.from_pretrained(vi2en_id).to("cpu")
self._vi2en_model.eval()
self._vi2en_ready = True
print("[INFO] ✅ Helsinki-NLP (Vi→En) đã sẵn sàng trên CPU")
except Exception as e:
print(f"[WARNING] ❌ Helsinki-NLP load thất bại: {e}")
# ── 2. MedCrab En→Vi (4-bit trên GPU) ──
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
medcrab_id = "pnnbao-ump/MedCrab-1.5B"
self._en2vi_tokenizer = AutoTokenizer.from_pretrained(medcrab_id)
d_map = {"": self.gpu_device} if self.gpu_device.type == "cuda" else None
self._en2vi_model = AutoModelForCausalLM.from_pretrained(
medcrab_id,
quantization_config=bnb_config,
device_map=d_map,
low_cpu_mem_usage=True
)
self._en2vi_model.eval()
self._en2vi_ready = True
print(f"[INFO] ✅ MedCrab-1.5B (En→Vi) đã sẵn sàng trên {self.gpu_device}")
except Exception as e:
print(f"[WARNING] ❌ MedCrab load thất bại: {e}")
# ── Vi → En ──
def translate_vi2en(self, text):
"""Dịch câu hỏi Tiếng Việt sang Tiếng Anh."""
if not text:
return text
self._lazy_load()
if not self._vi2en_ready:
# Fallback: trả về nguyên văn (LLaVA vẫn hiểu được một phần)
return text
try:
texts = text if isinstance(text, list) else [text]
results = []
for t in texts:
inputs = self._vi2en_tokenizer(t, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
output_ids = self._vi2en_model.generate(**inputs, max_new_tokens=128)
translated = self._vi2en_tokenizer.decode(output_ids[0], skip_special_tokens=True)
results.append(translated)
return results if isinstance(text, list) else results[0]
except Exception as e:
print(f"[WARNING] Vi→En error: {e}")
return text
# ── En → Vi ──
def translate_en2vi(self, text):
"""Dịch kết quả từ LLaVA-Med sang Tiếng Việt."""
if not text:
return text
# 1. Ánh xạ trực tiếp nhãn nhị phân (nhanh + chính xác 100%)
if isinstance(text, str):
t = text.lower().strip().rstrip(".").rstrip(",").strip()
# Xử lý các câu trả lời dài bắt đầu bằng Yes/No của LLaVA (vd: "No, the image does not...")
if t.startswith("yes"):
return "có"
if t.startswith("no"):
return "không"
# Exact match trước
direct_map = {
"true": "có", "false": "không",
"correct": "có", "incorrect": "không",
"present": "có", "absent": "không",
"normal": "bình thường", "abnormal": "bất thường",
}
if t in direct_map:
return direct_map[t]
# 2. Dịch bằng MedCrab
self._lazy_load()
if not self._en2vi_ready:
if isinstance(text, list):
return text
return text
if isinstance(text, list):
return [self._medcrab_translate(t) for t in text]
return self._medcrab_translate(text)
def _medcrab_translate(self, text):
"""Dịch 1 câu En→Vi bằng MedCrab với ràng buộc ngắn gọn."""
# Kiểm tra ánh xạ trực tiếp trước
t = text.lower().strip().rstrip(".").rstrip(",").strip()
direct_map = {
"yes": "có", "no": "không",
"normal": "bình thường", "abnormal": "bất thường",
}
if t in direct_map:
return direct_map[t]
try:
prompt = f"English: {text}\nVietnamese (trả lời ngắn gọn):"
inputs = self._en2vi_tokenizer(prompt, return_tensors="pt").to(self.gpu_device)
with torch.no_grad():
outputs = self._en2vi_model.generate(
**inputs,
max_new_tokens=30,
repetition_penalty=1.2,
temperature=0.1,
do_sample=False,
pad_token_id=self._en2vi_tokenizer.eos_token_id
)
full_text = self._en2vi_tokenizer.decode(outputs[0], skip_special_tokens=True)
translated = full_text.split("Vietnamese (trả lời ngắn gọn):")[-1].strip()
return translated
except Exception as e:
print(f"[WARNING] En→Vi error: {e}")
return text
|