Gov_Workflow_RL / scripts /pretrain_go_nogo.py
Siddharaj Shirke
deploy: clean code-only snapshot for HF Space
df97e68
"""
Pre-train checklist + GO/NO-GO gate for Gov Workflow RL Phase 1.
This script validates the local training stack without running training.
Use it before starting Phase 1 retraining.
Usage:
python scripts/pretrain_go_nogo.py
python scripts/pretrain_go_nogo.py --run-tests
"""
from __future__ import annotations
import argparse
import importlib
import json
import subprocess
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Callable
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
PHASE1_TASK = "district_backlog_easy"
EXPECTED_OBS_DIM = 84
EXPECTED_ACTIONS = 28
@dataclass
class CheckResult:
name: str
status: str # PASS | WARN | FAIL
detail: str
def _run_cmd(cmd: list[str], cwd: Path | None = None) -> tuple[int, str, str]:
proc = subprocess.run(
cmd,
cwd=str(cwd or ROOT),
capture_output=True,
text=True,
)
return proc.returncode, proc.stdout, proc.stderr
def check_required_files() -> CheckResult:
required = [
"rl/train_ppo.py",
"rl/train_recurrent.py",
"rl/gov_workflow_env.py",
"rl/feature_builder.py",
"rl/action_mask.py",
"rl/callbacks.py",
"rl/curriculum.py",
"rl/cost_tracker.py",
"rl/evaluate.py",
"rl/eval_grader.py",
"rl/plot_training.py",
"rl/configs/ppo_easy.yaml",
"app/env.py",
"app/models.py",
"app/tasks.py",
"app/reward.py",
"app/graders.py",
]
missing = [p for p in required if not (ROOT / p).exists()]
if missing:
return CheckResult(
name="required_files",
status="FAIL",
detail="Missing files: " + ", ".join(missing),
)
return CheckResult(
name="required_files",
status="PASS",
detail=f"{len(required)} required files present",
)
def check_python_imports() -> CheckResult:
modules = [
"yaml",
"numpy",
"gymnasium",
"torch",
"stable_baselines3",
"sb3_contrib",
"tensorboard",
"rl.train_ppo",
"rl.train_recurrent",
"rl.gov_workflow_env",
"rl.feature_builder",
"rl.action_mask",
"rl.callbacks",
"rl.evaluate",
"rl.eval_grader",
"app.env",
"app.tasks",
"app.graders",
]
failed: list[str] = []
for mod in modules:
try:
importlib.import_module(mod)
except Exception:
failed.append(mod)
if failed:
return CheckResult(
name="python_imports",
status="FAIL",
detail="Import failures: " + ", ".join(failed),
)
return CheckResult(
name="python_imports",
status="PASS",
detail=f"{len(modules)} modules import cleanly",
)
def check_compile() -> CheckResult:
targets = [
"rl/train_ppo.py",
"rl/train_recurrent.py",
"rl/gov_workflow_env.py",
"rl/feature_builder.py",
"rl/action_mask.py",
"rl/callbacks.py",
"rl/evaluate.py",
"rl/eval_grader.py",
"app/env.py",
"app/reward.py",
"app/graders.py",
"app/tasks.py",
]
cmd = [sys.executable, "-m", "py_compile", *targets]
rc, _out, err = _run_cmd(cmd, ROOT)
if rc != 0:
return CheckResult(
name="py_compile",
status="FAIL",
detail=err.strip() or "py_compile failed",
)
return CheckResult(
name="py_compile",
status="PASS",
detail=f"{len(targets)} files compiled successfully",
)
def check_env_contract() -> CheckResult:
try:
from rl.gov_workflow_env import GovWorkflowGymEnv
env = GovWorkflowGymEnv(task_id=PHASE1_TASK, seed=42)
obs, info = env.reset(seed=42)
masks = env.action_masks()
_obs2, reward, terminated, truncated, step_info = env.step(18)
problems: list[str] = []
if tuple(obs.shape) != (EXPECTED_OBS_DIM,):
problems.append(f"obs shape={tuple(obs.shape)} expected={(EXPECTED_OBS_DIM,)}")
if int(env.action_space.n) != EXPECTED_ACTIONS:
problems.append(f"action_space={env.action_space.n} expected={EXPECTED_ACTIONS}")
if len(masks) != EXPECTED_ACTIONS:
problems.append(f"mask_len={len(masks)} expected={EXPECTED_ACTIONS}")
if int(sum(bool(x) for x in masks)) <= 0:
problems.append("all actions masked")
if not isinstance(info, dict):
problems.append("reset info is not dict")
if not isinstance(step_info, dict):
problems.append("step info is not dict")
if not isinstance(float(reward), float):
problems.append("reward not float-castable")
if not isinstance(bool(terminated), bool) or not isinstance(bool(truncated), bool):
problems.append("terminated/truncated invalid type")
if problems:
return CheckResult(
name="gym_env_contract",
status="FAIL",
detail="; ".join(problems),
)
return CheckResult(
name="gym_env_contract",
status="PASS",
detail=f"obs={obs.shape}, action_n={env.action_space.n}, valid_masks={int(sum(masks))}",
)
except Exception as exc:
return CheckResult(
name="gym_env_contract",
status="FAIL",
detail=f"{type(exc).__name__}: {exc}",
)
def check_output_paths() -> CheckResult:
needed_dirs = [
ROOT / "results",
ROOT / "results" / "best_model",
ROOT / "results" / "runs",
ROOT / "results" / "eval_logs",
ROOT / "logs",
]
try:
for d in needed_dirs:
d.mkdir(parents=True, exist_ok=True)
probe = d / ".write_probe.tmp"
probe.write_text("ok", encoding="utf-8")
probe.unlink(missing_ok=True)
except Exception as exc:
return CheckResult(
name="output_paths",
status="FAIL",
detail=f"{type(exc).__name__}: {exc}",
)
return CheckResult(
name="output_paths",
status="PASS",
detail="results/ and logs/ are writable",
)
def check_train_cli() -> CheckResult:
cmd = [sys.executable, "-m", "rl.train_ppo", "--help"]
rc, out, err = _run_cmd(cmd, ROOT)
if rc != 0:
return CheckResult(
name="train_cli",
status="FAIL",
detail=err.strip() or "train_ppo --help failed",
)
needed_flags = [
"--phase",
"--timesteps",
"--n_envs",
"--task",
"--phase1-eval-freq",
"--phase1-n-eval-episodes",
"--phase1-disable-eval-callback",
"--phase1-grader-eval-freq-multiplier",
]
missing = [f for f in needed_flags if f not in out]
if missing:
return CheckResult(
name="train_cli",
status="WARN",
detail="Missing expected flags in help output: " + ", ".join(missing),
)
return CheckResult(
name="train_cli",
status="PASS",
detail="train_ppo CLI flags detected",
)
def check_config() -> CheckResult:
try:
import yaml
cfg_path = ROOT / "rl" / "configs" / "ppo_easy.yaml"
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8-sig")) or {}
hp = cfg.get("hyperparameters", {})
tr = cfg.get("training", {})
required_fields = [
("hyperparameters", "learning_rate"),
("hyperparameters", "n_steps"),
("hyperparameters", "batch_size"),
("training", "n_envs"),
("training", "seed"),
("training", "eval_freq"),
("training", "n_eval_episodes"),
]
missing = []
for section, key in required_fields:
parent = hp if section == "hyperparameters" else tr
if key not in parent:
missing.append(f"{section}.{key}")
if missing:
return CheckResult(
name="ppo_easy_config",
status="FAIL",
detail="Missing config fields: " + ", ".join(missing),
)
warnings = []
if int(tr.get("eval_freq", 0)) < 2048:
warnings.append("eval_freq is very low; may cause frequent pauses")
if int(tr.get("n_eval_episodes", 0)) > 5:
warnings.append("n_eval_episodes is high; callback cost may increase")
if warnings:
return CheckResult(
name="ppo_easy_config",
status="WARN",
detail="; ".join(warnings),
)
return CheckResult(
name="ppo_easy_config",
status="PASS",
detail="Phase 1 config fields are present and reasonable",
)
except Exception as exc:
return CheckResult(
name="ppo_easy_config",
status="FAIL",
detail=f"{type(exc).__name__}: {exc}",
)
def check_torch_device() -> CheckResult:
try:
import torch
if torch.cuda.is_available():
return CheckResult(
name="torch_device",
status="PASS",
detail=f"CUDA available ({torch.cuda.get_device_name(0)})",
)
return CheckResult(
name="torch_device",
status="WARN",
detail="CUDA not available; CPU training is expected",
)
except Exception as exc:
return CheckResult(
name="torch_device",
status="WARN",
detail=f"torch device check skipped: {type(exc).__name__}: {exc}",
)
def run_targeted_tests() -> CheckResult:
test_cmd = [
sys.executable,
"-m",
"pytest",
"tests/test_env.py",
"tests/test_gym_wrapper.py",
"tests/test_gym_wrapper_integration.py",
"tests/test_feature_builder.py",
"tests/test_action_mask.py",
"tests/test_curriculum.py",
"tests/test_rl_evaluate.py",
"-q",
"--tb=short",
]
rc, out, err = _run_cmd(test_cmd, ROOT)
if rc != 0:
return CheckResult(
name="targeted_tests",
status="FAIL",
detail=(out + "\n" + err).strip()[-1200:],
)
return CheckResult(
name="targeted_tests",
status="PASS",
detail=out.strip().splitlines()[-1] if out.strip() else "targeted tests passed",
)
def _print_results(results: list[CheckResult]) -> None:
print("\n=== Pre-Train Checklist Results ===")
for r in results:
print(f"[{r.status}] {r.name}: {r.detail}")
fail_count = sum(1 for r in results if r.status == "FAIL")
warn_count = sum(1 for r in results if r.status == "WARN")
print("\n=== Gate Decision ===")
if fail_count > 0:
print(f"NO-GO: {fail_count} failing check(s). Resolve failures before training.")
else:
print(
f"GO: no failing checks. "
f"{warn_count} warning(s) can be reviewed but do not block training."
)
def _print_next_commands(args: argparse.Namespace) -> None:
print("\n=== Recommended Phase 1 Commands (Manual) ===")
print(
"python -m rl.train_ppo "
f"--phase 1 --task {PHASE1_TASK} "
f"--timesteps {args.timesteps} --n_envs {args.n_envs} --seed {args.seed} "
"--phase1-no-progress-bar "
"--phase1-eval-freq 16384 "
"--phase1-n-eval-episodes 2 "
"--phase1-grader-eval-freq-multiplier 4"
)
print(
"python rl/eval_grader.py "
"--model results/best_model/phase1_final "
f"--task {PHASE1_TASK} --episodes 20 --seed {args.seed}"
)
print(
"python rl/plot_training.py "
f"--task {PHASE1_TASK} --phase 1"
)
def main() -> int:
parser = argparse.ArgumentParser(description="Pre-train checklist + GO/NO-GO gate")
parser.add_argument("--run-tests", action="store_true", help="Run targeted RL tests")
parser.add_argument("--timesteps", type=int, default=300000)
parser.add_argument("--n-envs", "--n_envs", dest="n_envs", type=int, default=4)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--json-out", default=None, help="Optional path to write JSON report")
args = parser.parse_args()
checks: list[Callable[[], CheckResult]] = [
check_required_files,
check_python_imports,
check_compile,
check_train_cli,
check_config,
check_env_contract,
check_output_paths,
check_torch_device,
]
if args.run_tests:
checks.append(run_targeted_tests)
results = [fn() for fn in checks]
_print_results(results)
_print_next_commands(args)
if args.json_out:
out_path = Path(args.json_out)
out_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"go_no_go": "NO-GO" if any(r.status == "FAIL" for r in results) else "GO",
"results": [asdict(r) for r in results],
}
out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
print(f"\nJSON report written to: {out_path}")
return 2 if any(r.status == "FAIL" for r in results) else 0
if __name__ == "__main__":
raise SystemExit(main())