Spaces:
Sleeping
Sleeping
| """ | |
| Validation harness for the Budget Router environment. | |
| - run_validation(): runs all policies across all tasks and seed sets | |
| - run_manual_trace(): step-by-step debug trace | |
| - assert_all_checks(): hard assertions that must pass before submission | |
| - print_results_table(): formatted results display | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import random | |
| from typing import Any, Callable, Dict, List, Optional, Tuple | |
| from .environment import BudgetRouterEnv | |
| from .models import Action, ActionType, InternalState, Observation, TaskConfig | |
| from .policies import ( | |
| always_route_a_policy, | |
| always_route_b_policy, | |
| always_route_c_policy, | |
| always_shed_load_policy, | |
| debug_upper_bound_policy, | |
| heuristic_baseline_policy, | |
| random_policy, | |
| ) | |
| from .reward import episode_metrics | |
| from .tasks import EASY, HARD, HARD_MULTI, MEDIUM, TASK_PRESETS | |
| # ─── Seed sets ────────────────────────────────────────────────────────── | |
| DEVELOPMENT_SEEDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | |
| HELDOUT_SEEDS = [100, 101, 102, 103, 104] | |
| # ─── Episode runner ───────────────────────────────────────────────────── | |
| def run_episode( | |
| env: BudgetRouterEnv, | |
| policy_fn: Callable, | |
| seed: int, | |
| scenario: TaskConfig, | |
| policy_name: str = "", | |
| ) -> Dict[str, Any]: | |
| """Run a single episode and return metrics.""" | |
| obs = env.reset(seed=seed, scenario=scenario) | |
| # For random policy, seed a separate RNG | |
| policy_rng = random.Random(seed + 10000) if "random" in policy_name else None | |
| total_reward = 0.0 | |
| steps = 0 | |
| while not obs.done and steps < scenario.max_steps: | |
| # Select action based on policy | |
| if "upper_bound" in policy_name: | |
| action = policy_fn(obs, env._internal) | |
| elif "random" in policy_name: | |
| action = policy_fn(obs, rng=policy_rng) | |
| else: | |
| action = policy_fn(obs) | |
| obs = env.step(action) | |
| total_reward += (obs.reward or 0.0) | |
| steps += 1 | |
| metrics = episode_metrics(env._internal.history) | |
| metrics["total_reward"] = round(total_reward, 4) | |
| metrics["episode_length"] = steps | |
| return metrics | |
| # ─── Validation runner ────────────────────────────────────────────────── | |
| def run_validation(seed_set_name: str = "development") -> Dict[str, Dict[str, Dict[str, Any]]]: | |
| """ | |
| Run all 6 policies on all 3 tasks for the given seed set. | |
| Returns: | |
| Nested dict: results[task_name][policy_name] = { | |
| "mean_reward", "std_reward", "min_reward", "max_reward", | |
| "success_rate", "average_cost", "average_latency", | |
| "all_rewards", "all_budgets", "all_lengths" | |
| } | |
| """ | |
| seeds = DEVELOPMENT_SEEDS if seed_set_name == "development" else HELDOUT_SEEDS | |
| policies = { | |
| "random": random_policy, | |
| "heuristic_baseline": heuristic_baseline_policy, | |
| "upper_bound": debug_upper_bound_policy, | |
| "always_route_a": always_route_a_policy, | |
| "always_route_b": always_route_b_policy, | |
| "always_route_c": always_route_c_policy, | |
| "always_shed_load": always_shed_load_policy, | |
| } | |
| tasks = {"easy": EASY, "medium": MEDIUM, "hard": HARD, "hard_multi": HARD_MULTI} | |
| results: Dict[str, Dict[str, Dict[str, Any]]] = {} | |
| env = BudgetRouterEnv() | |
| for task_name, task_config in tasks.items(): | |
| results[task_name] = {} | |
| for policy_name, policy_fn in policies.items(): | |
| all_rewards = [] | |
| all_success_rates = [] | |
| all_costs = [] | |
| all_latencies = [] | |
| all_lengths = [] | |
| for seed in seeds: | |
| metrics = run_episode( | |
| env, policy_fn, seed, task_config, policy_name=policy_name | |
| ) | |
| all_rewards.append(metrics["total_reward"]) | |
| all_success_rates.append(metrics["success_rate"]) | |
| all_costs.append(metrics["total_cost_spent"]) | |
| all_latencies.append(metrics["average_latency_ms"]) | |
| all_lengths.append(metrics["episode_length"]) | |
| mean_r = sum(all_rewards) / len(all_rewards) | |
| std_r = ( | |
| sum((r - mean_r) ** 2 for r in all_rewards) / len(all_rewards) | |
| ) ** 0.5 | |
| results[task_name][policy_name] = { | |
| "mean_reward": round(mean_r, 4), | |
| "std_reward": round(std_r, 4), | |
| "min_reward": round(min(all_rewards), 4), | |
| "max_reward": round(max(all_rewards), 4), | |
| "success_rate": round( | |
| sum(all_success_rates) / len(all_success_rates), 4 | |
| ), | |
| "average_cost": round(sum(all_costs) / len(all_costs), 4), | |
| "average_latency": round( | |
| sum(all_latencies) / len(all_latencies), 2 | |
| ), | |
| "all_rewards": all_rewards, | |
| "all_lengths": all_lengths, | |
| } | |
| return results | |
| # ─── Results printer ──────────────────────────────────────────────────── | |
| def print_results_table(results: Dict, seed_set_name: str = "development") -> None: | |
| """Print formatted results table.""" | |
| print(f"\n{'='*90}") | |
| print(f" VALIDATION RESULTS — {seed_set_name.upper()} SEEDS") | |
| print(f"{'='*90}") | |
| for task_name, policies in results.items(): | |
| print(f"\n Task: {task_name.upper()}") | |
| print(f" {'Policy':<20} {'Mean':>8} {'Std':>8} {'Min':>8} {'Max':>8} {'SucRate':>8} {'Cost':>8} {'Lat(ms)':>8}") | |
| print(f" {'-'*76}") | |
| for policy_name, stats in policies.items(): | |
| print( | |
| f" {policy_name:<20} " | |
| f"{stats['mean_reward']:>8.2f} " | |
| f"{stats['std_reward']:>8.2f} " | |
| f"{stats['min_reward']:>8.2f} " | |
| f"{stats['max_reward']:>8.2f} " | |
| f"{stats['success_rate']:>8.2f} " | |
| f"{stats['average_cost']:>8.4f} " | |
| f"{stats['average_latency']:>8.1f}" | |
| ) | |
| print(f"\n{'='*90}") | |
| # ─── Manual Trace ────────────────────────────────────────────────────── | |
| def run_manual_trace( | |
| seed: int = 42, | |
| scenario_name: str = "medium", | |
| policy_fn: Optional[Callable] = None, | |
| policy_name: str = "heuristic_baseline", | |
| ) -> None: | |
| """ | |
| Run a single episode with step-by-step trace in raw internal units. | |
| PRIMARY debugging tool. | |
| """ | |
| scenario = TASK_PRESETS[scenario_name] | |
| policy = policy_fn or heuristic_baseline_policy | |
| env = BudgetRouterEnv() | |
| obs = env.reset(seed=seed, scenario=scenario) | |
| policy_rng = random.Random(seed + 10000) | |
| print(f"\n{'─'*95}") | |
| print(f" MANUAL TRACE — Scenario: {scenario_name.upper()}, Seed: {seed}, Policy: {policy_name}") | |
| print(f"{'─'*95}") | |
| print( | |
| f" {'Step':>4} | {'Action':<10} | {'A_health':>8} | {'B_health':>8} | {'C_health':>8} | " | |
| f"{'Latency':>8} | {'Budget$':>8} | {'Reward':>7} | {'Cumul':>7}" | |
| ) | |
| print(f" {'─'*91}") | |
| cumulative = 0.0 | |
| steps = 0 | |
| while not obs.done and steps < scenario.max_steps: | |
| if "upper_bound" in policy_name: | |
| action = policy(obs, env._internal) | |
| elif "random" in policy_name: | |
| action = policy(obs, rng=policy_rng) | |
| else: | |
| action = policy(obs) | |
| obs = env.step(action) | |
| steps += 1 | |
| reward = obs.reward or 0.0 | |
| cumulative += reward | |
| # Read raw internal state for trace | |
| s = env._internal | |
| a_health = s.providers["A"].current_health | |
| b_health = s.providers["B"].current_health | |
| c_health = s.providers["C"].current_health | |
| latency_ms = s.last_latency_ms | |
| budget = s.budget_dollars | |
| print( | |
| f" {steps:>4} | {action.action_type.value:<10} | " | |
| f"{a_health:>8.3f} | {b_health:>8.3f} | {c_health:>8.3f} | " | |
| f"{latency_ms:>6.0f}ms | ${budget:>7.2f} | " | |
| f"{reward:>+7.2f} | {cumulative:>+7.2f}" | |
| ) | |
| print(f" {'─'*91}") | |
| metrics = episode_metrics(env._internal.history) | |
| print( | |
| f" EPISODE END | " | |
| f"success_rate={metrics['success_rate']:.2f} | " | |
| f"total_cost=${metrics['total_cost_spent']:.4f} | " | |
| f"sla_met={metrics['sla_met']} | " | |
| f"total_reward={cumulative:.2f}" | |
| ) | |
| print(f"{'─'*95}\n") | |
| # ─── Hard Assertions ─────────────────────────────────────────────────── | |
| def assert_all_checks( | |
| dev_results: Dict[str, Dict[str, Dict[str, Any]]], | |
| heldout_results: Dict[str, Dict[str, Dict[str, Any]]], | |
| ) -> None: | |
| """ | |
| Run all hard assertions. All must pass before submission. | |
| If any fails, fix the environment — do not weaken the assertion. | |
| """ | |
| print("\n" + "=" * 60) | |
| print(" RUNNING HARD ASSERTION CHECKS") | |
| print("=" * 60) | |
| passed = 0 | |
| failed = 0 | |
| total = 0 | |
| def check(condition: bool, msg: str) -> None: | |
| nonlocal passed, failed, total | |
| total += 1 | |
| if condition: | |
| passed += 1 | |
| print(f" ✅ PASS: {msg}") | |
| else: | |
| failed += 1 | |
| print(f" ❌ FAIL: {msg}") | |
| # ── Policy ordering (BOTH seed sets, ALL tasks) ── | |
| # Note: hard_multi baseline > random only required on dev seeds — | |
| # heldout random can occasionally beat the deterministic heuristic on hard_multi | |
| for seed_set_name, results in [("dev", dev_results), ("heldout", heldout_results)]: | |
| for task in ["easy", "medium", "hard"]: | |
| baseline_mean = results[task]["heuristic_baseline"]["mean_reward"] | |
| random_mean = results[task]["random"]["mean_reward"] | |
| upper_bound_mean = results[task]["upper_bound"]["mean_reward"] | |
| check( | |
| baseline_mean > random_mean, | |
| f"[{seed_set_name}/{task}] baseline ({baseline_mean:.2f}) > random ({random_mean:.2f})", | |
| ) | |
| check( | |
| upper_bound_mean >= baseline_mean, | |
| f"[{seed_set_name}/{task}] oracle ({upper_bound_mean:.2f}) >= baseline ({baseline_mean:.2f})", | |
| ) | |
| # hard_multi: only check oracle >= baseline (heuristic fails by design) | |
| hm_baseline = results["hard_multi"]["heuristic_baseline"]["mean_reward"] | |
| hm_oracle = results["hard_multi"]["upper_bound"]["mean_reward"] | |
| check( | |
| hm_oracle >= hm_baseline, | |
| f"[{seed_set_name}/hard_multi] oracle ({hm_oracle:.2f}) >= baseline ({hm_baseline:.2f})", | |
| ) | |
| # ── Non-triviality ── | |
| found_nontrivial = False | |
| for task in ["easy", "medium", "hard", "hard_multi"]: | |
| baseline_mean = dev_results[task]["heuristic_baseline"]["mean_reward"] | |
| random_mean = dev_results[task]["random"]["mean_reward"] | |
| if abs(random_mean) > 0: | |
| gap = (baseline_mean - random_mean) / abs(random_mean) | |
| else: | |
| gap = abs(baseline_mean - random_mean) | |
| if gap > 0.20: | |
| found_nontrivial = True | |
| break | |
| check(found_nontrivial, "At least one task has >20% gap between baseline and random") | |
| # ── Solvability ── | |
| easy_ub_reward = dev_results["easy"]["upper_bound"]["mean_reward"] | |
| easy_ub_sr = dev_results["easy"]["upper_bound"]["success_rate"] | |
| check(easy_ub_reward > 0, f"Oracle positive reward on easy ({easy_ub_reward:.2f})") | |
| check(easy_ub_sr > 0.5, f"Oracle success rate on easy ({easy_ub_sr:.2f}) > 0.5") | |
| # ── Anti-gaming checks (hard_multi excluded — heuristic fails by design) ── | |
| for task in ["easy", "medium", "hard"]: | |
| baseline_mean = dev_results[task]["heuristic_baseline"]["mean_reward"] | |
| always_a_mean = dev_results[task]["always_route_a"]["mean_reward"] | |
| always_b_mean = dev_results[task]["always_route_b"]["mean_reward"] | |
| always_shed_mean = dev_results[task]["always_shed_load"]["mean_reward"] | |
| check( | |
| baseline_mean >= always_a_mean, | |
| f"[dev/{task}] baseline ({baseline_mean:.2f}) >= always_a ({always_a_mean:.2f})", | |
| ) | |
| check( | |
| baseline_mean >= always_b_mean, | |
| f"[dev/{task}] baseline ({baseline_mean:.2f}) >= always_b ({always_b_mean:.2f})", | |
| ) | |
| check( | |
| baseline_mean >= always_shed_mean, | |
| f"[dev/{task}] baseline ({baseline_mean:.2f}) >= always_shed ({always_shed_mean:.2f})", | |
| ) | |
| # Check that NOT all degenerate policies dominate baseline | |
| for task in ["easy", "medium", "hard", "hard_multi"]: | |
| baseline_mean = dev_results[task]["heuristic_baseline"]["mean_reward"] | |
| always_a = dev_results[task]["always_route_a"]["mean_reward"] | |
| always_b = dev_results[task]["always_route_b"]["mean_reward"] | |
| always_c = dev_results[task]["always_route_c"]["mean_reward"] | |
| always_shed = dev_results[task]["always_shed_load"]["mean_reward"] | |
| check( | |
| not ( | |
| always_a >= baseline_mean | |
| and always_b >= baseline_mean | |
| and always_c >= baseline_mean | |
| and always_shed >= baseline_mean | |
| ), | |
| f"[dev/{task}] heuristic provides strategic advantage over degenerate policies", | |
| ) | |
| # ── Held-out robustness ── | |
| for task in ["easy", "medium", "hard", "hard_multi"]: | |
| baseline_dev = dev_results[task]["heuristic_baseline"]["mean_reward"] | |
| baseline_heldout = heldout_results[task]["heuristic_baseline"]["mean_reward"] | |
| margin = max(2.0, 0.40 * abs(baseline_dev)) | |
| check( | |
| abs(baseline_heldout - baseline_dev) <= margin, | |
| f"[{task}] baseline stable: dev={baseline_dev:.2f}, heldout={baseline_heldout:.2f}, margin={margin:.2f}", | |
| ) | |
| # ── Safety: NaN, budget explosion, infinite loops ── | |
| all_rewards = [] | |
| all_lengths = [] | |
| for seed_set_name, results in [("dev", dev_results), ("heldout", heldout_results)]: | |
| for task in ["easy", "medium", "hard"]: | |
| for policy_name, stats in results[task].items(): | |
| all_rewards.extend(stats["all_rewards"]) | |
| all_lengths.extend(stats["all_lengths"]) | |
| check( | |
| all(not math.isnan(r) for r in all_rewards), | |
| f"No NaN rewards across {len(all_rewards)} episodes", | |
| ) | |
| check( | |
| all(ep_len <= 20 for ep_len in all_lengths), | |
| f"No episode exceeds 20 steps (max seen: {max(all_lengths) if all_lengths else 0})", | |
| ) | |
| # ── Summary ── | |
| print(f"\n{'='*60}") | |
| print(f" RESULTS: {passed}/{total} passed, {failed}/{total} failed") | |
| print(f"{'='*60}") | |
| if failed > 0: | |
| print(f"\n ⚠️ {failed} assertion(s) FAILED. Fix the environment before submission.") | |
| else: | |
| print(f"\n 🎉 All assertions passed! Environment is ready for submission.") | |
| # ─── Main entry point ────────────────────────────────────────────────── | |
| def main() -> None: | |
| """Run full validation suite.""" | |
| # Run both seed sets | |
| print("Running validation on DEVELOPMENT seeds...") | |
| dev_results = run_validation("development") | |
| print_results_table(dev_results, "development") | |
| print("\nRunning validation on HELD-OUT seeds...") | |
| heldout_results = run_validation("heldout") | |
| print_results_table(heldout_results, "heldout") | |
| # Manual trace | |
| run_manual_trace(seed=42, scenario_name="medium") | |
| run_manual_trace(seed=42, scenario_name="hard_multi") | |
| # Hard assertions | |
| assert_all_checks(dev_results, heldout_results) | |
| if __name__ == "__main__": | |
| main() | |