| |
| """ |
| 对称比较直方图(仅输出 PNG): |
| - 同时计算 chosen→reject 与 reject→chosen 的 BERTScore-F1 与 ROUGE-L F1; |
| - 在每个指标上做方向平均(对称分数); |
| - 将两种指标的直方图画在同一张 PNG 中保存; |
| - 直接运行脚本(无需命令行参数)。 |
| """ |
| import os |
| import math |
| import numpy as np |
| import pandas as pd |
| from tqdm import tqdm |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| from bert_score import score as bertscore |
| from rouge_score import rouge_scorer |
|
|
| |
| DATA_PATH = "/home/data/prefiltered.parquet" |
| CHOSEN_COL = "chosen" |
| REJECT_COL = "reject" |
| LANG = "en" |
| BERTSCORE_MODEL = "roberta-large" |
| BATCH_SIZE = 256 |
| BERT_BATCH_CAP = 64 |
| PNG_PATH = "symmetric_metrics_hist.png" |
|
|
| |
| def norm_text(x): |
| if x is None or (isinstance(x, float) and math.isnan(x)): |
| return "" |
| return str(x).strip() |
|
|
| def compute_bert_symmetric_f1(chosen_list, reject_list, lang, model_type, batch_size): |
| """ |
| 对称 BERTScore-F1: |
| F1_sym = 0.5 * (F1(chosen→reject) + F1(reject→chosen)) |
| 返回 numpy.float32 数组(长度等于样本数) |
| """ |
| assert len(chosen_list) == len(reject_list) |
| n = len(chosen_list) |
| out_f1 = np.zeros(n, dtype=np.float32) |
| idx = 0 |
|
|
| for start in tqdm(range(0, n, batch_size), desc="BERTScore Symmetric"): |
| end = min(start + batch_size, n) |
| c_batch = chosen_list[start:end] |
| r_batch = reject_list[start:end] |
|
|
| |
| _, _, f1_cr = bertscore( |
| c_batch, r_batch, |
| lang=lang, |
| model_type=model_type, |
| rescale_with_baseline=True, |
| verbose=False, |
| batch_size=min(BERT_BATCH_CAP, batch_size), |
| ) |
| |
| _, _, f1_rc = bertscore( |
| r_batch, c_batch, |
| lang=lang, |
| model_type=model_type, |
| rescale_with_baseline=True, |
| verbose=False, |
| batch_size=min(BERT_BATCH_CAP, batch_size), |
| ) |
|
|
| f1_sym = 0.5 * (f1_cr.cpu().numpy() + f1_rc.cpu().numpy()) |
| out_f1[idx: idx + len(f1_sym)] = f1_sym.astype(np.float32) |
| idx += len(f1_sym) |
|
|
| return out_f1 |
|
|
| def compute_rougeL_symmetric_f1(chosen_list, reject_list, use_stemmer=True): |
| """ |
| 对称 ROUGE-L F1: |
| F1_sym = 0.5 * (F1(chosen→reject) + F1(reject→chosen)) |
| 返回 numpy.float32 数组 |
| """ |
| assert len(chosen_list) == len(reject_list) |
| scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=use_stemmer) |
| out = np.zeros(len(chosen_list), dtype=np.float32) |
|
|
| for i, (c, r) in enumerate(tqdm(zip(chosen_list, reject_list), |
| total=len(chosen_list), |
| desc="ROUGE-L Symmetric")): |
| |
| s_cr = scorer.score(c, r)["rougeL"].fmeasure |
| s_rc = scorer.score(r, c)["rougeL"].fmeasure |
| out[i] = 0.5 * (s_cr + s_rc) |
|
|
| return out.astype(np.float32) |
|
|
| |
| def main(): |
| |
| df = pd.read_parquet(DATA_PATH) |
| if CHOSEN_COL not in df.columns or REJECT_COL not in df.columns: |
| raise ValueError(f"输入文件缺少列:{CHOSEN_COL} 或 {REJECT_COL}") |
|
|
| df[CHOSEN_COL] = df[CHOSEN_COL].map(norm_text) |
| df[REJECT_COL] = df[REJECT_COL].map(norm_text) |
| mask = (df[CHOSEN_COL].str.len() > 0) & (df[REJECT_COL].str.len() > 0) |
| df = df[mask].reset_index(drop=True) |
|
|
| chosen_list = df[CHOSEN_COL].tolist() |
| reject_list = df[REJECT_COL].tolist() |
| n = len(chosen_list) |
| if n == 0: |
| raise ValueError("过滤后没有有效样本。请检查输入列内容。") |
|
|
| |
| berts_f1_sym = compute_bert_symmetric_f1( |
| chosen_list, reject_list, |
| lang=LANG, |
| model_type=BERTSCORE_MODEL, |
| batch_size=BATCH_SIZE, |
| ) |
|
|
| |
| rougeL_f1_sym = compute_rougeL_symmetric_f1( |
| chosen_list, reject_list, use_stemmer=True |
| ) |
|
|
| |
| plt.figure(figsize=(12, 5)) |
|
|
| |
| bins_bert = np.linspace(berts_f1_sym.min(), berts_f1_sym.max(), 30) |
| bins_rouge = np.linspace(rougeL_f1_sym.min(), rougeL_f1_sym.max(), 30) |
|
|
| |
| plt.subplot(1, 2, 1) |
| plt.hist(berts_f1_sym, bins=bins_bert, color='blue', alpha=0.7, edgecolor='black') |
| plt.title("Distribution of F1 BERT Scores") |
| plt.xlabel("F1 BERT Score") |
| plt.ylabel("Frequency") |
|
|
| |
| plt.subplot(1, 2, 2) |
| plt.hist(rougeL_f1_sym, bins=bins_rouge, color='green', alpha=0.7, edgecolor='black') |
| plt.title("Distribution of F1 ROUGE-L Scores") |
| plt.xlabel("F1 ROUGE-L Score") |
| plt.ylabel("Frequency") |
|
|
| plt.tight_layout() |
| plt.savefig(PNG_PATH, dpi=300) |
| print(f"[Info] 直方图已保存:{os.path.abspath(PNG_PATH)}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|