cernenv / training /training_script.py
anugrah55's picture
Update CERNenv Space
2b0bffa verified
"""GRPO (Group-Relative Policy Optimization) training script for CERNenv.
Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
fine-tune a small instruction-tuned model on full episodes of the CERN
environment. Each ``query`` is a prompt sampled from a freshly-reset env;
the reward function rolls the model's response through the environment and
returns the per-step + (optional) terminal reward.
This script is intentionally CPU-friendly and self-contained. For
GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
Run:
python -m training.training_script \
--model_name HuggingFaceTB/SmolLM2-360M-Instruct \
--total_episodes 200 --max_steps 18 --output_dir training/grpo-output
"""
from __future__ import annotations
import argparse
import logging
import math
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from models import ExperimentAction
from server.environment import CERNCollisionEnvironment
from training.llm_agent import (
LLMAgentConfig,
build_chat,
parse_action,
safe_default_action,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# ── Episode reward harness ───────────────────────────────────────────────
@dataclass
class EpisodeContext:
"""Per-prompt reusable env + observation snapshot used by the reward fn."""
env: CERNCollisionEnvironment
seed: int
scenario: Optional[str]
difficulty: Optional[str]
def _stepwise_reward(
*,
completion_text: str,
ctx: EpisodeContext,
) -> float:
"""Roll the model's first response through one full episode and
return the cumulative reward (per-step + terminal).
The completion is interpreted as the first action only; subsequent
steps fall back to the safe default policy. This keeps the reward
bandwidth high for early-exploration training without requiring
multi-turn rollouts inside GRPO.
"""
env = ctx.env
obs = env.reset(seed=ctx.seed, scenario=ctx.scenario, difficulty=ctx.difficulty)
action = parse_action(completion_text) or safe_default_action(obs)
obs = env.step(action)
cumulative = float(obs.reward or 0.0)
while not obs.done:
fallback = safe_default_action(obs)
obs = env.step(fallback)
cumulative += float(obs.reward or 0.0)
return cumulative
def _format_validity_bonus(completion_text: str) -> float:
return 0.5 if parse_action(completion_text) is not None else -0.5
def make_reward_fn(ctx: EpisodeContext):
"""Return a TRL-compatible reward function (closes over ``ctx``)."""
def reward_fn(prompts: List[str], completions: List[str], **kwargs: Any) -> List[float]:
rewards: List[float] = []
for completion in completions:
r = _stepwise_reward(completion_text=completion, ctx=ctx)
r += _format_validity_bonus(completion)
rewards.append(float(r))
return rewards
return reward_fn
# ── Prompt dataset ───────────────────────────────────────────────────────
def build_dataset(
*,
tokenizer,
n_prompts: int,
seed: int,
scenario: Optional[str],
difficulty: Optional[str],
) -> Dataset:
env = CERNCollisionEnvironment()
prompts: List[str] = []
for i in range(n_prompts):
obs = env.reset(seed=seed + i, scenario=scenario, difficulty=difficulty)
chat = build_chat(obs)
prompt = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False
)
prompts.append(prompt)
return Dataset.from_dict({"prompt": prompts})
# ── Main ─────────────────────────────────────────────────────────────────
def main() -> None: # pragma: no cover - training entrypoint
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument("--total_episodes", type=int, default=200)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max_steps", type=int, default=18)
parser.add_argument("--num_generations", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--max_prompt_length", type=int, default=1024)
parser.add_argument("--max_completion_length", type=int, default=256)
parser.add_argument("--output_dir", default="training/grpo-output")
args = parser.parse_args()
try:
from trl import GRPOConfig, GRPOTrainer
except ImportError as exc: # pragma: no cover
raise SystemExit(
"TRL (Transformer Reinforcement Learning) is required: "
"pip install -r requirements-train.txt"
) from exc
logger.info("Loading tokenizer + model: %s", args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.float32,
)
logger.info("Building prompt dataset (%d prompts)", args.total_episodes)
dataset = build_dataset(
tokenizer=tokenizer,
n_prompts=args.total_episodes,
seed=args.seed,
scenario=args.scenario,
difficulty=args.difficulty,
)
env = CERNCollisionEnvironment(max_steps=args.max_steps)
ctx = EpisodeContext(
env=env,
seed=args.seed,
scenario=args.scenario,
difficulty=args.difficulty,
)
reward_fn = make_reward_fn(ctx)
cfg = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
num_generations=args.num_generations,
learning_rate=args.learning_rate,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
logging_steps=5,
save_steps=50,
seed=args.seed,
bf16=False,
fp16=False,
report_to=[],
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
reward_funcs=[reward_fn],
args=cfg,
)
logger.info("Starting GRPO training")
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("Saved model to %s", args.output_dir)
if __name__ == "__main__": # pragma: no cover
main()