Spaces:
Running
Running
| """ | |
| Phase 1: Masked PPO on district_backlog_easy. | |
| Phase 2: Curriculum Masked PPO across all 3 tasks. | |
| Usage: | |
| python -m rl.train_ppo --phase 1 --timesteps 200000 | |
| python -m rl.train_ppo --phase 2 --timesteps 500000 | |
| python -m rl.train_ppo --phase 1 --task district_backlog_easy --n_envs 4 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import yaml | |
| from stable_baselines3.common.vec_env import DummyVecEnv | |
| from stable_baselines3.common.monitor import Monitor | |
| from sb3_contrib import MaskablePPO | |
| from rl.gov_workflow_env import GovWorkflowGymEnv | |
| from rl.callbacks import GovWorkflowEvalCallback, CostMonitorCallback | |
| from rl.curriculum import CurriculumScheduler, CurriculumConfig | |
| os.makedirs("results/runs", exist_ok=True) | |
| os.makedirs("results/best_model", exist_ok=True) | |
| os.makedirs("results/eval_logs", exist_ok=True) | |
| PHASE1_TASK_ID = "district_backlog_easy" | |
| def _load_cfg(path: str) -> dict: | |
| if os.path.exists(path): | |
| # `utf-8-sig` safely handles files with/without UTF-8 BOM. | |
| with open(path, encoding="utf-8-sig") as f: | |
| return yaml.safe_load(f) | |
| return {} | |
| def _resolve_checkpoint_path(path_like: str | None) -> str | None: | |
| if not path_like: | |
| return None | |
| if os.path.exists(path_like): | |
| return path_like | |
| zip_path = f"{path_like}.zip" | |
| if os.path.exists(zip_path): | |
| return zip_path | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Phase 1 — single task easy | |
| # --------------------------------------------------------------------------- | |
| def train_phase1( | |
| total_timesteps: int = 200_000, | |
| n_envs: int = 4, | |
| seed: int = 42, | |
| config_path: str = "rl/configs/ppo_easy.yaml", | |
| eval_freq_override: int | None = None, | |
| n_eval_episodes_override: int | None = None, | |
| disable_eval_callback: bool = False, | |
| no_progress_bar: bool = False, | |
| grader_eval_freq_multiplier_override: int | None = None, | |
| resume_path: str | None = None, | |
| ) -> MaskablePPO: | |
| cfg = _load_cfg(config_path) | |
| hp = cfg.get("hyperparameters", {}) | |
| tr_c = cfg.get("training", {}) | |
| def _make(rank: int): | |
| def _init(): | |
| return Monitor(GovWorkflowGymEnv("district_backlog_easy", seed=seed + rank)) | |
| return _init | |
| train_env = DummyVecEnv([_make(i) for i in range(n_envs)]) | |
| eval_freq = int(eval_freq_override if eval_freq_override is not None else tr_c.get("eval_freq", max(16_384 // n_envs, 1))) | |
| n_eval_episodes = int(n_eval_episodes_override if n_eval_episodes_override is not None else tr_c.get("n_eval_episodes", 2)) | |
| eval_callback_enabled = bool(tr_c.get("enable_eval_callback", True)) and (not disable_eval_callback) | |
| grader_eval_freq_multiplier = int( | |
| grader_eval_freq_multiplier_override | |
| if grader_eval_freq_multiplier_override is not None | |
| else tr_c.get("grader_eval_freq_multiplier", 4) | |
| ) | |
| callback_verbose = int(tr_c.get("callback_verbose", 0)) | |
| model_verbose = int(tr_c.get("model_verbose", 0)) | |
| progress_bar_enabled = (not no_progress_bar) and bool(tr_c.get("progress_bar", False)) | |
| callbacks = [CostMonitorCallback()] | |
| if eval_callback_enabled: | |
| eval_env = GovWorkflowGymEnv("district_backlog_easy", seed=seed + 1000, hard_action_mask=True) | |
| eval_cb = GovWorkflowEvalCallback( | |
| eval_env=eval_env, | |
| eval_freq=max(eval_freq, 1), | |
| n_eval_episodes=max(n_eval_episodes, 1), | |
| grader_eval_freq_multiplier=max(grader_eval_freq_multiplier, 1), | |
| best_model_save_path="results/best_model", | |
| log_path="results/eval_logs", | |
| task_id="district_backlog_easy", | |
| verbose=callback_verbose, | |
| ) | |
| callbacks.insert(0, eval_cb) | |
| resolved_resume = _resolve_checkpoint_path(resume_path) | |
| if resume_path and resolved_resume is None: | |
| raise FileNotFoundError( | |
| f"Phase 1 resume checkpoint not found: {resume_path} (or {resume_path}.zip)" | |
| ) | |
| if resolved_resume: | |
| print(f"[Phase 1] Resuming from {resolved_resume}") | |
| model = MaskablePPO.load(resolved_resume, env=train_env) | |
| else: | |
| model = MaskablePPO( | |
| policy="MlpPolicy", | |
| env=train_env, | |
| learning_rate=float(hp.get("learning_rate", 3e-4)), | |
| n_steps=int(hp.get("n_steps", 512)), | |
| batch_size=int(hp.get("batch_size", 64)), | |
| n_epochs=int(hp.get("n_epochs", 10)), | |
| gamma=float(hp.get("gamma", 0.99)), | |
| gae_lambda=float(hp.get("gae_lambda", 0.95)), | |
| clip_range=float(hp.get("clip_range", 0.2)), | |
| ent_coef=float(hp.get("ent_coef", 0.01)), | |
| vf_coef=float(hp.get("vf_coef", 0.5)), | |
| max_grad_norm=float(hp.get("max_grad_norm", 0.5)), | |
| policy_kwargs=dict(net_arch=hp.get("net_arch", [256, 256])), | |
| tensorboard_log="results/runs/phase1_masked_ppo", | |
| verbose=model_verbose, | |
| seed=seed, | |
| ) | |
| print( | |
| f"\n[Phase 1] Masked PPO | timesteps={total_timesteps} | n_envs={n_envs} " | |
| f"| eval_cb={'on' if eval_callback_enabled else 'off'} " | |
| f"| eval_freq={max(eval_freq,1)} | n_eval_episodes={max(n_eval_episodes,1)} " | |
| f"| grader_eval_x{max(grader_eval_freq_multiplier, 1)}" | |
| ) | |
| model.learn( | |
| total_timesteps=total_timesteps, | |
| callback=callbacks, | |
| tb_log_name="masked_ppo_easy", | |
| reset_num_timesteps=not bool(resolved_resume), | |
| progress_bar=progress_bar_enabled, | |
| ) | |
| model.save("results/best_model/phase1_final") | |
| print("[Phase 1] Done -> results/best_model/phase1_final") | |
| return model | |
| # --------------------------------------------------------------------------- | |
| # Phase 2 — curriculum across all tasks | |
| # --------------------------------------------------------------------------- | |
| def train_phase2( | |
| total_timesteps: int = 500_000, | |
| n_envs: int = 4, | |
| seed: int = 42, | |
| config_path: str = "rl/configs/curriculum.yaml", | |
| ) -> MaskablePPO: | |
| cfg = _load_cfg(config_path) | |
| if not cfg and config_path.endswith("curriculum.yaml"): | |
| # Backward compatibility with previous filename. | |
| cfg = _load_cfg("rl/configs/ppo_curriculum.yaml") | |
| hp = cfg.get("hyperparameters", {}) | |
| cur_c = cfg.get("curriculum", {}) | |
| tr_c = cfg.get("training", {}) | |
| scheduler = CurriculumScheduler( | |
| total_timesteps=total_timesteps, | |
| config=CurriculumConfig( | |
| stage1_end_frac=float(cur_c.get("stage1_end_frac", 0.30)), | |
| stage2_end_frac=float(cur_c.get("stage2_end_frac", 0.70)), | |
| stage3_weights=tuple(cur_c.get("stage3_weights", [0.20, 0.40, 0.40])), | |
| ), | |
| rng_seed=seed, | |
| ) | |
| global_step_counter = [0] | |
| def _sample_task() -> str: | |
| return scheduler.sample_task(global_step_counter[0]) | |
| def _make_curr(rank: int): | |
| def _init(): | |
| env = GovWorkflowGymEnv( | |
| task_id="district_backlog_easy", | |
| seed=seed + rank, | |
| ) | |
| env.set_task_sampler(_sample_task, global_step_counter) | |
| return Monitor(env) | |
| return _init | |
| train_env = DummyVecEnv([_make_curr(i) for i in range(n_envs)]) | |
| eval_task_id = str(tr_c.get("eval_task_id", "mixed_urgency_medium")) | |
| eval_env = GovWorkflowGymEnv(eval_task_id, seed=seed + 999, hard_action_mask=True) | |
| eval_cb = GovWorkflowEvalCallback( | |
| eval_env=eval_env, | |
| eval_freq=int(tr_c.get("eval_freq", max(4096 // n_envs, 1))), | |
| n_eval_episodes=int(tr_c.get("n_eval_episodes", 3)), | |
| grader_eval_freq_multiplier=int(tr_c.get("grader_eval_freq_multiplier", 4)), | |
| best_model_save_path="results/best_model", | |
| log_path="results/eval_logs", | |
| task_id=eval_task_id, | |
| verbose=1, | |
| ) | |
| warm_start_from = str(tr_c.get("warm_start_from", "results/best_model/phase1_final")) | |
| warm_start_path = _resolve_checkpoint_path(warm_start_from) | |
| if warm_start_path and os.path.exists(warm_start_path): | |
| print(f"[Phase 2] Warm-starting from {warm_start_path}") | |
| model = MaskablePPO.load(warm_start_path, env=train_env) | |
| else: | |
| model = MaskablePPO( | |
| policy="MlpPolicy", | |
| env=train_env, | |
| learning_rate=float(hp.get("learning_rate", 2e-4)), | |
| n_steps=int(hp.get("n_steps", 512)), | |
| batch_size=int(hp.get("batch_size", 64)), | |
| n_epochs=int(hp.get("n_epochs", 10)), | |
| gamma=float(hp.get("gamma", 0.99)), | |
| gae_lambda=float(hp.get("gae_lambda", 0.95)), | |
| clip_range=float(hp.get("clip_range", 0.2)), | |
| ent_coef=float(hp.get("ent_coef", 0.005)), | |
| policy_kwargs=dict(net_arch=hp.get("net_arch", [256, 256])), | |
| tensorboard_log="results/runs/phase2_curriculum_ppo", | |
| verbose=1, | |
| seed=seed, | |
| ) | |
| print(f"\n[Phase 2] Curriculum PPO | timesteps={total_timesteps}") | |
| model.learn( | |
| total_timesteps=total_timesteps, | |
| callback=[eval_cb, CostMonitorCallback()], | |
| tb_log_name="curriculum_ppo", | |
| progress_bar=True, | |
| ) | |
| model.save("results/best_model/phase2_final") | |
| print("[Phase 2] Done -> results/best_model/phase2_final") | |
| return model | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--phase", type=int, choices=[1, 2], default=1) | |
| parser.add_argument("--timesteps", type=int, default=200_000) | |
| parser.add_argument("--n-envs", "--n_envs", dest="n_envs", type=int, default=4) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument( | |
| "--task", | |
| default=None, | |
| help=( | |
| "CLI compatibility alias. Phase 1 supports only " | |
| f"'{PHASE1_TASK_ID}'. Phase 2 ignores this flag." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--phase1-config", | |
| default="rl/configs/ppo_easy.yaml", | |
| help="Config file for Phase 1 training.", | |
| ) | |
| parser.add_argument( | |
| "--phase1-eval-freq", | |
| type=int, | |
| default=None, | |
| help="Override Phase 1 eval callback frequency (in calls).", | |
| ) | |
| parser.add_argument( | |
| "--phase1-n-eval-episodes", | |
| type=int, | |
| default=None, | |
| help="Override Phase 1 eval callback episodes per eval.", | |
| ) | |
| parser.add_argument( | |
| "--phase1-disable-eval-callback", | |
| action="store_true", | |
| help="Disable Phase 1 evaluation callback to avoid pause-heavy eval blocks.", | |
| ) | |
| parser.add_argument( | |
| "--phase1-no-progress-bar", | |
| action="store_true", | |
| help="Disable tqdm progress bar rendering for Phase 1.", | |
| ) | |
| parser.add_argument( | |
| "--phase1-grader-eval-freq-multiplier", | |
| type=int, | |
| default=None, | |
| help="Run grader eval every N * eval_freq callback ticks for Phase 1.", | |
| ) | |
| parser.add_argument( | |
| "--resume", | |
| default=None, | |
| help="Resume Phase 1 from checkpoint path (with or without .zip suffix).", | |
| ) | |
| parser.add_argument( | |
| "--phase2-config", | |
| default="rl/configs/curriculum.yaml", | |
| help="Config file for Phase 2 curriculum training.", | |
| ) | |
| args = parser.parse_args() | |
| if args.phase == 1 and args.task and args.task != PHASE1_TASK_ID: | |
| raise ValueError( | |
| f"Phase 1 currently supports only task '{PHASE1_TASK_ID}', got '{args.task}'." | |
| ) | |
| if args.phase == 2 and args.task: | |
| print(f"[Phase 2] Ignoring --task={args.task}; curriculum scheduler controls task sampling.") | |
| if args.phase == 1: | |
| train_phase1( | |
| total_timesteps=args.timesteps, | |
| n_envs=args.n_envs, | |
| seed=args.seed, | |
| config_path=args.phase1_config, | |
| eval_freq_override=args.phase1_eval_freq, | |
| n_eval_episodes_override=args.phase1_n_eval_episodes, | |
| disable_eval_callback=args.phase1_disable_eval_callback, | |
| no_progress_bar=args.phase1_no_progress_bar, | |
| grader_eval_freq_multiplier_override=args.phase1_grader_eval_freq_multiplier, | |
| resume_path=args.resume, | |
| ) | |
| else: | |
| train_phase2( | |
| total_timesteps=args.timesteps, | |
| n_envs=args.n_envs, | |
| seed=args.seed, | |
| config_path=args.phase2_config, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |