Spaces:
Sleeping
Sleeping
| """GRPO (Group-Relative Policy Optimization) training script for CERNenv. | |
| Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to | |
| fine-tune a small instruction-tuned model on full episodes of the CERN | |
| environment. Each ``query`` is a prompt sampled from a freshly-reset env; | |
| the reward function rolls the model's response through the environment and | |
| returns the per-step + (optional) terminal reward. | |
| This script is intentionally CPU-friendly and self-contained. For | |
| GPU-accelerated training with LoRA, prefer ``training_unsloth.py``. | |
| Run: | |
| python -m training.training_script \ | |
| --model_name HuggingFaceTB/SmolLM2-360M-Instruct \ | |
| --total_episodes 200 --max_steps 18 --output_dir training/grpo-output | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import math | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from models import ExperimentAction | |
| from server.environment import CERNCollisionEnvironment | |
| from training.llm_agent import ( | |
| LLMAgentConfig, | |
| build_chat, | |
| parse_action, | |
| safe_default_action, | |
| ) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # ── Episode reward harness ─────────────────────────────────────────────── | |
| class EpisodeContext: | |
| """Per-prompt reusable env + observation snapshot used by the reward fn.""" | |
| env: CERNCollisionEnvironment | |
| seed: int | |
| scenario: Optional[str] | |
| difficulty: Optional[str] | |
| def _stepwise_reward( | |
| *, | |
| completion_text: str, | |
| ctx: EpisodeContext, | |
| ) -> float: | |
| """Roll the model's first response through one full episode and | |
| return the cumulative reward (per-step + terminal). | |
| The completion is interpreted as the first action only; subsequent | |
| steps fall back to the safe default policy. This keeps the reward | |
| bandwidth high for early-exploration training without requiring | |
| multi-turn rollouts inside GRPO. | |
| """ | |
| env = ctx.env | |
| obs = env.reset(seed=ctx.seed, scenario=ctx.scenario, difficulty=ctx.difficulty) | |
| action = parse_action(completion_text) or safe_default_action(obs) | |
| obs = env.step(action) | |
| cumulative = float(obs.reward or 0.0) | |
| while not obs.done: | |
| fallback = safe_default_action(obs) | |
| obs = env.step(fallback) | |
| cumulative += float(obs.reward or 0.0) | |
| return cumulative | |
| def _format_validity_bonus(completion_text: str) -> float: | |
| return 0.5 if parse_action(completion_text) is not None else -0.5 | |
| def make_reward_fn(ctx: EpisodeContext): | |
| """Return a TRL-compatible reward function (closes over ``ctx``).""" | |
| def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]: | |
| rewards: List[float] = [] | |
| for completion in completions: | |
| r = _stepwise_reward(completion_text=completion, ctx=ctx) | |
| r += _format_validity_bonus(completion) | |
| rewards.append(float(r)) | |
| return rewards | |
| return reward_fn | |
| # ── Prompt dataset ─────────────────────────────────────────────────────── | |
| def build_dataset( | |
| *, | |
| tokenizer, | |
| n_prompts: int, | |
| seed: int, | |
| scenario: Optional[str], | |
| difficulty: Optional[str], | |
| ) -> Dataset: | |
| env = CERNCollisionEnvironment() | |
| prompts: List[str] = [] | |
| for i in range(n_prompts): | |
| obs = env.reset(seed=seed + i, scenario=scenario, difficulty=difficulty) | |
| chat = build_chat(obs) | |
| prompt = tokenizer.apply_chat_template( | |
| chat, add_generation_prompt=True, tokenize=False | |
| ) | |
| prompts.append(prompt) | |
| return Dataset.from_dict({"prompt": prompts}) | |
| # ── Main ───────────────────────────────────────────────────────────────── | |
| def main() -> None: # pragma: no cover - training entrypoint | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct") | |
| parser.add_argument("--scenario", default=None) | |
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy") | |
| parser.add_argument("--total_episodes", type=int, default=200) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--max_steps", type=int, default=18) | |
| parser.add_argument("--num_generations", type=int, default=4) | |
| parser.add_argument("--learning_rate", type=float, default=1e-5) | |
| parser.add_argument("--max_prompt_length", type=int, default=1024) | |
| parser.add_argument("--max_completion_length", type=int, default=256) | |
| parser.add_argument("--output_dir", default="training/grpo-output") | |
| args = parser.parse_args() | |
| try: | |
| from trl import GRPOConfig, GRPOTrainer | |
| except ImportError as exc: # pragma: no cover | |
| raise SystemExit( | |
| "TRL (Transformer Reinforcement Learning) is required: " | |
| "pip install -r requirements-train.txt" | |
| ) from exc | |
| logger.info("Loading tokenizer + model: %s", args.model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_name, | |
| torch_dtype=torch.float32, | |
| ) | |
| logger.info("Building prompt dataset (%d prompts)", args.total_episodes) | |
| dataset = build_dataset( | |
| tokenizer=tokenizer, | |
| n_prompts=args.total_episodes, | |
| seed=args.seed, | |
| scenario=args.scenario, | |
| difficulty=args.difficulty, | |
| ) | |
| env = CERNCollisionEnvironment(max_steps=args.max_steps) | |
| ctx = EpisodeContext( | |
| env=env, | |
| seed=args.seed, | |
| scenario=args.scenario, | |
| difficulty=args.difficulty, | |
| ) | |
| reward_fn = make_reward_fn(ctx) | |
| cfg = GRPOConfig( | |
| output_dir=args.output_dir, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=2, | |
| num_generations=args.num_generations, | |
| learning_rate=args.learning_rate, | |
| max_prompt_length=args.max_prompt_length, | |
| max_completion_length=args.max_completion_length, | |
| logging_steps=5, | |
| save_steps=50, | |
| seed=args.seed, | |
| bf16=False, | |
| fp16=False, | |
| report_to=[], | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| train_dataset=dataset, | |
| reward_funcs=[reward_fn], | |
| args=cfg, | |
| ) | |
| logger.info("Starting GRPO training") | |
| trainer.train() | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| logger.info("Saved model to %s", args.output_dir) | |
| if __name__ == "__main__": # pragma: no cover | |
| main() | |