File size: 6,774 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | #!/usr/bin/env python
"""Submit ForgeEnv training as a HF Jobs run on A100 (or any flavor).
Two stages:
1. **Publish source**: uploads the full ``forgeenv`` repo (code + warmstart
data + artifacts) to ``<user>/forgeenv-source`` so the job can clone it.
2. **Submit job**: launches ``scripts/jobs/train_repair_agent.py`` on the
chosen hardware via ``HfApi.run_uv_job``. Streams the job logs back to
your terminal until completion.
Usage::
$env:HF_TOKEN = "hf_..."
python scripts/submit_training_job.py --user akhiilll --flavor a100-large
# add --dry-run to skip the actual submission and just publish source
# add --skip-publish to reuse the existing forgeenv-source repo
# tweak --sft-steps / --grpo-steps for a smoke test
Costs (Hub jobs, before hackathon credits):
a100-large $0.0417/min (~$2.50/hr; full training ~$10-15)
a10g-large $0.0250/min (~$1.50/hr; full training ~$6-9, slower)
t4-small $0.0067/min (~$0.40/hr; smoke tests only)
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
from huggingface_hub import HfApi, JobInfo
REPO_ROOT = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--user", default="akhiilll", help="HF username (owner of source/model repos)")
ap.add_argument("--flavor", default="a100-large", help="HF Jobs hardware flavor")
ap.add_argument("--sft-steps", type=int, default=1000)
ap.add_argument("--grpo-steps", type=int, default=200)
ap.add_argument("--base-model", default="Qwen/Qwen2.5-3B-Instruct")
ap.add_argument("--timeout", default="6h", help="job timeout (e.g. 30m, 2h, 6h)")
ap.add_argument("--skip-publish", action="store_true", help="reuse existing forgeenv-source repo")
ap.add_argument("--dry-run", action="store_true", help="publish source but do not launch the job")
ap.add_argument("--no-tail", action="store_true", help="skip log streaming after submission")
return ap.parse_args()
def publish_source(api: HfApi, token: str, user: str) -> str:
repo_id = f"{user}/forgeenv-source"
print(f"[launcher] publishing source -> {repo_id}", flush=True)
api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=False)
api.upload_folder(
folder_path=str(REPO_ROOT),
repo_id=repo_id,
repo_type="model",
token=token,
commit_message="forgeenv source snapshot for training job",
ignore_patterns=[
"__pycache__",
"*.pyc",
".pytest_cache",
".venv",
"venv",
"*.egg-info",
".git",
".github",
"outputs",
"wandb",
"*.log",
],
)
print(f"[launcher] source live at https://huggingface.co/{repo_id}", flush=True)
return repo_id
def submit_job(
api: HfApi,
token: str,
user: str,
flavor: str,
sft_steps: int,
grpo_steps: int,
base_model: str,
timeout: str,
) -> JobInfo:
# The training script lives in the published source repo. Pass its
# raw Hub URL — `run_uv_job` accepts a URL/path/command, not the
# script body itself.
script_url = (
f"https://huggingface.co/{user}/forgeenv-source/"
"resolve/main/scripts/jobs/train_repair_agent.py"
)
job = api.run_uv_job(
script=script_url,
dependencies=[
"huggingface_hub>=0.27",
"requests",
],
flavor=flavor,
timeout=timeout,
namespace=user,
env={
"HF_USERNAME": user,
"ENV_URL": f"https://{user}-forgeenv.hf.space",
"SOURCE_REPO": f"{user}/forgeenv-source",
"MODEL_REPO": f"{user}/forgeenv-repair-agent",
"BASE_MODEL": base_model,
"SFT_STEPS": str(sft_steps),
"GRPO_STEPS": str(grpo_steps),
"PYTHONUNBUFFERED": "1",
},
secrets={"HF_TOKEN": token},
token=token,
)
return job
_TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "DELETED"}
def _stage_of(info) -> str:
status = getattr(info, "status", None)
if status is None:
return "UNKNOWN"
stage = getattr(status, "stage", None)
if stage is None:
return str(status)
return str(stage)
def tail_logs(api: HfApi, token: str, job_id: str, namespace: str | None = None) -> int:
print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True)
try:
for line in api.fetch_job_logs(job_id=job_id, namespace=namespace, token=token):
print(line, flush=True)
except KeyboardInterrupt:
print("\n[launcher] log stream interrupted by user.", flush=True)
except Exception as e: # noqa: BLE001
print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True)
last_stage: str | None = None
while True:
info = api.inspect_job(job_id=job_id, namespace=namespace, token=token)
stage = _stage_of(info)
if stage != last_stage:
print(f"[launcher] status: {stage}", flush=True)
last_stage = stage
if stage in _TERMINAL_STAGES:
break
time.sleep(20)
print(f"[launcher] final status: {last_stage}", flush=True)
return 0 if last_stage == "COMPLETED" else 1
def main() -> int:
args = parse_args()
token = os.environ.get("HF_TOKEN")
if not token:
print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr)
return 2
api = HfApi()
if not args.skip_publish:
publish_source(api, token, args.user)
if args.dry_run:
print("[launcher] --dry-run set; not submitting job.", flush=True)
return 0
print(
f"[launcher] submitting job (flavor={args.flavor}, sft={args.sft_steps}, "
f"grpo={args.grpo_steps}, timeout={args.timeout}) ...",
flush=True,
)
job = submit_job(
api=api,
token=token,
user=args.user,
flavor=args.flavor,
sft_steps=args.sft_steps,
grpo_steps=args.grpo_steps,
base_model=args.base_model,
timeout=args.timeout,
)
job_id = getattr(job, "id", None) or getattr(job, "job_id", None)
print(f"[launcher] job submitted: id={job_id}", flush=True)
print(f"[launcher] dashboard: https://huggingface.co/jobs/{args.user}", flush=True)
if args.no_tail:
return 0
return tail_logs(api, token, job_id, namespace=args.user)
if __name__ == "__main__":
raise SystemExit(main())
|