""" BudgetRouterGRPOEnv — TRL environment_factory-compatible class for GRPO training. Usage with TRL GRPOTrainer: from datasets import Dataset from train.grpo_env import BudgetRouterGRPOEnv from budget_router.reward import grade_episode # Dataset: columns become **kwargs in reset(). "prompt" drives the model's initial message. dataset = Dataset.from_list([ {"prompt": [[{"role": "user", "content": "Route requests using the available tools."}]], "scenario": "hard_multi", "seed": i} for i in range(200) ]) # reward_funcs with an `environments` parameter is the CORRECT TRL pattern when using # environment_factory. TRL inspects the signature and passes env instances (not completions). # This is explicitly documented in the official TRL/OpenEnv integration guide. # Alternatively, env.reward is set on the instance and TRL reads it directly if # reward_funcs is omitted — but the explicit function gives more control. def reward_func(environments, **kwargs): return [float(grade_episode(env._env._internal.history)["overall_score"]) for env in environments] trainer = GRPOTrainer( model=model, reward_funcs=reward_func, train_dataset=dataset, args=GRPOConfig(num_generations=8, max_completion_length=2048), environment_factory=BudgetRouterGRPOEnv, ) Design Constraints (do NOT violate): - Tool methods MUST call self._env.step() and never construct custom step_info dicts. environment.py writes actual_degradation_start (jittered per-episode) into step_info. grade_episode() reads degradation_start_step from that dict to compute adaptation windows. Custom dicts would write the config constant (e.g. 0) instead of the jittered value, silently corrupting adaptation scores with no crash. - History is authoritative at self._env._internal.history — never maintain a separate copy. - Reward is computed once at episode end via grade_episode()["overall_score"] (float in [0,1]). - Raise ValueError("Episode complete.") when done — TRL catches this and ends the rollout. Mac / MPS Notes: - Unsloth does NOT support Mac for training (CUDA-only as of Apr 2026). - Use TRL + PyTorch MPS: no load_in_4bit, no vLLM, no paged_adamw_8bit. - PYTORCH_ENABLE_MPS_FALLBACK=1 required for ops not yet on Metal. - Recommended models for Mac: Qwen2.5-1.5B (fits 8GB+), Qwen2.5-3B (fits 16GB+). - For Colab/cloud: Unsloth + vLLM work normally on NVIDIA T4/A100. Reward Variance Note: - GRPO gradient = 0 when group_std ≈ 0. Use hard_multi scenario (not easy). - hard_multi has jitter + dual degradation → wider inter-rollout score spread. - num_generations=8 (not 4) recommended to get better group variance estimates. """ from __future__ import annotations from typing import Optional from budget_router.environment import BudgetRouterEnv from budget_router.models import Action, ActionType from budget_router.reward import grade_episode from budget_router.tasks import HARD_MULTI, TASK_PRESETS class BudgetRouterGRPOEnv: """ TRL environment_factory-compatible wrapper around BudgetRouterEnv. Exposes four named tool methods: route_to_a, route_to_b, route_to_c, shed_load. The LLM calls these via function-calling. TRL discovers them automatically. Episode lifecycle: 1. reset(**kwargs) → returns rich text observation (initial state). 2. Model calls tool methods N times until episode ends. 3. Tool method raises ValueError("Episode complete.") when obs.done is True. 4. TRL reads self.reward from the reward_func after the episode. """ def __init__(self) -> None: self._env = BudgetRouterEnv() self.reward: float = 0.0 # ─── TRL lifecycle ────────────────────────────────────────────────── def reset(self, **kwargs) -> str: """ Reset the environment. Called by TRL at the start of each episode. Accepts dataset columns as kwargs: scenario (str): one of "easy", "medium", "hard", "hard_multi" (default). seed (int): optional fixed seed for reproducibility. Returns: str: Initial observation text including provider status, budget, and task brief. """ scenario_name = str(kwargs.get("scenario", "hard_multi")) scenario = TASK_PRESETS.get(scenario_name, HARD_MULTI) seed: Optional[int] = kwargs.get("seed", None) if seed is not None: seed = int(seed) self._env.reset(seed=seed, scenario=scenario) self.reward = 0.0 return self._format_observation(is_initial=True) # ─── Tool methods (TRL discovers all public non-reset methods) ─────── def route_to_a(self) -> str: """ Route the current request to Provider A ($0.01/req, cheapest, lowest base reliability). Args: (none) Returns: Outcome feedback: success/failure, latency, budget remaining, provider health update. """ return self._step(ActionType.ROUTE_TO_A) def route_to_b(self) -> str: """ Route the current request to Provider B ($0.05/req, balanced cost and reliability). Args: (none) Returns: Outcome feedback: success/failure, latency, budget remaining, provider health update. """ return self._step(ActionType.ROUTE_TO_B) def route_to_c(self) -> str: """ Route the current request to Provider C ($0.10/req, most expensive, highest base reliability). Args: (none) Returns: Outcome feedback: success/failure, latency, budget remaining, provider health update. """ return self._step(ActionType.ROUTE_TO_C) def shed_load(self) -> str: """ Shed the current request — reject it without routing to any provider. Use when all providers appear degraded or budget is critically low. Penalty: -0.5 step reward. Slightly reduces queue backlog. Args: (none) Returns: Outcome feedback: load shed confirmation, budget remaining, current state. """ return self._step(ActionType.SHED_LOAD) # ─── Internal step dispatch ────────────────────────────────────────── def _step(self, action_type: ActionType) -> str: """ Execute one environment step. CRITICAL: Delegates entirely to self._env.step(). Never constructs a custom step_info dict. environment.py writes actual_degradation_start (the jittered per-episode onset) into step_info; grade_episode() reads this to compute adaptation scores. A custom dict would use the config constant instead, silently breaking adaptation scoring. """ if self._env._internal.episode_done: # Guard: called after done — reuse last reward, signal TRL to stop raise ValueError( f"Episode already complete. Final score: {self.reward:.3f}" ) action = Action(action_type=action_type) obs = self._env.step(action) # step_info written to self._env._internal.history # Format response text BEFORE checking done (obs fields still valid) response = self._format_step_result(obs) if obs.done: # History is authoritative at self._env._internal.history self.reward = float( grade_episode(self._env._internal.history)["overall_score"] ) raise ValueError( f"Episode complete. Score: {self.reward:.3f}. {response}" ) return response # ─── Observation / response formatters ────────────────────────────── def _format_observation(self, is_initial: bool = False) -> str: """Format current env state as a rich text observation string.""" obs = self._env._get_obs() s = self._env._internal config = self._env._config steps_remaining = max(0, s.max_steps - s.current_step) budget_dollars = s.budget_dollars budget_pct = obs.budget_remaining * 100.0 lines = [] if is_initial: lines.append( f"=== Budget Router — {config.name.upper()} ===\n" f"Budget: ${budget_dollars:.3f} ({budget_pct:.1f}% remaining) | " f"Steps remaining: {steps_remaining}/{s.max_steps}\n" f"Providers: A=$0.01/req (cheapest), B=$0.05/req, C=$0.10/req (most reliable)\n" f"Goal: Maximize successful routed requests. Budget exhaustion = heavy penalty.\n" ) else: lines.append( f"Budget: ${budget_dollars:.3f} ({budget_pct:.1f}%) | " f"Steps remaining: {steps_remaining}" ) lines.append( f"Provider health (windowed success rate; 0.5 = unobserved):\n" f" A: {obs.provider_a_status:.3f} | B: {obs.provider_b_status:.3f} | C: {obs.provider_c_status:.3f}\n" f"Queue backlog: {obs.queue_backlog:.3f} (normalized) | " f"System latency: {obs.system_latency:.3f} (normalized to SLA)\n" ) if is_initial: lines.append( "Choose a routing action: route_to_a / route_to_b / route_to_c / shed_load" ) return "\n".join(lines) def _format_step_result(self, obs) -> str: """Format step outcome as text returned to the model.""" s = self._env._internal history = s.history if not history: return self._format_observation() last = history[-1] action_type = last.get("action_type", "unknown") succeeded = last.get("request_succeeded", False) provider = last.get("provider") latency = last.get("latency_ms", 0.0) cost = last.get("cost", 0.0) budget_exhausted = last.get("budget_exhausted", False) queue_overflow = last.get("queue_overflow", False) if action_type == "shed_load": result = "shed" elif budget_exhausted: result = "budget_exhausted" elif succeeded: result = "ok" else: result = "fail" overflow_note = " overflow=1" if queue_overflow else "" step_num = last.get("step", s.current_step) obs_obj = self._env._get_obs() budget_pct = obs_obj.budget_remaining * 100.0 steps_remaining = max(0, s.max_steps - s.current_step) return ( f"step={step_num} action={action_type} result={result} p={provider or '-'} " f"lat={latency:.0f} cost={cost:.3f} budget={budget_pct:.1f}% " f"steps_left={steps_remaining} health=A{obs_obj.provider_a_status:.2f}/" f"B{obs_obj.provider_b_status:.2f}/C{obs_obj.provider_c_status:.2f} " f"queue={obs_obj.queue_backlog:.2f}{overflow_note}" )