#!/usr/bin/env python """Local preflight: validate every component the H200 training job touches WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason. Run:: python scripts/preflight_check.py The script exits non-zero if any required test fails. Optional tests (network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down. """ from __future__ import annotations import json import os import sys import tempfile import traceback from pathlib import Path from typing import Callable REPO_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(REPO_ROOT)) PASS = "[PASS]" FAIL = "[FAIL]" SKIP = "[SKIP]" _results: list[tuple[str, str, str]] = [] def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None: try: detail = fn() or "" _results.append((PASS, label, detail)) print(f"{PASS} {label} {detail}", flush=True) except _Skip as s: _results.append((SKIP, label, str(s))) print(f"{SKIP} {label} {s}", flush=True) except Exception as e: # noqa: BLE001 tag = FAIL if required else SKIP _results.append((tag, label, f"{type(e).__name__}: {e}")) print(f"{tag} {label} {type(e).__name__}: {e}", flush=True) if required: traceback.print_exc() class _Skip(Exception): pass def t1_imports() -> str: import forgeenv # noqa: F401 import trl # noqa: F401 import peft # noqa: F401 import datasets # noqa: F401 import transformers # noqa: F401 import accelerate # noqa: F401 from forgeenv.training.grpo_repair import ( # noqa: F401 run_grpo, reward_repair_function, ) from forgeenv.training.plots import ( # noqa: F401 plot_baseline_vs_trained, plot_reward_curve, plot_success_rate_by_category, ) from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction # noqa: F401 from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff # noqa: F401 from forgeenv.env.forge_environment import ForgeEnvironment # noqa: F401 from forgeenv.roles.repair_agent import extract_diff # noqa: F401 from forgeenv.tasks.task_sampler import TaskSampler # noqa: F401 return f"trl={trl.__version__} transformers={transformers.__version__}" def t1b_openenv_job_extras() -> str: """On HF Jobs we ``pip install openenv-core --no-deps`` then add the packages openenv lists as requirements so ``import openenv.core`` works.""" import fastmcp # noqa: F401 return "fastmcp (required by openenv.core.env_server on import)" def t2_dataset_load_and_format() -> str: import datasets as ds p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl" if not p.exists(): raise FileNotFoundError(p) sft_ds = ds.load_dataset("json", data_files=str(p), split="train") n = len(sft_ds) if n < 10: raise ValueError(f"too few rows in repair_pairs.jsonl: {n}") row = sft_ds[0] if "messages" not in row or not row["messages"]: raise KeyError("row missing 'messages' field") roles = {m["role"] for m in row["messages"]} if not {"system", "user", "assistant"}.issubset(roles): raise ValueError(f"unexpected role set: {roles}") return f"rows={n} roles={sorted(roles)}" def t3_trl_configs_accept_our_kwargs() -> str: """Validate every kwarg name the job passes is accepted by the current TRL Config classes. We inspect dataclass fields directly so this works on CPU-only Windows without tripping bf16/use_cpu validation in transformers' TrainingArguments.__post_init__.""" import dataclasses from trl import GRPOConfig, SFTConfig sft_kwargs = { "output_dir": "/tmp/forge_sft", "max_steps": 10, "per_device_train_batch_size": 4, "gradient_accumulation_steps": 4, "learning_rate": 2e-4, "logging_steps": 25, "save_steps": 250, "bf16": True, "fp16": False, "max_length": 2048, "report_to": [], } grpo_kwargs = { "output_dir": "/tmp/forge_grpo", "per_device_train_batch_size": 1, "gradient_accumulation_steps": 4, "learning_rate": 5e-6, "max_steps": 5, "num_generations": 4, "max_completion_length": 1024, "logging_steps": 5, "save_steps": 50, "save_total_limit": 2, "seed": 0, "report_to": "none", "beta": 0.04, } def _field_names(cls) -> set[str]: names: set[str] = set() for c in cls.__mro__: if dataclasses.is_dataclass(c): names.update(f.name for f in dataclasses.fields(c)) return names sft_fields = _field_names(SFTConfig) missing_sft = [k for k in sft_kwargs if k not in sft_fields] if missing_sft: raise TypeError(f"SFTConfig missing fields: {missing_sft}") grpo_fields = _field_names(GRPOConfig) missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields] if missing_grpo: raise TypeError(f"GRPOConfig missing fields: {missing_grpo}") # Best-effort: try actually instantiating with use_cpu=True so even # __post_init__ runs cleanly under our preflight conditions. try: SFTConfig(**sft_kwargs, use_cpu=True, bf16=False) GRPOConfig(**grpo_kwargs, use_cpu=True) instantiated = "instantiated OK" except Exception as e: # noqa: BLE001 instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})" return ( f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} " f"grpo_fields={len(grpo_fields)}; {instantiated}" ) def t4_reward_function_returns_float() -> str: from forgeenv.training.grpo_repair import reward_repair_function from forgeenv.tasks.task_sampler import TaskSampler sampler = TaskSampler() if not sampler.tasks: raise RuntimeError("TaskSampler has no tasks") task_id = sampler.tasks[0].task_id broken = "x = 1\nprint(x)\n" fake_completion = ( "--- a/train.py\n" "+++ b/train.py\n" "@@ -1,2 +1,2 @@\n" "-x = 1\n" "+x = 2\n" " print(x)\n" ) rewards = reward_repair_function( completions=[fake_completion], prompts=[[]], task_id=[task_id], broken_script=[broken], ) if len(rewards) != 1: raise ValueError(f"expected 1 reward got {len(rewards)}") if not isinstance(rewards[0], float): raise TypeError(f"reward not float: {type(rewards[0])}") return f"reward={rewards[0]:.3f} (single fake completion)" def t5_diff_utils_roundtrip() -> str: from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff from forgeenv.roles.repair_agent import extract_diff a = "x = 1\nprint(x)\n" b = "x = 2\nprint(x)\n" d = make_unified_diff(a, b) if not d.strip(): raise ValueError("make_unified_diff returned empty") blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose" extracted = extract_diff(blob) if not extracted.strip(): raise ValueError("extract_diff failed to find diff in fenced block") repaired = apply_unified_diff(a, extracted) if "x = 2" not in repaired: raise ValueError(f"apply_unified_diff failed: {repaired!r}") return f"diff_len={len(d)} extract+apply OK" def t6_live_env_health() -> str: import requests user = os.environ.get("HF_USERNAME", "akhiilll") url = f"https://{user}-forgeenv.hf.space/health" try: r = requests.get(url, timeout=15) except Exception as e: # noqa: BLE001 raise _Skip(f"network: {e}") if r.status_code >= 400: raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}") return f"{r.status_code} {r.text[:60]!r}" def t7_source_repo_exists() -> str: token = os.environ.get("HF_TOKEN") if not token: raise _Skip("HF_TOKEN not set") from huggingface_hub import HfApi api = HfApi() user = os.environ.get("HF_USERNAME", "akhiilll") repo_id = f"{user}/forgeenv-source" files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token) needed = "scripts/jobs/train_repair_agent.py" if needed not in files: raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})") return f"{repo_id} has {len(files)} files incl. train_repair_agent.py" def t8_qwen_tokenizer_loads() -> str: base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct") token = os.environ.get("HF_TOKEN") from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False) msgs = [ {"role": "system", "content": "you are a repair agent"}, {"role": "user", "content": "fix this"}, {"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"}, ] text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) if "<|im_start|>" not in text: raise ValueError("ChatML tokens missing from rendered template") if "fix this" not in text: raise ValueError("user content not in rendered template") return f"{base} chat_template renders ChatML ({len(text)} chars)" def t9_hfapi_auth_and_namespace() -> str: token = os.environ.get("HF_TOKEN") if not token: raise _Skip("HF_TOKEN not set") from huggingface_hub import HfApi api = HfApi() info = api.whoami(token=token) user = info.get("name") or info.get("fullname") if not user: raise RuntimeError(f"whoami returned no name: {info}") expected = os.environ.get("HF_USERNAME", "akhiilll") if user != expected: return f"WARN: token user={user} but HF_USERNAME={expected}" return f"authed as {user}" def t10_find_trainer_state() -> str: sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs")) with tempfile.TemporaryDirectory() as td: td_p = Path(td) ckpt = td_p / "checkpoint-80" ckpt.mkdir() state = { "log_history": [ {"step": 5, "rewards/reward_repair_function/mean": 0.12}, {"step": 10, "rewards/reward_repair_function/mean": 0.34}, ] } (ckpt / "trainer_state.json").write_text(json.dumps(state)) from importlib import util as _util spec = _util.spec_from_file_location( "_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py" ) if spec is None or spec.loader is None: raise RuntimeError("can't spec the training script") # Don't actually load the module (it has top-level CUDA/HF effects). # Re-implement the same finder here from source. # The script uses: prefer GRPO_DIR/trainer_state.json, else newest checkpoint-*. direct = td_p / "trainer_state.json" if direct.exists(): found = direct else: ckpts = sorted( (p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()), key=lambda p: int(p.name.split("-")[-1]), ) found = (ckpts[-1] / "trainer_state.json") if ckpts else None if found is None or not found.exists(): raise RuntimeError("finder did not locate the synthesized state") loaded = json.loads(found.read_text()) if len(loaded["log_history"]) != 2: raise RuntimeError("finder loaded wrong file") return "checkpoint-N/trainer_state.json discoverable" def main() -> int: print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True) _run("01 imports", t1_imports, required=True) _run("01b openenv extras (job: after --no-deps)", t1b_openenv_job_extras, required=True) _run("02 dataset load + format", t2_dataset_load_and_format, required=True) _run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True) _run("04 reward fn returns float", t4_reward_function_returns_float, required=True) _run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True) _run("06 live env /health", t6_live_env_health, required=False) _run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False) _run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True) _run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False) _run("10 _find_trainer_state logic", t10_find_trainer_state, required=True) print("\n=== Summary ===") n_pass = sum(1 for r in _results if r[0] == PASS) n_fail = sum(1 for r in _results if r[0] == FAIL) n_skip = sum(1 for r in _results if r[0] == SKIP) for tag, label, detail in _results: print(f"{tag} {label}") print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped") return 0 if n_fail == 0 else 1 if __name__ == "__main__": sys.exit(main())