ronitraj commited on
Commit
2d520b3
·
verified ·
1 Parent(s): 0139454

Upload notebooks/colab_train.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. notebooks/colab_train.ipynb +284 -0
notebooks/colab_train.ipynb ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Qubit-Medic - end-to-end Colab notebook\n",
8
+ "\n",
9
+ "Runs SFT warm-up + GRPO RL on a single Colab T4. Total wall-clock: ~24 hours\n",
10
+ "(SFT ~30 min, GRPO ~22 hours, eval ~30 min). The notebook is structured so\n",
11
+ "every cell is idempotent and re-runnable.\n",
12
+ "\n",
13
+ "**W&B integration is on by default.** Every stage (format-test, SFT, GRPO,\n",
14
+ "eval) logs to the same W&B project (`qubit-medic`) and shares a `--wandb-group`\n",
15
+ "so the runs appear together in the dashboard. Set `WANDB_DISABLED=1` if you\n",
16
+ "want to skip W&B entirely."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "## 1. Clone the repo and install"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "%cd /content\n",
33
+ "!git clone https://github.com/qubit-medic/qubit-medic.git || (cd qubit-medic && git pull)\n",
34
+ "%cd qubit-medic\n",
35
+ "!pip install -q -r requirements.txt\n",
36
+ "!pip install -q -r requirements-train.txt\n",
37
+ "!pip install -q --no-deps unsloth"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "## 2. Configure W&B\n",
45
+ "\n",
46
+ "Paste your API key from <https://wandb.ai/authorize>. The `EXPERIMENT_GROUP`\n",
47
+ "below is what bundles the format-test, SFT, GRPO, and eval runs together\n",
48
+ "on the dashboard - bump it for each new experiment."
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "import os, datetime\n",
58
+ "EXPERIMENT_GROUP = f\"colab-{datetime.datetime.utcnow().strftime('%Y%m%d-%H%M')}\"\n",
59
+ "os.environ['WANDB_PROJECT'] = 'qubit-medic'\n",
60
+ "# os.environ['WANDB_ENTITY'] = 'your-team' # uncomment if you use a team\n",
61
+ "# os.environ['WANDB_DISABLED'] = '1' # uncomment to skip W&B\n",
62
+ "print('experiment group:', EXPERIMENT_GROUP)\n",
63
+ "!wandb login"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "metadata": {},
69
+ "source": [
70
+ "## 3. Validate the environment\n",
71
+ "\n",
72
+ "All five gates must pass before going further. (No W&B logging here - this\n",
73
+ "is a static check.)"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "!python -m scripts.validate_env"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "## 4. Section 1.3 - format-test (existential go/no-go)\n",
90
+ "\n",
91
+ "If parseable rate is below 30%, SFT is mandatory. The result is logged to\n",
92
+ "W&B under `format_test/*`."
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "!python -m scripts.format_test \\\n",
102
+ " --backend unsloth \\\n",
103
+ " --model Qwen/Qwen2.5-3B-Instruct \\\n",
104
+ " --syndromes 10 --samples-per 3 \\\n",
105
+ " --out data/format_test.json \\\n",
106
+ " --report-to wandb \\\n",
107
+ " --wandb-group {EXPERIMENT_GROUP}"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "metadata": {},
113
+ "source": [
114
+ "## 5. Generate SFT data (5,000 syndromes, ~5 min)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "!python -m scripts.generate_sft_data --n 5000 --out data/sft_dataset.jsonl"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "## 6. SFT warm-up (~30 min on T4)\n",
131
+ "\n",
132
+ "Logs `sft/loss`, `sft/parse_success_rate`, and a `sft/generations` table\n",
133
+ "every 100 steps. Uploads the LoRA adapter dir as a W&B artifact at the end."
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "!python -m scripts.train_sft \\\n",
143
+ " --dataset data/sft_dataset.jsonl \\\n",
144
+ " --output checkpoints/sft_warmup \\\n",
145
+ " --report-to wandb \\\n",
146
+ " --wandb-group {EXPERIMENT_GROUP} \\\n",
147
+ " --wandb-run-name sft-warmup-{EXPERIMENT_GROUP} \\\n",
148
+ " --wandb-notes 'SFT warm-up on PyMatching-derived syndromes' \\\n",
149
+ " --sample-every 100 --sample-count 4"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {},
155
+ "source": [
156
+ "## 7. SFT validation gate (Section 6.2)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "!python -m scripts.eval \\\n",
166
+ " --adapter checkpoints/sft_warmup \\\n",
167
+ " --episodes 100 \\\n",
168
+ " --out data/sft_eval.json \\\n",
169
+ " --report-to wandb \\\n",
170
+ " --wandb-group {EXPERIMENT_GROUP} \\\n",
171
+ " --wandb-run-name eval-sft-{EXPERIMENT_GROUP}"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "metadata": {},
177
+ "source": [
178
+ "## 8. GRPO RL training (~22 hours on T4)\n",
179
+ "\n",
180
+ "Logs `rl/reward/<component>_mean|std|min|max` for each of the five reward\n",
181
+ "components, `rl/parse/*`, `rl/curriculum/*`, plus a generation table and\n",
182
+ "an in-loop greedy eval every 200 steps. Uploads the trained adapter as a\n",
183
+ "W&B artifact at the end.\n",
184
+ "\n",
185
+ "Adjust `--steps` if your time budget is tighter (~250 steps/hour on a T4)."
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "!python -m scripts.train_grpo \\\n",
195
+ " --sft-checkpoint checkpoints/sft_warmup \\\n",
196
+ " --output checkpoints/grpo \\\n",
197
+ " --steps 2000 \\\n",
198
+ " --report-to wandb \\\n",
199
+ " --wandb-group {EXPERIMENT_GROUP} \\\n",
200
+ " --wandb-run-name grpo-{EXPERIMENT_GROUP} \\\n",
201
+ " --wandb-notes 'GRPO with 5 verifiable rewards' \\\n",
202
+ " --sample-every 50 --sample-n 8 \\\n",
203
+ " --inloop-eval-every 200 --inloop-eval-episodes 50"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "metadata": {},
209
+ "source": [
210
+ "## 9. Final evaluation + headline plots"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "!python -m scripts.eval \\\n",
220
+ " --adapter checkpoints/grpo --episodes 500 \\\n",
221
+ " --out data/grpo_eval.json \\\n",
222
+ " --report-to wandb \\\n",
223
+ " --wandb-group {EXPERIMENT_GROUP} \\\n",
224
+ " --wandb-run-name eval-grpo-{EXPERIMENT_GROUP}\n",
225
+ "\n",
226
+ "!python -m scripts.baseline_policies --episodes 500 --out data/baseline_results.json\n",
227
+ "!python -m scripts.plot_results --baselines data/baseline_results.json --out-dir figures\n",
228
+ "!python -m scripts.animate_grid --frames 50"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "markdown",
233
+ "metadata": {},
234
+ "source": [
235
+ "## 10. Optional: Willow real-chip cross-validation (Section 8)"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "# Manually download from https://zenodo.org/record/13359217 and place at data/willow_d3.dem\n",
245
+ "!python -m scripts.willow_validation --dem data/willow_d3.dem --episodes 1000"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "markdown",
250
+ "metadata": {},
251
+ "source": [
252
+ "## 11. Push to Hugging Face Spaces\n",
253
+ "\n",
254
+ "After successful training, push the env + adapters to a Space."
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "from huggingface_hub import HfApi, login\n",
264
+ "login() # paste your HF token\n",
265
+ "api = HfApi()\n",
266
+ "# Replace with your Space repo id.\n",
267
+ "api.upload_folder(folder_path='.', repo_id='your-team/qubit-medic', repo_type='space')"
268
+ ]
269
+ }
270
+ ],
271
+ "metadata": {
272
+ "kernelspec": {
273
+ "display_name": "Python 3",
274
+ "language": "python",
275
+ "name": "python3"
276
+ },
277
+ "language_info": {
278
+ "name": "python",
279
+ "version": "3.11"
280
+ }
281
+ },
282
+ "nbformat": 4,
283
+ "nbformat_minor": 5
284
+ }