Spaces:
Sleeping
Sleeping
| """ | |
| Tests for the validation harness. | |
| Covers: policy ordering, solvability, NaN safety, baseline stability, | |
| and hard task crash resistance. | |
| """ | |
| import math | |
| import random | |
| import pytest | |
| from budget_router.environment import BudgetRouterEnv | |
| from budget_router.models import Action, ActionType | |
| from budget_router.policies import ( | |
| always_route_a_policy, | |
| always_route_b_policy, | |
| always_route_c_policy, | |
| always_shed_load_policy, | |
| debug_upper_bound_policy, | |
| heuristic_baseline_policy, | |
| random_policy, | |
| ) | |
| from budget_router.tasks import EASY, HARD, MEDIUM | |
| from budget_router.validation import DEVELOPMENT_SEEDS, HELDOUT_SEEDS, run_episode | |
| # ─── Helpers ──────────────────────────────────────────────────────────── | |
| def mean_reward_over_seeds(policy_fn, scenario, seeds, policy_name=""): | |
| """Compute mean total reward for a policy across seeds.""" | |
| env = BudgetRouterEnv() | |
| rewards = [] | |
| for seed in seeds: | |
| metrics = run_episode(env, policy_fn, seed, scenario, policy_name=policy_name) | |
| rewards.append(metrics["total_reward"]) | |
| return sum(rewards) / len(rewards), rewards | |
| # ─── Validation Tests ────────────────────────────────────────────────── | |
| class TestValidation: | |
| """Validation-level tests.""" | |
| def test_baseline_beats_random_easy_dev(self): | |
| """Baseline beats random on easy task across development seeds.""" | |
| baseline_mean, _ = mean_reward_over_seeds( | |
| heuristic_baseline_policy, EASY, DEVELOPMENT_SEEDS | |
| ) | |
| random_mean, _ = mean_reward_over_seeds( | |
| random_policy, EASY, DEVELOPMENT_SEEDS, policy_name="random" | |
| ) | |
| assert baseline_mean > random_mean, ( | |
| f"baseline ({baseline_mean:.2f}) <= random ({random_mean:.2f}) on easy" | |
| ) | |
| def test_upper_bound_beats_baseline_easy_dev(self): | |
| """Upper bound beats or matches baseline on easy task across dev seeds.""" | |
| baseline_mean, _ = mean_reward_over_seeds( | |
| heuristic_baseline_policy, EASY, DEVELOPMENT_SEEDS | |
| ) | |
| ub_mean, _ = mean_reward_over_seeds( | |
| debug_upper_bound_policy, EASY, DEVELOPMENT_SEEDS, policy_name="upper_bound" | |
| ) | |
| assert ub_mean >= baseline_mean, ( | |
| f"oracle ({ub_mean:.2f}) < baseline ({baseline_mean:.2f}) on easy" | |
| ) | |
| def test_easy_solvable_positive_reward(self): | |
| """Easy task is solvable: baseline achieves positive total reward on seed=42.""" | |
| env = BudgetRouterEnv() | |
| metrics = run_episode(env, heuristic_baseline_policy, seed=42, scenario=EASY) | |
| assert metrics["total_reward"] > 0, ( | |
| f"baseline achieves {metrics['total_reward']:.2f} on easy/seed=42" | |
| ) | |
| def test_hard_no_crash_dev_seeds(self): | |
| """Hard task terminates without environment crash on development_seeds.""" | |
| env = BudgetRouterEnv() | |
| for seed in DEVELOPMENT_SEEDS: | |
| try: | |
| metrics = run_episode( | |
| env, heuristic_baseline_policy, seed=seed, scenario=HARD | |
| ) | |
| assert metrics["episode_length"] <= 20 | |
| except Exception as e: | |
| pytest.fail(f"Hard task crashed on seed {seed}: {e}") | |
| def test_no_nan_rewards_all_combos(self): | |
| """No reward is NaN across all (task, policy, seed_set) combinations.""" | |
| env = BudgetRouterEnv() | |
| policies = { | |
| "random": random_policy, | |
| "heuristic_baseline": heuristic_baseline_policy, | |
| "upper_bound": debug_upper_bound_policy, | |
| "always_route_a": always_route_a_policy, | |
| "always_route_b": always_route_b_policy, | |
| "always_route_c": always_route_c_policy, | |
| "always_shed_load": always_shed_load_policy, | |
| } | |
| for scenario in [EASY, MEDIUM, HARD]: | |
| for policy_name, policy_fn in policies.items(): | |
| for seed in DEVELOPMENT_SEEDS[:3]: # subset for speed | |
| metrics = run_episode( | |
| env, policy_fn, seed, scenario, policy_name=policy_name | |
| ) | |
| assert not math.isnan(metrics["total_reward"]), ( | |
| f"NaN reward: {scenario.name}/{policy_name}/seed={seed}" | |
| ) | |
| def test_baseline_stability_heldout(self): | |
| """Baseline remains within reasonable stability margin on heldout seeds.""" | |
| for scenario in [EASY, MEDIUM, HARD]: | |
| dev_mean, _ = mean_reward_over_seeds( | |
| heuristic_baseline_policy, scenario, DEVELOPMENT_SEEDS | |
| ) | |
| heldout_mean, _ = mean_reward_over_seeds( | |
| heuristic_baseline_policy, scenario, HELDOUT_SEEDS | |
| ) | |
| margin = max(2.0, 0.40 * abs(dev_mean)) | |
| assert abs(heldout_mean - dev_mean) <= margin, ( | |
| f"Baseline unstable on {scenario.name}: " | |
| f"dev={dev_mean:.2f}, heldout={heldout_mean:.2f}, margin={margin:.2f}" | |
| ) | |
| def test_baseline_beats_always_route_b_dev(self): | |
| """Baseline beats always_route_b on all tasks across development seeds.""" | |
| for scenario in [EASY, MEDIUM, HARD]: | |
| baseline_mean, _ = mean_reward_over_seeds( | |
| heuristic_baseline_policy, scenario, DEVELOPMENT_SEEDS | |
| ) | |
| always_b_mean, _ = mean_reward_over_seeds( | |
| always_route_b_policy, scenario, DEVELOPMENT_SEEDS | |
| ) | |
| assert baseline_mean >= always_b_mean, ( | |
| f"baseline ({baseline_mean:.2f}) < always_route_b ({always_b_mean:.2f}) on {scenario.name}" | |
| ) | |