Spaces:
Running
Running
| """ | |
| Pre-train checklist + GO/NO-GO gate for Gov Workflow RL Phase 1. | |
| This script validates the local training stack without running training. | |
| Use it before starting Phase 1 retraining. | |
| Usage: | |
| python scripts/pretrain_go_nogo.py | |
| python scripts/pretrain_go_nogo.py --run-tests | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import importlib | |
| import json | |
| import subprocess | |
| import sys | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Callable | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| PHASE1_TASK = "district_backlog_easy" | |
| EXPECTED_OBS_DIM = 84 | |
| EXPECTED_ACTIONS = 28 | |
| class CheckResult: | |
| name: str | |
| status: str # PASS | WARN | FAIL | |
| detail: str | |
| def _run_cmd(cmd: list[str], cwd: Path | None = None) -> tuple[int, str, str]: | |
| proc = subprocess.run( | |
| cmd, | |
| cwd=str(cwd or ROOT), | |
| capture_output=True, | |
| text=True, | |
| ) | |
| return proc.returncode, proc.stdout, proc.stderr | |
| def check_required_files() -> CheckResult: | |
| required = [ | |
| "rl/train_ppo.py", | |
| "rl/train_recurrent.py", | |
| "rl/gov_workflow_env.py", | |
| "rl/feature_builder.py", | |
| "rl/action_mask.py", | |
| "rl/callbacks.py", | |
| "rl/curriculum.py", | |
| "rl/cost_tracker.py", | |
| "rl/evaluate.py", | |
| "rl/eval_grader.py", | |
| "rl/plot_training.py", | |
| "rl/configs/ppo_easy.yaml", | |
| "app/env.py", | |
| "app/models.py", | |
| "app/tasks.py", | |
| "app/reward.py", | |
| "app/graders.py", | |
| ] | |
| missing = [p for p in required if not (ROOT / p).exists()] | |
| if missing: | |
| return CheckResult( | |
| name="required_files", | |
| status="FAIL", | |
| detail="Missing files: " + ", ".join(missing), | |
| ) | |
| return CheckResult( | |
| name="required_files", | |
| status="PASS", | |
| detail=f"{len(required)} required files present", | |
| ) | |
| def check_python_imports() -> CheckResult: | |
| modules = [ | |
| "yaml", | |
| "numpy", | |
| "gymnasium", | |
| "torch", | |
| "stable_baselines3", | |
| "sb3_contrib", | |
| "tensorboard", | |
| "rl.train_ppo", | |
| "rl.train_recurrent", | |
| "rl.gov_workflow_env", | |
| "rl.feature_builder", | |
| "rl.action_mask", | |
| "rl.callbacks", | |
| "rl.evaluate", | |
| "rl.eval_grader", | |
| "app.env", | |
| "app.tasks", | |
| "app.graders", | |
| ] | |
| failed: list[str] = [] | |
| for mod in modules: | |
| try: | |
| importlib.import_module(mod) | |
| except Exception: | |
| failed.append(mod) | |
| if failed: | |
| return CheckResult( | |
| name="python_imports", | |
| status="FAIL", | |
| detail="Import failures: " + ", ".join(failed), | |
| ) | |
| return CheckResult( | |
| name="python_imports", | |
| status="PASS", | |
| detail=f"{len(modules)} modules import cleanly", | |
| ) | |
| def check_compile() -> CheckResult: | |
| targets = [ | |
| "rl/train_ppo.py", | |
| "rl/train_recurrent.py", | |
| "rl/gov_workflow_env.py", | |
| "rl/feature_builder.py", | |
| "rl/action_mask.py", | |
| "rl/callbacks.py", | |
| "rl/evaluate.py", | |
| "rl/eval_grader.py", | |
| "app/env.py", | |
| "app/reward.py", | |
| "app/graders.py", | |
| "app/tasks.py", | |
| ] | |
| cmd = [sys.executable, "-m", "py_compile", *targets] | |
| rc, _out, err = _run_cmd(cmd, ROOT) | |
| if rc != 0: | |
| return CheckResult( | |
| name="py_compile", | |
| status="FAIL", | |
| detail=err.strip() or "py_compile failed", | |
| ) | |
| return CheckResult( | |
| name="py_compile", | |
| status="PASS", | |
| detail=f"{len(targets)} files compiled successfully", | |
| ) | |
| def check_env_contract() -> CheckResult: | |
| try: | |
| from rl.gov_workflow_env import GovWorkflowGymEnv | |
| env = GovWorkflowGymEnv(task_id=PHASE1_TASK, seed=42) | |
| obs, info = env.reset(seed=42) | |
| masks = env.action_masks() | |
| _obs2, reward, terminated, truncated, step_info = env.step(18) | |
| problems: list[str] = [] | |
| if tuple(obs.shape) != (EXPECTED_OBS_DIM,): | |
| problems.append(f"obs shape={tuple(obs.shape)} expected={(EXPECTED_OBS_DIM,)}") | |
| if int(env.action_space.n) != EXPECTED_ACTIONS: | |
| problems.append(f"action_space={env.action_space.n} expected={EXPECTED_ACTIONS}") | |
| if len(masks) != EXPECTED_ACTIONS: | |
| problems.append(f"mask_len={len(masks)} expected={EXPECTED_ACTIONS}") | |
| if int(sum(bool(x) for x in masks)) <= 0: | |
| problems.append("all actions masked") | |
| if not isinstance(info, dict): | |
| problems.append("reset info is not dict") | |
| if not isinstance(step_info, dict): | |
| problems.append("step info is not dict") | |
| if not isinstance(float(reward), float): | |
| problems.append("reward not float-castable") | |
| if not isinstance(bool(terminated), bool) or not isinstance(bool(truncated), bool): | |
| problems.append("terminated/truncated invalid type") | |
| if problems: | |
| return CheckResult( | |
| name="gym_env_contract", | |
| status="FAIL", | |
| detail="; ".join(problems), | |
| ) | |
| return CheckResult( | |
| name="gym_env_contract", | |
| status="PASS", | |
| detail=f"obs={obs.shape}, action_n={env.action_space.n}, valid_masks={int(sum(masks))}", | |
| ) | |
| except Exception as exc: | |
| return CheckResult( | |
| name="gym_env_contract", | |
| status="FAIL", | |
| detail=f"{type(exc).__name__}: {exc}", | |
| ) | |
| def check_output_paths() -> CheckResult: | |
| needed_dirs = [ | |
| ROOT / "results", | |
| ROOT / "results" / "best_model", | |
| ROOT / "results" / "runs", | |
| ROOT / "results" / "eval_logs", | |
| ROOT / "logs", | |
| ] | |
| try: | |
| for d in needed_dirs: | |
| d.mkdir(parents=True, exist_ok=True) | |
| probe = d / ".write_probe.tmp" | |
| probe.write_text("ok", encoding="utf-8") | |
| probe.unlink(missing_ok=True) | |
| except Exception as exc: | |
| return CheckResult( | |
| name="output_paths", | |
| status="FAIL", | |
| detail=f"{type(exc).__name__}: {exc}", | |
| ) | |
| return CheckResult( | |
| name="output_paths", | |
| status="PASS", | |
| detail="results/ and logs/ are writable", | |
| ) | |
| def check_train_cli() -> CheckResult: | |
| cmd = [sys.executable, "-m", "rl.train_ppo", "--help"] | |
| rc, out, err = _run_cmd(cmd, ROOT) | |
| if rc != 0: | |
| return CheckResult( | |
| name="train_cli", | |
| status="FAIL", | |
| detail=err.strip() or "train_ppo --help failed", | |
| ) | |
| needed_flags = [ | |
| "--phase", | |
| "--timesteps", | |
| "--n_envs", | |
| "--task", | |
| "--phase1-eval-freq", | |
| "--phase1-n-eval-episodes", | |
| "--phase1-disable-eval-callback", | |
| "--phase1-grader-eval-freq-multiplier", | |
| ] | |
| missing = [f for f in needed_flags if f not in out] | |
| if missing: | |
| return CheckResult( | |
| name="train_cli", | |
| status="WARN", | |
| detail="Missing expected flags in help output: " + ", ".join(missing), | |
| ) | |
| return CheckResult( | |
| name="train_cli", | |
| status="PASS", | |
| detail="train_ppo CLI flags detected", | |
| ) | |
| def check_config() -> CheckResult: | |
| try: | |
| import yaml | |
| cfg_path = ROOT / "rl" / "configs" / "ppo_easy.yaml" | |
| cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8-sig")) or {} | |
| hp = cfg.get("hyperparameters", {}) | |
| tr = cfg.get("training", {}) | |
| required_fields = [ | |
| ("hyperparameters", "learning_rate"), | |
| ("hyperparameters", "n_steps"), | |
| ("hyperparameters", "batch_size"), | |
| ("training", "n_envs"), | |
| ("training", "seed"), | |
| ("training", "eval_freq"), | |
| ("training", "n_eval_episodes"), | |
| ] | |
| missing = [] | |
| for section, key in required_fields: | |
| parent = hp if section == "hyperparameters" else tr | |
| if key not in parent: | |
| missing.append(f"{section}.{key}") | |
| if missing: | |
| return CheckResult( | |
| name="ppo_easy_config", | |
| status="FAIL", | |
| detail="Missing config fields: " + ", ".join(missing), | |
| ) | |
| warnings = [] | |
| if int(tr.get("eval_freq", 0)) < 2048: | |
| warnings.append("eval_freq is very low; may cause frequent pauses") | |
| if int(tr.get("n_eval_episodes", 0)) > 5: | |
| warnings.append("n_eval_episodes is high; callback cost may increase") | |
| if warnings: | |
| return CheckResult( | |
| name="ppo_easy_config", | |
| status="WARN", | |
| detail="; ".join(warnings), | |
| ) | |
| return CheckResult( | |
| name="ppo_easy_config", | |
| status="PASS", | |
| detail="Phase 1 config fields are present and reasonable", | |
| ) | |
| except Exception as exc: | |
| return CheckResult( | |
| name="ppo_easy_config", | |
| status="FAIL", | |
| detail=f"{type(exc).__name__}: {exc}", | |
| ) | |
| def check_torch_device() -> CheckResult: | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| return CheckResult( | |
| name="torch_device", | |
| status="PASS", | |
| detail=f"CUDA available ({torch.cuda.get_device_name(0)})", | |
| ) | |
| return CheckResult( | |
| name="torch_device", | |
| status="WARN", | |
| detail="CUDA not available; CPU training is expected", | |
| ) | |
| except Exception as exc: | |
| return CheckResult( | |
| name="torch_device", | |
| status="WARN", | |
| detail=f"torch device check skipped: {type(exc).__name__}: {exc}", | |
| ) | |
| def run_targeted_tests() -> CheckResult: | |
| test_cmd = [ | |
| sys.executable, | |
| "-m", | |
| "pytest", | |
| "tests/test_env.py", | |
| "tests/test_gym_wrapper.py", | |
| "tests/test_gym_wrapper_integration.py", | |
| "tests/test_feature_builder.py", | |
| "tests/test_action_mask.py", | |
| "tests/test_curriculum.py", | |
| "tests/test_rl_evaluate.py", | |
| "-q", | |
| "--tb=short", | |
| ] | |
| rc, out, err = _run_cmd(test_cmd, ROOT) | |
| if rc != 0: | |
| return CheckResult( | |
| name="targeted_tests", | |
| status="FAIL", | |
| detail=(out + "\n" + err).strip()[-1200:], | |
| ) | |
| return CheckResult( | |
| name="targeted_tests", | |
| status="PASS", | |
| detail=out.strip().splitlines()[-1] if out.strip() else "targeted tests passed", | |
| ) | |
| def _print_results(results: list[CheckResult]) -> None: | |
| print("\n=== Pre-Train Checklist Results ===") | |
| for r in results: | |
| print(f"[{r.status}] {r.name}: {r.detail}") | |
| fail_count = sum(1 for r in results if r.status == "FAIL") | |
| warn_count = sum(1 for r in results if r.status == "WARN") | |
| print("\n=== Gate Decision ===") | |
| if fail_count > 0: | |
| print(f"NO-GO: {fail_count} failing check(s). Resolve failures before training.") | |
| else: | |
| print( | |
| f"GO: no failing checks. " | |
| f"{warn_count} warning(s) can be reviewed but do not block training." | |
| ) | |
| def _print_next_commands(args: argparse.Namespace) -> None: | |
| print("\n=== Recommended Phase 1 Commands (Manual) ===") | |
| print( | |
| "python -m rl.train_ppo " | |
| f"--phase 1 --task {PHASE1_TASK} " | |
| f"--timesteps {args.timesteps} --n_envs {args.n_envs} --seed {args.seed} " | |
| "--phase1-no-progress-bar " | |
| "--phase1-eval-freq 16384 " | |
| "--phase1-n-eval-episodes 2 " | |
| "--phase1-grader-eval-freq-multiplier 4" | |
| ) | |
| print( | |
| "python rl/eval_grader.py " | |
| "--model results/best_model/phase1_final " | |
| f"--task {PHASE1_TASK} --episodes 20 --seed {args.seed}" | |
| ) | |
| print( | |
| "python rl/plot_training.py " | |
| f"--task {PHASE1_TASK} --phase 1" | |
| ) | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Pre-train checklist + GO/NO-GO gate") | |
| parser.add_argument("--run-tests", action="store_true", help="Run targeted RL tests") | |
| parser.add_argument("--timesteps", type=int, default=300000) | |
| parser.add_argument("--n-envs", "--n_envs", dest="n_envs", type=int, default=4) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--json-out", default=None, help="Optional path to write JSON report") | |
| args = parser.parse_args() | |
| checks: list[Callable[[], CheckResult]] = [ | |
| check_required_files, | |
| check_python_imports, | |
| check_compile, | |
| check_train_cli, | |
| check_config, | |
| check_env_contract, | |
| check_output_paths, | |
| check_torch_device, | |
| ] | |
| if args.run_tests: | |
| checks.append(run_targeted_tests) | |
| results = [fn() for fn in checks] | |
| _print_results(results) | |
| _print_next_commands(args) | |
| if args.json_out: | |
| out_path = Path(args.json_out) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "go_no_go": "NO-GO" if any(r.status == "FAIL" for r in results) else "GO", | |
| "results": [asdict(r) for r in results], | |
| } | |
| out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| print(f"\nJSON report written to: {out_path}") | |
| return 2 if any(r.status == "FAIL" for r in results) else 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |