| |
| import argparse, json, re |
| from typing import List, Dict, Any |
|
|
| def normalize(s: str) -> str: |
| s = s.replace("```", " ") |
| s = s.strip().lower() |
| |
| s = re.sub(r"\s+", " ", s) |
| return s |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--pred_path", type=str, required=True, help="eval 的输出 JSON") |
| ap.add_argument("--out_path", type=str, default="./valid_clean/valid.json", help="评分明细输出 JSON") |
| args = ap.parse_args() |
|
|
| with open(args.pred_path, "r", encoding="utf-8") as f: |
| preds: List[Dict[str, Any]] = json.load(f) |
|
|
| rows = [] |
| hit, total = 0, 0 |
| for item in preds: |
| gt = item.get("ground_truth", "") |
| pred = item.get("model_output", "") |
| |
| if gt is None or gt == "": |
| rows.append({ |
| "id": item.get("id"), |
| "match": None, |
| "reason": "missing_ground_truth", |
| "ground_truth": gt, |
| "model_output": pred |
| }) |
| continue |
|
|
| total += 1 |
| ngt = normalize(gt) |
| npred = normalize(pred) |
|
|
| match = (npred in ngt) |
| if match: |
| hit += 1 |
|
|
| rows.append({ |
| "id": item.get("id"), |
| "match": bool(match), |
| "ground_truth": gt, |
| "model_output": pred |
| }) |
|
|
| summary = { |
| "total_with_gt": total, |
| "matched": hit, |
| "accuracy": (hit / total) if total > 0 else None |
| } |
|
|
| out = {"summary": summary, "details": rows} |
| with open(args.out_path, "w", encoding="utf-8") as f: |
| json.dump(out, f, ensure_ascii=False, indent=2) |
|
|
| print(f"[SUMMARY] matched {hit}/{total} = {summary['accuracy']:.4f}" if total else "[SUMMARY] no GT") |
|
|
| if __name__ == "__main__": |
| main() |
|
|