from __future__ import annotations import argparse import csv import time from pathlib import Path try: from bert_score import score as bertscore from rouge_score import rouge_scorer import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer except Exception as exc: # pragma: no cover raise SystemExit( "Evaluation requires bert-score, rouge-score, torch and transformers. Install dependencies first." ) from exc from data_utils import load_jsonl def parse_args(): parser = argparse.ArgumentParser(description="Evaluate summarization models") parser.add_argument("--test-path", required=True) parser.add_argument("--model-name", default="fnlp/bart-base-chinese") parser.add_argument("--max-source-length", type=int, default=512) parser.add_argument("--target-length", type=int, default=120) parser.add_argument("--tolerance", type=float, default=0.2) parser.add_argument("--output-csv", default="metrics_report.csv") parser.add_argument("--qafacteval-model-folder", default=None) return parser.parse_args() def length_hit(text: str, target_length: int, tolerance: float) -> bool: low = int(target_length * (1 - tolerance)) high = int(target_length * (1 + tolerance)) return low <= len(text) <= high def try_qafacteval(model_folder: str | None, sources, preds): if not model_folder: return [None] * len(preds) try: from qafacteval import QAFactEval except Exception: return [None] * len(preds) metric = QAFactEval( lerc_quip_path=f"{model_folder}/quip-512-mocha", generation_model_path=f"{model_folder}/generation/model.tar.gz", answering_model_dir=f"{model_folder}/answering", lerc_model_path=f"{model_folder}/lerc/model.tar.gz", lerc_pretrained_model_path=f"{model_folder}/lerc/pretraining.tar.gz", cuda_device=0 if torch.cuda.is_available() else -1, use_lerc_quip=True, verbose=False, generation_batch_size=8, answering_batch_size=8, lerc_batch_size=4, ) results = metric.score_batch(list(sources), [[p] for p in preds], return_qa_pairs=True) scores = [] for row in results: item = row[0]["qa-eval"].get("lerc_quip") scores.append(item) return scores def main(): args = parse_args() examples = load_jsonl(args.test_path) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) sources = [] refs = [] preds = [] times_ms = [] length_flags = [] for ex in examples: inputs = tokenizer( ex.article, return_tensors="pt", truncation=True, max_length=args.max_source_length, ) inputs.pop("token_type_ids", None) inputs = {k: v.to(device) for k, v in inputs.items()} start = time.perf_counter() with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=max(48, min(192, int(args.target_length * 1.1))), num_beams=4, no_repeat_ngram_size=3, length_penalty=1.0, early_stopping=True, ) elapsed_ms = (time.perf_counter() - start) * 1000 pred = tokenizer.decode(out[0], skip_special_tokens=True).strip() sources.append(ex.article) refs.append(ex.summary) preds.append(pred) times_ms.append(elapsed_ms) length_flags.append(length_hit(pred, args.target_length, args.tolerance)) rouge_ls = [scorer.score(ref, pred)["rougeL"].fmeasure for ref, pred in zip(refs, preds)] P, R, F1 = bertscore(preds, refs, lang="zh", verbose=False) qafacteval_scores = try_qafacteval(args.qafacteval_model_folder, sources, preds) rouge_l = sum(rouge_ls) / max(1, len(rouge_ls)) bert_f1 = float(F1.mean().item()) if hasattr(F1.mean(), "item") else float(F1.mean()) length_rate = sum(1 for v in length_flags if v) / max(1, len(length_flags)) avg_latency = sum(times_ms) / max(1, len(times_ms)) qafacteval_valid = [s for s in qafacteval_scores if s is not None] qafacteval_avg = sum(qafacteval_valid) / len(qafacteval_valid) if qafacteval_valid else None print(f"ROUGE-L: {rouge_l:.4f}") print(f"BERTScore: {bert_f1:.4f}") print(f"Length Hit Rate: {length_rate:.4f}") print(f"Avg Latency(ms): {avg_latency:.2f}") if qafacteval_avg is not None: print(f"QAFactEval: {qafacteval_avg:.4f}") else: print("QAFactEval: N/A") out_path = Path(args.output_csv) with out_path.open("w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["model", "rouge_l", "bertscore", "qafacteval", "length_hit_rate", "avg_latency_ms"]) writer.writerow( [ args.model_name, f"{rouge_l:.4f}", f"{bert_f1:.4f}", f"{qafacteval_avg:.4f}" if qafacteval_avg is not None else "", f"{length_rate:.4f}", f"{avg_latency:.2f}", ] ) print(f"saved metrics to {out_path}") if __name__ == "__main__": main()