cernenv / training /training_unsloth.py
anugrahhu's picture
feat: interactive Gradio demo at /demo
9c00159 verified
"""Unsloth + LoRA (Low-Rank Adaptation) GRPO training for CERNenv.
This is the recommended path for Colab / single- or multi-GPU runs because
Unsloth's fused kernels and 4-bit loading let us train 2B–8B models with
limited VRAM, while TRL's GRPO (Group-Relative Policy Optimization) loop
handles the policy-gradient math.
The trainer is wired up to produce **all** "training-progress evidence"
artifacts demanded by the OpenEnv hackathon's scoring rubric:
* per-step training log + reward/loss curve PNG (Portable Network Graphics)
* mid-training checkpoint evaluations + progression curve PNG
* (post-run) before/after summary + reward-distribution PNG
All artifacts land in ``--evidence_dir`` (default: ``evidence/``).
Run on Colab / single GPU:
!python -m training.training_unsloth \
--model_name unsloth/Qwen2.5-3B-Instruct \
--total_episodes 400 --num_generations 4 --output_dir runs/unsloth-grpo
Run on a 4×A100 Hugging Face Space (multi-GPU via accelerate):
accelerate launch --num_processes 4 -m training.training_unsloth \
--total_episodes 1500 --num_generations 8 --output_dir runs/unsloth-grpo
"""
from __future__ import annotations
import argparse
import logging
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def _build_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument(
"--curriculum",
action="store_true",
help=(
"Enable adaptive curriculum: start at --difficulty and promote "
"to medium/hard once held-out success rate clears the threshold "
"(see training/curriculum.py)."
),
)
parser.add_argument("--curriculum_promote", type=float, default=0.55)
parser.add_argument("--curriculum_demote", type=float, default=0.10)
parser.add_argument("--total_episodes", type=int, default=400)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max_steps", type=int, default=18)
parser.add_argument("--num_generations", type=int, default=4)
parser.add_argument("--max_prompt_length", type=int, default=2048)
parser.add_argument("--max_completion_length", type=int, default=384)
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--load_in_4bit", action="store_true", default=True)
parser.add_argument("--lora_rank", type=int, default=16)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--per_device_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--logging_steps", type=int, default=2)
parser.add_argument("--save_steps", type=int, default=50)
parser.add_argument("--checkpoint_eval_steps", type=int, default=25,
help="Run a held-out eval every N updates for the progression curve.")
parser.add_argument("--checkpoint_eval_episodes", type=int, default=8,
help="Number of held-out episodes per mid-training eval.")
parser.add_argument("--output_dir", default="runs/unsloth-grpo")
parser.add_argument("--evidence_dir", default="evidence")
return parser.parse_args()
def main() -> None: # pragma: no cover - heavy GPU path
args = _build_args()
# IMPORTANT: Unsloth MUST be imported before transformers / trl. It
# patches transformers' lazy ``_import_structure`` to register a few
# symbols (notably ``PreTrainedModel`` under torch-aware paths). If trl
# loads transformers first, the lazy loader will fail with a confusing
# ``ImportError: cannot import name 'PreTrainedModel' from 'transformers'``
# at GRPOTrainer import time — which is exactly what we hit on the
# trainer Space before this reorder.
# See: https://github.com/unslothai/unsloth and the matching
# transformers issue #42548 for the lazy-import root cause.
from unsloth import FastLanguageModel
from transformers import TrainerCallback
from trl import GRPOConfig, GRPOTrainer
from server.environment import CERNCollisionEnvironment
from training.curriculum import CurriculumConfig, CurriculumManager
from training.evidence import (
CheckpointEvalWriter,
EvidencePaths,
RewardComponentLogWriter,
TrainingLogWriter,
render_checkpoint_progression,
render_reward_components,
render_training_curve,
)
from training.llm_agent import LLMAgentConfig
from training.rollouts import collect_episode
from training.training_script import (
EpisodeContext,
RewardComponentAccumulator,
)
paths = EvidencePaths(root=Path(args.evidence_dir))
paths.ensure()
log_writer = TrainingLogWriter(paths.training_log_csv)
ckpt_writer = CheckpointEvalWriter(paths.checkpoint_evals_csv)
component_writer = RewardComponentLogWriter(paths.reward_components_csv)
component_accumulator = RewardComponentAccumulator()
curriculum: Optional[CurriculumManager] = None
if args.curriculum:
curriculum = CurriculumManager(
CurriculumConfig(
start_difficulty=args.difficulty,
promote_threshold=args.curriculum_promote,
demote_threshold=args.curriculum_demote,
)
)
logger.info("Curriculum enabled: start=%s promote≥%.2f demote≤%.2f",
args.difficulty, args.curriculum_promote, args.curriculum_demote)
logger.info("Loading Unsloth model: %s", args.model_name)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_prompt_length + args.max_completion_length,
load_in_4bit=args.load_in_4bit,
# fast_inference requires vLLM, which is not in requirements; plain transformers generation is used instead. Re-enable after pinning vllm in space/training/requirements.txt.
fast_inference=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
from training.training_script import build_dataset, make_reward_fn
env = CERNCollisionEnvironment(max_steps=args.max_steps)
dataset = build_dataset(
tokenizer=tokenizer,
n_prompts=args.total_episodes,
seed=args.seed,
scenario=args.scenario,
difficulty=args.difficulty,
curriculum=args.curriculum,
)
ctx = EpisodeContext(
env=env, seed=args.seed,
scenario=args.scenario, difficulty=args.difficulty,
)
reward_fn = make_reward_fn(ctx, accumulator=component_accumulator)
cfg = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
num_generations=args.num_generations,
learning_rate=args.learning_rate,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
seed=args.seed,
bf16=True,
report_to=[],
)
held_out_seeds = list(range(900_000, 900_000 + args.checkpoint_eval_episodes))
class EvidenceCallback(TrainerCallback):
"""Stream training metrics + run periodic mid-training evals."""
def __init__(self) -> None:
self._t0 = time.time()
self._last_eval_step = -1
def on_log(self, _args, state, control, logs=None, **kw):
logs = logs or {}
row = {
"step": state.global_step,
"epoch": logs.get("epoch"),
"loss": logs.get("loss"),
"reward": logs.get("reward") or logs.get("rewards/mean"),
"reward_std": logs.get("reward_std") or logs.get("rewards/std"),
"kl": logs.get("kl"),
"grad_norm": logs.get("grad_norm"),
"learning_rate": logs.get("learning_rate"),
"wall_time_s": round(time.time() - self._t0, 2),
}
if any(v is not None for k, v in row.items() if k != "step"):
log_writer.append(row)
render_training_curve(paths.training_log_csv, paths.training_curve_png)
# Per-component reward summary (FAQ Q17, Q43, Q52: don't watch
# only the mean reward — track terminal vs shaping, success
# rates, and parse rate so verifier hacks become visible).
drained = component_accumulator.drain()
if drained:
summary = RewardComponentAccumulator.summarise(drained)
summary["step"] = state.global_step
component_writer.append(summary)
render_reward_components(
paths.reward_components_csv, paths.reward_components_png,
)
def on_step_end(self, _args, state, control, **kw):
step = state.global_step
if step <= 0 or step == self._last_eval_step:
return control
if step % args.checkpoint_eval_steps != 0:
return control
self._last_eval_step = step
try:
self._run_checkpoint_eval(step, state)
except Exception as exc:
logger.warning("checkpoint eval failed at step %d: %s", step, exc)
return control
def _run_checkpoint_eval(self, step: int, state) -> None:
FastLanguageModel.for_inference(model)
try:
# When curriculum is enabled, evaluate at whatever tier the
# adaptive manager currently considers appropriate. Otherwise
# use the static --difficulty.
eval_difficulty = (
curriculum.next_difficulty()
if curriculum is not None
else args.difficulty
)
episodes = []
for s in held_out_seeds:
ep = self._rollout_one(seed=s, difficulty=eval_difficulty)
if ep is not None:
episodes.append(ep)
if not episodes:
return
rewards = [e.cumulative_reward for e in episodes]
success_rate = sum(1 for e in episodes if e.discovered) / len(episodes)
ckpt_writer.append(
step=step,
fraction_done=round(step / max(state.max_steps or step, 1), 4),
episodes=len(episodes),
mean_reward=round(sum(rewards) / len(rewards), 4),
success_rate=round(success_rate, 4),
mass_acc=round(sum(1 for e in episodes if e.correct_mass) / len(episodes), 4),
channel_acc=round(sum(1 for e in episodes if e.correct_channel) / len(episodes), 4),
)
render_checkpoint_progression(
paths.checkpoint_evals_csv,
paths.checkpoint_progression_png,
)
if curriculum is not None:
snap = curriculum.record(
success=success_rate >= 0.5,
reward=sum(rewards) / len(rewards),
)
curriculum.save(paths.root / "curriculum_state.json")
if snap.get("event"):
logger.info(
"[curriculum] %s @ step=%d → tier=%s (rolling=%.2f)",
snap["event"], step, snap["current"], snap["rolling_success"],
)
logger.info(
"[checkpoint-eval step=%d difficulty=%s] reward=%.3f success=%.2f",
step, eval_difficulty,
rewards and (sum(rewards) / len(rewards)) or 0.0,
success_rate,
)
finally:
FastLanguageModel.for_training(model)
def _rollout_one(self, seed: int, difficulty: Optional[str] = None):
def prompt_fn(chat):
return tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
def generate_fn(prompt: str, _config) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=args.max_completion_length,
do_sample=True, temperature=0.7, top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
)
gen = outputs[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(gen, skip_special_tokens=True)
return collect_episode(
env=env, seed=seed,
scenario=args.scenario,
difficulty=difficulty or args.difficulty,
prompt_fn=prompt_fn, generate_fn=generate_fn,
config=LLMAgentConfig(),
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
reward_funcs=[reward_fn],
args=cfg,
callbacks=[EvidenceCallback()],
)
logger.info("Starting Unsloth + LoRA GRPO training")
trainer.train()
# Drain whatever rollouts the final on_log didn't catch so the last
# row of reward_components.csv is correct.
final_drain = component_accumulator.drain()
if final_drain:
summary = RewardComponentAccumulator.summarise(final_drain)
summary["step"] = trainer.state.global_step
component_writer.append(summary)
render_reward_components(
paths.reward_components_csv, paths.reward_components_png,
)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("Saved adapters to %s", args.output_dir)
logger.info("Evidence artifacts in %s", paths.root)
if __name__ == "__main__": # pragma: no cover
main()