Spaces:
Running
Running
| """ | |
| merge_curves.py — Merge checkpoint data from all 3 tasks into one reward_curve.png | |
| Place in repo root. Run after all 3 tasks have completed training. | |
| Usage: | |
| python merge_curves.py | |
| Output: | |
| reward_curve.png — 3-line plot, one per task | |
| """ | |
| import json | |
| import os | |
| import sys | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| CHECKPOINT_DIR = "./phase2_checkpoints" | |
| OUTPUT_PATH = "reward_curve.png" | |
| TASKS = { | |
| "single_crash": { | |
| "color": "#00ff9d", | |
| "label": "Task 1: Single Crash (Easy)", | |
| "max_steps": 8, | |
| }, | |
| "cascading_failure": { | |
| "color": "#ffaa00", | |
| "label": "Task 2: Cascading Failure (Medium)", | |
| "max_steps": 12, | |
| }, | |
| "silent_degradation": { | |
| "color": "#ff3b3b", | |
| "label": "Task 3: Silent Degradation (Hard)", | |
| "max_steps": 15, | |
| }, | |
| } | |
| def load_task_rewards(task_id): | |
| """Load rewards from highest-episode checkpoint for a given task.""" | |
| if not os.path.isdir(CHECKPOINT_DIR): | |
| print(f"[ERROR] Checkpoint dir not found: {CHECKPOINT_DIR}") | |
| return [] | |
| files = [ | |
| f for f in os.listdir(CHECKPOINT_DIR) | |
| if f.startswith(task_id) and f.endswith(".json") | |
| ] | |
| if not files: | |
| print(f"[WARN] No checkpoint found for task: {task_id}") | |
| return [] | |
| # Pick checkpoint with highest episode number | |
| def ep_num(fname): | |
| try: | |
| return int(fname.split("_ep")[1].replace(".json", "")) | |
| except Exception: | |
| return 0 | |
| latest = sorted(files, key=ep_num)[-1] | |
| path = os.path.join(CHECKPOINT_DIR, latest) | |
| with open(path) as f: | |
| data = json.load(f) | |
| rewards = data.get("rewards", []) | |
| print(f"[OK] {task_id}: loaded {len(rewards)} episodes from {latest}") | |
| return rewards | |
| def smooth(rewards, window=5): | |
| """Rolling average smoothing.""" | |
| smoothed = [] | |
| for i in range(len(rewards)): | |
| w = rewards[max(0, i - window + 1):i + 1] | |
| smoothed.append(sum(w) / len(w)) | |
| return smoothed | |
| def print_stats(task_id, rewards): | |
| """Print first/last 10 episode averages.""" | |
| if not rewards: | |
| return | |
| first10 = rewards[:min(10, len(rewards))] | |
| last10 = rewards[-min(10, len(rewards)):] | |
| avg_first = sum(first10) / len(first10) | |
| avg_last = sum(last10) / len(last10) | |
| improvement = avg_last - avg_first | |
| sign = "+" if improvement >= 0 else "" | |
| print(f" {task_id}:") | |
| print(f" First 10 avg : {avg_first:+.3f}") | |
| print(f" Last 10 avg : {avg_last:+.3f}") | |
| print(f" Improvement : {sign}{improvement:.3f}") | |
| def main(): | |
| print("\n=== merge_curves.py ===") | |
| print(f"Checkpoint dir : {CHECKPOINT_DIR}") | |
| print(f"Output : {OUTPUT_PATH}\n") | |
| # Dark background matching terminal aesthetic | |
| plt.style.use("dark_background") | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| fig.patch.set_facecolor("#0a0c0f") | |
| ax.set_facecolor("#0e1117") | |
| found_any = False | |
| legend_patches = [] | |
| for task_id, meta in TASKS.items(): | |
| rewards = load_task_rewards(task_id) | |
| if not rewards: | |
| continue | |
| found_any = True | |
| episodes = list(range(1, len(rewards) + 1)) | |
| smoothed = smooth(rewards, window=5) | |
| # Raw line (faint) | |
| ax.plot( | |
| episodes, rewards, | |
| alpha=0.2, | |
| color=meta["color"], | |
| linewidth=0.8, | |
| zorder=2, | |
| ) | |
| # Smoothed line (bold) | |
| ax.plot( | |
| episodes, smoothed, | |
| color=meta["color"], | |
| linewidth=2.5, | |
| zorder=3, | |
| ) | |
| # Start/end markers | |
| ax.scatter([1], [rewards[0]], color=meta["color"], s=40, zorder=4, alpha=0.6) | |
| ax.scatter([len(rewards)], [rewards[-1]], color=meta["color"], s=60, zorder=4) | |
| legend_patches.append( | |
| mpatches.Patch(color=meta["color"], label=meta["label"]) | |
| ) | |
| print_stats(task_id, rewards) | |
| if not found_any: | |
| print("[ERROR] No checkpoints found in", CHECKPOINT_DIR) | |
| print(" Make sure train.py has run at least one task with --episodes > 0") | |
| sys.exit(1) | |
| # Zero line | |
| ax.axhline(y=0, color="#2a3545", linewidth=1, linestyle="--", zorder=1, alpha=0.8) | |
| ax.text( | |
| 1, 0.01, | |
| "zero reward threshold", | |
| color="#2a3545", | |
| fontsize=9, | |
| va="bottom", | |
| ) | |
| # Grid | |
| ax.grid(True, alpha=0.1, color="#2a3545") | |
| ax.set_axisbelow(True) | |
| # Labels | |
| ax.set_xlabel("Episode", fontsize=12, color="#6b7d8f", labelpad=8) | |
| ax.set_ylabel("Episode Reward", fontsize=12, color="#6b7d8f", labelpad=8) | |
| ax.set_title( | |
| "LogTriageEnv — GRPO Training Reward Improvement", | |
| fontsize=14, | |
| color="#e8f0f8", | |
| fontweight="bold", | |
| pad=16, | |
| ) | |
| # Tick colors | |
| ax.tick_params(colors="#6b7d8f") | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor("#1e2530") | |
| # Legend | |
| ax.legend( | |
| handles=legend_patches, | |
| loc="lower right", | |
| fontsize=10, | |
| facecolor="#0e1117", | |
| edgecolor="#1e2530", | |
| labelcolor="#c8d4e0", | |
| ) | |
| # Annotation | |
| ax.annotate( | |
| "Higher reward = agent resolves incident faster with fewer wrong actions", | |
| xy=(0.02, 0.03), | |
| xycoords="axes fraction", | |
| fontsize=9, | |
| color="#6b7d8f", | |
| style="italic", | |
| ) | |
| plt.tight_layout() | |
| plt.savefig(OUTPUT_PATH, dpi=150, bbox_inches="tight", facecolor="#0a0c0f") | |
| plt.close() | |
| print(f"\n[OK] Saved: {OUTPUT_PATH}") | |
| print(" Open with: start reward_curve.png") | |
| print(" Push with: git add reward_curve.png && git commit -m 'feat: 3-task reward curve' && git push") | |
| if __name__ == "__main__": | |
| main() | |