Spaces:
Sleeping
Sleeping
Prasham.Jain
feat(training): A10G-optimised pipeline β auto train.py, Dockerfile.train, GH Action sync
11f97d8 | """Automated end-to-end training script for HF Spaces. | |
| Runs: scenario download β SFT warmstart β GRPO fine-tuning β push to HF Hub. | |
| All config comes from environment variables (set as Space secrets). | |
| Optimised for A10G Large (46 GB VRAM, 12 vCPU). | |
| Required env vars: | |
| HF_TOKEN - HuggingFace write token | |
| HF_USERNAME - your HF username | |
| WANDB_API_KEY - Weights & Biases API key | |
| Optional: | |
| HF_SCENARIOS_REPO - default: {HF_USERNAME}/ci-triage-scenarios | |
| HF_SFT_DATASET_REPO - default: {HF_USERNAME}/ci-triage-sft | |
| HF_MODEL_REPO - default: {HF_USERNAME}/ci-triage-agent | |
| GRPO_STEPS - default: 100 (set lower to finish faster, higher for more training) | |
| SKIP_SFT - set to "1" to skip SFT and jump straight to GRPO (if checkpoint exists) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # ββ resolve config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| HF_USERNAME = os.environ["HF_USERNAME"] | |
| WANDB_KEY = os.environ.get("WANDB_API_KEY", "") | |
| SCENARIOS_REPO = os.environ.get("HF_SCENARIOS_REPO", f"{HF_USERNAME}/ci-triage-scenarios") | |
| SFT_DATASET_REPO = os.environ.get("HF_SFT_DATASET_REPO", f"{HF_USERNAME}/ci-triage-sft") | |
| MODEL_REPO = os.environ.get("HF_MODEL_REPO", f"{HF_USERNAME}/ci-triage-agent") | |
| GRPO_STEPS = int(os.environ.get("GRPO_STEPS", "100")) | |
| SKIP_SFT = os.environ.get("SKIP_SFT", "0") == "1" | |
| DATA_ROOT = Path("/data") | |
| SCEN_DIR = DATA_ROOT / "scenarios" | |
| SFT_DS_DIR = DATA_ROOT / "sft_dataset" | |
| SFT_CKPT = DATA_ROOT / "checkpoints" / "sft" | |
| GRPO_CKPT = DATA_ROOT / "checkpoints" / "grpo" | |
| # ββ auth ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from huggingface_hub import login | |
| login(token=HF_TOKEN) | |
| if WANDB_KEY: | |
| import wandb | |
| wandb.login(key=WANDB_KEY) | |
| os.environ["WANDB_PROJECT"] = "ci-triage-env" | |
| else: | |
| os.environ["WANDB_DISABLED"] = "true" | |
| # ββ Step 1: download scenario corpus βββββββββββββββββββββββββββββββββββββββββ | |
| if not SCEN_DIR.exists() or not any(SCEN_DIR.rglob("*.json")): | |
| print(f"\n[1/4] Downloading scenarios from {SCENARIOS_REPO} β¦") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id=SCENARIOS_REPO, | |
| repo_type="dataset", | |
| local_dir=str(SCEN_DIR), | |
| token=HF_TOKEN, | |
| ) | |
| else: | |
| n = sum(1 for _ in SCEN_DIR.rglob("*.json")) | |
| print(f"\n[1/4] Scenarios already present ({n} files) β skipping download.") | |
| train_scen = list(SCEN_DIR.rglob("train/**/*.json")) or list(SCEN_DIR.rglob("*.json")) | |
| print(f" Train scenarios available: {len(train_scen)}") | |
| # ββ Step 2: download SFT dataset βββββββββββββββββββββββββββββββββββββββββββββ | |
| if not SFT_DS_DIR.exists(): | |
| print(f"\n[2/4] Downloading SFT dataset from {SFT_DATASET_REPO} β¦") | |
| from datasets import load_dataset | |
| ds = load_dataset(SFT_DATASET_REPO, split="train", token=HF_TOKEN) | |
| SFT_DS_DIR.mkdir(parents=True, exist_ok=True) | |
| ds.save_to_disk(str(SFT_DS_DIR)) | |
| print(f" {len(ds)} SFT examples saved.") | |
| else: | |
| from datasets import load_from_disk | |
| ds = load_from_disk(str(SFT_DS_DIR)) | |
| print(f"\n[2/4] SFT dataset already present ({len(ds)} examples) β skipping download.") | |
| # ββ Step 3: SFT warmstart βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if SKIP_SFT and SFT_CKPT.exists(): | |
| print(f"\n[3/4] SKIP_SFT=1 and checkpoint found at {SFT_CKPT} β skipping SFT.") | |
| else: | |
| print(f"\n[3/4] SFT warmstart β {len(ds)} examples, A10G-optimised settings β¦") | |
| from ci_triage_env.training.sft import run_sft | |
| run_sft( | |
| dataset_path=str(SFT_DS_DIR), | |
| output_dir=str(SFT_CKPT), | |
| num_epochs=2, | |
| per_device_batch_size=4, # 46 GB β fit 4 sequences comfortably | |
| gradient_accumulation_steps=4, # effective batch = 16 | |
| ) | |
| print(f" SFT done β {SFT_CKPT}") | |
| # Push SFT checkpoint immediately so it's saved even if GRPO fails | |
| print(" Pushing SFT checkpoint to HF Hub β¦") | |
| from huggingface_hub import upload_folder | |
| upload_folder( | |
| folder_path=str(SFT_CKPT), | |
| repo_id=MODEL_REPO + "-sft", | |
| repo_type="model", | |
| token=HF_TOKEN, | |
| commit_message="SFT warmstart checkpoint", | |
| ) | |
| # ββ Step 4: GRPO fine-tuning ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n[4/4] GRPO training β {GRPO_STEPS} steps, MockEnvClient in-process β¦") | |
| print(" Monitoring: https://wandb.ai (search project ci-triage-env)") | |
| from ci_triage_env.training.mock_env_client import MockEnvClient | |
| from ci_triage_env.training.grpo import run_grpo | |
| env_client = MockEnvClient(scenarios_dir=str(SCEN_DIR / "train")) | |
| print(f" Loaded {len(env_client.scenario_ids)} train scenarios into MockEnvClient") | |
| # A10G Large optimised hyperparams. | |
| # max_turns=4 + max_completion_length=256 keeps each rollout to ~15 sec so | |
| # 100 steps Γ 4 rollouts β 100 min total β fits the 2-3 hour budget. | |
| run_grpo( | |
| sft_checkpoint_dir=str(SFT_CKPT), | |
| output_dir=str(GRPO_CKPT), | |
| total_steps=GRPO_STEPS, | |
| env_client=env_client, | |
| scenarios_train_path=str(SCEN_DIR / "train"), | |
| hyperparams={ | |
| "per_device_train_batch_size": 1, | |
| "gradient_accumulation_steps": 4, # effective batch = 4 | |
| "num_generations": 4, | |
| "max_prompt_length": 2048, | |
| "max_completion_length": 256, | |
| "learning_rate": 5e-6, | |
| "kl_coef": 0.04, | |
| "temperature": 0.8, | |
| "top_p": 0.95, | |
| "logging_steps": 5, | |
| "save_steps": 50, | |
| "report_to": "wandb" if WANDB_KEY else "none", | |
| }, | |
| ) | |
| print(f" GRPO done β {GRPO_CKPT}") | |
| # ββ Push final model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n[done] Pushing final model to {MODEL_REPO} β¦") | |
| from huggingface_hub import upload_folder | |
| upload_folder( | |
| folder_path=str(GRPO_CKPT), | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| token=HF_TOKEN, | |
| commit_message=f"GRPO-trained adapter β {GRPO_STEPS} steps", | |
| ) | |
| print(f" Model at: https://huggingface.co/{MODEL_REPO}") | |
| print("\nTraining complete.") | |