Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
"""
GRPO Smoke Test — 10 gradient steps, M4 Mac MPS (or CUDA/CPU).
PURPOSE
Validate the full TRL training loop (model → rollout → reward → gradient)
works end-to-end with BudgetRouterGRPOEnv before a full training run.
NOT for actual learning — 10 steps is statistical noise.
USAGE
Requires optional GRPO deps (`uv sync --extra grpo`), then e.g.:
PYTORCH_ENABLE_MPS_FALLBACK=1 uv run python train/smoke_test.py
EXPECTED RUNTIME
~5-10 min on M4 Mac 48 GB (MPS, Qwen2.5-0.5B-Instruct)
HYPERPARAMETERS (source)
- learning_rate, beta, temperature: DeepSeek-R1 GRPO paper + TRL Wordle example
- num_generations=4: minimum GRPO group; 8+ for real training
- max_completion_length=512: enough for ~10 multi-turn tool calls at 0.5B
- optim=adamw_torch: paged_adamw_8bit is CUDA-only
- No vLLM, no load_in_4bit: both CUDA-only
PASS CRITERIA
- 10 gradient steps complete without exception
- reward_mean is a finite float (0.0 acceptable — model is untrained)
- loss is finite
"""
from __future__ import annotations
import math
import os
import sys
import time
# Must be set before importing torch — causes MPS to fall back to CPU for
# unsupported Metal ops (e.g. some GRPOTrainer matmul variants).
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
# Suppress tokenizer parallelism warnings
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
try:
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import GRPOConfig, GRPOTrainer
except ModuleNotFoundError as exc:
name = getattr(exc, "name", None) or str(exc)
print(
"\nGRPO smoke test requires optional packages (torch, datasets, trl, …).\n"
f"Missing: {name}\n\n"
"Install with:\n"
" uv sync --extra grpo\n\n"
"Then re-run this script.\n",
file=sys.stderr,
)
raise SystemExit(1) from exc
from budget_router.reward import grade_episode
from train.grpo_env import BudgetRouterGRPOEnv
# ── Constants ────────────────────────────────────────────────────────────────
# Smallest Qwen2.5 with validated function-calling support.
# Smoke test only — use Qwen2.5-1.5B-Instruct for real training.
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
SYSTEM_PROMPT = (
"You are a budget-aware API router. "
"Use the available tools to route each request to the best provider. "
"Adapt when providers degrade — switch away from failing providers early."
)
# ── Reward function ──────────────────────────────────────────────────────────
def reward_func(environments, **kwargs):
"""
TRL reads env instances after each rollout. Returns List[float] in [0, 1].
grade_episode() is the calibrated grader used by the eval pipeline — keeps
training and eval metrics consistent.
"""
rewards = []
for env in environments:
history = env._env._internal.history
if not history:
# Model made no tool calls — assign 0, not an error
rewards.append(0.0)
else:
rewards.append(float(grade_episode(history)["overall_score"]))
return rewards
# ── Dataset ──────────────────────────────────────────────────────────────────
def build_dataset(n: int = 32) -> Dataset:
"""
Minimal dataset. Columns become **kwargs in BudgetRouterGRPOEnv.reset().
'prompt' is required by GRPOTrainer (messages format).
'scenario' and 'seed' are passed to reset() for episode configuration.
"""
return Dataset.from_list([
{
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "Route the incoming requests optimally."},
],
"scenario": "hard_multi",
"seed": i,
}
for i in range(n)
])
# ── Step logger ──────────────────────────────────────────────────────────────
class SmokeTestCallback(TrainerCallback):
"""Captures per-step metrics for PASS/FAIL evaluation."""
def __init__(self):
self.steps: list[dict] = []
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs or state.global_step == 0:
return
# TRL 1.x logs reward under "reward" or "train/reward"
reward_mean = logs.get("reward", logs.get("train/reward", float("nan")))
reward_std = logs.get("reward_std", logs.get("train/reward_std", float("nan")))
loss = logs.get("loss", logs.get("train/loss", float("nan")))
entry = {
"step": state.global_step,
"reward_mean": float(reward_mean),
"reward_std": float(reward_std),
"loss": float(loss),
}
self.steps.append(entry)
print(
f" Step {entry['step']:02d}/10 | "
f"loss={entry['loss']:.4f} | "
f"reward_mean={entry['reward_mean']:.4f} | "
f"reward_std={entry['reward_std']:.4f}"
)
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
t0 = time.time()
# Device detection
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print("=" * 62)
print("GRPO Smoke Test — Budget Router")
print("=" * 62)
print(f"Device : {device.upper()}")
print(f"Model : {MODEL_NAME}")
print(f"Steps : 10 (num_generations=4 → 40 rollouts total)")
print(f"Torch : {torch.__version__}")
if device == "cpu":
print("⚠️ WARNING: Running on CPU. Expect ~30-60 min for 10 steps.")
print("=" * 62)
# Load model — explicit dtype for MPS (bfloat16 supported on M-series)
print("\nLoading model (may download on first run)...")
dtype = torch.bfloat16 if device in ("mps", "cuda") else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# LoRA: small rank for smoke test — keeps memory and step time low
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# GRPOConfig — hyperparams per TRL/OpenEnv Wordle example + DeepSeek-R1
# Source: https://huggingface.co/docs/trl/openenv (Wordle section)
# DeepSeek-R1 paper: lr=1e-6, temp=1.0, beta=0.001
args = GRPOConfig(
max_steps=10,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
num_generations=4, # min for GRPO; use 8 for real runs
generation_batch_size=4, # TRL 1.x: must be divisible by num_generations (see learn_experiment.py)
max_completion_length=512, # ~10 multi-turn tool-call turns
temperature=1.0, # diverse exploration (DeepSeek-R1)
beta=0.001, # KL penalty; small for verifiable tasks
learning_rate=5e-7, # conservative; real training: 1e-6
optim="adamw_torch", # paged_adamw_8bit is CUDA-only
report_to="none", # no WandB prompt
logging_steps=1, # log every step for smoke visibility
remove_unused_columns=False, # CRITICAL: keeps scenario/seed cols for reset()
dataloader_num_workers=0, # avoid MPS multiprocessing issues
output_dir="/tmp/grpo_smoke",
)
dataset = build_dataset(n=32)
logger = SmokeTestCallback()
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=reward_func,
train_dataset=dataset,
args=args,
peft_config=peft_config,
environment_factory=BudgetRouterGRPOEnv,
callbacks=[logger],
)
print("\nStarting training loop...\n")
try:
trainer.train()
except Exception as exc:
elapsed = time.time() - t0
print(f"\n❌ Training loop raised {type(exc).__name__} after {elapsed:.0f}s:")
print(f" {exc}")
print("\n=== SMOKE TEST: FAIL ===")
sys.exit(1)
elapsed = time.time() - t0
# Evaluate
if not logger.steps:
print("\n❌ No steps were logged — trainer may have exited early.")
print("=== SMOKE TEST: FAIL ===")
sys.exit(1)
last = logger.steps[-1]
reward_mean = last["reward_mean"]
reward_std = last["reward_std"]
loss = last["loss"]
passed = (
len(logger.steps) >= 10
and not math.isnan(reward_mean)
and not math.isnan(loss)
and not math.isinf(loss)
)
print("\n" + "=" * 62)
print("SMOKE TEST RESULT")
print("=" * 62)
print(f"Steps completed : {len(logger.steps)}/10")
print(f"reward_mean : {reward_mean:.4f}")
print(f"reward_std : {reward_std:.4f}")
print(f"loss : {loss:.4f}")
print(f"elapsed : {elapsed:.0f}s")
print()
if passed:
print("✅ PASS — Loop is functional. Scale up with Qwen2.5-1.5B + num_generations=8.")
else:
print("❌ FAIL — Fix issues above before full training run.")
print("=" * 62)
if not passed:
sys.exit(1)
if __name__ == "__main__":
main()