chakravyuh / eval /plot_coevolution.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Generate publication-quality co-evolution plots for B.2.
Reads result JSONs in `logs/` and writes PNGs to `plots/chakravyuh_plots/`.
Gracefully degrades if some JSONs are missing.
Usage:
python eval/plot_coevolution.py
"""
from __future__ import annotations
import json
import math
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
REPO = Path(__file__).resolve().parent.parent
LOGS = REPO / "logs"
PLOTS = REPO / "plots" / "chakravyuh_plots"
PLOTS.mkdir(parents=True, exist_ok=True)
PHASE1_HEADTOHEAD = LOGS / "b2_phase1_scammer_vs_v2_lora.json"
PHASE2_EVAL = LOGS / "b2_phase2_coevolution_eval.json"
PHASE1_BESTOF8 = LOGS / "b2_phase1_scammer_eval_n64_bestof8.json"
PHASE1_SINGLESHOT = LOGS / "b2_phase1_scammer_eval_n64.json"
C_SCRIPTED = "#bdbdbd"
C_V2 = "#fb8c00"
C_COEVO = "#43a047"
C_TRAIN = "#1e88e5"
C_HELDOUT = "#8e24aa"
def wilson_ci(k: int, n: int, z: float = 1.96) -> tuple[float, float]:
if n == 0:
return 0.0, 0.0
p = k / n
denom = 1 + z * z / n
center = (p + z * z / (2 * n)) / denom
margin = z * math.sqrt(p * (1 - p) / n + z * z / (4 * n * n)) / denom
return max(0.0, center - margin), min(1.0, center + margin)
def _ci_err(rate: float, lo: float, hi: float) -> tuple[float, float]:
"""Convert Wilson CI to matplotlib's lower/upper-error format."""
return rate - lo, hi - rate
def plot_coevolution_headline() -> Path | None:
"""Bar chart: Scripted / v2 / v2-coevolved bypass rates by split."""
if not PHASE1_HEADTOHEAD.exists():
print(f"[skip] no {PHASE1_HEADTOHEAD.name}")
return None
h2h = json.load(open(PHASE1_HEADTOHEAD))
has_phase2 = PHASE2_EVAL.exists()
p2 = json.load(open(PHASE2_EVAL)) if has_phase2 else None
splits = ["overall", "train_seeds", "held_out_seeds"]
labels = ["Overall (n=64)", "Train (n=32)", "Held-out (n=32)"]
n_splits = len(splits)
scripted = [h2h["aggregate"][s]["scripted_bypass_rate"] for s in splits]
scripted_ci = [h2h["aggregate"][s]["scripted_wilson_95_ci"] for s in splits]
v2 = [h2h["aggregate"][s]["v2_bypass_rate"] for s in splits]
v2_ci = [h2h["aggregate"][s]["v2_wilson_95_ci"] for s in splits]
if has_phase2:
coevo = [p2["aggregate"][s]["coevolved_bypass_rate"] for s in splits]
coevo_ci = [p2["aggregate"][s]["coevolved_wilson_95_ci"] for s in splits]
n_bars = 3
else:
n_bars = 2
x = np.arange(n_splits)
w = 0.27 if n_bars == 3 else 0.4
fig, ax = plt.subplots(figsize=(10, 5.5))
# Scripted bars
s_err = np.array([_ci_err(r, lo, hi) for r, (lo, hi) in zip(scripted, scripted_ci)]).T
ax.bar(x - w, scripted, w, yerr=s_err, capsize=4,
color=C_SCRIPTED, label="ScriptedAnalyzer (rule-based)", edgecolor="#666", linewidth=0.5)
# v2 bars
v_err = np.array([_ci_err(r, lo, hi) for r, (lo, hi) in zip(v2, v2_ci)]).T
ax.bar(x, v2, w, yerr=v_err, capsize=4,
color=C_V2, label="v2 Analyzer LoRA (round 1)", edgecolor="#444", linewidth=0.5)
if has_phase2:
c_err = np.array([_ci_err(r, lo, hi) for r, (lo, hi) in zip(coevo, coevo_ci)]).T
ax.bar(x + w, coevo, w, yerr=c_err, capsize=4,
color=C_COEVO, label="v2-coevolved (round 2)", edgecolor="#222", linewidth=0.5)
# Annotate bar tops with percentages
bars_groups = [(x - w, scripted), (x, v2)]
if has_phase2:
bars_groups.append((x + w, coevo))
for xs, ys in bars_groups:
for xi, yi in zip(xs, ys):
ax.text(xi, yi + 0.02, f"{yi:.0%}", ha="center", va="bottom", fontsize=10, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=11)
ax.set_ylabel("Scammer bypass rate (lower = stronger defender)", fontsize=11)
ax.set_ylim(0, 1.10)
ax.set_yticks([0, 0.25, 0.50, 0.75, 1.0])
ax.set_yticklabels(["0%", "25%", "50%", "75%", "100%"])
title = "B.2 Co-evolution head-to-head: same Scammer, three defenders"
if has_phase2:
title += " (round 1 vs round 2)"
ax.set_title(title, fontsize=12, pad=12)
ax.legend(loc="upper right", framealpha=0.95, fontsize=10)
ax.grid(axis="y", alpha=0.3, linewidth=0.5)
ax.set_axisbelow(True)
out = PLOTS / "coevolution_headline.png"
plt.tight_layout()
plt.savefig(out, dpi=140, bbox_inches="tight")
plt.close()
print(f"[ok] {out.name}")
return out
def plot_per_category() -> Path | None:
"""Grouped bar chart: bypass rate per seed category, v2 vs v2-coevolved."""
if not PHASE1_HEADTOHEAD.exists():
print(f"[skip] no {PHASE1_HEADTOHEAD.name}")
return None
h2h = json.load(open(PHASE1_HEADTOHEAD))
has_phase2 = PHASE2_EVAL.exists()
p2 = json.load(open(PHASE2_EVAL)) if has_phase2 else None
# Aggregate per-seed (8 train + 8 held-out)
samples = h2h["samples"]
p2_samples = {(s["seed"], s.get("split")): s for s in p2["samples"]} if has_phase2 else {}
by_seed: dict[str, dict] = {}
for s in samples:
seed = s["seed"]
d = by_seed.setdefault(seed, {"split": s["split"], "v2_byp": 0, "co_byp": 0, "n": 0})
d["n"] += 1
if s["v2_bypass"]:
d["v2_byp"] += 1
if has_phase2 and (seed, s["split"]) in p2_samples:
if p2_samples[(seed, s["split"])].get("v2_coevolved_bypass"):
d["co_byp"] += 1
rows = sorted(by_seed.items(), key=lambda kv: (kv[1]["split"], -kv[1]["v2_byp"]))
short_labels = []
for seed, _ in rows:
# Strip "Write a realistic " prefix and "scam ..." suffix for display
s = seed.replace("Write a realistic ", "")
for tail in (" scam message", " scam pretending", " scam asking", " scam promising",
" scam claiming", " scam threatening", " scam pre-approving",
" scam impersonating", " notification scam"):
if tail in s:
s = s.split(tail)[0]
break
short_labels.append(s[:35])
v2_rates = [d["v2_byp"] / d["n"] for _, d in rows]
co_rates = [d["co_byp"] / d["n"] for _, d in rows] if has_phase2 else None
splits = [d["split"] for _, d in rows]
y = np.arange(len(rows))
h = 0.4
fig, ax = plt.subplots(figsize=(11, 8))
ax.barh(y - h / 2, v2_rates, h, color=C_V2, label="v2 Analyzer (round 1)",
edgecolor="#444", linewidth=0.5)
if has_phase2:
ax.barh(y + h / 2, co_rates, h, color=C_COEVO, label="v2-coevolved (round 2)",
edgecolor="#222", linewidth=0.5)
# Color y-tick labels by split
ax.set_yticks(y)
ax.set_yticklabels(short_labels, fontsize=9)
for tick, sp in zip(ax.get_yticklabels(), splits):
tick.set_color(C_TRAIN if sp == "train" else C_HELDOUT)
ax.set_xlabel("Scammer bypass rate (lower = stronger defender)", fontsize=11)
ax.set_xlim(0, 1.05)
ax.set_xticks([0, 0.25, 0.5, 0.75, 1.0])
ax.set_xticklabels(["0%", "25%", "50%", "75%", "100%"])
title = "Per-category bypass: v2"
if has_phase2:
title += " vs v2-coevolved"
title += " (blue = train, purple = held-out)"
ax.set_title(title, fontsize=12, pad=12)
ax.legend(loc="lower right", framealpha=0.95, fontsize=10)
ax.grid(axis="x", alpha=0.3, linewidth=0.5)
ax.set_axisbelow(True)
ax.invert_yaxis()
out = PLOTS / "coevolution_per_category.png"
plt.tight_layout()
plt.savefig(out, dpi=140, bbox_inches="tight")
plt.close()
print(f"[ok] {out.name}")
return out
def plot_score_movement() -> Path | None:
"""Scatter: v2 score (x) vs v2-coevolved score (y) per sample. Requires phase 2."""
if not PHASE2_EVAL.exists():
print("[skip] score-movement scatter (needs phase 2 eval)")
return None
p2 = json.load(open(PHASE2_EVAL))
samples = p2["samples"]
if not all("v2_score" in s and "v2_coevolved_score" in s for s in samples):
print("[skip] phase 2 samples missing v2_score / v2_coevolved_score fields")
return None
train_pts = [(s["v2_score"], s["v2_coevolved_score"]) for s in samples if s["split"] == "train"]
held_pts = [(s["v2_score"], s["v2_coevolved_score"]) for s in samples if s["split"] == "held_out"]
fig, ax = plt.subplots(figsize=(7.5, 7.5))
# Quadrant shading
ax.axhspan(0.5, 1.0, xmin=0, xmax=0.5, color="#a5d6a7", alpha=0.25, label="caught by coevolved (was bypass)")
ax.axhspan(0, 0.5, xmin=0, xmax=0.5, color="#ef9a9a", alpha=0.20, label="bypass under both (true hard)")
ax.axhspan(0, 0.5, xmin=0.5, xmax=1.0, color="#fff59d", alpha=0.25, label="regression (caught by v2, missed by coevolved)")
if train_pts:
xs, ys = zip(*train_pts)
ax.scatter(xs, ys, s=46, c=C_TRAIN, alpha=0.8, edgecolors="#0d47a1",
linewidth=0.6, label=f"train seeds (n={len(train_pts)})")
if held_pts:
xs, ys = zip(*held_pts)
ax.scatter(xs, ys, s=46, c=C_HELDOUT, alpha=0.8, edgecolors="#4a148c",
linewidth=0.6, label=f"held-out seeds (n={len(held_pts)})")
ax.plot([0, 1], [0, 1], color="#666", linestyle="--", linewidth=1, label="identity")
ax.axhline(0.5, color="black", linewidth=0.6, linestyle=":")
ax.axvline(0.5, color="black", linewidth=0.6, linestyle=":")
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
ax.set_xlabel("v2 Analyzer score (round 1)", fontsize=11)
ax.set_ylabel("v2-coevolved Analyzer score (round 2)", fontsize=11)
ax.set_title("Per-sample score movement: v2 → v2-coevolved\n(higher = stronger detection; threshold = 0.5)",
fontsize=11, pad=10)
ax.legend(loc="lower right", framealpha=0.95, fontsize=9)
ax.grid(alpha=0.3, linewidth=0.5)
ax.set_aspect("equal")
out = PLOTS / "coevolution_score_movement.png"
plt.tight_layout()
plt.savefig(out, dpi=140, bbox_inches="tight")
plt.close()
print(f"[ok] {out.name}")
return out
def plot_training_curve() -> Path | None:
"""Phase 2 training trajectory. Plots GRPO reward if present; falls back to SFT loss.
Needs `training_log_history` in phase 2 JSON's meta dict."""
if not PHASE2_EVAL.exists():
print("[skip] training curve (needs phase 2 eval)")
return None
p2 = json.load(open(PHASE2_EVAL))
log_history = p2.get("meta", {}).get("training_log_history") or p2.get("training_log_history")
if not log_history:
print("[skip] training curve (no training_log_history in phase 2 JSON)")
return None
method = p2.get("meta", {}).get("phase2_training", {}).get("method", "").lower()
has_reward = any("reward" in e for e in log_history)
has_loss = any("loss" in e for e in log_history)
fig, ax = plt.subplots(figsize=(10, 5.5))
if has_reward:
# GRPO regime — plot reward as primary signal
steps = [e.get("step") for e in log_history if "reward" in e]
rewards = [e.get("reward") for e in log_history if "reward" in e]
ax.plot(steps, rewards, color=C_COEVO, linewidth=1.6, marker="o", markersize=3,
label="mean reward (group)")
ax.axhline(0, color="#999", linewidth=0.6, linestyle="-")
ax.axhline(-0.3, color="#c62828", linewidth=0.8, linestyle="--",
label="SafetyEarlyStop threshold (-0.3)")
ax.fill_between(steps, -0.3, min(rewards) - 0.05, color="#ffcdd2", alpha=0.25)
ax.set_ylabel("Mean GRPO reward (higher = better detection)", fontsize=11)
ax.set_title("Phase 2 GRPO training trajectory — v2 → v2-coevolved", fontsize=12, pad=10)
kl_entries = [e for e in log_history if "kl" in e]
if kl_entries:
ax2 = ax.twinx()
ax2.plot([e["step"] for e in kl_entries], [e["kl"] for e in kl_entries],
color="#1976d2", linewidth=1.2, alpha=0.6, linestyle=":",
label="KL(policy || base)")
ax2.set_ylabel("KL divergence", color="#1976d2", fontsize=10)
ax2.tick_params(axis="y", labelcolor="#1976d2")
elif has_loss:
# SFT regime — plot cross-entropy loss as primary signal
steps = [e.get("step") for e in log_history if "loss" in e]
losses = [e.get("loss") for e in log_history if "loss" in e]
ax.plot(steps, losses, color=C_COEVO, linewidth=1.6, marker="o", markersize=3,
label="cross-entropy loss")
ax.set_ylabel("SFT loss (lower = better fit to gold JSON)", fontsize=11)
ax.set_title("Phase 2 SFT training trajectory — v2 → v2-coevolved (hardened on bypass cases)",
fontsize=12, pad=10)
if losses:
ax.set_ylim(bottom=0, top=max(losses) * 1.1)
lr_entries = [e for e in log_history if "learning_rate" in e]
if lr_entries:
ax2 = ax.twinx()
ax2.plot([e["step"] for e in lr_entries], [e["learning_rate"] for e in lr_entries],
color="#1976d2", linewidth=1.2, alpha=0.6, linestyle=":",
label="learning rate (cosine decay)")
ax2.set_ylabel("Learning rate", color="#1976d2", fontsize=10)
ax2.tick_params(axis="y", labelcolor="#1976d2")
ax2.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
else:
print("[skip] log_history has neither 'reward' nor 'loss' entries")
plt.close()
return None
ax.set_xlabel("Optimizer step", fontsize=11)
ax.legend(loc="upper right" if has_loss else "lower right", framealpha=0.95, fontsize=10)
ax.grid(alpha=0.3, linewidth=0.5)
ax.set_axisbelow(True)
out = PLOTS / "coevolution_training_curve.png"
plt.tight_layout()
plt.savefig(out, dpi=140, bbox_inches="tight")
plt.close()
print(f"[ok] {out.name}")
return out
def plot_scammer_phase1_summary() -> Path | None:
"""Scammer-side context: bypass rate vs ScriptedAnalyzer per category (single-shot vs best-of-8)."""
if not (PHASE1_BESTOF8.exists() and PHASE1_SINGLESHOT.exists()):
print("[skip] scammer phase1 summary (needs both eval files)")
return None
bo8 = json.load(open(PHASE1_BESTOF8))
ss = json.load(open(PHASE1_SINGLESHOT))
# Per-seed map for both runs
bo8_per = {seed: r["bypass_rate"] for seed, r in bo8["per_seed"].items()}
ss_per = {seed: r["bypass_rate"] for seed, r in ss["per_seed"].items()}
seed_split = {s["seed"][:70]: s["split"] for s in bo8["samples"]}
seeds = sorted(bo8_per.keys(), key=lambda s: (seed_split.get(s, "?"), -bo8_per[s]))
short = []
for s in seeds:
t = s.replace("Write a realistic ", "")
for tail in (" scam message", " scam pretending", " scam asking", " scam promising",
" scam claiming", " scam threatening", " scam pre-approving",
" scam impersonating", " notification scam"):
if tail in t:
t = t.split(tail)[0]
break
short.append(t[:35])
splits = [seed_split.get(s, "?") for s in seeds]
ss_rates = [ss_per[s] for s in seeds]
bo8_rates = [bo8_per[s] for s in seeds]
y = np.arange(len(seeds))
h = 0.4
fig, ax = plt.subplots(figsize=(11, 8))
ax.barh(y - h / 2, ss_rates, h, color="#90caf9",
edgecolor="#444", linewidth=0.5, label="single-shot inference")
ax.barh(y + h / 2, bo8_rates, h, color="#1565c0",
edgecolor="#222", linewidth=0.5, label="best-of-8 inference")
ax.set_yticks(y)
ax.set_yticklabels(short, fontsize=9)
for tick, sp in zip(ax.get_yticklabels(), splits):
tick.set_color(C_TRAIN if sp == "train" else C_HELDOUT)
ax.set_xlabel("Scammer bypass rate vs ScriptedAnalyzer (higher = stronger attacker)", fontsize=11)
ax.set_xlim(0, 1.05)
ax.set_xticks([0, 0.25, 0.5, 0.75, 1.0])
ax.set_xticklabels(["0%", "25%", "50%", "75%", "100%"])
ax.set_title("B.2 phase 1 Scammer: per-category bypass of rule-based defense\n"
"(blue tick labels = training categories, purple = held-out novel)",
fontsize=11, pad=10)
ax.legend(loc="lower right", framealpha=0.95, fontsize=10)
ax.grid(axis="x", alpha=0.3, linewidth=0.5)
ax.set_axisbelow(True)
ax.invert_yaxis()
out = PLOTS / "scammer_phase1_per_category.png"
plt.tight_layout()
plt.savefig(out, dpi=140, bbox_inches="tight")
plt.close()
print(f"[ok] {out.name}")
return out
def main() -> None:
print(f"Reading from: {LOGS}")
print(f"Writing to: {PLOTS}\n")
generated = []
for fn in (
plot_coevolution_headline,
plot_per_category,
plot_score_movement,
plot_training_curve,
plot_scammer_phase1_summary,
):
out = fn()
if out:
generated.append(out)
print(f"\nGenerated {len(generated)} plot(s):")
for p in generated:
print(f" {p.relative_to(REPO)}")
if __name__ == "__main__":
main()