Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
"""
BudgetRouterGRPOEnv — TRL environment_factory-compatible class for GRPO training.
Usage with TRL GRPOTrainer:
from datasets import Dataset
from train.grpo_env import BudgetRouterGRPOEnv
from budget_router.reward import grade_episode
# Dataset: columns become **kwargs in reset(). "prompt" drives the model's initial message.
dataset = Dataset.from_list([
{"prompt": [[{"role": "user", "content": "Route requests using the available tools."}]],
"scenario": "hard_multi", "seed": i}
for i in range(200)
])
# reward_funcs with an `environments` parameter is the CORRECT TRL pattern when using
# environment_factory. TRL inspects the signature and passes env instances (not completions).
# This is explicitly documented in the official TRL/OpenEnv integration guide.
# Alternatively, env.reward is set on the instance and TRL reads it directly if
# reward_funcs is omitted — but the explicit function gives more control.
def reward_func(environments, **kwargs):
return [float(grade_episode(env._env._internal.history)["overall_score"])
for env in environments]
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_func,
train_dataset=dataset,
args=GRPOConfig(num_generations=8, max_completion_length=2048),
environment_factory=BudgetRouterGRPOEnv,
)
Design Constraints (do NOT violate):
- Tool methods MUST call self._env.step() and never construct custom step_info dicts.
environment.py writes actual_degradation_start (jittered per-episode) into step_info.
grade_episode() reads degradation_start_step from that dict to compute adaptation windows.
Custom dicts would write the config constant (e.g. 0) instead of the jittered value,
silently corrupting adaptation scores with no crash.
- History is authoritative at self._env._internal.history — never maintain a separate copy.
- Reward is computed once at episode end via grade_episode()["overall_score"] (float in [0,1]).
- Raise ValueError("Episode complete.") when done — TRL catches this and ends the rollout.
Mac / MPS Notes:
- Unsloth does NOT support Mac for training (CUDA-only as of Apr 2026).
- Use TRL + PyTorch MPS: no load_in_4bit, no vLLM, no paged_adamw_8bit.
- PYTORCH_ENABLE_MPS_FALLBACK=1 required for ops not yet on Metal.
- Recommended models for Mac: Qwen2.5-1.5B (fits 8GB+), Qwen2.5-3B (fits 16GB+).
- For Colab/cloud: Unsloth + vLLM work normally on NVIDIA T4/A100.
Reward Variance Note:
- GRPO gradient = 0 when group_std ≈ 0. Use hard_multi scenario (not easy).
- hard_multi has jitter + dual degradation → wider inter-rollout score spread.
- num_generations=8 (not 4) recommended to get better group variance estimates.
"""
from __future__ import annotations
from typing import Optional
from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType
from budget_router.reward import grade_episode
from budget_router.tasks import HARD_MULTI, TASK_PRESETS
class BudgetRouterGRPOEnv:
"""
TRL environment_factory-compatible wrapper around BudgetRouterEnv.
Exposes four named tool methods: route_to_a, route_to_b, route_to_c, shed_load.
The LLM calls these via function-calling. TRL discovers them automatically.
Episode lifecycle:
1. reset(**kwargs) → returns rich text observation (initial state).
2. Model calls tool methods N times until episode ends.
3. Tool method raises ValueError("Episode complete.") when obs.done is True.
4. TRL reads self.reward from the reward_func after the episode.
"""
def __init__(self) -> None:
self._env = BudgetRouterEnv()
self.reward: float = 0.0
# ─── TRL lifecycle ──────────────────────────────────────────────────
def reset(self, **kwargs) -> str:
"""
Reset the environment. Called by TRL at the start of each episode.
Accepts dataset columns as kwargs:
scenario (str): one of "easy", "medium", "hard", "hard_multi" (default).
seed (int): optional fixed seed for reproducibility.
Returns:
str: Initial observation text including provider status, budget, and task brief.
"""
scenario_name = str(kwargs.get("scenario", "hard_multi"))
scenario = TASK_PRESETS.get(scenario_name, HARD_MULTI)
seed: Optional[int] = kwargs.get("seed", None)
if seed is not None:
seed = int(seed)
self._env.reset(seed=seed, scenario=scenario)
self.reward = 0.0
return self._format_observation(is_initial=True)
# ─── Tool methods (TRL discovers all public non-reset methods) ───────
def route_to_a(self) -> str:
"""
Route the current request to Provider A ($0.01/req, cheapest, lowest base reliability).
Args:
(none)
Returns:
Outcome feedback: success/failure, latency, budget remaining, provider health update.
"""
return self._step(ActionType.ROUTE_TO_A)
def route_to_b(self) -> str:
"""
Route the current request to Provider B ($0.05/req, balanced cost and reliability).
Args:
(none)
Returns:
Outcome feedback: success/failure, latency, budget remaining, provider health update.
"""
return self._step(ActionType.ROUTE_TO_B)
def route_to_c(self) -> str:
"""
Route the current request to Provider C ($0.10/req, most expensive, highest base reliability).
Args:
(none)
Returns:
Outcome feedback: success/failure, latency, budget remaining, provider health update.
"""
return self._step(ActionType.ROUTE_TO_C)
def shed_load(self) -> str:
"""
Shed the current request — reject it without routing to any provider.
Use when all providers appear degraded or budget is critically low.
Penalty: -0.5 step reward. Slightly reduces queue backlog.
Args:
(none)
Returns:
Outcome feedback: load shed confirmation, budget remaining, current state.
"""
return self._step(ActionType.SHED_LOAD)
# ─── Internal step dispatch ──────────────────────────────────────────
def _step(self, action_type: ActionType) -> str:
"""
Execute one environment step.
CRITICAL: Delegates entirely to self._env.step(). Never constructs a custom
step_info dict. environment.py writes actual_degradation_start (the jittered
per-episode onset) into step_info; grade_episode() reads this to compute
adaptation scores. A custom dict would use the config constant instead,
silently breaking adaptation scoring.
"""
if self._env._internal.episode_done:
# Guard: called after done — reuse last reward, signal TRL to stop
raise ValueError(
f"Episode already complete. Final score: {self.reward:.3f}"
)
action = Action(action_type=action_type)
obs = self._env.step(action) # step_info written to self._env._internal.history
# Format response text BEFORE checking done (obs fields still valid)
response = self._format_step_result(obs)
if obs.done:
# History is authoritative at self._env._internal.history
self.reward = float(
grade_episode(self._env._internal.history)["overall_score"]
)
raise ValueError(
f"Episode complete. Score: {self.reward:.3f}. {response}"
)
return response
# ─── Observation / response formatters ──────────────────────────────
def _format_observation(self, is_initial: bool = False) -> str:
"""Format current env state as a rich text observation string."""
obs = self._env._get_obs()
s = self._env._internal
config = self._env._config
steps_remaining = max(0, s.max_steps - s.current_step)
budget_dollars = s.budget_dollars
budget_pct = obs.budget_remaining * 100.0
lines = []
if is_initial:
lines.append(
f"=== Budget Router — {config.name.upper()} ===\n"
f"Budget: ${budget_dollars:.3f} ({budget_pct:.1f}% remaining) | "
f"Steps remaining: {steps_remaining}/{s.max_steps}\n"
f"Providers: A=$0.01/req (cheapest), B=$0.05/req, C=$0.10/req (most reliable)\n"
f"Goal: Maximize successful routed requests. Budget exhaustion = heavy penalty.\n"
)
else:
lines.append(
f"Budget: ${budget_dollars:.3f} ({budget_pct:.1f}%) | "
f"Steps remaining: {steps_remaining}"
)
lines.append(
f"Provider health (windowed success rate; 0.5 = unobserved):\n"
f" A: {obs.provider_a_status:.3f} | B: {obs.provider_b_status:.3f} | C: {obs.provider_c_status:.3f}\n"
f"Queue backlog: {obs.queue_backlog:.3f} (normalized) | "
f"System latency: {obs.system_latency:.3f} (normalized to SLA)\n"
)
if is_initial:
lines.append(
"Choose a routing action: route_to_a / route_to_b / route_to_c / shed_load"
)
return "\n".join(lines)
def _format_step_result(self, obs) -> str:
"""Format step outcome as text returned to the model."""
s = self._env._internal
history = s.history
if not history:
return self._format_observation()
last = history[-1]
action_type = last.get("action_type", "unknown")
succeeded = last.get("request_succeeded", False)
provider = last.get("provider")
latency = last.get("latency_ms", 0.0)
cost = last.get("cost", 0.0)
budget_exhausted = last.get("budget_exhausted", False)
queue_overflow = last.get("queue_overflow", False)
if action_type == "shed_load":
result = "shed"
elif budget_exhausted:
result = "budget_exhausted"
elif succeeded:
result = "ok"
else:
result = "fail"
overflow_note = " overflow=1" if queue_overflow else ""
step_num = last.get("step", s.current_step)
obs_obj = self._env._get_obs()
budget_pct = obs_obj.budget_remaining * 100.0
steps_remaining = max(0, s.max_steps - s.current_step)
return (
f"step={step_num} action={action_type} result={result} p={provider or '-'} "
f"lat={latency:.0f} cost={cost:.3f} budget={budget_pct:.1f}% "
f"steps_left={steps_remaining} health=A{obs_obj.provider_a_status:.2f}/"
f"B{obs_obj.provider_b_status:.2f}/C{obs_obj.provider_c_status:.2f} "
f"queue={obs_obj.queue_backlog:.2f}{overflow_note}"
)