Prasham.Jain Claude Sonnet 4.6 commited on
Commit
11f97d8
·
1 Parent(s): 1134123

feat(training): A10G-optimised pipeline — auto train.py, Dockerfile.train, GH Action sync

Browse files

- Add fastmcp to pyproject.toml (was imported but not declared — env server Dockerfile failed)
- Add train.py: fully automated SFT→GRPO→push script for HF Space auto-run
- Add Dockerfile.train: training Space image (JupyterLab on :7860 or auto train.py)
- Add train-entrypoint.sh: START_MODE=jupyter|auto switch
- Add .github/workflows/sync_hf_space.yml: push main → HF env-server Space on every commit
- Rewrite train_grpo.ipynb: remove google.colab, fix for HF Spaces env vars, remove
unnecessary env server subprocess, tune hyperparams for 46 GB VRAM
- grpo.py: pass max_turns through hyperparams (default 4 for fast GRPO episodes)

Timing targets on A10G Large:
SFT (2 epochs, batch 4, grad_accum 4): ~45 min
GRPO (100 steps, 4 rollouts, max_turns=4, 256 completion tokens): ~90 min
Total: ~2.5 hours

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

.github/workflows/sync_hf_space.yml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync → HF Space (env server)
2
+
3
+ # Pushes the repo to the env-server HF Space on every commit to main.
4
+ # The Space rebuilds its Docker image automatically, which downloads
5
+ # scenarios from HF Hub and restarts the env server.
6
+ #
7
+ # Required GitHub secrets:
8
+ # HF_TOKEN - HuggingFace write token (Settings → Secrets → Actions)
9
+ # HF_USERNAME - your HuggingFace username
10
+ # ENV_SPACE_NAME - name of your env-server Space (e.g. "ci-triage-env")
11
+ #
12
+ # The training Space is NOT auto-synced here (rebuilding mid-training would
13
+ # kill a running job). Manually push to it when you want to update.
14
+
15
+ on:
16
+ push:
17
+ branches: [main]
18
+
19
+ jobs:
20
+ sync-env-space:
21
+ runs-on: ubuntu-latest
22
+ steps:
23
+ - name: Checkout repo (full history)
24
+ uses: actions/checkout@v4
25
+ with:
26
+ fetch-depth: 0
27
+ lfs: true
28
+
29
+ - name: Push to HF Space
30
+ env:
31
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
32
+ HF_USERNAME: ${{ secrets.HF_USERNAME }}
33
+ ENV_SPACE_NAME: ${{ secrets.ENV_SPACE_NAME }}
34
+ run: |
35
+ git config --global user.email "github-action@ci-triage"
36
+ git config --global user.name "CI Triage Sync"
37
+
38
+ REMOTE="https://${HF_USERNAME}:${HF_TOKEN}@huggingface.co/spaces/${HF_USERNAME}/${ENV_SPACE_NAME}"
39
+ git remote add hf-env "$REMOTE" 2>/dev/null || git remote set-url hf-env "$REMOTE"
40
+
41
+ # Force-push main → Space repo (Space will auto-rebuild Docker image)
42
+ git push hf-env HEAD:main --force
43
+ echo "✓ Pushed to https://huggingface.co/spaces/${HF_USERNAME}/${ENV_SPACE_NAME}"
Dockerfile.train ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Space Dockerfile — JupyterLab + auto-run on port 7860.
2
+ #
3
+ # Two modes (controlled by START_MODE env var in Space settings):
4
+ # START_MODE=jupyter → opens JupyterLab so you can run train_grpo.ipynb manually
5
+ # START_MODE=auto → runs train.py immediately, no interaction needed
6
+ #
7
+ # HF Space secrets to set:
8
+ # HF_TOKEN, HF_USERNAME, WANDB_API_KEY
9
+ # HF_SCENARIOS_REPO, HF_SFT_DATASET_REPO, HF_MODEL_REPO (optional)
10
+ # GRPO_STEPS (optional, default 100)
11
+
12
+ FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel
13
+
14
+ ENV DEBIAN_FRONTEND=noninteractive
15
+ ENV PYTHONUNBUFFERED=1
16
+
17
+ RUN apt-get update && apt-get install -y --no-install-recommends \
18
+ git curl build-essential \
19
+ && rm -rf /var/lib/apt/lists/*
20
+
21
+ WORKDIR /workspace
22
+
23
+ # 1. Install unsloth (must come after torch, hence not in pyproject extras)
24
+ RUN pip install --no-cache-dir \
25
+ "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
26
+
27
+ # 2. Install project + all training deps
28
+ COPY pyproject.toml ./
29
+ COPY src/ src/
30
+ RUN pip install --no-cache-dir -e ".[data,training]"
31
+
32
+ # 3. JupyterLab for interactive mode
33
+ RUN pip install --no-cache-dir jupyterlab ipywidgets
34
+
35
+ # 4. Copy notebooks and training scripts
36
+ COPY notebooks/ notebooks/
37
+ COPY train.py ./
38
+
39
+ # Persistent storage expected at /data (attach 20 GB disk in Space settings)
40
+ RUN mkdir -p /data/checkpoints /data/scenarios /data/sft_dataset
41
+
42
+ EXPOSE 7860
43
+
44
+ ENV START_MODE=jupyter
45
+
46
+ COPY train-entrypoint.sh /train-entrypoint.sh
47
+ RUN chmod +x /train-entrypoint.sh
48
+ ENTRYPOINT ["/train-entrypoint.sh"]
notebooks/train_grpo.ipynb CHANGED
@@ -2,176 +2,250 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
 
5
  "metadata": {},
6
  "source": [
7
- "# CI-Triage-Env — GRPO Training Notebook\n",
8
  "\n",
9
- "Colab-runnable end-to-end training pipeline:\n",
 
 
10
  "1. Install dependencies\n",
11
- "2. Pull scenario corpus from HF Hub\n",
12
- "3. Start env server\n",
13
- "4. SFT warmstart on C3 trajectory dataset\n",
14
- "5. GRPO smoke test (100 steps)\n",
15
- "6. Full GRPO (3000 steps)\n",
16
- "7. Push adapter to HF Hub\n",
17
- "\n",
18
- "**Prerequisites**: `HF_TOKEN`, `OPENAI_API_KEY`, `WANDB_API_KEY` set as Colab secrets."
 
 
 
 
 
 
19
  ]
20
  },
21
  {
22
  "cell_type": "code",
23
  "execution_count": null,
 
24
  "metadata": {},
25
  "outputs": [],
26
  "source": [
27
- "# Cell 1: Install dependencies\n",
28
- "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
29
- "!pip install -q unsloth trl transformers accelerate peft\n",
30
- "!pip install -q wandb datasets huggingface_hub openai httpx fastapi uvicorn pydantic jsonschema\n",
31
- "!pip install -q -e . # install ci_triage_env package in editable mode"
 
 
 
 
 
 
 
 
 
 
 
32
  ]
33
  },
34
  {
35
  "cell_type": "code",
36
  "execution_count": null,
 
37
  "metadata": {},
38
  "outputs": [],
39
  "source": [
40
- "# Cell 2: Environment setup\n",
41
  "import os\n",
42
- "from google.colab import userdata\n",
 
43
  "\n",
44
- "os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
45
- "os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n",
46
- "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n",
47
- "os.environ['WANDB_PROJECT'] = 'ci-triage-env'\n",
48
  "\n",
49
- "import wandb\n",
50
- "wandb.login()"
51
- ]
52
- },
53
- {
54
- "cell_type": "code",
55
- "execution_count": null,
56
- "metadata": {},
57
- "outputs": [],
58
- "source": [
59
- "# Cell 3: Pull scenario corpus from HF dataset hub\n",
60
- "# Replace YOUR_ORG with your HuggingFace org/username\n",
61
- "HF_DATASET_REPO = 'YOUR_ORG/ci-triage-scenarios'\n",
62
- "HF_MODEL_REPO = 'YOUR_ORG/ci-triage-trained-qwen3.5-4b'\n",
63
  "\n",
64
- "from huggingface_hub import snapshot_download\n",
65
- "scen_dir = snapshot_download(HF_DATASET_REPO, repo_type='dataset',\n",
66
- " local_dir='data_artifacts/scenarios')\n",
67
- "print(f'Scenarios downloaded to {scen_dir}')"
 
68
  ]
69
  },
70
  {
71
  "cell_type": "code",
72
  "execution_count": null,
 
73
  "metadata": {},
74
  "outputs": [],
75
  "source": [
76
- "# Cell 4: Start env server in background\n",
77
- "import subprocess, time\n",
78
- "server_proc = subprocess.Popen(\n",
79
- " ['python', '-m', 'ci_triage_env.env.server'],\n",
80
- " stdout=subprocess.PIPE, stderr=subprocess.PIPE\n",
81
- ")\n",
82
- "time.sleep(4) # give server time to start\n",
83
- "print('Env server started, PID:', server_proc.pid)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ]
85
  },
86
  {
87
  "cell_type": "code",
88
  "execution_count": null,
 
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
92
- "# Cell 5: Generate SFT trajectories (skip if already done)\n",
93
- "import os\n",
94
- "if not os.path.exists('data_artifacts/sft_dataset'):\n",
95
- " from ci_triage_env.training.trajectory_gen import main as traj_main\n",
96
- " traj_main([\n",
97
- " '--count', '600',\n",
98
- " '--model', 'gpt-4o-mini',\n",
99
- " '--budget', '25.0',\n",
100
- " '--output', 'data_artifacts/sft_dataset/',\n",
101
- " ])\n",
102
  "else:\n",
103
- " print('SFT dataset already exists, skipping generation.')"
 
 
 
104
  ]
105
  },
106
  {
107
  "cell_type": "code",
108
  "execution_count": null,
 
109
  "metadata": {},
110
  "outputs": [],
111
  "source": [
112
- "# Cell 6: SFT warmstart\n",
 
 
113
  "from ci_triage_env.training.sft import run_sft\n",
114
  "\n",
115
- "run_sft(\n",
116
- " dataset_path='data_artifacts/sft_dataset/',\n",
117
- " output_dir='checkpoints/sft/',\n",
118
- " num_epochs=3,\n",
119
- " per_device_batch_size=1,\n",
120
- " gradient_accumulation_steps=4,\n",
121
- ")\n",
122
- "print('SFT complete. Checkpoint at checkpoints/sft/')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  ]
124
  },
125
  {
126
  "cell_type": "code",
127
  "execution_count": null,
 
128
  "metadata": {},
129
  "outputs": [],
130
  "source": [
131
- "# Cell 7: GRPO smoke test (100 steps, ~30 min)\n",
 
 
 
 
 
 
 
 
 
132
  "from ci_triage_env.training.grpo import run_grpo\n",
133
  "\n",
 
 
 
 
 
 
 
 
134
  "run_grpo(\n",
135
- " sft_checkpoint_dir='checkpoints/sft/',\n",
136
- " output_dir='checkpoints/grpo_smoke/',\n",
137
- " total_steps=100,\n",
138
- ")\n",
139
- "print('Smoke test complete. Check W&B for reward curve.')"
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": null,
145
- "metadata": {},
146
- "outputs": [],
147
- "source": [
148
- "# Cell 8: Full GRPO run (3000 steps, ~30h wall-clock)\n",
149
- "# Monitor: https://wandb.ai/<entity>/ci-triage-env\n",
150
- "# Hard-stop rules: see plan/branch-c-reward-training/phase-c4.md\n",
151
- "run_grpo(\n",
152
- " sft_checkpoint_dir='checkpoints/sft/',\n",
153
- " output_dir='checkpoints/grpo_full/',\n",
154
- " total_steps=3000,\n",
 
 
155
  ")\n",
156
- "print('Full GRPO complete.')"
157
  ]
158
  },
159
  {
160
  "cell_type": "code",
161
  "execution_count": null,
 
162
  "metadata": {},
163
  "outputs": [],
164
  "source": [
165
- "# Cell 9: Push trained adapter to HF Hub\n",
166
  "from huggingface_hub import upload_folder\n",
167
  "\n",
168
  "upload_folder(\n",
169
- " folder_path='checkpoints/grpo_full/',\n",
170
- " repo_id=HF_MODEL_REPO,\n",
171
  " repo_type='model',\n",
172
- " commit_message='CI-Triage-Env GRPO-trained Qwen3.5-4B adapter',\n",
 
173
  ")\n",
174
- "print(f'Adapter pushed to https://huggingface.co/{HF_MODEL_REPO}')"
175
  ]
176
  }
177
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
+ "id": "intro",
6
  "metadata": {},
7
  "source": [
8
+ "# CI-Triage-Env — GRPO Training\n",
9
  "\n",
10
+ "**Hardware target**: A10G Large (46 GB VRAM, 12 vCPU) — HuggingFace Space\n",
11
+ "\n",
12
+ "Pipeline:\n",
13
  "1. Install dependencies\n",
14
+ "2. Authenticate (HF + W&B)\n",
15
+ "3. Pull scenario corpus from HF Hub\n",
16
+ "4. Pull SFT dataset from HF Hub\n",
17
+ "5. SFT warmstart (~45 min)\n",
18
+ "6. GRPO fine-tuning (~90 min for 100 steps)\n",
19
+ "7. Push final model to HF Hub\n",
20
+ "\n",
21
+ "**Set these as Space secrets** (Settings Variables and secrets):\n",
22
+ "- `HF_TOKEN` — HuggingFace write token\n",
23
+ "- `HF_USERNAME` — your HF username\n",
24
+ "- `WANDB_API_KEY` — Weights & Biases key (get free at wandb.ai)\n",
25
+ "\n",
26
+ "**Time budget**: SFT≈45 min + GRPO≈90 min = ~2.5 hours on A10G Large.\n",
27
+ "Monitor training live at https://wandb.ai (project: `ci-triage-env`)."
28
  ]
29
  },
30
  {
31
  "cell_type": "code",
32
  "execution_count": null,
33
+ "id": "cell-install",
34
  "metadata": {},
35
  "outputs": [],
36
  "source": [
37
+ "# Cell 1 Install deps (run once; ~10 min including unsloth compile)\n",
38
+ "import subprocess, sys\n",
39
+ "\n",
40
+ "def run(cmd):\n",
41
+ " result = subprocess.run(cmd, shell=True, capture_output=True, text=True)\n",
42
+ " if result.returncode != 0:\n",
43
+ " print(result.stderr[-2000:])\n",
44
+ " raise RuntimeError(f'Command failed: {cmd}')\n",
45
+ " return result.stdout\n",
46
+ "\n",
47
+ "# PyTorch is pre-installed in the Space Docker image; install the rest\n",
48
+ "run('pip install -q \"unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git\"')\n",
49
+ "run('pip install -q trl>=0.11 transformers>=4.45 accelerate>=0.30 peft')\n",
50
+ "run('pip install -q wandb datasets huggingface_hub')\n",
51
+ "run('pip install -q -e /workspace') # install ci_triage_env package\n",
52
+ "print('All dependencies installed.')"
53
  ]
54
  },
55
  {
56
  "cell_type": "code",
57
  "execution_count": null,
58
+ "id": "cell-auth",
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
62
+ "# Cell 2 Authenticate\n",
63
  "import os\n",
64
+ "from huggingface_hub import login\n",
65
+ "import wandb\n",
66
  "\n",
67
+ "HF_TOKEN = os.environ['HF_TOKEN']\n",
68
+ "HF_USERNAME = os.environ['HF_USERNAME']\n",
69
+ "WANDB_KEY = os.environ.get('WANDB_API_KEY', '')\n",
 
70
  "\n",
71
+ "login(token=HF_TOKEN)\n",
72
+ "if WANDB_KEY:\n",
73
+ " wandb.login(key=WANDB_KEY)\n",
74
+ " os.environ['WANDB_PROJECT'] = 'ci-triage-env'\n",
75
+ "else:\n",
76
+ " os.environ['WANDB_DISABLED'] = 'true'\n",
77
+ " print('W&B disabled — set WANDB_API_KEY secret to enable')\n",
 
 
 
 
 
 
 
78
  "\n",
79
+ "# Repo names (edit if you used different names)\n",
80
+ "SCENARIOS_REPO = f'{HF_USERNAME}/ci-triage-scenarios'\n",
81
+ "SFT_DATASET_REPO = f'{HF_USERNAME}/ci-triage-sft'\n",
82
+ "MODEL_REPO = f'{HF_USERNAME}/ci-triage-agent'\n",
83
+ "print(f'Authenticated as {HF_USERNAME}')"
84
  ]
85
  },
86
  {
87
  "cell_type": "code",
88
  "execution_count": null,
89
+ "id": "cell-scenarios",
90
  "metadata": {},
91
  "outputs": [],
92
  "source": [
93
+ "# Cell 3 Download scenario corpus from HF Hub\n",
94
+ "from pathlib import Path\n",
95
+ "from huggingface_hub import snapshot_download\n",
96
+ "\n",
97
+ "SCEN_DIR = Path('/data/scenarios')\n",
98
+ "SCEN_DIR.mkdir(parents=True, exist_ok=True)\n",
99
+ "\n",
100
+ "existing = list(SCEN_DIR.rglob('*.json'))\n",
101
+ "if existing:\n",
102
+ " print(f'Scenarios already present: {len(existing)} files — skipping download')\n",
103
+ "else:\n",
104
+ " snapshot_download(\n",
105
+ " repo_id=SCENARIOS_REPO,\n",
106
+ " repo_type='dataset',\n",
107
+ " local_dir=str(SCEN_DIR),\n",
108
+ " token=HF_TOKEN,\n",
109
+ " )\n",
110
+ " n = len(list(SCEN_DIR.rglob('*.json')))\n",
111
+ " print(f'Downloaded {n} scenario files')\n",
112
+ "\n",
113
+ "train_dir = SCEN_DIR / 'train'\n",
114
+ "print(f'Train scenarios: {len(list(train_dir.rglob(\"*.json\")))}')"
115
  ]
116
  },
117
  {
118
  "cell_type": "code",
119
  "execution_count": null,
120
+ "id": "cell-sft-ds",
121
  "metadata": {},
122
  "outputs": [],
123
  "source": [
124
+ "# Cell 4 Download SFT dataset from HF Hub\n",
125
+ "from datasets import load_dataset, load_from_disk\n",
126
+ "\n",
127
+ "SFT_DS_DIR = Path('/data/sft_dataset')\n",
128
+ "\n",
129
+ "if SFT_DS_DIR.exists():\n",
130
+ " ds = load_from_disk(str(SFT_DS_DIR))\n",
131
+ " print(f'SFT dataset already present: {len(ds)} examples')\n",
 
 
132
  "else:\n",
133
+ " ds = load_dataset(SFT_DATASET_REPO, split='train', token=HF_TOKEN)\n",
134
+ " SFT_DS_DIR.mkdir(parents=True, exist_ok=True)\n",
135
+ " ds.save_to_disk(str(SFT_DS_DIR))\n",
136
+ " print(f'Downloaded {len(ds)} SFT examples')"
137
  ]
138
  },
139
  {
140
  "cell_type": "code",
141
  "execution_count": null,
142
+ "id": "cell-sft",
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
+ "# Cell 5 SFT warmstart\n",
147
+ "# Estimated time: ~45 min on A10G Large\n",
148
+ "# Optimised: batch_size=4, grad_accum=4 → effective batch 16, bf16 via unsloth\n",
149
  "from ci_triage_env.training.sft import run_sft\n",
150
  "\n",
151
+ "SFT_CKPT = Path('/data/checkpoints/sft')\n",
152
+ "\n",
153
+ "if SFT_CKPT.exists():\n",
154
+ " print(f'SFT checkpoint found at {SFT_CKPT} — skipping (delete to retrain)')\n",
155
+ "else:\n",
156
+ " run_sft(\n",
157
+ " dataset_path=str(SFT_DS_DIR),\n",
158
+ " output_dir=str(SFT_CKPT),\n",
159
+ " num_epochs=2,\n",
160
+ " per_device_batch_size=4, # A10G Large has 46 GB — fits 4 sequences\n",
161
+ " gradient_accumulation_steps=4, # effective batch = 16\n",
162
+ " )\n",
163
+ " print(f'SFT done → {SFT_CKPT}')\n",
164
+ "\n",
165
+ " # Push immediately so checkpoint is safe even if GRPO fails\n",
166
+ " from huggingface_hub import upload_folder\n",
167
+ " upload_folder(\n",
168
+ " folder_path=str(SFT_CKPT),\n",
169
+ " repo_id=MODEL_REPO + '-sft',\n",
170
+ " repo_type='model',\n",
171
+ " token=HF_TOKEN,\n",
172
+ " commit_message='SFT warmstart checkpoint (Qwen3.5-4B + LoRA)',\n",
173
+ " )\n",
174
+ " print(f'SFT checkpoint pushed to {MODEL_REPO}-sft')"
175
  ]
176
  },
177
  {
178
  "cell_type": "code",
179
  "execution_count": null,
180
+ "id": "cell-grpo",
181
  "metadata": {},
182
  "outputs": [],
183
  "source": [
184
+ "# Cell 6 GRPO fine-tuning\n",
185
+ "# Estimated time: ~90 min for 100 steps on A10G Large\n",
186
+ "#\n",
187
+ "# Why 100 steps? Each step = 4 multi-turn rollouts (max 4 tool calls each).\n",
188
+ "# Sequential rollout with model.generate() is the bottleneck: ~50 sec/step.\n",
189
+ "# Increase GRPO_STEPS if you have more time budget.\n",
190
+ "#\n",
191
+ "# MockEnvClient is used in-process — no server needed, full speed.\n",
192
+ "\n",
193
+ "from ci_triage_env.training.mock_env_client import MockEnvClient\n",
194
  "from ci_triage_env.training.grpo import run_grpo\n",
195
  "\n",
196
+ "GRPO_CKPT = Path('/data/checkpoints/grpo')\n",
197
+ "GRPO_STEPS = 100 # increase to 200 if you have ~3 hours total\n",
198
+ "\n",
199
+ "env_client = MockEnvClient(scenarios_dir=str(SCEN_DIR / 'train'))\n",
200
+ "print(f'MockEnvClient loaded {len(env_client.scenario_ids)} train scenarios')\n",
201
+ "print(f'Starting GRPO — {GRPO_STEPS} steps, ~{GRPO_STEPS * 50 // 60} min estimated')\n",
202
+ "print('Monitor: https://wandb.ai (project: ci-triage-env)')\n",
203
+ "\n",
204
  "run_grpo(\n",
205
+ " sft_checkpoint_dir=str(SFT_CKPT),\n",
206
+ " output_dir=str(GRPO_CKPT),\n",
207
+ " total_steps=GRPO_STEPS,\n",
208
+ " env_client=env_client,\n",
209
+ " scenarios_train_path=str(SCEN_DIR / 'train'),\n",
210
+ " hyperparams={\n",
211
+ " # ── training update (fast) ──────────────────────\n",
212
+ " 'per_device_train_batch_size': 1,\n",
213
+ " 'gradient_accumulation_steps': 4, # effective batch = 4\n",
214
+ " 'learning_rate': 5e-6,\n",
215
+ " 'kl_coef': 0.04,\n",
216
+ " # ── rollout generation (bottleneck) ────────────\n",
217
+ " 'num_generations': 4, # 4 rollouts per training sample\n",
218
+ " 'max_prompt_length': 2048,\n",
219
+ " 'max_completion_length': 256, # short = fast; CI responses are concise\n",
220
+ " 'temperature': 0.8,\n",
221
+ " 'top_p': 0.95,\n",
222
+ " # ── logging ────────────────────────────────────\n",
223
+ " 'logging_steps': 5,\n",
224
+ " 'save_steps': 50,\n",
225
+ " 'report_to': 'wandb' if WANDB_KEY else 'none',\n",
226
+ " },\n",
227
  ")\n",
228
+ "print(f'GRPO done → {GRPO_CKPT}')"
229
  ]
230
  },
231
  {
232
  "cell_type": "code",
233
  "execution_count": null,
234
+ "id": "cell-push",
235
  "metadata": {},
236
  "outputs": [],
237
  "source": [
238
+ "# Cell 7 Push final model to HF Hub\n",
239
  "from huggingface_hub import upload_folder\n",
240
  "\n",
241
  "upload_folder(\n",
242
+ " folder_path=str(GRPO_CKPT),\n",
243
+ " repo_id=MODEL_REPO,\n",
244
  " repo_type='model',\n",
245
+ " token=HF_TOKEN,\n",
246
+ " commit_message=f'GRPO-trained adapter — {GRPO_STEPS} steps on A10G Large',\n",
247
  ")\n",
248
+ "print(f'Final model: https://huggingface.co/{MODEL_REPO}')"
249
  ]
250
  }
251
  ],
pyproject.toml CHANGED
@@ -14,6 +14,7 @@ dependencies = [
14
  "huggingface_hub>=0.23",
15
  "jsonschema>=4.21",
16
  "openenv-core>=0.2.3",
 
17
  ]
18
 
19
  [project.optional-dependencies]
 
14
  "huggingface_hub>=0.23",
15
  "jsonschema>=4.21",
16
  "openenv-core>=0.2.3",
17
+ "fastmcp>=0.4",
18
  ]
19
 
20
  [project.optional-dependencies]
src/ci_triage_env/training/grpo.py CHANGED
@@ -64,10 +64,12 @@ def run_grpo(
64
  train_dir = Path(scenarios_train_path)
65
  scenario_ids = [p.stem for p in train_dir.rglob("*.json")] if train_dir.exists() else []
66
 
 
67
  rollout = TrainingRollout(
68
  env_client=env_client,
69
  scenarios_train=scenario_ids,
70
  weights=weights_override,
 
71
  )
72
 
73
  model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
 
64
  train_dir = Path(scenarios_train_path)
65
  scenario_ids = [p.stem for p in train_dir.rglob("*.json")] if train_dir.exists() else []
66
 
67
+ max_turns = hp.pop("max_turns", 4) # short episodes for faster GRPO
68
  rollout = TrainingRollout(
69
  env_client=env_client,
70
  scenarios_train=scenario_ids,
71
  weights=weights_override,
72
+ max_turns=max_turns,
73
  )
74
 
75
  model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
train-entrypoint.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ if [[ "${START_MODE:-jupyter}" == "auto" ]]; then
5
+ echo "[train-entrypoint] START_MODE=auto — running train.py"
6
+ exec python /workspace/train.py
7
+ else
8
+ echo "[train-entrypoint] START_MODE=jupyter — launching JupyterLab on :7860"
9
+ exec jupyter lab \
10
+ --ip=0.0.0.0 \
11
+ --port=7860 \
12
+ --no-browser \
13
+ --allow-root \
14
+ --NotebookApp.token="" \
15
+ --NotebookApp.password="" \
16
+ --notebook-dir=/workspace
17
+ fi
train.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Automated end-to-end training script for HF Spaces.
2
+
3
+ Runs: scenario download → SFT warmstart → GRPO fine-tuning → push to HF Hub.
4
+ All config comes from environment variables (set as Space secrets).
5
+
6
+ Optimised for A10G Large (46 GB VRAM, 12 vCPU).
7
+
8
+ Required env vars:
9
+ HF_TOKEN - HuggingFace write token
10
+ HF_USERNAME - your HF username
11
+ WANDB_API_KEY - Weights & Biases API key
12
+
13
+ Optional:
14
+ HF_SCENARIOS_REPO - default: {HF_USERNAME}/ci-triage-scenarios
15
+ HF_SFT_DATASET_REPO - default: {HF_USERNAME}/ci-triage-sft
16
+ HF_MODEL_REPO - default: {HF_USERNAME}/ci-triage-agent
17
+ GRPO_STEPS - default: 100 (set lower to finish faster, higher for more training)
18
+ SKIP_SFT - set to "1" to skip SFT and jump straight to GRPO (if checkpoint exists)
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import os
24
+ import sys
25
+ from pathlib import Path
26
+
27
+ # ── resolve config ────────────────────────────────────────────────────────────
28
+ HF_TOKEN = os.environ["HF_TOKEN"]
29
+ HF_USERNAME = os.environ["HF_USERNAME"]
30
+ WANDB_KEY = os.environ.get("WANDB_API_KEY", "")
31
+
32
+ SCENARIOS_REPO = os.environ.get("HF_SCENARIOS_REPO", f"{HF_USERNAME}/ci-triage-scenarios")
33
+ SFT_DATASET_REPO = os.environ.get("HF_SFT_DATASET_REPO", f"{HF_USERNAME}/ci-triage-sft")
34
+ MODEL_REPO = os.environ.get("HF_MODEL_REPO", f"{HF_USERNAME}/ci-triage-agent")
35
+ GRPO_STEPS = int(os.environ.get("GRPO_STEPS", "100"))
36
+ SKIP_SFT = os.environ.get("SKIP_SFT", "0") == "1"
37
+
38
+ DATA_ROOT = Path("/data")
39
+ SCEN_DIR = DATA_ROOT / "scenarios"
40
+ SFT_DS_DIR = DATA_ROOT / "sft_dataset"
41
+ SFT_CKPT = DATA_ROOT / "checkpoints" / "sft"
42
+ GRPO_CKPT = DATA_ROOT / "checkpoints" / "grpo"
43
+
44
+ # ── auth ──────────────────────────────────────────────────────────────────────
45
+ from huggingface_hub import login
46
+ login(token=HF_TOKEN)
47
+
48
+ if WANDB_KEY:
49
+ import wandb
50
+ wandb.login(key=WANDB_KEY)
51
+ os.environ["WANDB_PROJECT"] = "ci-triage-env"
52
+ else:
53
+ os.environ["WANDB_DISABLED"] = "true"
54
+
55
+ # ── Step 1: download scenario corpus ─────────────────────────────────────────
56
+ if not SCEN_DIR.exists() or not any(SCEN_DIR.rglob("*.json")):
57
+ print(f"\n[1/4] Downloading scenarios from {SCENARIOS_REPO} …")
58
+ from huggingface_hub import snapshot_download
59
+ snapshot_download(
60
+ repo_id=SCENARIOS_REPO,
61
+ repo_type="dataset",
62
+ local_dir=str(SCEN_DIR),
63
+ token=HF_TOKEN,
64
+ )
65
+ else:
66
+ n = sum(1 for _ in SCEN_DIR.rglob("*.json"))
67
+ print(f"\n[1/4] Scenarios already present ({n} files) — skipping download.")
68
+
69
+ train_scen = list(SCEN_DIR.rglob("train/**/*.json")) or list(SCEN_DIR.rglob("*.json"))
70
+ print(f" Train scenarios available: {len(train_scen)}")
71
+
72
+ # ── Step 2: download SFT dataset ─────────────────────────────────────────────
73
+ if not SFT_DS_DIR.exists():
74
+ print(f"\n[2/4] Downloading SFT dataset from {SFT_DATASET_REPO} …")
75
+ from datasets import load_dataset
76
+ ds = load_dataset(SFT_DATASET_REPO, split="train", token=HF_TOKEN)
77
+ SFT_DS_DIR.mkdir(parents=True, exist_ok=True)
78
+ ds.save_to_disk(str(SFT_DS_DIR))
79
+ print(f" {len(ds)} SFT examples saved.")
80
+ else:
81
+ from datasets import load_from_disk
82
+ ds = load_from_disk(str(SFT_DS_DIR))
83
+ print(f"\n[2/4] SFT dataset already present ({len(ds)} examples) — skipping download.")
84
+
85
+ # ── Step 3: SFT warmstart ─────────────────────────────────────────────────────
86
+ if SKIP_SFT and SFT_CKPT.exists():
87
+ print(f"\n[3/4] SKIP_SFT=1 and checkpoint found at {SFT_CKPT} — skipping SFT.")
88
+ else:
89
+ print(f"\n[3/4] SFT warmstart — {len(ds)} examples, A10G-optimised settings …")
90
+ from ci_triage_env.training.sft import run_sft
91
+ run_sft(
92
+ dataset_path=str(SFT_DS_DIR),
93
+ output_dir=str(SFT_CKPT),
94
+ num_epochs=2,
95
+ per_device_batch_size=4, # 46 GB → fit 4 sequences comfortably
96
+ gradient_accumulation_steps=4, # effective batch = 16
97
+ )
98
+ print(f" SFT done → {SFT_CKPT}")
99
+
100
+ # Push SFT checkpoint immediately so it's saved even if GRPO fails
101
+ print(" Pushing SFT checkpoint to HF Hub …")
102
+ from huggingface_hub import upload_folder
103
+ upload_folder(
104
+ folder_path=str(SFT_CKPT),
105
+ repo_id=MODEL_REPO + "-sft",
106
+ repo_type="model",
107
+ token=HF_TOKEN,
108
+ commit_message="SFT warmstart checkpoint",
109
+ )
110
+
111
+ # ── Step 4: GRPO fine-tuning ──────────────────────────────────────────────────
112
+ print(f"\n[4/4] GRPO training — {GRPO_STEPS} steps, MockEnvClient in-process …")
113
+ print(" Monitoring: https://wandb.ai (search project ci-triage-env)")
114
+
115
+ from ci_triage_env.training.mock_env_client import MockEnvClient
116
+ from ci_triage_env.training.grpo import run_grpo
117
+
118
+ env_client = MockEnvClient(scenarios_dir=str(SCEN_DIR / "train"))
119
+ print(f" Loaded {len(env_client.scenario_ids)} train scenarios into MockEnvClient")
120
+
121
+ # A10G Large optimised hyperparams.
122
+ # max_turns=4 + max_completion_length=256 keeps each rollout to ~15 sec so
123
+ # 100 steps × 4 rollouts ≈ 100 min total — fits the 2-3 hour budget.
124
+ run_grpo(
125
+ sft_checkpoint_dir=str(SFT_CKPT),
126
+ output_dir=str(GRPO_CKPT),
127
+ total_steps=GRPO_STEPS,
128
+ env_client=env_client,
129
+ scenarios_train_path=str(SCEN_DIR / "train"),
130
+ hyperparams={
131
+ "per_device_train_batch_size": 1,
132
+ "gradient_accumulation_steps": 4, # effective batch = 4
133
+ "num_generations": 4,
134
+ "max_prompt_length": 2048,
135
+ "max_completion_length": 256,
136
+ "learning_rate": 5e-6,
137
+ "kl_coef": 0.04,
138
+ "temperature": 0.8,
139
+ "top_p": 0.95,
140
+ "logging_steps": 5,
141
+ "save_steps": 50,
142
+ "report_to": "wandb" if WANDB_KEY else "none",
143
+ },
144
+ )
145
+ print(f" GRPO done → {GRPO_CKPT}")
146
+
147
+ # ── Push final model ──────────────────────────────────────────────────────────
148
+ print(f"\n[done] Pushing final model to {MODEL_REPO} …")
149
+ from huggingface_hub import upload_folder
150
+ upload_folder(
151
+ folder_path=str(GRPO_CKPT),
152
+ repo_id=MODEL_REPO,
153
+ repo_type="model",
154
+ token=HF_TOKEN,
155
+ commit_message=f"GRPO-trained adapter — {GRPO_STEPS} steps",
156
+ )
157
+ print(f" Model at: https://huggingface.co/{MODEL_REPO}")
158
+ print("\nTraining complete.")