Spaces:
Paused
Paused
| import torch | |
| from tqdm import tqdm | |
| from src.utils.metrics import batch_metrics, compute_bertscore, compute_semantic_score | |
| from src.utils.text_utils import is_medical_term_compliant, normalize_answer, postprocess_answer | |
| def normalize_for_metric(text: str) -> str: | |
| return text.strip().lower() | |
| def _normalize_closed_answer(question_vi: str, question_en: str, pred_vi: str, pred_en: str = "") -> str: | |
| """Map descriptive yes/no-style outputs to closed-form labels.""" | |
| question_vi_norm = normalize_answer(question_vi) | |
| question_en_norm = normalize_answer(question_en) | |
| pred_vi_norm = normalize_answer(pred_vi) | |
| pred_en_norm = normalize_answer(pred_en) | |
| combined = " ".join(part for part in [pred_vi_norm, pred_en_norm] if part).strip() | |
| is_normality_question = any( | |
| pattern in " ".join([question_vi_norm, question_en_norm]) | |
| for pattern in ["bΓ¬nh thΖ°α»ng", "normal", "abnormal", "bat thuong"] | |
| ) | |
| if is_normality_question: | |
| explicit_negative_patterns = [ | |
| "khΓ΄ng bΓ¬nh thΖ°α»ng", | |
| "not normal", | |
| ] | |
| explicit_positive_patterns = [ | |
| "cΓ³", | |
| "yes", | |
| ] | |
| positive_patterns = [ | |
| "bΓ¬nh thΖ°α»ng", | |
| "normal", | |
| "no significant abnormalities", | |
| "no abnormality", | |
| "unremarkable", | |
| "appears to be normal", | |
| "without significant abnormalities", | |
| "khΓ΄ng phΓ‘t hiα»n bαΊ₯t thΖ°α»ng", | |
| ] | |
| negative_patterns = [ | |
| "bαΊ₯t thΖ°α»ng", | |
| "abnormal", | |
| "abnormality detected", | |
| "fracture", | |
| "lesion", | |
| "mass", | |
| "effusion", | |
| "pneumothorax", | |
| ] | |
| if any(pattern in combined for pattern in explicit_negative_patterns): | |
| return "khΓ΄ng" | |
| if any(pattern in combined.split() for pattern in explicit_positive_patterns): | |
| return "cΓ³" | |
| if any(pattern in combined for pattern in positive_patterns): | |
| return "cΓ³" | |
| if any(pattern in combined for pattern in negative_patterns): | |
| return "khΓ΄ng" | |
| else: | |
| positive_patterns = [ | |
| "cΓ³", | |
| "yes", | |
| "present", | |
| "detected", | |
| "positive", | |
| ] | |
| negative_patterns = [ | |
| "khΓ΄ng", | |
| "no", | |
| "absent", | |
| "not seen", | |
| "negative", | |
| "none", | |
| ] | |
| # For presence/absence questions, "khΓ΄ng cΓ³ ..." contains "cΓ³" but | |
| # semantically means no. Check negation before positive cues. | |
| if any(pattern in combined for pattern in negative_patterns): | |
| return "khΓ΄ng" | |
| if any(pattern in combined for pattern in positive_patterns): | |
| return "cΓ³" | |
| fallback_positive_patterns = [ | |
| "bΓ¬nh thΖ°α»ng", | |
| "normal", | |
| "no significant abnormalities", | |
| "no abnormality", | |
| "unremarkable", | |
| "appears to be normal", | |
| "without significant abnormalities", | |
| "khΓ΄ng phΓ‘t hiα»n bαΊ₯t thΖ°α»ng", | |
| ] | |
| fallback_negative_patterns = [ | |
| "bαΊ₯t thΖ°α»ng", | |
| "abnormal", | |
| "abnormality detected", | |
| "fracture", | |
| "lesion", | |
| "mass", | |
| "effusion", | |
| "pneumothorax", | |
| ] | |
| if any(pattern in combined for pattern in fallback_positive_patterns): | |
| return "cΓ³" | |
| if any(pattern in combined for pattern in fallback_negative_patterns): | |
| return "khΓ΄ng" | |
| return pred_vi_norm or pred_en_norm | |
| def _compute_format_stats(preds: list[str], max_words: int) -> dict[str, float]: | |
| if not preds: | |
| return { | |
| "max_10_word_compliance_rate": 0.0, | |
| "medical_term_compliance_rate": 0.0, | |
| "avg_answer_length": 0.0, | |
| } | |
| word_counts = [len(p.split()) for p in preds] | |
| return { | |
| "max_10_word_compliance_rate": sum(1 for count in word_counts if count <= max_words) / len(word_counts), | |
| "medical_term_compliance_rate": sum(1 for pred in preds if is_medical_term_compliant(pred)) / len(preds), | |
| "avg_answer_length": sum(word_counts) / len(word_counts), | |
| } | |
| def _build_bad_words_ids(processor, variant: str) -> list[list[int]] | None: | |
| if variant not in {"B1", "B2", "DPO", "PPO"}: | |
| return None | |
| tokenizer = getattr(processor, "tokenizer", None) | |
| if tokenizer is None: | |
| return None | |
| banned_phrases = [ | |
| "yes", | |
| "no", | |
| "the answer is", | |
| "the image is", | |
| "this image is", | |
| "the image shows", | |
| "the scan shows", | |
| "there is", | |
| "there are", | |
| "it appears", | |
| "the finding is", | |
| ] | |
| bad_words_ids = [] | |
| for phrase in banned_phrases: | |
| token_ids = tokenizer.encode(phrase, add_special_tokens=False) | |
| if token_ids: | |
| bad_words_ids.append(token_ids) | |
| return bad_words_ids or None | |
| def _attach_metric_views(metrics: dict[str, float]) -> dict[str, float]: | |
| """Add explicit metric names while preserving backward-compatible aliases.""" | |
| if "accuracy" in metrics: | |
| metrics["accuracy_normalized"] = metrics["accuracy"] | |
| if "em" in metrics: | |
| metrics["em_normalized"] = metrics["em"] | |
| if "f1" in metrics: | |
| metrics["f1_normalized"] = metrics["f1"] | |
| if "bleu1" in metrics: | |
| metrics["bleu1_normalized"] = metrics["bleu1"] | |
| if "bleu2" in metrics: | |
| metrics["bleu2_normalized"] = metrics["bleu2"] | |
| if "bleu3" in metrics: | |
| metrics["bleu3_normalized"] = metrics["bleu3"] | |
| if "bleu4" in metrics: | |
| metrics["bleu4_normalized"] = metrics["bleu4"] | |
| if "rouge_l" in metrics: | |
| metrics["rouge_l_normalized"] = metrics["rouge_l"] | |
| if "meteor" in metrics: | |
| metrics["meteor_normalized"] = metrics["meteor"] | |
| if "bert_score" in metrics: | |
| metrics["bert_score_raw"] = metrics["bert_score"] | |
| if "semantic" in metrics: | |
| metrics["semantic_raw"] = metrics["semantic"] | |
| return metrics | |
| class MedicalVQAEvaluator: | |
| """ | |
| Hα» thα»ng ΔΓ‘nh giΓ‘ hợp nhαΊ₯t cho cαΊ£ HΖ°α»ng A vΓ HΖ°α»ng B. | |
| """ | |
| def __init__(self, device, tokenizer=None, processor=None): | |
| self.device = device | |
| self.tokenizer = tokenizer | |
| self.processor = processor | |
| def evaluate(self, model, dataloader, variant_type='A', beam_width=1): | |
| """ | |
| Giao diα»n chung Δα» ΔΓ‘nh giΓ‘ bαΊ₯t kα»³ variant nΓ o. | |
| """ | |
| if variant_type == 'A': | |
| return evaluate_vqa(model, dataloader, self.device, self.tokenizer, beam_width) | |
| else: | |
| return evaluate_multimodal_vqa(model, dataloader, self.device, self.processor, beam_width, variant=variant_type) | |
| def evaluate_vqa(model, dataloader, device, tokenizer, beam_width=1, max_len=32, max_words=10): | |
| model.eval() | |
| all_preds = [] | |
| all_preds_raw = [] | |
| all_preds_display = [] | |
| all_refs = [] | |
| all_refs_full = [] | |
| all_is_closed = [] | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, desc="Evaluating"): | |
| images = batch['image'].to(device) | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['label_closed'] | |
| # [FIX] Gα»i inference() Δα» lαΊ₯y CαΊ’ HAI head outputs, truyα»n max_len tα»« config | |
| logits_closed, pred_ids = model.inference(images, input_ids, attention_mask, beam_width=beam_width, max_len=max_len) | |
| # Decode generative head + lΓ m sαΊ‘ch subword artifacts | |
| preds_text_raw = [ | |
| postprocess_answer(t, max_words=max_words) | |
| for t in tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
| ] | |
| preds_text = list(preds_text_raw) | |
| # [CRITICAL FIX] Vα»i cΓ’u ΔΓ³ng (Yes/No), dΓΉng classifier head thay vΓ¬ generator | |
| closed_map = {0: "khΓ΄ng", 1: "cΓ³"} | |
| closed_preds_idx = torch.argmax(logits_closed, dim=-1) # [B] | |
| for i in range(len(preds_text)): | |
| if labels[i].item() != -1: # CΓ’u hα»i ΔΓ³ng | |
| preds_text[i] = closed_map[closed_preds_idx[i].item()] | |
| preds_text[i] = postprocess_answer(preds_text[i], max_words=max_words) | |
| # Debug: Hiα»n thα» cαΊ£ cΓ’u ΔΓ³ng vΓ cΓ’u Mα» Δα» kiα»m tra Δa dαΊ‘ng | |
| if len(all_preds) == 0: | |
| print("\n--- DEBUG PREDICTIONS ---") | |
| shown_closed, shown_open = 0, 0 | |
| for i in range(len(preds_text)): | |
| is_closed = labels[i].item() != -1 | |
| if (is_closed and shown_closed < 2) or (not is_closed and shown_open < 2): | |
| q_type = "CLOSED" if is_closed else "OPEN" | |
| print(f"[{q_type}] Q: {batch['raw_questions'][i]}") | |
| print(f" Pred raw: '{preds_text_raw[i]}'") | |
| print(f" Pred normalized: '{preds_text[i]}'") | |
| print(f" GT : '{batch['raw_answer'][i]}'") | |
| if is_closed: shown_closed += 1 | |
| else: shown_open += 1 | |
| if shown_closed >= 2 and shown_open >= 2: | |
| break | |
| print("--------------------------\n") | |
| all_preds.extend([normalize_for_metric(p) for p in preds_text]) | |
| all_preds_raw.extend([normalize_for_metric(p) for p in preds_text_raw]) | |
| all_preds_display.extend([normalize_for_metric(p) for p in preds_text_raw]) | |
| # [CRITICAL FIX] DΓΉng ΔΓ‘p Γ‘n TiαΊΏng Viα»t Δα» chαΊ₯m Δiα»m | |
| all_refs.extend([normalize_for_metric(postprocess_answer(r, max_words=max_words)) for r in batch['raw_answer']]) | |
| all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])]) | |
| is_closed = (batch['label_closed'] != -1).tolist() | |
| all_is_closed.extend(is_closed) | |
| metrics = batch_metrics(all_preds, all_refs) | |
| metrics["semantic"] = compute_semantic_score(all_preds_raw, all_refs) | |
| metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs) | |
| metrics = _attach_metric_views(metrics) | |
| metrics.update(_compute_format_stats(all_preds, max_words=max_words)) | |
| metrics['predictions'] = all_preds | |
| metrics['predictions_raw'] = all_preds_raw | |
| metrics['predictions_display'] = all_preds_display | |
| metrics['ground_truths'] = all_refs | |
| closed_preds = [p for p, c in zip(all_preds, all_is_closed) if c] | |
| closed_refs = [r for r, c in zip(all_refs, all_is_closed) if c] | |
| closed_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if c] | |
| if closed_preds: | |
| metrics['closed'] = batch_metrics(closed_preds, closed_refs) | |
| metrics['closed']["semantic"] = compute_semantic_score(closed_preds_raw, closed_refs) | |
| metrics['closed']["bert_score"] = compute_bertscore(closed_preds_raw, closed_refs) | |
| metrics['closed'] = _attach_metric_views(metrics['closed']) | |
| metrics['closed'].update(_compute_format_stats(closed_preds, max_words=max_words)) | |
| metrics['closed_eval'] = { | |
| "accuracy": metrics['closed'].get("accuracy_normalized", 0.0), | |
| "em": metrics['closed'].get("em_normalized", 0.0), | |
| "f1": metrics['closed'].get("f1_normalized", 0.0), | |
| "count": len(closed_preds), | |
| } | |
| open_preds = [p for p, c in zip(all_preds, all_is_closed) if not c] | |
| open_refs = [r for r, c in zip(all_refs, all_is_closed) if not c] | |
| open_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if not c] | |
| if open_preds: | |
| metrics['open'] = batch_metrics(open_preds, open_refs) | |
| metrics['open']["semantic"] = compute_semantic_score(open_preds_raw, open_refs) | |
| metrics['open']["bert_score"] = compute_bertscore(open_preds_raw, open_refs) | |
| metrics['open'] = _attach_metric_views(metrics['open']) | |
| metrics['open'].update(_compute_format_stats(open_preds, max_words=max_words)) | |
| metrics['open_eval'] = { | |
| "semantic": metrics['open'].get("semantic_raw", 0.0), | |
| "bert_score": metrics['open'].get("bert_score_raw", 0.0), | |
| "f1": metrics['open'].get("f1_normalized", 0.0), | |
| "rouge_l": metrics['open'].get("rouge_l_normalized", 0.0), | |
| "count": len(open_preds), | |
| } | |
| metrics['long_answers_eval'] = { | |
| "accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0), | |
| "f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0), | |
| "bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0), | |
| "semantic": compute_semantic_score(all_preds_raw, all_refs_full), | |
| "bert_score": compute_bertscore(all_preds_raw, all_refs_full) | |
| } | |
| return metrics | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # B1 HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _B1_FEW_SHOT = ( | |
| "Q: Is there cardiomegaly? A: yes\n" | |
| "Q: What organ is shown? A: lung\n" | |
| "Q: Is the aorta normal? A: no\n" | |
| "Q: What abnormality is present? A: pleural effusion\n" | |
| ) | |
| def _build_b1_prompt(question_en: str, max_words: int) -> str: | |
| """ | |
| Few-shot prompt Γ©p LLaVA trαΊ£ lα»i ngαΊ―n (β€max_words tα»« y tαΊΏ), khΓ΄ng sinh cΓ’u dΓ i. | |
| ΔαΊ·t 4 vΓ dα»₯ in-context trΖ°α»c cΓ’u hα»i thα»±c Δα» suppress verbose prefix. | |
| """ | |
| return ( | |
| f"USER: <image>\n" | |
| f"Answer each question with medical terminology only, " | |
| f"no more than {max_words} words, no full sentences.\n" | |
| f"{_B1_FEW_SHOT}" | |
| f"Q: {question_en} A: ASSISTANT:" | |
| ) | |
| # En β Vi fast lookup (50+ thuαΊt ngα»― y tαΊΏ thΖ°α»ng gαΊ·p trong SLAKE + VQA-RAD) | |
| _EN_VI_DIRECT: dict = { | |
| # binary | |
| "yes": "cΓ³", "no": "khΓ΄ng", | |
| "present": "cΓ³", "absent": "khΓ΄ng", | |
| "normal": "bΓ¬nh thΖ°α»ng", "abnormal": "bαΊ₯t thΖ°α»ng", | |
| "true": "cΓ³", "false": "khΓ΄ng", | |
| "positive": "cΓ³", "negative": "khΓ΄ng", | |
| # anatomy | |
| "lung": "phα»i", "lungs": "phα»i", | |
| "heart": "tim", "liver": "gan", "spleen": "lΓ‘ch", | |
| "kidney": "thαΊn", "brain": "nΓ£o", "bladder": "bΓ ng quang", | |
| "chest": "ngα»±c", "abdomen": "bα»₯ng", "pelvis": "xΖ°Ζ‘ng chαΊu", | |
| "spine": "cα»t sα»ng", "rib": "xΖ°Ζ‘ng sΖ°α»n", "ribs": "xΖ°Ζ‘ng sΖ°α»n", | |
| "trachea": "khΓ quαΊ£n", "aorta": "Δα»ng mαΊ‘ch chα»§", | |
| "diaphragm": "cΖ‘ hoΓ nh", "mediastinum": "trung thαΊ₯t", | |
| # modality | |
| "chest x-ray": "x-quang ngα»±c", "x-ray": "x-quang", "xray": "x-quang", | |
| "mri": "mri", "ct": "ct", "ultrasound": "siΓͺu Γ’m", | |
| "ct scan": "ct", "mri scan": "mri", | |
| # planes | |
| "axial": "mαΊ·t phαΊ³ng ngang", | |
| "coronal": "mαΊ·t phαΊ³ng vΓ nh", | |
| "sagittal": "mαΊ·t phαΊ³ng dα»c", | |
| "transverse": "mαΊ·t phαΊ³ng ngang", | |
| # pathologies | |
| "cardiomegaly": "tim to", | |
| "pneumonia": "viΓͺm phα»i", | |
| "pleural effusion": "trΓ n dα»ch mΓ ng phα»i", | |
| "pneumothorax": "trΓ n khΓ mΓ ng phα»i", | |
| "fracture": "gΓ£y xΖ°Ζ‘ng", | |
| "edema": "phΓΉ nα»", | |
| "pulmonary edema": "phΓΉ phα»i", | |
| "consolidation": "ΔΓ΄ng ΔαΊ·c", | |
| "atelectasis": "xαΊΉp phα»i", | |
| "opacity": "mα» Δα»₯c", | |
| "mass": "khα»i u", | |
| "nodule": "nα»t", | |
| "lesion": "tα»n thΖ°Ζ‘ng", | |
| "tumor": "khα»i u", | |
| "effusion": "trΓ n dα»ch", | |
| "infiltrate": "thΓ’m nhiα» m", | |
| "fibrosis": "xΖ‘ hΓ³a", | |
| "calcification": "vΓ΄i hΓ³a", | |
| "carcinoma": "ung thΖ°", | |
| "metastasis": "di cΔn", | |
| "bilateral": "hai bΓͺn", | |
| "unilateral": "mα»t bΓͺn", | |
| "left": "trΓ‘i", "right": "phαΊ£i", | |
| "upper": "trΓͺn", "lower": "dΖ°α»i", | |
| "right upper quadrant": "phΓa trΓͺn bΓͺn phαΊ£i", | |
| "left upper quadrant": "phΓa trΓͺn bΓͺn trΓ‘i", | |
| "right lower quadrant": "phΓa dΖ°α»i bΓͺn phαΊ£i", | |
| "left lower quadrant": "phΓa dΖ°α»i bΓͺn trΓ‘i", | |
| "right upper": "phΓa trΓͺn bΓͺn phαΊ£i", | |
| "left upper": "phΓa trΓͺn bΓͺn trΓ‘i", | |
| "upper left": "phΓa trΓͺn bΓͺn trΓ‘i", | |
| "upper right": "phΓa trΓͺn bΓͺn phαΊ£i", | |
| "lower left": "phΓa dΖ°α»i bΓͺn trΓ‘i", | |
| "lower right": "phΓa dΖ°α»i bΓͺn phαΊ£i", | |
| } | |
| def _extract_key_medical_term(raw_en: str, max_words: int) -> str: | |
| """ | |
| LoαΊ‘i bα» verbose prefix LLaVA hay sinh ("The image shows a chest X-ray with..."), | |
| chα» giα»― lαΊ‘i thuαΊt ngα»― y tαΊΏ chΓnh. | |
| """ | |
| import re | |
| text = raw_en.strip().lower() | |
| # CΓ‘c prefix verbose phα» biαΊΏn cαΊ§n xΓ³a | |
| prefixes = [ | |
| r"^the (image|scan|x-ray|xray|mri|ct|picture|photo|radiograph) (shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+", | |
| r"^based on the (image|scan|x-ray|mri|ct)\s*,?\s*", | |
| r"^in (this|the) (image|scan|x-ray|mri|ct)\s*,?\s*", | |
| r"^i (can see|observe|notice|see)\s+", | |
| r"^there (is|are)\s+(a |an |some )?", | |
| r"^(it |this )(shows?|is|appears?|looks?)\s+(like\s+)?", | |
| r"^the (patient|subject)\s+(has|shows?|presents?)\s+", | |
| r"^(a|an|the)\s+", | |
| r"^[a-z\s]+ is (located|seen|found|present)( in| at| on)?\s+(the\s+)?", | |
| ] | |
| for pat in prefixes: | |
| text = re.sub(pat, "", text) | |
| text = re.sub(r"[.!?,;:]+$", "", text).strip() | |
| text = re.sub(r"\s+", " ", text).strip() | |
| words = text.split() | |
| return " ".join(words[:max_words]) if words else raw_en.strip() | |
| def _en_to_vi_direct(en_text: str) -> str | None: | |
| """ | |
| Tra tα»« Δiα»n nhanh. SαΊ―p xαΊΏp theo Δα» dΓ i giαΊ£m dαΊ§n Δα» phrase dΓ i match trΖ°α»c. | |
| TrαΊ£ vα» None nαΊΏu khΓ΄ng match β caller dΓΉng Translation Model. | |
| """ | |
| norm = en_text.strip().lower() | |
| if norm in _EN_VI_DIRECT: | |
| return _EN_VI_DIRECT[norm] | |
| return None | |
| def _dual_score_open( | |
| preds_vi: list, | |
| preds_en: list, | |
| refs_vi: list, | |
| refs_en: list, | |
| ) -> list: | |
| """ | |
| Vα»i mα»i cΓ’u hα»i mα», so sΓ‘nh F1 Vi vs F1 En rα»i chα»n prediction tα»t hΖ‘n. | |
| GiαΊ£i quyαΊΏt 0% open-ended do dα»ch thuαΊt mαΊ₯t nghΔ©a. | |
| """ | |
| from src.utils.metrics import compute_f1 | |
| from src.utils.text_utils import normalize_answer | |
| result = [] | |
| for pv, pe, rv, re_ in zip(preds_vi, preds_en, refs_vi, refs_en): | |
| f1_vi = compute_f1(pv, rv) | |
| f1_en = compute_f1(normalize_answer(pe), normalize_answer(re_)) if re_ else 0.0 | |
| result.append(pv if f1_vi >= f1_en else normalize_answer(pe)) | |
| return result | |
| def evaluate_multimodal_vqa( | |
| model, | |
| dataloader, | |
| device, | |
| processor, | |
| beam_width=1, | |
| max_words=10, | |
| variant='B1', | |
| beam_width_closed=None, | |
| beam_width_open=None, | |
| max_new_tokens_closed=None, | |
| max_new_tokens_open=None, | |
| generation_batch_size=None, | |
| ): | |
| """ | |
| B1 Zero-Shot evaluation & B2/DPO/PPO Fine-Tuned evaluation. | |
| """ | |
| model.eval() | |
| all_preds = [] | |
| all_preds_raw = [] | |
| all_preds_display = [] | |
| all_preds_en = [] | |
| all_refs = [] | |
| all_refs_full = [] | |
| all_refs_en = [] | |
| all_is_closed = [] | |
| from src.utils.translator import MedicalTranslator | |
| translator = MedicalTranslator(device=device.type) | |
| from src.models.multimodal_vqa import MultimodalVQA | |
| wrapper = MultimodalVQA() | |
| beam_width_closed = beam_width if beam_width_closed is None else beam_width_closed | |
| beam_width_open = beam_width if beam_width_open is None else beam_width_open | |
| max_new_tokens_closed = 4 if max_new_tokens_closed is None else max_new_tokens_closed | |
| max_new_tokens_open = (max_words + 6) if max_new_tokens_open is None else max_new_tokens_open | |
| generation_batch_size = 1 if generation_batch_size is None else max(1, int(generation_batch_size)) | |
| bad_words_ids = _build_bad_words_ids(processor, variant) | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating {variant}")): | |
| raw_images = batch.get('raw_image') | |
| questions_vi = batch.get('raw_questions', []) | |
| questions_en = batch.get('raw_questions_en', []) | |
| refs_vi_raw = batch.get('raw_answer', []) | |
| refs_en_raw = batch.get('raw_answer_en', []) | |
| labels = batch['label_closed'] | |
| if variant == 'B1': | |
| # B1 (Zero-shot) needs English translation & English few-shot prompt | |
| if not questions_en or any(not str(q).strip() for q in questions_en): | |
| questions_en = translator.translate_vi2en(questions_vi) | |
| prompts = [_build_b1_prompt(q, max_words) for q in questions_en] | |
| else: | |
| # B2 / DPO / PPO (Fine-tuned) expect Vietnamese instruction directly | |
| prompts = [wrapper.build_instruction_prompt(q, language="vi", include_answer=False) for q in questions_vi] | |
| preds_raw = [""] * len(prompts) | |
| closed_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl != -1] | |
| open_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl == -1] | |
| def _run_generation(sample_indices, num_beams, max_new_tokens): | |
| if not sample_indices: | |
| return [] | |
| decoded_outputs = [] | |
| chunk_size = generation_batch_size if num_beams > 1 else max(generation_batch_size, 2) | |
| for start in range(0, len(sample_indices), chunk_size): | |
| chunk_indices = sample_indices[start:start + chunk_size] | |
| text_subset = [prompts[i] for i in chunk_indices] | |
| image_subset = [raw_images[i] for i in chunk_indices] if raw_images is not None else None | |
| if image_subset is not None: | |
| inputs = processor( | |
| text=text_subset, | |
| images=image_subset, | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(device) | |
| if "pixel_values" in inputs: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) | |
| else: | |
| inputs = processor(text=text_subset, return_tensors="pt", padding=True).to(device) | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| num_beams=num_beams, | |
| early_stopping=num_beams > 1, | |
| bad_words_ids=bad_words_ids, | |
| ) | |
| input_token_len = inputs.input_ids.shape[1] | |
| decoded_outputs.extend( | |
| processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True) | |
| ) | |
| del inputs, output_ids | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return decoded_outputs | |
| if variant == 'B1': | |
| generated = _run_generation(list(range(len(prompts))), beam_width_open, max_new_tokens_open) | |
| preds_raw = generated | |
| else: | |
| for idx, pred in zip(closed_idx, _run_generation(closed_idx, beam_width_closed, max_new_tokens_closed)): | |
| preds_raw[idx] = pred | |
| for idx, pred in zip(open_idx, _run_generation(open_idx, beam_width_open, max_new_tokens_open)): | |
| preds_raw[idx] = pred | |
| preds_vi = [] | |
| preds_vi_display = [] | |
| preds_en_clean = [] | |
| if variant == 'B1': | |
| # [FIX 2] Strip verbose prefix β giα»― key medical term. TrΓ‘nh cαΊ―t vα»₯n cΓ’u tiαΊΏng Anh Δα» Dα»ch thuαΊt hiα»u ΔΓΊng. | |
| preds_en_clean = [_extract_key_medical_term(p, 50) for p in preds_raw] | |
| # [FIX 3 + 5] Per-sample: closed β normalize En trΖ°α»c; open β dict lookup rα»i Translation Model | |
| needs_translate_idx = [] # index cαΊ§n dα»ch | |
| needs_translate_txt = [] | |
| for i, pred_en in enumerate(preds_en_clean): | |
| if labels[i].item() != -1: | |
| # Closed: dΓΉng _normalize_closed_answer vα»i En pred (chΓnh xΓ‘c hΖ‘n) | |
| preds_vi.append( | |
| _normalize_closed_answer( | |
| questions_vi[i], questions_en[i], pred_en, pred_en | |
| ) | |
| ) | |
| else: | |
| # Open: thα» dict nhanh trΖ°α»c | |
| vi_direct = _en_to_vi_direct(pred_en) | |
| if vi_direct is not None: | |
| preds_vi.append(postprocess_answer(vi_direct, max_words=max_words)) | |
| else: | |
| preds_vi.append(None) # placeholder | |
| needs_translate_idx.append(i) | |
| needs_translate_txt.append(pred_en) | |
| # Batch dα»ch nhα»―ng cΓ’u cαΊ§n Translation Model | |
| if needs_translate_txt: | |
| translated = translator.translate_en2vi(needs_translate_txt) | |
| if isinstance(translated, str): | |
| translated = [translated] | |
| for idx, vi in zip(needs_translate_idx, translated): | |
| preds_vi[idx] = postprocess_answer(vi, max_words=max_words) | |
| preds_vi_display = list(preds_vi) | |
| else: | |
| # B2 / DPO / PPO directly outputs Vietnamese, no translation needed | |
| preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_raw] | |
| for i, pred_vi in enumerate(preds_raw): | |
| if labels[i].item() != -1: | |
| preds_vi.append( | |
| _normalize_closed_answer( | |
| questions_vi[i], questions_en[i] if i < len(questions_en) else "", pred_vi | |
| ) | |
| ) | |
| else: | |
| preds_vi.append(pred_vi) | |
| preds_en_clean = [""] * len(preds_raw) | |
| # ΔαΊ£m bαΊ£o khΓ΄ng cΓ³ None | |
| preds_vi = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi] | |
| preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi_display] | |
| preds_vi_raw = list(preds_vi_display) | |
| # Refs | |
| refs_vi = [postprocess_answer(r, max_words=max_words) for r in refs_vi_raw] | |
| refs_en = [postprocess_answer(r, max_words=max_words) if r else "" for r in refs_en_raw] | |
| # Debug batch ΔαΊ§u | |
| if batch_idx == 0: | |
| print(f"\n--- DEBUG {variant} (Evaluation) ---") | |
| for i in range(min(4, len(preds_vi))): | |
| q_type = "CLOSED" if labels[i].item() != -1 else "OPEN" | |
| if variant == 'B1': | |
| print(f"[{q_type}] Q (En): {questions_en[i]}") | |
| print(f" Pred (En raw): '{preds_raw[i]}'") | |
| print(f" Pred (En clean): '{preds_en_clean[i]}'") | |
| else: | |
| print(f"[{q_type}] Q (Vi): {questions_vi[i]}") | |
| print(f" Pred (Vi raw): '{preds_raw[i]}'") | |
| print(f" Pred display: '{preds_vi_display[i]}'") | |
| print(f" Pred (Vi): '{preds_vi[i]}'") | |
| print(f" GT (Vi): '{refs_vi[i]}' | GT (En): '{refs_en[i]}'") | |
| print("-----------------------------------------\n") | |
| all_preds.extend([normalize_for_metric(p) for p in preds_vi]) | |
| all_preds_raw.extend([normalize_for_metric(p) for p in preds_vi_raw]) | |
| all_preds_display.extend([normalize_for_metric(p) for p in preds_vi_display]) | |
| all_preds_en.extend([normalize_for_metric(p) for p in preds_en_clean]) | |
| all_refs.extend([normalize_for_metric(r) for r in refs_vi]) | |
| all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])]) | |
| all_refs_en.extend([normalize_for_metric(r) for r in refs_en]) | |
| all_is_closed.extend((labels != -1).tolist()) | |
| # [FIX 4] Dual-language scoring cho open-ended (chα» dΓΉng cho B1) | |
| if variant == 'B1': | |
| open_idx = [i for i, c in enumerate(all_is_closed) if not c] | |
| if open_idx: | |
| best_open = _dual_score_open( | |
| [all_preds[i] for i in open_idx], | |
| [all_preds_en[i] for i in open_idx], | |
| [all_refs[i] for i in open_idx], | |
| [all_refs_en[i] for i in open_idx], | |
| ) | |
| for k, i in enumerate(open_idx): | |
| all_preds[i] = best_open[k] | |
| # ββ Compute metrics ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| metrics = batch_metrics(all_preds, all_refs) | |
| metrics["semantic"] = compute_semantic_score(all_preds_raw, all_refs) | |
| metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs) | |
| metrics = _attach_metric_views(metrics) | |
| metrics.update(_compute_format_stats(all_preds, max_words=max_words)) | |
| metrics['predictions'] = all_preds | |
| metrics['predictions_raw'] = all_preds_raw | |
| metrics['predictions_display'] = all_preds_display | |
| metrics['predictions_en'] = all_preds_en | |
| metrics['ground_truths'] = all_refs | |
| metrics['ground_truths_en'] = all_refs_en | |
| def _subset(pred_list, ref_list, pred_raw_list): | |
| m = batch_metrics(pred_list, ref_list) | |
| m["semantic"] = compute_semantic_score(pred_raw_list, ref_list) | |
| m["bert_score"] = compute_bertscore(pred_raw_list, ref_list) | |
| m = _attach_metric_views(m) | |
| m.update(_compute_format_stats(pred_list, max_words=max_words)) | |
| return m | |
| closed_idx = [i for i, c in enumerate(all_is_closed) if c] | |
| open_idx = [i for i, c in enumerate(all_is_closed) if not c] | |
| if closed_idx: | |
| metrics['closed'] = _subset( | |
| [all_preds[i] for i in closed_idx], | |
| [all_refs[i] for i in closed_idx], | |
| [all_preds_raw[i] for i in closed_idx], | |
| ) | |
| metrics['closed_eval'] = { | |
| "accuracy": metrics['closed'].get("accuracy_normalized", 0.0), | |
| "em": metrics['closed'].get("em_normalized", 0.0), | |
| "f1": metrics['closed'].get("f1_normalized", 0.0), | |
| "count": len(closed_idx), | |
| } | |
| if open_idx: | |
| metrics['open'] = _subset( | |
| [all_preds[i] for i in open_idx], | |
| [all_refs[i] for i in open_idx], | |
| [all_preds_raw[i] for i in open_idx], | |
| ) | |
| metrics['open_eval'] = { | |
| "semantic": metrics['open'].get("semantic_raw", 0.0), | |
| "bert_score": metrics['open'].get("bert_score_raw", 0.0), | |
| "f1": metrics['open'].get("f1_normalized", 0.0), | |
| "rouge_l": metrics['open'].get("rouge_l_normalized", 0.0), | |
| "count": len(open_idx), | |
| } | |
| metrics['long_answers_eval'] = { | |
| "accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0), | |
| "f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0), | |
| "bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0), | |
| "semantic": compute_semantic_score(all_preds_raw, all_refs_full), | |
| "bert_score": compute_bertscore(all_preds_raw, all_refs_full) | |
| } | |
| return metrics | |