"""
GRPO Learning Experiment — Budget Router (~90 min on M4 Mac MPS)
GOAL
Determine (80-90% confidence) whether BudgetRouterGRPOEnv provides a
learnable signal for GRPO on Qwen/Qwen3-0.6B.
TIMING (from smoke test)
Outer iteration = 1 rollout step (~50s) + 3 gradient steps (~6s) = ~56s
max_steps=360 → ~90 outer iterations → ~84 min on M4 48GB MPS
VERDICT LOGIC
Compare mean reward of first 25% vs last 25% of rollout steps.
- LEARNING DETECTED : last_quarter > first_quarter + 0.05
- NOT LEARNING : last_quarter < first_quarter - 0.02
- INCONCLUSIVE : otherwise (high variance, too few tool calls)
USAGE
PYTORCH_ENABLE_MPS_FALLBACK=1 uv run python train/learn_experiment.py
KEY DIFFERENCES FROM smoke_test.py
- max_steps=360 (vs 10)
- learning_rate=5e-6 (vs 5e-7 — standard GRPO lr per DeepSeek-R1 paper)
- Proper rollout-only callback: separates data-collection from gradient steps
- Rolling reward average printed every 10 rollout steps
- VERDICT analysis at the end
"""
from __future__ import annotations
import argparse
import math
import os
import sys
import time
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import torch
from collections import Counter
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import GRPOConfig, GRPOTrainer
from budget_router.reward import grade_episode
from train.grpo_env import BudgetRouterGRPOEnv
# ── Config ───────────────────────────────────────────────────────────────────
# NOTE: model is configurable via --model-name (defaults to a stronger tool caller).
DEFAULT_MODEL_NAME = "Qwen/Qwen3-1.7B"
PROMPT_STYLES = ("strict", "explore")
STRICT_SYSTEM_PROMPT = (
"You are a budget-aware API router. "
"Use the available tools to route each request to the best provider. "
"Providers can degrade mid-episode — monitor health and switch early.\n\n"
"TOOL CALL FORMAT — you must use this EXACT format to call a tool:\n"
"\n{\"name\": \"route_to_a\", \"arguments\": {}}\n\n\n"
"EXAMPLES:\n"
" Observation: A=0.833, B=0.500, C=0.500, budget=95%\n"
" Action: \n{\"name\": \"route_to_a\", \"arguments\": {}}\n\n\n"
" Observation: A=0.100 (degraded!), B=0.833, C=0.500, budget=70%\n"
" Action: \n{\"name\": \"route_to_b\", \"arguments\": {}}\n\n\n"
"Route to the cheapest healthy provider (A=$0.01 < B=$0.05 < C=$0.10). "
"Switch early if a provider's health drops below 0.52."
)
EXPLORE_SYSTEM_PROMPT = (
"You are a budget-aware API router. "
"Use exactly one available tool after each observation to route the current request.\n\n"
"TOOL CALL FORMAT — respond with only this XML-wrapped JSON shape:\n"
"\n{\"name\": \"route_to_a\", \"arguments\": {}}\n\n\n"
"Available tools: route_to_a, route_to_b, route_to_c, shed_load.\n"
"Objective: maximize successful routed requests while preserving budget and avoiding "
"unhealthy or overloaded providers. Providers can degrade mid-episode, so use the "
"latest observed health, latency, queue, budget, and steps-left information.\n"
"Do not follow a fixed provider cycle. Change actions only when the observations "
"make a different provider or shedding load look better."
)
def build_system_prompt(prompt_style: str = "strict") -> str:
if prompt_style == "strict":
return STRICT_SYSTEM_PROMPT
if prompt_style == "explore":
return EXPLORE_SYSTEM_PROMPT
raise ValueError(f"Unknown prompt_style={prompt_style!r}; expected one of {PROMPT_STYLES}")
# ── Reward function ──────────────────────────────────────────────────────────
LAST_ROLLOUT_DIAGNOSTICS: dict[str, object] = {}
def episode_training_reward(env: BudgetRouterGRPOEnv) -> float:
internal = env._env._internal
history = internal.history
if not history:
return 0.0
grader = float(grade_episode(history)["overall_score"])
if internal.episode_done:
return grader
progress = internal.current_step / max(1, internal.max_steps)
return grader * progress
def _mean(values: list[float]) -> float:
return sum(values) / len(values) if values else 0.0
def _action_sequence(history) -> str:
actions = [str(step.get("action_type", "unknown")) for step in history]
return " ".join(actions) if actions else ""
def summarize_training_rollout(environments) -> dict[str, object]:
env_steps = []
progress = []
raw_graders = []
training_rewards = []
completions = []
budget_exhaustions = []
action_sequences = []
for env in environments:
internal = env._env._internal
history = internal.history
env_steps.append(float(internal.current_step))
progress.append(float(internal.current_step / max(1, internal.max_steps)))
raw_graders.append(float(grade_episode(history)["overall_score"]) if history else 0.0)
training_rewards.append(episode_training_reward(env))
completions.append(1.0 if internal.episode_done else 0.0)
budget_exhaustions.append(
1.0 if any(step.get("budget_exhausted", False) for step in history) else 0.0
)
action_sequences.append(_action_sequence(history))
sequence_counts = dict(Counter(action_sequences))
return {
"env_steps_mean": _mean(env_steps),
"env_steps_min": min(env_steps) if env_steps else 0.0,
"env_steps_max": max(env_steps) if env_steps else 0.0,
"progress_mean": _mean(progress),
"raw_grader_mean": _mean(raw_graders),
"training_reward_mean": _mean(training_rewards),
"training_rewards": training_rewards,
"episode_completion_rate": _mean(completions),
"budget_exhaustion_rate": _mean(budget_exhaustions),
"action_sequences": action_sequences,
"unique_action_sequences": len(sequence_counts),
"action_sequence_counts": sequence_counts,
}
def reward_func(environments, **kwargs):
global LAST_ROLLOUT_DIAGNOSTICS
LAST_ROLLOUT_DIAGNOSTICS = summarize_training_rollout(environments)
return [episode_training_reward(env) for env in environments]
# ── Dataset ──────────────────────────────────────────────────────────────────
def build_dataset(n: int = 200, prompt_style: str = "strict") -> Dataset:
system_prompt = build_system_prompt(prompt_style)
return Dataset.from_list([
{
"prompt": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "Route the incoming requests optimally."},
],
"scenario": "hard_multi",
"seed": i,
}
for i in range(n)
])
# ── Callback ─────────────────────────────────────────────────────────────────
class LearnCallback(TrainerCallback):
"""
Tracks rollout steps (data-collection) separately from gradient-only steps.
GRPO pattern: 1 rollout step + 3 gradient-only steps = 1 outer iteration.
Only rollout steps carry reward/tool_call metrics.
"""
ROLLOUT_PRINT_EVERY = 10 # print rolling average every N rollout steps
def __init__(self):
self.rollout_rewards: list[float] = []
self.total_grad_steps: int = 0
self.tool_call_freqs: list[float] = []
self.env_step_means: list[float] = []
self.completion_rates: list[float] = []
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs or state.global_step == 0:
return
if "train_runtime" in logs:
return
step = state.global_step
self.total_grad_steps = max(self.total_grad_steps, step)
loss = float(logs.get("loss", float("nan")))
# Gradient-only steps have no reward key
if "reward" not in logs:
print(f" [{step:03d}] grad-only | loss={loss:.4f}")
return
reward = float(logs.get("reward", float("nan")))
reward_std = float(logs.get("reward_std", float("nan")))
tool_freq = float(logs.get("tools/call_frequency", float("nan")))
diagnostics = LAST_ROLLOUT_DIAGNOSTICS
if not math.isnan(reward):
self.rollout_rewards.append(reward)
if not math.isnan(tool_freq):
self.tool_call_freqs.append(tool_freq)
if diagnostics:
self.env_step_means.append(float(diagnostics["env_steps_mean"]))
self.completion_rates.append(float(diagnostics["episode_completion_rate"]))
rollout_n = len(self.rollout_rewards)
rolling_avg = (
sum(self.rollout_rewards[-10:]) / len(self.rollout_rewards[-10:])
if self.rollout_rewards else float("nan")
)
sequence_counts = diagnostics.get("action_sequence_counts", {}) if diagnostics else {}
unique_sequences = int(diagnostics.get("unique_action_sequences", 0)) if diagnostics else 0
group_size = len(diagnostics.get("action_sequences", [])) if diagnostics else 0
reward_values = diagnostics.get("training_rewards", []) if diagnostics else []
reward_preview = ",".join(f"{float(v):.4f}" for v in reward_values)
print(
f" [{step:03d}] ROLLOUT #{rollout_n:02d} | "
f"reward={reward:.4f} | std={reward_std:.4f} | "
f"tool_freq={tool_freq:.2f} | "
f"env_steps={diagnostics.get('env_steps_mean', float('nan')):.1f} | "
f"complete={diagnostics.get('episode_completion_rate', float('nan')):.2f} | "
f"raw={diagnostics.get('raw_grader_mean', float('nan')):.4f} | "
f"seqs={unique_sequences}/{group_size} | "
f"rewards=[{reward_preview}] | "
f"rolling10={rolling_avg:.4f} | loss={loss:.4f}"
)
if unique_sequences <= 3 and sequence_counts:
counts = " || ".join(
f"{count}x {sequence}" for sequence, count in sequence_counts.items()
)
print(f" action_sequences: {counts}")
if rollout_n % self.ROLLOUT_PRINT_EVERY == 0:
self._print_trend_summary()
def _print_trend_summary(self):
n = len(self.rollout_rewards)
if n < 4:
return
q = max(1, n // 4)
first_q = sum(self.rollout_rewards[:q]) / q
last_q = sum(self.rollout_rewards[-q:]) / q
avg_tool = (
sum(self.tool_call_freqs) / len(self.tool_call_freqs)
if self.tool_call_freqs else float("nan")
)
avg_env_steps = (
sum(self.env_step_means) / len(self.env_step_means)
if self.env_step_means else float("nan")
)
avg_completion = (
sum(self.completion_rates) / len(self.completion_rates)
if self.completion_rates else float("nan")
)
print(
f"\n ── Trend @ rollout {n} ──\n"
f" first-quarter mean : {first_q:.4f}\n"
f" last-quarter mean : {last_q:.4f}\n"
f" delta : {last_q - first_q:+.4f}\n"
f" avg tool_call_freq : {avg_tool:.3f}\n"
f" avg env_steps : {avg_env_steps:.2f}\n"
f" avg completion_rate : {avg_completion:.3f}\n"
)
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
t0 = time.time()
parser = argparse.ArgumentParser(description="GRPO Learning Experiment — Budget Router")
parser.add_argument(
"--model-name",
type=str,
default=DEFAULT_MODEL_NAME,
help="HF model id to train (default: Qwen/Qwen3-1.7B).",
)
parser.add_argument("--max-steps", type=int, default=360, help="Total GRPO max_steps (outer iterations).")
parser.add_argument("--dataset-n", type=int, default=200, help="Number of episodes in training dataset.")
parser.add_argument("--save-steps", type=int, default=60, help="Checkpoint save frequency in steps.")
parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt.")
parser.add_argument("--max-completion-length", type=int, default=1024, help="Completion token budget across tool loop.")
parser.add_argument("--max-tool-calling-iterations", type=int, default=20, help="Maximum tool loop iterations per rollout.")
parser.add_argument(
"--prompt-style",
choices=PROMPT_STYLES,
default="strict",
help="Prompt style: strict preserves the original heuristic prompt; explore reduces deterministic policy bias.",
)
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature for GRPO rollouts.")
parser.add_argument("--top-p", type=float, default=1.0, help="Nucleus sampling top-p for GRPO rollouts.")
cli = parser.parse_args()
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print("=" * 68)
print("GRPO Learning Experiment — Budget Router")
print("=" * 68)
print(f"Device : {device.upper()}")
print(f"Model : {cli.model_name}")
print(f"Steps : {cli.max_steps}")
print(f"Prompt : {cli.prompt_style}")
print(f"Sampling : temperature={cli.temperature} top_p={cli.top_p}")
print(f"Torch : {torch.__version__}")
print("=" * 68)
print("\nLoading model...")
dtype = torch.bfloat16 if device in ("mps", "cuda") else torch.float32
model = AutoModelForCausalLM.from_pretrained(
cli.model_name, dtype=dtype, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(cli.model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# Hyperparams: TRL/OpenEnv Wordle example + DeepSeek-R1 paper
# lr=1e-6: standard GRPO (smoke test used 5e-7, too conservative for 360 steps)
# enable_thinking=False: reclaims ~400 tokens/step from Qwen3 reasoning blocks,
# allowing max_completion_length to drop from 512→256 without clipping valid tool calls.
args = GRPOConfig(
max_steps=cli.max_steps,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
num_generations=cli.num_generations,
generation_batch_size=cli.num_generations,
max_completion_length=cli.max_completion_length,
max_tool_calling_iterations=cli.max_tool_calling_iterations,
temperature=cli.temperature,
top_p=cli.top_p,
beta=0.001,
learning_rate=5e-6,
optim="adamw_torch",
report_to="none",
logging_steps=1,
remove_unused_columns=False,
dataloader_num_workers=0,
save_steps=cli.save_steps,
output_dir="trained_models/grpo_checkpoints",
chat_template_kwargs={"enable_thinking": False}, # Qwen3: skip blocks
)
dataset = build_dataset(n=cli.dataset_n, prompt_style=cli.prompt_style)
cb = LearnCallback()
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=reward_func,
train_dataset=dataset,
args=args,
peft_config=peft_config,
environment_factory=BudgetRouterGRPOEnv,
callbacks=[cb],
)
def _save_merged(label: str) -> None:
"""
Merge LoRA weights into the base model and save to disk.
The merged model is a plain HuggingFace model — loadable with
AutoModelForCausalLM.from_pretrained() without any PEFT dependency.
"""
safe_name = (
cli.model_name.replace("/", "_")
.replace(":", "_")
.replace("@", "_")
)
save_path = f"trained_models/grpo_{safe_name}"
print(f"\n[{label}] Merging LoRA into base model and saving to {save_path} ...")
try:
merged = trainer.model.merge_and_unload()
merged.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"[{label}] ✅ Saved. Load with: AutoModelForCausalLM.from_pretrained('{save_path}')")
except Exception as e:
print(f"[{label}] ⚠️ Save failed: {e}")
print("\nStarting experiment (Ctrl+C to stop early — partial results still printed)...\n")
try:
trainer.train()
_save_merged("END")
except KeyboardInterrupt:
print("\n[Interrupted by user — computing verdict on partial results...]")
_save_merged("INTERRUPT")
except Exception as exc:
print(f"\n❌ Training error: {type(exc).__name__}: {exc}")
sys.exit(1)
elapsed = time.time() - t0
# ── VERDICT ──────────────────────────────────────────────────────────────
rewards = cb.rollout_rewards
tool_freqs = cb.tool_call_freqs
env_step_means = cb.env_step_means
completion_rates = cb.completion_rates
print("\n" + "=" * 68)
print("LEARNING EXPERIMENT — VERDICT")
print("=" * 68)
print(f"Grad steps completed : {cb.total_grad_steps}/{cli.max_steps}")
print(f"Rollout steps : {len(rewards)}")
print(f"Elapsed : {elapsed/60:.1f} min")
if len(rewards) < 4:
print("\n⚠️ Too few rollout steps for a verdict. Run longer.")
sys.exit(0)
q = max(1, len(rewards) // 4)
first_q_mean = sum(rewards[:q]) / q
last_q_mean = sum(rewards[-q:]) / q
delta = last_q_mean - first_q_mean
avg_tool_freq = sum(tool_freqs) / len(tool_freqs) if tool_freqs else 0.0
avg_env_steps = sum(env_step_means) / len(env_step_means) if env_step_means else 0.0
avg_completion_rate = sum(completion_rates) / len(completion_rates) if completion_rates else 0.0
overall_mean = sum(rewards) / len(rewards)
print(f"\nReward summary:")
print(f" First-quarter mean : {first_q_mean:.4f}")
print(f" Last-quarter mean : {last_q_mean:.4f}")
print(f" Delta (improvement): {delta:+.4f}")
print(f" Overall mean : {overall_mean:.4f}")
print(f" Avg tool_call_freq : {avg_tool_freq:.3f}")
print(f" Avg env_steps : {avg_env_steps:.2f}")
print(f" Avg completion_rate: {avg_completion_rate:.3f}")
print(f"\nHeuristic baseline : ~0.60-0.65 (from environment benchmark)")
print(f"Random agent : ~0.15-0.20")
print()
if delta > 0.05:
print("✅ VERDICT: LEARNING SIGNAL DETECTED")
print(" Reward improved meaningfully across the run.")
print(" Recommend: scale up to Qwen2.5-1.5B-Instruct + num_generations=8.")
elif delta < -0.02:
print("❌ VERDICT: NOT LEARNING")
print(" Reward trended downward. Possible causes:")
print(" - lr too high (try 5e-7)")
print(" - too few tool calls (model not generating tool syntax reliably)")
print(" - environment reward too sparse")
else:
print("⚠️ VERDICT: INCONCLUSIVE")
print(f" Delta={delta:+.4f} is within noise range.")
if avg_tool_freq < 0.3:
print(" Root cause likely: tool_call_freq too low — model rarely uses tools.")
print(" Fix: add few-shot tool-call examples to the system prompt.")
elif avg_env_steps < 10:
print(" Root cause likely: rollouts are too short for 20-step routing.")
print(" Fix: inspect completion budget, compact tool responses, and tool-loop limits.")
else:
print(" More steps needed. Try 600-800 steps for clearer trend.")
print("=" * 68)
if __name__ == "__main__":
main()