""" GRPO Smoke Test — 10 gradient steps, M4 Mac MPS (or CUDA/CPU). PURPOSE Validate the full TRL training loop (model → rollout → reward → gradient) works end-to-end with BudgetRouterGRPOEnv before a full training run. NOT for actual learning — 10 steps is statistical noise. USAGE Requires optional GRPO deps (`uv sync --extra grpo`), then e.g.: PYTORCH_ENABLE_MPS_FALLBACK=1 uv run python train/smoke_test.py EXPECTED RUNTIME ~5-10 min on M4 Mac 48 GB (MPS, Qwen2.5-0.5B-Instruct) HYPERPARAMETERS (source) - learning_rate, beta, temperature: DeepSeek-R1 GRPO paper + TRL Wordle example - num_generations=4: minimum GRPO group; 8+ for real training - max_completion_length=512: enough for ~10 multi-turn tool calls at 0.5B - optim=adamw_torch: paged_adamw_8bit is CUDA-only - No vLLM, no load_in_4bit: both CUDA-only PASS CRITERIA - 10 gradient steps complete without exception - reward_mean is a finite float (0.0 acceptable — model is untrained) - loss is finite """ from __future__ import annotations import math import os import sys import time # Must be set before importing torch — causes MPS to fall back to CPU for # unsupported Metal ops (e.g. some GRPOTrainer matmul variants). os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") # Suppress tokenizer parallelism warnings os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") try: import torch from datasets import Dataset from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback from trl import GRPOConfig, GRPOTrainer except ModuleNotFoundError as exc: name = getattr(exc, "name", None) or str(exc) print( "\nGRPO smoke test requires optional packages (torch, datasets, trl, …).\n" f"Missing: {name}\n\n" "Install with:\n" " uv sync --extra grpo\n\n" "Then re-run this script.\n", file=sys.stderr, ) raise SystemExit(1) from exc from budget_router.reward import grade_episode from train.grpo_env import BudgetRouterGRPOEnv # ── Constants ──────────────────────────────────────────────────────────────── # Smallest Qwen2.5 with validated function-calling support. # Smoke test only — use Qwen2.5-1.5B-Instruct for real training. MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" SYSTEM_PROMPT = ( "You are a budget-aware API router. " "Use the available tools to route each request to the best provider. " "Adapt when providers degrade — switch away from failing providers early." ) # ── Reward function ────────────────────────────────────────────────────────── def reward_func(environments, **kwargs): """ TRL reads env instances after each rollout. Returns List[float] in [0, 1]. grade_episode() is the calibrated grader used by the eval pipeline — keeps training and eval metrics consistent. """ rewards = [] for env in environments: history = env._env._internal.history if not history: # Model made no tool calls — assign 0, not an error rewards.append(0.0) else: rewards.append(float(grade_episode(history)["overall_score"])) return rewards # ── Dataset ────────────────────────────────────────────────────────────────── def build_dataset(n: int = 32) -> Dataset: """ Minimal dataset. Columns become **kwargs in BudgetRouterGRPOEnv.reset(). 'prompt' is required by GRPOTrainer (messages format). 'scenario' and 'seed' are passed to reset() for episode configuration. """ return Dataset.from_list([ { "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": "Route the incoming requests optimally."}, ], "scenario": "hard_multi", "seed": i, } for i in range(n) ]) # ── Step logger ────────────────────────────────────────────────────────────── class SmokeTestCallback(TrainerCallback): """Captures per-step metrics for PASS/FAIL evaluation.""" def __init__(self): self.steps: list[dict] = [] def on_log(self, args, state, control, logs=None, **kwargs): if not logs or state.global_step == 0: return # TRL 1.x logs reward under "reward" or "train/reward" reward_mean = logs.get("reward", logs.get("train/reward", float("nan"))) reward_std = logs.get("reward_std", logs.get("train/reward_std", float("nan"))) loss = logs.get("loss", logs.get("train/loss", float("nan"))) entry = { "step": state.global_step, "reward_mean": float(reward_mean), "reward_std": float(reward_std), "loss": float(loss), } self.steps.append(entry) print( f" Step {entry['step']:02d}/10 | " f"loss={entry['loss']:.4f} | " f"reward_mean={entry['reward_mean']:.4f} | " f"reward_std={entry['reward_std']:.4f}" ) # ── Main ───────────────────────────────────────────────────────────────────── def main(): t0 = time.time() # Device detection if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" print("=" * 62) print("GRPO Smoke Test — Budget Router") print("=" * 62) print(f"Device : {device.upper()}") print(f"Model : {MODEL_NAME}") print(f"Steps : 10 (num_generations=4 → 40 rollouts total)") print(f"Torch : {torch.__version__}") if device == "cpu": print("⚠️ WARNING: Running on CPU. Expect ~30-60 min for 10 steps.") print("=" * 62) # Load model — explicit dtype for MPS (bfloat16 supported on M-series) print("\nLoading model (may download on first run)...") dtype = torch.bfloat16 if device in ("mps", "cuda") else torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=dtype, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # LoRA: small rank for smoke test — keeps memory and step time low peft_config = LoraConfig( r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) # GRPOConfig — hyperparams per TRL/OpenEnv Wordle example + DeepSeek-R1 # Source: https://huggingface.co/docs/trl/openenv (Wordle section) # DeepSeek-R1 paper: lr=1e-6, temp=1.0, beta=0.001 args = GRPOConfig( max_steps=10, per_device_train_batch_size=1, gradient_accumulation_steps=1, num_generations=4, # min for GRPO; use 8 for real runs generation_batch_size=4, # TRL 1.x: must be divisible by num_generations (see learn_experiment.py) max_completion_length=512, # ~10 multi-turn tool-call turns temperature=1.0, # diverse exploration (DeepSeek-R1) beta=0.001, # KL penalty; small for verifiable tasks learning_rate=5e-7, # conservative; real training: 1e-6 optim="adamw_torch", # paged_adamw_8bit is CUDA-only report_to="none", # no WandB prompt logging_steps=1, # log every step for smoke visibility remove_unused_columns=False, # CRITICAL: keeps scenario/seed cols for reset() dataloader_num_workers=0, # avoid MPS multiprocessing issues output_dir="/tmp/grpo_smoke", ) dataset = build_dataset(n=32) logger = SmokeTestCallback() trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=reward_func, train_dataset=dataset, args=args, peft_config=peft_config, environment_factory=BudgetRouterGRPOEnv, callbacks=[logger], ) print("\nStarting training loop...\n") try: trainer.train() except Exception as exc: elapsed = time.time() - t0 print(f"\n❌ Training loop raised {type(exc).__name__} after {elapsed:.0f}s:") print(f" {exc}") print("\n=== SMOKE TEST: FAIL ===") sys.exit(1) elapsed = time.time() - t0 # Evaluate if not logger.steps: print("\n❌ No steps were logged — trainer may have exited early.") print("=== SMOKE TEST: FAIL ===") sys.exit(1) last = logger.steps[-1] reward_mean = last["reward_mean"] reward_std = last["reward_std"] loss = last["loss"] passed = ( len(logger.steps) >= 10 and not math.isnan(reward_mean) and not math.isnan(loss) and not math.isinf(loss) ) print("\n" + "=" * 62) print("SMOKE TEST RESULT") print("=" * 62) print(f"Steps completed : {len(logger.steps)}/10") print(f"reward_mean : {reward_mean:.4f}") print(f"reward_std : {reward_std:.4f}") print(f"loss : {loss:.4f}") print(f"elapsed : {elapsed:.0f}s") print() if passed: print("✅ PASS — Loop is functional. Scale up with Qwen2.5-1.5B + num_generations=8.") else: print("❌ FAIL — Fix issues above before full training run.") print("=" * 62) if not passed: sys.exit(1) if __name__ == "__main__": main()