cernenv-trainer / training /training_script.py
anugrahhu's picture
sft+reward-fix: training/training_script.py
2b97998 verified
"""GRPO (Group-Relative Policy Optimization) training script for CERNenv.
Uses Hugging Face TRL (Transformer Reinforcement Learning) ``GRPOTrainer`` to
fine-tune a small instruction-tuned model on full episodes of the CERN
environment. Each ``query`` is a prompt sampled from a freshly-reset env;
the reward function rolls the model's response through the environment and
returns the per-step + (optional) terminal reward.
This script is intentionally CPU-friendly and self-contained. For
GPU-accelerated training with LoRA, prefer ``training_unsloth.py``.
Run:
python -m training.training_script \
--model_name HuggingFaceTB/SmolLM2-360M-Instruct \
--total_episodes 200 --max_steps 18 --output_dir training/grpo-output
"""
from __future__ import annotations
import argparse
import logging
import math
import os
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, TYPE_CHECKING
# Heavy ML deps (torch, datasets, transformers) are imported lazily inside
# ``main`` and ``build_dataset`` so the lightweight helpers — reward
# function, curriculum schedule, format-validity bonus — remain importable
# in environments that only have the env's runtime dependencies (numpy,
# pydantic, openenv-core). This keeps ``tests/`` runnable on CPU.
from models import ExperimentAction
from server.environment import CERNCollisionEnvironment
from training.llm_agent import (
LLMAgentConfig,
build_chat,
parse_action,
safe_default_action,
)
if TYPE_CHECKING: # pragma: no cover
from datasets import Dataset
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# ── Episode reward harness ───────────────────────────────────────────────
@dataclass
class EpisodeContext:
"""Per-prompt reusable env + default rollout config.
``seed`` and ``difficulty`` here are *fallback* values used when the
TRL reward function does not receive per-prompt overrides via dataset
columns. With a curriculum-aware dataset we always pass per-prompt
``seed``/``difficulty`` so the reward truly corresponds to the
scored prompt.
"""
env: CERNCollisionEnvironment
seed: int
scenario: Optional[str]
difficulty: Optional[str]
@dataclass
class EpisodeStats:
"""Per-rollout reward breakdown surfaced for component-level logging.
The hackathon FAQ (Q17, Q43, Q52) repeatedly warns: "watch individual
reward function columns, not just average reward". This struct gives
the EvidenceCallback enough information to log each component on its
own column so a reviewer (or you) can see *which* reward terms drove
the policy update at any given training step.
"""
cumulative_reward: float = 0.0
terminal_reward: float = 0.0
step_shaping: float = 0.0 # cumulative_reward - terminal_reward
discovered: bool = False
correct_mass: bool = False
correct_channel: bool = False
correct_spin: bool = False
parsed_ok: bool = False
n_steps: int = 0
difficulty: Optional[str] = None
def _stepwise_reward(
*,
completion_text: str,
ctx: EpisodeContext,
seed: Optional[int] = None,
difficulty: Optional[str] = None,
scenario: Optional[str] = None,
out_stats: Optional[EpisodeStats] = None,
) -> float:
"""Roll the model's first response through one full episode and
return the cumulative reward (per-step + terminal).
The completion is interpreted as the first action only; subsequent
steps fall back to the safe default policy. This keeps the reward
bandwidth high for early-exploration training without requiring
multi-turn rollouts inside GRPO.
If ``out_stats`` is provided, it is populated in-place with a
per-rollout breakdown (terminal vs shaping reward, success flags)
so the caller can stream component-level metrics into the evidence
log instead of relying only on aggregate reward.
"""
env = ctx.env
obs = env.reset(
seed=seed if seed is not None else ctx.seed,
scenario=scenario if scenario is not None else ctx.scenario,
difficulty=difficulty if difficulty is not None else ctx.difficulty,
)
parsed = parse_action(completion_text)
action = parsed or safe_default_action(obs)
obs = env.step(action)
cumulative = float(obs.reward or 0.0)
n_steps = 1
while not obs.done:
fallback = safe_default_action(obs)
obs = env.step(fallback)
cumulative += float(obs.reward or 0.0)
n_steps += 1
if out_stats is not None:
st = env.state
terminal = float(st.terminal_reward or 0.0)
out_stats.cumulative_reward = cumulative
out_stats.terminal_reward = terminal
out_stats.step_shaping = cumulative - terminal
out_stats.discovered = bool(st.discovered) if st.discovered is not None else False
out_stats.correct_mass = bool(st.correct_mass) if st.correct_mass is not None else False
out_stats.correct_channel = (
bool(st.correct_channel) if st.correct_channel is not None else False
)
out_stats.correct_spin = bool(st.correct_spin) if st.correct_spin is not None else False
out_stats.parsed_ok = parsed is not None
out_stats.n_steps = n_steps
out_stats.difficulty = st.difficulty
return cumulative
# ── Reward-component accumulator (used by EvidenceCallback) ──────────────
class RewardComponentAccumulator:
"""Thread-safe rolling buffer of per-rollout ``EpisodeStats``.
The reward function appends to this; the EvidenceCallback drains it
on each ``on_log`` and writes one summary row to
``evidence/reward_components.csv``. By pairing each row with the
matching GRPO ``state.global_step``, we can plot per-component reward
curves *aligned* with the loss curve.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._buf: List[EpisodeStats] = []
def append(self, stats: EpisodeStats) -> None:
with self._lock:
self._buf.append(stats)
def drain(self) -> List[EpisodeStats]:
with self._lock:
out, self._buf = self._buf, []
return out
@staticmethod
def summarise(stats: List[EpisodeStats]) -> Dict[str, float]:
if not stats:
return {
"n": 0,
"mean_cumulative": 0.0,
"mean_terminal": 0.0,
"mean_step_shaping": 0.0,
"discovered_rate": 0.0,
"mass_correct_rate": 0.0,
"channel_correct_rate": 0.0,
"spin_correct_rate": 0.0,
"parsed_rate": 0.0,
"mean_n_steps": 0.0,
}
n = len(stats)
return {
"n": n,
"mean_cumulative": sum(s.cumulative_reward for s in stats) / n,
"mean_terminal": sum(s.terminal_reward for s in stats) / n,
"mean_step_shaping": sum(s.step_shaping for s in stats) / n,
"discovered_rate": sum(1 for s in stats if s.discovered) / n,
"mass_correct_rate": sum(1 for s in stats if s.correct_mass) / n,
"channel_correct_rate": sum(1 for s in stats if s.correct_channel) / n,
"spin_correct_rate": sum(1 for s in stats if s.correct_spin) / n,
"parsed_rate": sum(1 for s in stats if s.parsed_ok) / n,
"mean_n_steps": sum(s.n_steps for s in stats) / n,
}
FORMAT_BONUS_VALID = 0.05 # was 0.15 — Fix #3 (lower per-step floor)
FORMAT_BONUS_INVALID = -0.20 # kept punitive so unparseable completions still hurt
def _format_validity_bonus(completion_text: str) -> float:
"""Small ± nudge for emitting a structured action.
Kept intentionally small (≪ terminal_scale) so the policy can't be
dominated by a "spam well-formed JSON" objective. After Fix #3 the
positive branch is 1/3 of its v1 value (0.05 vs 0.15) — combined
with the lower step_reward_clip and the heavier repeat-action
penalty, this means a model can no longer farm ~+0.22/step by
looping a single well-formed action.
"""
return FORMAT_BONUS_VALID if parse_action(completion_text) is not None else FORMAT_BONUS_INVALID
def make_reward_fn(
ctx: EpisodeContext,
accumulator: Optional[RewardComponentAccumulator] = None,
):
"""Return a TRL-compatible reward function.
TRL forwards extra dataset columns (e.g. ``seed``, ``difficulty``)
as ``kwargs`` aligned 1-to-1 with ``prompts``/``completions``. We
use those here so the rollout used to score completion ``i`` matches
the prompt that produced it, which also unlocks curriculum training.
If ``accumulator`` is provided, every rollout's ``EpisodeStats`` is
appended to it so the trainer's ``on_log`` callback can flush a
per-component summary into the evidence CSV — that's what produces
the "watch individual reward function columns" view recommended in
the hackathon FAQ.
"""
def reward_fn(
prompts: List[str],
completions: List[str],
**kwargs: Any,
) -> List[float]:
seeds = kwargs.get("seed")
diffs = kwargs.get("difficulty")
scenarios = kwargs.get("scenario")
rewards: List[float] = []
for i, completion in enumerate(completions):
stats = EpisodeStats() if accumulator is not None else None
r = _stepwise_reward(
completion_text=completion,
ctx=ctx,
seed=int(seeds[i]) if seeds is not None else None,
difficulty=diffs[i] if diffs is not None else None,
scenario=scenarios[i] if scenarios is not None else None,
out_stats=stats,
)
r += _format_validity_bonus(completion)
rewards.append(float(r))
if accumulator is not None and stats is not None:
accumulator.append(stats)
return rewards
return reward_fn
# ── Prompt dataset ───────────────────────────────────────────────────────
DEFAULT_CURRICULUM_SCHEDULE: List[tuple] = [
("easy", 0.50),
("medium", 0.30),
("hard", 0.20),
]
def curriculum_difficulty_for(
idx: int,
n_prompts: int,
schedule: Optional[List[tuple]] = None,
) -> str:
"""Map an episode index to a difficulty using a deterministic ramp.
A simple "easy first → harder later" schedule (FAQ Q14, help-guide §6)
is enough to keep early-training success rate non-zero, which is the
whole point of curriculum: the policy must occasionally see positive
reward before RL can move probability mass toward it.
"""
sched = schedule or DEFAULT_CURRICULUM_SCHEDULE
boundaries: List[tuple] = []
cumulative = 0.0
for diff, frac in sched:
cumulative += frac
boundaries.append((diff, cumulative * n_prompts))
for diff, upper in boundaries:
if idx < upper:
return diff
return boundaries[-1][0]
def build_dataset(
*,
tokenizer,
n_prompts: int,
seed: int,
scenario: Optional[str],
difficulty: Optional[str],
curriculum: bool = False,
schedule: Optional[List[tuple]] = None,
) -> "Dataset":
from datasets import Dataset # lazy: heavy import path
env = CERNCollisionEnvironment()
prompts: List[str] = []
seeds: List[int] = []
diffs: List[str] = []
for i in range(n_prompts):
ep_seed = seed + i
ep_diff = (
curriculum_difficulty_for(i, n_prompts, schedule)
if curriculum else (difficulty or "easy")
)
obs = env.reset(seed=ep_seed, scenario=scenario, difficulty=ep_diff)
chat = build_chat(obs)
prompt = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False
)
prompts.append(prompt)
seeds.append(ep_seed)
diffs.append(ep_diff)
return Dataset.from_dict({
"prompt": prompts,
"seed": seeds,
"difficulty": diffs,
})
# ── Main ─────────────────────────────────────────────────────────────────
def main() -> None: # pragma: no cover - training entrypoint
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="HuggingFaceTB/SmolLM2-360M-Instruct")
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument(
"--curriculum",
action="store_true",
help="Build the prompt set with an easy→medium→hard ramp.",
)
parser.add_argument("--total_episodes", type=int, default=200)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max_steps", type=int, default=18)
parser.add_argument("--num_generations", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--max_prompt_length", type=int, default=1024)
parser.add_argument("--max_completion_length", type=int, default=256)
parser.add_argument("--output_dir", default="training/grpo-output")
parser.add_argument(
"--evidence_dir",
default="evidence",
help="Directory for training_log.csv, reward_components.csv, "
"checkpoint_evals.csv and the corresponding *.png plots.",
)
parser.add_argument(
"--checkpoint_eval_steps",
type=int,
default=25,
help="Run a held-out eval every N GRPO updates for the progression curve.",
)
parser.add_argument(
"--checkpoint_eval_episodes",
type=int,
default=8,
help="Number of held-out episodes per mid-training eval.",
)
args = parser.parse_args()
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
except ImportError as exc: # pragma: no cover
raise SystemExit(
"TRL (Transformer Reinforcement Learning) is required: "
"pip install -r requirements-train.txt"
) from exc
logger.info("Loading tokenizer + model: %s", args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.float32,
)
logger.info(
"Building prompt dataset (%d prompts, curriculum=%s)",
args.total_episodes, args.curriculum,
)
dataset = build_dataset(
tokenizer=tokenizer,
n_prompts=args.total_episodes,
seed=args.seed,
scenario=args.scenario,
difficulty=args.difficulty,
curriculum=args.curriculum,
)
env = CERNCollisionEnvironment(max_steps=args.max_steps)
ctx = EpisodeContext(
env=env,
seed=args.seed,
scenario=args.scenario,
difficulty=args.difficulty,
)
# ── Evidence wiring (training_log.csv / reward_components.csv /
# checkpoint_evals.csv + PNG plots). Mirrors training_unsloth.py so the
# vanilla GRPO backend hydrates the same dashboard cards. The render
# helpers are best-effort: matplotlib import failures are swallowed and
# the corresponding PNG is skipped, never crashing training.
import time as _time
from transformers import TrainerCallback
from training.evidence import (
CheckpointEvalWriter,
EvidencePaths,
RewardComponentLogWriter,
TrainingLogWriter,
render_checkpoint_progression,
render_reward_components,
render_training_curve,
)
from training.llm_agent import LLMAgentConfig
from training.rollouts import collect_episode
paths = EvidencePaths(root=Path(args.evidence_dir))
paths.ensure()
log_writer = TrainingLogWriter(paths.training_log_csv)
ckpt_writer = CheckpointEvalWriter(paths.checkpoint_evals_csv)
component_writer = RewardComponentLogWriter(paths.reward_components_csv)
component_accumulator = RewardComponentAccumulator()
held_out_seeds = list(range(900_000, 900_000 + args.checkpoint_eval_episodes))
reward_fn = make_reward_fn(ctx, accumulator=component_accumulator)
cfg = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
num_generations=args.num_generations,
learning_rate=args.learning_rate,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
logging_steps=5,
save_steps=50,
seed=args.seed,
bf16=False,
fp16=False,
report_to=[],
)
class EvidenceCallback(TrainerCallback):
"""Stream training metrics + run periodic mid-training evals.
Backported from training/training_unsloth.py so the vanilla GRPO
path produces the same evidence/*.csv + *.png artefacts the
dashboard reads. Differs from the Unsloth version only in the
train/eval mode toggle: plain transformers uses model.eval() /
model.train() instead of FastLanguageModel.for_inference().
"""
def __init__(self) -> None:
self._t0 = _time.time()
self._last_eval_step = -1
def on_log(self, _args, state, control, logs=None, **kw):
logs = logs or {}
row = {
"step": state.global_step,
"epoch": logs.get("epoch"),
"loss": logs.get("loss"),
"reward": logs.get("reward") or logs.get("rewards/mean"),
"reward_std": logs.get("reward_std") or logs.get("rewards/std"),
"kl": logs.get("kl"),
"grad_norm": logs.get("grad_norm"),
"learning_rate": logs.get("learning_rate"),
"wall_time_s": round(_time.time() - self._t0, 2),
}
if any(v is not None for k, v in row.items() if k != "step"):
log_writer.append(row)
try:
render_training_curve(paths.training_log_csv, paths.training_curve_png)
except Exception as exc: # pragma: no cover - plotting is best-effort
logger.warning("training curve render failed: %s", exc)
drained = component_accumulator.drain()
if drained:
summary = RewardComponentAccumulator.summarise(drained)
summary["step"] = state.global_step
component_writer.append(summary)
try:
render_reward_components(
paths.reward_components_csv, paths.reward_components_png,
)
except Exception as exc: # pragma: no cover
logger.warning("reward components render failed: %s", exc)
def on_step_end(self, _args, state, control, **kw):
step = state.global_step
if step <= 0 or step == self._last_eval_step:
return control
if step % args.checkpoint_eval_steps != 0:
return control
self._last_eval_step = step
try:
self._run_checkpoint_eval(step, state)
except Exception as exc:
logger.warning("checkpoint eval failed at step %d: %s", step, exc)
return control
def _run_checkpoint_eval(self, step: int, state) -> None:
was_training = model.training
model.eval()
try:
episodes = []
for s in held_out_seeds:
ep = self._rollout_one(seed=s)
if ep is not None:
episodes.append(ep)
if not episodes:
return
rewards = [e.cumulative_reward for e in episodes]
success_rate = sum(1 for e in episodes if e.discovered) / len(episodes)
ckpt_writer.append(
step=step,
fraction_done=round(step / max(state.max_steps or step, 1), 4),
episodes=len(episodes),
mean_reward=round(sum(rewards) / len(rewards), 4),
success_rate=round(success_rate, 4),
mass_acc=round(
sum(1 for e in episodes if e.correct_mass) / len(episodes), 4,
),
channel_acc=round(
sum(1 for e in episodes if e.correct_channel) / len(episodes), 4,
),
)
try:
render_checkpoint_progression(
paths.checkpoint_evals_csv,
paths.checkpoint_progression_png,
)
except Exception as exc: # pragma: no cover
logger.warning("checkpoint progression render failed: %s", exc)
logger.info(
"[checkpoint-eval step=%d] reward=%.3f success=%.2f",
step,
sum(rewards) / len(rewards) if rewards else 0.0,
success_rate,
)
finally:
if was_training:
model.train()
def _rollout_one(self, seed: int):
def prompt_fn(chat):
return tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False,
)
def generate_fn(prompt: str, _config) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=args.max_completion_length,
do_sample=True,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
)
gen = outputs[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(gen, skip_special_tokens=True)
return collect_episode(
env=env,
seed=seed,
scenario=args.scenario,
difficulty=args.difficulty,
prompt_fn=prompt_fn,
generate_fn=generate_fn,
config=LLMAgentConfig(),
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
reward_funcs=[reward_fn],
args=cfg,
callbacks=[EvidenceCallback()],
)
logger.info("Starting GRPO training")
trainer.train()
# Drain any rollouts the final on_log didn't catch so the last row of
# reward_components.csv reflects the end-of-training state.
final_drain = component_accumulator.drain()
if final_drain:
summary = RewardComponentAccumulator.summarise(final_drain)
summary["step"] = trainer.state.global_step
component_writer.append(summary)
try:
render_reward_components(
paths.reward_components_csv, paths.reward_components_png,
)
except Exception as exc: # pragma: no cover
logger.warning("final reward components render failed: %s", exc)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("Saved model to %s", args.output_dir)
logger.info("Evidence artifacts in %s", paths.root)
if __name__ == "__main__": # pragma: no cover
main()