cernenv / training /training_script.py
anugrahhu's picture
Update CERNenv Space
f28409b verified
"""GRPO (Group-Relative Policy Optimization) training script for CERNenv.
Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
fine-tune a small instruction-tuned model on full episodes of the CERN
environment. Each ``query`` is a prompt sampled from a freshly-reset env;
the reward function rolls the model's response through the environment and
returns the per-step + (optional) terminal reward.
This script is intentionally CPU-friendly and self-contained. For
GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
Run:
python -m training.training_script \
--model_name HuggingFaceTB/SmolLM2-360M-Instruct \
--total_episodes 200 --max_steps 18 --output_dir training/grpo-output
"""
from __future__ import annotations
import argparse
import logging
import math
import os
import threading
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, TYPE_CHECKING
# Heavy ML deps (torch, datasets, transformers) are imported lazily inside
# ``main`` and ``build_dataset`` so the lightweight helpers — reward
# function, curriculum schedule, format-validity bonus — remain importable
# in environments that only have the env's runtime dependencies (numpy,
# pydantic, openenv-core). This keeps ``tests/`` runnable on CPU.
from models import ExperimentAction
from server.environment import CERNCollisionEnvironment
from training.llm_agent import (
LLMAgentConfig,
build_chat,
parse_action,
safe_default_action,
)
if TYPE_CHECKING: # pragma: no cover
from datasets import Dataset
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# ── Episode reward harness ───────────────────────────────────────────────
@dataclass
class EpisodeContext:
"""Per-prompt reusable env + default rollout config.
``seed`` and ``difficulty`` here are *fallback* values used when the
TRL reward function does not receive per-prompt overrides via dataset
columns. With a curriculum-aware dataset we always pass per-prompt
``seed``/``difficulty`` so the reward truly corresponds to the
scored prompt.
"""
env: CERNCollisionEnvironment
seed: int
scenario: Optional[str]
difficulty: Optional[str]
@dataclass
class EpisodeStats:
"""Per-rollout reward breakdown surfaced for component-level logging.
The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual
reward function columns, not just average reward". This struct gives
the EvidenceCallback enough information to log each component on its
own column so a reviewer (or you) can see *which* reward terms drove
the policy update at any given training step.
"""
cumulative_reward: float = 0.0
terminal_reward: float = 0.0
step_shaping: float = 0.0 # cumulative_reward - terminal_reward
discovered: bool = False
correct_mass: bool = False
correct_channel: bool = False
correct_spin: bool = False
parsed_ok: bool = False
n_steps: int = 0
difficulty: Optional[str] = None
def _stepwise_reward(
*,
completion_text: str,
ctx: EpisodeContext,
seed: Optional[int] = None,
difficulty: Optional[str] = None,
scenario: Optional[str] = None,
out_stats: Optional[EpisodeStats] = None,
) -> float:
"""Roll the model's first response through one full episode and
return the cumulative reward (per-step + terminal).
The completion is interpreted as the first action only; subsequent
steps fall back to the safe default policy. This keeps the reward
bandwidth high for early-exploration training without requiring
multi-turn rollouts inside GRPO.
If ``out_stats`` is provided, it is populated in-place with a
per-rollout breakdown (terminal vs shaping reward, success flags)
so the caller can stream component-level metrics into the evidence
log instead of relying only on aggregate reward.
"""
env = ctx.env
obs = env.reset(
seed=seed if seed is not None else ctx.seed,
scenario=scenario if scenario is not None else ctx.scenario,
difficulty=difficulty if difficulty is not None else ctx.difficulty,
)
parsed = parse_action(completion_text)
action = parsed or safe_default_action(obs)
obs = env.step(action)
cumulative = float(obs.reward or 0.0)
n_steps = 1
while not obs.done:
fallback = safe_default_action(obs)
obs = env.step(fallback)
cumulative += float(obs.reward or 0.0)
n_steps += 1
if out_stats is not None:
st = env.state
terminal = float(st.terminal_reward or 0.0)
out_stats.cumulative_reward = cumulative
out_stats.terminal_reward = terminal
out_stats.step_shaping = cumulative - terminal
out_stats.discovered = bool(st.discovered) if st.discovered is not None else False
out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False
out_stats.correct_channel = (
bool(st.correct_channel) if st.correct_channel is not None else False
)
out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False
out_stats.parsed_ok = parsed is not None
out_stats.n_steps = n_steps
out_stats.difficulty = st.difficulty
return cumulative
# ── Reward-component accumulator (used by EvidenceCallback) ──────────────
class RewardComponentAccumulator:
"""Thread-safe rolling buffer of per-rollout ``EpisodeStats``.
The reward function appends to this; the EvidenceCallback drains it
on each ``on_log`` and writes one summary row to
``evidence/reward_components.csv``. By pairing each row with the
matching GRPO ``state.global_step``, we can plot per-component reward
curves *aligned* with the loss curve.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._buf: List[EpisodeStats] = []
def append(self, stats: EpisodeStats) -> None:
with self._lock:
self._buf.append(stats)
def drain(self) -> List[EpisodeStats]:
with self._lock:
out, self._buf = self._buf, []
return out
@staticmethod
def summarise(stats: List[EpisodeStats]) -> Dict[str, float]:
if not stats:
return {
"n": 0,
"mean_cumulative": 0.0,
"mean_terminal": 0.0,
"mean_step_shaping": 0.0,
"discovered_rate": 0.0,
"mass_correct_rate": 0.0,
"channel_correct_rate": 0.0,
"spin_correct_rate": 0.0,
"parsed_rate": 0.0,
"mean_n_steps": 0.0,
}
n = len(stats)
return {
"n": n,
"mean_cumulative": sum(s.cumulative_reward for s in stats) / n,
"mean_terminal": sum(s.terminal_reward for s in stats) / n,
"mean_step_shaping": sum(s.step_shaping for s in stats) / n,
"discovered_rate": sum(1 for s in stats if s.discovered) / n,
"mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n,
"channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n,
"spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n,
"parsed_rate": sum(1 for s in stats if s.parsed_ok) / n,
"mean_n_steps": sum(s.n_steps for s in stats) / n,
}
FORMAT_BONUS_VALID = 0.15
FORMAT_BONUS_INVALID = -0.20
def _format_validity_bonus(completion_text: str) -> float:
"""Small ± nudge for emitting a structured action.
Kept intentionally small (≪ terminal_scale) so the policy can't be
dominated by a "spam well-formed JSON" objective. The negative branch
is slightly larger than the positive branch so unparseable garbage
is dispreferred without crowding out the actual task reward.
"""
return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID
def make_reward_fn(
ctx: EpisodeContext,
accumulator: Optional[RewardComponentAccumulator] = None,
):
"""Return a TRL-compatible reward function.
TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``)
as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We
use those here so the rollout used to score completion ``i`` matches
the prompt that produced it, which also unlocks curriculum training.
If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is
appended to it so the trainer's ``on_log`` callback can flush a
per-component summary into the evidence CSV — that's what produces
the "watch individual reward function columns" view recommended in
the hackathon FAQ.
"""
def reward_fn(
prompts: List[str],
completions: List[str],
**kwargs: Any,
) -> List[float]:
seeds = kwargs.get("seed")
diffs = kwargs.get("difficulty")
scenarios = kwargs.get("scenario")
rewards: List[float] = []
for i, completion in enumerate(completions):
stats = EpisodeStats() if accumulator is not None else None
r = _stepwise_reward(
completion_text=completion,
ctx=ctx,
seed=int(seeds[i]) if seeds is not None else None,
difficulty=diffs[i] if diffs is not None else None,
scenario=scenarios[i] if scenarios is not None else None,
out_stats=stats,
)
r += _format_validity_bonus(completion)
rewards.append(float(r))
if accumulator is not None and stats is not None:
accumulator.append(stats)
return rewards
return reward_fn
# ── Prompt dataset ───────────────────────────────────────────────────────
DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [
("easy", 0.50),
("medium", 0.30),
("hard", 0.20),
]
def curriculum_difficulty_for(
idx: int,
n_prompts: int,
schedule: Optional[List[tuple]] = None,
) -> str:
"""Map an episode index to a difficulty using a deterministic ramp.
A simple "easy first → harder later" schedule (FAQ Q14, help-guide §6)
is enough to keep early-training success rate non-zero, which is the
whole point of curriculum: the policy must occasionally see positive
reward before RL can move probability mass toward it.
"""
sched = schedule or DEFAULT_CURRICULUM_SCHEDULE
boundaries: List[tuple] = []
cumulative = 0.0
for diff, frac in sched:
cumulative += frac
boundaries.append((diff, cumulative * n_prompts))
for diff, upper in boundaries:
if idx < upper:
return diff
return boundaries[-1][0]
def build_dataset(
*,
tokenizer,
n_prompts: int,
seed: int,
scenario: Optional[str],
difficulty: Optional[str],
curriculum: bool = False,
schedule: Optional[List[tuple]] = None,
) -> "Dataset":
from datasets import Dataset # lazy: heavy import path
env = CERNCollisionEnvironment()
prompts: List[str] = []
seeds: List[int] = []
diffs: List[str] = []
for i in range(n_prompts):
ep_seed = seed + i
ep_diff = (
curriculum_difficulty_for(i, n_prompts, schedule)
if curriculum else (difficulty or "easy")
)
obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff)
chat = build_chat(obs)
prompt = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False
)
prompts.append(prompt)
seeds.append(ep_seed)
diffs.append(ep_diff)
return Dataset.from_dict({
"prompt": prompts,
"seed": seeds,
"difficulty": diffs,
})
# ── Main ─────────────────────────────────────────────────────────────────
def main() -> None: # pragma: no cover - training entrypoint
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument(
"--curriculum",
action="store_true",
help="Build the prompt set with an easy→medium→hard ramp.",
)
parser.add_argument("--total_episodes", type=int, default=200)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max_steps", type=int, default=18)
parser.add_argument("--num_generations", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--max_prompt_length", type=int, default=1024)
parser.add_argument("--max_completion_length", type=int, default=256)
parser.add_argument("--output_dir", default="training/grpo-output")
args = parser.parse_args()
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
except ImportError as exc: # pragma: no cover
raise SystemExit(
"TRL (Transformer Reinforcement Learning) is required: "
"pip install -r requirements-train.txt"
) from exc
logger.info("Loading tokenizer + model: %s", args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.float32,
)
logger.info(
"Building prompt dataset (%d prompts, curriculum=%s)",
args.total_episodes, args.curriculum,
)
dataset = build_dataset(
tokenizer=tokenizer,
n_prompts=args.total_episodes,
seed=args.seed,
scenario=args.scenario,
difficulty=args.difficulty,
curriculum=args.curriculum,
)
env = CERNCollisionEnvironment(max_steps=args.max_steps)
ctx = EpisodeContext(
env=env,
seed=args.seed,
scenario=args.scenario,
difficulty=args.difficulty,
)
reward_fn = make_reward_fn(ctx)
cfg = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
num_generations=args.num_generations,
learning_rate=args.learning_rate,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
logging_steps=5,
save_steps=50,
seed=args.seed,
bf16=False,
fp16=False,
report_to=[],
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
reward_funcs=[reward_fn],
args=cfg,
)
logger.info("Starting GRPO training")
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("Saved model to %s", args.output_dir)
if __name__ == "__main__": # pragma: no cover
main()