"""GRPO (Group-Relative Policy Optimization) training script for CERNenv. Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to fine-tune a small instruction-tuned model on full episodes of the CERN environment. Each ``query`` is a prompt sampled from a freshly-reset env; the reward function rolls the model's response through the environment and returns the per-step + (optional) terminal reward. This script is intentionally CPU-friendly and self-contained. For GPU-accelerated training with LoRA, prefer ``training_unsloth.py``. Run: python -m training.training_script \ --model_name HuggingFaceTB/SmolLM2-360M-Instruct \ --total_episodes 200 --max_steps 18 --output_dir training/grpo-output """ from __future__ import annotations import argparse import logging import math import os from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch from datasets import Dataset from transformers import AutoModelForCausalLM, AutoTokenizer from models import ExperimentAction from server.environment import CERNCollisionEnvironment from training.llm_agent import ( LLMAgentConfig, build_chat, parse_action, safe_default_action, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) # ── Episode reward harness ─────────────────────────────────────────────── @dataclass class EpisodeContext: """Per-prompt reusable env + observation snapshot used by the reward fn.""" env: CERNCollisionEnvironment seed: int scenario: Optional[str] difficulty: Optional[str] def _stepwise_reward( *, completion_text: str, ctx: EpisodeContext, ) -> float: """Roll the model's first response through one full episode and return the cumulative reward (per-step + terminal). The completion is interpreted as the first action only; subsequent steps fall back to the safe default policy. This keeps the reward bandwidth high for early-exploration training without requiring multi-turn rollouts inside GRPO. """ env = ctx.env obs = env.reset(seed=ctx.seed, scenario=ctx.scenario, difficulty=ctx.difficulty) action = parse_action(completion_text) or safe_default_action(obs) obs = env.step(action) cumulative = float(obs.reward or 0.0) while not obs.done: fallback = safe_default_action(obs) obs = env.step(fallback) cumulative += float(obs.reward or 0.0) return cumulative def _format_validity_bonus(completion_text: str) -> float: return 0.5 if parse_action(completion_text) is not None else -0.5 def make_reward_fn(ctx: EpisodeContext): """Return a TRL-compatible reward function (closes over ``ctx``).""" def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]: rewards: List[float] = [] for completion in completions: r = _stepwise_reward(completion_text=completion, ctx=ctx) r += _format_validity_bonus(completion) rewards.append(float(r)) return rewards return reward_fn # ── Prompt dataset ─────────────────────────────────────────────────────── def build_dataset( *, tokenizer, n_prompts: int, seed: int, scenario: Optional[str], difficulty: Optional[str], ) -> Dataset: env = CERNCollisionEnvironment() prompts: List[str] = [] for i in range(n_prompts): obs = env.reset(seed=seed + i, scenario=scenario, difficulty=difficulty) chat = build_chat(obs) prompt = tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False ) prompts.append(prompt) return Dataset.from_dict({"prompt": prompts}) # ── Main ───────────────────────────────────────────────────────────────── def main() -> None: # pragma: no cover - training entrypoint parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct") parser.add_argument("--scenario", default=None) parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") parser.add_argument("--total_episodes", type=int, default=200) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--max_steps", type=int, default=18) parser.add_argument("--num_generations", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=1e-5) parser.add_argument("--max_prompt_length", type=int, default=1024) parser.add_argument("--max_completion_length", type=int, default=256) parser.add_argument("--output_dir", default="training/grpo-output") args = parser.parse_args() try: from trl import GRPOConfig, GRPOTrainer except ImportError as exc: # pragma: no cover raise SystemExit( "TRL (Transformer Reinforcement Learning) is required: " "pip install -r requirements-train.txt" ) from exc logger.info("Loading tokenizer + model: %s", args.model_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_name, torch_dtype=torch.float32, ) logger.info("Building prompt dataset (%d prompts)", args.total_episodes) dataset = build_dataset( tokenizer=tokenizer, n_prompts=args.total_episodes, seed=args.seed, scenario=args.scenario, difficulty=args.difficulty, ) env = CERNCollisionEnvironment(max_steps=args.max_steps) ctx = EpisodeContext( env=env, seed=args.seed, scenario=args.scenario, difficulty=args.difficulty, ) reward_fn = make_reward_fn(ctx) cfg = GRPOConfig( output_dir=args.output_dir, per_device_train_batch_size=2, gradient_accumulation_steps=2, num_generations=args.num_generations, learning_rate=args.learning_rate, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, logging_steps=5, save_steps=50, seed=args.seed, bf16=False, fp16=False, report_to=[], ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, train_dataset=dataset, reward_funcs=[reward_fn], args=cfg, ) logger.info("Starting GRPO training") trainer.train() trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) logger.info("Saved model to %s", args.output_dir) if __name__ == "__main__": # pragma: no cover main()