forgeenv-source / scripts /preflight_check.py
akhiilll's picture
forgeenv source snapshot for training job
8be8ee0 verified
raw
history blame
13.4 kB
#!/usr/bin/env python
"""Local preflight: validate every component the H200 training job touches
WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason.
Run::
python scripts/preflight_check.py
The script exits non-zero if any required test fails. Optional tests
(network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down.
"""
from __future__ import annotations
import json
import os
import sys
import tempfile
import traceback
from pathlib import Path
from typing import Callable
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
PASS = "[PASS]"
FAIL = "[FAIL]"
SKIP = "[SKIP]"
_results: list[tuple[str, str, str]] = []
def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None:
try:
detail = fn() or ""
_results.append((PASS, label, detail))
print(f"{PASS} {label} {detail}", flush=True)
except _Skip as s:
_results.append((SKIP, label, str(s)))
print(f"{SKIP} {label} {s}", flush=True)
except Exception as e: # noqa: BLE001
tag = FAIL if required else SKIP
_results.append((tag, label, f"{type(e).__name__}: {e}"))
print(f"{tag} {label} {type(e).__name__}: {e}", flush=True)
if required:
traceback.print_exc()
class _Skip(Exception):
pass
def t1_imports() -> str:
import forgeenv # noqa: F401
import trl # noqa: F401
import peft # noqa: F401
import datasets # noqa: F401
import transformers # noqa: F401
import accelerate # noqa: F401
from forgeenv.training.grpo_repair import ( # noqa: F401
run_grpo,
reward_repair_function,
)
from forgeenv.training.plots import ( # noqa: F401
plot_baseline_vs_trained,
plot_reward_curve,
plot_success_rate_by_category,
)
from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction # noqa: F401
from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff # noqa: F401
from forgeenv.env.forge_environment import ForgeEnvironment # noqa: F401
from forgeenv.roles.repair_agent import extract_diff # noqa: F401
from forgeenv.tasks.task_sampler import TaskSampler # noqa: F401
return f"trl={trl.__version__} transformers={transformers.__version__}"
def t1b_openenv_job_extras() -> str:
"""On HF Jobs we ``pip install openenv-core --no-deps`` then add the
packages openenv lists as requirements so ``import openenv.core`` works."""
import fastmcp # noqa: F401
return "fastmcp (required by openenv.core.env_server on import)"
def t2_dataset_load_and_format() -> str:
import datasets as ds
p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl"
if not p.exists():
raise FileNotFoundError(p)
sft_ds = ds.load_dataset("json", data_files=str(p), split="train")
n = len(sft_ds)
if n < 10:
raise ValueError(f"too few rows in repair_pairs.jsonl: {n}")
row = sft_ds[0]
if "messages" not in row or not row["messages"]:
raise KeyError("row missing 'messages' field")
roles = {m["role"] for m in row["messages"]}
if not {"system", "user", "assistant"}.issubset(roles):
raise ValueError(f"unexpected role set: {roles}")
return f"rows={n} roles={sorted(roles)}"
def t3_trl_configs_accept_our_kwargs() -> str:
"""Validate every kwarg name the job passes is accepted by the
current TRL Config classes. We inspect dataclass fields directly so
this works on CPU-only Windows without tripping bf16/use_cpu
validation in transformers' TrainingArguments.__post_init__."""
import dataclasses
from trl import GRPOConfig, SFTConfig
sft_kwargs = {
"output_dir": "/tmp/forge_sft",
"max_steps": 10,
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"learning_rate": 2e-4,
"logging_steps": 25,
"save_steps": 250,
"bf16": True,
"fp16": False,
"max_length": 2048,
"report_to": [],
}
grpo_kwargs = {
"output_dir": "/tmp/forge_grpo",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"learning_rate": 5e-6,
"max_steps": 5,
"num_generations": 4,
"max_completion_length": 1024,
"logging_steps": 5,
"save_steps": 50,
"save_total_limit": 2,
"seed": 0,
"report_to": "none",
"beta": 0.04,
}
def _field_names(cls) -> set[str]:
names: set[str] = set()
for c in cls.__mro__:
if dataclasses.is_dataclass(c):
names.update(f.name for f in dataclasses.fields(c))
return names
sft_fields = _field_names(SFTConfig)
missing_sft = [k for k in sft_kwargs if k not in sft_fields]
if missing_sft:
raise TypeError(f"SFTConfig missing fields: {missing_sft}")
grpo_fields = _field_names(GRPOConfig)
missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields]
if missing_grpo:
raise TypeError(f"GRPOConfig missing fields: {missing_grpo}")
# Best-effort: try actually instantiating with use_cpu=True so even
# __post_init__ runs cleanly under our preflight conditions.
try:
SFTConfig(**sft_kwargs, use_cpu=True, bf16=False)
GRPOConfig(**grpo_kwargs, use_cpu=True)
instantiated = "instantiated OK"
except Exception as e: # noqa: BLE001
instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})"
return (
f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} "
f"grpo_fields={len(grpo_fields)}; {instantiated}"
)
def t4_reward_function_returns_float() -> str:
from forgeenv.training.grpo_repair import reward_repair_function
from forgeenv.tasks.task_sampler import TaskSampler
sampler = TaskSampler()
if not sampler.tasks:
raise RuntimeError("TaskSampler has no tasks")
task_id = sampler.tasks[0].task_id
broken = "x = 1\nprint(x)\n"
fake_completion = (
"--- a/train.py\n"
"+++ b/train.py\n"
"@@ -1,2 +1,2 @@\n"
"-x = 1\n"
"+x = 2\n"
" print(x)\n"
)
rewards = reward_repair_function(
completions=[fake_completion],
prompts=[[]],
task_id=[task_id],
broken_script=[broken],
)
if len(rewards) != 1:
raise ValueError(f"expected 1 reward got {len(rewards)}")
if not isinstance(rewards[0], float):
raise TypeError(f"reward not float: {type(rewards[0])}")
return f"reward={rewards[0]:.3f} (single fake completion)"
def t5_diff_utils_roundtrip() -> str:
from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
from forgeenv.roles.repair_agent import extract_diff
a = "x = 1\nprint(x)\n"
b = "x = 2\nprint(x)\n"
d = make_unified_diff(a, b)
if not d.strip():
raise ValueError("make_unified_diff returned empty")
blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose"
extracted = extract_diff(blob)
if not extracted.strip():
raise ValueError("extract_diff failed to find diff in fenced block")
repaired = apply_unified_diff(a, extracted)
if "x = 2" not in repaired:
raise ValueError(f"apply_unified_diff failed: {repaired!r}")
return f"diff_len={len(d)} extract+apply OK"
def t6_live_env_health() -> str:
import requests
user = os.environ.get("HF_USERNAME", "akhiilll")
url = f"https://{user}-forgeenv.hf.space/health"
try:
r = requests.get(url, timeout=15)
except Exception as e: # noqa: BLE001
raise _Skip(f"network: {e}")
if r.status_code >= 400:
raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}")
return f"{r.status_code} {r.text[:60]!r}"
def t7_source_repo_exists() -> str:
token = os.environ.get("HF_TOKEN")
if not token:
raise _Skip("HF_TOKEN not set")
from huggingface_hub import HfApi
api = HfApi()
user = os.environ.get("HF_USERNAME", "akhiilll")
repo_id = f"{user}/forgeenv-source"
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
needed = "scripts/jobs/train_repair_agent.py"
if needed not in files:
raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})")
return f"{repo_id} has {len(files)} files incl. train_repair_agent.py"
def t8_qwen_tokenizer_loads() -> str:
base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")
token = os.environ.get("HF_TOKEN")
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False)
msgs = [
{"role": "system", "content": "you are a repair agent"},
{"role": "user", "content": "fix this"},
{"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"},
]
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
if "<|im_start|>" not in text:
raise ValueError("ChatML tokens missing from rendered template")
if "fix this" not in text:
raise ValueError("user content not in rendered template")
return f"{base} chat_template renders ChatML ({len(text)} chars)"
def t9_hfapi_auth_and_namespace() -> str:
token = os.environ.get("HF_TOKEN")
if not token:
raise _Skip("HF_TOKEN not set")
from huggingface_hub import HfApi
api = HfApi()
info = api.whoami(token=token)
user = info.get("name") or info.get("fullname")
if not user:
raise RuntimeError(f"whoami returned no name: {info}")
expected = os.environ.get("HF_USERNAME", "akhiilll")
if user != expected:
return f"WARN: token user={user} but HF_USERNAME={expected}"
return f"authed as {user}"
def t10_find_trainer_state() -> str:
sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs"))
with tempfile.TemporaryDirectory() as td:
td_p = Path(td)
ckpt = td_p / "checkpoint-80"
ckpt.mkdir()
state = {
"log_history": [
{"step": 5, "rewards/reward_repair_function/mean": 0.12},
{"step": 10, "rewards/reward_repair_function/mean": 0.34},
]
}
(ckpt / "trainer_state.json").write_text(json.dumps(state))
from importlib import util as _util
spec = _util.spec_from_file_location(
"_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py"
)
if spec is None or spec.loader is None:
raise RuntimeError("can't spec the training script")
# Don't actually load the module (it has top-level CUDA/HF effects).
# Re-implement the same finder here from source.
# The script uses: prefer GRPO_DIR/trainer_state.json, else newest checkpoint-*.
direct = td_p / "trainer_state.json"
if direct.exists():
found = direct
else:
ckpts = sorted(
(p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()),
key=lambda p: int(p.name.split("-")[-1]),
)
found = (ckpts[-1] / "trainer_state.json") if ckpts else None
if found is None or not found.exists():
raise RuntimeError("finder did not locate the synthesized state")
loaded = json.loads(found.read_text())
if len(loaded["log_history"]) != 2:
raise RuntimeError("finder loaded wrong file")
return "checkpoint-N/trainer_state.json discoverable"
def main() -> int:
print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True)
_run("01 imports", t1_imports, required=True)
_run("01b openenv extras (job: after --no-deps)", t1b_openenv_job_extras, required=True)
_run("02 dataset load + format", t2_dataset_load_and_format, required=True)
_run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True)
_run("04 reward fn returns float", t4_reward_function_returns_float, required=True)
_run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True)
_run("06 live env /health", t6_live_env_health, required=False)
_run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False)
_run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True)
_run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False)
_run("10 _find_trainer_state logic", t10_find_trainer_state, required=True)
print("\n=== Summary ===")
n_pass = sum(1 for r in _results if r[0] == PASS)
n_fail = sum(1 for r in _results if r[0] == FAIL)
n_skip = sum(1 for r in _results if r[0] == SKIP)
for tag, label, detail in _results:
print(f"{tag} {label}")
print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped")
return 0 if n_fail == 0 else 1
if __name__ == "__main__":
sys.exit(main())