File size: 3,434 Bytes
98a5a8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Evaluate the trained PPO agent against the heuristic baseline.

Usage:
    uv run python train/eval_ppo.py

Loads trained_models/ppo_easy_50k.zip and runs 10 episodes (seeds 0-9),
reporting per-episode and mean grader scores.
"""
from __future__ import annotations

import statistics
from pathlib import Path

from stable_baselines3 import PPO

from train.gym_wrapper import BudgetRouterGymEnv
from budget_router.environment import BudgetRouterEnv
from budget_router.models import Action, ActionType
from budget_router.policies import heuristic_baseline_policy
from budget_router.reward import grade_episode
from budget_router.tasks import EASY

MODEL_PATH = "trained_models/ppo_easy_50k.zip"
EVAL_SEEDS = list(range(10))    # seeds 0-9 (development set)
HEURISTIC_BASELINE = 0.7958     # confirmed grader score from README


def _grader_score_from_history(history: list[dict]) -> float:
    """Compute grader score directly from the environment's history dict."""
    return float(grade_episode(history)["overall_score"])


def eval_ppo(model: PPO, seeds: list[int]) -> list[float]:
    """Run PPO policy for each seed, return list of grader scores."""
    scores = []
    for seed in seeds:
        env = BudgetRouterGymEnv(scenario=EASY, seed=seed)
        inner_env = env._env          # direct access to BudgetRouterEnv for history

        obs, _ = env.reset()
        done = False
        while not done:
            action_idx, _ = model.predict(obs, deterministic=True)
            obs, _, terminated, truncated, _ = env.step(int(action_idx))
            done = terminated or truncated

        score = _grader_score_from_history(inner_env._internal.history)
        scores.append(score)
        print(f"  seed={seed:2d}  grader={score:.4f}")
    return scores


def eval_heuristic(seeds: list[int]) -> list[float]:
    """Run heuristic policy for each seed, return list of grader scores."""
    scores = []
    for seed in seeds:
        env = BudgetRouterEnv()
        obs = env.reset(seed=seed, scenario=EASY)
        while not obs.done:
            action = heuristic_baseline_policy(obs)
            obs = env.step(action)
        score = _grader_score_from_history(env._internal.history)
        scores.append(score)
    return scores


def main() -> None:
    if not Path(MODEL_PATH).exists():
        print(f"[eval] Model not found at {MODEL_PATH}. Run train/train_ppo.py first.")
        return

    print(f"[eval] Loading {MODEL_PATH}")
    model = PPO.load(MODEL_PATH)

    print("\n[eval] PPO agent (deterministic):")
    ppo_scores = eval_ppo(model, EVAL_SEEDS)
    ppo_mean = statistics.mean(ppo_scores)

    print("\n[eval] Heuristic baseline:")
    heuristic_scores = eval_heuristic(EVAL_SEEDS)
    heuristic_mean = statistics.mean(heuristic_scores)

    print("\n── Results ──────────────────────────────────")
    print(f"  PPO mean grader score  : {ppo_mean:.4f}")
    print(f"  Heuristic mean grader  : {heuristic_mean:.4f}  (expected ≈ {HEURISTIC_BASELINE})")
    delta = ppo_mean - heuristic_mean
    sign  = "+" if delta >= 0 else ""
    print(f"  Delta (PPO - heuristic): {sign}{delta:.4f}")
    if ppo_mean > 0.60:
        print("  ✅ PPO > 0.60 threshold — README update warranted.")
    else:
        print("  ⚠️  PPO < 0.60 — keep scaffolding but skip README PPO row.")


if __name__ == "__main__":
    main()