"""GRPO (Group-Relative Policy Optimization) training script for CERNenv. Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to fine-tune a small instruction-tuned model on full episodes of the CERN environment. Each ``query`` is a prompt sampled from a freshly-reset env; the reward function rolls the model's response through the environment and returns the per-step + (optional) terminal reward. This script is intentionally CPU-friendly and self-contained. For GPU-accelerated training with LoRA, prefer ``training_unsloth.py``. Run: python -m training.training_script \ --model_name HuggingFaceTB/SmolLM2-360M-Instruct \ --total_episodes 200 --max_steps 18 --output_dir training/grpo-output """ from __future__ import annotations import argparse import logging import math import os import threading from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, TYPE_CHECKING # Heavy ML deps (torch, datasets, transformers) are imported lazily inside # ``main`` and ``build_dataset`` so the lightweight helpers — reward # function, curriculum schedule, format-validity bonus — remain importable # in environments that only have the env's runtime dependencies (numpy, # pydantic, openenv-core). This keeps ``tests/`` runnable on CPU. from models import ExperimentAction from server.environment import CERNCollisionEnvironment from training.llm_agent import ( LLMAgentConfig, build_chat, parse_action, safe_default_action, ) if TYPE_CHECKING: # pragma: no cover from datasets import Dataset logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) # ── Episode reward harness ─────────────────────────────────────────────── @dataclass class EpisodeContext: """Per-prompt reusable env + default rollout config. ``seed`` and ``difficulty`` here are *fallback* values used when the TRL reward function does not receive per-prompt overrides via dataset columns. With a curriculum-aware dataset we always pass per-prompt ``seed``/``difficulty`` so the reward truly corresponds to the scored prompt. """ env: CERNCollisionEnvironment seed: int scenario: Optional[str] difficulty: Optional[str] @dataclass class EpisodeStats: """Per-rollout reward breakdown surfaced for component-level logging. The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual reward function columns, not just average reward". This struct gives the EvidenceCallback enough information to log each component on its own column so a reviewer (or you) can see *which* reward terms drove the policy update at any given training step. """ cumulative_reward: float = 0.0 terminal_reward: float = 0.0 step_shaping: float = 0.0 # cumulative_reward - terminal_reward discovered: bool = False correct_mass: bool = False correct_channel: bool = False correct_spin: bool = False parsed_ok: bool = False n_steps: int = 0 difficulty: Optional[str] = None def _stepwise_reward( *, completion_text: str, ctx: EpisodeContext, seed: Optional[int] = None, difficulty: Optional[str] = None, scenario: Optional[str] = None, out_stats: Optional[EpisodeStats] = None, ) -> float: """Roll the model's first response through one full episode and return the cumulative reward (per-step + terminal). The completion is interpreted as the first action only; subsequent steps fall back to the safe default policy. This keeps the reward bandwidth high for early-exploration training without requiring multi-turn rollouts inside GRPO. If ``out_stats`` is provided, it is populated in-place with a per-rollout breakdown (terminal vs shaping reward, success flags) so the caller can stream component-level metrics into the evidence log instead of relying only on aggregate reward. """ env = ctx.env obs = env.reset( seed=seed if seed is not None else ctx.seed, scenario=scenario if scenario is not None else ctx.scenario, difficulty=difficulty if difficulty is not None else ctx.difficulty, ) parsed = parse_action(completion_text) action = parsed or safe_default_action(obs) obs = env.step(action) cumulative = float(obs.reward or 0.0) n_steps = 1 while not obs.done: fallback = safe_default_action(obs) obs = env.step(fallback) cumulative += float(obs.reward or 0.0) n_steps += 1 if out_stats is not None: st = env.state terminal = float(st.terminal_reward or 0.0) out_stats.cumulative_reward = cumulative out_stats.terminal_reward = terminal out_stats.step_shaping = cumulative - terminal out_stats.discovered = bool(st.discovered) if st.discovered is not None else False out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False out_stats.correct_channel = ( bool(st.correct_channel) if st.correct_channel is not None else False ) out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False out_stats.parsed_ok = parsed is not None out_stats.n_steps = n_steps out_stats.difficulty = st.difficulty return cumulative # ── Reward-component accumulator (used by EvidenceCallback) ────────────── class RewardComponentAccumulator: """Thread-safe rolling buffer of per-rollout ``EpisodeStats``. The reward function appends to this; the EvidenceCallback drains it on each ``on_log`` and writes one summary row to ``evidence/reward_components.csv``. By pairing each row with the matching GRPO ``state.global_step``, we can plot per-component reward curves *aligned* with the loss curve. """ def __init__(self) -> None: self._lock = threading.Lock() self._buf: List[EpisodeStats] = [] def append(self, stats: EpisodeStats) -> None: with self._lock: self._buf.append(stats) def drain(self) -> List[EpisodeStats]: with self._lock: out, self._buf = self._buf, [] return out @staticmethod def summarise(stats: List[EpisodeStats]) -> Dict[str, float]: if not stats: return { "n": 0, "mean_cumulative": 0.0, "mean_terminal": 0.0, "mean_step_shaping": 0.0, "discovered_rate": 0.0, "mass_correct_rate": 0.0, "channel_correct_rate": 0.0, "spin_correct_rate": 0.0, "parsed_rate": 0.0, "mean_n_steps": 0.0, } n = len(stats) return { "n": n, "mean_cumulative": sum(s.cumulative_reward for s in stats) / n, "mean_terminal": sum(s.terminal_reward for s in stats) / n, "mean_step_shaping": sum(s.step_shaping for s in stats) / n, "discovered_rate": sum(1 for s in stats if s.discovered) / n, "mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n, "channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n, "spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n, "parsed_rate": sum(1 for s in stats if s.parsed_ok) / n, "mean_n_steps": sum(s.n_steps for s in stats) / n, } FORMAT_BONUS_VALID = 0.05 # was 0.15 — Fix #3 (lower per-step floor) FORMAT_BONUS_INVALID = -0.20 # kept punitive so unparseable completions still hurt def _format_validity_bonus(completion_text: str) -> float: """Small ± nudge for emitting a structured action. Kept intentionally small (≪ terminal_scale) so the policy can't be dominated by a "spam well-formed JSON" objective. After Fix #3 the positive branch is 1/3 of its v1 value (0.05 vs 0.15) — combined with the lower step_reward_clip and the heavier repeat-action penalty, this means a model can no longer farm ~+0.22/step by looping a single well-formed action. """ return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID def make_reward_fn( ctx: EpisodeContext, accumulator: Optional[RewardComponentAccumulator] = None, ): """Return a TRL-compatible reward function. TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``) as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We use those here so the rollout used to score completion ``i`` matches the prompt that produced it, which also unlocks curriculum training. If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is appended to it so the trainer's ``on_log`` callback can flush a per-component summary into the evidence CSV — that's what produces the "watch individual reward function columns" view recommended in the hackathon FAQ. """ def reward_fn( prompts: List[str], completions: List[str], **kwargs: Any, ) -> List[float]: seeds = kwargs.get("seed") diffs = kwargs.get("difficulty") scenarios = kwargs.get("scenario") rewards: List[float] = [] for i, completion in enumerate(completions): stats = EpisodeStats() if accumulator is not None else None r = _stepwise_reward( completion_text=completion, ctx=ctx, seed=int(seeds[i]) if seeds is not None else None, difficulty=diffs[i] if diffs is not None else None, scenario=scenarios[i] if scenarios is not None else None, out_stats=stats, ) r += _format_validity_bonus(completion) rewards.append(float(r)) if accumulator is not None and stats is not None: accumulator.append(stats) return rewards return reward_fn # ── Prompt dataset ─────────────────────────────────────────────────────── DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [ ("easy", 0.50), ("medium", 0.30), ("hard", 0.20), ] def curriculum_difficulty_for( idx: int, n_prompts: int, schedule: Optional[List[tuple]] = None, ) -> str: """Map an episode index to a difficulty using a deterministic ramp. A simple "easy first → harder later" schedule (FAQ Q14, help-guide §6) is enough to keep early-training success rate non-zero, which is the whole point of curriculum: the policy must occasionally see positive reward before RL can move probability mass toward it. """ sched = schedule or DEFAULT_CURRICULUM_SCHEDULE boundaries: List[tuple] = [] cumulative = 0.0 for diff, frac in sched: cumulative += frac boundaries.append((diff, cumulative * n_prompts)) for diff, upper in boundaries: if idx < upper: return diff return boundaries[-1][0] def build_dataset( *, tokenizer, n_prompts: int, seed: int, scenario: Optional[str], difficulty: Optional[str], curriculum: bool = False, schedule: Optional[List[tuple]] = None, ) -> "Dataset": from datasets import Dataset # lazy: heavy import path env = CERNCollisionEnvironment() prompts: List[str] = [] seeds: List[int] = [] diffs: List[str] = [] for i in range(n_prompts): ep_seed = seed + i ep_diff = ( curriculum_difficulty_for(i, n_prompts, schedule) if curriculum else (difficulty or "easy") ) obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff) chat = build_chat(obs) prompt = tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False ) prompts.append(prompt) seeds.append(ep_seed) diffs.append(ep_diff) return Dataset.from_dict({ "prompt": prompts, "seed": seeds, "difficulty": diffs, }) # ── Main ───────────────────────────────────────────────────────────────── def main() -> None: # pragma: no cover - training entrypoint parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct") parser.add_argument("--scenario", default=None) parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") parser.add_argument( "--curriculum", action="store_true", help="Build the prompt set with an easy→medium→hard ramp.", ) parser.add_argument("--total_episodes", type=int, default=200) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--max_steps", type=int, default=18) parser.add_argument("--num_generations", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=1e-5) parser.add_argument("--max_prompt_length", type=int, default=1024) parser.add_argument("--max_completion_length", type=int, default=256) parser.add_argument("--output_dir", default="training/grpo-output") parser.add_argument( "--evidence_dir", default="evidence", help="Directory for training_log.csv, reward_components.csv, " "checkpoint_evals.csv and the corresponding *.png plots.", ) parser.add_argument( "--checkpoint_eval_steps", type=int, default=25, help="Run a held-out eval every N GRPO updates for the progression curve.", ) parser.add_argument( "--checkpoint_eval_episodes", type=int, default=8, help="Number of held-out episodes per mid-training eval.", ) args = parser.parse_args() try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOConfig, GRPOTrainer except ImportError as exc: # pragma: no cover raise SystemExit( "TRL (Transformer Reinforcement Learning) is required: " "pip install -r requirements-train.txt" ) from exc logger.info("Loading tokenizer + model: %s", args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_name, torch_dtype=torch.float32, ) logger.info( "Building prompt dataset (%d prompts, curriculum=%s)", args.total_episodes, args.curriculum, ) dataset = build_dataset( tokenizer=tokenizer, n_prompts=args.total_episodes, seed=args.seed, scenario=args.scenario, difficulty=args.difficulty, curriculum=args.curriculum, ) env = CERNCollisionEnvironment(max_steps=args.max_steps) ctx = EpisodeContext( env=env, seed=args.seed, scenario=args.scenario, difficulty=args.difficulty, ) # ── Evidence wiring (training_log.csv / reward_components.csv / # checkpoint_evals.csv + PNG plots). Mirrors training_unsloth.py so the # vanilla GRPO backend hydrates the same dashboard cards. The render # helpers are best-effort: matplotlib import failures are swallowed and # the corresponding PNG is skipped, never crashing training. import time as _time from transformers import TrainerCallback from training.evidence import ( CheckpointEvalWriter, EvidencePaths, RewardComponentLogWriter, TrainingLogWriter, render_checkpoint_progression, render_reward_components, render_training_curve, ) from training.llm_agent import LLMAgentConfig from training.rollouts import collect_episode paths = EvidencePaths(root=Path(args.evidence_dir)) paths.ensure() log_writer = TrainingLogWriter(paths.training_log_csv) ckpt_writer = CheckpointEvalWriter(paths.checkpoint_evals_csv) component_writer = RewardComponentLogWriter(paths.reward_components_csv) component_accumulator = RewardComponentAccumulator() held_out_seeds = list(range(900_000, 900_000 + args.checkpoint_eval_episodes)) reward_fn = make_reward_fn(ctx, accumulator=component_accumulator) cfg = GRPOConfig( output_dir=args.output_dir, per_device_train_batch_size=2, gradient_accumulation_steps=2, num_generations=args.num_generations, learning_rate=args.learning_rate, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, logging_steps=5, save_steps=50, seed=args.seed, bf16=False, fp16=False, report_to=[], ) class EvidenceCallback(TrainerCallback): """Stream training metrics + run periodic mid-training evals. Backported from training/training_unsloth.py so the vanilla GRPO path produces the same evidence/*.csv + *.png artefacts the dashboard reads. Differs from the Unsloth version only in the train/eval mode toggle: plain transformers uses model.eval() / model.train() instead of FastLanguageModel.for_inference(). """ def __init__(self) -> None: self._t0 = _time.time() self._last_eval_step = -1 def on_log(self, _args, state, control, logs=None, **kw): logs = logs or {} row = { "step": state.global_step, "epoch": logs.get("epoch"), "loss": logs.get("loss"), "reward": logs.get("reward") or logs.get("rewards/mean"), "reward_std": logs.get("reward_std") or logs.get("rewards/std"), "kl": logs.get("kl"), "grad_norm": logs.get("grad_norm"), "learning_rate": logs.get("learning_rate"), "wall_time_s": round(_time.time() - self._t0, 2), } if any(v is not None for k, v in row.items() if k != "step"): log_writer.append(row) try: render_training_curve(paths.training_log_csv, paths.training_curve_png) except Exception as exc: # pragma: no cover - plotting is best-effort logger.warning("training curve render failed: %s", exc) drained = component_accumulator.drain() if drained: summary = RewardComponentAccumulator.summarise(drained) summary["step"] = state.global_step component_writer.append(summary) try: render_reward_components( paths.reward_components_csv, paths.reward_components_png, ) except Exception as exc: # pragma: no cover logger.warning("reward components render failed: %s", exc) def on_step_end(self, _args, state, control, **kw): step = state.global_step if step <= 0 or step == self._last_eval_step: return control if step % args.checkpoint_eval_steps != 0: return control self._last_eval_step = step try: self._run_checkpoint_eval(step, state) except Exception as exc: logger.warning("checkpoint eval failed at step %d: %s", step, exc) return control def _run_checkpoint_eval(self, step: int, state) -> None: was_training = model.training model.eval() try: episodes = [] for s in held_out_seeds: ep = self._rollout_one(seed=s) if ep is not None: episodes.append(ep) if not episodes: return rewards = [e.cumulative_reward for e in episodes] success_rate = sum(1 for e in episodes if e.discovered) / len(episodes) ckpt_writer.append( step=step, fraction_done=round(step / max(state.max_steps or step, 1), 4), episodes=len(episodes), mean_reward=round(sum(rewards) / len(rewards), 4), success_rate=round(success_rate, 4), mass_acc=round( sum(1 for e in episodes if e.correct_mass) / len(episodes), 4, ), channel_acc=round( sum(1 for e in episodes if e.correct_channel) / len(episodes), 4, ), ) try: render_checkpoint_progression( paths.checkpoint_evals_csv, paths.checkpoint_progression_png, ) except Exception as exc: # pragma: no cover logger.warning("checkpoint progression render failed: %s", exc) logger.info( "[checkpoint-eval step=%d] reward=%.3f success=%.2f", step, sum(rewards) / len(rewards) if rewards else 0.0, success_rate, ) finally: if was_training: model.train() def _rollout_one(self, seed: int): def prompt_fn(chat): return tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False, ) def generate_fn(prompt: str, _config) -> str: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=args.max_completion_length, do_sample=True, temperature=0.7, top_p=0.95, pad_token_id=tokenizer.pad_token_id, ) gen = outputs[0][inputs["input_ids"].shape[1]:] return tokenizer.decode(gen, skip_special_tokens=True) return collect_episode( env=env, seed=seed, scenario=args.scenario, difficulty=args.difficulty, prompt_fn=prompt_fn, generate_fn=generate_fn, config=LLMAgentConfig(), ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, train_dataset=dataset, reward_funcs=[reward_fn], args=cfg, callbacks=[EvidenceCallback()], ) logger.info("Starting GRPO training") trainer.train() # Drain any rollouts the final on_log didn't catch so the last row of # reward_components.csv reflects the end-of-training state. final_drain = component_accumulator.drain() if final_drain: summary = RewardComponentAccumulator.summarise(final_drain) summary["step"] = trainer.state.global_step component_writer.append(summary) try: render_reward_components( paths.reward_components_csv, paths.reward_components_png, ) except Exception as exc: # pragma: no cover logger.warning("final reward components render failed: %s", exc) trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) logger.info("Saved model to %s", args.output_dir) logger.info("Evidence artifacts in %s", paths.root) if __name__ == "__main__": # pragma: no cover main()