forgeenv-source / scripts /submit_training_job.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
#!/usr/bin/env python
"""Submit ForgeEnv training as a HF Jobs run on A100 (or any flavor).
Two stages:
1. **Publish source**: uploads the full ``forgeenv`` repo (code + warmstart
data + artifacts) to ``<user>/forgeenv-source`` so the job can clone it.
2. **Submit job**: launches ``scripts/jobs/train_repair_agent.py`` on the
chosen hardware via ``HfApi.run_uv_job``. Streams the job logs back to
your terminal until completion.
Usage::
$env:HF_TOKEN = "hf_..."
python scripts/submit_training_job.py --user akhiilll --flavor a100-large
# add --dry-run to skip the actual submission and just publish source
# add --skip-publish to reuse the existing forgeenv-source repo
# tweak --sft-steps / --grpo-steps for a smoke test
Costs (Hub jobs, before hackathon credits):
a100-large $0.0417/min (~$2.50/hr; full training ~$10-15)
a10g-large $0.0250/min (~$1.50/hr; full training ~$6-9, slower)
t4-small $0.0067/min (~$0.40/hr; smoke tests only)
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
from huggingface_hub import HfApi, JobInfo
REPO_ROOT = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--user", default="akhiilll", help="HF username (owner of source/model repos)")
ap.add_argument("--flavor", default="a100-large", help="HF Jobs hardware flavor")
ap.add_argument("--sft-steps", type=int, default=1000)
ap.add_argument("--grpo-steps", type=int, default=200)
ap.add_argument("--base-model", default="Qwen/Qwen2.5-3B-Instruct")
ap.add_argument("--timeout", default="6h", help="job timeout (e.g. 30m, 2h, 6h)")
ap.add_argument("--skip-publish", action="store_true", help="reuse existing forgeenv-source repo")
ap.add_argument("--dry-run", action="store_true", help="publish source but do not launch the job")
ap.add_argument("--no-tail", action="store_true", help="skip log streaming after submission")
return ap.parse_args()
def publish_source(api: HfApi, token: str, user: str) -> str:
repo_id = f"{user}/forgeenv-source"
print(f"[launcher] publishing source -> {repo_id}", flush=True)
api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=False)
api.upload_folder(
folder_path=str(REPO_ROOT),
repo_id=repo_id,
repo_type="model",
token=token,
commit_message="forgeenv source snapshot for training job",
ignore_patterns=[
"__pycache__",
"*.pyc",
".pytest_cache",
".venv",
"venv",
"*.egg-info",
".git",
".github",
"outputs",
"wandb",
"*.log",
],
)
print(f"[launcher] source live at https://huggingface.co/{repo_id}", flush=True)
return repo_id
def submit_job(
api: HfApi,
token: str,
user: str,
flavor: str,
sft_steps: int,
grpo_steps: int,
base_model: str,
timeout: str,
) -> JobInfo:
# The training script lives in the published source repo. Pass its
# raw Hub URL — `run_uv_job` accepts a URL/path/command, not the
# script body itself.
script_url = (
f"https://huggingface.co/{user}/forgeenv-source/"
"resolve/main/scripts/jobs/train_repair_agent.py"
)
job = api.run_uv_job(
script=script_url,
dependencies=[
"huggingface_hub>=0.27",
"requests",
],
flavor=flavor,
timeout=timeout,
namespace=user,
env={
"HF_USERNAME": user,
"ENV_URL": f"https://{user}-forgeenv.hf.space",
"SOURCE_REPO": f"{user}/forgeenv-source",
"MODEL_REPO": f"{user}/forgeenv-repair-agent",
"BASE_MODEL": base_model,
"SFT_STEPS": str(sft_steps),
"GRPO_STEPS": str(grpo_steps),
"PYTHONUNBUFFERED": "1",
},
secrets={"HF_TOKEN": token},
token=token,
)
return job
_TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "DELETED"}
def _stage_of(info) -> str:
status = getattr(info, "status", None)
if status is None:
return "UNKNOWN"
stage = getattr(status, "stage", None)
if stage is None:
return str(status)
return str(stage)
def tail_logs(api: HfApi, token: str, job_id: str, namespace: str | None = None) -> int:
print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True)
try:
for line in api.fetch_job_logs(job_id=job_id, namespace=namespace, token=token):
print(line, flush=True)
except KeyboardInterrupt:
print("\n[launcher] log stream interrupted by user.", flush=True)
except Exception as e: # noqa: BLE001
print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True)
last_stage: str | None = None
while True:
info = api.inspect_job(job_id=job_id, namespace=namespace, token=token)
stage = _stage_of(info)
if stage != last_stage:
print(f"[launcher] status: {stage}", flush=True)
last_stage = stage
if stage in _TERMINAL_STAGES:
break
time.sleep(20)
print(f"[launcher] final status: {last_stage}", flush=True)
return 0 if last_stage == "COMPLETED" else 1
def main() -> int:
args = parse_args()
token = os.environ.get("HF_TOKEN")
if not token:
print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr)
return 2
api = HfApi()
if not args.skip_publish:
publish_source(api, token, args.user)
if args.dry_run:
print("[launcher] --dry-run set; not submitting job.", flush=True)
return 0
print(
f"[launcher] submitting job (flavor={args.flavor}, sft={args.sft_steps}, "
f"grpo={args.grpo_steps}, timeout={args.timeout}) ...",
flush=True,
)
job = submit_job(
api=api,
token=token,
user=args.user,
flavor=args.flavor,
sft_steps=args.sft_steps,
grpo_steps=args.grpo_steps,
base_model=args.base_model,
timeout=args.timeout,
)
job_id = getattr(job, "id", None) or getattr(job, "job_id", None)
print(f"[launcher] job submitted: id={job_id}", flush=True)
print(f"[launcher] dashboard: https://huggingface.co/jobs/{args.user}", flush=True)
if args.no_tail:
return 0
return tail_logs(api, token, job_id, namespace=args.user)
if __name__ == "__main__":
raise SystemExit(main())