Spaces:
Paused
Paused
File size: 7,745 Bytes
d63774a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """Evaluation metrics for VQA: Accuracy, EM, F1, BLEU-1~4, METEOR, and Semantic Score."""
from __future__ import annotations
from collections import Counter
import numpy as np
import torch
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from nltk.translate.meteor_score import meteor_score as _nltk_meteor
import nltk
try:
nltk.data.find('corpora/wordnet')
except LookupError:
print("[INFO] Đang tự động tải bộ từ điển NLTK WordNet cho METEOR score...")
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
# 1. Semantic Score (SentenceTransformer)
try:
from sentence_transformers import SentenceTransformer, util
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
semantic_model = None
print(f"Warning: Could not load SentenceTransformer: {e}")
# 2. BERTScore
try:
from bert_score import BERTScorer
# Ép sử dụng model multilingual để tránh lỗi attribute của Tokenizer trên Python 3.12
device = "cuda" if torch.cuda.is_available() else "cpu"
bert_scorer = BERTScorer(model_type="bert-base-multilingual-cased", device=device)
except ImportError:
print("[WARNING] Thư viện bert_score chưa được cài đặt.")
bert_scorer = None
except Exception as e:
bert_scorer = None
print(f"Warning: Could not load BERTScorer: {e}")
# 3. ROUGE-L
try:
from rouge_score import rouge_scorer as rs
rouge_l_scorer = rs.RougeScorer(['rougeL'], use_stemmer=True)
except Exception as e:
rouge_l_scorer = None
print(f"Warning: Could not load rouge-score: {e}")
# [FIX] Import from the local text_utils instead of non-existent src.data.preprocessing
from .text_utils import normalize_answer, majority_answer
def compute_rouge_l(pred: str, refs) -> float:
"""Tính ROUGE-L (Lấy MAX over multiple refs)."""
if not rouge_l_scorer: return 0.0
if isinstance(refs, str): refs = [refs]
best_rouge = 0.0
for r in refs:
score = rouge_l_scorer.score(normalize_answer(r), normalize_answer(pred))['rougeL'].fmeasure
best_rouge = max(best_rouge, score)
return best_rouge
def compute_bertscore(preds: list[str], refs: list) -> float:
"""Tính BERTScore cho cả batch."""
if not bert_scorer or not preds or not refs:
return 0.0
clean_preds = [normalize_answer(p) if normalize_answer(p).strip() else "." for p in preds]
clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
clean_refs = [r if r.strip() else "." for r in clean_refs]
try:
# Tăng tốc bằng cách tắt idf nếu cần
P, R, F1 = bert_scorer.score(clean_preds, clean_refs)
return float(F1.mean().item())
except Exception as e:
print(f"[WARNING] BERTScore error: {e}")
return 0.0
def compute_exact_match(pred: str, refs) -> float:
"""So khớp chính xác lấy MAX (soft match over multiple refs)."""
if isinstance(refs, str): refs = [refs]
return float(any(normalize_answer(pred) == normalize_answer(r) for r in refs))
def compute_f1(pred: str, refs) -> float:
"""Tính F1-score ở mức độ token. Lấy MAX over multiple refs."""
if isinstance(refs, str): refs = [refs]
best_f1 = 0.0
p_toks = normalize_answer(pred).split()
for r in refs:
r_toks = normalize_answer(r).split()
if not p_toks or not r_toks:
f1 = float(p_toks == r_toks)
else:
common = Counter(p_toks) & Counter(r_toks)
num_same = sum(common.values())
if num_same == 0:
f1 = 0.0
else:
precision = num_same / len(p_toks)
recall = num_same / len(r_toks)
f1 = 2 * precision * recall / (precision + recall)
best_f1 = max(best_f1, f1)
return best_f1
def compute_bleu(pred: str, refs) -> dict[str, float]:
"""Tính BLEU from 1 đến 4 sử dụng corpus-level refs."""
if isinstance(refs, str): refs = [refs]
smoothie = SmoothingFunction().method4
p_toks = normalize_answer(pred).split()
r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
if not p_toks or not r_toks_list:
return {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0}
weights = [
(1, 0, 0, 0), # BLEU-1
(0.5, 0.5, 0, 0), # BLEU-2
(0.33, 0.33, 0.33, 0), # BLEU-3
(0.25, 0.25, 0.25, 0.25) # BLEU-4
]
return {
f"bleu{i+1}": sentence_bleu(r_toks_list, p_toks, weights=w, smoothing_function=smoothie)
for i, w in enumerate(weights)
}
def compute_meteor(pred: str, refs) -> float:
"""Tính METEOR score (hỗ trợ N refs)."""
if isinstance(refs, str): refs = [refs]
p_toks = normalize_answer(pred).split()
r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
if not p_toks or not r_toks_list:
return 0.0
return _nltk_meteor(r_toks_list, p_toks)
def compute_vqa_accuracy(pred: str, direct_answers) -> float:
"""
Tính VQA Accuracy mềm: min(#người_cùng_đáp_án / 3, 1.0).
Using cho các tập dữ liệu có nhiều người gắn nhãn (như A-OKVQA).
"""
if isinstance(direct_answers, str):
return compute_exact_match(pred, direct_answers)
normed_pred = normalize_answer(pred)
matches = sum(1 for a in direct_answers if normalize_answer(a) == normed_pred)
return min(matches / 3.0, 1.0)
def compute_semantic_score(preds: list[str], refs: list) -> float:
"""Tính điểm tương đồng ngữ nghĩa bằng Cosine Similarity."""
if not semantic_model or not preds or not refs:
return 0.0
clean_preds = [normalize_answer(p) for p in preds]
# Take the most representative string if it's a list for semantic comparison
clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
# Encode to Vector (Embeddings)
pred_embs = semantic_model.encode(clean_preds, convert_to_tensor=True, show_progress_bar=False)
ref_embs = semantic_model.encode(clean_refs, convert_to_tensor=True, show_progress_bar=False)
# Compute Cosine distance matrix and take diagonal (1-to-1 comparison)
cosine_scores = util.cos_sim(pred_embs, ref_embs)
scores = torch.diag(cosine_scores)
return float(scores.mean().item())
def batch_metrics(predictions: list[str], references: list) -> dict[str, float]:
"""Tổng hợp toàn bộ chỉ số đo lường trên batch."""
results = {
"accuracy": [], "em": [], "f1": [], "meteor": [],
"bleu1": [], "bleu2": [], "bleu3": [], "bleu4": [],
"rouge_l": []
}
for pred, ref in zip(predictions, references):
# Pass full refs list to compute_f1, compute_bleu to maximize score
results["accuracy"].append(compute_vqa_accuracy(pred, ref))
results["em"].append(compute_exact_match(pred, ref))
results["f1"].append(compute_f1(pred, ref))
results["meteor"].append(compute_meteor(pred, ref))
results["rouge_l"].append(compute_rouge_l(pred, ref))
bleus = compute_bleu(pred, ref)
for k, v in bleus.items():
results[k].append(v)
# Average traditional metrics
final_metrics = {k: float(np.mean(v)) for k, v in results.items()}
# Compute Semantic Score and BERTScore for entire batch
final_metrics["semantic"] = compute_semantic_score(predictions, references)
final_metrics["bert_score"] = compute_bertscore(predictions, references)
return final_metrics
|