|
|
| """Local preflight: validate every component the H200 training job touches
|
| WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason.
|
|
|
| Run::
|
|
|
| python scripts/preflight_check.py
|
|
|
| The script exits non-zero if any required test fails. Optional tests
|
| (network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down.
|
| """
|
| from __future__ import annotations
|
|
|
| import json
|
| import os
|
| import sys
|
| import tempfile
|
| import traceback
|
| from pathlib import Path
|
| from typing import Callable
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1]
|
| sys.path.insert(0, str(REPO_ROOT))
|
|
|
| PASS = "[PASS]"
|
| FAIL = "[FAIL]"
|
| SKIP = "[SKIP]"
|
|
|
| _results: list[tuple[str, str, str]] = []
|
|
|
|
|
| def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None:
|
| try:
|
| detail = fn() or ""
|
| _results.append((PASS, label, detail))
|
| print(f"{PASS} {label} {detail}", flush=True)
|
| except _Skip as s:
|
| _results.append((SKIP, label, str(s)))
|
| print(f"{SKIP} {label} {s}", flush=True)
|
| except Exception as e:
|
| tag = FAIL if required else SKIP
|
| _results.append((tag, label, f"{type(e).__name__}: {e}"))
|
| print(f"{tag} {label} {type(e).__name__}: {e}", flush=True)
|
| if required:
|
| traceback.print_exc()
|
|
|
|
|
| class _Skip(Exception):
|
| pass
|
|
|
|
|
| def t1_imports() -> str:
|
| import forgeenv
|
| import trl
|
| import peft
|
| import datasets
|
| import transformers
|
| import accelerate
|
|
|
| from forgeenv.training.grpo_repair import (
|
| run_grpo,
|
| reward_repair_function,
|
| )
|
| from forgeenv.training.plots import (
|
| plot_baseline_vs_trained,
|
| plot_reward_curve,
|
| plot_success_rate_by_category,
|
| )
|
| from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
|
| from forgeenv.env.forge_environment import ForgeEnvironment
|
| from forgeenv.roles.repair_agent import extract_diff
|
| from forgeenv.tasks.task_sampler import TaskSampler
|
|
|
| return f"trl={trl.__version__} transformers={transformers.__version__}"
|
|
|
|
|
| def t1b_openenv_job_extras() -> str:
|
| """On HF Jobs we ``pip install openenv-core --no-deps`` then add the
|
| packages openenv lists as requirements so ``import openenv.core`` works."""
|
| import fastmcp
|
|
|
| return "fastmcp (required by openenv.core.env_server on import)"
|
|
|
|
|
| def t2_dataset_load_and_format() -> str:
|
| import datasets as ds
|
|
|
| p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl"
|
| if not p.exists():
|
| raise FileNotFoundError(p)
|
| sft_ds = ds.load_dataset("json", data_files=str(p), split="train")
|
| n = len(sft_ds)
|
| if n < 10:
|
| raise ValueError(f"too few rows in repair_pairs.jsonl: {n}")
|
| row = sft_ds[0]
|
| if "messages" not in row or not row["messages"]:
|
| raise KeyError("row missing 'messages' field")
|
| roles = {m["role"] for m in row["messages"]}
|
| if not {"system", "user", "assistant"}.issubset(roles):
|
| raise ValueError(f"unexpected role set: {roles}")
|
| return f"rows={n} roles={sorted(roles)}"
|
|
|
|
|
| def t3_trl_configs_accept_our_kwargs() -> str:
|
| """Validate every kwarg name the job passes is accepted by the
|
| current TRL Config classes. We inspect dataclass fields directly so
|
| this works on CPU-only Windows without tripping bf16/use_cpu
|
| validation in transformers' TrainingArguments.__post_init__."""
|
| import dataclasses
|
|
|
| from trl import GRPOConfig, SFTConfig
|
|
|
| sft_kwargs = {
|
| "output_dir": "/tmp/forge_sft",
|
| "max_steps": 10,
|
| "per_device_train_batch_size": 4,
|
| "gradient_accumulation_steps": 4,
|
| "learning_rate": 2e-4,
|
| "logging_steps": 25,
|
| "save_steps": 250,
|
| "bf16": True,
|
| "fp16": False,
|
| "max_length": 2048,
|
| "report_to": [],
|
| }
|
| grpo_kwargs = {
|
| "output_dir": "/tmp/forge_grpo",
|
| "per_device_train_batch_size": 1,
|
| "gradient_accumulation_steps": 4,
|
| "learning_rate": 5e-6,
|
| "max_steps": 5,
|
| "num_generations": 4,
|
| "max_completion_length": 1024,
|
| "logging_steps": 5,
|
| "save_steps": 50,
|
| "save_total_limit": 2,
|
| "seed": 0,
|
| "report_to": "none",
|
| "beta": 0.04,
|
| }
|
|
|
| def _field_names(cls) -> set[str]:
|
| names: set[str] = set()
|
| for c in cls.__mro__:
|
| if dataclasses.is_dataclass(c):
|
| names.update(f.name for f in dataclasses.fields(c))
|
| return names
|
|
|
| sft_fields = _field_names(SFTConfig)
|
| missing_sft = [k for k in sft_kwargs if k not in sft_fields]
|
| if missing_sft:
|
| raise TypeError(f"SFTConfig missing fields: {missing_sft}")
|
|
|
| grpo_fields = _field_names(GRPOConfig)
|
| missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields]
|
| if missing_grpo:
|
| raise TypeError(f"GRPOConfig missing fields: {missing_grpo}")
|
|
|
|
|
|
|
| try:
|
| SFTConfig(**sft_kwargs, use_cpu=True, bf16=False)
|
| GRPOConfig(**grpo_kwargs, use_cpu=True)
|
| instantiated = "instantiated OK"
|
| except Exception as e:
|
| instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})"
|
|
|
| return (
|
| f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} "
|
| f"grpo_fields={len(grpo_fields)}; {instantiated}"
|
| )
|
|
|
|
|
| def t4_reward_function_returns_float() -> str:
|
| from forgeenv.training.grpo_repair import reward_repair_function
|
| from forgeenv.tasks.task_sampler import TaskSampler
|
|
|
| sampler = TaskSampler()
|
| if not sampler.tasks:
|
| raise RuntimeError("TaskSampler has no tasks")
|
| task_id = sampler.tasks[0].task_id
|
| broken = "x = 1\nprint(x)\n"
|
| fake_completion = (
|
| "--- a/train.py\n"
|
| "+++ b/train.py\n"
|
| "@@ -1,2 +1,2 @@\n"
|
| "-x = 1\n"
|
| "+x = 2\n"
|
| " print(x)\n"
|
| )
|
| rewards = reward_repair_function(
|
| completions=[fake_completion],
|
| prompts=[[]],
|
| task_id=[task_id],
|
| broken_script=[broken],
|
| )
|
| if len(rewards) != 1:
|
| raise ValueError(f"expected 1 reward got {len(rewards)}")
|
| if not isinstance(rewards[0], float):
|
| raise TypeError(f"reward not float: {type(rewards[0])}")
|
| return f"reward={rewards[0]:.3f} (single fake completion)"
|
|
|
|
|
| def t5_diff_utils_roundtrip() -> str:
|
| from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
|
| from forgeenv.roles.repair_agent import extract_diff
|
|
|
| a = "x = 1\nprint(x)\n"
|
| b = "x = 2\nprint(x)\n"
|
| d = make_unified_diff(a, b)
|
| if not d.strip():
|
| raise ValueError("make_unified_diff returned empty")
|
| blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose"
|
| extracted = extract_diff(blob)
|
| if not extracted.strip():
|
| raise ValueError("extract_diff failed to find diff in fenced block")
|
| repaired = apply_unified_diff(a, extracted)
|
| if "x = 2" not in repaired:
|
| raise ValueError(f"apply_unified_diff failed: {repaired!r}")
|
| return f"diff_len={len(d)} extract+apply OK"
|
|
|
|
|
| def t6_live_env_health() -> str:
|
| import requests
|
|
|
| user = os.environ.get("HF_USERNAME", "akhiilll")
|
| url = f"https://{user}-forgeenv.hf.space/health"
|
| try:
|
| r = requests.get(url, timeout=15)
|
| except Exception as e:
|
| raise _Skip(f"network: {e}")
|
| if r.status_code >= 400:
|
| raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}")
|
| return f"{r.status_code} {r.text[:60]!r}"
|
|
|
|
|
| def t7_source_repo_exists() -> str:
|
| token = os.environ.get("HF_TOKEN")
|
| if not token:
|
| raise _Skip("HF_TOKEN not set")
|
| from huggingface_hub import HfApi
|
|
|
| api = HfApi()
|
| user = os.environ.get("HF_USERNAME", "akhiilll")
|
| repo_id = f"{user}/forgeenv-source"
|
| files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
|
| needed = "scripts/jobs/train_repair_agent.py"
|
| if needed not in files:
|
| raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})")
|
| return f"{repo_id} has {len(files)} files incl. train_repair_agent.py"
|
|
|
|
|
| def t8_qwen_tokenizer_loads() -> str:
|
| base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")
|
| token = os.environ.get("HF_TOKEN")
|
| from transformers import AutoTokenizer
|
|
|
| tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False)
|
| msgs = [
|
| {"role": "system", "content": "you are a repair agent"},
|
| {"role": "user", "content": "fix this"},
|
| {"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"},
|
| ]
|
| text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
|
| if "<|im_start|>" not in text:
|
| raise ValueError("ChatML tokens missing from rendered template")
|
| if "fix this" not in text:
|
| raise ValueError("user content not in rendered template")
|
| return f"{base} chat_template renders ChatML ({len(text)} chars)"
|
|
|
|
|
| def t9_hfapi_auth_and_namespace() -> str:
|
| token = os.environ.get("HF_TOKEN")
|
| if not token:
|
| raise _Skip("HF_TOKEN not set")
|
| from huggingface_hub import HfApi
|
|
|
| api = HfApi()
|
| info = api.whoami(token=token)
|
| user = info.get("name") or info.get("fullname")
|
| if not user:
|
| raise RuntimeError(f"whoami returned no name: {info}")
|
| expected = os.environ.get("HF_USERNAME", "akhiilll")
|
| if user != expected:
|
| return f"WARN: token user={user} but HF_USERNAME={expected}"
|
| return f"authed as {user}"
|
|
|
|
|
| def t10_find_trainer_state() -> str:
|
| sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs"))
|
| with tempfile.TemporaryDirectory() as td:
|
| td_p = Path(td)
|
| ckpt = td_p / "checkpoint-80"
|
| ckpt.mkdir()
|
| state = {
|
| "log_history": [
|
| {"step": 5, "rewards/reward_repair_function/mean": 0.12},
|
| {"step": 10, "rewards/reward_repair_function/mean": 0.34},
|
| ]
|
| }
|
| (ckpt / "trainer_state.json").write_text(json.dumps(state))
|
| from importlib import util as _util
|
|
|
| spec = _util.spec_from_file_location(
|
| "_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py"
|
| )
|
| if spec is None or spec.loader is None:
|
| raise RuntimeError("can't spec the training script")
|
|
|
|
|
|
|
| direct = td_p / "trainer_state.json"
|
| if direct.exists():
|
| found = direct
|
| else:
|
| ckpts = sorted(
|
| (p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()),
|
| key=lambda p: int(p.name.split("-")[-1]),
|
| )
|
| found = (ckpts[-1] / "trainer_state.json") if ckpts else None
|
| if found is None or not found.exists():
|
| raise RuntimeError("finder did not locate the synthesized state")
|
| loaded = json.loads(found.read_text())
|
| if len(loaded["log_history"]) != 2:
|
| raise RuntimeError("finder loaded wrong file")
|
| return "checkpoint-N/trainer_state.json discoverable"
|
|
|
|
|
| def main() -> int:
|
| print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True)
|
| _run("01 imports", t1_imports, required=True)
|
| _run("01b openenv extras (job: after --no-deps)", t1b_openenv_job_extras, required=True)
|
| _run("02 dataset load + format", t2_dataset_load_and_format, required=True)
|
| _run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True)
|
| _run("04 reward fn returns float", t4_reward_function_returns_float, required=True)
|
| _run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True)
|
| _run("06 live env /health", t6_live_env_health, required=False)
|
| _run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False)
|
| _run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True)
|
| _run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False)
|
| _run("10 _find_trainer_state logic", t10_find_trainer_state, required=True)
|
|
|
| print("\n=== Summary ===")
|
| n_pass = sum(1 for r in _results if r[0] == PASS)
|
| n_fail = sum(1 for r in _results if r[0] == FAIL)
|
| n_skip = sum(1 for r in _results if r[0] == SKIP)
|
| for tag, label, detail in _results:
|
| print(f"{tag} {label}")
|
| print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped")
|
| return 0 if n_fail == 0 else 1
|
|
|
|
|
| if __name__ == "__main__":
|
| sys.exit(main())
|
|
|