ronitraj commited on
Commit
16c627e
·
verified ·
1 Parent(s): ff28459

Upload scripts/generate_sft_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/generate_sft_data.py +440 -0
scripts/generate_sft_data.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """scripts/generate_sft_data.py - SFT dataset generator (master spec, sec. 1).
2
+
3
+ Locked configuration:
4
+ * Train split: 3,000 examples (default seed 42).
5
+ * Held-out split: 100 examples (seed 4242 - independent stream).
6
+ * Curriculum mix: 40% L1_warmup, 50% L2_target, 10% L3_stretch.
7
+
8
+ For each example:
9
+ 1. Pick a curriculum level by the locked mixture.
10
+ 2. Sample a noisy syndrome from Stim (SI1000 noise model).
11
+ 3. Run PyMatching to get the canonical correction (Pauli frame).
12
+ 4. Format the locked prompt + target completion.
13
+ 5. Emit one JSONL record per sample. Records carry ``true_x_errors``,
14
+ ``true_z_errors``, ``actual_observable_flip``, and curriculum info
15
+ so the SFT validation callback can compute every spec metric
16
+ (logical_correction_rate, exact_match_pymatching, hamming_overlap,
17
+ syndrome_consistency, ...) without re-sampling.
18
+
19
+ Output:
20
+ data/sft_dataset.jsonl - training set (3,000 rows)
21
+ data/sft_validation.jsonl - held-out validation (100 rows)
22
+ data/sft_dataset_sample.jsonl - 50-row preview for repo commit
23
+
24
+ Run::
25
+
26
+ python -m scripts.generate_sft_data \
27
+ --n 3000 --val-n 100 \
28
+ --out data/sft_dataset.jsonl \
29
+ --val-out data/sft_validation.jsonl
30
+ """
31
+ from __future__ import annotations
32
+
33
+ import argparse
34
+ import json
35
+ import random
36
+ import sys
37
+ from pathlib import Path
38
+ from typing import Iterable
39
+
40
+ import numpy as np
41
+ import pymatching
42
+
43
+ from qubit_medic.config import (
44
+ PRIMARY_SEED,
45
+ SFT_DATASET_SIZE,
46
+ SFT_VAL_HOLDOUT,
47
+ level_by_name,
48
+ )
49
+ from qubit_medic.prompts import build_prompt, format_completion
50
+ from qubit_medic.server.physics import (
51
+ build_circuit,
52
+ build_dem,
53
+ extract_layout,
54
+ per_round_x_z_counts,
55
+ pymatching_predicted_pauli_frame,
56
+ rectify_pauli_frame_to_observable,
57
+ )
58
+
59
+
60
+ # --------------------------------------------------------------------------- #
61
+ # Optional reasoning helper #
62
+ # --------------------------------------------------------------------------- #
63
+ # Earlier revisions emitted a short reasoning sentence before the canonical
64
+ # format line. The step-5 / step-15 raw outputs showed Qwen copying that too
65
+ # eagerly: it spent the whole 128-token eval budget on generic analysis and
66
+ # never reached ``X_ERRORS=[...] Z_ERRORS=[...]``. SFT warmup needs to teach
67
+ # the parser contract first, so the active target below is format-line-only.
68
+
69
+
70
+ def _build_reasoning(px: list[int], pz: list[int]) -> str:
71
+ """Deterministic 1-sentence reasoning that matches the format line."""
72
+ if not px and not pz:
73
+ return ("All stabilizer measurements report no detector firings, "
74
+ "indicating no data-qubit errors.")
75
+ if px and not pz:
76
+ ids = ", ".join(str(q) for q in sorted(set(px)))
77
+ return f"Z-stabilizer firings localize X-errors to qubit(s) {ids}."
78
+ if pz and not px:
79
+ ids = ", ".join(str(q) for q in sorted(set(pz)))
80
+ return f"X-stabilizer firings localize Z-errors to qubit(s) {ids}."
81
+ x_ids = ", ".join(str(q) for q in sorted(set(px)))
82
+ z_ids = ", ".join(str(q) for q in sorted(set(pz)))
83
+ return (f"X-stabilizer firings localize Z-errors to qubit(s) {z_ids}, "
84
+ f"and Z-stabilizer firings localize X-errors to qubit(s) {x_ids}.")
85
+
86
+
87
+ # Quota-based generation (master spec, section 1, plus the dataset-audit
88
+ # fix): instead of weighted sampling + global rejection (which biased L1
89
+ # down because L1 produces mostly trivial syndromes), we generate a fixed
90
+ # count per level with per-level non-empty floors. This guarantees the
91
+ # 40/50/10 curriculum split exactly while still hitting an overall
92
+ # non-empty fraction in the 65-75% target band.
93
+ LEVEL_QUOTAS_TRAIN: dict[str, int] = {
94
+ "L1_warmup": 1200, # 40% of 3000
95
+ "L2_target": 1500, # 50%
96
+ "L3_stretch": 300, # 10%
97
+ }
98
+
99
+ LEVEL_QUOTAS_VAL: dict[str, int] = {
100
+ "L1_warmup": 40, # 40% of 100
101
+ "L2_target": 50, # 50%
102
+ "L3_stretch": 10, # 10%
103
+ }
104
+
105
+ # Per-level minimum non-empty correction fraction. The math (with the
106
+ # configured 40/50/10 quota mix) gives:
107
+ # L1 0.50 + L2 0.80 + L3 0.90 = 0.40*0.50 + 0.50*0.80 + 0.10*0.90 = 0.69
108
+ # which lands solidly inside the audit's 65-75% target band. ``None``
109
+ # would mean "accept all draws naturally" but the natural non-empty rate
110
+ # at L1's p=0.0005 (~3.5%) is too low to satisfy the audit, so we enforce
111
+ # an explicit floor here too.
112
+ PER_LEVEL_NONEMPTY_FLOOR: dict[str, float | None] = {
113
+ "L1_warmup": 0.50, # ~600 non-empty + 600 empty per 1200
114
+ "L2_target": 0.80, # ~1200 non-empty + 300 empty per 1500
115
+ "L3_stretch": 0.90, # ~270 non-empty + 30 empty per 300
116
+ }
117
+
118
+ # Held-out validation runs from a disjoint seed stream so it is truly
119
+ # independent of the train split.
120
+ VALIDATION_SEED_OFFSET: int = 4_242
121
+
122
+
123
+ def _quotas_from_total(total: int, base: dict[str, int]) -> dict[str, int]:
124
+ """Scale ``base`` quota proportions to sum to ``total``.
125
+
126
+ When the user passes ``--n`` or ``--val-n`` overriding the default
127
+ sizes, we keep the 40/50/10 curriculum proportions and absorb any
128
+ rounding remainder into the largest level (L2) so the file row count
129
+ matches ``total`` exactly.
130
+ """
131
+ base_sum = sum(base.values())
132
+ if base_sum == 0:
133
+ return {k: 0 for k in base}
134
+ scaled = {k: int(round(v * total / base_sum)) for k, v in base.items()}
135
+ diff = total - sum(scaled.values())
136
+ if diff != 0:
137
+ # Largest level absorbs the remainder.
138
+ target = max(scaled, key=scaled.get)
139
+ scaled[target] += diff
140
+ return scaled
141
+
142
+
143
+ def _build_caches() -> dict[str, dict]:
144
+ """Pre-compile circuits / matchers once per level."""
145
+ caches: dict[str, dict] = {}
146
+ for name in LEVEL_QUOTAS_TRAIN.keys():
147
+ lvl = level_by_name(name)
148
+ c = build_circuit(lvl)
149
+ dem = build_dem(c)
150
+ m = pymatching.Matching.from_detector_error_model(dem)
151
+ layout = extract_layout(c)
152
+ n_x, n_z = per_round_x_z_counts(layout)
153
+ caches[name] = {
154
+ "level": lvl,
155
+ "circuit": c,
156
+ "dem": dem,
157
+ "matching": m,
158
+ "layout": layout,
159
+ "n_x_stab": n_x,
160
+ "n_z_stab": n_z,
161
+ }
162
+ return caches
163
+
164
+
165
+ # Per-level seed offsets so each level draws an independent shot stream
166
+ # from a distinct RNG. Without this, switching from L1 to L2 with the
167
+ # same `seed` would produce identical syndromes (Stim's RNG is per-sampler).
168
+ _LEVEL_SEED_OFFSETS: dict[str, int] = {
169
+ "L1_warmup": 0,
170
+ "L2_target": 100_000,
171
+ "L3_stretch": 200_000,
172
+ }
173
+
174
+ # Safety cap on shots per level. With L1 floor=0.50 at p=0.0005 (~3.5%
175
+ # natural non-empty rate) we expect ~17k shots; 1M is a generous ceiling
176
+ # that triggers a descriptive error if generation can't converge -- e.g.
177
+ # someone bumped a level's floor too aggressively for its physical error
178
+ # rate.
179
+ _MAX_SHOTS_PER_LEVEL: int = 1_000_000
180
+
181
+ # Stim's compile_detector_sampler is the slow step (~ms per call); once
182
+ # compiled, sample(N) is essentially free. We sample in chunks of this
183
+ # size to amortise the compile cost across thousands of shots.
184
+ _SHOT_BATCH_SIZE: int = 4096
185
+
186
+
187
+ def _level_shot_stream(cache: dict, base_seed: int):
188
+ """Yield ``(det_row, obs_row)`` tuples lazily from a level's circuit.
189
+
190
+ Compiles the detector sampler exactly ONCE per level and then pulls
191
+ shots in batches of :data:`_SHOT_BATCH_SIZE`. ``det_row`` is a
192
+ ``np.uint8`` 1-D array (the detector activations); ``obs_row`` is the
193
+ 1-D observables vector for the same shot.
194
+
195
+ Determinism: the same ``base_seed`` always produces the same shot
196
+ sequence regardless of batch size (Stim's per-sampler RNG advances
197
+ deterministically across each ``sample()`` call).
198
+ """
199
+ sampler = cache["circuit"].compile_detector_sampler(seed=base_seed)
200
+ while True:
201
+ det, obs = sampler.sample(_SHOT_BATCH_SIZE, separate_observables=True)
202
+ for i in range(_SHOT_BATCH_SIZE):
203
+ yield det[i].astype(np.uint8), obs[i]
204
+
205
+
206
+ def _generate_split(
207
+ *,
208
+ quotas: dict[str, int],
209
+ seed: int,
210
+ caches: dict[str, dict],
211
+ out_path: Path,
212
+ rng: random.Random,
213
+ ) -> tuple[int, int, int]:
214
+ """Quota-based generator with per-level non-empty floors.
215
+
216
+ Returns ``(n_written, n_syndrome, n_errors)``.
217
+
218
+ For each level in ``quotas`` we generate exactly ``quotas[level]`` rows.
219
+ Within each level, :data:`PER_LEVEL_NONEMPTY_FLOOR` controls the
220
+ non-empty/empty split:
221
+
222
+ * ``floor=None`` -> accept every draw until the quota is filled
223
+ (mostly empty for low-p levels).
224
+ * ``floor=f`` -> accept exactly ``round(level_n * f)`` non-empty
225
+ rows and ``level_n - round(level_n * f)`` empty rows. Surplus on
226
+ either side is dropped, draws continue until both sub-quotas are
227
+ filled or :data:`_MAX_SHOTS_PER_LEVEL` is exceeded.
228
+
229
+ Stim sampling is batched per level (single ``compile_detector_sampler``
230
+ call, chunked ``sample()``) so generation is ~1 second per level even
231
+ when the floor demands tens of thousands of shots.
232
+ """
233
+ n_with_syndrome = n_with_errors = 0
234
+ out_path.parent.mkdir(parents=True, exist_ok=True)
235
+
236
+ # Buffer all records in memory then shuffle before writing. This is
237
+ # critical: per-level generation produces L1-block / L2-block / L3-block
238
+ # contiguously, which (a) makes SFTTrainer's first batches all-L1 even
239
+ # though Trainer shuffles per-epoch, and (b) makes the validation
240
+ # callback's "first N samples" display all-L1 -- hiding model behaviour
241
+ # on L2/L3 prompts. A deterministic shuffle keyed off `rng` (the
242
+ # caller-passed random.Random) gives us level-mixed streams while
243
+ # keeping `--seed N` fully reproducible.
244
+ records: list[dict] = []
245
+
246
+ for level_name, level_n in quotas.items():
247
+ cache = caches[level_name]
248
+ layout = cache["layout"]
249
+ floor = PER_LEVEL_NONEMPTY_FLOOR.get(level_name)
250
+
251
+ if floor is None:
252
+ target_nonempty = None
253
+ target_empty = None
254
+ else:
255
+ target_nonempty = int(round(level_n * floor))
256
+ target_empty = level_n - target_nonempty
257
+
258
+ level_nonempty = 0
259
+ level_empty = 0
260
+ shots_drawn = 0
261
+ level_seed = seed + _LEVEL_SEED_OFFSETS.get(level_name, 0)
262
+ shots = _level_shot_stream(cache, level_seed)
263
+
264
+ while (level_nonempty + level_empty) < level_n:
265
+ if shots_drawn >= _MAX_SHOTS_PER_LEVEL:
266
+ raise RuntimeError(
267
+ f"[gen] level {level_name}: exceeded "
268
+ f"_MAX_SHOTS_PER_LEVEL={_MAX_SHOTS_PER_LEVEL} with "
269
+ f"only {level_nonempty} non-empty + {level_empty} "
270
+ f"empty rows (target: {target_nonempty} non-empty + "
271
+ f"{target_empty} empty). Either lower "
272
+ f"PER_LEVEL_NONEMPTY_FLOOR[{level_name!r}] or "
273
+ f"raise the level's physical error rate in "
274
+ f"qubit_medic/config.py."
275
+ )
276
+ det_row, obs_row = next(shots)
277
+ shots_drawn += 1
278
+
279
+ # Optimal correction via PyMatching (X + Z Pauli frame).
280
+ px_stim, pz_stim = pymatching_predicted_pauli_frame(
281
+ cache["matching"], det_row, layout,
282
+ )
283
+ pm_obs = int(cache["matching"].decode(det_row)[0])
284
+ px_stim, pz_stim = rectify_pauli_frame_to_observable(
285
+ px_stim, pz_stim, pm_obs, layout,
286
+ )
287
+ # LLM ID space (consecutive 0..N-1).
288
+ px = layout.stim_to_llm(px_stim)
289
+ pz = layout.stim_to_llm(pz_stim)
290
+ is_nonempty = bool(px or pz)
291
+
292
+ # Per-level quota acceptance:
293
+ if floor is None:
294
+ pass # accept anything until level_n is filled
295
+ elif is_nonempty:
296
+ if level_nonempty >= target_nonempty:
297
+ continue # surplus non-empty for this level
298
+ else:
299
+ if level_empty >= target_empty:
300
+ continue # surplus empty for this level
301
+
302
+ actual_obs = int(obs_row[0]) if obs_row.shape[0] else 0
303
+
304
+ prompt = build_prompt(
305
+ distance=cache["level"].distance,
306
+ rounds=cache["level"].rounds,
307
+ p=cache["level"].p,
308
+ syndrome_bits=det_row.tolist(),
309
+ num_x_stabilizers=cache["n_x_stab"],
310
+ num_z_stabilizers=cache["n_z_stab"],
311
+ num_data_qubits=layout.num_data_qubits,
312
+ )
313
+ completion = format_completion(px, pz)
314
+ record = {
315
+ "prompt": prompt,
316
+ "completion": completion,
317
+ "level": level_name,
318
+ "distance": cache["level"].distance,
319
+ "rounds": cache["level"].rounds,
320
+ "p": cache["level"].p,
321
+ "num_data_qubits": int(layout.num_data_qubits),
322
+ "num_x_stabilizers": int(cache["n_x_stab"]),
323
+ "num_z_stabilizers": int(cache["n_z_stab"]),
324
+ "syndrome_bits": [int(b) for b in det_row.tolist()],
325
+ "true_x_errors": list(map(int, px)),
326
+ "true_z_errors": list(map(int, pz)),
327
+ "actual_observable_flip": actual_obs,
328
+ "pymatching_observable_pred": pm_obs,
329
+ "had_syndrome": bool(det_row.any()),
330
+ "had_errors": bool(px or pz),
331
+ }
332
+ records.append(record)
333
+ if record["had_errors"]:
334
+ n_with_errors += 1
335
+ level_nonempty += 1
336
+ else:
337
+ level_empty += 1
338
+ if record["had_syndrome"]:
339
+ n_with_syndrome += 1
340
+
341
+ print(f" [{level_name}] {level_nonempty} non-empty + "
342
+ f"{level_empty} empty (drew {shots_drawn} shots, "
343
+ f"natural non-empty rate "
344
+ f"~{level_nonempty / max(1, shots_drawn):.1%})")
345
+
346
+ # Deterministic shuffle: same `seed` -> same row order, but no longer
347
+ # blocked by level. SFTTrainer's per-epoch shuffle still applies on top
348
+ # of this; the buffer-shuffle ensures every batch (and every eval
349
+ # display window) sees a representative L1/L2/L3 mix.
350
+ rng.shuffle(records)
351
+
352
+ with out_path.open("w") as f:
353
+ for record in records:
354
+ f.write(json.dumps(record) + "\n")
355
+
356
+ return len(records), n_with_syndrome, n_with_errors
357
+
358
+
359
+ def main(argv: Iterable[str] = ()) -> int:
360
+ parser = argparse.ArgumentParser(description=__doc__)
361
+ parser.add_argument("--n", type=int, default=SFT_DATASET_SIZE,
362
+ help=f"train split size (default {SFT_DATASET_SIZE})")
363
+ parser.add_argument("--val-n", type=int, default=SFT_VAL_HOLDOUT,
364
+ help=f"held-out validation size (default {SFT_VAL_HOLDOUT})")
365
+ parser.add_argument("--out", type=str, default="data/sft_dataset.jsonl")
366
+ parser.add_argument("--val-out", type=str, default="data/sft_validation.jsonl")
367
+ parser.add_argument("--sample-out", type=str,
368
+ default="data/sft_dataset_sample.jsonl",
369
+ help="optional small JSONL committed to the repo")
370
+ parser.add_argument("--sample-size", type=int, default=50)
371
+ parser.add_argument("--seed", type=int, default=PRIMARY_SEED,
372
+ help=f"deterministic seed (default {PRIMARY_SEED})")
373
+ parser.add_argument("--no-validation", action="store_true",
374
+ help="skip writing the held-out validation split")
375
+ args = parser.parse_args(list(argv))
376
+
377
+ train_path = Path(args.out)
378
+ val_path = Path(args.val_out)
379
+ sample_path = Path(args.sample_out)
380
+ sample_path.parent.mkdir(parents=True, exist_ok=True)
381
+
382
+ caches = _build_caches()
383
+ print(f"prepared caches for {len(caches)} levels")
384
+
385
+ # ---- training split ------------------------------------------------ #
386
+ train_quotas = _quotas_from_total(args.n, LEVEL_QUOTAS_TRAIN)
387
+ train_rng = random.Random(args.seed)
388
+ print(f"writing TRAIN split: n={args.n}, seed={args.seed}, "
389
+ f"quotas={train_quotas} -> {train_path}")
390
+ train_written, train_syn, train_err = _generate_split(
391
+ quotas=train_quotas, seed=args.seed, caches=caches,
392
+ out_path=train_path, rng=train_rng,
393
+ )
394
+ print(f" wrote {train_written}; syndrome-fraction={train_syn / max(1, train_written):.3f}; "
395
+ f"non-empty-correction-fraction={train_err / max(1, train_written):.3f}")
396
+
397
+ # ---- validation split (disjoint seed stream) ---------------------- #
398
+ if not args.no_validation:
399
+ val_quotas = _quotas_from_total(args.val_n, LEVEL_QUOTAS_VAL)
400
+ val_seed = args.seed + VALIDATION_SEED_OFFSET
401
+ val_rng = random.Random(val_seed)
402
+ print(f"writing VAL split: n={args.val_n}, seed={val_seed}, "
403
+ f"quotas={val_quotas} -> {val_path}")
404
+ val_written, val_syn, val_err = _generate_split(
405
+ quotas=val_quotas, seed=val_seed, caches=caches,
406
+ out_path=val_path, rng=val_rng,
407
+ )
408
+ print(f" wrote {val_written}; syndrome-fraction={val_syn / max(1, val_written):.3f}; "
409
+ f"non-empty-correction-fraction={val_err / max(1, val_written):.3f}")
410
+
411
+ # ---- sample preview (for repo commit / eyeball QC) ---------------- #
412
+ sample_records: list[dict] = []
413
+ with train_path.open() as src:
414
+ for line in src:
415
+ sample_records.append(json.loads(line))
416
+ if len(sample_records) >= args.sample_size:
417
+ break
418
+ with sample_path.open("w") as sf:
419
+ for r in sample_records:
420
+ sf.write(json.dumps(r) + "\n")
421
+ print(f"wrote {len(sample_records)} sample records to {sample_path}")
422
+
423
+ # ---- self-audit (fail fast on bad regen) -------------------------- #
424
+ # Run the same audit train_sft.py runs at startup, so a regen that
425
+ # silently produced bad data exits non-zero immediately rather than
426
+ # waiting until the next training launch. Lazy import so we don't
427
+ # pull in train_sft's heavy ML deps at import time.
428
+ if not args.no_validation:
429
+ try:
430
+ from scripts.train_sft import audit_sft_dataset
431
+ except ImportError as exc:
432
+ print(f"[gen] could not run self-audit: {exc}", file=sys.stderr)
433
+ return 0
434
+ print() # blank line before banner
435
+ audit_sft_dataset(str(train_path), str(val_path))
436
+ return 0
437
+
438
+
439
+ if __name__ == "__main__":
440
+ sys.exit(main(sys.argv[1:]))