File size: 7,745 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""Evaluation metrics for VQA: Accuracy, EM, F1, BLEU-1~4, METEOR, and Semantic Score."""

from __future__ import annotations
from collections import Counter
import numpy as np
import torch
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from nltk.translate.meteor_score import meteor_score as _nltk_meteor

import nltk
try:
    nltk.data.find('corpora/wordnet')
except LookupError:
    print("[INFO] Đang tự động tải bộ từ điển NLTK WordNet cho METEOR score...")
    nltk.download('wordnet', quiet=True)
    nltk.download('omw-1.4', quiet=True)

# 1. Semantic Score (SentenceTransformer)
try:
    from sentence_transformers import SentenceTransformer, util
    semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
    semantic_model = None
    print(f"Warning: Could not load SentenceTransformer: {e}")

# 2. BERTScore
try:
    from bert_score import BERTScorer
    # Ép sử dụng model multilingual để tránh lỗi attribute của Tokenizer trên Python 3.12
    device = "cuda" if torch.cuda.is_available() else "cpu"
    bert_scorer = BERTScorer(model_type="bert-base-multilingual-cased", device=device)
except ImportError:
    print("[WARNING] Thư viện bert_score chưa được cài đặt.")
    bert_scorer = None
except Exception as e:
    bert_scorer = None
    print(f"Warning: Could not load BERTScorer: {e}")

# 3. ROUGE-L
try:
    from rouge_score import rouge_scorer as rs
    rouge_l_scorer = rs.RougeScorer(['rougeL'], use_stemmer=True)
except Exception as e:
    rouge_l_scorer = None
    print(f"Warning: Could not load rouge-score: {e}")

# [FIX] Import from the local text_utils instead of non-existent src.data.preprocessing
from .text_utils import normalize_answer, majority_answer

def compute_rouge_l(pred: str, refs) -> float:
    """Tính ROUGE-L (Lấy MAX over multiple refs)."""
    if not rouge_l_scorer: return 0.0
    if isinstance(refs, str): refs = [refs]
    best_rouge = 0.0
    for r in refs:
        score = rouge_l_scorer.score(normalize_answer(r), normalize_answer(pred))['rougeL'].fmeasure
        best_rouge = max(best_rouge, score)
    return best_rouge

def compute_bertscore(preds: list[str], refs: list) -> float:
    """Tính BERTScore cho cả batch."""
    if not bert_scorer or not preds or not refs:
        return 0.0
    
    clean_preds = [normalize_answer(p) if normalize_answer(p).strip() else "." for p in preds]
    clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
    clean_refs = [r if r.strip() else "." for r in clean_refs]
    
    try:
        # Tăng tốc bằng cách tắt idf nếu cần
        P, R, F1 = bert_scorer.score(clean_preds, clean_refs)
        return float(F1.mean().item())
    except Exception as e:
        print(f"[WARNING] BERTScore error: {e}")
        return 0.0

def compute_exact_match(pred: str, refs) -> float:
    """So khớp chính xác lấy MAX (soft match over multiple refs)."""
    if isinstance(refs, str): refs = [refs]
    return float(any(normalize_answer(pred) == normalize_answer(r) for r in refs))

def compute_f1(pred: str, refs) -> float:
    """Tính F1-score ở mức độ token. Lấy MAX over multiple refs."""
    if isinstance(refs, str): refs = [refs]
    best_f1 = 0.0
    p_toks = normalize_answer(pred).split()
    for r in refs:
        r_toks = normalize_answer(r).split()
        if not p_toks or not r_toks:
            f1 = float(p_toks == r_toks)
        else:
            common = Counter(p_toks) & Counter(r_toks)
            num_same = sum(common.values())
            if num_same == 0:
                f1 = 0.0
            else:
                precision = num_same / len(p_toks)
                recall = num_same / len(r_toks)
                f1 = 2 * precision * recall / (precision + recall)
        best_f1 = max(best_f1, f1)
    return best_f1

def compute_bleu(pred: str, refs) -> dict[str, float]:
    """Tính BLEU from 1 đến 4 sử dụng corpus-level refs."""
    if isinstance(refs, str): refs = [refs]
    smoothie = SmoothingFunction().method4
    p_toks = normalize_answer(pred).split()
    r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
    
    if not p_toks or not r_toks_list:
        return {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0}
    
    weights = [
        (1, 0, 0, 0),          # BLEU-1
        (0.5, 0.5, 0, 0),      # BLEU-2
        (0.33, 0.33, 0.33, 0), # BLEU-3
        (0.25, 0.25, 0.25, 0.25) # BLEU-4
    ]
    
    return {
        f"bleu{i+1}": sentence_bleu(r_toks_list, p_toks, weights=w, smoothing_function=smoothie)
        for i, w in enumerate(weights)
    }

def compute_meteor(pred: str, refs) -> float:
    """Tính METEOR score (hỗ trợ N refs)."""
    if isinstance(refs, str): refs = [refs]
    p_toks = normalize_answer(pred).split()
    r_toks_list = [normalize_answer(r).split() for r in refs if normalize_answer(r).strip()]
    if not p_toks or not r_toks_list:
        return 0.0
    return _nltk_meteor(r_toks_list, p_toks)

def compute_vqa_accuracy(pred: str, direct_answers) -> float:
    """
    Tính VQA Accuracy mềm: min(#người_cùng_đáp_án / 3, 1.0).
    Using cho các tập dữ liệu có nhiều người gắn nhãn (như A-OKVQA).
    """
    if isinstance(direct_answers, str):
        return compute_exact_match(pred, direct_answers)
    
    normed_pred = normalize_answer(pred)
    matches = sum(1 for a in direct_answers if normalize_answer(a) == normed_pred)
    return min(matches / 3.0, 1.0)

def compute_semantic_score(preds: list[str], refs: list) -> float:
    """Tính điểm tương đồng ngữ nghĩa bằng Cosine Similarity."""
    if not semantic_model or not preds or not refs:
        return 0.0
    
    clean_preds = [normalize_answer(p) for p in preds]
    # Take the most representative string if it's a list for semantic comparison
    clean_refs = [majority_answer(r) if isinstance(r, list) else normalize_answer(r) for r in refs]
    
    # Encode to Vector (Embeddings)
    pred_embs = semantic_model.encode(clean_preds, convert_to_tensor=True, show_progress_bar=False)
    ref_embs = semantic_model.encode(clean_refs, convert_to_tensor=True, show_progress_bar=False)
    
    # Compute Cosine distance matrix and take diagonal (1-to-1 comparison)
    cosine_scores = util.cos_sim(pred_embs, ref_embs)
    scores = torch.diag(cosine_scores)
    
    return float(scores.mean().item())

def batch_metrics(predictions: list[str], references: list) -> dict[str, float]:
    """Tổng hợp toàn bộ chỉ số đo lường trên batch."""
    results = {
        "accuracy": [], "em": [], "f1": [], "meteor": [],
        "bleu1": [], "bleu2": [], "bleu3": [], "bleu4": [],
        "rouge_l": []
    }
    
    for pred, ref in zip(predictions, references):
        # Pass full refs list to compute_f1, compute_bleu to maximize score
        results["accuracy"].append(compute_vqa_accuracy(pred, ref))
        results["em"].append(compute_exact_match(pred, ref))
        results["f1"].append(compute_f1(pred, ref))
        results["meteor"].append(compute_meteor(pred, ref))
        results["rouge_l"].append(compute_rouge_l(pred, ref))
        
        bleus = compute_bleu(pred, ref)
        for k, v in bleus.items():
            results[k].append(v)
            
    # Average traditional metrics
    final_metrics = {k: float(np.mean(v)) for k, v in results.items()}
    
    # Compute Semantic Score and BERTScore for entire batch
    final_metrics["semantic"] = compute_semantic_score(predictions, references)
    final_metrics["bert_score"] = compute_bertscore(predictions, references)
    
    return final_metrics