Medical-VQA / src /engine /medical_eval.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
import torch
from tqdm import tqdm
from src.utils.metrics import batch_metrics, compute_bertscore, compute_semantic_score
from src.utils.text_utils import is_medical_term_compliant, normalize_answer, postprocess_answer
def normalize_for_metric(text: str) -> str:
return text.strip().lower()
def _normalize_closed_answer(question_vi: str, question_en: str, pred_vi: str, pred_en: str = "") -> str:
"""Map descriptive yes/no-style outputs to closed-form labels."""
question_vi_norm = normalize_answer(question_vi)
question_en_norm = normalize_answer(question_en)
pred_vi_norm = normalize_answer(pred_vi)
pred_en_norm = normalize_answer(pred_en)
combined = " ".join(part for part in [pred_vi_norm, pred_en_norm] if part).strip()
is_normality_question = any(
pattern in " ".join([question_vi_norm, question_en_norm])
for pattern in ["bình thường", "normal", "abnormal", "bat thuong"]
)
if is_normality_question:
explicit_negative_patterns = [
"không bình thường",
"not normal",
]
explicit_positive_patterns = [
"cΓ³",
"yes",
]
positive_patterns = [
"bình thường",
"normal",
"no significant abnormalities",
"no abnormality",
"unremarkable",
"appears to be normal",
"without significant abnormalities",
"khΓ΄ng phΓ‘t hiện bαΊ₯t thường",
]
negative_patterns = [
"bαΊ₯t thường",
"abnormal",
"abnormality detected",
"fracture",
"lesion",
"mass",
"effusion",
"pneumothorax",
]
if any(pattern in combined for pattern in explicit_negative_patterns):
return "khΓ΄ng"
if any(pattern in combined.split() for pattern in explicit_positive_patterns):
return "cΓ³"
if any(pattern in combined for pattern in positive_patterns):
return "cΓ³"
if any(pattern in combined for pattern in negative_patterns):
return "khΓ΄ng"
else:
positive_patterns = [
"cΓ³",
"yes",
"present",
"detected",
"positive",
]
negative_patterns = [
"khΓ΄ng",
"no",
"absent",
"not seen",
"negative",
"none",
]
# For presence/absence questions, "khΓ΄ng cΓ³ ..." contains "cΓ³" but
# semantically means no. Check negation before positive cues.
if any(pattern in combined for pattern in negative_patterns):
return "khΓ΄ng"
if any(pattern in combined for pattern in positive_patterns):
return "cΓ³"
fallback_positive_patterns = [
"bình thường",
"normal",
"no significant abnormalities",
"no abnormality",
"unremarkable",
"appears to be normal",
"without significant abnormalities",
"khΓ΄ng phΓ‘t hiện bαΊ₯t thường",
]
fallback_negative_patterns = [
"bαΊ₯t thường",
"abnormal",
"abnormality detected",
"fracture",
"lesion",
"mass",
"effusion",
"pneumothorax",
]
if any(pattern in combined for pattern in fallback_positive_patterns):
return "cΓ³"
if any(pattern in combined for pattern in fallback_negative_patterns):
return "khΓ΄ng"
return pred_vi_norm or pred_en_norm
def _compute_format_stats(preds: list[str], max_words: int) -> dict[str, float]:
if not preds:
return {
"max_10_word_compliance_rate": 0.0,
"medical_term_compliance_rate": 0.0,
"avg_answer_length": 0.0,
}
word_counts = [len(p.split()) for p in preds]
return {
"max_10_word_compliance_rate": sum(1 for count in word_counts if count <= max_words) / len(word_counts),
"medical_term_compliance_rate": sum(1 for pred in preds if is_medical_term_compliant(pred)) / len(preds),
"avg_answer_length": sum(word_counts) / len(word_counts),
}
def _build_bad_words_ids(processor, variant: str) -> list[list[int]] | None:
if variant not in {"B1", "B2", "DPO", "PPO"}:
return None
tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is None:
return None
banned_phrases = [
"yes",
"no",
"the answer is",
"the image is",
"this image is",
"the image shows",
"the scan shows",
"there is",
"there are",
"it appears",
"the finding is",
]
bad_words_ids = []
for phrase in banned_phrases:
token_ids = tokenizer.encode(phrase, add_special_tokens=False)
if token_ids:
bad_words_ids.append(token_ids)
return bad_words_ids or None
def _attach_metric_views(metrics: dict[str, float]) -> dict[str, float]:
"""Add explicit metric names while preserving backward-compatible aliases."""
if "accuracy" in metrics:
metrics["accuracy_normalized"] = metrics["accuracy"]
if "em" in metrics:
metrics["em_normalized"] = metrics["em"]
if "f1" in metrics:
metrics["f1_normalized"] = metrics["f1"]
if "bleu1" in metrics:
metrics["bleu1_normalized"] = metrics["bleu1"]
if "bleu2" in metrics:
metrics["bleu2_normalized"] = metrics["bleu2"]
if "bleu3" in metrics:
metrics["bleu3_normalized"] = metrics["bleu3"]
if "bleu4" in metrics:
metrics["bleu4_normalized"] = metrics["bleu4"]
if "rouge_l" in metrics:
metrics["rouge_l_normalized"] = metrics["rouge_l"]
if "meteor" in metrics:
metrics["meteor_normalized"] = metrics["meteor"]
if "bert_score" in metrics:
metrics["bert_score_raw"] = metrics["bert_score"]
if "semantic" in metrics:
metrics["semantic_raw"] = metrics["semantic"]
return metrics
class MedicalVQAEvaluator:
"""
Hệ thα»‘ng Δ‘Γ‘nh giΓ‘ hợp nhαΊ₯t cho cαΊ£ HΖ°α»›ng A vΓ  HΖ°α»›ng B.
"""
def __init__(self, device, tokenizer=None, processor=None):
self.device = device
self.tokenizer = tokenizer
self.processor = processor
def evaluate(self, model, dataloader, variant_type='A', beam_width=1):
"""
Giao diện chung để Δ‘Γ‘nh giΓ‘ bαΊ₯t kα»³ variant nΓ o.
"""
if variant_type == 'A':
return evaluate_vqa(model, dataloader, self.device, self.tokenizer, beam_width)
else:
return evaluate_multimodal_vqa(model, dataloader, self.device, self.processor, beam_width, variant=variant_type)
def evaluate_vqa(model, dataloader, device, tokenizer, beam_width=1, max_len=32, max_words=10):
model.eval()
all_preds = []
all_preds_raw = []
all_preds_display = []
all_refs = []
all_refs_full = []
all_is_closed = []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
images = batch['image'].to(device)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label_closed']
# [FIX] Gọi inference() để lαΊ₯y CαΊ’ HAI head outputs, truyền max_len tα»« config
logits_closed, pred_ids = model.inference(images, input_ids, attention_mask, beam_width=beam_width, max_len=max_len)
# Decode generative head + lΓ m sαΊ‘ch subword artifacts
preds_text_raw = [
postprocess_answer(t, max_words=max_words)
for t in tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
]
preds_text = list(preds_text_raw)
# [CRITICAL FIX] Vα»›i cΓ’u Đóng (Yes/No), dΓΉng classifier head thay vΓ¬ generator
closed_map = {0: "khΓ΄ng", 1: "cΓ³"}
closed_preds_idx = torch.argmax(logits_closed, dim=-1) # [B]
for i in range(len(preds_text)):
if labels[i].item() != -1: # CΓ’u hỏi Δ‘Γ³ng
preds_text[i] = closed_map[closed_preds_idx[i].item()]
preds_text[i] = postprocess_answer(preds_text[i], max_words=max_words)
# Debug: Hiển thα»‹ cαΊ£ cΓ’u Đóng vΓ  cΓ’u Mở để kiểm tra Δ‘a dαΊ‘ng
if len(all_preds) == 0:
print("\n--- DEBUG PREDICTIONS ---")
shown_closed, shown_open = 0, 0
for i in range(len(preds_text)):
is_closed = labels[i].item() != -1
if (is_closed and shown_closed < 2) or (not is_closed and shown_open < 2):
q_type = "CLOSED" if is_closed else "OPEN"
print(f"[{q_type}] Q: {batch['raw_questions'][i]}")
print(f" Pred raw: '{preds_text_raw[i]}'")
print(f" Pred normalized: '{preds_text[i]}'")
print(f" GT : '{batch['raw_answer'][i]}'")
if is_closed: shown_closed += 1
else: shown_open += 1
if shown_closed >= 2 and shown_open >= 2:
break
print("--------------------------\n")
all_preds.extend([normalize_for_metric(p) for p in preds_text])
all_preds_raw.extend([normalize_for_metric(p) for p in preds_text_raw])
all_preds_display.extend([normalize_for_metric(p) for p in preds_text_raw])
# [CRITICAL FIX] DΓΉng Δ‘Γ‘p Γ‘n TiαΊΏng Việt để chαΊ₯m Δ‘iểm
all_refs.extend([normalize_for_metric(postprocess_answer(r, max_words=max_words)) for r in batch['raw_answer']])
all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])])
is_closed = (batch['label_closed'] != -1).tolist()
all_is_closed.extend(is_closed)
metrics = batch_metrics(all_preds, all_refs)
metrics["semantic"] = compute_semantic_score(all_preds_raw, all_refs)
metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs)
metrics = _attach_metric_views(metrics)
metrics.update(_compute_format_stats(all_preds, max_words=max_words))
metrics['predictions'] = all_preds
metrics['predictions_raw'] = all_preds_raw
metrics['predictions_display'] = all_preds_display
metrics['ground_truths'] = all_refs
closed_preds = [p for p, c in zip(all_preds, all_is_closed) if c]
closed_refs = [r for r, c in zip(all_refs, all_is_closed) if c]
closed_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if c]
if closed_preds:
metrics['closed'] = batch_metrics(closed_preds, closed_refs)
metrics['closed']["semantic"] = compute_semantic_score(closed_preds_raw, closed_refs)
metrics['closed']["bert_score"] = compute_bertscore(closed_preds_raw, closed_refs)
metrics['closed'] = _attach_metric_views(metrics['closed'])
metrics['closed'].update(_compute_format_stats(closed_preds, max_words=max_words))
metrics['closed_eval'] = {
"accuracy": metrics['closed'].get("accuracy_normalized", 0.0),
"em": metrics['closed'].get("em_normalized", 0.0),
"f1": metrics['closed'].get("f1_normalized", 0.0),
"count": len(closed_preds),
}
open_preds = [p for p, c in zip(all_preds, all_is_closed) if not c]
open_refs = [r for r, c in zip(all_refs, all_is_closed) if not c]
open_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if not c]
if open_preds:
metrics['open'] = batch_metrics(open_preds, open_refs)
metrics['open']["semantic"] = compute_semantic_score(open_preds_raw, open_refs)
metrics['open']["bert_score"] = compute_bertscore(open_preds_raw, open_refs)
metrics['open'] = _attach_metric_views(metrics['open'])
metrics['open'].update(_compute_format_stats(open_preds, max_words=max_words))
metrics['open_eval'] = {
"semantic": metrics['open'].get("semantic_raw", 0.0),
"bert_score": metrics['open'].get("bert_score_raw", 0.0),
"f1": metrics['open'].get("f1_normalized", 0.0),
"rouge_l": metrics['open'].get("rouge_l_normalized", 0.0),
"count": len(open_preds),
}
metrics['long_answers_eval'] = {
"accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0),
"f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0),
"bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0),
"semantic": compute_semantic_score(all_preds_raw, all_refs_full),
"bert_score": compute_bertscore(all_preds_raw, all_refs_full)
}
return metrics
# ─────────────────────────────────────────────────────────────────────────────
# B1 HELPERS
# ─────────────────────────────────────────────────────────────────────────────
_B1_FEW_SHOT = (
"Q: Is there cardiomegaly? A: yes\n"
"Q: What organ is shown? A: lung\n"
"Q: Is the aorta normal? A: no\n"
"Q: What abnormality is present? A: pleural effusion\n"
)
def _build_b1_prompt(question_en: str, max_words: int) -> str:
"""
Few-shot prompt Γ©p LLaVA trαΊ£ lời ngαΊ―n (≀max_words tα»« y tαΊΏ), khΓ΄ng sinh cΓ’u dΓ i.
Đặt 4 vΓ­ dα»₯ in-context trΖ°α»›c cΓ’u hỏi thα»±c để suppress verbose prefix.
"""
return (
f"USER: <image>\n"
f"Answer each question with medical terminology only, "
f"no more than {max_words} words, no full sentences.\n"
f"{_B1_FEW_SHOT}"
f"Q: {question_en} A: ASSISTANT:"
)
# En β†’ Vi fast lookup (50+ thuαΊ­t ngα»― y tαΊΏ thường gαΊ·p trong SLAKE + VQA-RAD)
_EN_VI_DIRECT: dict = {
# binary
"yes": "cΓ³", "no": "khΓ΄ng",
"present": "cΓ³", "absent": "khΓ΄ng",
"normal": "bΓ¬nh thường", "abnormal": "bαΊ₯t thường",
"true": "cΓ³", "false": "khΓ΄ng",
"positive": "cΓ³", "negative": "khΓ΄ng",
# anatomy
"lung": "phα»•i", "lungs": "phα»•i",
"heart": "tim", "liver": "gan", "spleen": "lΓ‘ch",
"kidney": "thαΊ­n", "brain": "nΓ£o", "bladder": "bΓ ng quang",
"chest": "ngα»±c", "abdomen": "bα»₯ng", "pelvis": "xΖ°Ζ‘ng chαΊ­u",
"spine": "cα»™t sα»‘ng", "rib": "xΖ°Ζ‘ng sườn", "ribs": "xΖ°Ζ‘ng sườn",
"trachea": "khΓ­ quαΊ£n", "aorta": "Δ‘α»™ng mαΊ‘ch chα»§",
"diaphragm": "cΖ‘ hoΓ nh", "mediastinum": "trung thαΊ₯t",
# modality
"chest x-ray": "x-quang ngα»±c", "x-ray": "x-quang", "xray": "x-quang",
"mri": "mri", "ct": "ct", "ultrasound": "siΓͺu Γ’m",
"ct scan": "ct", "mri scan": "mri",
# planes
"axial": "mαΊ·t phαΊ³ng ngang",
"coronal": "mαΊ·t phαΊ³ng vΓ nh",
"sagittal": "mặt phẳng dọc",
"transverse": "mαΊ·t phαΊ³ng ngang",
# pathologies
"cardiomegaly": "tim to",
"pneumonia": "viΓͺm phα»•i",
"pleural effusion": "trΓ n dα»‹ch mΓ ng phα»•i",
"pneumothorax": "trΓ n khΓ­ mΓ ng phα»•i",
"fracture": "gΓ£y xΖ°Ζ‘ng",
"edema": "phù nề",
"pulmonary edema": "phΓΉ phα»•i",
"consolidation": "Δ‘Γ΄ng Δ‘αΊ·c",
"atelectasis": "xαΊΉp phα»•i",
"opacity": "mờ Δ‘α»₯c",
"mass": "khα»‘i u",
"nodule": "nα»‘t",
"lesion": "tα»•n thΖ°Ζ‘ng",
"tumor": "khα»‘i u",
"effusion": "trΓ n dα»‹ch",
"infiltrate": "thΓ’m nhiα»…m",
"fibrosis": "xΖ‘ hΓ³a",
"calcification": "vΓ΄i hΓ³a",
"carcinoma": "ung thΖ°",
"metastasis": "di căn",
"bilateral": "hai bΓͺn",
"unilateral": "mα»™t bΓͺn",
"left": "trΓ‘i", "right": "phαΊ£i",
"upper": "trΓͺn", "lower": "dΖ°α»›i",
"right upper quadrant": "phΓ­a trΓͺn bΓͺn phαΊ£i",
"left upper quadrant": "phΓ­a trΓͺn bΓͺn trΓ‘i",
"right lower quadrant": "phΓ­a dΖ°α»›i bΓͺn phαΊ£i",
"left lower quadrant": "phΓ­a dΖ°α»›i bΓͺn trΓ‘i",
"right upper": "phΓ­a trΓͺn bΓͺn phαΊ£i",
"left upper": "phΓ­a trΓͺn bΓͺn trΓ‘i",
"upper left": "phΓ­a trΓͺn bΓͺn trΓ‘i",
"upper right": "phΓ­a trΓͺn bΓͺn phαΊ£i",
"lower left": "phΓ­a dΖ°α»›i bΓͺn trΓ‘i",
"lower right": "phΓ­a dΖ°α»›i bΓͺn phαΊ£i",
}
def _extract_key_medical_term(raw_en: str, max_words: int) -> str:
"""
Loẑi bỏ verbose prefix LLaVA hay sinh ("The image shows a chest X-ray with..."),
chỉ giα»― lαΊ‘i thuαΊ­t ngα»― y tαΊΏ chΓ­nh.
"""
import re
text = raw_en.strip().lower()
# CΓ‘c prefix verbose phα»• biαΊΏn cαΊ§n xΓ³a
prefixes = [
r"^the (image|scan|x-ray|xray|mri|ct|picture|photo|radiograph) (shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
r"^based on the (image|scan|x-ray|mri|ct)\s*,?\s*",
r"^in (this|the) (image|scan|x-ray|mri|ct)\s*,?\s*",
r"^i (can see|observe|notice|see)\s+",
r"^there (is|are)\s+(a |an |some )?",
r"^(it |this )(shows?|is|appears?|looks?)\s+(like\s+)?",
r"^the (patient|subject)\s+(has|shows?|presents?)\s+",
r"^(a|an|the)\s+",
r"^[a-z\s]+ is (located|seen|found|present)( in| at| on)?\s+(the\s+)?",
]
for pat in prefixes:
text = re.sub(pat, "", text)
text = re.sub(r"[.!?,;:]+$", "", text).strip()
text = re.sub(r"\s+", " ", text).strip()
words = text.split()
return " ".join(words[:max_words]) if words else raw_en.strip()
def _en_to_vi_direct(en_text: str) -> str | None:
"""
Tra tα»« Δ‘iển nhanh. SαΊ―p xαΊΏp theo Δ‘α»™ dΓ i giαΊ£m dαΊ§n để phrase dΓ i match trΖ°α»›c.
TrαΊ£ về None nαΊΏu khΓ΄ng match β†’ caller dΓΉng Translation Model.
"""
norm = en_text.strip().lower()
if norm in _EN_VI_DIRECT:
return _EN_VI_DIRECT[norm]
return None
def _dual_score_open(
preds_vi: list,
preds_en: list,
refs_vi: list,
refs_en: list,
) -> list:
"""
Vα»›i mα»—i cΓ’u hỏi mở, so sΓ‘nh F1 Vi vs F1 En rα»“i chọn prediction tα»‘t hΖ‘n.
GiαΊ£i quyαΊΏt 0% open-ended do dα»‹ch thuαΊ­t mαΊ₯t nghΔ©a.
"""
from src.utils.metrics import compute_f1
from src.utils.text_utils import normalize_answer
result = []
for pv, pe, rv, re_ in zip(preds_vi, preds_en, refs_vi, refs_en):
f1_vi = compute_f1(pv, rv)
f1_en = compute_f1(normalize_answer(pe), normalize_answer(re_)) if re_ else 0.0
result.append(pv if f1_vi >= f1_en else normalize_answer(pe))
return result
def evaluate_multimodal_vqa(
model,
dataloader,
device,
processor,
beam_width=1,
max_words=10,
variant='B1',
beam_width_closed=None,
beam_width_open=None,
max_new_tokens_closed=None,
max_new_tokens_open=None,
generation_batch_size=None,
):
"""
B1 Zero-Shot evaluation & B2/DPO/PPO Fine-Tuned evaluation.
"""
model.eval()
all_preds = []
all_preds_raw = []
all_preds_display = []
all_preds_en = []
all_refs = []
all_refs_full = []
all_refs_en = []
all_is_closed = []
from src.utils.translator import MedicalTranslator
translator = MedicalTranslator(device=device.type)
from src.models.multimodal_vqa import MultimodalVQA
wrapper = MultimodalVQA()
beam_width_closed = beam_width if beam_width_closed is None else beam_width_closed
beam_width_open = beam_width if beam_width_open is None else beam_width_open
max_new_tokens_closed = 4 if max_new_tokens_closed is None else max_new_tokens_closed
max_new_tokens_open = (max_words + 6) if max_new_tokens_open is None else max_new_tokens_open
generation_batch_size = 1 if generation_batch_size is None else max(1, int(generation_batch_size))
bad_words_ids = _build_bad_words_ids(processor, variant)
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating {variant}")):
raw_images = batch.get('raw_image')
questions_vi = batch.get('raw_questions', [])
questions_en = batch.get('raw_questions_en', [])
refs_vi_raw = batch.get('raw_answer', [])
refs_en_raw = batch.get('raw_answer_en', [])
labels = batch['label_closed']
if variant == 'B1':
# B1 (Zero-shot) needs English translation & English few-shot prompt
if not questions_en or any(not str(q).strip() for q in questions_en):
questions_en = translator.translate_vi2en(questions_vi)
prompts = [_build_b1_prompt(q, max_words) for q in questions_en]
else:
# B2 / DPO / PPO (Fine-tuned) expect Vietnamese instruction directly
prompts = [wrapper.build_instruction_prompt(q, language="vi", include_answer=False) for q in questions_vi]
preds_raw = [""] * len(prompts)
closed_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl != -1]
open_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl == -1]
def _run_generation(sample_indices, num_beams, max_new_tokens):
if not sample_indices:
return []
decoded_outputs = []
chunk_size = generation_batch_size if num_beams > 1 else max(generation_batch_size, 2)
for start in range(0, len(sample_indices), chunk_size):
chunk_indices = sample_indices[start:start + chunk_size]
text_subset = [prompts[i] for i in chunk_indices]
image_subset = [raw_images[i] for i in chunk_indices] if raw_images is not None else None
if image_subset is not None:
inputs = processor(
text=text_subset,
images=image_subset,
return_tensors="pt",
padding=True,
).to(device)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
else:
inputs = processor(text=text_subset, return_tensors="pt", padding=True).to(device)
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=num_beams,
early_stopping=num_beams > 1,
bad_words_ids=bad_words_ids,
)
input_token_len = inputs.input_ids.shape[1]
decoded_outputs.extend(
processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
)
del inputs, output_ids
if device.type == "cuda":
torch.cuda.empty_cache()
return decoded_outputs
if variant == 'B1':
generated = _run_generation(list(range(len(prompts))), beam_width_open, max_new_tokens_open)
preds_raw = generated
else:
for idx, pred in zip(closed_idx, _run_generation(closed_idx, beam_width_closed, max_new_tokens_closed)):
preds_raw[idx] = pred
for idx, pred in zip(open_idx, _run_generation(open_idx, beam_width_open, max_new_tokens_open)):
preds_raw[idx] = pred
preds_vi = []
preds_vi_display = []
preds_en_clean = []
if variant == 'B1':
# [FIX 2] Strip verbose prefix β†’ giα»― key medical term. TrΓ‘nh cαΊ―t vα»₯n cΓ’u tiαΊΏng Anh để Dα»‹ch thuαΊ­t hiểu Δ‘ΓΊng.
preds_en_clean = [_extract_key_medical_term(p, 50) for p in preds_raw]
# [FIX 3 + 5] Per-sample: closed β†’ normalize En trΖ°α»›c; open β†’ dict lookup rα»“i Translation Model
needs_translate_idx = [] # index cαΊ§n dα»‹ch
needs_translate_txt = []
for i, pred_en in enumerate(preds_en_clean):
if labels[i].item() != -1:
# Closed: dΓΉng _normalize_closed_answer vα»›i En pred (chΓ­nh xΓ‘c hΖ‘n)
preds_vi.append(
_normalize_closed_answer(
questions_vi[i], questions_en[i], pred_en, pred_en
)
)
else:
# Open: thα»­ dict nhanh trΖ°α»›c
vi_direct = _en_to_vi_direct(pred_en)
if vi_direct is not None:
preds_vi.append(postprocess_answer(vi_direct, max_words=max_words))
else:
preds_vi.append(None) # placeholder
needs_translate_idx.append(i)
needs_translate_txt.append(pred_en)
# Batch dα»‹ch nhα»―ng cΓ’u cαΊ§n Translation Model
if needs_translate_txt:
translated = translator.translate_en2vi(needs_translate_txt)
if isinstance(translated, str):
translated = [translated]
for idx, vi in zip(needs_translate_idx, translated):
preds_vi[idx] = postprocess_answer(vi, max_words=max_words)
preds_vi_display = list(preds_vi)
else:
# B2 / DPO / PPO directly outputs Vietnamese, no translation needed
preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_raw]
for i, pred_vi in enumerate(preds_raw):
if labels[i].item() != -1:
preds_vi.append(
_normalize_closed_answer(
questions_vi[i], questions_en[i] if i < len(questions_en) else "", pred_vi
)
)
else:
preds_vi.append(pred_vi)
preds_en_clean = [""] * len(preds_raw)
# Đảm bảo không có None
preds_vi = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi]
preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi_display]
preds_vi_raw = list(preds_vi_display)
# Refs
refs_vi = [postprocess_answer(r, max_words=max_words) for r in refs_vi_raw]
refs_en = [postprocess_answer(r, max_words=max_words) if r else "" for r in refs_en_raw]
# Debug batch Δ‘αΊ§u
if batch_idx == 0:
print(f"\n--- DEBUG {variant} (Evaluation) ---")
for i in range(min(4, len(preds_vi))):
q_type = "CLOSED" if labels[i].item() != -1 else "OPEN"
if variant == 'B1':
print(f"[{q_type}] Q (En): {questions_en[i]}")
print(f" Pred (En raw): '{preds_raw[i]}'")
print(f" Pred (En clean): '{preds_en_clean[i]}'")
else:
print(f"[{q_type}] Q (Vi): {questions_vi[i]}")
print(f" Pred (Vi raw): '{preds_raw[i]}'")
print(f" Pred display: '{preds_vi_display[i]}'")
print(f" Pred (Vi): '{preds_vi[i]}'")
print(f" GT (Vi): '{refs_vi[i]}' | GT (En): '{refs_en[i]}'")
print("-----------------------------------------\n")
all_preds.extend([normalize_for_metric(p) for p in preds_vi])
all_preds_raw.extend([normalize_for_metric(p) for p in preds_vi_raw])
all_preds_display.extend([normalize_for_metric(p) for p in preds_vi_display])
all_preds_en.extend([normalize_for_metric(p) for p in preds_en_clean])
all_refs.extend([normalize_for_metric(r) for r in refs_vi])
all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])])
all_refs_en.extend([normalize_for_metric(r) for r in refs_en])
all_is_closed.extend((labels != -1).tolist())
# [FIX 4] Dual-language scoring cho open-ended (chỉ dΓΉng cho B1)
if variant == 'B1':
open_idx = [i for i, c in enumerate(all_is_closed) if not c]
if open_idx:
best_open = _dual_score_open(
[all_preds[i] for i in open_idx],
[all_preds_en[i] for i in open_idx],
[all_refs[i] for i in open_idx],
[all_refs_en[i] for i in open_idx],
)
for k, i in enumerate(open_idx):
all_preds[i] = best_open[k]
# ── Compute metrics ──────────────────────────────────────────────────────
metrics = batch_metrics(all_preds, all_refs)
metrics["semantic"] = compute_semantic_score(all_preds_raw, all_refs)
metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs)
metrics = _attach_metric_views(metrics)
metrics.update(_compute_format_stats(all_preds, max_words=max_words))
metrics['predictions'] = all_preds
metrics['predictions_raw'] = all_preds_raw
metrics['predictions_display'] = all_preds_display
metrics['predictions_en'] = all_preds_en
metrics['ground_truths'] = all_refs
metrics['ground_truths_en'] = all_refs_en
def _subset(pred_list, ref_list, pred_raw_list):
m = batch_metrics(pred_list, ref_list)
m["semantic"] = compute_semantic_score(pred_raw_list, ref_list)
m["bert_score"] = compute_bertscore(pred_raw_list, ref_list)
m = _attach_metric_views(m)
m.update(_compute_format_stats(pred_list, max_words=max_words))
return m
closed_idx = [i for i, c in enumerate(all_is_closed) if c]
open_idx = [i for i, c in enumerate(all_is_closed) if not c]
if closed_idx:
metrics['closed'] = _subset(
[all_preds[i] for i in closed_idx],
[all_refs[i] for i in closed_idx],
[all_preds_raw[i] for i in closed_idx],
)
metrics['closed_eval'] = {
"accuracy": metrics['closed'].get("accuracy_normalized", 0.0),
"em": metrics['closed'].get("em_normalized", 0.0),
"f1": metrics['closed'].get("f1_normalized", 0.0),
"count": len(closed_idx),
}
if open_idx:
metrics['open'] = _subset(
[all_preds[i] for i in open_idx],
[all_refs[i] for i in open_idx],
[all_preds_raw[i] for i in open_idx],
)
metrics['open_eval'] = {
"semantic": metrics['open'].get("semantic_raw", 0.0),
"bert_score": metrics['open'].get("bert_score_raw", 0.0),
"f1": metrics['open'].get("f1_normalized", 0.0),
"rouge_l": metrics['open'].get("rouge_l_normalized", 0.0),
"count": len(open_idx),
}
metrics['long_answers_eval'] = {
"accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0),
"f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0),
"bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0),
"semantic": compute_semantic_score(all_preds_raw, all_refs_full),
"bert_score": compute_bertscore(all_preds_raw, all_refs_full)
}
return metrics