Spaces:
Running
Running
File size: 5,546 Bytes
03815d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """Compute Expected Calibration Error (ECE) + render reliability diagram.
Uses per-row eval data (`logs/eval_sft_per_row.jsonl`) to compute the
calibration of the SFT-baseline Analyzer. The v2 LoRA per-row scores
are not yet logged (B.12 in WIN_PLAN — needs GPU re-inference); when
they ship the same script renders both side by side.
Outputs:
- logs/calibration_sft.json (ECE + per-bin counts)
- plots/chakravyuh_plots/ece_reliability.png (reliability diagram)
Reliability-diagram convention (Guo et al. 2017): bin scores into 10
equal-width bins, plot bin-mean confidence (x) vs bin-mean accuracy (y).
A perfectly calibrated model lies on y = x. ECE is the weighted average
of |confidence − accuracy| across bins, weighted by bin size.
"""
from __future__ import annotations
import argparse
import json
from dataclasses import asdict, dataclass
from pathlib import Path
import matplotlib.pyplot as plt
@dataclass(frozen=True)
class BinStats:
bin_lo: float
bin_hi: float
n: int
mean_score: float
accuracy: float
abs_gap: float
@dataclass(frozen=True)
class CalibrationReport:
name: str
n: int
ece: float
mce: float
bins: list[BinStats]
notes: list[str]
def _load_per_row(path: Path) -> list[dict]:
rows: list[dict] = []
for line in path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def _compute(name: str, rows: list[dict], n_bins: int = 10) -> CalibrationReport:
bin_edges = [i / n_bins for i in range(n_bins + 1)]
bins: list[BinStats] = []
notes: list[str] = []
total_abs = 0.0
n_total = 0
mce = 0.0
for i in range(n_bins):
lo, hi = bin_edges[i], bin_edges[i + 1]
if i == n_bins - 1:
in_bin = [r for r in rows if lo <= r["score"] <= hi]
else:
in_bin = [r for r in rows if lo <= r["score"] < hi]
n = len(in_bin)
if n == 0:
bins.append(BinStats(lo, hi, 0, 0.0, 0.0, 0.0))
continue
mean_score = sum(r["score"] for r in in_bin) / n
accuracy = sum(int(bool(r.get("ground_truth", r.get("is_scam")))) for r in in_bin) / n
abs_gap = abs(mean_score - accuracy)
bins.append(BinStats(lo, hi, n, mean_score, accuracy, abs_gap))
total_abs += n * abs_gap
n_total += n
if abs_gap > mce:
mce = abs_gap
ece = total_abs / max(1, n_total)
return CalibrationReport(
name=name,
n=n_total,
ece=round(ece, 4),
mce=round(mce, 4),
bins=bins,
notes=notes,
)
def _plot(reports: list[CalibrationReport], out_path: Path) -> None:
fig, ax = plt.subplots(figsize=(7, 6))
ax.plot([0, 1], [0, 1], "k:", linewidth=1, label="perfect calibration (y = x)")
colors = {"sft": "#1565c0", "scripted": "#c62828", "v2_lora": "#558b2f"}
for r in reports:
color = colors.get(r.name, "#6a1b9a")
xs = [b.mean_score for b in r.bins if b.n > 0]
ys = [b.accuracy for b in r.bins if b.n > 0]
sizes = [max(20, b.n * 4) for b in r.bins if b.n > 0]
label = f"{r.name} (ECE = {r.ece:.3f}, n = {r.n})"
ax.scatter(xs, ys, s=sizes, alpha=0.7, color=color, edgecolor="black", linewidth=0.5,
label=label)
ax.plot(xs, ys, color=color, linewidth=1.2, alpha=0.6)
ax.set_xlabel("Confidence (bin-mean predicted score)", fontsize=11)
ax.set_ylabel("Accuracy (bin-mean ground-truth label)", fontsize=11)
ax.set_title(
"Reliability diagram — Chakravyuh defenders\n"
"Lower ECE = better calibrated · marker size ∝ bin n",
fontsize=12, fontweight="bold",
)
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
ax.grid(True, alpha=0.3)
ax.legend(loc="upper left", fontsize=9, framealpha=0.95)
ax.set_aspect("equal", adjustable="box")
fig.tight_layout()
out_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=120, bbox_inches="tight")
plt.close(fig)
print(f"Wrote {out_path} ({out_path.stat().st_size:,} bytes)")
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--sft-per-row",
type=Path,
default=Path("logs/eval_sft_per_row.jsonl"),
help="Per-row eval JSONL for the SFT baseline.",
)
parser.add_argument(
"--out-json",
type=Path,
default=Path("logs/calibration_sft.json"),
)
parser.add_argument(
"--out-plot",
type=Path,
default=Path("plots/chakravyuh_plots/ece_reliability.png"),
)
args = parser.parse_args()
reports: list[CalibrationReport] = []
if args.sft_per_row.exists():
rows = _load_per_row(args.sft_per_row)
report = _compute("sft", rows)
reports.append(report)
print(f"sft: ECE = {report.ece:.4f} · MCE = {report.mce:.4f} · n = {report.n}")
else:
print(f"warning: {args.sft_per_row} not found — skipping SFT")
if not reports:
print("error: no per-row data found")
return 2
args.out_json.parent.mkdir(parents=True, exist_ok=True)
args.out_json.write_text(
json.dumps([asdict(r) for r in reports], indent=2, default=lambda o: asdict(o)),
encoding="utf-8",
)
print(f"Wrote {args.out_json}")
_plot(reports, args.out_plot)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|