| """GRPO (Group-Relative Policy Optimization) training script for CERNenv.
|
|
|
| Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
|
| fine-tune a small instruction-tuned model on full episodes of the CERN
|
| environment. Each ``query`` is a prompt sampled from a freshly-reset env;
|
| the reward function rolls the model's response through the environment and
|
| returns the per-step + (optional) terminal reward.
|
|
|
| This script is intentionally CPU-friendly and self-contained. For
|
| GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
|
|
|
| Run:
|
| python -m training.training_script \
|
| --model_name HuggingFaceTB/SmolLM2-360M-Instruct \
|
| --total_episodes 200 --max_steps 18 --output_dir training/grpo-output
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import logging
|
| import math
|
| import os
|
| import threading
|
| from dataclasses import dataclass, field
|
| from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from models import ExperimentAction
|
| from server.environment import CERNCollisionEnvironment
|
| from training.llm_agent import (
|
| LLMAgentConfig,
|
| build_chat,
|
| parse_action,
|
| safe_default_action,
|
| )
|
|
|
| if TYPE_CHECKING:
|
| from datasets import Dataset
|
|
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class EpisodeContext:
|
| """Per-prompt reusable env + default rollout config.
|
|
|
| ``seed`` and ``difficulty`` here are *fallback* values used when the
|
| TRL reward function does not receive per-prompt overrides via dataset
|
| columns. With a curriculum-aware dataset we always pass per-prompt
|
| ``seed``/``difficulty`` so the reward truly corresponds to the
|
| scored prompt.
|
| """
|
|
|
| env: CERNCollisionEnvironment
|
| seed: int
|
| scenario: Optional[str]
|
| difficulty: Optional[str]
|
|
|
|
|
| @dataclass
|
| class EpisodeStats:
|
| """Per-rollout reward breakdown surfaced for component-level logging.
|
|
|
| The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual
|
| reward function columns, not just average reward". This struct gives
|
| the EvidenceCallback enough information to log each component on its
|
| own column so a reviewer (or you) can see *which* reward terms drove
|
| the policy update at any given training step.
|
| """
|
|
|
| cumulative_reward: float = 0.0
|
| terminal_reward: float = 0.0
|
| step_shaping: float = 0.0
|
| discovered: bool = False
|
| correct_mass: bool = False
|
| correct_channel: bool = False
|
| correct_spin: bool = False
|
| parsed_ok: bool = False
|
| n_steps: int = 0
|
| difficulty: Optional[str] = None
|
|
|
|
|
| def _stepwise_reward(
|
| *,
|
| completion_text: str,
|
| ctx: EpisodeContext,
|
| seed: Optional[int] = None,
|
| difficulty: Optional[str] = None,
|
| scenario: Optional[str] = None,
|
| out_stats: Optional[EpisodeStats] = None,
|
| ) -> float:
|
| """Roll the model's first response through one full episode and
|
| return the cumulative reward (per-step + terminal).
|
|
|
| The completion is interpreted as the first action only; subsequent
|
| steps fall back to the safe default policy. This keeps the reward
|
| bandwidth high for early-exploration training without requiring
|
| multi-turn rollouts inside GRPO.
|
|
|
| If ``out_stats`` is provided, it is populated in-place with a
|
| per-rollout breakdown (terminal vs shaping reward, success flags)
|
| so the caller can stream component-level metrics into the evidence
|
| log instead of relying only on aggregate reward.
|
| """
|
| env = ctx.env
|
| obs = env.reset(
|
| seed=seed if seed is not None else ctx.seed,
|
| scenario=scenario if scenario is not None else ctx.scenario,
|
| difficulty=difficulty if difficulty is not None else ctx.difficulty,
|
| )
|
|
|
| parsed = parse_action(completion_text)
|
| action = parsed or safe_default_action(obs)
|
| obs = env.step(action)
|
| cumulative = float(obs.reward or 0.0)
|
| n_steps = 1
|
|
|
| while not obs.done:
|
| fallback = safe_default_action(obs)
|
| obs = env.step(fallback)
|
| cumulative += float(obs.reward or 0.0)
|
| n_steps += 1
|
|
|
| if out_stats is not None:
|
| st = env.state
|
| terminal = float(st.terminal_reward or 0.0)
|
| out_stats.cumulative_reward = cumulative
|
| out_stats.terminal_reward = terminal
|
| out_stats.step_shaping = cumulative - terminal
|
| out_stats.discovered = bool(st.discovered) if st.discovered is not None else False
|
| out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False
|
| out_stats.correct_channel = (
|
| bool(st.correct_channel) if st.correct_channel is not None else False
|
| )
|
| out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False
|
| out_stats.parsed_ok = parsed is not None
|
| out_stats.n_steps = n_steps
|
| out_stats.difficulty = st.difficulty
|
|
|
| return cumulative
|
|
|
|
|
|
|
|
|
|
|
| class RewardComponentAccumulator:
|
| """Thread-safe rolling buffer of per-rollout ``EpisodeStats``.
|
|
|
| The reward function appends to this; the EvidenceCallback drains it
|
| on each ``on_log`` and writes one summary row to
|
| ``evidence/reward_components.csv``. By pairing each row with the
|
| matching GRPO ``state.global_step``, we can plot per-component reward
|
| curves *aligned* with the loss curve.
|
| """
|
|
|
| def __init__(self) -> None:
|
| self._lock = threading.Lock()
|
| self._buf: List[EpisodeStats] = []
|
|
|
| def append(self, stats: EpisodeStats) -> None:
|
| with self._lock:
|
| self._buf.append(stats)
|
|
|
| def drain(self) -> List[EpisodeStats]:
|
| with self._lock:
|
| out, self._buf = self._buf, []
|
| return out
|
|
|
| @staticmethod
|
| def summarise(stats: List[EpisodeStats]) -> Dict[str, float]:
|
| if not stats:
|
| return {
|
| "n": 0,
|
| "mean_cumulative": 0.0,
|
| "mean_terminal": 0.0,
|
| "mean_step_shaping": 0.0,
|
| "discovered_rate": 0.0,
|
| "mass_correct_rate": 0.0,
|
| "channel_correct_rate": 0.0,
|
| "spin_correct_rate": 0.0,
|
| "parsed_rate": 0.0,
|
| "mean_n_steps": 0.0,
|
| }
|
| n = len(stats)
|
| return {
|
| "n": n,
|
| "mean_cumulative": sum(s.cumulative_reward for s in stats) / n,
|
| "mean_terminal": sum(s.terminal_reward for s in stats) / n,
|
| "mean_step_shaping": sum(s.step_shaping for s in stats) / n,
|
| "discovered_rate": sum(1 for s in stats if s.discovered) / n,
|
| "mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n,
|
| "channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n,
|
| "spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n,
|
| "parsed_rate": sum(1 for s in stats if s.parsed_ok) / n,
|
| "mean_n_steps": sum(s.n_steps for s in stats) / n,
|
| }
|
|
|
|
|
| FORMAT_BONUS_VALID = 0.15
|
| FORMAT_BONUS_INVALID = -0.20
|
|
|
|
|
| def _format_validity_bonus(completion_text: str) -> float:
|
| """Small ± nudge for emitting a structured action.
|
|
|
| Kept intentionally small (≪ terminal_scale) so the policy can't be
|
| dominated by a "spam well-formed JSON" objective. The negative branch
|
| is slightly larger than the positive branch so unparseable garbage
|
| is dispreferred without crowding out the actual task reward.
|
| """
|
| return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID
|
|
|
|
|
| def make_reward_fn(
|
| ctx: EpisodeContext,
|
| accumulator: Optional[RewardComponentAccumulator] = None,
|
| ):
|
| """Return a TRL-compatible reward function.
|
|
|
| TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``)
|
| as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We
|
| use those here so the rollout used to score completion ``i`` matches
|
| the prompt that produced it, which also unlocks curriculum training.
|
|
|
| If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is
|
| appended to it so the trainer's ``on_log`` callback can flush a
|
| per-component summary into the evidence CSV — that's what produces
|
| the "watch individual reward function columns" view recommended in
|
| the hackathon FAQ.
|
| """
|
|
|
| def reward_fn(
|
| prompts: List[str],
|
| completions: List[str],
|
| **kwargs: Any,
|
| ) -> List[float]:
|
| seeds = kwargs.get("seed")
|
| diffs = kwargs.get("difficulty")
|
| scenarios = kwargs.get("scenario")
|
| rewards: List[float] = []
|
| for i, completion in enumerate(completions):
|
| stats = EpisodeStats() if accumulator is not None else None
|
| r = _stepwise_reward(
|
| completion_text=completion,
|
| ctx=ctx,
|
| seed=int(seeds[i]) if seeds is not None else None,
|
| difficulty=diffs[i] if diffs is not None else None,
|
| scenario=scenarios[i] if scenarios is not None else None,
|
| out_stats=stats,
|
| )
|
| r += _format_validity_bonus(completion)
|
| rewards.append(float(r))
|
| if accumulator is not None and stats is not None:
|
| accumulator.append(stats)
|
| return rewards
|
|
|
| return reward_fn
|
|
|
|
|
|
|
|
|
|
|
| DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [
|
| ("easy", 0.50),
|
| ("medium", 0.30),
|
| ("hard", 0.20),
|
| ]
|
|
|
|
|
| def curriculum_difficulty_for(
|
| idx: int,
|
| n_prompts: int,
|
| schedule: Optional[List[tuple]] = None,
|
| ) -> str:
|
| """Map an episode index to a difficulty using a deterministic ramp.
|
|
|
| A simple "easy first → harder later" schedule (FAQ Q14, help-guide §6)
|
| is enough to keep early-training success rate non-zero, which is the
|
| whole point of curriculum: the policy must occasionally see positive
|
| reward before RL can move probability mass toward it.
|
| """
|
| sched = schedule or DEFAULT_CURRICULUM_SCHEDULE
|
| boundaries: List[tuple] = []
|
| cumulative = 0.0
|
| for diff, frac in sched:
|
| cumulative += frac
|
| boundaries.append((diff, cumulative * n_prompts))
|
| for diff, upper in boundaries:
|
| if idx < upper:
|
| return diff
|
| return boundaries[-1][0]
|
|
|
|
|
| def build_dataset(
|
| *,
|
| tokenizer,
|
| n_prompts: int,
|
| seed: int,
|
| scenario: Optional[str],
|
| difficulty: Optional[str],
|
| curriculum: bool = False,
|
| schedule: Optional[List[tuple]] = None,
|
| ) -> "Dataset":
|
| from datasets import Dataset
|
|
|
| env = CERNCollisionEnvironment()
|
| prompts: List[str] = []
|
| seeds: List[int] = []
|
| diffs: List[str] = []
|
| for i in range(n_prompts):
|
| ep_seed = seed + i
|
| ep_diff = (
|
| curriculum_difficulty_for(i, n_prompts, schedule)
|
| if curriculum else (difficulty or "easy")
|
| )
|
| obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff)
|
| chat = build_chat(obs)
|
| prompt = tokenizer.apply_chat_template(
|
| chat, add_generation_prompt=True, tokenize=False
|
| )
|
| prompts.append(prompt)
|
| seeds.append(ep_seed)
|
| diffs.append(ep_diff)
|
| return Dataset.from_dict({
|
| "prompt": prompts,
|
| "seed": seeds,
|
| "difficulty": diffs,
|
| })
|
|
|
|
|
|
|
|
|
|
|
| def main() -> None:
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
|
| parser.add_argument("--scenario", default=None)
|
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
| parser.add_argument(
|
| "--curriculum",
|
| action="store_true",
|
| help="Build the prompt set with an easy→medium→hard ramp.",
|
| )
|
| parser.add_argument("--total_episodes", type=int, default=200)
|
| parser.add_argument("--seed", type=int, default=42)
|
| parser.add_argument("--max_steps", type=int, default=18)
|
| parser.add_argument("--num_generations", type=int, default=4)
|
| parser.add_argument("--learning_rate", type=float, default=1e-5)
|
| parser.add_argument("--max_prompt_length", type=int, default=1024)
|
| parser.add_argument("--max_completion_length", type=int, default=256)
|
| parser.add_argument("--output_dir", default="training/grpo-output")
|
| args = parser.parse_args()
|
|
|
| try:
|
| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from trl import GRPOConfig, GRPOTrainer
|
| except ImportError as exc:
|
| raise SystemExit(
|
| "TRL (Transformer Reinforcement Learning) is required: "
|
| "pip install -r requirements-train.txt"
|
| ) from exc
|
|
|
| logger.info("Loading tokenizer + model: %s", args.model_name)
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
| model = AutoModelForCausalLM.from_pretrained(
|
| args.model_name,
|
| torch_dtype=torch.float32,
|
| )
|
|
|
| logger.info(
|
| "Building prompt dataset (%d prompts, curriculum=%s)",
|
| args.total_episodes, args.curriculum,
|
| )
|
| dataset = build_dataset(
|
| tokenizer=tokenizer,
|
| n_prompts=args.total_episodes,
|
| seed=args.seed,
|
| scenario=args.scenario,
|
| difficulty=args.difficulty,
|
| curriculum=args.curriculum,
|
| )
|
|
|
| env = CERNCollisionEnvironment(max_steps=args.max_steps)
|
| ctx = EpisodeContext(
|
| env=env,
|
| seed=args.seed,
|
| scenario=args.scenario,
|
| difficulty=args.difficulty,
|
| )
|
| reward_fn = make_reward_fn(ctx)
|
|
|
| cfg = GRPOConfig(
|
| output_dir=args.output_dir,
|
| per_device_train_batch_size=2,
|
| gradient_accumulation_steps=2,
|
| num_generations=args.num_generations,
|
| learning_rate=args.learning_rate,
|
| max_prompt_length=args.max_prompt_length,
|
| max_completion_length=args.max_completion_length,
|
| logging_steps=5,
|
| save_steps=50,
|
| seed=args.seed,
|
| bf16=False,
|
| fp16=False,
|
| report_to=[],
|
| )
|
|
|
| trainer = GRPOTrainer(
|
| model=model,
|
| processing_class=tokenizer,
|
| train_dataset=dataset,
|
| reward_funcs=[reward_fn],
|
| args=cfg,
|
| )
|
| logger.info("Starting GRPO training")
|
| trainer.train()
|
| trainer.save_model(args.output_dir)
|
| tokenizer.save_pretrained(args.output_dir)
|
| logger.info("Saved model to %s", args.output_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|