Spaces:
Running
Running
| """Compute Expected Calibration Error (ECE) + render reliability diagram. | |
| Uses per-row eval data (`logs/eval_sft_per_row.jsonl`) to compute the | |
| calibration of the SFT-baseline Analyzer. The v2 LoRA per-row scores | |
| are not yet logged (B.12 in WIN_PLAN — needs GPU re-inference); when | |
| they ship the same script renders both side by side. | |
| Outputs: | |
| - logs/calibration_sft.json (ECE + per-bin counts) | |
| - plots/chakravyuh_plots/ece_reliability.png (reliability diagram) | |
| Reliability-diagram convention (Guo et al. 2017): bin scores into 10 | |
| equal-width bins, plot bin-mean confidence (x) vs bin-mean accuracy (y). | |
| A perfectly calibrated model lies on y = x. ECE is the weighted average | |
| of |confidence − accuracy| across bins, weighted by bin size. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| class BinStats: | |
| bin_lo: float | |
| bin_hi: float | |
| n: int | |
| mean_score: float | |
| accuracy: float | |
| abs_gap: float | |
| class CalibrationReport: | |
| name: str | |
| n: int | |
| ece: float | |
| mce: float | |
| bins: list[BinStats] | |
| notes: list[str] | |
| def _load_per_row(path: Path) -> list[dict]: | |
| rows: list[dict] = [] | |
| for line in path.read_text(encoding="utf-8").splitlines(): | |
| line = line.strip() | |
| if line: | |
| rows.append(json.loads(line)) | |
| return rows | |
| def _compute(name: str, rows: list[dict], n_bins: int = 10) -> CalibrationReport: | |
| bin_edges = [i / n_bins for i in range(n_bins + 1)] | |
| bins: list[BinStats] = [] | |
| notes: list[str] = [] | |
| total_abs = 0.0 | |
| n_total = 0 | |
| mce = 0.0 | |
| for i in range(n_bins): | |
| lo, hi = bin_edges[i], bin_edges[i + 1] | |
| if i == n_bins - 1: | |
| in_bin = [r for r in rows if lo <= r["score"] <= hi] | |
| else: | |
| in_bin = [r for r in rows if lo <= r["score"] < hi] | |
| n = len(in_bin) | |
| if n == 0: | |
| bins.append(BinStats(lo, hi, 0, 0.0, 0.0, 0.0)) | |
| continue | |
| mean_score = sum(r["score"] for r in in_bin) / n | |
| accuracy = sum(int(bool(r.get("ground_truth", r.get("is_scam")))) for r in in_bin) / n | |
| abs_gap = abs(mean_score - accuracy) | |
| bins.append(BinStats(lo, hi, n, mean_score, accuracy, abs_gap)) | |
| total_abs += n * abs_gap | |
| n_total += n | |
| if abs_gap > mce: | |
| mce = abs_gap | |
| ece = total_abs / max(1, n_total) | |
| return CalibrationReport( | |
| name=name, | |
| n=n_total, | |
| ece=round(ece, 4), | |
| mce=round(mce, 4), | |
| bins=bins, | |
| notes=notes, | |
| ) | |
| def _plot(reports: list[CalibrationReport], out_path: Path) -> None: | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| ax.plot([0, 1], [0, 1], "k:", linewidth=1, label="perfect calibration (y = x)") | |
| colors = {"sft": "#1565c0", "scripted": "#c62828", "v2_lora": "#558b2f"} | |
| for r in reports: | |
| color = colors.get(r.name, "#6a1b9a") | |
| xs = [b.mean_score for b in r.bins if b.n > 0] | |
| ys = [b.accuracy for b in r.bins if b.n > 0] | |
| sizes = [max(20, b.n * 4) for b in r.bins if b.n > 0] | |
| label = f"{r.name} (ECE = {r.ece:.3f}, n = {r.n})" | |
| ax.scatter(xs, ys, s=sizes, alpha=0.7, color=color, edgecolor="black", linewidth=0.5, | |
| label=label) | |
| ax.plot(xs, ys, color=color, linewidth=1.2, alpha=0.6) | |
| ax.set_xlabel("Confidence (bin-mean predicted score)", fontsize=11) | |
| ax.set_ylabel("Accuracy (bin-mean ground-truth label)", fontsize=11) | |
| ax.set_title( | |
| "Reliability diagram — Chakravyuh defenders\n" | |
| "Lower ECE = better calibrated · marker size ∝ bin n", | |
| fontsize=12, fontweight="bold", | |
| ) | |
| ax.set_xlim(-0.02, 1.02) | |
| ax.set_ylim(-0.02, 1.02) | |
| ax.grid(True, alpha=0.3) | |
| ax.legend(loc="upper left", fontsize=9, framealpha=0.95) | |
| ax.set_aspect("equal", adjustable="box") | |
| fig.tight_layout() | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(out_path, dpi=120, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"Wrote {out_path} ({out_path.stat().st_size:,} bytes)") | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument( | |
| "--sft-per-row", | |
| type=Path, | |
| default=Path("logs/eval_sft_per_row.jsonl"), | |
| help="Per-row eval JSONL for the SFT baseline.", | |
| ) | |
| parser.add_argument( | |
| "--out-json", | |
| type=Path, | |
| default=Path("logs/calibration_sft.json"), | |
| ) | |
| parser.add_argument( | |
| "--out-plot", | |
| type=Path, | |
| default=Path("plots/chakravyuh_plots/ece_reliability.png"), | |
| ) | |
| args = parser.parse_args() | |
| reports: list[CalibrationReport] = [] | |
| if args.sft_per_row.exists(): | |
| rows = _load_per_row(args.sft_per_row) | |
| report = _compute("sft", rows) | |
| reports.append(report) | |
| print(f"sft: ECE = {report.ece:.4f} · MCE = {report.mce:.4f} · n = {report.n}") | |
| else: | |
| print(f"warning: {args.sft_per_row} not found — skipping SFT") | |
| if not reports: | |
| print("error: no per-row data found") | |
| return 2 | |
| args.out_json.parent.mkdir(parents=True, exist_ok=True) | |
| args.out_json.write_text( | |
| json.dumps([asdict(r) for r in reports], indent=2, default=lambda o: asdict(o)), | |
| encoding="utf-8", | |
| ) | |
| print(f"Wrote {args.out_json}") | |
| _plot(reports, args.out_plot) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |