File size: 23,286 Bytes
98a5a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Literal

import typer
from openai import OpenAI

from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType, Observation, TaskConfig
from budget_router.policies import heuristic_baseline_policy
from budget_router.reward import episode_metrics, grade_episode
from budget_router.tasks import EASY, HARD, HARD_MULTI, MEDIUM

_VALID_ACTIONS = ["route_to_a", "route_to_b", "route_to_c", "shed_load"]


def _parse_llm_action(response_text: str) -> str:
    """Extract a valid action from LLM output. Falls back to shed_load — never raises."""
    text = response_text.strip().lower()
    for action in _VALID_ACTIONS:
        if action in text:
            return action
    return "shed_load"  # safe fallback: always valid, never crashes


SYSTEM_PROMPT = """
You are a cost-aware LLM API routing agent managing a production system.
At each step, output EXACTLY ONE action string. Nothing else.

ENVIRONMENT:
  Three providers: A ($0.01/req, cheapest), B ($0.05/req), C ($0.10/req, most reliable).
  provider_X_status = windowed success rate [0=always fails, 1=always succeeds].
    IMPORTANT: A status of exactly 0.500 means this provider has NEVER been routed to
    in this episode — it is unobserved, not confirmed healthy. Route to it once to get
    a real reading. Do not treat 0.500 as a health signal.
  budget_remaining: fraction of budget left. Reaching 0 = catastrophic -10 penalty.
  step_count [0→1], steps_remaining: episode progress (20 steps total).

VALID ACTIONS (output ONLY one):
  route_to_a | route_to_b | route_to_c | shed_load

GOLDEN RULE — DEFAULT STRATEGY:
  Stay on the CHEAPEST provider whose status > 0.52. Only deviate if there is CLEAR, SUSTAINED evidence of degradation (defined below). Unnecessary switching to expensive providers burns budget and reduces your score.

NOISE CALIBRATION (critical):
- Status fluctuates due to Bernoulli sampling noise. Single-step dips are not reliable signals.
- Use the provided 2-step trend (avg/step): a sustained negative trend across multiple steps
indicates real degradation; a trend near 0 means the provider is stable. Do NOT switch on noise.
- REAL degradation signal: sustained negative trend AND current status is visibly declining.
- Only when both conditions hold across consecutive observations should you consider early switching.
- On stable tasks, trends hover near zero. Switching on noise burns budget without benefit.


WHEN TO SWITCH (use your conversation history):
A → B: When trend_a is clearly and consistently negative AND status_a is approaching unreliable,
           OR status_a is already below 0.52 (failure probability exceeds success probability).
B → C: Same principle — sustained decline signals, not single-step noise.
Never switch based on a single bad observation — noise causes occasional dips.

BUDGET RUNWAY — HARD CONSTRAINT:
budget_runway_at_current_rate shows how many more steps you can afford at current spend rate.
If budget_runway_at_current_rate < steps_remaining: switch to a cheaper provider IMMEDIATELY.
If budget_remaining < 0.15 (less than 15% left): treat C as OFF-LIMITS unless A and B are
  both below 0.30 status. Prefer shed_load over routing C when budget is this low.
NEVER route to any provider if doing so would leave budget_remaining below the cost of
that provider times the steps_remaining. The -10 bankruptcy penalty destroys all episode
value accumulated so far — budget survival is non-negotiable.
TASK PROFILES (the task name appears in each observation — use it):
  easy:       Stable environment. Trend fluctuations are mostly noise. Stay on the cheapest provider unless its trend is catastrophically and sustainedly negative.
  medium:     Dynamic environment. A provider may degrade mid-episode. Monitor trends and switch to the next cheapest healthy fallback if the primary fails.
  hard / hard_multi: Hostile, multi-failure environments. Multiple providers may degrade at unexpected times in unpredictable sequences.
              Your Runbook: Always map traffic to the lowest-cost healthy provider (A=$0.01, B=$0.05, C=$0.10).
              Watch your conversation history: if your currently active provider shows a clear, sustained negative trend, switch early to the next cheapest option that is healthy.
              CRITICAL: Before switching to expensive fallbacks (like C), use budget_runway to verify you can afford them to prevent budget exhaustion.

Output only the action string."""

OBJECTIVE_FEEDBACK_PROMPT = """
You are a budget-aware API routing agent managing a production system.
At each step, output EXACTLY ONE action string and nothing else.

VALID ACTIONS:
route_to_a | route_to_b | route_to_c | shed_load

ENVIRONMENT INTUITION:
You route one request per step across three providers:
- A: cheapest ($0.01), useful for conserving budget, but may degrade.
- B: medium cost ($0.05), usually the bridge provider.
- C: expensive ($0.10), most reliable, but using it too early can bankrupt the episode.
- shed_load: reject this request; use only when routing is likely worse than abstaining or budget is critically unsafe.

OBSERVATIONS:
provider_X_status is a rolling recent success estimate, not true health.
- Exactly 0.500 means unobserved in this episode: no evidence yet.
- After one probe, the status may jump to 0.000 or 1.000; that is weak evidence because it may be only one sample.
- Repeated outcomes, sustained negative trend, worsening reward, or repeated failures are stronger evidence.
Do not treat a single success as proof of health or a single failure as proof of degradation.

PRIMARY OBJECTIVE:
Maximize full-episode grader score:
- successful routed requests
- low latency and SLA health
- budget preservation
- adaptation after provider degradation

DECISION POLICY:
1. Budget survival is mandatory. Avoid the -10 budget exhaustion cliff.
2. Prefer the cheapest provider that is plausibly healthy, but do not blindly follow a fixed threshold.
3. Probe unknown providers when information is valuable and affordable.
4. Switch away from a provider when there is repeated failure, sustained decline, or clearly bad status.
5. Use C as late-phase reliability capacity, not as the default.
6. Prefer shed_load when all available routed choices look likely to fail or when routing would risk budget exhaustion.

TASK-SPECIFIC STRATEGY:
easy:
- Stable task. Prefer cheap routing. Do not overreact to noise.

medium:
- A may degrade after early steps. Start cheap, then move to B when A shows repeated weakness.

hard:
- A may degrade from the beginning. Probe cheaply, but react faster to repeated A failures or sustained decline.

hard_multi:
- This is a sequential cascade: A degrades early; B can degrade later.
- Early phase: conserve budget. Use/probe A only while evidence supports it.
- Middle phase: B is often the bridge provider after A weakens.
- Late phase: preserve enough runway for C if B begins failing.
- Do not burn C too early; do not stay on A/B after repeated failures.

BUDGET RUNWAY:
If budget_runway_at_current_rate < steps_remaining, current spending is unsafe.
If budget_remaining < 0.15, treat C as emergency-only unless A and B both look very poor.
If routing risks budget exhaustion, shed_load is better than bankruptcy.

Use previous_step_feedback when present:
- previous_success=false, negative reward, high latency, or repeated same-provider failures are evidence to update your belief.
- previous_success=true is useful evidence but not proof, especially after only one sample.

Output only one valid action string.
"""

app = typer.Typer(add_completion=False)

API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")

API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
LLM_TIMEOUT_SECONDS = float(os.getenv("LLM_TIMEOUT_SECONDS") or "25")
LLM_MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES") or "1")
LLM_LOG_RAW = (os.getenv("LLM_LOG_RAW") or "").strip().lower() in {"1", "true", "yes", "y", "on"}
LLM_LOG_RAW_MAX_CHARS = int(os.getenv("LLM_LOG_RAW_MAX_CHARS") or "220")
LLM_POLICY_MODE = (os.getenv("LLM_POLICY_MODE") or "baseline").strip().lower()
BENCHMARK_NAME = os.getenv("BENCHMARK_NAME") or "budget_router"

SEED_SETS: Dict[str, List[int]] = {
    "development": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    "heldout": [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
}
TASKS: Dict[str, TaskConfig] = {
    "easy": EASY,
    "medium": MEDIUM,
    "hard": HARD,
    "hard_multi": HARD_MULTI,
}
VALID_ACTIONS = [action.value for action in ActionType]


class LLMRouter:
    def __init__(
        self,
        api_base_url: str,
        model_name: str,
        api_key: str,
        prompt_mode: str | None = None,
    ) -> None:
        self._client = OpenAI(
            base_url=api_base_url,
            api_key=api_key,
            timeout=LLM_TIMEOUT_SECONDS,
            max_retries=LLM_MAX_RETRIES,
        )
        self._model_name = model_name
        self._prompt_mode = (prompt_mode or LLM_POLICY_MODE or "baseline").strip().lower()
        self._messages: List[Dict[str, str]] = []
        self.last_error: str | None = None
        self.last_raw_output: str | None = None
        self.last_parsed_action: str | None = None
        self._prev_obs: dict | None = None
        self._prev2_obs: dict | None = None
        self._task_name: str = ""
        self.reset()

    def reset(self, task_name: str = "") -> None:
        prompt = OBJECTIVE_FEEDBACK_PROMPT if self._prompt_mode == "objective_feedback" else SYSTEM_PROMPT
        self._messages = [{"role": "system", "content": prompt}]
        self.last_error = None
        self.last_raw_output = None
        self.last_parsed_action = None
        self._prev_obs = None
        self._prev2_obs = None
        self._task_name = task_name

    def choose_action(self, observation: Observation) -> Action:
        obs = observation
        if not self._messages:
            self.reset(task_name=self._task_name)
        elif obs.step_count == 0.0 and len(self._messages) > 1:
            self.reset(task_name=self._task_name)

        # ── Compute 2-step trend (more noise-robust than single-step delta) ──
        trend_text = ""
        budget_runway_text = ""

        if self._prev2_obs is not None:
            # Average per-step change over 2 steps — variance is ~30% lower than 1-step delta
            ta = (obs.provider_a_status - self._prev2_obs["a"]) / 2.0
            tb = (obs.provider_b_status - self._prev2_obs["b"]) / 2.0
            tc = (obs.provider_c_status - self._prev2_obs["c"]) / 2.0
            trend_text = f"\ntrend (avg/step, 2-step):  A:{ta:+.3f}  B:{tb:+.3f}  C:{tc:+.3f}"
        elif self._prev_obs is not None:
            # First step — single-step delta only, label as less reliable
            ta = obs.provider_a_status - self._prev_obs["a"]
            tb = obs.provider_b_status - self._prev_obs["b"]
            tc = obs.provider_c_status - self._prev_obs["c"]
            trend_text = f"\ntrend (1-step only, noisy): A:{ta:+.3f}  B:{tb:+.3f}  C:{tc:+.3f}"

        if self._prev_obs is not None:
            budget_spent = self._prev_obs["budget"] - obs.budget_remaining
            if budget_spent > 0.001:
                runway = int(obs.budget_remaining / budget_spent)
                budget_runway_text = f"\nbudget_runway_at_current_rate: ~{runway} steps"
            else:
                budget_runway_text = "\nbudget_runway_at_current_rate: >20 steps"

        steps_total = 20
        steps_remaining = max(1, steps_total - int(round(obs.step_count * steps_total)))

        task_line = f"\ntask: {self._task_name}" if self._task_name else ""
        obs_text = "\n".join([
            f"provider_a_status: {obs.provider_a_status:.3f}",
            f"provider_b_status: {obs.provider_b_status:.3f}",
            f"provider_c_status: {obs.provider_c_status:.3f}",
            f"budget_remaining:  {obs.budget_remaining:.3f}",
            f"queue_backlog:     {obs.queue_backlog:.3f}",
            f"system_latency:    {obs.system_latency:.3f}",
            f"step_count:        {obs.step_count:.3f}",
            f"steps_remaining:   {steps_remaining}",
        ])
        obs_text += trend_text + budget_runway_text + task_line
        if self._prompt_mode == "objective_feedback":
            feedback_lines = self._previous_step_feedback(observation=obs)
            if feedback_lines:
                obs_text += "\n" + feedback_lines
        user_prompt = f"Current observation:\n{obs_text}\n\nYour action:"

        # Shift history: prev becomes prev2, current becomes prev
        self._prev2_obs = self._prev_obs
        self._prev_obs = {
            "a": obs.provider_a_status,
            "b": obs.provider_b_status,
            "c": obs.provider_c_status,
            "budget": obs.budget_remaining,
        }

        client = self._client
        model_name = self._model_name
        self._messages.append({"role": "user", "content": user_prompt})
        try:
            response = client.with_options(timeout=LLM_TIMEOUT_SECONDS).chat.completions.create(
                model=model_name,
                messages=self._messages,
                max_tokens=30,
                temperature=0,
            )
            raw = response.choices[0].message.content or ""
            action_str = _parse_llm_action(raw)
            action_str = self._apply_budget_safety_guard(action_str=action_str, observation=obs)
            self.last_raw_output = raw
            self.last_parsed_action = action_str
            self.last_error = None
        except Exception as e:
            self.last_error = str(e)
            action_str = "shed_load"
            self.last_raw_output = None
            self.last_parsed_action = action_str
        self._messages.append({"role": "assistant", "content": action_str})
        return Action(action_type=ActionType(action_str))

    def _apply_budget_safety_guard(self, action_str: str, observation: Observation) -> str:
        """Prevent only actions that would immediately exhaust the public remaining budget."""
        if action_str == "shed_load":
            return action_str

        scenario = TASKS.get(self._task_name)
        if scenario is None:
            return action_str

        action_costs = {
            "route_to_a": scenario.cost_a,
            "route_to_b": scenario.cost_b,
            "route_to_c": scenario.cost_c,
        }
        selected_cost = action_costs.get(action_str, 0.0)
        budget_dollars = float(observation.budget_remaining) * float(scenario.initial_budget)

        if selected_cost >= budget_dollars - 1e-9:
            return "shed_load"
        return action_str

    def _previous_step_feedback(self, observation: Observation) -> str:
        metadata = getattr(observation, "metadata", None) or {}
        if not metadata:
            return ""

        previous_action = metadata.get("action_type")
        if not previous_action:
            return ""

        reward = observation.reward
        latency = metadata.get("latency_ms")
        cost = metadata.get("cost")
        succeeded = metadata.get("request_succeeded")
        budget_exhausted = metadata.get("budget_exhausted", False)

        feedback_parts = [
            "previous_step_feedback:",
            f"  previous_action: {previous_action}",
        ]
        if reward is not None:
            feedback_parts.append(f"  previous_reward: {float(reward):+.2f}")
        if succeeded is not None:
            feedback_parts.append(f"  previous_success: {str(bool(succeeded)).lower()}")
        if cost is not None:
            feedback_parts.append(f"  previous_cost: {float(cost):.2f}")
        if latency is not None:
            feedback_parts.append(f"  previous_latency_ms: {float(latency):.2f}")
        if budget_exhausted:
            feedback_parts.append("  previous_budget_exhausted: true")
        return "\n".join(feedback_parts)


def _single_line(value: str | None) -> str:
    if not value:
        return "null"
    return str(value).replace("\n", " ").replace("\r", " ")


def _truncate_and_sanitize(value: str | None, max_chars: int) -> str:
    if not value:
        return "null"
    s = _single_line(value).strip()
    if len(s) <= max_chars:
        return s
    return s[: max(0, max_chars - 3)] + "..."


def _observation_payload(observation: Observation) -> Dict[str, float]:
    return {
        "provider_a_status": float(observation.provider_a_status),
        "provider_b_status": float(observation.provider_b_status),
        "provider_c_status": float(observation.provider_c_status),
        "budget_remaining": float(observation.budget_remaining),
        "queue_backlog": float(observation.queue_backlog),
        "system_latency": float(observation.system_latency),
        "step_count": float(observation.step_count),
    }


def _reported_score(value: float) -> float:
    return min(max(float(value), 0.001), 0.999)


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(
    step: int,
    action: str,
    reward: float,
    done: bool,
    error: str | None,
    llm_raw: str | None = None,
    llm_parsed: str | None = None,
) -> None:
    base = (
        f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} "
        f"error={_single_line(error)}"
    )
    if LLM_LOG_RAW:
        raw_s = _truncate_and_sanitize(llm_raw, max_chars=max(20, LLM_LOG_RAW_MAX_CHARS))
        parsed_s = _single_line(llm_parsed)
        base += f" llm_raw={raw_s} llm_parsed={parsed_s}"
    print(base, flush=True)


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={_reported_score(score):.3f} rewards={rewards_str}",
        flush=True,
    )


def select_policy(policy_name: Literal["heuristic", "llm"]) -> object:
    if policy_name == "heuristic":
        return heuristic_baseline_policy

    if not API_KEY or not API_BASE_URL:
        raise RuntimeError(
            "LLM policy requires API_BASE_URL and API_KEY and reads MODEL_NAME from environment variables."
        )
    return LLMRouter(api_base_url=API_BASE_URL, model_name=MODEL_NAME, api_key=API_KEY)


def choose_action(policy_name: Literal["heuristic", "llm"], policy: object, observation: Observation) -> Action:
    if policy_name == "heuristic":
        return policy(observation)
    return policy.choose_action(observation)


def run_episode(
    env: BudgetRouterEnv,
    scenario: TaskConfig,
    seed: int,
    episode: int,
    policy_name: Literal["heuristic", "llm"],
    policy: object,
) -> Dict[str, Any]:
    total_reward = 0.0
    grader_score: float | None = None
    rewards: List[float] = []
    steps_taken = 0
    success = False

    if policy_name == "llm":
        policy.reset(task_name=scenario.name)

    log_start(task=scenario.name, env=BENCHMARK_NAME, model=MODEL_NAME)

    try:
        observation = env.reset(seed=seed, scenario=scenario)
        while not observation.done:
            action = choose_action(policy_name=policy_name, policy=policy, observation=observation)
            action_name = action.action_type.value
            observation = env.step(action)
            reward = float(observation.reward or 0.0)
            total_reward += reward
            rewards.append(reward)
            steps_taken = env._internal.current_step
            step_error = getattr(policy, "last_error", None) if policy_name == "llm" else None
            llm_raw = getattr(policy, "last_raw_output", None) if policy_name == "llm" else None
            llm_parsed = getattr(policy, "last_parsed_action", None) if policy_name == "llm" else None
            log_step(
                step=env._internal.current_step,
                action=action_name,
                reward=reward,
                done=bool(observation.done),
                error=step_error,
                llm_raw=llm_raw,
                llm_parsed=llm_parsed,
            )

        metrics = episode_metrics(env._internal.history)
        metrics["seed"] = seed
        metrics["episode"] = episode
        metrics["total_reward"] = round(total_reward, 4)
        metrics["episode_length"] = env._internal.current_step
        grader = grade_episode(env._internal.history)
        grader_score = float(grader["overall_score"])
        success = grader_score > 0.0
        metrics["grader_score"] = grader_score
        metrics["grader_breakdown"] = grader
        return metrics
    finally:
        close_fn = getattr(env, "close", None)
        if callable(close_fn):
            close_fn()
        if grader_score is None:
            grader_score = float(grade_episode(env._internal.history)["overall_score"])
            success = grader_score > 0.0
        log_end(success=success, steps=steps_taken, score=grader_score, rewards=rewards)


