Spaces:
Sleeping
Sleeping
| """Training-progress evidence logging for CERNenv. | |
| Captures three classes of evidence required by the OpenEnv hackathon's | |
| "Showing Improvement in Rewards" judging criterion: | |
| 1. **Per-step training log** — every GRPO logging step records reward, | |
| loss, KL (Kullback-Leibler divergence), gradient norm and learning rate | |
| into ``evidence/training_log.csv``. A live-updating PNG curve is | |
| regenerated each time the log is appended. | |
| 2. **Mid-training checkpoint evaluations** — every ``eval_every_steps`` | |
| GRPO updates we re-evaluate the agent on a held-out task suite and | |
| append a row to ``evidence/checkpoint_evals.csv`` (training_step, | |
| mean_reward, success_rate, mass_acc, channel_acc). This produces the | |
| "progression" plot showing rewards rising over training. | |
| 3. **Before/after summary** — pre- and post-training evaluation JSONLs | |
| are turned into bar charts and reward distributions, plus a | |
| machine-readable ``evidence/before_after_metrics.json``. | |
| Everything ends up under ``evidence/`` so the trainer Space can serve | |
| the artifacts directly and ``scripts.push_to_hub`` can upload them | |
| with the model. | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence | |
| logger = logging.getLogger(__name__) | |
| # ── Paths ──────────────────────────────────────────────────────────────── | |
| class EvidencePaths: | |
| """All evidence artifact paths for a training run.""" | |
| root: Path | |
| training_log_csv: Path = field(init=False) | |
| checkpoint_evals_csv: Path = field(init=False) | |
| training_curve_png: Path = field(init=False) | |
| checkpoint_progression_png: Path = field(init=False) | |
| before_after_summary_png: Path = field(init=False) | |
| reward_distribution_png: Path = field(init=False) | |
| before_after_metrics_json: Path = field(init=False) | |
| sample_trajectories_md: Path = field(init=False) | |
| pre_eval_jsonl: Path = field(init=False) | |
| post_eval_jsonl: Path = field(init=False) | |
| def __post_init__(self) -> None: | |
| self.root = Path(self.root) | |
| self.training_log_csv = self.root / "training_log.csv" | |
| self.checkpoint_evals_csv = self.root / "checkpoint_evals.csv" | |
| self.training_curve_png = self.root / "training_curve.png" | |
| self.checkpoint_progression_png = self.root / "checkpoint_progression.png" | |
| self.before_after_summary_png = self.root / "before_after_summary.png" | |
| self.reward_distribution_png = self.root / "reward_distribution.png" | |
| self.before_after_metrics_json = self.root / "before_after_metrics.json" | |
| self.sample_trajectories_md = self.root / "sample_trajectories.md" | |
| self.pre_eval_jsonl = self.root / "pre_eval.jsonl" | |
| self.post_eval_jsonl = self.root / "post_eval.jsonl" | |
| def ensure(self) -> None: | |
| self.root.mkdir(parents=True, exist_ok=True) | |
| # ── Per-step training log + curve ──────────────────────────────────────── | |
| _LOG_FIELDS = [ | |
| "step", "epoch", "loss", "reward", "reward_std", | |
| "kl", "grad_norm", "learning_rate", "wall_time_s", | |
| ] | |
| class TrainingLogWriter: | |
| """Append-only CSV writer for per-step GRPO metrics.""" | |
| def __init__(self, path: Path) -> None: | |
| self.path = Path(path) | |
| self.path.parent.mkdir(parents=True, exist_ok=True) | |
| self._lock = threading.Lock() | |
| if not self.path.exists(): | |
| with open(self.path, "w", newline="") as f: | |
| csv.DictWriter(f, fieldnames=_LOG_FIELDS).writeheader() | |
| def append(self, row: Dict[str, Any]) -> None: | |
| with self._lock: | |
| with open(self.path, "a", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=_LOG_FIELDS) | |
| w.writerow({k: row.get(k, "") for k in _LOG_FIELDS}) | |
| def _try_import_matplotlib(): | |
| try: | |
| import matplotlib # type: ignore | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # type: ignore | |
| return plt | |
| except Exception as exc: # pragma: no cover | |
| logger.warning("matplotlib unavailable, skipping plot: %s", exc) | |
| return None | |
| def render_training_curve(csv_path: Path, png_path: Path) -> Optional[Path]: | |
| """Render a 2-panel reward / loss curve from the training log CSV.""" | |
| plt = _try_import_matplotlib() | |
| if plt is None: | |
| return None | |
| if not csv_path.exists(): | |
| return None | |
| rows: List[Dict[str, Any]] = [] | |
| with open(csv_path) as f: | |
| rdr = csv.DictReader(f) | |
| for row in rdr: | |
| try: | |
| rows.append({k: (float(v) if v not in (None, "") else None) for k, v in row.items()}) | |
| except ValueError: | |
| continue | |
| if not rows: | |
| return None | |
| steps = [r["step"] for r in rows if r.get("step") is not None] | |
| rewards = [r.get("reward") for r in rows] | |
| losses = [r.get("loss") for r in rows] | |
| fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) | |
| if any(v is not None for v in rewards): | |
| axes[0].plot(steps[: len(rewards)], rewards, lw=1.6, color="#1d4ed8") | |
| axes[0].set_ylabel("mean reward") | |
| axes[0].set_title("CERNenv GRPO training — reward over steps") | |
| axes[0].grid(alpha=0.25) | |
| if any(v is not None for v in losses): | |
| axes[1].plot(steps[: len(losses)], losses, lw=1.6, color="#c026d3") | |
| axes[1].set_ylabel("GRPO loss") | |
| axes[1].set_xlabel("training step") | |
| axes[1].grid(alpha=0.25) | |
| fig.tight_layout() | |
| png_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(png_path, dpi=140) | |
| plt.close(fig) | |
| return png_path | |
| # ── Mid-training checkpoint evaluations ────────────────────────────────── | |
| _CHECKPOINT_FIELDS = [ | |
| "step", "fraction_done", "episodes", | |
| "mean_reward", "success_rate", "mass_acc", "channel_acc", | |
| ] | |
| class CheckpointEvalWriter: | |
| """Append-only CSV writer for periodic mid-training evaluations.""" | |
| def __init__(self, path: Path) -> None: | |
| self.path = Path(path) | |
| self.path.parent.mkdir(parents=True, exist_ok=True) | |
| self._lock = threading.Lock() | |
| if not self.path.exists(): | |
| with open(self.path, "w", newline="") as f: | |
| csv.DictWriter(f, fieldnames=_CHECKPOINT_FIELDS).writeheader() | |
| def append(self, **row: Any) -> None: | |
| with self._lock: | |
| with open(self.path, "a", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=_CHECKPOINT_FIELDS) | |
| w.writerow({k: row.get(k, "") for k in _CHECKPOINT_FIELDS}) | |
| def render_checkpoint_progression(csv_path: Path, png_path: Path) -> Optional[Path]: | |
| """Render mean-reward & success-rate vs training-step progression curves.""" | |
| plt = _try_import_matplotlib() | |
| if plt is None or not csv_path.exists(): | |
| return None | |
| rows = [] | |
| with open(csv_path) as f: | |
| for row in csv.DictReader(f): | |
| try: | |
| rows.append({k: float(v) if v not in (None, "") else None for k, v in row.items()}) | |
| except ValueError: | |
| continue | |
| if not rows: | |
| return None | |
| steps = [r["step"] for r in rows] | |
| mean_r = [r.get("mean_reward") for r in rows] | |
| succ = [r.get("success_rate") for r in rows] | |
| mass = [r.get("mass_acc") for r in rows] | |
| ch = [r.get("channel_acc") for r in rows] | |
| fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) | |
| axes[0].plot(steps, mean_r, "o-", color="#1d4ed8", label="mean reward") | |
| axes[0].set_ylabel("mean episode reward") | |
| axes[0].set_title("CERNenv mid-training evaluation — progression") | |
| axes[0].grid(alpha=0.25) | |
| axes[0].legend(loc="lower right") | |
| axes[1].plot(steps, succ, "o-", color="#16a34a", label="discovery success rate") | |
| axes[1].plot(steps, mass, "s--", color="#9333ea", label="mass accuracy") | |
| axes[1].plot(steps, ch, "^--", color="#ea580c", label="channel accuracy") | |
| axes[1].set_ylabel("rate") | |
| axes[1].set_xlabel("training step") | |
| axes[1].set_ylim(-0.02, 1.02) | |
| axes[1].grid(alpha=0.25) | |
| axes[1].legend(loc="lower right") | |
| fig.tight_layout() | |
| png_path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(png_path, dpi=140) | |
| plt.close(fig) | |
| return png_path | |
| # ── Before/after summary ──────────────────────────────────────────────── | |
| def _load_jsonl(path: Path) -> List[Dict[str, Any]]: | |
| if not path.exists(): | |
| return [] | |
| out = [] | |
| with open(path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| try: | |
| out.append(json.loads(line)) | |
| except json.JSONDecodeError: | |
| continue | |
| return out | |
| def _summarise_episodes(eps: Sequence[Dict[str, Any]]) -> Dict[str, float]: | |
| if not eps: | |
| return {"n": 0, "mean_reward": 0.0, "median_reward": 0.0, | |
| "success_rate": 0.0, "mass_acc": 0.0, "channel_acc": 0.0} | |
| rewards = sorted(float(e.get("cumulative_reward") or 0.0) for e in eps) | |
| mid = rewards[len(rewards) // 2] | |
| return { | |
| "n": len(eps), | |
| "mean_reward": sum(rewards) / len(rewards), | |
| "median_reward": mid, | |
| "success_rate": sum(1 for e in eps if e.get("discovered")) / len(eps), | |
| "mass_acc": sum(1 for e in eps if e.get("correct_mass")) / len(eps), | |
| "channel_acc": sum(1 for e in eps if e.get("correct_channel")) / len(eps), | |
| } | |
| def render_before_after( | |
| *, | |
| pre_jsonl: Path, | |
| post_jsonl: Path, | |
| summary_png: Path, | |
| distribution_png: Path, | |
| metrics_json: Path, | |
| ) -> Dict[str, Any]: | |
| pre = _load_jsonl(pre_jsonl) | |
| post = _load_jsonl(post_jsonl) | |
| pre_stats = _summarise_episodes(pre) | |
| post_stats = _summarise_episodes(post) | |
| delta = { | |
| k: post_stats[k] - pre_stats[k] | |
| for k in ("mean_reward", "median_reward", "success_rate", "mass_acc", "channel_acc") | |
| } | |
| payload = {"pre": pre_stats, "post": post_stats, "delta": delta} | |
| metrics_json.parent.mkdir(parents=True, exist_ok=True) | |
| metrics_json.write_text(json.dumps(payload, indent=2)) | |
| plt = _try_import_matplotlib() | |
| if plt is None: | |
| return payload | |
| metrics = ["mean_reward", "success_rate", "mass_acc", "channel_acc"] | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| x = list(range(len(metrics))) | |
| width = 0.36 | |
| ax.bar([i - width / 2 for i in x], [pre_stats[m] for m in metrics], width=width, | |
| label=f"pre (n={pre_stats['n']})", color="#94a3b8") | |
| ax.bar([i + width / 2 for i in x], [post_stats[m] for m in metrics], width=width, | |
| label=f"post (n={post_stats['n']})", color="#1d4ed8") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(["mean reward", "discovery rate", "mass acc.", "channel acc."]) | |
| ax.set_title("CERNenv before vs after GRPO training") | |
| ax.legend() | |
| for i, m in enumerate(metrics): | |
| delta_v = post_stats[m] - pre_stats[m] | |
| ax.annotate( | |
| f"{delta_v:+.2f}", | |
| xy=(i, max(pre_stats[m], post_stats[m])), | |
| xytext=(0, 4), textcoords="offset points", | |
| ha="center", fontsize=9, color="#0f172a", | |
| ) | |
| fig.tight_layout() | |
| summary_png.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(summary_png, dpi=140) | |
| plt.close(fig) | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| pre_r = [float(e.get("cumulative_reward") or 0.0) for e in pre] | |
| post_r = [float(e.get("cumulative_reward") or 0.0) for e in post] | |
| if pre_r: | |
| ax.hist(pre_r, bins=15, alpha=0.55, label=f"pre (μ={pre_stats['mean_reward']:+.2f})", color="#94a3b8") | |
| if post_r: | |
| ax.hist(post_r, bins=15, alpha=0.55, label=f"post (μ={post_stats['mean_reward']:+.2f})", color="#1d4ed8") | |
| ax.set_xlabel("episode cumulative reward") | |
| ax.set_ylabel("episode count") | |
| ax.set_title("Reward distribution: pre vs post training") | |
| ax.legend() | |
| fig.tight_layout() | |
| distribution_png.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(distribution_png, dpi=140) | |
| plt.close(fig) | |
| return payload | |
| def render_sample_trajectories( | |
| *, | |
| pre_jsonl: Path, | |
| post_jsonl: Path, | |
| md_path: Path, | |
| n_samples: int = 3, | |
| ) -> None: | |
| """Pick representative pre vs post episodes and dump a markdown comparison.""" | |
| pre = _load_jsonl(pre_jsonl) | |
| post = _load_jsonl(post_jsonl) | |
| pre_sorted = sorted(pre, key=lambda e: float(e.get("cumulative_reward") or 0.0))[:n_samples] | |
| post_sorted = sorted(post, key=lambda e: -float(e.get("cumulative_reward") or 0.0))[:n_samples] | |
| def _fmt(ep: Dict[str, Any]) -> str: | |
| steps = ep.get("steps") or ep.get("trajectory") or [] | |
| lines = [ | |
| f"- **reward**: `{ep.get('cumulative_reward')}` " | |
| f"**discovered**: `{ep.get('discovered')}` " | |
| f"**correct_mass**: `{ep.get('correct_mass')}` " | |
| f"**correct_channel**: `{ep.get('correct_channel')}`", | |
| ] | |
| for i, st in enumerate(steps[:8]): | |
| act = st.get("action") if isinstance(st, dict) else None | |
| r = st.get("reward") if isinstance(st, dict) else None | |
| if isinstance(act, dict): | |
| lines.append(f" - step {i}: `{act.get('action_type')}` → reward `{r}`") | |
| else: | |
| lines.append(f" - step {i}: {act} → reward `{r}`") | |
| if len(steps) > 8: | |
| lines.append(f" - ... ({len(steps) - 8} more steps)") | |
| return "\n".join(lines) | |
| md = ["# CERNenv — sample trajectories (pre vs post training)\n"] | |
| md.append("## Worst pre-training episodes\n") | |
| for ep in pre_sorted: | |
| md.append(_fmt(ep) + "\n") | |
| md.append("## Best post-training episodes\n") | |
| for ep in post_sorted: | |
| md.append(_fmt(ep) + "\n") | |
| md_path.parent.mkdir(parents=True, exist_ok=True) | |
| md_path.write_text("\n".join(md)) | |
| __all__ = [ | |
| "EvidencePaths", | |
| "TrainingLogWriter", | |
| "CheckpointEvalWriter", | |
| "render_training_curve", | |
| "render_checkpoint_progression", | |
| "render_before_after", | |
| "render_sample_trajectories", | |
| ] | |