Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Iterable, List, Literal | |
| import typer | |
| from openai import OpenAI | |
| from budget_router.environment import BudgetRouterEnv | |
| from budget_router.models import Action, ActionType, Observation, TaskConfig | |
| from budget_router.policies import heuristic_baseline_policy | |
| from budget_router.reward import episode_metrics, grade_episode | |
| from budget_router.tasks import EASY, HARD, HARD_MULTI, MEDIUM | |
| _VALID_ACTIONS = ["route_to_a", "route_to_b", "route_to_c", "shed_load"] | |
| def _parse_llm_action(response_text: str) -> str: | |
| """Extract a valid action from LLM output. Falls back to shed_load — never raises.""" | |
| text = response_text.strip().lower() | |
| for action in _VALID_ACTIONS: | |
| if action in text: | |
| return action | |
| return "shed_load" # safe fallback: always valid, never crashes | |
| SYSTEM_PROMPT = """ | |
| You are a cost-aware LLM API routing agent managing a production system. | |
| At each step, output EXACTLY ONE action string. Nothing else. | |
| ENVIRONMENT: | |
| Three providers: A ($0.01/req, cheapest), B ($0.05/req), C ($0.10/req, most reliable). | |
| provider_X_status = windowed success rate [0=always fails, 1=always succeeds]. | |
| IMPORTANT: A status of exactly 0.500 means this provider has NEVER been routed to | |
| in this episode — it is unobserved, not confirmed healthy. Route to it once to get | |
| a real reading. Do not treat 0.500 as a health signal. | |
| budget_remaining: fraction of budget left. Reaching 0 = catastrophic -10 penalty. | |
| step_count [0→1], steps_remaining: episode progress (20 steps total). | |
| VALID ACTIONS (output ONLY one): | |
| route_to_a | route_to_b | route_to_c | shed_load | |
| GOLDEN RULE — DEFAULT STRATEGY: | |
| Stay on the CHEAPEST provider whose status > 0.52. Only deviate if there is CLEAR, SUSTAINED evidence of degradation (defined below). Unnecessary switching to expensive providers burns budget and reduces your score. | |
| NOISE CALIBRATION (critical): | |
| - Status fluctuates due to Bernoulli sampling noise. Single-step dips are not reliable signals. | |
| - Use the provided 2-step trend (avg/step): a sustained negative trend across multiple steps | |
| indicates real degradation; a trend near 0 means the provider is stable. Do NOT switch on noise. | |
| - REAL degradation signal: sustained negative trend AND current status is visibly declining. | |
| - Only when both conditions hold across consecutive observations should you consider early switching. | |
| - On stable tasks, trends hover near zero. Switching on noise burns budget without benefit. | |
| WHEN TO SWITCH (use your conversation history): | |
| A → B: When trend_a is clearly and consistently negative AND status_a is approaching unreliable, | |
| OR status_a is already below 0.52 (failure probability exceeds success probability). | |
| B → C: Same principle — sustained decline signals, not single-step noise. | |
| Never switch based on a single bad observation — noise causes occasional dips. | |
| BUDGET RUNWAY — HARD CONSTRAINT: | |
| budget_runway_at_current_rate shows how many more steps you can afford at current spend rate. | |
| If budget_runway_at_current_rate < steps_remaining: switch to a cheaper provider IMMEDIATELY. | |
| If budget_remaining < 0.15 (less than 15% left): treat C as OFF-LIMITS unless A and B are | |
| both below 0.30 status. Prefer shed_load over routing C when budget is this low. | |
| NEVER route to any provider if doing so would leave budget_remaining below the cost of | |
| that provider times the steps_remaining. The -10 bankruptcy penalty destroys all episode | |
| value accumulated so far — budget survival is non-negotiable. | |
| TASK PROFILES (the task name appears in each observation — use it): | |
| easy: Stable environment. Trend fluctuations are mostly noise. Stay on the cheapest provider unless its trend is catastrophically and sustainedly negative. | |
| medium: Dynamic environment. A provider may degrade mid-episode. Monitor trends and switch to the next cheapest healthy fallback if the primary fails. | |
| hard / hard_multi: Hostile, multi-failure environments. Multiple providers may degrade at unexpected times in unpredictable sequences. | |
| Your Runbook: Always map traffic to the lowest-cost healthy provider (A=$0.01, B=$0.05, C=$0.10). | |
| Watch your conversation history: if your currently active provider shows a clear, sustained negative trend, switch early to the next cheapest option that is healthy. | |
| CRITICAL: Before switching to expensive fallbacks (like C), use budget_runway to verify you can afford them to prevent budget exhaustion. | |
| Output only the action string.""" | |
| OBJECTIVE_FEEDBACK_PROMPT = """ | |
| You are a budget-aware API routing agent managing a production system. | |
| At each step, output EXACTLY ONE action string and nothing else. | |
| VALID ACTIONS: | |
| route_to_a | route_to_b | route_to_c | shed_load | |
| ENVIRONMENT INTUITION: | |
| You route one request per step across three providers: | |
| - A: cheapest ($0.01), useful for conserving budget, but may degrade. | |
| - B: medium cost ($0.05), usually the bridge provider. | |
| - C: expensive ($0.10), most reliable, but using it too early can bankrupt the episode. | |
| - shed_load: reject this request; use only when routing is likely worse than abstaining or budget is critically unsafe. | |
| OBSERVATIONS: | |
| provider_X_status is a rolling recent success estimate, not true health. | |
| - Exactly 0.500 means unobserved in this episode: no evidence yet. | |
| - After one probe, the status may jump to 0.000 or 1.000; that is weak evidence because it may be only one sample. | |
| - Repeated outcomes, sustained negative trend, worsening reward, or repeated failures are stronger evidence. | |
| Do not treat a single success as proof of health or a single failure as proof of degradation. | |
| PRIMARY OBJECTIVE: | |
| Maximize full-episode grader score: | |
| - successful routed requests | |
| - low latency and SLA health | |
| - budget preservation | |
| - adaptation after provider degradation | |
| DECISION POLICY: | |
| 1. Budget survival is mandatory. Avoid the -10 budget exhaustion cliff. | |
| 2. Prefer the cheapest provider that is plausibly healthy, but do not blindly follow a fixed threshold. | |
| 3. Probe unknown providers when information is valuable and affordable. | |
| 4. Switch away from a provider when there is repeated failure, sustained decline, or clearly bad status. | |
| 5. Use C as late-phase reliability capacity, not as the default. | |
| 6. Prefer shed_load when all available routed choices look likely to fail or when routing would risk budget exhaustion. | |
| TASK-SPECIFIC STRATEGY: | |
| easy: | |
| - Stable task. Prefer cheap routing. Do not overreact to noise. | |
| medium: | |
| - A may degrade after early steps. Start cheap, then move to B when A shows repeated weakness. | |
| hard: | |
| - A may degrade from the beginning. Probe cheaply, but react faster to repeated A failures or sustained decline. | |
| hard_multi: | |
| - This is a sequential cascade: A degrades early; B can degrade later. | |
| - Early phase: conserve budget. Use/probe A only while evidence supports it. | |
| - Middle phase: B is often the bridge provider after A weakens. | |
| - Late phase: preserve enough runway for C if B begins failing. | |
| - Do not burn C too early; do not stay on A/B after repeated failures. | |
| BUDGET RUNWAY: | |
| If budget_runway_at_current_rate < steps_remaining, current spending is unsafe. | |
| If budget_remaining < 0.15, treat C as emergency-only unless A and B both look very poor. | |
| If routing risks budget exhaustion, shed_load is better than bankruptcy. | |
| Use previous_step_feedback when present: | |
| - previous_success=false, negative reward, high latency, or repeated same-provider failures are evidence to update your belief. | |
| - previous_success=true is useful evidence but not proof, especially after only one sample. | |
| Output only one valid action string. | |
| """ | |
| app = typer.Typer(add_completion=False) | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" | |
| MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" | |
| LLM_TIMEOUT_SECONDS = float(os.getenv("LLM_TIMEOUT_SECONDS") or "25") | |
| LLM_MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES") or "1") | |
| LLM_LOG_RAW = (os.getenv("LLM_LOG_RAW") or "").strip().lower() in {"1", "true", "yes", "y", "on"} | |
| LLM_LOG_RAW_MAX_CHARS = int(os.getenv("LLM_LOG_RAW_MAX_CHARS") or "220") | |
| LLM_POLICY_MODE = (os.getenv("LLM_POLICY_MODE") or "baseline").strip().lower() | |
| BENCHMARK_NAME = os.getenv("BENCHMARK_NAME") or "budget_router" | |
| SEED_SETS: Dict[str, List[int]] = { | |
| "development": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], | |
| "heldout": [100, 101, 102, 103, 104, 105, 106, 107, 108, 109], | |
| } | |
| TASKS: Dict[str, TaskConfig] = { | |
| "easy": EASY, | |
| "medium": MEDIUM, | |
| "hard": HARD, | |
| "hard_multi": HARD_MULTI, | |
| } | |
| VALID_ACTIONS = [action.value for action in ActionType] | |
| class LLMRouter: | |
| def __init__( | |
| self, | |
| api_base_url: str, | |
| model_name: str, | |
| api_key: str, | |
| prompt_mode: str | None = None, | |
| ) -> None: | |
| self._client = OpenAI( | |
| base_url=api_base_url, | |
| api_key=api_key, | |
| timeout=LLM_TIMEOUT_SECONDS, | |
| max_retries=LLM_MAX_RETRIES, | |
| ) | |
| self._model_name = model_name | |
| self._prompt_mode = (prompt_mode or LLM_POLICY_MODE or "baseline").strip().lower() | |
| self._messages: List[Dict[str, str]] = [] | |
| self.last_error: str | None = None | |
| self.last_raw_output: str | None = None | |
| self.last_parsed_action: str | None = None | |
| self._prev_obs: dict | None = None | |
| self._prev2_obs: dict | None = None | |
| self._task_name: str = "" | |
| self.reset() | |
| def reset(self, task_name: str = "") -> None: | |
| prompt = OBJECTIVE_FEEDBACK_PROMPT if self._prompt_mode == "objective_feedback" else SYSTEM_PROMPT | |
| self._messages = [{"role": "system", "content": prompt}] | |
| self.last_error = None | |
| self.last_raw_output = None | |
| self.last_parsed_action = None | |
| self._prev_obs = None | |
| self._prev2_obs = None | |
| self._task_name = task_name | |
| def choose_action(self, observation: Observation) -> Action: | |
| obs = observation | |
| if not self._messages: | |
| self.reset(task_name=self._task_name) | |
| elif obs.step_count == 0.0 and len(self._messages) > 1: | |
| self.reset(task_name=self._task_name) | |
| # ── Compute 2-step trend (more noise-robust than single-step delta) ── | |
| trend_text = "" | |
| budget_runway_text = "" | |
| if self._prev2_obs is not None: | |
| # Average per-step change over 2 steps — variance is ~30% lower than 1-step delta | |
| ta = (obs.provider_a_status - self._prev2_obs["a"]) / 2.0 | |
| tb = (obs.provider_b_status - self._prev2_obs["b"]) / 2.0 | |
| tc = (obs.provider_c_status - self._prev2_obs["c"]) / 2.0 | |
| trend_text = f"\ntrend (avg/step, 2-step): A:{ta:+.3f} B:{tb:+.3f} C:{tc:+.3f}" | |
| elif self._prev_obs is not None: | |
| # First step — single-step delta only, label as less reliable | |
| ta = obs.provider_a_status - self._prev_obs["a"] | |
| tb = obs.provider_b_status - self._prev_obs["b"] | |
| tc = obs.provider_c_status - self._prev_obs["c"] | |
| trend_text = f"\ntrend (1-step only, noisy): A:{ta:+.3f} B:{tb:+.3f} C:{tc:+.3f}" | |
| if self._prev_obs is not None: | |
| budget_spent = self._prev_obs["budget"] - obs.budget_remaining | |
| if budget_spent > 0.001: | |
| runway = int(obs.budget_remaining / budget_spent) | |
| budget_runway_text = f"\nbudget_runway_at_current_rate: ~{runway} steps" | |
| else: | |
| budget_runway_text = "\nbudget_runway_at_current_rate: >20 steps" | |
| steps_total = 20 | |
| steps_remaining = max(1, steps_total - int(round(obs.step_count * steps_total))) | |
| task_line = f"\ntask: {self._task_name}" if self._task_name else "" | |
| obs_text = "\n".join([ | |
| f"provider_a_status: {obs.provider_a_status:.3f}", | |
| f"provider_b_status: {obs.provider_b_status:.3f}", | |
| f"provider_c_status: {obs.provider_c_status:.3f}", | |
| f"budget_remaining: {obs.budget_remaining:.3f}", | |
| f"queue_backlog: {obs.queue_backlog:.3f}", | |
| f"system_latency: {obs.system_latency:.3f}", | |
| f"step_count: {obs.step_count:.3f}", | |
| f"steps_remaining: {steps_remaining}", | |
| ]) | |
| obs_text += trend_text + budget_runway_text + task_line | |
| if self._prompt_mode == "objective_feedback": | |
| feedback_lines = self._previous_step_feedback(observation=obs) | |
| if feedback_lines: | |
| obs_text += "\n" + feedback_lines | |
| user_prompt = f"Current observation:\n{obs_text}\n\nYour action:" | |
| # Shift history: prev becomes prev2, current becomes prev | |
| self._prev2_obs = self._prev_obs | |
| self._prev_obs = { | |
| "a": obs.provider_a_status, | |
| "b": obs.provider_b_status, | |
| "c": obs.provider_c_status, | |
| "budget": obs.budget_remaining, | |
| } | |
| client = self._client | |
| model_name = self._model_name | |
| self._messages.append({"role": "user", "content": user_prompt}) | |
| try: | |
| response = client.with_options(timeout=LLM_TIMEOUT_SECONDS).chat.completions.create( | |
| model=model_name, | |
| messages=self._messages, | |
| max_tokens=30, | |
| temperature=0, | |
| ) | |
| raw = response.choices[0].message.content or "" | |
| action_str = _parse_llm_action(raw) | |
| action_str = self._apply_budget_safety_guard(action_str=action_str, observation=obs) | |
| self.last_raw_output = raw | |
| self.last_parsed_action = action_str | |
| self.last_error = None | |
| except Exception as e: | |
| self.last_error = str(e) | |
| action_str = "shed_load" | |
| self.last_raw_output = None | |
| self.last_parsed_action = action_str | |
| self._messages.append({"role": "assistant", "content": action_str}) | |
| return Action(action_type=ActionType(action_str)) | |
| def _apply_budget_safety_guard(self, action_str: str, observation: Observation) -> str: | |
| """Prevent only actions that would immediately exhaust the public remaining budget.""" | |
| if action_str == "shed_load": | |
| return action_str | |
| scenario = TASKS.get(self._task_name) | |
| if scenario is None: | |
| return action_str | |
| action_costs = { | |
| "route_to_a": scenario.cost_a, | |
| "route_to_b": scenario.cost_b, | |
| "route_to_c": scenario.cost_c, | |
| } | |
| selected_cost = action_costs.get(action_str, 0.0) | |
| budget_dollars = float(observation.budget_remaining) * float(scenario.initial_budget) | |
| if selected_cost >= budget_dollars - 1e-9: | |
| return "shed_load" | |
| return action_str | |
| def _previous_step_feedback(self, observation: Observation) -> str: | |
| metadata = getattr(observation, "metadata", None) or {} | |
| if not metadata: | |
| return "" | |
| previous_action = metadata.get("action_type") | |
| if not previous_action: | |
| return "" | |
| reward = observation.reward | |
| latency = metadata.get("latency_ms") | |
| cost = metadata.get("cost") | |
| succeeded = metadata.get("request_succeeded") | |
| budget_exhausted = metadata.get("budget_exhausted", False) | |
| feedback_parts = [ | |
| "previous_step_feedback:", | |
| f" previous_action: {previous_action}", | |
| ] | |
| if reward is not None: | |
| feedback_parts.append(f" previous_reward: {float(reward):+.2f}") | |
| if succeeded is not None: | |
| feedback_parts.append(f" previous_success: {str(bool(succeeded)).lower()}") | |
| if cost is not None: | |
| feedback_parts.append(f" previous_cost: {float(cost):.2f}") | |
| if latency is not None: | |
| feedback_parts.append(f" previous_latency_ms: {float(latency):.2f}") | |
| if budget_exhausted: | |
| feedback_parts.append(" previous_budget_exhausted: true") | |
| return "\n".join(feedback_parts) | |
| def _single_line(value: str | None) -> str: | |
| if not value: | |
| return "null" | |
| return str(value).replace("\n", " ").replace("\r", " ") | |
| def _truncate_and_sanitize(value: str | None, max_chars: int) -> str: | |
| if not value: | |
| return "null" | |
| s = _single_line(value).strip() | |
| if len(s) <= max_chars: | |
| return s | |
| return s[: max(0, max_chars - 3)] + "..." | |
| def _observation_payload(observation: Observation) -> Dict[str, float]: | |
| return { | |
| "provider_a_status": float(observation.provider_a_status), | |
| "provider_b_status": float(observation.provider_b_status), | |
| "provider_c_status": float(observation.provider_c_status), | |
| "budget_remaining": float(observation.budget_remaining), | |
| "queue_backlog": float(observation.queue_backlog), | |
| "system_latency": float(observation.system_latency), | |
| "step_count": float(observation.step_count), | |
| } | |
| def _reported_score(value: float) -> float: | |
| return min(max(float(value), 0.001), 0.999) | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step( | |
| step: int, | |
| action: str, | |
| reward: float, | |
| done: bool, | |
| error: str | None, | |
| llm_raw: str | None = None, | |
| llm_parsed: str | None = None, | |
| ) -> None: | |
| base = ( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} " | |
| f"error={_single_line(error)}" | |
| ) | |
| if LLM_LOG_RAW: | |
| raw_s = _truncate_and_sanitize(llm_raw, max_chars=max(20, LLM_LOG_RAW_MAX_CHARS)) | |
| parsed_s = _single_line(llm_parsed) | |
| base += f" llm_raw={raw_s} llm_parsed={parsed_s}" | |
| print(base, flush=True) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={_reported_score(score):.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def select_policy(policy_name: Literal["heuristic", "llm"]) -> object: | |
| if policy_name == "heuristic": | |
| return heuristic_baseline_policy | |
| if not API_KEY or not API_BASE_URL: | |
| raise RuntimeError( | |
| "LLM policy requires API_BASE_URL and API_KEY and reads MODEL_NAME from environment variables." | |
| ) | |
| return LLMRouter(api_base_url=API_BASE_URL, model_name=MODEL_NAME, api_key=API_KEY) | |
| def choose_action(policy_name: Literal["heuristic", "llm"], policy: object, observation: Observation) -> Action: | |
| if policy_name == "heuristic": | |
| return policy(observation) | |
| return policy.choose_action(observation) | |
| def run_episode( | |
| env: BudgetRouterEnv, | |
| scenario: TaskConfig, | |
| seed: int, | |
| episode: int, | |
| policy_name: Literal["heuristic", "llm"], | |
| policy: object, | |
| ) -> Dict[str, Any]: | |
| total_reward = 0.0 | |
| grader_score: float | None = None | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| success = False | |
| if policy_name == "llm": | |
| policy.reset(task_name=scenario.name) | |
| log_start(task=scenario.name, env=BENCHMARK_NAME, model=MODEL_NAME) | |
| try: | |
| observation = env.reset(seed=seed, scenario=scenario) | |
| while not observation.done: | |
| action = choose_action(policy_name=policy_name, policy=policy, observation=observation) | |
| action_name = action.action_type.value | |
| observation = env.step(action) | |
| reward = float(observation.reward or 0.0) | |
| total_reward += reward | |
| rewards.append(reward) | |
| steps_taken = env._internal.current_step | |
| step_error = getattr(policy, "last_error", None) if policy_name == "llm" else None | |
| llm_raw = getattr(policy, "last_raw_output", None) if policy_name == "llm" else None | |
| llm_parsed = getattr(policy, "last_parsed_action", None) if policy_name == "llm" else None | |
| log_step( | |
| step=env._internal.current_step, | |
| action=action_name, | |
| reward=reward, | |
| done=bool(observation.done), | |
| error=step_error, | |
| llm_raw=llm_raw, | |
| llm_parsed=llm_parsed, | |
| ) | |
| metrics = episode_metrics(env._internal.history) | |
| metrics["seed"] = seed | |
| metrics["episode"] = episode | |
| metrics["total_reward"] = round(total_reward, 4) | |
| metrics["episode_length"] = env._internal.current_step | |
| grader = grade_episode(env._internal.history) | |
| grader_score = float(grader["overall_score"]) | |
| success = grader_score > 0.0 | |
| metrics["grader_score"] = grader_score | |
| metrics["grader_breakdown"] = grader | |
| return metrics | |
| finally: | |
| close_fn = getattr(env, "close", None) | |
| if callable(close_fn): | |
| close_fn() | |
| if grader_score is None: | |
| grader_score = float(grade_episode(env._internal.history)["overall_score"]) | |
| success = grader_score > 0.0 | |
| log_end(success=success, steps=steps_taken, score=grader_score, rewards=rewards) | |
| def summarize(metrics: Iterable[Dict[str, float]]) -> Dict[str, float]: | |
| rows = list(metrics) | |
| return { | |
| "mean_reward": round(sum(row["total_reward"] for row in rows) / len(rows), 4), | |
| "mean_success_rate": round(sum(row["success_rate"] for row in rows) / len(rows), 4), | |
| "mean_cost": round(sum(row["total_cost_spent"] for row in rows) / len(rows), 4), | |
| "mean_latency_ms": round(sum(row["average_latency_ms"] for row in rows) / len(rows), 2), | |
| "mean_grader_score": round(sum(row["grader_score"] for row in rows) / len(rows), 4), | |
| } | |
| def main( | |
| policy: Literal["heuristic", "llm"] = typer.Option("llm" if API_KEY and API_BASE_URL else "heuristic"), | |
| seed_set: Literal["development", "heldout"] = typer.Option("development"), | |
| scenario: Literal["all", "easy", "medium", "hard", "hard_multi"] = typer.Option("all"), | |
| max_seeds: int = typer.Option(1), | |
| output_path: Path = typer.Option(Path("baseline_results.json")), | |
| ) -> None: | |
| selected_policy = select_policy(policy) | |
| selected_tasks = TASKS if scenario == "all" else {scenario: TASKS[scenario]} | |
| selected_seeds = SEED_SETS[seed_set][: max(1, max_seeds)] | |
| results: Dict[str, Dict[str, object]] = {} | |
| episode = 1 | |
| for task_name, task_config in selected_tasks.items(): | |
| task_metrics = [] | |
| for seed in selected_seeds: | |
| env = BudgetRouterEnv() | |
| task_metrics.append( | |
| run_episode( | |
| env=env, | |
| scenario=task_config, | |
| seed=seed, | |
| episode=episode, | |
| policy_name=policy, | |
| policy=selected_policy, | |
| ) | |
| ) | |
| episode += 1 | |
| results[task_name] = { | |
| "policy": policy, | |
| "seed_set": seed_set, | |
| "summary": summarize(task_metrics), | |
| "episodes": task_metrics, | |
| } | |
| output_path.write_text(json.dumps(results, indent=2) + "\n", encoding="utf-8") | |
| if __name__ == "__main__": | |
| app() | |