| |
| """Submit ForgeEnv training as a HF Jobs run on A100 (or any flavor). |
| |
| Two stages: |
| |
| 1. **Publish source**: uploads the full ``forgeenv`` repo (code + warmstart |
| data + artifacts) to ``<user>/forgeenv-source`` so the job can clone it. |
| 2. **Submit job**: launches ``scripts/jobs/train_repair_agent.py`` on the |
| chosen hardware via ``HfApi.run_uv_job``. Streams the job logs back to |
| your terminal until completion. |
| |
| Usage:: |
| |
| $env:HF_TOKEN = "hf_..." |
| python scripts/submit_training_job.py --user akhiilll --flavor a100-large |
| # add --dry-run to skip the actual submission and just publish source |
| # add --skip-publish to reuse the existing forgeenv-source repo |
| # tweak --sft-steps / --grpo-steps for a smoke test |
| |
| Costs (Hub jobs, before hackathon credits): |
| a100-large $0.0417/min (~$2.50/hr; full training ~$10-15) |
| a10g-large $0.0250/min (~$1.50/hr; full training ~$6-9, slower) |
| t4-small $0.0067/min (~$0.40/hr; smoke tests only) |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| from huggingface_hub import HfApi, JobInfo |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) |
| ap.add_argument("--user", default="akhiilll", help="HF username (owner of source/model repos)") |
| ap.add_argument("--flavor", default="a100-large", help="HF Jobs hardware flavor") |
| ap.add_argument("--sft-steps", type=int, default=1000) |
| ap.add_argument("--grpo-steps", type=int, default=200) |
| ap.add_argument("--base-model", default="Qwen/Qwen2.5-3B-Instruct") |
| ap.add_argument("--timeout", default="6h", help="job timeout (e.g. 30m, 2h, 6h)") |
| ap.add_argument("--skip-publish", action="store_true", help="reuse existing forgeenv-source repo") |
| ap.add_argument("--dry-run", action="store_true", help="publish source but do not launch the job") |
| ap.add_argument("--no-tail", action="store_true", help="skip log streaming after submission") |
| return ap.parse_args() |
|
|
|
|
| def publish_source(api: HfApi, token: str, user: str) -> str: |
| repo_id = f"{user}/forgeenv-source" |
| print(f"[launcher] publishing source -> {repo_id}", flush=True) |
| api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=False) |
| api.upload_folder( |
| folder_path=str(REPO_ROOT), |
| repo_id=repo_id, |
| repo_type="model", |
| token=token, |
| commit_message="forgeenv source snapshot for training job", |
| ignore_patterns=[ |
| "__pycache__", |
| "*.pyc", |
| ".pytest_cache", |
| ".venv", |
| "venv", |
| "*.egg-info", |
| ".git", |
| ".github", |
| "outputs", |
| "wandb", |
| "*.log", |
| ], |
| ) |
| print(f"[launcher] source live at https://huggingface.co/{repo_id}", flush=True) |
| return repo_id |
|
|
|
|
| def submit_job( |
| api: HfApi, |
| token: str, |
| user: str, |
| flavor: str, |
| sft_steps: int, |
| grpo_steps: int, |
| base_model: str, |
| timeout: str, |
| ) -> JobInfo: |
| |
| |
| |
| script_url = ( |
| f"https://huggingface.co/{user}/forgeenv-source/" |
| "resolve/main/scripts/jobs/train_repair_agent.py" |
| ) |
|
|
| job = api.run_uv_job( |
| script=script_url, |
| dependencies=[ |
| "huggingface_hub>=0.27", |
| "requests", |
| ], |
| flavor=flavor, |
| timeout=timeout, |
| namespace=user, |
| env={ |
| "HF_USERNAME": user, |
| "ENV_URL": f"https://{user}-forgeenv.hf.space", |
| "SOURCE_REPO": f"{user}/forgeenv-source", |
| "MODEL_REPO": f"{user}/forgeenv-repair-agent", |
| "BASE_MODEL": base_model, |
| "SFT_STEPS": str(sft_steps), |
| "GRPO_STEPS": str(grpo_steps), |
| "PYTHONUNBUFFERED": "1", |
| }, |
| secrets={"HF_TOKEN": token}, |
| token=token, |
| ) |
| return job |
|
|
|
|
| _TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "DELETED"} |
|
|
|
|
| def _stage_of(info) -> str: |
| status = getattr(info, "status", None) |
| if status is None: |
| return "UNKNOWN" |
| stage = getattr(status, "stage", None) |
| if stage is None: |
| return str(status) |
| return str(stage) |
|
|
|
|
| def tail_logs(api: HfApi, token: str, job_id: str, namespace: str | None = None) -> int: |
| print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True) |
| try: |
| for line in api.fetch_job_logs(job_id=job_id, namespace=namespace, token=token): |
| print(line, flush=True) |
| except KeyboardInterrupt: |
| print("\n[launcher] log stream interrupted by user.", flush=True) |
| except Exception as e: |
| print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True) |
|
|
| last_stage: str | None = None |
| while True: |
| info = api.inspect_job(job_id=job_id, namespace=namespace, token=token) |
| stage = _stage_of(info) |
| if stage != last_stage: |
| print(f"[launcher] status: {stage}", flush=True) |
| last_stage = stage |
| if stage in _TERMINAL_STAGES: |
| break |
| time.sleep(20) |
|
|
| print(f"[launcher] final status: {last_stage}", flush=True) |
| return 0 if last_stage == "COMPLETED" else 1 |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| token = os.environ.get("HF_TOKEN") |
| if not token: |
| print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr) |
| return 2 |
|
|
| api = HfApi() |
|
|
| if not args.skip_publish: |
| publish_source(api, token, args.user) |
|
|
| if args.dry_run: |
| print("[launcher] --dry-run set; not submitting job.", flush=True) |
| return 0 |
|
|
| print( |
| f"[launcher] submitting job (flavor={args.flavor}, sft={args.sft_steps}, " |
| f"grpo={args.grpo_steps}, timeout={args.timeout}) ...", |
| flush=True, |
| ) |
| job = submit_job( |
| api=api, |
| token=token, |
| user=args.user, |
| flavor=args.flavor, |
| sft_steps=args.sft_steps, |
| grpo_steps=args.grpo_steps, |
| base_model=args.base_model, |
| timeout=args.timeout, |
| ) |
| job_id = getattr(job, "id", None) or getattr(job, "job_id", None) |
| print(f"[launcher] job submitted: id={job_id}", flush=True) |
| print(f"[launcher] dashboard: https://huggingface.co/jobs/{args.user}", flush=True) |
|
|
| if args.no_tail: |
| return 0 |
| return tail_logs(api, token, job_id, namespace=args.user) |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|