"""Evaluation metrics for VQA: Accuracy, EM, F1, BLEU-1~4, METEOR, and Semantic Score.""" from __future__ import annotations from collections import Counter import numpy as np import torch from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from nltk.translate.meteor_score import meteor_score as _nltk_meteor import nltk try: nltk.data.find('corpora/wordnet') except LookupError: print("[INFO] Đang tự động tải bộ từ điển NLTK WordNet cho METEOR score...") nltk.download('wordnet', quiet=True) nltk.download('omw-1.4', quiet=True) # 1. Semantic Score (SentenceTransformer) try: from sentence_transformers import SentenceTransformer, util semantic_model = SentenceTransformer('all-MiniLM-L6-v2') except Exception as e: semantic_model = None print(f"Warning: Could not load SentenceTransformer: {e}") # 2. BERTScore try: from bert_score import BERTScorer # Ép sử dụng model multilingual để tránh lỗi attribute của Tokenizer trên Python 3.12 device = "cuda" if torch.cuda.is_available() else "cpu" bert_scorer = BERTScorer(model_type="bert-base-multilingual-cased", device=device) except ImportError: print("[WARNING] Thư viện bert_score chưa được cài đặt.") bert_scorer = None except Exception as e: bert_scorer = None print(f"Warning: Could not load BERTScorer: {e}") # 3. ROUGE-L try: from rouge_score import rouge_scorer as rs rouge_l_scorer = rs.RougeScorer(['rougeL'], use_stemmer=True) except Exception as e: rouge_l_scorer = None print(f"Warning: Could not load rouge-score: {e}") # [FIX] Import from the local text_utils instead of non-existent src.data.preprocessing from .text_utils import normalize_answer, majority_answer def compute_rouge_l(pred: str, refs) -> float: """Tính ROUGE-L (Lấy MAX over multiple refs).""" if not rouge_l_scorer: return 0.0 if isinstance(refs, str): refs = [refs] best_rouge = 0.0 for r in refs: score = rouge_l_scorer.score(normalize_answer(r), normalize_answer(pred))['rougeL'].fmeasure best_rouge = max(best_rouge, score) return best_rouge def compute_bertscore(preds: list[str], refs: list) -> float: """Tính BERTScore cho cả batch.""" if not bert_scorer or not preds or not refs: return 0.0 clean_preds = [normalize_answer(p) if normalize_answer(p).strip() else "." for p in preds] clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs] clean_refs = [r if r.strip() else "." for r in clean_refs] try: # Tăng tốc bằng cách tắt idf nếu cần P, R, F1 = bert_scorer.score(clean_preds, clean_refs) return float(F1.mean().item()) except Exception as e: print(f"[WARNING] BERTScore error: {e}") return 0.0 def compute_exact_match(pred: str, refs) -> float: """So khớp chính xác lấy MAX (soft match over multiple refs).""" if isinstance(refs, str): refs = [refs] return float(any(normalize_answer(pred) == normalize_answer(r) for r in refs)) def compute_f1(pred: str, refs) -> float: """Tính F1-score ở mức độ token. Lấy MAX over multiple refs.""" if isinstance(refs, str): refs = [refs] best_f1 = 0.0 p_toks = normalize_answer(pred).split() for r in refs: r_toks = normalize_answer(r).split() if not p_toks or not r_toks: f1 = float(p_toks == r_toks) else: common = Counter(p_toks) & Counter(r_toks) num_same = sum(common.values()) if num_same == 0: f1 = 0.0 else: precision = num_same / len(p_toks) recall = num_same / len(r_toks) f1 = 2 * precision * recall / (precision + recall) best_f1 = max(best_f1, f1) return best_f1 def compute_bleu(pred: str, refs) -> dict[str, float]: """Tính BLEU from 1 đến 4 sử dụng corpus-level refs.""" if isinstance(refs, str): refs = [refs] smoothie = SmoothingFunction().method4 p_toks = normalize_answer(pred).split() r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()] if not p_toks or not r_toks_list: return {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0} weights = [ (1, 0, 0, 0), # BLEU-1 (0.5, 0.5, 0, 0), # BLEU-2 (0.33, 0.33, 0.33, 0), # BLEU-3 (0.25, 0.25, 0.25, 0.25) # BLEU-4 ] return { f"bleu{i+1}": sentence_bleu(r_toks_list, p_toks, weights=w, smoothing_function=smoothie) for i, w in enumerate(weights) } def compute_meteor(pred: str, refs) -> float: """Tính METEOR score (hỗ trợ N refs).""" if isinstance(refs, str): refs = [refs] p_toks = normalize_answer(pred).split() r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()] if not p_toks or not r_toks_list: return 0.0 return _nltk_meteor(r_toks_list, p_toks) def compute_vqa_accuracy(pred: str, direct_answers) -> float: """ Tính VQA Accuracy mềm: min(#người_cùng_đáp_án / 3, 1.0). Using cho các tập dữ liệu có nhiều người gắn nhãn (như A-OKVQA). """ if isinstance(direct_answers, str): return compute_exact_match(pred, direct_answers) normed_pred = normalize_answer(pred) matches = sum(1 for a in direct_answers if normalize_answer(a) == normed_pred) return min(matches / 3.0, 1.0) def compute_semantic_score(preds: list[str], refs: list) -> float: """Tính điểm tương đồng ngữ nghĩa bằng Cosine Similarity.""" if not semantic_model or not preds or not refs: return 0.0 clean_preds = [normalize_answer(p) for p in preds] # Take the most representative string if it's a list for semantic comparison clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs] # Encode to Vector (Embeddings) pred_embs = semantic_model.encode(clean_preds, convert_to_tensor=True, show_progress_bar=False) ref_embs = semantic_model.encode(clean_refs, convert_to_tensor=True, show_progress_bar=False) # Compute Cosine distance matrix and take diagonal (1-to-1 comparison) cosine_scores = util.cos_sim(pred_embs, ref_embs) scores = torch.diag(cosine_scores) return float(scores.mean().item()) def batch_metrics(predictions: list[str], references: list) -> dict[str, float]: """Tổng hợp toàn bộ chỉ số đo lường trên batch.""" results = { "accuracy": [], "em": [], "f1": [], "meteor": [], "bleu1": [], "bleu2": [], "bleu3": [], "bleu4": [], "rouge_l": [] } for pred, ref in zip(predictions, references): # Pass full refs list to compute_f1, compute_bleu to maximize score results["accuracy"].append(compute_vqa_accuracy(pred, ref)) results["em"].append(compute_exact_match(pred, ref)) results["f1"].append(compute_f1(pred, ref)) results["meteor"].append(compute_meteor(pred, ref)) results["rouge_l"].append(compute_rouge_l(pred, ref)) bleus = compute_bleu(pred, ref) for k, v in bleus.items(): results[k].append(v) # Average traditional metrics final_metrics = {k: float(np.mean(v)) for k, v in results.items()} # Compute Semantic Score and BERTScore for entire batch final_metrics["semantic"] = compute_semantic_score(predictions, references) final_metrics["bert_score"] = compute_bertscore(predictions, references) return final_metrics