forgeenv-source / scripts /jobs /train_repair_agent.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
#!/usr/bin/env python
"""Job-side training entrypoint for ForgeEnv on HF Jobs A100.
Submitted via ``scripts/submit_training_job.py``. The launcher fills in
``HF_TOKEN``, ``HF_USERNAME``, ``ENV_URL`` as Job env vars. The job:
1. Clones ``<HF_USERNAME>/forgeenv-source`` (full project tree).
2. Installs the repo with training extras.
3. Sanity-pings the live env Space.
4. Runs warm-start SFT (TRL SFTTrainer + Unsloth, 4-bit LoRA).
5. Runs GRPO repair (TRL GRPOTrainer) starting from the SFT adapter.
6. Generates plots via ``forgeenv.training.plots``.
7. Pushes the LoRA + ``repair_library.json`` + plots to
``<HF_USERNAME>/forgeenv-repair-agent``.
The script is linear and prints big section markers so the streaming log
is easy to follow from the launcher.
"""
from __future__ import annotations
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
def _sh(cmd: list[str], **kwargs) -> None:
print(f"[job] $ {' '.join(cmd)}", flush=True)
subprocess.check_call(cmd, **kwargs)
def step(label: str) -> None:
print(f"\n========== {label} ==========\n", flush=True)
HF_TOKEN = os.environ["HF_TOKEN"]
HF_USERNAME = os.environ.get("HF_USERNAME", "akhiilll")
ENV_URL = os.environ.get("ENV_URL", f"https://{HF_USERNAME}-forgeenv.hf.space")
SOURCE_REPO = os.environ.get("SOURCE_REPO", f"{HF_USERNAME}/forgeenv-source")
MODEL_REPO = os.environ.get("MODEL_REPO", f"{HF_USERNAME}/forgeenv-repair-agent")
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
SFT_STEPS = int(os.environ.get("SFT_STEPS", "1000"))
GRPO_STEPS = int(os.environ.get("GRPO_STEPS", "200"))
WORK = Path("/tmp/forgeenv_work")
WORK.mkdir(parents=True, exist_ok=True)
OUT = WORK / "outputs"
OUT.mkdir(parents=True, exist_ok=True)
SFT_DIR = OUT / "sft"
GRPO_DIR = OUT / "grpo"
PLOTS_DIR = OUT / "plots"
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
step("0. clone source from Hub")
src_dir = WORK / "src"
if src_dir.exists():
shutil.rmtree(src_dir)
_sh([
"git", "clone",
f"https://USER:{HF_TOKEN}@huggingface.co/{SOURCE_REPO}",
str(src_dir),
])
# Belt-and-braces: prepend the source dir to sys.path so `import forgeenv`
# works even if `pip install -e` doesn't persist inside the uv-managed
# venv. We still run pip install for any setuptools side-effects.
sys.path.insert(0, str(src_dir))
step("1. pin torch (cu124) + install GPU-stable deps")
# Force a CUDA 12.4 torch wheel BEFORE anything else so other packages'
# resolvers don't pull a cu130 wheel that mismatches the host driver
# (Error 802 on some HF Job flavors). TRL 1.2+ imports ``FSDPModule`` from
# ``torch.distributed.fsdp``, which exists only in PyTorch >= 2.6 — do not
# pin to 2.5.x.
_sh([
sys.executable, "-m", "pip", "install",
"--index-url", "https://download.pytorch.org/whl/cu124",
"torch==2.6.0", "torchvision==0.21.0",
])
# `--no-deps` on openenv-core: it pins a different transformers/torch
# stack that we don't want. We still need its *runtime* imports:
# ``import forgeenv`` -> ``ForgeEnvironment`` -> ``openenv.core`` pulls in
# ``fastmcp`` (and friends) from ``openenv.core.env_server``.
_sh([
sys.executable, "-m", "pip", "install", "--no-deps",
"openenv-core>=0.2.0",
])
_sh([
sys.executable, "-m", "pip", "install",
"fastmcp>=3.0.0",
"gradio>=4.0.0",
"openai>=2.7.2",
"tomli>=2.3.0",
"tomli-w>=1.2.0",
"websockets>=15.0.1",
])
_sh([
sys.executable, "-m", "pip", "install",
"trl==1.2.0", "peft", "accelerate", "datasets",
"bitsandbytes",
"matplotlib", "pyyaml", "nltk", "scikit-learn",
"fastapi", "uvicorn", "pydantic", "requests",
"sentencepiece", "protobuf",
])
try:
# --no-deps is critical: prevents unsloth from re-resolving torch.
_sh([sys.executable, "-m", "pip", "install", "--no-deps", "unsloth", "unsloth-zoo"])
except subprocess.CalledProcessError:
print("[job] WARN: unsloth install failed — trainer will use plain HF.", flush=True)
import torch # noqa: E402
print(f"[job] torch: {torch.__version__}", flush=True)
print(f"[job] CUDA available: {torch.cuda.is_available()}", flush=True)
if torch.cuda.is_available():
print(f"[job] GPU: {torch.cuda.get_device_name(0)}", flush=True)
print(
f"[job] VRAM: "
f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB",
flush=True,
)
else:
raise SystemExit("[job] FATAL: no CUDA — refusing to run training on CPU.")
step("2. ping live env Space + verify forgeenv import")
import requests # noqa: E402
try:
r = requests.get(f"{ENV_URL}/health", timeout=20)
print(f"[job] env /health -> {r.status_code} {r.text}", flush=True)
except Exception as e: # noqa: BLE001
print(f"[job] WARN: env ping failed: {e}", flush=True)
# Fail fast if forgeenv isn't on the path -- much cheaper to discover
# this here than after 8+ minutes of SFT.
import forgeenv # noqa: F401, E402
from forgeenv.training.grpo_repair import run_grpo # noqa: F401, E402
print("[job] forgeenv import OK", flush=True)
step("3. SFT: load Qwen + LoRA via Unsloth, train on warm-start pairs")
from unsloth import FastLanguageModel # noqa: E402
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL,
max_seq_length=2048,
load_in_4bit=True,
dtype=None,
token=HF_TOKEN,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
lora_alpha=32,
lora_dropout=0,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
use_gradient_checkpointing="unsloth",
)
print(
f"[job] trainable params: "
f"{model.num_parameters(only_trainable=True):,}",
flush=True,
)
import datasets as ds # noqa: E402
from trl import SFTConfig, SFTTrainer # noqa: E402
sft_jsonl = src_dir / "warmstart" / "data" / "repair_pairs.jsonl"
if not sft_jsonl.exists():
sft_jsonl = src_dir / "warmstart" / "data" / "drift_pairs.jsonl"
print(f"[job] SFT pairs: {sft_jsonl}", flush=True)
def _format_chat(example):
msgs = example.get("messages")
if not msgs:
return {"text": ""}
return {
"text": tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=False
)
}
sft_ds = ds.load_dataset("json", data_files=str(sft_jsonl), split="train")
sft_ds = sft_ds.map(_format_chat, remove_columns=sft_ds.column_names)
sft_trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=sft_ds,
args=SFTConfig(
output_dir=str(SFT_DIR),
max_steps=SFT_STEPS,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=25,
save_steps=max(250, SFT_STEPS // 4),
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported(),
max_length=2048,
packing=True,
packing_strategy="bfd",
report_to=[],
),
)
sft_trainer.train()
model.save_pretrained(str(SFT_DIR))
tokenizer.save_pretrained(str(SFT_DIR))
# free memory before GRPO reloads the model
del sft_trainer, model, tokenizer
import gc
gc.collect()
torch.cuda.empty_cache()
step("4. GRPO repair training (resumes from SFT adapter)")
from forgeenv.training.grpo_repair import run_grpo # noqa: E402
run_grpo(
base_model=BASE_MODEL,
adapter_path=str(SFT_DIR),
output_dir=str(GRPO_DIR),
total_episodes=GRPO_STEPS,
group_size=4,
learning_rate=5e-6,
)
step("5. generate plots from training logs")
from forgeenv.training.plots import ( # noqa: E402
plot_baseline_vs_trained,
plot_reward_curve,
plot_success_rate_by_category,
)
# TRL writes trainer_state.json under each checkpoint dir, not directly
# at output_dir. Pick the latest checkpoint, fall back to output_dir.
def _find_trainer_state(grpo_dir: Path) -> Optional[Path]: # type: ignore[name-defined]
direct = grpo_dir / "trainer_state.json"
if direct.exists():
return direct
ckpts = sorted(
(p for p in grpo_dir.glob("checkpoint-*") if (p / "trainer_state.json").exists()),
key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1,
)
return (ckpts[-1] / "trainer_state.json") if ckpts else None
from typing import Optional # noqa: E402
trainer_state = _find_trainer_state(GRPO_DIR)
print(f"[job] trainer_state path: {trainer_state}", flush=True)
training_rewards: list[float] = []
if trainer_state is not None and trainer_state.exists():
state = json.loads(trainer_state.read_text())
log_history = state.get("log_history", [])
print(f"[job] log_history rows: {len(log_history)}", flush=True)
if log_history:
sample_keys = sorted(set().union(*(log.keys() for log in log_history)))
print(f"[job] log keys present: {sample_keys}", flush=True)
for log in log_history:
# TRL emits a few different reward keys depending on version;
# try the most specific first, then fall back.
candidates = [
"rewards/reward_repair_function/mean",
"rewards/mean",
"reward",
"train/reward",
]
# also pick up any key matching rewards/<name>/mean
for k in list(log.keys()):
if k.startswith("rewards/") and k.endswith("/mean") and k not in candidates:
candidates.append(k)
for k in candidates:
if k in log:
training_rewards.append(float(log[k]))
break
print(f"[job] {len(training_rewards)} reward log points", flush=True)
if training_rewards:
print(
f"[job] reward range: {min(training_rewards):.3f}..{max(training_rewards):.3f}",
flush=True,
)
plot_reward_curve(
training_rewards or [0.0],
str(PLOTS_DIR / "training_reward_curve.png"),
)
# we keep the CPU artifacts for baseline_vs_trained; if you want a real
# eval pass post-training, run the rollout helper here. The artifact
# generator already produced these from the dry-run.
src_plots = src_dir / "artifacts" / "plots"
for name in ("baseline_vs_trained.png", "success_by_category.png"):
src_p = src_plots / name
if src_p.exists():
shutil.copy(src_p, PLOTS_DIR / name)
step("6. push LoRA + artifacts to Hub")
final_dir = OUT / "final"
final_dir.mkdir(parents=True, exist_ok=True)
for item in GRPO_DIR.iterdir():
if item.is_file() and (
item.name.startswith("adapter_")
or item.name.startswith("tokenizer")
or item.name in {"special_tokens_map.json", "vocab.json", "merges.txt"}
):
shutil.copy(item, final_dir / item.name)
repair_lib = src_dir / "artifacts" / "repair_library.json"
if repair_lib.exists():
shutil.copy(repair_lib, final_dir / "repair_library.json")
from huggingface_hub import HfApi # noqa: E402
api = HfApi()
api.create_repo(
repo_id=MODEL_REPO,
repo_type="model",
token=HF_TOKEN,
exist_ok=True,
private=False,
)
api.upload_folder(
folder_path=str(final_dir),
repo_id=MODEL_REPO,
repo_type="model",
token=HF_TOKEN,
commit_message=f"GRPO LoRA (sft={SFT_STEPS}, grpo={GRPO_STEPS})",
ignore_patterns=["__pycache__", "*.pyc"],
)
api.upload_folder(
folder_path=str(PLOTS_DIR),
repo_id=MODEL_REPO,
repo_type="model",
token=HF_TOKEN,
path_in_repo="plots",
commit_message="Training plots",
)
print(
f"\n[job] DONE. Model live at https://huggingface.co/{MODEL_REPO}",
flush=True,
)
print(
json.dumps(
{
"sft_steps": SFT_STEPS,
"grpo_steps": GRPO_STEPS,
"rewards_logged": len(training_rewards),
"model_repo": MODEL_REPO,
},
indent=2,
),
flush=True,
)