"""Generate publication-quality co-evolution plots for B.2. Reads result JSONs in `logs/` and writes PNGs to `plots/chakravyuh_plots/`. Gracefully degrades if some JSONs are missing. Usage: python eval/plot_coevolution.py """ from __future__ import annotations import json import math from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np REPO = Path(__file__).resolve().parent.parent LOGS = REPO / "logs" PLOTS = REPO / "plots" / "chakravyuh_plots" PLOTS.mkdir(parents=True, exist_ok=True) PHASE1_HEADTOHEAD = LOGS / "b2_phase1_scammer_vs_v2_lora.json" PHASE2_EVAL = LOGS / "b2_phase2_coevolution_eval.json" PHASE1_BESTOF8 = LOGS / "b2_phase1_scammer_eval_n64_bestof8.json" PHASE1_SINGLESHOT = LOGS / "b2_phase1_scammer_eval_n64.json" C_SCRIPTED = "#bdbdbd" C_V2 = "#fb8c00" C_COEVO = "#43a047" C_TRAIN = "#1e88e5" C_HELDOUT = "#8e24aa" def wilson_ci(k: int, n: int, z: float = 1.96) -> tuple[float, float]: if n == 0: return 0.0, 0.0 p = k / n denom = 1 + z * z / n center = (p + z * z / (2 * n)) / denom margin = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom return max(0.0, center - margin), min(1.0, center + margin) def _ci_err(rate: float, lo: float, hi: float) -> tuple[float, float]: """Convert Wilson CI to matplotlib's lower/upper-error format.""" return rate - lo, hi - rate def plot_coevolution_headline() -> Path | None: """Bar chart: Scripted / v2 / v2-coevolved bypass rates by split.""" if not PHASE1_HEADTOHEAD.exists(): print(f"[skip] no {PHASE1_HEADTOHEAD.name}") return None h2h = json.load(open(PHASE1_HEADTOHEAD)) has_phase2 = PHASE2_EVAL.exists() p2 = json.load(open(PHASE2_EVAL)) if has_phase2 else None splits = ["overall", "train_seeds", "held_out_seeds"] labels = ["Overall (n=64)", "Train (n=32)", "Held-out (n=32)"] n_splits = len(splits) scripted = [h2h["aggregate"][s]["scripted_bypass_rate"] for s in splits] scripted_ci = [h2h["aggregate"][s]["scripted_wilson_95_ci"] for s in splits] v2 = [h2h["aggregate"][s]["v2_bypass_rate"] for s in splits] v2_ci = [h2h["aggregate"][s]["v2_wilson_95_ci"] for s in splits] if has_phase2: coevo = [p2["aggregate"][s]["coevolved_bypass_rate"] for s in splits] coevo_ci = [p2["aggregate"][s]["coevolved_wilson_95_ci"] for s in splits] n_bars = 3 else: n_bars = 2 x = np.arange(n_splits) w = 0.27 if n_bars == 3 else 0.4 fig, ax = plt.subplots(figsize=(10, 5.5)) # Scripted bars s_err = np.array([_ci_err(r, lo, hi) for r, (lo, hi) in zip(scripted, scripted_ci)]).T ax.bar(x - w, scripted, w, yerr=s_err, capsize=4, color=C_SCRIPTED, label="ScriptedAnalyzer (rule-based)", edgecolor="#666", linewidth=0.5) # v2 bars v_err = np.array([_ci_err(r, lo, hi) for r, (lo, hi) in zip(v2, v2_ci)]).T ax.bar(x, v2, w, yerr=v_err, capsize=4, color=C_V2, label="v2 Analyzer LoRA (round 1)", edgecolor="#444", linewidth=0.5) if has_phase2: c_err = np.array([_ci_err(r, lo, hi) for r, (lo, hi) in zip(coevo, coevo_ci)]).T ax.bar(x + w, coevo, w, yerr=c_err, capsize=4, color=C_COEVO, label="v2-coevolved (round 2)", edgecolor="#222", linewidth=0.5) # Annotate bar tops with percentages bars_groups = [(x - w, scripted), (x, v2)] if has_phase2: bars_groups.append((x + w, coevo)) for xs, ys in bars_groups: for xi, yi in zip(xs, ys): ax.text(xi, yi + 0.02, f"{yi:.0%}", ha="center", va="bottom", fontsize=10, fontweight="bold") ax.set_xticks(x) ax.set_xticklabels(labels, fontsize=11) ax.set_ylabel("Scammer bypass rate (lower = stronger defender)", fontsize=11) ax.set_ylim(0, 1.10) ax.set_yticks([0, 0.25, 0.50, 0.75, 1.0]) ax.set_yticklabels(["0%", "25%", "50%", "75%", "100%"]) title = "B.2 Co-evolution head-to-head: same Scammer, three defenders" if has_phase2: title += " (round 1 vs round 2)" ax.set_title(title, fontsize=12, pad=12) ax.legend(loc="upper right", framealpha=0.95, fontsize=10) ax.grid(axis="y", alpha=0.3, linewidth=0.5) ax.set_axisbelow(True) out = PLOTS / "coevolution_headline.png" plt.tight_layout() plt.savefig(out, dpi=140, bbox_inches="tight") plt.close() print(f"[ok] {out.name}") return out def plot_per_category() -> Path | None: """Grouped bar chart: bypass rate per seed category, v2 vs v2-coevolved.""" if not PHASE1_HEADTOHEAD.exists(): print(f"[skip] no {PHASE1_HEADTOHEAD.name}") return None h2h = json.load(open(PHASE1_HEADTOHEAD)) has_phase2 = PHASE2_EVAL.exists() p2 = json.load(open(PHASE2_EVAL)) if has_phase2 else None # Aggregate per-seed (8 train + 8 held-out) samples = h2h["samples"] p2_samples = {(s["seed"], s.get("split")): s for s in p2["samples"]} if has_phase2 else {} by_seed: dict[str, dict] = {} for s in samples: seed = s["seed"] d = by_seed.setdefault(seed, {"split": s["split"], "v2_byp": 0, "co_byp": 0, "n": 0}) d["n"] += 1 if s["v2_bypass"]: d["v2_byp"] += 1 if has_phase2 and (seed, s["split"]) in p2_samples: if p2_samples[(seed, s["split"])].get("v2_coevolved_bypass"): d["co_byp"] += 1 rows = sorted(by_seed.items(), key=lambda kv: (kv[1]["split"], -kv[1]["v2_byp"])) short_labels = [] for seed, _ in rows: # Strip "Write a realistic " prefix and "scam ..." suffix for display s = seed.replace("Write a realistic ", "") for tail in (" scam message", " scam pretending", " scam asking", " scam promising", " scam claiming", " scam threatening", " scam pre-approving", " scam impersonating", " notification scam"): if tail in s: s = s.split(tail)[0] break short_labels.append(s[:35]) v2_rates = [d["v2_byp"] / d["n"] for _, d in rows] co_rates = [d["co_byp"] / d["n"] for _, d in rows] if has_phase2 else None splits = [d["split"] for _, d in rows] y = np.arange(len(rows)) h = 0.4 fig, ax = plt.subplots(figsize=(11, 8)) ax.barh(y - h / 2, v2_rates, h, color=C_V2, label="v2 Analyzer (round 1)", edgecolor="#444", linewidth=0.5) if has_phase2: ax.barh(y + h / 2, co_rates, h, color=C_COEVO, label="v2-coevolved (round 2)", edgecolor="#222", linewidth=0.5) # Color y-tick labels by split ax.set_yticks(y) ax.set_yticklabels(short_labels, fontsize=9) for tick, sp in zip(ax.get_yticklabels(), splits): tick.set_color(C_TRAIN if sp == "train" else C_HELDOUT) ax.set_xlabel("Scammer bypass rate (lower = stronger defender)", fontsize=11) ax.set_xlim(0, 1.05) ax.set_xticks([0, 0.25, 0.5, 0.75, 1.0]) ax.set_xticklabels(["0%", "25%", "50%", "75%", "100%"]) title = "Per-category bypass: v2" if has_phase2: title += " vs v2-coevolved" title += " (blue = train, purple = held-out)" ax.set_title(title, fontsize=12, pad=12) ax.legend(loc="lower right", framealpha=0.95, fontsize=10) ax.grid(axis="x", alpha=0.3, linewidth=0.5) ax.set_axisbelow(True) ax.invert_yaxis() out = PLOTS / "coevolution_per_category.png" plt.tight_layout() plt.savefig(out, dpi=140, bbox_inches="tight") plt.close() print(f"[ok] {out.name}") return out def plot_score_movement() -> Path | None: """Scatter: v2 score (x) vs v2-coevolved score (y) per sample. Requires phase 2.""" if not PHASE2_EVAL.exists(): print("[skip] score-movement scatter (needs phase 2 eval)") return None p2 = json.load(open(PHASE2_EVAL)) samples = p2["samples"] if not all("v2_score" in s and "v2_coevolved_score" in s for s in samples): print("[skip] phase 2 samples missing v2_score / v2_coevolved_score fields") return None train_pts = [(s["v2_score"], s["v2_coevolved_score"]) for s in samples if s["split"] == "train"] held_pts = [(s["v2_score"], s["v2_coevolved_score"]) for s in samples if s["split"] == "held_out"] fig, ax = plt.subplots(figsize=(7.5, 7.5)) # Quadrant shading ax.axhspan(0.5, 1.0, xmin=0, xmax=0.5, color="#a5d6a7", alpha=0.25, label="caught by coevolved (was bypass)") ax.axhspan(0, 0.5, xmin=0, xmax=0.5, color="#ef9a9a", alpha=0.20, label="bypass under both (true hard)") ax.axhspan(0, 0.5, xmin=0.5, xmax=1.0, color="#fff59d", alpha=0.25, label="regression (caught by v2, missed by coevolved)") if train_pts: xs, ys = zip(*train_pts) ax.scatter(xs, ys, s=46, c=C_TRAIN, alpha=0.8, edgecolors="#0d47a1", linewidth=0.6, label=f"train seeds (n={len(train_pts)})") if held_pts: xs, ys = zip(*held_pts) ax.scatter(xs, ys, s=46, c=C_HELDOUT, alpha=0.8, edgecolors="#4a148c", linewidth=0.6, label=f"held-out seeds (n={len(held_pts)})") ax.plot([0, 1], [0, 1], color="#666", linestyle="--", linewidth=1, label="identity") ax.axhline(0.5, color="black", linewidth=0.6, linestyle=":") ax.axvline(0.5, color="black", linewidth=0.6, linestyle=":") ax.set_xlim(-0.02, 1.02) ax.set_ylim(-0.02, 1.02) ax.set_xlabel("v2 Analyzer score (round 1)", fontsize=11) ax.set_ylabel("v2-coevolved Analyzer score (round 2)", fontsize=11) ax.set_title("Per-sample score movement: v2 → v2-coevolved\n(higher = stronger detection; threshold = 0.5)", fontsize=11, pad=10) ax.legend(loc="lower right", framealpha=0.95, fontsize=9) ax.grid(alpha=0.3, linewidth=0.5) ax.set_aspect("equal") out = PLOTS / "coevolution_score_movement.png" plt.tight_layout() plt.savefig(out, dpi=140, bbox_inches="tight") plt.close() print(f"[ok] {out.name}") return out def plot_training_curve() -> Path | None: """Phase 2 training trajectory. Plots GRPO reward if present; falls back to SFT loss. Needs `training_log_history` in phase 2 JSON's meta dict.""" if not PHASE2_EVAL.exists(): print("[skip] training curve (needs phase 2 eval)") return None p2 = json.load(open(PHASE2_EVAL)) log_history = p2.get("meta", {}).get("training_log_history") or p2.get("training_log_history") if not log_history: print("[skip] training curve (no training_log_history in phase 2 JSON)") return None method = p2.get("meta", {}).get("phase2_training", {}).get("method", "").lower() has_reward = any("reward" in e for e in log_history) has_loss = any("loss" in e for e in log_history) fig, ax = plt.subplots(figsize=(10, 5.5)) if has_reward: # GRPO regime — plot reward as primary signal steps = [e.get("step") for e in log_history if "reward" in e] rewards = [e.get("reward") for e in log_history if "reward" in e] ax.plot(steps, rewards, color=C_COEVO, linewidth=1.6, marker="o", markersize=3, label="mean reward (group)") ax.axhline(0, color="#999", linewidth=0.6, linestyle="-") ax.axhline(-0.3, color="#c62828", linewidth=0.8, linestyle="--", label="SafetyEarlyStop threshold (-0.3)") ax.fill_between(steps, -0.3, min(rewards) - 0.05, color="#ffcdd2", alpha=0.25) ax.set_ylabel("Mean GRPO reward (higher = better detection)", fontsize=11) ax.set_title("Phase 2 GRPO training trajectory — v2 → v2-coevolved", fontsize=12, pad=10) kl_entries = [e for e in log_history if "kl" in e] if kl_entries: ax2 = ax.twinx() ax2.plot([e["step"] for e in kl_entries], [e["kl"] for e in kl_entries], color="#1976d2", linewidth=1.2, alpha=0.6, linestyle=":", label="KL(policy || base)") ax2.set_ylabel("KL divergence", color="#1976d2", fontsize=10) ax2.tick_params(axis="y", labelcolor="#1976d2") elif has_loss: # SFT regime — plot cross-entropy loss as primary signal steps = [e.get("step") for e in log_history if "loss" in e] losses = [e.get("loss") for e in log_history if "loss" in e] ax.plot(steps, losses, color=C_COEVO, linewidth=1.6, marker="o", markersize=3, label="cross-entropy loss") ax.set_ylabel("SFT loss (lower = better fit to gold JSON)", fontsize=11) ax.set_title("Phase 2 SFT training trajectory — v2 → v2-coevolved (hardened on bypass cases)", fontsize=12, pad=10) if losses: ax.set_ylim(bottom=0, top=max(losses) * 1.1) lr_entries = [e for e in log_history if "learning_rate" in e] if lr_entries: ax2 = ax.twinx() ax2.plot([e["step"] for e in lr_entries], [e["learning_rate"] for e in lr_entries], color="#1976d2", linewidth=1.2, alpha=0.6, linestyle=":", label="learning rate (cosine decay)") ax2.set_ylabel("Learning rate", color="#1976d2", fontsize=10) ax2.tick_params(axis="y", labelcolor="#1976d2") ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0)) else: print("[skip] log_history has neither 'reward' nor 'loss' entries") plt.close() return None ax.set_xlabel("Optimizer step", fontsize=11) ax.legend(loc="upper right" if has_loss else "lower right", framealpha=0.95, fontsize=10) ax.grid(alpha=0.3, linewidth=0.5) ax.set_axisbelow(True) out = PLOTS / "coevolution_training_curve.png" plt.tight_layout() plt.savefig(out, dpi=140, bbox_inches="tight") plt.close() print(f"[ok] {out.name}") return out def plot_scammer_phase1_summary() -> Path | None: """Scammer-side context: bypass rate vs ScriptedAnalyzer per category (single-shot vs best-of-8).""" if not (PHASE1_BESTOF8.exists() and PHASE1_SINGLESHOT.exists()): print("[skip] scammer phase1 summary (needs both eval files)") return None bo8 = json.load(open(PHASE1_BESTOF8)) ss = json.load(open(PHASE1_SINGLESHOT)) # Per-seed map for both runs bo8_per = {seed: r["bypass_rate"] for seed, r in bo8["per_seed"].items()} ss_per = {seed: r["bypass_rate"] for seed, r in ss["per_seed"].items()} seed_split = {s["seed"][:70]: s["split"] for s in bo8["samples"]} seeds = sorted(bo8_per.keys(), key=lambda s: (seed_split.get(s, "?"), -bo8_per[s])) short = [] for s in seeds: t = s.replace("Write a realistic ", "") for tail in (" scam message", " scam pretending", " scam asking", " scam promising", " scam claiming", " scam threatening", " scam pre-approving", " scam impersonating", " notification scam"): if tail in t: t = t.split(tail)[0] break short.append(t[:35]) splits = [seed_split.get(s, "?") for s in seeds] ss_rates = [ss_per[s] for s in seeds] bo8_rates = [bo8_per[s] for s in seeds] y = np.arange(len(seeds)) h = 0.4 fig, ax = plt.subplots(figsize=(11, 8)) ax.barh(y - h / 2, ss_rates, h, color="#90caf9", edgecolor="#444", linewidth=0.5, label="single-shot inference") ax.barh(y + h / 2, bo8_rates, h, color="#1565c0", edgecolor="#222", linewidth=0.5, label="best-of-8 inference") ax.set_yticks(y) ax.set_yticklabels(short, fontsize=9) for tick, sp in zip(ax.get_yticklabels(), splits): tick.set_color(C_TRAIN if sp == "train" else C_HELDOUT) ax.set_xlabel("Scammer bypass rate vs ScriptedAnalyzer (higher = stronger attacker)", fontsize=11) ax.set_xlim(0, 1.05) ax.set_xticks([0, 0.25, 0.5, 0.75, 1.0]) ax.set_xticklabels(["0%", "25%", "50%", "75%", "100%"]) ax.set_title("B.2 phase 1 Scammer: per-category bypass of rule-based defense\n" "(blue tick labels = training categories, purple = held-out novel)", fontsize=11, pad=10) ax.legend(loc="lower right", framealpha=0.95, fontsize=10) ax.grid(axis="x", alpha=0.3, linewidth=0.5) ax.set_axisbelow(True) ax.invert_yaxis() out = PLOTS / "scammer_phase1_per_category.png" plt.tight_layout() plt.savefig(out, dpi=140, bbox_inches="tight") plt.close() print(f"[ok] {out.name}") return out def main() -> None: print(f"Reading from: {LOGS}") print(f"Writing to: {PLOTS}\n") generated = [] for fn in ( plot_coevolution_headline, plot_per_category, plot_score_movement, plot_training_curve, plot_scammer_phase1_summary, ): out = fn() if out: generated.append(out) print(f"\nGenerated {len(generated)} plot(s):") for p in generated: print(f" {p.relative_to(REPO)}") if __name__ == "__main__": main()