def summarize(metrics: Iterable[Dict[str, float]]) -> Dict[str, float]:
    rows = list(metrics)
    return {
        "mean_reward": round(sum(row["total_reward"] for row in rows) / len(rows), 4),
        "mean_success_rate": round(sum(row["success_rate"] for row in rows) / len(rows), 4),
        "mean_cost": round(sum(row["total_cost_spent"] for row in rows) / len(rows), 4),
        "mean_latency_ms": round(sum(row["average_latency_ms"] for row in rows) / len(rows), 2),
        "mean_grader_score": round(sum(row["grader_score"] for row in rows) / len(rows), 4),
    }


@app.command()
def main(
    policy: Literal["heuristic", "llm"] = typer.Option("llm" if API_KEY and API_BASE_URL else "heuristic"),
    seed_set: Literal["development", "heldout"] = typer.Option("development"),
    scenario: Literal["all", "easy", "medium", "hard", "hard_multi"] = typer.Option("all"),
    max_seeds: int = typer.Option(1),
    output_path: Path = typer.Option(Path("baseline_results.json")),
) -> None:
    selected_policy = select_policy(policy)
    selected_tasks = TASKS if scenario == "all" else {scenario: TASKS[scenario]}
    selected_seeds = SEED_SETS[seed_set][: max(1, max_seeds)]
    results: Dict[str, Dict[str, object]] = {}
    episode = 1

    for task_name, task_config in selected_tasks.items():
        task_metrics = []
        for seed in selected_seeds:
            env = BudgetRouterEnv()
            task_metrics.append(
                run_episode(
                    env=env,
                    scenario=task_config,
                    seed=seed,
                    episode=episode,
                    policy_name=policy,
                    policy=selected_policy,
                )
            )
            episode += 1
        results[task_name] = {
            "policy": policy,
            "seed_set": seed_set,
            "summary": summarize(task_metrics),
            "episodes": task_metrics,
        }

    output_path.write_text(json.dumps(results, indent=2) + "\n", encoding="utf-8")


if __name__ == "__main__":
    app()