"""GRPO (Group-Relative Policy Optimization) training script for CERNenv. Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to fine-tune a small instruction-tuned model on full episodes of the CERN environment. Each ``query`` is a prompt sampled from a freshly-reset env; the reward function rolls the model's response through the environment and returns the per-step + (optional) terminal reward. This script is intentionally CPU-friendly and self-contained. For GPU-accelerated training with LoRA, prefer ``training_unsloth.py``. Run: python -m training.training_script \ --model_name HuggingFaceTB/SmolLM2-360M-Instruct \ --total_episodes 200 --max_steps 18 --output_dir training/grpo-output """ from __future__ import annotations import argparse import logging import math import os import threading from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, TYPE_CHECKING # Heavy ML deps (torch, datasets, transformers) are imported lazily inside # ``main`` and ``build_dataset`` so the lightweight helpers — reward # function, curriculum schedule, format-validity bonus — remain importable # in environments that only have the env's runtime dependencies (numpy, # pydantic, openenv-core). This keeps ``tests/`` runnable on CPU. from models import ExperimentAction from server.environment import CERNCollisionEnvironment from training.llm_agent import ( LLMAgentConfig, build_chat, parse_action, safe_default_action, ) if TYPE_CHECKING: # pragma: no cover from datasets import Dataset logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) # ── Episode reward harness ─────────────────────────────────────────────── @dataclass class EpisodeContext: """Per-prompt reusable env + default rollout config. ``seed`` and ``difficulty`` here are *fallback* values used when the TRL reward function does not receive per-prompt overrides via dataset columns. With a curriculum-aware dataset we always pass per-prompt ``seed``/``difficulty`` so the reward truly corresponds to the scored prompt. """ env: CERNCollisionEnvironment seed: int scenario: Optional[str] difficulty: Optional[str] @dataclass class EpisodeStats: """Per-rollout reward breakdown surfaced for component-level logging. The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual reward function columns, not just average reward". This struct gives the EvidenceCallback enough information to log each component on its own column so a reviewer (or you) can see *which* reward terms drove the policy update at any given training step. """ cumulative_reward: float = 0.0 terminal_reward: float = 0.0 step_shaping: float = 0.0 # cumulative_reward - terminal_reward discovered: bool = False correct_mass: bool = False correct_channel: bool = False correct_spin: bool = False parsed_ok: bool = False n_steps: int = 0 difficulty: Optional[str] = None def _stepwise_reward( *, completion_text: str, ctx: EpisodeContext, seed: Optional[int] = None, difficulty: Optional[str] = None, scenario: Optional[str] = None, out_stats: Optional[EpisodeStats] = None, ) -> float: """Roll the model's first response through one full episode and return the cumulative reward (per-step + terminal). The completion is interpreted as the first action only; subsequent steps fall back to the safe default policy. This keeps the reward bandwidth high for early-exploration training without requiring multi-turn rollouts inside GRPO. If ``out_stats`` is provided, it is populated in-place with a per-rollout breakdown (terminal vs shaping reward, success flags) so the caller can stream component-level metrics into the evidence log instead of relying only on aggregate reward. """ env = ctx.env obs = env.reset( seed=seed if seed is not None else ctx.seed, scenario=scenario if scenario is not None else ctx.scenario, difficulty=difficulty if difficulty is not None else ctx.difficulty, ) parsed = parse_action(completion_text) action = parsed or safe_default_action(obs) obs = env.step(action) cumulative = float(obs.reward or 0.0) n_steps = 1 while not obs.done: fallback = safe_default_action(obs) obs = env.step(fallback) cumulative += float(obs.reward or 0.0) n_steps += 1 if out_stats is not None: st = env.state terminal = float(st.terminal_reward or 0.0) out_stats.cumulative_reward = cumulative out_stats.terminal_reward = terminal out_stats.step_shaping = cumulative - terminal out_stats.discovered = bool(st.discovered) if st.discovered is not None else False out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False out_stats.correct_channel = ( bool(st.correct_channel) if st.correct_channel is not None else False ) out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False out_stats.parsed_ok = parsed is not None out_stats.n_steps = n_steps out_stats.difficulty = st.difficulty return cumulative # ── Reward-component accumulator (used by EvidenceCallback) ────────────── class RewardComponentAccumulator: """Thread-safe rolling buffer of per-rollout ``EpisodeStats``. The reward function appends to this; the EvidenceCallback drains it on each ``on_log`` and writes one summary row to ``evidence/reward_components.csv``. By pairing each row with the matching GRPO ``state.global_step``, we can plot per-component reward curves *aligned* with the loss curve. """ def __init__(self) -> None: self._lock = threading.Lock() self._buf: List[EpisodeStats] = [] def append(self, stats: EpisodeStats) -> None: with self._lock: self._buf.append(stats) def drain(self) -> List[EpisodeStats]: with self._lock: out, self._buf = self._buf, [] return out @staticmethod def summarise(stats: List[EpisodeStats]) -> Dict[str, float]: if not stats: return { "n": 0, "mean_cumulative": 0.0, "mean_terminal": 0.0, "mean_step_shaping": 0.0, "discovered_rate": 0.0, "mass_correct_rate": 0.0, "channel_correct_rate": 0.0, "spin_correct_rate": 0.0, "parsed_rate": 0.0, "mean_n_steps": 0.0, } n = len(stats) return { "n": n, "mean_cumulative": sum(s.cumulative_reward for s in stats) / n, "mean_terminal": sum(s.terminal_reward for s in stats) / n, "mean_step_shaping": sum(s.step_shaping for s in stats) / n, "discovered_rate": sum(1 for s in stats if s.discovered) / n, "mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n, "channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n, "spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n, "parsed_rate": sum(1 for s in stats if s.parsed_ok) / n, "mean_n_steps": sum(s.n_steps for s in stats) / n, } FORMAT_BONUS_VALID = 0.15 FORMAT_BONUS_INVALID = -0.20 def _format_validity_bonus(completion_text: str) -> float: """Small ± nudge for emitting a structured action. Kept intentionally small (≪ terminal_scale) so the policy can't be dominated by a "spam well-formed JSON" objective. The negative branch is slightly larger than the positive branch so unparseable garbage is dispreferred without crowding out the actual task reward. """ return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID def make_reward_fn( ctx: EpisodeContext, accumulator: Optional[RewardComponentAccumulator] = None, ): """Return a TRL-compatible reward function. TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``) as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We use those here so the rollout used to score completion ``i`` matches the prompt that produced it, which also unlocks curriculum training. If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is appended to it so the trainer's ``on_log`` callback can flush a per-component summary into the evidence CSV — that's what produces the "watch individual reward function columns" view recommended in the hackathon FAQ. """ def reward_fn( prompts: List[str], completions: List[str], **kwargs: Any, ) -> List[float]: seeds = kwargs.get("seed") diffs = kwargs.get("difficulty") scenarios = kwargs.get("scenario") rewards: List[float] = [] for i, completion in enumerate(completions): stats = EpisodeStats() if accumulator is not None else None r = _stepwise_reward( completion_text=completion, ctx=ctx, seed=int(seeds[i]) if seeds is not None else None, difficulty=diffs[i] if diffs is not None else None, scenario=scenarios[i] if scenarios is not None else None, out_stats=stats, ) r += _format_validity_bonus(completion) rewards.append(float(r)) if accumulator is not None and stats is not None: accumulator.append(stats) return rewards return reward_fn # ── Prompt dataset ─────────────────────────────────────────────────────── DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [ ("easy", 0.50), ("medium", 0.30), ("hard", 0.20), ] def curriculum_difficulty_for( idx: int, n_prompts: int, schedule: Optional[List[tuple]] = None, ) -> str: """Map an episode index to a difficulty using a deterministic ramp. A simple "easy first → harder later" schedule (FAQ Q14, help-guide §6) is enough to keep early-training success rate non-zero, which is the whole point of curriculum: the policy must occasionally see positive reward before RL can move probability mass toward it. """ sched = schedule or DEFAULT_CURRICULUM_SCHEDULE boundaries: List[tuple] = [] cumulative = 0.0 for diff, frac in sched: cumulative += frac boundaries.append((diff, cumulative * n_prompts)) for diff, upper in boundaries: if idx < upper: return diff return boundaries[-1][0] def build_dataset( *, tokenizer, n_prompts: int, seed: int, scenario: Optional[str], difficulty: Optional[str], curriculum: bool = False, schedule: Optional[List[tuple]] = None, ) -> "Dataset": from datasets import Dataset # lazy: heavy import path env = CERNCollisionEnvironment() prompts: List[str] = [] seeds: List[int] = [] diffs: List[str] = [] for i in range(n_prompts): ep_seed = seed + i ep_diff = ( curriculum_difficulty_for(i, n_prompts, schedule) if curriculum else (difficulty or "easy") ) obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff) chat = build_chat(obs) prompt = tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False ) prompts.append(prompt) seeds.append(ep_seed) diffs.append(ep_diff) return Dataset.from_dict({ "prompt": prompts, "seed": seeds, "difficulty": diffs, }) # ── Main ───────────────────────────────────────────────────────────────── def main() -> None: # pragma: no cover - training entrypoint parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct") parser.add_argument("--scenario", default=None) parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") parser.add_argument( "--curriculum", action="store_true", help="Build the prompt set with an easy→medium→hard ramp.", ) parser.add_argument("--total_episodes", type=int, default=200) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--max_steps", type=int, default=18) parser.add_argument("--num_generations", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=1e-5) parser.add_argument("--max_prompt_length", type=int, default=1024) parser.add_argument("--max_completion_length", type=int, default=256) parser.add_argument("--output_dir", default="training/grpo-output") args = parser.parse_args() try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOConfig, GRPOTrainer except ImportError as exc: # pragma: no cover raise SystemExit( "TRL (Transformer Reinforcement Learning) is required: " "pip install -r requirements-train.txt" ) from exc logger.info("Loading tokenizer + model: %s", args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_name, torch_dtype=torch.float32, ) logger.info( "Building prompt dataset (%d prompts, curriculum=%s)", args.total_episodes, args.curriculum, ) dataset = build_dataset( tokenizer=tokenizer, n_prompts=args.total_episodes, seed=args.seed, scenario=args.scenario, difficulty=args.difficulty, curriculum=args.curriculum, ) env = CERNCollisionEnvironment(max_steps=args.max_steps) ctx = EpisodeContext( env=env, seed=args.seed, scenario=args.scenario, difficulty=args.difficulty, ) reward_fn = make_reward_fn(ctx) cfg = GRPOConfig( output_dir=args.output_dir, per_device_train_batch_size=2, gradient_accumulation_steps=2, num_generations=args.num_generations, learning_rate=args.learning_rate, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, logging_steps=5, save_steps=50, seed=args.seed, bf16=False, fp16=False, report_to=[], ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, train_dataset=dataset, reward_funcs=[reward_fn], args=cfg, ) logger.info("Starting GRPO training") trainer.train() trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) logger.info("Saved model to %s", args.output_dir) if __name__ == "__main__": # pragma: no cover main()