Spaces:
Sleeping
Sleeping
File size: 2,972 Bytes
98a5a8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """
Train a PPO agent on the Hard_Multi scenario.
This is the key experiment: Hard_Multi has a secondary provider cascade at step 10
(Provider B degrades after A). A reactive heuristic cannot conserve budget in advance
and scores ~0.6094. An RL agent with access to step_count + budget_remaining can
learn anticipatory routing and should materially exceed the heuristic.
Usage:
uv run python train/train_ppo_hard_multi.py
Output:
trained_models/ppo_hard_multi_100k.zip — saved SB3 model
trained_models/ppo_hard_multi_100k_tb/ — TensorBoard logs
"""
from __future__ import annotations
import sys
from pathlib import Path
# Ensure project root is on sys.path when running as a script
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback
from train.gym_wrapper import BudgetRouterGymEnv
from budget_router.tasks import HARD_MULTI
# ── Config ──────────────────────────────────────────────────────────────────
N_ENVS = 4
TOTAL_STEPS = 100_000 # Hard_Multi needs more signal than Easy
SAVE_PATH = "trained_models/ppo_hard_multi_100k"
LOG_PATH = "trained_models/ppo_hard_multi_100k_tb"
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
# ────────────────────────────────────────────────────────────────────────────
def main() -> None:
print(f"[train:hard_multi] device={DEVICE} n_envs={N_ENVS} total_steps={TOTAL_STEPS:,}")
print("[train:hard_multi] Scenario: Provider A degrades step 0, Provider B degrades step 10")
print("[train:hard_multi] Heuristic baseline grader: 0.6094 (reactive, cannot conserve budget)")
train_env = make_vec_env(
lambda: BudgetRouterGymEnv(scenario=HARD_MULTI),
n_envs=N_ENVS,
)
eval_env = BudgetRouterGymEnv(scenario=HARD_MULTI, seed=99)
eval_cb = EvalCallback(
eval_env,
eval_freq=max(10_000 // N_ENVS, 1),
n_eval_episodes=10,
verbose=1,
)
model = PPO(
policy="MlpPolicy",
env=train_env,
n_steps=512,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
ent_coef=0.02, # slightly higher entropy to encourage exploration on harder task
learning_rate=3e-4,
verbose=1,
device=DEVICE,
)
model.learn(
total_timesteps=TOTAL_STEPS,
callback=eval_cb,
progress_bar=True,
)
model.save(SAVE_PATH)
print(f"[train:hard_multi] Model saved → {SAVE_PATH}.zip")
if __name__ == "__main__":
main()
|