Spaces:
Sleeping
Sleeping
| """ | |
| Evaluate the trained PPO agent against the heuristic baseline. | |
| Usage: | |
| uv run python train/eval_ppo.py | |
| Loads trained_models/ppo_easy_50k.zip and runs 10 episodes (seeds 0-9), | |
| reporting per-episode and mean grader scores. | |
| """ | |
| from __future__ import annotations | |
| import statistics | |
| from pathlib import Path | |
| from stable_baselines3 import PPO | |
| from train.gym_wrapper import BudgetRouterGymEnv | |
| from budget_router.environment import BudgetRouterEnv | |
| from budget_router.models import Action, ActionType | |
| from budget_router.policies import heuristic_baseline_policy | |
| from budget_router.reward import grade_episode | |
| from budget_router.tasks import EASY | |
| MODEL_PATH = "trained_models/ppo_easy_50k.zip" | |
| EVAL_SEEDS = list(range(10)) # seeds 0-9 (development set) | |
| HEURISTIC_BASELINE = 0.7958 # confirmed grader score from README | |
| def _grader_score_from_history(history: list[dict]) -> float: | |
| """Compute grader score directly from the environment's history dict.""" | |
| return float(grade_episode(history)["overall_score"]) | |
| def eval_ppo(model: PPO, seeds: list[int]) -> list[float]: | |
| """Run PPO policy for each seed, return list of grader scores.""" | |
| scores = [] | |
| for seed in seeds: | |
| env = BudgetRouterGymEnv(scenario=EASY, seed=seed) | |
| inner_env = env._env # direct access to BudgetRouterEnv for history | |
| obs, _ = env.reset() | |
| done = False | |
| while not done: | |
| action_idx, _ = model.predict(obs, deterministic=True) | |
| obs, _, terminated, truncated, _ = env.step(int(action_idx)) | |
| done = terminated or truncated | |
| score = _grader_score_from_history(inner_env._internal.history) | |
| scores.append(score) | |
| print(f" seed={seed:2d} grader={score:.4f}") | |
| return scores | |
| def eval_heuristic(seeds: list[int]) -> list[float]: | |
| """Run heuristic policy for each seed, return list of grader scores.""" | |
| scores = [] | |
| for seed in seeds: | |
| env = BudgetRouterEnv() | |
| obs = env.reset(seed=seed, scenario=EASY) | |
| while not obs.done: | |
| action = heuristic_baseline_policy(obs) | |
| obs = env.step(action) | |
| score = _grader_score_from_history(env._internal.history) | |
| scores.append(score) | |
| return scores | |
| def main() -> None: | |
| if not Path(MODEL_PATH).exists(): | |
| print(f"[eval] Model not found at {MODEL_PATH}. Run train/train_ppo.py first.") | |
| return | |
| print(f"[eval] Loading {MODEL_PATH}") | |
| model = PPO.load(MODEL_PATH) | |
| print("\n[eval] PPO agent (deterministic):") | |
| ppo_scores = eval_ppo(model, EVAL_SEEDS) | |
| ppo_mean = statistics.mean(ppo_scores) | |
| print("\n[eval] Heuristic baseline:") | |
| heuristic_scores = eval_heuristic(EVAL_SEEDS) | |
| heuristic_mean = statistics.mean(heuristic_scores) | |
| print("\n── Results ──────────────────────────────────") | |
| print(f" PPO mean grader score : {ppo_mean:.4f}") | |
| print(f" Heuristic mean grader : {heuristic_mean:.4f} (expected ≈ {HEURISTIC_BASELINE})") | |
| delta = ppo_mean - heuristic_mean | |
| sign = "+" if delta >= 0 else "" | |
| print(f" Delta (PPO - heuristic): {sign}{delta:.4f}") | |
| if ppo_mean > 0.60: | |
| print(" ✅ PPO > 0.60 threshold — README update warranted.") | |
| else: | |
| print(" ⚠️ PPO < 0.60 — keep scaffolding but skip README PPO row.") | |
| if __name__ == "__main__": | |
| main() | |