""" Phase 3: Recurrent PPO (LSTM policy) training. This trainer keeps the existing 28-action design and uses curriculum sampling across tasks (easy -> medium -> hard). Because current sb3-contrib releases do not provide MaskableRecurrentPPO, we enforce action masks in two places: 1) hard mask in GovWorkflowGymEnv before executing an action, 2) recurrent evaluation callback with masked action sanitization. Usage: python -m rl.train_recurrent --timesteps 600000 python -m rl.train_recurrent --task cross_department_hard --n_envs 4 """ from __future__ import annotations import argparse import os from typing import Any import yaml from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv from sb3_contrib import MaskablePPO, RecurrentPPO from rl.callbacks import CostMonitorCallback, RecurrentEvalCallback from rl.curriculum import CurriculumConfig, CurriculumScheduler from rl.gov_workflow_env import GovWorkflowGymEnv os.makedirs("results/runs", exist_ok=True) os.makedirs("results/best_model", exist_ok=True) os.makedirs("results/eval_logs", exist_ok=True) def _load_cfg(path: str) -> dict: if os.path.exists(path): with open(path, encoding="utf-8-sig") as f: return yaml.safe_load(f) return {} def _transfer_matching_policy_weights( recurrent_model: RecurrentPPO, flat_model_path: str, exclude_prefixes: tuple[str, ...] = (), ) -> int: """ Transfer compatible policy weights from a flat MaskablePPO checkpoint. Returns number of copied tensors. """ src_path = flat_model_path if not src_path.endswith(".zip"): src_path = f"{src_path}.zip" if not os.path.exists(src_path): return 0 try: flat_model = MaskablePPO.load(src_path) except Exception as exc: print(f"[Phase 3] Skipping flat-weight transfer, could not load MaskablePPO from {src_path}: {exc}") return 0 src_state = flat_model.policy.state_dict() dst_state = recurrent_model.policy.state_dict() copied = 0 for key, dst_tensor in dst_state.items(): if any(key.startswith(prefix) for prefix in exclude_prefixes): continue src_tensor = src_state.get(key) if src_tensor is None: continue if tuple(src_tensor.shape) != tuple(dst_tensor.shape): continue dst_state[key] = src_tensor copied += 1 recurrent_model.policy.load_state_dict(dst_state, strict=False) return copied def train_phase3( total_timesteps: int = 600_000, n_envs: int = 4, seed: int = 42, config_path: str = "rl/configs/recurrent.yaml", eval_task_id_override: str | None = None, ) -> RecurrentPPO: cfg = _load_cfg(config_path) hp = cfg.get("hyperparameters", {}) cur_c = cfg.get("curriculum", {}) tr_c = cfg.get("training", {}) scheduler = CurriculumScheduler( total_timesteps=total_timesteps, config=CurriculumConfig( stage1_end_frac=float(cur_c.get("stage1_end_frac", 0.20)), stage2_end_frac=float(cur_c.get("stage2_end_frac", 0.55)), stage3_weights=tuple(cur_c.get("stage3_weights", [0.15, 0.35, 0.50])), ), rng_seed=seed, ) global_step_counter = [0] hard_action_mask_train = bool(tr_c.get("hard_action_mask_train", True)) hard_action_mask_eval = bool(tr_c.get("hard_action_mask_eval", True)) def _sample_task() -> str: return scheduler.sample_task(global_step_counter[0]) def _make_curr(rank: int): def _init(): env = GovWorkflowGymEnv( task_id="district_backlog_easy", seed=seed + rank, hard_action_mask=hard_action_mask_train, ) env.set_task_sampler(_sample_task, global_step_counter) return Monitor(env) return _init train_env = DummyVecEnv([_make_curr(i) for i in range(n_envs)]) eval_task_id = str(eval_task_id_override or tr_c.get("eval_task_id", "mixed_urgency_medium")) eval_env = GovWorkflowGymEnv(eval_task_id, seed=seed + 999, hard_action_mask=hard_action_mask_eval) eval_cb = RecurrentEvalCallback( eval_env=eval_env, eval_freq=int(tr_c.get("eval_freq", max(4096 // n_envs, 1))), n_eval_episodes=int(tr_c.get("n_eval_episodes", 3)), best_model_save_path="results/best_model", log_path="results/eval_logs", task_id=eval_task_id, verbose=1, ) model = RecurrentPPO( policy="MlpLstmPolicy", env=train_env, learning_rate=float(hp.get("learning_rate", 1e-4)), n_steps=int(hp.get("n_steps", 512)), batch_size=int(hp.get("batch_size", 128)), n_epochs=int(hp.get("n_epochs", 10)), gamma=float(hp.get("gamma", 0.995)), gae_lambda=float(hp.get("gae_lambda", 0.95)), clip_range=float(hp.get("clip_range", 0.2)), ent_coef=float(hp.get("ent_coef", 0.002)), vf_coef=float(hp.get("vf_coef", 0.5)), max_grad_norm=float(hp.get("max_grad_norm", 0.5)), policy_kwargs=dict( net_arch=hp.get("net_arch", [256, 256]), lstm_hidden_size=int(hp.get("lstm_hidden_size", 128)), n_lstm_layers=int(hp.get("n_lstm_layers", 1)), shared_lstm=bool(hp.get("shared_lstm", False)), enable_critic_lstm=bool(hp.get("enable_critic_lstm", True)), ), tensorboard_log="results/runs/phase3_recurrent_ppo", verbose=1, seed=seed, ) warm_start_from = str(tr_c.get("warm_start_from", "results/best_model/phase2_final")) transfer_flat = bool(tr_c.get("transfer_flat_weights", True)) transfer_exclude_prefixes = tuple( tr_c.get("transfer_exclude_prefixes", ["action_net.", "value_net."]) ) if transfer_flat: copied = _transfer_matching_policy_weights( model, warm_start_from, exclude_prefixes=transfer_exclude_prefixes, ) if copied > 0: print(f"[Phase 3] Transferred {copied} compatible policy tensors from {warm_start_from}") else: print(f"[Phase 3] No compatible transfer tensors found from {warm_start_from}") print(f"\n[Phase 3] Recurrent PPO | timesteps={total_timesteps} | n_envs={n_envs}") model.learn( total_timesteps=total_timesteps, callback=[eval_cb, CostMonitorCallback()], tb_log_name="recurrent_ppo", progress_bar=True, ) model.save("results/best_model/phase3_final") print("[Phase 3] Done -> results/best_model/phase3_final") return model def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--timesteps", type=int, default=600_000) parser.add_argument("--n-envs", "--n_envs", dest="n_envs", type=int, default=4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--config", default="rl/configs/recurrent.yaml") parser.add_argument( "--task", default=None, choices=["district_backlog_easy", "mixed_urgency_medium", "cross_department_hard"], help="Compatibility alias for evaluation task used by recurrent eval callback.", ) args = parser.parse_args() train_phase3( total_timesteps=args.timesteps, n_envs=args.n_envs, seed=args.seed, config_path=args.config, eval_task_id_override=args.task, ) if __name__ == "__main__": main()