Gov_Workflow_RL / rl /train_ppo.py
Siddharaj Shirke
deploy: clean code-only snapshot for HF Space
df97e68
"""
Phase 1: Masked PPO on district_backlog_easy.
Phase 2: Curriculum Masked PPO across all 3 tasks.
Usage:
python -m rl.train_ppo --phase 1 --timesteps 200000
python -m rl.train_ppo --phase 2 --timesteps 500000
python -m rl.train_ppo --phase 1 --task district_backlog_easy --n_envs 4
"""
from __future__ import annotations
import argparse
import os
import yaml
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from sb3_contrib import MaskablePPO
from rl.gov_workflow_env import GovWorkflowGymEnv
from rl.callbacks import GovWorkflowEvalCallback, CostMonitorCallback
from rl.curriculum import CurriculumScheduler, CurriculumConfig
os.makedirs("results/runs", exist_ok=True)
os.makedirs("results/best_model", exist_ok=True)
os.makedirs("results/eval_logs", exist_ok=True)
PHASE1_TASK_ID = "district_backlog_easy"
def _load_cfg(path: str) -> dict:
if os.path.exists(path):
# `utf-8-sig` safely handles files with/without UTF-8 BOM.
with open(path, encoding="utf-8-sig") as f:
return yaml.safe_load(f)
return {}
def _resolve_checkpoint_path(path_like: str | None) -> str | None:
if not path_like:
return None
if os.path.exists(path_like):
return path_like
zip_path = f"{path_like}.zip"
if os.path.exists(zip_path):
return zip_path
return None
# ---------------------------------------------------------------------------
# Phase 1 — single task easy
# ---------------------------------------------------------------------------
def train_phase1(
total_timesteps: int = 200_000,
n_envs: int = 4,
seed: int = 42,
config_path: str = "rl/configs/ppo_easy.yaml",
eval_freq_override: int | None = None,
n_eval_episodes_override: int | None = None,
disable_eval_callback: bool = False,
no_progress_bar: bool = False,
grader_eval_freq_multiplier_override: int | None = None,
resume_path: str | None = None,
) -> MaskablePPO:
cfg = _load_cfg(config_path)
hp = cfg.get("hyperparameters", {})
tr_c = cfg.get("training", {})
def _make(rank: int):
def _init():
return Monitor(GovWorkflowGymEnv("district_backlog_easy", seed=seed + rank))
return _init
train_env = DummyVecEnv([_make(i) for i in range(n_envs)])
eval_freq = int(eval_freq_override if eval_freq_override is not None else tr_c.get("eval_freq", max(16_384 // n_envs, 1)))
n_eval_episodes = int(n_eval_episodes_override if n_eval_episodes_override is not None else tr_c.get("n_eval_episodes", 2))
eval_callback_enabled = bool(tr_c.get("enable_eval_callback", True)) and (not disable_eval_callback)
grader_eval_freq_multiplier = int(
grader_eval_freq_multiplier_override
if grader_eval_freq_multiplier_override is not None
else tr_c.get("grader_eval_freq_multiplier", 4)
)
callback_verbose = int(tr_c.get("callback_verbose", 0))
model_verbose = int(tr_c.get("model_verbose", 0))
progress_bar_enabled = (not no_progress_bar) and bool(tr_c.get("progress_bar", False))
callbacks = [CostMonitorCallback()]
if eval_callback_enabled:
eval_env = GovWorkflowGymEnv("district_backlog_easy", seed=seed + 1000, hard_action_mask=True)
eval_cb = GovWorkflowEvalCallback(
eval_env=eval_env,
eval_freq=max(eval_freq, 1),
n_eval_episodes=max(n_eval_episodes, 1),
grader_eval_freq_multiplier=max(grader_eval_freq_multiplier, 1),
best_model_save_path="results/best_model",
log_path="results/eval_logs",
task_id="district_backlog_easy",
verbose=callback_verbose,
)
callbacks.insert(0, eval_cb)
resolved_resume = _resolve_checkpoint_path(resume_path)
if resume_path and resolved_resume is None:
raise FileNotFoundError(
f"Phase 1 resume checkpoint not found: {resume_path} (or {resume_path}.zip)"
)
if resolved_resume:
print(f"[Phase 1] Resuming from {resolved_resume}")
model = MaskablePPO.load(resolved_resume, env=train_env)
else:
model = MaskablePPO(
policy="MlpPolicy",
env=train_env,
learning_rate=float(hp.get("learning_rate", 3e-4)),
n_steps=int(hp.get("n_steps", 512)),
batch_size=int(hp.get("batch_size", 64)),
n_epochs=int(hp.get("n_epochs", 10)),
gamma=float(hp.get("gamma", 0.99)),
gae_lambda=float(hp.get("gae_lambda", 0.95)),
clip_range=float(hp.get("clip_range", 0.2)),
ent_coef=float(hp.get("ent_coef", 0.01)),
vf_coef=float(hp.get("vf_coef", 0.5)),
max_grad_norm=float(hp.get("max_grad_norm", 0.5)),
policy_kwargs=dict(net_arch=hp.get("net_arch", [256, 256])),
tensorboard_log="results/runs/phase1_masked_ppo",
verbose=model_verbose,
seed=seed,
)
print(
f"\n[Phase 1] Masked PPO | timesteps={total_timesteps} | n_envs={n_envs} "
f"| eval_cb={'on' if eval_callback_enabled else 'off'} "
f"| eval_freq={max(eval_freq,1)} | n_eval_episodes={max(n_eval_episodes,1)} "
f"| grader_eval_x{max(grader_eval_freq_multiplier, 1)}"
)
model.learn(
total_timesteps=total_timesteps,
callback=callbacks,
tb_log_name="masked_ppo_easy",
reset_num_timesteps=not bool(resolved_resume),
progress_bar=progress_bar_enabled,
)
model.save("results/best_model/phase1_final")
print("[Phase 1] Done -> results/best_model/phase1_final")
return model
# ---------------------------------------------------------------------------
# Phase 2 — curriculum across all tasks
# ---------------------------------------------------------------------------
def train_phase2(
total_timesteps: int = 500_000,
n_envs: int = 4,
seed: int = 42,
config_path: str = "rl/configs/curriculum.yaml",
) -> MaskablePPO:
cfg = _load_cfg(config_path)
if not cfg and config_path.endswith("curriculum.yaml"):
# Backward compatibility with previous filename.
cfg = _load_cfg("rl/configs/ppo_curriculum.yaml")
hp = cfg.get("hyperparameters", {})
cur_c = cfg.get("curriculum", {})
tr_c = cfg.get("training", {})
scheduler = CurriculumScheduler(
total_timesteps=total_timesteps,
config=CurriculumConfig(
stage1_end_frac=float(cur_c.get("stage1_end_frac", 0.30)),
stage2_end_frac=float(cur_c.get("stage2_end_frac", 0.70)),
stage3_weights=tuple(cur_c.get("stage3_weights", [0.20, 0.40, 0.40])),
),
rng_seed=seed,
)
global_step_counter = [0]
def _sample_task() -> str:
return scheduler.sample_task(global_step_counter[0])
def _make_curr(rank: int):
def _init():
env = GovWorkflowGymEnv(
task_id="district_backlog_easy",
seed=seed + rank,
)
env.set_task_sampler(_sample_task, global_step_counter)
return Monitor(env)
return _init
train_env = DummyVecEnv([_make_curr(i) for i in range(n_envs)])
eval_task_id = str(tr_c.get("eval_task_id", "mixed_urgency_medium"))
eval_env = GovWorkflowGymEnv(eval_task_id, seed=seed + 999, hard_action_mask=True)
eval_cb = GovWorkflowEvalCallback(
eval_env=eval_env,
eval_freq=int(tr_c.get("eval_freq", max(4096 // n_envs, 1))),
n_eval_episodes=int(tr_c.get("n_eval_episodes", 3)),
grader_eval_freq_multiplier=int(tr_c.get("grader_eval_freq_multiplier", 4)),
best_model_save_path="results/best_model",
log_path="results/eval_logs",
task_id=eval_task_id,
verbose=1,
)
warm_start_from = str(tr_c.get("warm_start_from", "results/best_model/phase1_final"))
warm_start_path = _resolve_checkpoint_path(warm_start_from)
if warm_start_path and os.path.exists(warm_start_path):
print(f"[Phase 2] Warm-starting from {warm_start_path}")
model = MaskablePPO.load(warm_start_path, env=train_env)
else:
model = MaskablePPO(
policy="MlpPolicy",
env=train_env,
learning_rate=float(hp.get("learning_rate", 2e-4)),
n_steps=int(hp.get("n_steps", 512)),
batch_size=int(hp.get("batch_size", 64)),
n_epochs=int(hp.get("n_epochs", 10)),
gamma=float(hp.get("gamma", 0.99)),
gae_lambda=float(hp.get("gae_lambda", 0.95)),
clip_range=float(hp.get("clip_range", 0.2)),
ent_coef=float(hp.get("ent_coef", 0.005)),
policy_kwargs=dict(net_arch=hp.get("net_arch", [256, 256])),
tensorboard_log="results/runs/phase2_curriculum_ppo",
verbose=1,
seed=seed,
)
print(f"\n[Phase 2] Curriculum PPO | timesteps={total_timesteps}")
model.learn(
total_timesteps=total_timesteps,
callback=[eval_cb, CostMonitorCallback()],
tb_log_name="curriculum_ppo",
progress_bar=True,
)
model.save("results/best_model/phase2_final")
print("[Phase 2] Done -> results/best_model/phase2_final")
return model
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--phase", type=int, choices=[1, 2], default=1)
parser.add_argument("--timesteps", type=int, default=200_000)
parser.add_argument("--n-envs", "--n_envs", dest="n_envs", type=int, default=4)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--task",
default=None,
help=(
"CLI compatibility alias. Phase 1 supports only "
f"'{PHASE1_TASK_ID}'. Phase 2 ignores this flag."
),
)
parser.add_argument(
"--phase1-config",
default="rl/configs/ppo_easy.yaml",
help="Config file for Phase 1 training.",
)
parser.add_argument(
"--phase1-eval-freq",
type=int,
default=None,
help="Override Phase 1 eval callback frequency (in calls).",
)
parser.add_argument(
"--phase1-n-eval-episodes",
type=int,
default=None,
help="Override Phase 1 eval callback episodes per eval.",
)
parser.add_argument(
"--phase1-disable-eval-callback",
action="store_true",
help="Disable Phase 1 evaluation callback to avoid pause-heavy eval blocks.",
)
parser.add_argument(
"--phase1-no-progress-bar",
action="store_true",
help="Disable tqdm progress bar rendering for Phase 1.",
)
parser.add_argument(
"--phase1-grader-eval-freq-multiplier",
type=int,
default=None,
help="Run grader eval every N * eval_freq callback ticks for Phase 1.",
)
parser.add_argument(
"--resume",
default=None,
help="Resume Phase 1 from checkpoint path (with or without .zip suffix).",
)
parser.add_argument(
"--phase2-config",
default="rl/configs/curriculum.yaml",
help="Config file for Phase 2 curriculum training.",
)
args = parser.parse_args()
if args.phase == 1 and args.task and args.task != PHASE1_TASK_ID:
raise ValueError(
f"Phase 1 currently supports only task '{PHASE1_TASK_ID}', got '{args.task}'."
)
if args.phase == 2 and args.task:
print(f"[Phase 2] Ignoring --task={args.task}; curriculum scheduler controls task sampling.")
if args.phase == 1:
train_phase1(
total_timesteps=args.timesteps,
n_envs=args.n_envs,
seed=args.seed,
config_path=args.phase1_config,
eval_freq_override=args.phase1_eval_freq,
n_eval_episodes_override=args.phase1_n_eval_episodes,
disable_eval_callback=args.phase1_disable_eval_callback,
no_progress_bar=args.phase1_no_progress_bar,
grader_eval_freq_multiplier_override=args.phase1_grader_eval_freq_multiplier,
resume_path=args.resume,
)
else:
train_phase2(
total_timesteps=args.timesteps,
n_envs=args.n_envs,
seed=args.seed,
config_path=args.phase2_config,
)
if __name__ == "__main__":
main()