|
|
| """Submit ForgeEnv training as a HF Jobs run on A100 (or any flavor).
|
|
|
| Two stages:
|
|
|
| 1. **Publish source**: uploads the full ``forgeenv`` repo (code + warmstart
|
| data + artifacts) to ``<user>/forgeenv-source`` so the job can clone it.
|
| 2. **Submit job**: launches ``scripts/jobs/train_repair_agent.py`` on the
|
| chosen hardware via ``HfApi.run_uv_job``. Streams the job logs back to
|
| your terminal until completion.
|
|
|
| Usage::
|
|
|
| $env:HF_TOKEN = "hf_..."
|
| python scripts/submit_training_job.py --user akhiilll --flavor a100-large
|
| # add --dry-run to skip the actual submission and just publish source
|
| # add --skip-publish to reuse the existing forgeenv-source repo
|
| # tweak --sft-steps / --grpo-steps for a smoke test
|
|
|
| Costs (Hub jobs, before hackathon credits):
|
| a100-large $0.0417/min (~$2.50/hr; full training ~$10-15)
|
| a10g-large $0.0250/min (~$1.50/hr; full training ~$6-9, slower)
|
| t4-small $0.0067/min (~$0.40/hr; smoke tests only)
|
| """
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import os
|
| import sys
|
| import time
|
| from pathlib import Path
|
|
|
| from huggingface_hub import HfApi, JobInfo
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
| ap.add_argument("--user", default="akhiilll", help="HF username (owner of source/model repos)")
|
| ap.add_argument("--flavor", default="a100-large", help="HF Jobs hardware flavor")
|
| ap.add_argument("--sft-steps", type=int, default=1000)
|
| ap.add_argument("--grpo-steps", type=int, default=200)
|
| ap.add_argument("--base-model", default="Qwen/Qwen2.5-3B-Instruct")
|
| ap.add_argument("--timeout", default="6h", help="job timeout (e.g. 30m, 2h, 6h)")
|
| ap.add_argument("--skip-publish", action="store_true", help="reuse existing forgeenv-source repo")
|
| ap.add_argument("--dry-run", action="store_true", help="publish source but do not launch the job")
|
| ap.add_argument("--no-tail", action="store_true", help="skip log streaming after submission")
|
| return ap.parse_args()
|
|
|
|
|
| def publish_source(api: HfApi, token: str, user: str) -> str:
|
| repo_id = f"{user}/forgeenv-source"
|
| print(f"[launcher] publishing source -> {repo_id}", flush=True)
|
| api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True, private=False)
|
| api.upload_folder(
|
| folder_path=str(REPO_ROOT),
|
| repo_id=repo_id,
|
| repo_type="model",
|
| token=token,
|
| commit_message="forgeenv source snapshot for training job",
|
| ignore_patterns=[
|
| "__pycache__",
|
| "*.pyc",
|
| ".pytest_cache",
|
| ".venv",
|
| "venv",
|
| "*.egg-info",
|
| ".git",
|
| ".github",
|
| "outputs",
|
| "wandb",
|
| "*.log",
|
| ],
|
| )
|
| print(f"[launcher] source live at https://huggingface.co/{repo_id}", flush=True)
|
| return repo_id
|
|
|
|
|
| def submit_job(
|
| api: HfApi,
|
| token: str,
|
| user: str,
|
| flavor: str,
|
| sft_steps: int,
|
| grpo_steps: int,
|
| base_model: str,
|
| timeout: str,
|
| ) -> JobInfo:
|
|
|
|
|
|
|
| script_url = (
|
| f"https://huggingface.co/{user}/forgeenv-source/"
|
| "resolve/main/scripts/jobs/train_repair_agent.py"
|
| )
|
|
|
| job = api.run_uv_job(
|
| script=script_url,
|
| dependencies=[
|
| "huggingface_hub>=0.27",
|
| "requests",
|
| ],
|
| flavor=flavor,
|
| timeout=timeout,
|
| namespace=user,
|
| env={
|
| "HF_USERNAME": user,
|
| "ENV_URL": f"https://{user}-forgeenv.hf.space",
|
| "SOURCE_REPO": f"{user}/forgeenv-source",
|
| "MODEL_REPO": f"{user}/forgeenv-repair-agent",
|
| "BASE_MODEL": base_model,
|
| "SFT_STEPS": str(sft_steps),
|
| "GRPO_STEPS": str(grpo_steps),
|
| "PYTHONUNBUFFERED": "1",
|
| },
|
| secrets={"HF_TOKEN": token},
|
| token=token,
|
| )
|
| return job
|
|
|
|
|
| _TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "DELETED"}
|
|
|
|
|
| def _stage_of(info) -> str:
|
| status = getattr(info, "status", None)
|
| if status is None:
|
| return "UNKNOWN"
|
| stage = getattr(status, "stage", None)
|
| if stage is None:
|
| return str(status)
|
| return str(stage)
|
|
|
|
|
| def tail_logs(api: HfApi, token: str, job_id: str, namespace: str | None = None) -> int:
|
| print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True)
|
| try:
|
| for line in api.fetch_job_logs(job_id=job_id, namespace=namespace, token=token):
|
| print(line, flush=True)
|
| except KeyboardInterrupt:
|
| print("\n[launcher] log stream interrupted by user.", flush=True)
|
| except Exception as e:
|
| print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True)
|
|
|
| last_stage: str | None = None
|
| while True:
|
| info = api.inspect_job(job_id=job_id, namespace=namespace, token=token)
|
| stage = _stage_of(info)
|
| if stage != last_stage:
|
| print(f"[launcher] status: {stage}", flush=True)
|
| last_stage = stage
|
| if stage in _TERMINAL_STAGES:
|
| break
|
| time.sleep(20)
|
|
|
| print(f"[launcher] final status: {last_stage}", flush=True)
|
| return 0 if last_stage == "COMPLETED" else 1
|
|
|
|
|
| def main() -> int:
|
| args = parse_args()
|
| token = os.environ.get("HF_TOKEN")
|
| if not token:
|
| print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr)
|
| return 2
|
|
|
| api = HfApi()
|
|
|
| if not args.skip_publish:
|
| publish_source(api, token, args.user)
|
|
|
| if args.dry_run:
|
| print("[launcher] --dry-run set; not submitting job.", flush=True)
|
| return 0
|
|
|
| print(
|
| f"[launcher] submitting job (flavor={args.flavor}, sft={args.sft_steps}, "
|
| f"grpo={args.grpo_steps}, timeout={args.timeout}) ...",
|
| flush=True,
|
| )
|
| job = submit_job(
|
| api=api,
|
| token=token,
|
| user=args.user,
|
| flavor=args.flavor,
|
| sft_steps=args.sft_steps,
|
| grpo_steps=args.grpo_steps,
|
| base_model=args.base_model,
|
| timeout=args.timeout,
|
| )
|
| job_id = getattr(job, "id", None) or getattr(job, "job_id", None)
|
| print(f"[launcher] job submitted: id={job_id}", flush=True)
|
| print(f"[launcher] dashboard: https://huggingface.co/jobs/{args.user}", flush=True)
|
|
|
| if args.no_tail:
|
| return 0
|
| return tail_logs(api, token, job_id, namespace=args.user)
|
|
|
|
|
| if __name__ == "__main__":
|
| raise SystemExit(main())
|
|
|