budget-router-openenv / eval_sft.py
Akshay Babbar
chore: HF Space export (size filter)
98a5a8c
#!/usr/bin/env python3
# /// script
# dependencies = [
# "torch",
# "transformers>=4.45.0",
# "huggingface_hub>=0.24.0",
# "scipy",
# "budget-router @ git+https://huggingface.co/spaces/akshay4/budget-router-openenv",
# ]
# ///
"""Evaluate a Budget Router SFT model against the heuristic baseline."""
from __future__ import annotations
import argparse
import json
import math
import os
import time
from pathlib import Path
from typing import Any
import numpy as np
from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType, Observation, TaskConfig
from budget_router.policies import heuristic_baseline_policy
from budget_router.reward import episode_metrics, grade_episode
from budget_router.tasks import HARD_MULTI, TASK_PRESETS
try:
from inference import SYSTEM_PROMPT
_SYSTEM_PROMPT_SOURCE = "inference"
except ModuleNotFoundError as exc:
if exc.name != "inference":
raise
SYSTEM_PROMPT = """
You are a cost-aware LLM API routing agent managing a production system.
At each step, output EXACTLY ONE action string. Nothing else.
ENVIRONMENT:
Three providers: A ($0.01/req, cheapest), B ($0.05/req), C ($0.10/req, most reliable).
provider_X_status = windowed success rate [0=always fails, 1=always succeeds].
IMPORTANT: A status of exactly 0.500 means this provider has NEVER been routed to
in this episode — it is unobserved, not confirmed healthy. Route to it once to get
a real reading. Do not treat 0.500 as a health signal.
budget_remaining: fraction of budget left. Reaching 0 = catastrophic -10 penalty.
step_count [0→1], steps_remaining: episode progress (20 steps total).
VALID ACTIONS (output ONLY one):
route_to_a | route_to_b | route_to_c | shed_load
GOLDEN RULE — DEFAULT STRATEGY:
Stay on the CHEAPEST provider whose status > 0.52. Only deviate if there is CLEAR, SUSTAINED evidence of degradation (defined below). Unnecessary switching to expensive providers burns budget and reduces your score.
NOISE CALIBRATION (critical):
- Status fluctuates due to Bernoulli sampling noise. Single-step dips are not reliable signals.
- Use the provided 2-step trend (avg/step): a sustained negative trend across multiple steps
indicates real degradation; a trend near 0 means the provider is stable. Do NOT switch on noise.
- REAL degradation signal: sustained negative trend AND current status is visibly declining.
- Only when both conditions hold across consecutive observations should you consider early switching.
- On stable tasks, trends hover near zero. Switching on noise burns budget without benefit.
WHEN TO SWITCH (use your conversation history):
A → B: When trend_a is clearly and consistently negative AND status_a is approaching unreliable,
OR status_a is already below 0.52 (failure probability exceeds success probability).
B → C: Same principle — sustained decline signals, not single-step noise.
Never switch based on a single bad observation — noise causes occasional dips.
BUDGET RUNWAY — HARD CONSTRAINT:
budget_runway_at_current_rate shows how many more steps you can afford at current spend rate.
If budget_runway_at_current_rate < steps_remaining: switch to a cheaper provider IMMEDIATELY.
If budget_remaining < 0.15 (less than 15% left): treat C as OFF-LIMITS unless A and B are
both below 0.30 status. Prefer shed_load over routing C when budget is this low.
NEVER route to any provider if doing so would leave budget_remaining below the cost of
that provider times the steps_remaining. The -10 bankruptcy penalty destroys all episode
value accumulated so far — budget survival is non-negotiable.
TASK PROFILES (the task name appears in each observation — use it):
easy: Stable environment. Trend fluctuations are mostly noise. Stay on the cheapest provider unless its trend is catastrophically and sustainedly negative.
medium: Dynamic environment. A provider may degrade mid-episode. Monitor trends and switch to the next cheapest healthy fallback if the primary fails.
hard / hard_multi: Hostile, multi-failure environments. Multiple providers may degrade at unexpected times in unpredictable sequences.
Your Runbook: Always map traffic to the lowest-cost healthy provider (A=$0.01, B=$0.05, C=$0.10).
Watch your conversation history: if your currently active provider shows a clear, sustained negative trend, switch early to the next cheapest option that is healthy.
CRITICAL: Before switching to expensive fallbacks (like C), use budget_runway to verify you can afford them to prevent budget exhaustion.
Output only the action string."""
_SYSTEM_PROMPT_SOURCE = "embedded_fallback"
_AGENT_DEBUG_LOG = "/Users/akshaybabbar/Desktop/work/.cursor/debug-e4cac3.log"
def _agent_debug_ndjson(payload: dict[str, object]) -> None:
line = json.dumps(payload)
try:
with open(_AGENT_DEBUG_LOG, "a", encoding="utf-8") as f:
f.write(line + "\n")
except OSError:
print(f"[agent-debug] {line}", flush=True)
VALID_ACTIONS = ["route_to_a", "route_to_b", "route_to_c", "shed_load"]
DEFAULT_MODEL_REPO = "akshay4/budget-router-sft-qwen1.5b"
def _steps_remaining(obs: Observation, max_steps: int = 20) -> int:
elapsed = int(round(float(obs.step_count) * max_steps))
return max(0, max_steps - elapsed)
def _trend_text(obs: Observation, previous_obs: Observation | None, previous2_obs: Observation | None) -> str:
if previous2_obs is not None:
ta = (obs.provider_a_status - previous2_obs.provider_a_status) / 2.0
tb = (obs.provider_b_status - previous2_obs.provider_b_status) / 2.0
tc = (obs.provider_c_status - previous2_obs.provider_c_status) / 2.0
return f"trend (avg/step, 2-step): A:{ta:+.3f} B:{tb:+.3f} C:{tc:+.3f}"
if previous_obs is not None:
ta = obs.provider_a_status - previous_obs.provider_a_status
tb = obs.provider_b_status - previous_obs.provider_b_status
tc = obs.provider_c_status - previous_obs.provider_c_status
return f"trend (1-step only, noisy): A:{ta:+.3f} B:{tb:+.3f} C:{tc:+.3f}"
return "trend: unavailable"
def _budget_runway_text(obs: Observation, previous_obs: Observation | None) -> str:
if previous_obs is None:
return "budget_runway_at_current_rate: >20 steps"
budget_spent = float(previous_obs.budget_remaining) - float(obs.budget_remaining)
if budget_spent <= 0.001:
return "budget_runway_at_current_rate: >20 steps"
runway = int(float(obs.budget_remaining) / budget_spent)
return f"budget_runway_at_current_rate: ~{runway} steps"
def _previous_step_feedback(obs: Observation) -> str:
metadata = getattr(obs, "metadata", None) or {}
if not metadata.get("action_type"):
return ""
parts = [
"previous_step_feedback:",
f" previous_action: {metadata.get('action_type')}",
]
if obs.reward is not None:
parts.append(f" previous_reward: {float(obs.reward):+.2f}")
if metadata.get("request_succeeded") is not None:
parts.append(f" previous_success: {str(bool(metadata.get('request_succeeded'))).lower()}")
if metadata.get("cost") is not None:
parts.append(f" previous_cost: {float(metadata.get('cost')):.2f}")
if metadata.get("latency_ms") is not None:
parts.append(f" previous_latency_ms: {float(metadata.get('latency_ms')):.2f}")
if metadata.get("budget_exhausted"):
parts.append(" previous_budget_exhausted: true")
return "\n".join(parts)
def format_observation_for_sft(
*,
obs: Observation,
task_name: str,
previous_obs: Observation | None,
previous2_obs: Observation | None,
) -> str:
lines = [
f"task: {task_name}",
f"provider_a_status: {obs.provider_a_status:.3f}",
f"provider_b_status: {obs.provider_b_status:.3f}",
f"provider_c_status: {obs.provider_c_status:.3f}",
f"budget_remaining: {obs.budget_remaining:.3f}",
f"queue_backlog: {obs.queue_backlog:.3f}",
f"system_latency: {obs.system_latency:.3f}",
f"step_count: {obs.step_count:.3f}",
f"steps_remaining: {_steps_remaining(obs)}",
_trend_text(obs, previous_obs, previous2_obs),
_budget_runway_text(obs, previous_obs),
]
feedback = _previous_step_feedback(obs)
if feedback:
lines.append(feedback)
return "\n".join(lines)
def parse_action(text: str) -> tuple[str, bool]:
lowered = text.strip().lower()
for action in VALID_ACTIONS:
if action in lowered:
return action, True
return "route_to_a", False
def apply_budget_safety_guard(action_str: str, observation: Observation, task_cfg: TaskConfig) -> str:
if action_str == "shed_load":
return action_str
costs = {
"route_to_a": task_cfg.cost_a,
"route_to_b": task_cfg.cost_b,
"route_to_c": task_cfg.cost_c,
}
selected_cost = costs.get(action_str, 0.0)
budget_dollars = float(observation.budget_remaining) * float(task_cfg.initial_budget)
if selected_cost >= budget_dollars - 1e-9:
return "shed_load"
return action_str
def run_heuristic_episode(task_cfg: TaskConfig, seed: int) -> dict[str, Any]:
env = BudgetRouterEnv()
obs = env.reset(seed=seed, scenario=task_cfg)
total_reward = 0.0
while not obs.done:
obs = env.step(heuristic_baseline_policy(obs))
total_reward += float(obs.reward or 0.0)
grader = grade_episode(env._internal.history)
metrics = episode_metrics(env._internal.history)
return {
"grader_score": float(grader["overall_score"]),
"total_reward": total_reward,
"episode_length": env._internal.current_step,
"grader": grader,
"metrics": metrics,
}
class SFTPolicy:
def __init__(self, model_repo: str, *, token: str | None, use_budget_guard: bool) -> None:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
self.device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if self.device == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
self.model = AutoModelForCausalLM.from_pretrained(model_repo, torch_dtype=dtype, token=token)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_repo, token=token)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.use_budget_guard = use_budget_guard
self.messages: list[dict[str, str]] = []
self.previous_obs: Observation | None = None
self.previous2_obs: Observation | None = None
self.parse_failures = 0
def reset(self) -> None:
self.messages = [{"role": "system", "content": SYSTEM_PROMPT}]
self.previous_obs = None
self.previous2_obs = None
self.parse_failures = 0
def choose_action(self, obs: Observation, *, task_name: str, task_cfg: TaskConfig) -> str:
import torch
obs_text = format_observation_for_sft(
obs=obs,
task_name=task_name,
previous_obs=self.previous_obs,
previous2_obs=self.previous2_obs,
)
self.messages.append({"role": "user", "content": obs_text})
prompt = self.tokenizer.apply_chat_template(
self.messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=10,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
)
generated = self.tokenizer.decode(
output[0][inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)
action_str, ok = parse_action(generated)
if not ok:
self.parse_failures += 1
if self.use_budget_guard:
action_str = apply_budget_safety_guard(action_str, obs, task_cfg)
self.messages.append({"role": "assistant", "content": action_str})
self.previous2_obs = self.previous_obs
self.previous_obs = obs
return action_str
def run_sft_episode(policy: SFTPolicy, task_name: str, task_cfg: TaskConfig, seed: int) -> dict[str, Any]:
env = BudgetRouterEnv()
policy.reset()
obs = env.reset(seed=seed, scenario=task_cfg)
total_reward = 0.0
actions: list[str] = []
while not obs.done:
action_str = policy.choose_action(obs, task_name=task_name, task_cfg=task_cfg)
actions.append(action_str)
obs = env.step(Action(action_type=ActionType(action_str)))
total_reward += float(obs.reward or 0.0)
grader = grade_episode(env._internal.history)
metrics = episode_metrics(env._internal.history)
return {
"grader_score": float(grader["overall_score"]),
"total_reward": total_reward,
"episode_length": env._internal.current_step,
"grader": grader,
"metrics": metrics,
"actions": actions,
"parse_failures": policy.parse_failures,
}
def _mean(values: list[float]) -> float:
return float(sum(values) / len(values)) if values else 0.0
def _sample_std(values: list[float]) -> float:
if len(values) < 2:
return 0.0
mean = _mean(values)
return float(math.sqrt(sum((v - mean) ** 2 for v in values) / (len(values) - 1)))
def compute_paired_stats(heuristic_scores: list[float], sft_scores: list[float]) -> dict[str, Any]:
if len(heuristic_scores) != len(sft_scores):
raise ValueError("Paired stats require equal-length score lists.")
if not heuristic_scores:
raise ValueError("No scores provided.")
diffs = [s - h for h, s in zip(heuristic_scores, sft_scores)]
n = len(diffs)
delta = _mean(diffs)
std_diff = _sample_std(diffs)
if std_diff == 0.0:
t_stat = math.inf if delta > 0 else (-math.inf if delta < 0 else 0.0)
p_val = 0.0 if delta > 0 else 1.0
cohens_d = math.inf if delta > 0 else (-math.inf if delta < 0 else 0.0)
else:
try:
from scipy import stats
t_stat, p_val = stats.ttest_rel(sft_scores, heuristic_scores, alternative="greater")
cohens_d = delta / std_diff
except Exception:
t_stat = delta / (std_diff / math.sqrt(n))
p_val = float("nan")
cohens_d = delta / std_diff
return {
"n_seeds": n,
"mean_heuristic": _mean(heuristic_scores),
"mean_sft": _mean(sft_scores),
"std_heuristic": _sample_std(heuristic_scores),
"std_sft": _sample_std(sft_scores),
"delta": delta,
"t_stat": float(t_stat),
"p_val": float(p_val),
"cohens_d": float(cohens_d),
"significant": bool(delta > 0 and p_val < 0.05),
"wins": sum(1 for d in diffs if d > 0),
"ties": sum(1 for d in diffs if d == 0),
"losses": sum(1 for d in diffs if d < 0),
}
def _ci95(values: list[float]) -> tuple[float, float]:
n = len(values)
mean = _mean(values)
if n < 2:
return mean, mean
se = _sample_std(values) / math.sqrt(n)
try:
from scipy import stats
lo, hi = stats.t.interval(0.95, df=n - 1, loc=mean, scale=se)
return float(lo), float(hi)
except Exception:
return mean - 1.96 * se, mean + 1.96 * se
def _parse_seed_values(value: str | None, n_seeds: int) -> list[int]:
if value:
return [int(part) for part in value.replace(",", " ").split()]
return list(range(300, 300 + n_seeds))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate SFT Budget Router model.")
parser.add_argument("--model-repo", default=os.getenv("SFT_MODEL_REPO", DEFAULT_MODEL_REPO))
parser.add_argument("--task", default=os.getenv("TASK_NAME", "hard_multi"), choices=sorted(TASK_PRESETS))
parser.add_argument("--n-seeds", type=int, default=int(os.getenv("N_SEEDS", "10")))
parser.add_argument("--seed-values", default=os.getenv("EVAL_SEED_VALUES"))
parser.add_argument("--output-json", default=os.getenv("EVAL_OUTPUT_JSON", "eval_results_sft.json"))
parser.add_argument("--no-budget-guard", action="store_true")
parser.add_argument("--no-upload", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
token = os.environ.get("HF_TOKEN")
task_cfg = TASK_PRESETS[args.task]
seeds = _parse_seed_values(args.seed_values, args.n_seeds)
# #region agent log
_agent_debug_ndjson(
{
"sessionId": "e4cac3",
"runId": os.environ.get("DEBUG_RUN_ID", "eval-import-fix"),
"hypothesisId": "H1",
"location": "eval_sft.py:main",
"message": "eval_startup",
"data": {
"system_prompt_source": _SYSTEM_PROMPT_SOURCE,
"model_repo": args.model_repo,
"task": args.task,
"n_seeds": len(seeds),
},
"timestamp": int(time.time() * 1000),
}
)
# #endregion
policy = SFTPolicy(args.model_repo, token=token, use_budget_guard=not args.no_budget_guard)
episodes: list[dict[str, Any]] = []
heuristic_scores: list[float] = []
sft_scores: list[float] = []
for seed in seeds:
heuristic_ep = run_heuristic_episode(task_cfg, seed)
sft_ep = run_sft_episode(policy, args.task, task_cfg, seed)
heuristic_scores.append(float(heuristic_ep["grader_score"]))
sft_scores.append(float(sft_ep["grader_score"]))
episodes.append({"seed": seed, "heuristic": heuristic_ep, "sft": sft_ep})
print(
f"[eval-sft] seed={seed} heuristic={heuristic_ep['grader_score']:.4f} "
f"sft={sft_ep['grader_score']:.4f} delta={sft_ep['grader_score'] - heuristic_ep['grader_score']:+.4f} "
f"parse_failures={sft_ep['parse_failures']}",
flush=True,
)
stats = compute_paired_stats(heuristic_scores, sft_scores)
heu_ci = _ci95(heuristic_scores)
sft_ci = _ci95(sft_scores)
result = {
**stats,
"task": args.task,
"seeds": seeds,
"heuristic_scores": heuristic_scores,
"sft_scores": sft_scores,
"heuristic_ci95": heu_ci,
"sft_ci95": sft_ci,
"budget_guard": not args.no_budget_guard,
"episodes": episodes,
}
Path(args.output_json).write_text(json.dumps(result, indent=2, sort_keys=True), encoding="utf-8")
print()
print("| Policy | Mean | Std | 95% CI | vs Heuristic |")
print("|---|---:|---:|---|---:|")
print(
f"| Heuristic | {stats['mean_heuristic']:.3f} | {stats['std_heuristic']:.3f} | "
f"[{heu_ci[0]:.3f}, {heu_ci[1]:.3f}] | baseline |"
)
print(
f"| SFT | {stats['mean_sft']:.3f} | {stats['std_sft']:.3f} | "
f"[{sft_ci[0]:.3f}, {sft_ci[1]:.3f}] | {stats['delta']:+.3f} |"
)
verdict = "SIGNIFICANT" if stats["significant"] else "NOT SIGNIFICANT"
print(
f"SFT: {stats['mean_sft']:.3f} vs Heuristic: {stats['mean_heuristic']:.3f} | "
f"delta={stats['delta']:+.3f} | t({stats['n_seeds'] - 1})={stats['t_stat']:.2f}, "
f"p={stats['p_val']:.4f} | {verdict} | Cohen's d={stats['cohens_d']:.2f} | "
f"wins/ties/losses={stats['wins']}/{stats['ties']}/{stats['losses']}"
)
if not args.no_upload:
if not token:
raise RuntimeError("HF_TOKEN must be set to upload eval JSON. Use --no-upload to skip.")
from huggingface_hub import upload_file
upload_file(
path_or_fileobj=args.output_json,
path_in_repo=Path(args.output_json).name,
repo_id=args.model_repo,
repo_type="model",
token=token,
)
print(f"[eval-sft] uploaded {args.output_json} to {args.model_repo}", flush=True)
if __name__ == "__main__":
main()