chakravyuh / eval /calibration_analysis.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Compute Expected Calibration Error (ECE) + render reliability diagram.
Uses per-row eval data (`logs/eval_sft_per_row.jsonl`) to compute the
calibration of the SFT-baseline Analyzer. The v2 LoRA per-row scores
are not yet logged (B.12 in WIN_PLAN — needs GPU re-inference); when
they ship the same script renders both side by side.
Outputs:
- logs/calibration_sft.json (ECE + per-bin counts)
- plots/chakravyuh_plots/ece_reliability.png (reliability diagram)
Reliability-diagram convention (Guo et al. 2017): bin scores into 10
equal-width bins, plot bin-mean confidence (x) vs bin-mean accuracy (y).
A perfectly calibrated model lies on y = x. ECE is the weighted average
of |confidence − accuracy| across bins, weighted by bin size.
"""
from __future__ import annotations
import argparse
import json
from dataclasses import asdict, dataclass
from pathlib import Path
import matplotlib.pyplot as plt
@dataclass(frozen=True)
class BinStats:
bin_lo: float
bin_hi: float
n: int
mean_score: float
accuracy: float
abs_gap: float
@dataclass(frozen=True)
class CalibrationReport:
name: str
n: int
ece: float
mce: float
bins: list[BinStats]
notes: list[str]
def _load_per_row(path: Path) -> list[dict]:
rows: list[dict] = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def _compute(name: str, rows: list[dict], n_bins: int = 10) -> CalibrationReport:
bin_edges = [i / n_bins for i in range(n_bins + 1)]
bins: list[BinStats] = []
notes: list[str] = []
total_abs = 0.0
n_total = 0
mce = 0.0
for i in range(n_bins):
lo, hi = bin_edges[i], bin_edges[i + 1]
if i == n_bins - 1:
in_bin = [r for r in rows if lo <= r["score"] <= hi]
else:
in_bin = [r for r in rows if lo <= r["score"] < hi]
n = len(in_bin)
if n == 0:
bins.append(BinStats(lo, hi, 0, 0.0, 0.0, 0.0))
continue
mean_score = sum(r["score"] for r in in_bin) / n
accuracy = sum(int(bool(r.get("ground_truth", r.get("is_scam")))) for r in in_bin) / n
abs_gap = abs(mean_score - accuracy)
bins.append(BinStats(lo, hi, n, mean_score, accuracy, abs_gap))
total_abs += n * abs_gap
n_total += n
if abs_gap > mce:
mce = abs_gap
ece = total_abs / max(1, n_total)
return CalibrationReport(
name=name,
n=n_total,
ece=round(ece, 4),
mce=round(mce, 4),
bins=bins,
notes=notes,
)
def _plot(reports: list[CalibrationReport], out_path: Path) -> None:
fig, ax = plt.subplots(figsize=(7, 6))
ax.plot([0, 1], [0, 1], "k:", linewidth=1, label="perfect calibration (y = x)")
colors = {"sft": "#1565c0", "scripted": "#c62828", "v2_lora": "#558b2f"}
for r in reports:
color = colors.get(r.name, "#6a1b9a")
xs = [b.mean_score for b in r.bins if b.n > 0]
ys = [b.accuracy for b in r.bins if b.n > 0]
sizes = [max(20, b.n * 4) for b in r.bins if b.n > 0]
label = f"{r.name} (ECE = {r.ece:.3f}, n = {r.n})"
ax.scatter(xs, ys, s=sizes, alpha=0.7, color=color, edgecolor="black", linewidth=0.5,
label=label)
ax.plot(xs, ys, color=color, linewidth=1.2, alpha=0.6)
ax.set_xlabel("Confidence (bin-mean predicted score)", fontsize=11)
ax.set_ylabel("Accuracy (bin-mean ground-truth label)", fontsize=11)
ax.set_title(
"Reliability diagram — Chakravyuh defenders\n"
"Lower ECE = better calibrated · marker size ∝ bin n",
fontsize=12, fontweight="bold",
)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
ax.grid(True, alpha=0.3)
ax.legend(loc="upper left", fontsize=9, framealpha=0.95)
ax.set_aspect("equal", adjustable="box")
fig.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=120, bbox_inches="tight")
plt.close(fig)
print(f"Wrote {out_path} ({out_path.stat().st_size:,} bytes)")
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--sft-per-row",
type=Path,
default=Path("logs/eval_sft_per_row.jsonl"),
help="Per-row eval JSONL for the SFT baseline.",
)
parser.add_argument(
"--out-json",
type=Path,
default=Path("logs/calibration_sft.json"),
)
parser.add_argument(
"--out-plot",
type=Path,
default=Path("plots/chakravyuh_plots/ece_reliability.png"),
)
args = parser.parse_args()
reports: list[CalibrationReport] = []
if args.sft_per_row.exists():
rows = _load_per_row(args.sft_per_row)
report = _compute("sft", rows)
reports.append(report)
print(f"sft: ECE = {report.ece:.4f} · MCE = {report.mce:.4f} · n = {report.n}")
else:
print(f"warning: {args.sft_per_row} not found — skipping SFT")
if not reports:
print("error: no per-row data found")
return 2
args.out_json.parent.mkdir(parents=True, exist_ok=True)
args.out_json.write_text(
json.dumps([asdict(r) for r in reports], indent=2, default=lambda o: asdict(o)),
encoding="utf-8",
)
print(f"Wrote {args.out_json}")
_plot(reports, args.out_plot)
return 0
if __name__ == "__main__":
raise SystemExit(main())