Spaces:
Paused
Paused
| """GRPO (Group-Relative Policy Optimization) training script for CERNenv. | |
| Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to | |
| fine-tune a small instruction-tuned model on full episodes of the CERN | |
| environment. Each ``query`` is a prompt sampled from a freshly-reset env; | |
| the reward function rolls the model's response through the environment and | |
| returns the per-step + (optional) terminal reward. | |
| This script is intentionally CPU-friendly and self-contained. For | |
| GPU-accelerated training with LoRA, prefer ``training_unsloth.py``. | |
| Run: | |
| python -m training.training_script \ | |
| --model_name HuggingFaceTB/SmolLM2-360M-Instruct \ | |
| --total_episodes 200 --max_steps 18 --output_dir training/grpo-output | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import math | |
| import os | |
| import threading | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, TYPE_CHECKING | |
| # Heavy ML deps (torch, datasets, transformers) are imported lazily inside | |
| # ``main`` and ``build_dataset`` so the lightweight helpers — reward | |
| # function, curriculum schedule, format-validity bonus — remain importable | |
| # in environments that only have the env's runtime dependencies (numpy, | |
| # pydantic, openenv-core). This keeps ``tests/`` runnable on CPU. | |
| from models import ExperimentAction | |
| from server.environment import CERNCollisionEnvironment | |
| from training.llm_agent import ( | |
| LLMAgentConfig, | |
| build_chat, | |
| parse_action, | |
| safe_default_action, | |
| ) | |
| if TYPE_CHECKING: # pragma: no cover | |
| from datasets import Dataset | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # ── Episode reward harness ─────────────────────────────────────────────── | |
| class EpisodeContext: | |
| """Per-prompt reusable env + default rollout config. | |
| ``seed`` and ``difficulty`` here are *fallback* values used when the | |
| TRL reward function does not receive per-prompt overrides via dataset | |
| columns. With a curriculum-aware dataset we always pass per-prompt | |
| ``seed``/``difficulty`` so the reward truly corresponds to the | |
| scored prompt. | |
| """ | |
| env: CERNCollisionEnvironment | |
| seed: int | |
| scenario: Optional[str] | |
| difficulty: Optional[str] | |
| class EpisodeStats: | |
| """Per-rollout reward breakdown surfaced for component-level logging. | |
| The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual | |
| reward function columns, not just average reward". This struct gives | |
| the EvidenceCallback enough information to log each component on its | |
| own column so a reviewer (or you) can see *which* reward terms drove | |
| the policy update at any given training step. | |
| """ | |
| cumulative_reward: float = 0.0 | |
| terminal_reward: float = 0.0 | |
| step_shaping: float = 0.0 # cumulative_reward - terminal_reward | |
| discovered: bool = False | |
| correct_mass: bool = False | |
| correct_channel: bool = False | |
| correct_spin: bool = False | |
| parsed_ok: bool = False | |
| n_steps: int = 0 | |
| difficulty: Optional[str] = None | |
| def _stepwise_reward( | |
| *, | |
| completion_text: str, | |
| ctx: EpisodeContext, | |
| seed: Optional[int] = None, | |
| difficulty: Optional[str] = None, | |
| scenario: Optional[str] = None, | |
| out_stats: Optional[EpisodeStats] = None, | |
| ) -> float: | |
| """Roll the model's first response through one full episode and | |
| return the cumulative reward (per-step + terminal). | |
| The completion is interpreted as the first action only; subsequent | |
| steps fall back to the safe default policy. This keeps the reward | |
| bandwidth high for early-exploration training without requiring | |
| multi-turn rollouts inside GRPO. | |
| If ``out_stats`` is provided, it is populated in-place with a | |
| per-rollout breakdown (terminal vs shaping reward, success flags) | |
| so the caller can stream component-level metrics into the evidence | |
| log instead of relying only on aggregate reward. | |
| """ | |
| env = ctx.env | |
| obs = env.reset( | |
| seed=seed if seed is not None else ctx.seed, | |
| scenario=scenario if scenario is not None else ctx.scenario, | |
| difficulty=difficulty if difficulty is not None else ctx.difficulty, | |
| ) | |
| parsed = parse_action(completion_text) | |
| action = parsed or safe_default_action(obs) | |
| obs = env.step(action) | |
| cumulative = float(obs.reward or 0.0) | |
| n_steps = 1 | |
| while not obs.done: | |
| fallback = safe_default_action(obs) | |
| obs = env.step(fallback) | |
| cumulative += float(obs.reward or 0.0) | |
| n_steps += 1 | |
| if out_stats is not None: | |
| st = env.state | |
| terminal = float(st.terminal_reward or 0.0) | |
| out_stats.cumulative_reward = cumulative | |
| out_stats.terminal_reward = terminal | |
| out_stats.step_shaping = cumulative - terminal | |
| out_stats.discovered = bool(st.discovered) if st.discovered is not None else False | |
| out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False | |
| out_stats.correct_channel = ( | |
| bool(st.correct_channel) if st.correct_channel is not None else False | |
| ) | |
| out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False | |
| out_stats.parsed_ok = parsed is not None | |
| out_stats.n_steps = n_steps | |
| out_stats.difficulty = st.difficulty | |
| return cumulative | |
| # ── Reward-component accumulator (used by EvidenceCallback) ────────────── | |
| class RewardComponentAccumulator: | |
| """Thread-safe rolling buffer of per-rollout ``EpisodeStats``. | |
| The reward function appends to this; the EvidenceCallback drains it | |
| on each ``on_log`` and writes one summary row to | |
| ``evidence/reward_components.csv``. By pairing each row with the | |
| matching GRPO ``state.global_step``, we can plot per-component reward | |
| curves *aligned* with the loss curve. | |
| """ | |
| def __init__(self) -> None: | |
| self._lock = threading.Lock() | |
| self._buf: List[EpisodeStats] = [] | |
| def append(self, stats: EpisodeStats) -> None: | |
| with self._lock: | |
| self._buf.append(stats) | |
| def drain(self) -> List[EpisodeStats]: | |
| with self._lock: | |
| out, self._buf = self._buf, [] | |
| return out | |
| def summarise(stats: List[EpisodeStats]) -> Dict[str, float]: | |
| if not stats: | |
| return { | |
| "n": 0, | |
| "mean_cumulative": 0.0, | |
| "mean_terminal": 0.0, | |
| "mean_step_shaping": 0.0, | |
| "discovered_rate": 0.0, | |
| "mass_correct_rate": 0.0, | |
| "channel_correct_rate": 0.0, | |
| "spin_correct_rate": 0.0, | |
| "parsed_rate": 0.0, | |
| "mean_n_steps": 0.0, | |
| } | |
| n = len(stats) | |
| return { | |
| "n": n, | |
| "mean_cumulative": sum(s.cumulative_reward for s in stats) / n, | |
| "mean_terminal": sum(s.terminal_reward for s in stats) / n, | |
| "mean_step_shaping": sum(s.step_shaping for s in stats) / n, | |
| "discovered_rate": sum(1 for s in stats if s.discovered) / n, | |
| "mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n, | |
| "channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n, | |
| "spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n, | |
| "parsed_rate": sum(1 for s in stats if s.parsed_ok) / n, | |
| "mean_n_steps": sum(s.n_steps for s in stats) / n, | |
| } | |
| FORMAT_BONUS_VALID = 0.05 # was 0.15 — Fix #3 (lower per-step floor) | |
| FORMAT_BONUS_INVALID = -0.20 # kept punitive so unparseable completions still hurt | |
| def _format_validity_bonus(completion_text: str) -> float: | |
| """Small ± nudge for emitting a structured action. | |
| Kept intentionally small (≪ terminal_scale) so the policy can't be | |
| dominated by a "spam well-formed JSON" objective. After Fix #3 the | |
| positive branch is 1/3 of its v1 value (0.05 vs 0.15) — combined | |
| with the lower step_reward_clip and the heavier repeat-action | |
| penalty, this means a model can no longer farm ~+0.22/step by | |
| looping a single well-formed action. | |
| """ | |
| return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID | |
| def make_reward_fn( | |
| ctx: EpisodeContext, | |
| accumulator: Optional[RewardComponentAccumulator] = None, | |
| ): | |
| """Return a TRL-compatible reward function. | |
| TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``) | |
| as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We | |
| use those here so the rollout used to score completion ``i`` matches | |
| the prompt that produced it, which also unlocks curriculum training. | |
| If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is | |
| appended to it so the trainer's ``on_log`` callback can flush a | |
| per-component summary into the evidence CSV — that's what produces | |
| the "watch individual reward function columns" view recommended in | |
| the hackathon FAQ. | |
| """ | |
| def reward_fn( | |
| prompts: List[str], | |
| completions: List[str], | |
| **kwargs: Any, | |
| ) -> List[float]: | |
| seeds = kwargs.get("seed") | |
| diffs = kwargs.get("difficulty") | |
| scenarios = kwargs.get("scenario") | |
| rewards: List[float] = [] | |
| for i, completion in enumerate(completions): | |
| stats = EpisodeStats() if accumulator is not None else None | |
| r = _stepwise_reward( | |
| completion_text=completion, | |
| ctx=ctx, | |
| seed=int(seeds[i]) if seeds is not None else None, | |
| difficulty=diffs[i] if diffs is not None else None, | |
| scenario=scenarios[i] if scenarios is not None else None, | |
| out_stats=stats, | |
| ) | |
| r += _format_validity_bonus(completion) | |
| rewards.append(float(r)) | |
| if accumulator is not None and stats is not None: | |
| accumulator.append(stats) | |
| return rewards | |
| return reward_fn | |
| # ── Prompt dataset ─────────────────────────────────────────────────────── | |
| DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [ | |
| ("easy", 0.50), | |
| ("medium", 0.30), | |
| ("hard", 0.20), | |
| ] | |
| def curriculum_difficulty_for( | |
| idx: int, | |
| n_prompts: int, | |
| schedule: Optional[List[tuple]] = None, | |
| ) -> str: | |
| """Map an episode index to a difficulty using a deterministic ramp. | |
| A simple "easy first → harder later" schedule (FAQ Q14, help-guide §6) | |
| is enough to keep early-training success rate non-zero, which is the | |
| whole point of curriculum: the policy must occasionally see positive | |
| reward before RL can move probability mass toward it. | |
| """ | |
| sched = schedule or DEFAULT_CURRICULUM_SCHEDULE | |
| boundaries: List[tuple] = [] | |
| cumulative = 0.0 | |
| for diff, frac in sched: | |
| cumulative += frac | |
| boundaries.append((diff, cumulative * n_prompts)) | |
| for diff, upper in boundaries: | |
| if idx < upper: | |
| return diff | |
| return boundaries[-1][0] | |
| def build_dataset( | |
| *, | |
| tokenizer, | |
| n_prompts: int, | |
| seed: int, | |
| scenario: Optional[str], | |
| difficulty: Optional[str], | |
| curriculum: bool = False, | |
| schedule: Optional[List[tuple]] = None, | |
| ) -> "Dataset": | |
| from datasets import Dataset # lazy: heavy import path | |
| env = CERNCollisionEnvironment() | |
| prompts: List[str] = [] | |
| seeds: List[int] = [] | |
| diffs: List[str] = [] | |
| for i in range(n_prompts): | |
| ep_seed = seed + i | |
| ep_diff = ( | |
| curriculum_difficulty_for(i, n_prompts, schedule) | |
| if curriculum else (difficulty or "easy") | |
| ) | |
| obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff) | |
| chat = build_chat(obs) | |
| prompt = tokenizer.apply_chat_template( | |
| chat, add_generation_prompt=True, tokenize=False | |
| ) | |
| prompts.append(prompt) | |
| seeds.append(ep_seed) | |
| diffs.append(ep_diff) | |
| return Dataset.from_dict({ | |
| "prompt": prompts, | |
| "seed": seeds, | |
| "difficulty": diffs, | |
| }) | |
| # ── Main ───────────────────────────────────────────────────────────────── | |
| def main() -> None: # pragma: no cover - training entrypoint | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct") | |
| parser.add_argument("--scenario", default=None) | |
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") | |
| parser.add_argument( | |
| "--curriculum", | |
| action="store_true", | |
| help="Build the prompt set with an easy→medium→hard ramp.", | |
| ) | |
| parser.add_argument("--total_episodes", type=int, default=200) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--max_steps", type=int, default=18) | |
| parser.add_argument("--num_generations", type=int, default=4) | |
| parser.add_argument("--learning_rate", type=float, default=1e-5) | |
| parser.add_argument("--max_prompt_length", type=int, default=1024) | |
| parser.add_argument("--max_completion_length", type=int, default=256) | |
| parser.add_argument("--output_dir", default="training/grpo-output") | |
| parser.add_argument( | |
| "--evidence_dir", | |
| default="evidence", | |
| help="Directory for training_log.csv, reward_components.csv, " | |
| "checkpoint_evals.csv and the corresponding *.png plots.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_eval_steps", | |
| type=int, | |
| default=25, | |
| help="Run a held-out eval every N GRPO updates for the progression curve.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_eval_episodes", | |
| type=int, | |
| default=8, | |
| help="Number of held-out episodes per mid-training eval.", | |
| ) | |
| args = parser.parse_args() | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| except ImportError as exc: # pragma: no cover | |
| raise SystemExit( | |
| "TRL (Transformer Reinforcement Learning) is required: " | |
| "pip install -r requirements-train.txt" | |
| ) from exc | |
| logger.info("Loading tokenizer + model: %s", args.model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_name, | |
| torch_dtype=torch.float32, | |
| ) | |
| logger.info( | |
| "Building prompt dataset (%d prompts, curriculum=%s)", | |
| args.total_episodes, args.curriculum, | |
| ) | |
| dataset = build_dataset( | |
| tokenizer=tokenizer, | |
| n_prompts=args.total_episodes, | |
| seed=args.seed, | |
| scenario=args.scenario, | |
| difficulty=args.difficulty, | |
| curriculum=args.curriculum, | |
| ) | |
| env = CERNCollisionEnvironment(max_steps=args.max_steps) | |
| ctx = EpisodeContext( | |
| env=env, | |
| seed=args.seed, | |
| scenario=args.scenario, | |
| difficulty=args.difficulty, | |
| ) | |
| # ── Evidence wiring (training_log.csv / reward_components.csv / | |
| # checkpoint_evals.csv + PNG plots). Mirrors training_unsloth.py so the | |
| # vanilla GRPO backend hydrates the same dashboard cards. The render | |
| # helpers are best-effort: matplotlib import failures are swallowed and | |
| # the corresponding PNG is skipped, never crashing training. | |
| import time as _time | |
| from transformers import TrainerCallback | |
| from training.evidence import ( | |
| CheckpointEvalWriter, | |
| EvidencePaths, | |
| RewardComponentLogWriter, | |
| TrainingLogWriter, | |
| render_checkpoint_progression, | |
| render_reward_components, | |
| render_training_curve, | |
| ) | |
| from training.llm_agent import LLMAgentConfig | |
| from training.rollouts import collect_episode | |
| paths = EvidencePaths(root=Path(args.evidence_dir)) | |
| paths.ensure() | |
| log_writer = TrainingLogWriter(paths.training_log_csv) | |
| ckpt_writer = CheckpointEvalWriter(paths.checkpoint_evals_csv) | |
| component_writer = RewardComponentLogWriter(paths.reward_components_csv) | |
| component_accumulator = RewardComponentAccumulator() | |
| held_out_seeds = list(range(900_000, 900_000 + args.checkpoint_eval_episodes)) | |
| reward_fn = make_reward_fn(ctx, accumulator=component_accumulator) | |
| cfg = GRPOConfig( | |
| output_dir=args.output_dir, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=2, | |
| num_generations=args.num_generations, | |
| learning_rate=args.learning_rate, | |
| max_prompt_length=args.max_prompt_length, | |
| max_completion_length=args.max_completion_length, | |
| logging_steps=5, | |
| save_steps=50, | |
| seed=args.seed, | |
| bf16=False, | |
| fp16=False, | |
| report_to=[], | |
| ) | |
| class EvidenceCallback(TrainerCallback): | |
| """Stream training metrics + run periodic mid-training evals. | |
| Backported from training/training_unsloth.py so the vanilla GRPO | |
| path produces the same evidence/*.csv + *.png artefacts the | |
| dashboard reads. Differs from the Unsloth version only in the | |
| train/eval mode toggle: plain transformers uses model.eval() / | |
| model.train() instead of FastLanguageModel.for_inference(). | |
| """ | |
| def __init__(self) -> None: | |
| self._t0 = _time.time() | |
| self._last_eval_step = -1 | |
| def on_log(self, _args, state, control, logs=None, **kw): | |
| logs = logs or {} | |
| row = { | |
| "step": state.global_step, | |
| "epoch": logs.get("epoch"), | |
| "loss": logs.get("loss"), | |
| "reward": logs.get("reward") or logs.get("rewards/mean"), | |
| "reward_std": logs.get("reward_std") or logs.get("rewards/std"), | |
| "kl": logs.get("kl"), | |
| "grad_norm": logs.get("grad_norm"), | |
| "learning_rate": logs.get("learning_rate"), | |
| "wall_time_s": round(_time.time() - self._t0, 2), | |
| } | |
| if any(v is not None for k, v in row.items() if k != "step"): | |
| log_writer.append(row) | |
| try: | |
| render_training_curve(paths.training_log_csv, paths.training_curve_png) | |
| except Exception as exc: # pragma: no cover - plotting is best-effort | |
| logger.warning("training curve render failed: %s", exc) | |
| drained = component_accumulator.drain() | |
| if drained: | |
| summary = RewardComponentAccumulator.summarise(drained) | |
| summary["step"] = state.global_step | |
| component_writer.append(summary) | |
| try: | |
| render_reward_components( | |
| paths.reward_components_csv, paths.reward_components_png, | |
| ) | |
| except Exception as exc: # pragma: no cover | |
| logger.warning("reward components render failed: %s", exc) | |
| def on_step_end(self, _args, state, control, **kw): | |
| step = state.global_step | |
| if step <= 0 or step == self._last_eval_step: | |
| return control | |
| if step % args.checkpoint_eval_steps != 0: | |
| return control | |
| self._last_eval_step = step | |
| try: | |
| self._run_checkpoint_eval(step, state) | |
| except Exception as exc: | |
| logger.warning("checkpoint eval failed at step %d: %s", step, exc) | |
| return control | |
| def _run_checkpoint_eval(self, step: int, state) -> None: | |
| was_training = model.training | |
| model.eval() | |
| try: | |
| episodes = [] | |
| for s in held_out_seeds: | |
| ep = self._rollout_one(seed=s) | |
| if ep is not None: | |
| episodes.append(ep) | |
| if not episodes: | |
| return | |
| rewards = [e.cumulative_reward for e in episodes] | |
| success_rate = sum(1 for e in episodes if e.discovered) / len(episodes) | |
| ckpt_writer.append( | |
| step=step, | |
| fraction_done=round(step / max(state.max_steps or step, 1), 4), | |
| episodes=len(episodes), | |
| mean_reward=round(sum(rewards) / len(rewards), 4), | |
| success_rate=round(success_rate, 4), | |
| mass_acc=round( | |
| sum(1 for e in episodes if e.correct_mass) / len(episodes), 4, | |
| ), | |
| channel_acc=round( | |
| sum(1 for e in episodes if e.correct_channel) / len(episodes), 4, | |
| ), | |
| ) | |
| try: | |
| render_checkpoint_progression( | |
| paths.checkpoint_evals_csv, | |
| paths.checkpoint_progression_png, | |
| ) | |
| except Exception as exc: # pragma: no cover | |
| logger.warning("checkpoint progression render failed: %s", exc) | |
| logger.info( | |
| "[checkpoint-eval step=%d] reward=%.3f success=%.2f", | |
| step, | |
| sum(rewards) / len(rewards) if rewards else 0.0, | |
| success_rate, | |
| ) | |
| finally: | |
| if was_training: | |
| model.train() | |
| def _rollout_one(self, seed: int): | |
| def prompt_fn(chat): | |
| return tokenizer.apply_chat_template( | |
| chat, add_generation_prompt=True, tokenize=False, | |
| ) | |
| def generate_fn(prompt: str, _config) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=args.max_completion_length, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| gen = outputs[0][inputs["input_ids"].shape[1]:] | |
| return tokenizer.decode(gen, skip_special_tokens=True) | |
| return collect_episode( | |
| env=env, | |
| seed=seed, | |
| scenario=args.scenario, | |
| difficulty=args.difficulty, | |
| prompt_fn=prompt_fn, | |
| generate_fn=generate_fn, | |
| config=LLMAgentConfig(), | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| train_dataset=dataset, | |
| reward_funcs=[reward_fn], | |
| args=cfg, | |
| callbacks=[EvidenceCallback()], | |
| ) | |
| logger.info("Starting GRPO training") | |
| trainer.train() | |
| # Drain any rollouts the final on_log didn't catch so the last row of | |
| # reward_components.csv reflects the end-of-training state. | |
| final_drain = component_accumulator.drain() | |
| if final_drain: | |
| summary = RewardComponentAccumulator.summarise(final_drain) | |
| summary["step"] = trainer.state.global_step | |
| component_writer.append(summary) | |
| try: | |
| render_reward_components( | |
| paths.reward_components_csv, paths.reward_components_png, | |
| ) | |
| except Exception as exc: # pragma: no cover | |
| logger.warning("final reward components render failed: %s", exc) | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| logger.info("Saved model to %s", args.output_dir) | |
| logger.info("Evidence artifacts in %s", paths.root) | |
| if __name__ == "__main__": # pragma: no cover | |
| main() | |