| """GRPO (Group-Relative Policy Optimization) training script for CERNenv.
|
|
|
| Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
|
| fine-tune a small instruction-tuned model on full episodes of the CERN
|
| environment. Each ``query`` is a prompt sampled from a freshly-reset env;
|
| the reward function rolls the model's response through the environment and
|
| returns the per-step + (optional) terminal reward.
|
|
|
| This script is intentionally CPU-friendly and self-contained. For
|
| GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
|
|
|
| Run:
|
| python -m training.training_script \
|
| --model_name HuggingFaceTB/SmolLM2-360M-Instruct \
|
| --total_episodes 200 --max_steps 18 --output_dir training/grpo-output
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import logging
|
| import math
|
| import os
|
| from dataclasses import dataclass
|
| from typing import Any, Dict, List, Optional
|
|
|
| import torch
|
| from datasets import Dataset
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
| from models import ExperimentAction
|
| from server.environment import CERNCollisionEnvironment
|
| from training.llm_agent import (
|
| LLMAgentConfig,
|
| build_chat,
|
| parse_action,
|
| safe_default_action,
|
| )
|
|
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class EpisodeContext:
|
| """Per-prompt reusable env + observation snapshot used by the reward fn."""
|
|
|
| env: CERNCollisionEnvironment
|
| seed: int
|
| scenario: Optional[str]
|
| difficulty: Optional[str]
|
|
|
|
|
| def _stepwise_reward(
|
| *,
|
| completion_text: str,
|
| ctx: EpisodeContext,
|
| ) -> float:
|
| """Roll the model's first response through one full episode and
|
| return the cumulative reward (per-step + terminal).
|
|
|
| The completion is interpreted as the first action only; subsequent
|
| steps fall back to the safe default policy. This keeps the reward
|
| bandwidth high for early-exploration training without requiring
|
| multi-turn rollouts inside GRPO.
|
| """
|
| env = ctx.env
|
| obs = env.reset(seed=ctx.seed, scenario=ctx.scenario, difficulty=ctx.difficulty)
|
|
|
| action = parse_action(completion_text) or safe_default_action(obs)
|
| obs = env.step(action)
|
| cumulative = float(obs.reward or 0.0)
|
|
|
| while not obs.done:
|
| fallback = safe_default_action(obs)
|
| obs = env.step(fallback)
|
| cumulative += float(obs.reward or 0.0)
|
|
|
| return cumulative
|
|
|
|
|
| def _format_validity_bonus(completion_text: str) -> float:
|
| return 0.5 if parse_action(completion_text) is not None else -0.5
|
|
|
|
|
| def make_reward_fn(ctx: EpisodeContext):
|
| """Return a TRL-compatible reward function (closes over ``ctx``)."""
|
|
|
| def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
|
| rewards: List[float] = []
|
| for completion in completions:
|
| r = _stepwise_reward(completion_text=completion, ctx=ctx)
|
| r += _format_validity_bonus(completion)
|
| rewards.append(float(r))
|
| return rewards
|
|
|
| return reward_fn
|
|
|
|
|
|
|
|
|
|
|
| def build_dataset(
|
| *,
|
| tokenizer,
|
| n_prompts: int,
|
| seed: int,
|
| scenario: Optional[str],
|
| difficulty: Optional[str],
|
| ) -> Dataset:
|
| env = CERNCollisionEnvironment()
|
| prompts: List[str] = []
|
| for i in range(n_prompts):
|
| obs = env.reset(seed=seed + i, scenario=scenario, difficulty=difficulty)
|
| chat = build_chat(obs)
|
| prompt = tokenizer.apply_chat_template(
|
| chat, add_generation_prompt=True, tokenize=False
|
| )
|
| prompts.append(prompt)
|
| return Dataset.from_dict({"prompt": prompts})
|
|
|
|
|
|
|
|
|
|
|
| def main() -> None:
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
|
| parser.add_argument("--scenario", default=None)
|
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
|
| parser.add_argument("--total_episodes", type=int, default=200)
|
| parser.add_argument("--seed", type=int, default=42)
|
| parser.add_argument("--max_steps", type=int, default=18)
|
| parser.add_argument("--num_generations", type=int, default=4)
|
| parser.add_argument("--learning_rate", type=float, default=1e-5)
|
| parser.add_argument("--max_prompt_length", type=int, default=1024)
|
| parser.add_argument("--max_completion_length", type=int, default=256)
|
| parser.add_argument("--output_dir", default="training/grpo-output")
|
| args = parser.parse_args()
|
|
|
| try:
|
| from trl import GRPOConfig, GRPOTrainer
|
| except ImportError as exc:
|
| raise SystemExit(
|
| "TRL (Transformer Reinforcement Learning) is required: "
|
| "pip install -r requirements-train.txt"
|
| ) from exc
|
|
|
| logger.info("Loading tokenizer + model: %s", args.model_name)
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
| model = AutoModelForCausalLM.from_pretrained(
|
| args.model_name,
|
| torch_dtype=torch.float32,
|
| )
|
|
|
| logger.info("Building prompt dataset (%d prompts)", args.total_episodes)
|
| dataset = build_dataset(
|
| tokenizer=tokenizer,
|
| n_prompts=args.total_episodes,
|
| seed=args.seed,
|
| scenario=args.scenario,
|
| difficulty=args.difficulty,
|
| )
|
|
|
| env = CERNCollisionEnvironment(max_steps=args.max_steps)
|
| ctx = EpisodeContext(
|
| env=env,
|
| seed=args.seed,
|
| scenario=args.scenario,
|
| difficulty=args.difficulty,
|
| )
|
| reward_fn = make_reward_fn(ctx)
|
|
|
| cfg = GRPOConfig(
|
| output_dir=args.output_dir,
|
| per_device_train_batch_size=2,
|
| gradient_accumulation_steps=2,
|
| num_generations=args.num_generations,
|
| learning_rate=args.learning_rate,
|
| max_prompt_length=args.max_prompt_length,
|
| max_completion_length=args.max_completion_length,
|
| logging_steps=5,
|
| save_steps=50,
|
| seed=args.seed,
|
| bf16=False,
|
| fp16=False,
|
| report_to=[],
|
| )
|
|
|
| trainer = GRPOTrainer(
|
| model=model,
|
| processing_class=tokenizer,
|
| train_dataset=dataset,
|
| reward_funcs=[reward_fn],
|
| args=cfg,
|
| )
|
| logger.info("Starting GRPO training")
|
| trainer.train()
|
| trainer.save_model(args.output_dir)
|
| tokenizer.save_pretrained(args.output_dir)
|
| logger.info("Saved model to %s", args.output_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|