OPENENV_RL_01 / rl /train_recurrent.py
Siddharaj Shirke
deploy: fresh snapshot to Hugging Face Space
3eae4cc
"""
Phase 3: Recurrent PPO (LSTM policy) training.
This trainer keeps the existing 28-action design and uses curriculum sampling
across tasks (easy -> medium -> hard). Because current sb3-contrib releases do
not provide MaskableRecurrentPPO, we enforce action masks in two places:
1) hard mask in GovWorkflowGymEnv before executing an action,
2) recurrent evaluation callback with masked action sanitization.
Usage:
python -m rl.train_recurrent --timesteps 600000
python -m rl.train_recurrent --task cross_department_hard --n_envs 4
"""
from __future__ import annotations
import argparse
import os
from typing import Any
import yaml
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib import MaskablePPO, RecurrentPPO
from rl.callbacks import CostMonitorCallback, RecurrentEvalCallback
from rl.curriculum import CurriculumConfig, CurriculumScheduler
from rl.gov_workflow_env import GovWorkflowGymEnv
os.makedirs("results/runs", exist_ok=True)
os.makedirs("results/best_model", exist_ok=True)
os.makedirs("results/eval_logs", exist_ok=True)
def _load_cfg(path: str) -> dict:
if os.path.exists(path):
with open(path, encoding="utf-8-sig") as f:
return yaml.safe_load(f)
return {}
def _transfer_matching_policy_weights(
recurrent_model: RecurrentPPO,
flat_model_path: str,
exclude_prefixes: tuple[str, ...] = (),
) -> int:
"""
Transfer compatible policy weights from a flat MaskablePPO checkpoint.
Returns number of copied tensors.
"""
src_path = flat_model_path
if not src_path.endswith(".zip"):
src_path = f"{src_path}.zip"
if not os.path.exists(src_path):
return 0
try:
flat_model = MaskablePPO.load(src_path)
except Exception as exc:
print(f"[Phase 3] Skipping flat-weight transfer, could not load MaskablePPO from {src_path}: {exc}")
return 0
src_state = flat_model.policy.state_dict()
dst_state = recurrent_model.policy.state_dict()
copied = 0
for key, dst_tensor in dst_state.items():
if any(key.startswith(prefix) for prefix in exclude_prefixes):
continue
src_tensor = src_state.get(key)
if src_tensor is None:
continue
if tuple(src_tensor.shape) != tuple(dst_tensor.shape):
continue
dst_state[key] = src_tensor
copied += 1
recurrent_model.policy.load_state_dict(dst_state, strict=False)
return copied
def train_phase3(
total_timesteps: int = 600_000,
n_envs: int = 4,
seed: int = 42,
config_path: str = "rl/configs/recurrent.yaml",
eval_task_id_override: str | None = None,
) -> RecurrentPPO:
cfg = _load_cfg(config_path)
hp = cfg.get("hyperparameters", {})
cur_c = cfg.get("curriculum", {})
tr_c = cfg.get("training", {})
scheduler = CurriculumScheduler(
total_timesteps=total_timesteps,
config=CurriculumConfig(
stage1_end_frac=float(cur_c.get("stage1_end_frac", 0.20)),
stage2_end_frac=float(cur_c.get("stage2_end_frac", 0.55)),
stage3_weights=tuple(cur_c.get("stage3_weights", [0.15, 0.35, 0.50])),
),
rng_seed=seed,
)
global_step_counter = [0]
hard_action_mask_train = bool(tr_c.get("hard_action_mask_train", True))
hard_action_mask_eval = bool(tr_c.get("hard_action_mask_eval", True))
def _sample_task() -> str:
return scheduler.sample_task(global_step_counter[0])
def _make_curr(rank: int):
def _init():
env = GovWorkflowGymEnv(
task_id="district_backlog_easy",
seed=seed + rank,
hard_action_mask=hard_action_mask_train,
)
env.set_task_sampler(_sample_task, global_step_counter)
return Monitor(env)
return _init
train_env = DummyVecEnv([_make_curr(i) for i in range(n_envs)])
eval_task_id = str(eval_task_id_override or tr_c.get("eval_task_id", "mixed_urgency_medium"))
eval_env = GovWorkflowGymEnv(eval_task_id, seed=seed + 999, hard_action_mask=hard_action_mask_eval)
eval_cb = RecurrentEvalCallback(
eval_env=eval_env,
eval_freq=int(tr_c.get("eval_freq", max(4096 // n_envs, 1))),
n_eval_episodes=int(tr_c.get("n_eval_episodes", 3)),
best_model_save_path="results/best_model",
log_path="results/eval_logs",
task_id=eval_task_id,
verbose=1,
)
model = RecurrentPPO(
policy="MlpLstmPolicy",
env=train_env,
learning_rate=float(hp.get("learning_rate", 1e-4)),
n_steps=int(hp.get("n_steps", 512)),
batch_size=int(hp.get("batch_size", 128)),
n_epochs=int(hp.get("n_epochs", 10)),
gamma=float(hp.get("gamma", 0.995)),
gae_lambda=float(hp.get("gae_lambda", 0.95)),
clip_range=float(hp.get("clip_range", 0.2)),
ent_coef=float(hp.get("ent_coef", 0.002)),
vf_coef=float(hp.get("vf_coef", 0.5)),
max_grad_norm=float(hp.get("max_grad_norm", 0.5)),
policy_kwargs=dict(
net_arch=hp.get("net_arch", [256, 256]),
lstm_hidden_size=int(hp.get("lstm_hidden_size", 128)),
n_lstm_layers=int(hp.get("n_lstm_layers", 1)),
shared_lstm=bool(hp.get("shared_lstm", False)),
enable_critic_lstm=bool(hp.get("enable_critic_lstm", True)),
),
tensorboard_log="results/runs/phase3_recurrent_ppo",
verbose=1,
seed=seed,
)
warm_start_from = str(tr_c.get("warm_start_from", "results/best_model/phase2_final"))
transfer_flat = bool(tr_c.get("transfer_flat_weights", True))
transfer_exclude_prefixes = tuple(
tr_c.get("transfer_exclude_prefixes", ["action_net.", "value_net."])
)
if transfer_flat:
copied = _transfer_matching_policy_weights(
model,
warm_start_from,
exclude_prefixes=transfer_exclude_prefixes,
)
if copied > 0:
print(f"[Phase 3] Transferred {copied} compatible policy tensors from {warm_start_from}")
else:
print(f"[Phase 3] No compatible transfer tensors found from {warm_start_from}")
print(f"\n[Phase 3] Recurrent PPO | timesteps={total_timesteps} | n_envs={n_envs}")
model.learn(
total_timesteps=total_timesteps,
callback=[eval_cb, CostMonitorCallback()],
tb_log_name="recurrent_ppo",
progress_bar=True,
)
model.save("results/best_model/phase3_final")
print("[Phase 3] Done -> results/best_model/phase3_final")
return model
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--timesteps", type=int, default=600_000)
parser.add_argument("--n-envs", "--n_envs", dest="n_envs", type=int, default=4)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--config", default="rl/configs/recurrent.yaml")
parser.add_argument(
"--task",
default=None,
choices=["district_backlog_easy", "mixed_urgency_medium", "cross_department_hard"],
help="Compatibility alias for evaluation task used by recurrent eval callback.",
)
args = parser.parse_args()
train_phase3(
total_timesteps=args.timesteps,
n_envs=args.n_envs,
seed=args.seed,
config_path=args.config,
eval_task_id_override=args.task,
)
if __name__ == "__main__":
main()