openenv / training /reward_plotter.py
sentinel-space-publisher
space: publish latest Sentinel app snapshot
c452421
"""Reward Curve Visualization — Auto-generate training plots.
Judges NEED to see reward curves.
Features:
- Per-episode reward with rolling average
- Trend line (slope shows learning rate)
- Phase transitions marked with vertical lines
- Milestone achievement annotations
- Component-level breakdown sub-plots
- Auto-saves PNG to training output directory
Usage:
from training.reward_plotter import plot_reward_curves, log_episode_reward
# During training:
log_episode_reward(csv_path, episode=1, reward=0.42, breakdown={...})
# After training:
plot_reward_curves("outputs/reward_log.csv", "outputs/reward_plot.png")
"""
from __future__ import annotations
import csv
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
def log_episode_reward(
csv_path: str | Path,
episode: int,
total_reward: float,
tp_rate: float = 0.0,
fp_rate: float = 0.0,
fn_rate: float = 0.0,
exp_accuracy: float = 0.0,
terminal_bonus: float = 0.0,
milestones: int = 0,
phase: int = 1,
task_id: str = "basic_oversight",
breakdown: Optional[Dict[str, Any]] = None,
) -> None:
"""Append one episode reward to the CSV log.
This is called after each GRPO episode to build the reward curve data.
"""
csv_path = Path(csv_path)
csv_path.parent.mkdir(parents=True, exist_ok=True)
write_header = not csv_path.exists() or csv_path.stat().st_size == 0
with open(csv_path, "a", newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow([
"episode", "total_reward", "tp_rate", "fp_rate", "fn_rate",
"exp_accuracy", "terminal_bonus", "milestones", "phase",
"task_id", "timestamp", "breakdown_json",
])
writer.writerow([
episode,
round(total_reward, 4),
round(tp_rate, 4),
round(fp_rate, 4),
round(fn_rate, 4),
round(exp_accuracy, 4),
round(terminal_bonus, 4),
milestones,
phase,
task_id,
datetime.now().isoformat(),
json.dumps(breakdown) if breakdown else "",
])
def plot_reward_curves(
csv_path: str | Path,
out_path: Optional[str | Path] = None,
title: str = "SENTINEL Oversight Agent — GRPO Training",
) -> Optional[str]:
"""Generate reward curve plots from training CSV log.
Returns the path to the saved plot, or None if plotting failed.
"""
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
logger.warning("matplotlib/numpy not available — skipping reward plot")
return None
csv_path = Path(csv_path)
if not csv_path.exists():
logger.warning("No reward log at %s", csv_path)
return None
# Read CSV
episodes, totals, tp_rates, fp_rates, fn_rates = [], [], [], [], []
exp_accuracies, terminal_bonuses, milestones_list, phases = [], [], [], []
with open(csv_path) as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
if len(row) < 9:
continue
episodes.append(int(row[0]))
totals.append(float(row[1]))
tp_rates.append(float(row[2]))
fp_rates.append(float(row[3]))
fn_rates.append(float(row[4]))
exp_accuracies.append(float(row[5]))
terminal_bonuses.append(float(row[6]))
milestones_list.append(int(row[7]))
phases.append(int(row[8]))
if not episodes:
logger.warning("No episodes in %s", csv_path)
return None
# Rolling average
window = min(10, len(episodes))
def rolling_avg(vals):
return [
sum(vals[max(0, i - window):i + 1]) / min(i + 1, window)
for i in range(len(vals))
]
rolling = rolling_avg(totals)
# Create figure with 3 subplots
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 12), height_ratios=[3, 2, 2])
fig.suptitle(title, fontsize=16, fontweight="bold", y=0.98)
# --- Plot 1: Total Reward Curve ---
ax1.plot(episodes, totals, alpha=0.25, color="#6366f1", marker="o",
markersize=3, label="Per episode")
ax1.plot(episodes, rolling, color="#6366f1", linewidth=2.5,
label=f"Rolling avg ({window})")
# Trend line
z = np.polyfit(episodes, totals, 1)
trend = np.poly1d(z)
direction = "↑" if z[0] > 0 else "↓"
ax1.plot(episodes, trend(episodes), color="#ef4444", linewidth=1.5,
linestyle="--", label=f"Trend ({direction} {abs(z[0]):.4f}/ep)")
# Phase transitions
phase_changes = []
for i in range(1, len(phases)):
if phases[i] != phases[i - 1]:
phase_changes.append(episodes[i])
ax1.axvline(x=episodes[i], color="#f59e0b", linestyle="--",
alpha=0.7, linewidth=1.5)
ax1.text(episodes[i], max(totals) * 0.95,
f"Phase {phases[i]}",
rotation=90, fontsize=8, color="#f59e0b", ha="right")
ax1.set_ylabel("Total Reward")
ax1.set_title("Oversight Quality Over Training")
ax1.legend(loc="lower right")
ax1.grid(True, alpha=0.3)
ax1.axhline(y=0, color="gray", linestyle="--", alpha=0.3)
# Stats annotation
mean_all = sum(totals) / len(totals)
last10 = totals[-10:]
mean_last10 = sum(last10) / len(last10)
ax1.text(0.02, 0.02,
f"Episodes: {len(episodes)} | Mean: {mean_all:.3f} | "
f"Last-10 avg: {mean_last10:.3f} | Best: {max(totals):.3f}",
transform=ax1.transAxes, fontsize=9, verticalalignment="bottom",
bbox=dict(boxstyle="round", facecolor="#1e1e2e", edgecolor="#6366f1",
alpha=0.8),
color="white")
# --- Plot 2: Detection Quality ---
ax2.plot(episodes, tp_rates, color="#10b981", linewidth=1.5,
alpha=0.7, label="TP Rate (detection)")
ax2.plot(episodes, rolling_avg(tp_rates), color="#10b981", linewidth=2.5)
ax2.plot(episodes, fp_rates, color="#ef4444", linewidth=1.5,
alpha=0.7, label="FP Rate (over-blocking)")
ax2.plot(episodes, rolling_avg(fp_rates), color="#ef4444", linewidth=2.5)
ax2.plot(episodes, fn_rates, color="#f59e0b", linewidth=1.5,
alpha=0.7, label="FN Rate (missed)")
ax2.plot(episodes, rolling_avg(fn_rates), color="#f59e0b", linewidth=2.5)
ax2.set_ylabel("Rate")
ax2.set_title("Detection Quality: TP vs FP vs FN")
ax2.legend(loc="center right")
ax2.grid(True, alpha=0.3)
ax2.set_ylim(-0.05, 1.05)
# --- Plot 3: Terminal Bonus + Milestones ---
ax3.bar(episodes, terminal_bonuses, alpha=0.4, color="#a855f7",
label="Terminal Bonus")
ax3_twin = ax3.twinx()
ax3_twin.plot(episodes, milestones_list, color="#ec4899", linewidth=2,
marker="s", markersize=3, label="Milestones (of 8)")
ax3_twin.set_ylabel("Milestones Achieved", color="#ec4899")
ax3_twin.set_ylim(-0.5, 8.5)
ax3_twin.tick_params(axis="y", labelcolor="#ec4899")
ax3.set_xlabel("Episode")
ax3.set_ylabel("Terminal Bonus")
ax3.set_title("Terminal Reward & Milestone Progression")
ax3.legend(loc="upper left")
ax3_twin.legend(loc="upper right")
ax3.grid(True, alpha=0.3)
plt.tight_layout()
save_path = Path(out_path) if out_path else csv_path.with_suffix(".png")
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight",
facecolor="#0a0a0f", edgecolor="none")
plt.close()
logger.info("Reward plot saved to %s", save_path)
return str(save_path)
def plot_component_breakdown(
csv_path: str | Path,
out_path: Optional[str | Path] = None,
) -> Optional[str]:
"""Generate a heatmap of reward component evolution."""
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
except ImportError:
return None
csv_path = Path(csv_path)
if not csv_path.exists():
return None
# Read breakdowns
episodes = []
breakdowns = []
with open(csv_path) as f:
reader = csv.reader(f)
next(reader) # skip header
for row in reader:
if len(row) < 12 or not row[11]:
continue
episodes.append(int(row[0]))
breakdowns.append(json.loads(row[11]))
if not breakdowns:
return None
# Extract component values
components = [
"true_positive_catch", "explanation_accuracy", "correct_redirect",
"audit_trail_quality", "incident_efficiency",
"false_positive_penalty", "false_negative_penalty",
]
data = np.zeros((len(components), len(breakdowns)))
for j, bd in enumerate(breakdowns):
for i, comp in enumerate(components):
data[i, j] = bd.get(comp, 0.0)
fig, ax = plt.subplots(figsize=(14, 6))
im = ax.imshow(data, aspect="auto", cmap="RdYlGn", vmin=-0.3, vmax=1.0)
ax.set_yticks(range(len(components)))
ax.set_yticklabels([c.replace("_", " ").title() for c in components])
ax.set_xlabel("Episode")
ax.set_title("Reward Component Evolution — 10-Component Breakdown")
plt.colorbar(im, ax=ax, label="Component Score")
plt.tight_layout()
save_path = Path(out_path) if out_path else csv_path.with_name("component_heatmap.png")
plt.savefig(save_path, dpi=150)
plt.close()
return str(save_path)