Spaces:
Sleeping
Sleeping
File size: 6,946 Bytes
11f97d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """Automated end-to-end training script for HF Spaces.
Runs: scenario download β SFT warmstart β GRPO fine-tuning β push to HF Hub.
All config comes from environment variables (set as Space secrets).
Optimised for A10G Large (46 GB VRAM, 12 vCPU).
Required env vars:
HF_TOKEN - HuggingFace write token
HF_USERNAME - your HF username
WANDB_API_KEY - Weights & Biases API key
Optional:
HF_SCENARIOS_REPO - default: {HF_USERNAME}/ci-triage-scenarios
HF_SFT_DATASET_REPO - default: {HF_USERNAME}/ci-triage-sft
HF_MODEL_REPO - default: {HF_USERNAME}/ci-triage-agent
GRPO_STEPS - default: 100 (set lower to finish faster, higher for more training)
SKIP_SFT - set to "1" to skip SFT and jump straight to GRPO (if checkpoint exists)
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
# ββ resolve config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
HF_TOKEN = os.environ["HF_TOKEN"]
HF_USERNAME = os.environ["HF_USERNAME"]
WANDB_KEY = os.environ.get("WANDB_API_KEY", "")
SCENARIOS_REPO = os.environ.get("HF_SCENARIOS_REPO", f"{HF_USERNAME}/ci-triage-scenarios")
SFT_DATASET_REPO = os.environ.get("HF_SFT_DATASET_REPO", f"{HF_USERNAME}/ci-triage-sft")
MODEL_REPO = os.environ.get("HF_MODEL_REPO", f"{HF_USERNAME}/ci-triage-agent")
GRPO_STEPS = int(os.environ.get("GRPO_STEPS", "100"))
SKIP_SFT = os.environ.get("SKIP_SFT", "0") == "1"
DATA_ROOT = Path("/data")
SCEN_DIR = DATA_ROOT / "scenarios"
SFT_DS_DIR = DATA_ROOT / "sft_dataset"
SFT_CKPT = DATA_ROOT / "checkpoints" / "sft"
GRPO_CKPT = DATA_ROOT / "checkpoints" / "grpo"
# ββ auth ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from huggingface_hub import login
login(token=HF_TOKEN)
if WANDB_KEY:
import wandb
wandb.login(key=WANDB_KEY)
os.environ["WANDB_PROJECT"] = "ci-triage-env"
else:
os.environ["WANDB_DISABLED"] = "true"
# ββ Step 1: download scenario corpus βββββββββββββββββββββββββββββββββββββββββ
if not SCEN_DIR.exists() or not any(SCEN_DIR.rglob("*.json")):
print(f"\n[1/4] Downloading scenarios from {SCENARIOS_REPO} β¦")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=SCENARIOS_REPO,
repo_type="dataset",
local_dir=str(SCEN_DIR),
token=HF_TOKEN,
)
else:
n = sum(1 for _ in SCEN_DIR.rglob("*.json"))
print(f"\n[1/4] Scenarios already present ({n} files) β skipping download.")
train_scen = list(SCEN_DIR.rglob("train/**/*.json")) or list(SCEN_DIR.rglob("*.json"))
print(f" Train scenarios available: {len(train_scen)}")
# ββ Step 2: download SFT dataset βββββββββββββββββββββββββββββββββββββββββββββ
if not SFT_DS_DIR.exists():
print(f"\n[2/4] Downloading SFT dataset from {SFT_DATASET_REPO} β¦")
from datasets import load_dataset
ds = load_dataset(SFT_DATASET_REPO, split="train", token=HF_TOKEN)
SFT_DS_DIR.mkdir(parents=True, exist_ok=True)
ds.save_to_disk(str(SFT_DS_DIR))
print(f" {len(ds)} SFT examples saved.")
else:
from datasets import load_from_disk
ds = load_from_disk(str(SFT_DS_DIR))
print(f"\n[2/4] SFT dataset already present ({len(ds)} examples) β skipping download.")
# ββ Step 3: SFT warmstart βββββββββββββββββββββββββββββββββββββββββββββββββββββ
if SKIP_SFT and SFT_CKPT.exists():
print(f"\n[3/4] SKIP_SFT=1 and checkpoint found at {SFT_CKPT} β skipping SFT.")
else:
print(f"\n[3/4] SFT warmstart β {len(ds)} examples, A10G-optimised settings β¦")
from ci_triage_env.training.sft import run_sft
run_sft(
dataset_path=str(SFT_DS_DIR),
output_dir=str(SFT_CKPT),
num_epochs=2,
per_device_batch_size=4, # 46 GB β fit 4 sequences comfortably
gradient_accumulation_steps=4, # effective batch = 16
)
print(f" SFT done β {SFT_CKPT}")
# Push SFT checkpoint immediately so it's saved even if GRPO fails
print(" Pushing SFT checkpoint to HF Hub β¦")
from huggingface_hub import upload_folder
upload_folder(
folder_path=str(SFT_CKPT),
repo_id=MODEL_REPO + "-sft",
repo_type="model",
token=HF_TOKEN,
commit_message="SFT warmstart checkpoint",
)
# ββ Step 4: GRPO fine-tuning ββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"\n[4/4] GRPO training β {GRPO_STEPS} steps, MockEnvClient in-process β¦")
print(" Monitoring: https://wandb.ai (search project ci-triage-env)")
from ci_triage_env.training.mock_env_client import MockEnvClient
from ci_triage_env.training.grpo import run_grpo
env_client = MockEnvClient(scenarios_dir=str(SCEN_DIR / "train"))
print(f" Loaded {len(env_client.scenario_ids)} train scenarios into MockEnvClient")
# A10G Large optimised hyperparams.
# max_turns=4 + max_completion_length=256 keeps each rollout to ~15 sec so
# 100 steps Γ 4 rollouts β 100 min total β fits the 2-3 hour budget.
run_grpo(
sft_checkpoint_dir=str(SFT_CKPT),
output_dir=str(GRPO_CKPT),
total_steps=GRPO_STEPS,
env_client=env_client,
scenarios_train_path=str(SCEN_DIR / "train"),
hyperparams={
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4, # effective batch = 4
"num_generations": 4,
"max_prompt_length": 2048,
"max_completion_length": 256,
"learning_rate": 5e-6,
"kl_coef": 0.04,
"temperature": 0.8,
"top_p": 0.95,
"logging_steps": 5,
"save_steps": 50,
"report_to": "wandb" if WANDB_KEY else "none",
},
)
print(f" GRPO done β {GRPO_CKPT}")
# ββ Push final model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"\n[done] Pushing final model to {MODEL_REPO} β¦")
from huggingface_hub import upload_folder
upload_folder(
folder_path=str(GRPO_CKPT),
repo_id=MODEL_REPO,
repo_type="model",
token=HF_TOKEN,
commit_message=f"GRPO-trained adapter β {GRPO_STEPS} steps",
)
print(f" Model at: https://huggingface.co/{MODEL_REPO}")
print("\nTraining complete.")
|