"""Reward Curve Visualization — Auto-generate training plots. Judges NEED to see reward curves. Features: - Per-episode reward with rolling average - Trend line (slope shows learning rate) - Phase transitions marked with vertical lines - Milestone achievement annotations - Component-level breakdown sub-plots - Auto-saves PNG to training output directory Usage: from training.reward_plotter import plot_reward_curves, log_episode_reward # During training: log_episode_reward(csv_path, episode=1, reward=0.42, breakdown={...}) # After training: plot_reward_curves("outputs/reward_log.csv", "outputs/reward_plot.png") """ from __future__ import annotations import csv import json import logging import os from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) def log_episode_reward( csv_path: str | Path, episode: int, total_reward: float, tp_rate: float = 0.0, fp_rate: float = 0.0, fn_rate: float = 0.0, exp_accuracy: float = 0.0, terminal_bonus: float = 0.0, milestones: int = 0, phase: int = 1, task_id: str = "basic_oversight", breakdown: Optional[Dict[str, Any]] = None, ) -> None: """Append one episode reward to the CSV log. This is called after each GRPO episode to build the reward curve data. """ csv_path = Path(csv_path) csv_path.parent.mkdir(parents=True, exist_ok=True) write_header = not csv_path.exists() or csv_path.stat().st_size == 0 with open(csv_path, "a", newline="") as f: writer = csv.writer(f) if write_header: writer.writerow([ "episode", "total_reward", "tp_rate", "fp_rate", "fn_rate", "exp_accuracy", "terminal_bonus", "milestones", "phase", "task_id", "timestamp", "breakdown_json", ]) writer.writerow([ episode, round(total_reward, 4), round(tp_rate, 4), round(fp_rate, 4), round(fn_rate, 4), round(exp_accuracy, 4), round(terminal_bonus, 4), milestones, phase, task_id, datetime.now().isoformat(), json.dumps(breakdown) if breakdown else "", ]) def plot_reward_curves( csv_path: str | Path, out_path: Optional[str | Path] = None, title: str = "SENTINEL Oversight Agent — GRPO Training", ) -> Optional[str]: """Generate reward curve plots from training CSV log. Returns the path to the saved plot, or None if plotting failed. """ try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np except ImportError: logger.warning("matplotlib/numpy not available — skipping reward plot") return None csv_path = Path(csv_path) if not csv_path.exists(): logger.warning("No reward log at %s", csv_path) return None # Read CSV episodes, totals, tp_rates, fp_rates, fn_rates = [], [], [], [], [] exp_accuracies, terminal_bonuses, milestones_list, phases = [], [], [], [] with open(csv_path) as f: reader = csv.reader(f) header = next(reader) for row in reader: if len(row) < 9: continue episodes.append(int(row[0])) totals.append(float(row[1])) tp_rates.append(float(row[2])) fp_rates.append(float(row[3])) fn_rates.append(float(row[4])) exp_accuracies.append(float(row[5])) terminal_bonuses.append(float(row[6])) milestones_list.append(int(row[7])) phases.append(int(row[8])) if not episodes: logger.warning("No episodes in %s", csv_path) return None # Rolling average window = min(10, len(episodes)) def rolling_avg(vals): return [ sum(vals[max(0, i - window):i + 1]) / min(i + 1, window) for i in range(len(vals)) ] rolling = rolling_avg(totals) # Create figure with 3 subplots fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 12), height_ratios=[3, 2, 2]) fig.suptitle(title, fontsize=16, fontweight="bold", y=0.98) # --- Plot 1: Total Reward Curve --- ax1.plot(episodes, totals, alpha=0.25, color="#6366f1", marker="o", markersize=3, label="Per episode") ax1.plot(episodes, rolling, color="#6366f1", linewidth=2.5, label=f"Rolling avg ({window})") # Trend line z = np.polyfit(episodes, totals, 1) trend = np.poly1d(z) direction = "↑" if z[0] > 0 else "↓" ax1.plot(episodes, trend(episodes), color="#ef4444", linewidth=1.5, linestyle="--", label=f"Trend ({direction} {abs(z[0]):.4f}/ep)") # Phase transitions phase_changes = [] for i in range(1, len(phases)): if phases[i] != phases[i - 1]: phase_changes.append(episodes[i]) ax1.axvline(x=episodes[i], color="#f59e0b", linestyle="--", alpha=0.7, linewidth=1.5) ax1.text(episodes[i], max(totals) * 0.95, f"Phase {phases[i]}", rotation=90, fontsize=8, color="#f59e0b", ha="right") ax1.set_ylabel("Total Reward") ax1.set_title("Oversight Quality Over Training") ax1.legend(loc="lower right") ax1.grid(True, alpha=0.3) ax1.axhline(y=0, color="gray", linestyle="--", alpha=0.3) # Stats annotation mean_all = sum(totals) / len(totals) last10 = totals[-10:] mean_last10 = sum(last10) / len(last10) ax1.text(0.02, 0.02, f"Episodes: {len(episodes)} | Mean: {mean_all:.3f} | " f"Last-10 avg: {mean_last10:.3f} | Best: {max(totals):.3f}", transform=ax1.transAxes, fontsize=9, verticalalignment="bottom", bbox=dict(boxstyle="round", facecolor="#1e1e2e", edgecolor="#6366f1", alpha=0.8), color="white") # --- Plot 2: Detection Quality --- ax2.plot(episodes, tp_rates, color="#10b981", linewidth=1.5, alpha=0.7, label="TP Rate (detection)") ax2.plot(episodes, rolling_avg(tp_rates), color="#10b981", linewidth=2.5) ax2.plot(episodes, fp_rates, color="#ef4444", linewidth=1.5, alpha=0.7, label="FP Rate (over-blocking)") ax2.plot(episodes, rolling_avg(fp_rates), color="#ef4444", linewidth=2.5) ax2.plot(episodes, fn_rates, color="#f59e0b", linewidth=1.5, alpha=0.7, label="FN Rate (missed)") ax2.plot(episodes, rolling_avg(fn_rates), color="#f59e0b", linewidth=2.5) ax2.set_ylabel("Rate") ax2.set_title("Detection Quality: TP vs FP vs FN") ax2.legend(loc="center right") ax2.grid(True, alpha=0.3) ax2.set_ylim(-0.05, 1.05) # --- Plot 3: Terminal Bonus + Milestones --- ax3.bar(episodes, terminal_bonuses, alpha=0.4, color="#a855f7", label="Terminal Bonus") ax3_twin = ax3.twinx() ax3_twin.plot(episodes, milestones_list, color="#ec4899", linewidth=2, marker="s", markersize=3, label="Milestones (of 8)") ax3_twin.set_ylabel("Milestones Achieved", color="#ec4899") ax3_twin.set_ylim(-0.5, 8.5) ax3_twin.tick_params(axis="y", labelcolor="#ec4899") ax3.set_xlabel("Episode") ax3.set_ylabel("Terminal Bonus") ax3.set_title("Terminal Reward & Milestone Progression") ax3.legend(loc="upper left") ax3_twin.legend(loc="upper right") ax3.grid(True, alpha=0.3) plt.tight_layout() save_path = Path(out_path) if out_path else csv_path.with_suffix(".png") save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="#0a0a0f", edgecolor="none") plt.close() logger.info("Reward plot saved to %s", save_path) return str(save_path) def plot_component_breakdown( csv_path: str | Path, out_path: Optional[str | Path] = None, ) -> Optional[str]: """Generate a heatmap of reward component evolution.""" try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np except ImportError: return None csv_path = Path(csv_path) if not csv_path.exists(): return None # Read breakdowns episodes = [] breakdowns = [] with open(csv_path) as f: reader = csv.reader(f) next(reader) # skip header for row in reader: if len(row) < 12 or not row[11]: continue episodes.append(int(row[0])) breakdowns.append(json.loads(row[11])) if not breakdowns: return None # Extract component values components = [ "true_positive_catch", "explanation_accuracy", "correct_redirect", "audit_trail_quality", "incident_efficiency", "false_positive_penalty", "false_negative_penalty", ] data = np.zeros((len(components), len(breakdowns))) for j, bd in enumerate(breakdowns): for i, comp in enumerate(components): data[i, j] = bd.get(comp, 0.0) fig, ax = plt.subplots(figsize=(14, 6)) im = ax.imshow(data, aspect="auto", cmap="RdYlGn", vmin=-0.3, vmax=1.0) ax.set_yticks(range(len(components))) ax.set_yticklabels([c.replace("_", " ").title() for c in components]) ax.set_xlabel("Episode") ax.set_title("Reward Component Evolution — 10-Component Breakdown") plt.colorbar(im, ax=ax, label="Component Score") plt.tight_layout() save_path = Path(out_path) if out_path else csv_path.with_name("component_heatmap.png") plt.savefig(save_path, dpi=150) plt.close() return str(save_path)