logtriage-env / merge_curves.py
OGrohit's picture
Add train.py and merge_curves.py
8dc2306
"""
merge_curves.py — Merge checkpoint data from all 3 tasks into one reward_curve.png
Place in repo root. Run after all 3 tasks have completed training.
Usage:
python merge_curves.py
Output:
reward_curve.png — 3-line plot, one per task
"""
import json
import os
import sys
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
CHECKPOINT_DIR = "./phase2_checkpoints"
OUTPUT_PATH = "reward_curve.png"
TASKS = {
"single_crash": {
"color": "#00ff9d",
"label": "Task 1: Single Crash (Easy)",
"max_steps": 8,
},
"cascading_failure": {
"color": "#ffaa00",
"label": "Task 2: Cascading Failure (Medium)",
"max_steps": 12,
},
"silent_degradation": {
"color": "#ff3b3b",
"label": "Task 3: Silent Degradation (Hard)",
"max_steps": 15,
},
}
def load_task_rewards(task_id):
"""Load rewards from highest-episode checkpoint for a given task."""
if not os.path.isdir(CHECKPOINT_DIR):
print(f"[ERROR] Checkpoint dir not found: {CHECKPOINT_DIR}")
return []
files = [
f for f in os.listdir(CHECKPOINT_DIR)
if f.startswith(task_id) and f.endswith(".json")
]
if not files:
print(f"[WARN] No checkpoint found for task: {task_id}")
return []
# Pick checkpoint with highest episode number
def ep_num(fname):
try:
return int(fname.split("_ep")[1].replace(".json", ""))
except Exception:
return 0
latest = sorted(files, key=ep_num)[-1]
path = os.path.join(CHECKPOINT_DIR, latest)
with open(path) as f:
data = json.load(f)
rewards = data.get("rewards", [])
print(f"[OK] {task_id}: loaded {len(rewards)} episodes from {latest}")
return rewards
def smooth(rewards, window=5):
"""Rolling average smoothing."""
smoothed = []
for i in range(len(rewards)):
w = rewards[max(0, i - window + 1):i + 1]
smoothed.append(sum(w) / len(w))
return smoothed
def print_stats(task_id, rewards):
"""Print first/last 10 episode averages."""
if not rewards:
return
first10 = rewards[:min(10, len(rewards))]
last10 = rewards[-min(10, len(rewards)):]
avg_first = sum(first10) / len(first10)
avg_last = sum(last10) / len(last10)
improvement = avg_last - avg_first
sign = "+" if improvement >= 0 else ""
print(f" {task_id}:")
print(f" First 10 avg : {avg_first:+.3f}")
print(f" Last 10 avg : {avg_last:+.3f}")
print(f" Improvement : {sign}{improvement:.3f}")
def main():
print("\n=== merge_curves.py ===")
print(f"Checkpoint dir : {CHECKPOINT_DIR}")
print(f"Output : {OUTPUT_PATH}\n")
# Dark background matching terminal aesthetic
plt.style.use("dark_background")
fig, ax = plt.subplots(figsize=(12, 6))
fig.patch.set_facecolor("#0a0c0f")
ax.set_facecolor("#0e1117")
found_any = False
legend_patches = []
for task_id, meta in TASKS.items():
rewards = load_task_rewards(task_id)
if not rewards:
continue
found_any = True
episodes = list(range(1, len(rewards) + 1))
smoothed = smooth(rewards, window=5)
# Raw line (faint)
ax.plot(
episodes, rewards,
alpha=0.2,
color=meta["color"],
linewidth=0.8,
zorder=2,
)
# Smoothed line (bold)
ax.plot(
episodes, smoothed,
color=meta["color"],
linewidth=2.5,
zorder=3,
)
# Start/end markers
ax.scatter([1], [rewards[0]], color=meta["color"], s=40, zorder=4, alpha=0.6)
ax.scatter([len(rewards)], [rewards[-1]], color=meta["color"], s=60, zorder=4)
legend_patches.append(
mpatches.Patch(color=meta["color"], label=meta["label"])
)
print_stats(task_id, rewards)
if not found_any:
print("[ERROR] No checkpoints found in", CHECKPOINT_DIR)
print(" Make sure train.py has run at least one task with --episodes > 0")
sys.exit(1)
# Zero line
ax.axhline(y=0, color="#2a3545", linewidth=1, linestyle="--", zorder=1, alpha=0.8)
ax.text(
1, 0.01,
"zero reward threshold",
color="#2a3545",
fontsize=9,
va="bottom",
)
# Grid
ax.grid(True, alpha=0.1, color="#2a3545")
ax.set_axisbelow(True)
# Labels
ax.set_xlabel("Episode", fontsize=12, color="#6b7d8f", labelpad=8)
ax.set_ylabel("Episode Reward", fontsize=12, color="#6b7d8f", labelpad=8)
ax.set_title(
"LogTriageEnv — GRPO Training Reward Improvement",
fontsize=14,
color="#e8f0f8",
fontweight="bold",
pad=16,
)
# Tick colors
ax.tick_params(colors="#6b7d8f")
for spine in ax.spines.values():
spine.set_edgecolor("#1e2530")
# Legend
ax.legend(
handles=legend_patches,
loc="lower right",
fontsize=10,
facecolor="#0e1117",
edgecolor="#1e2530",
labelcolor="#c8d4e0",
)
# Annotation
ax.annotate(
"Higher reward = agent resolves incident faster with fewer wrong actions",
xy=(0.02, 0.03),
xycoords="axes fraction",
fontsize=9,
color="#6b7d8f",
style="italic",
)
plt.tight_layout()
plt.savefig(OUTPUT_PATH, dpi=150, bbox_inches="tight", facecolor="#0a0c0f")
plt.close()
print(f"\n[OK] Saved: {OUTPUT_PATH}")
print(" Open with: start reward_curve.png")
print(" Push with: git add reward_curve.png && git commit -m 'feat: 3-task reward curve' && git push")
if __name__ == "__main__":
main()