"""scripts/generate_sft_data.py - SFT dataset generator (master spec, sec. 1). Locked configuration: * Train split: 3,000 examples (default seed 42). * Held-out split: 100 examples (seed 4242 - independent stream). * Curriculum mix: 40% L1_warmup, 50% L2_target, 10% L3_stretch. For each example: 1. Pick a curriculum level by the locked mixture. 2. Sample a noisy syndrome from Stim (SI1000 noise model). 3. Run PyMatching to get the canonical correction (Pauli frame). 4. Format the locked prompt + target completion. 5. Emit one JSONL record per sample. Records carry ``true_x_errors``, ``true_z_errors``, ``actual_observable_flip``, and curriculum info so the SFT validation callback can compute every spec metric (logical_correction_rate, exact_match_pymatching, hamming_overlap, syndrome_consistency, ...) without re-sampling. Output: data/sft_dataset.jsonl - training set (3,000 rows) data/sft_validation.jsonl - held-out validation (100 rows) data/sft_dataset_sample.jsonl - 50-row preview for repo commit Run:: python -m scripts.generate_sft_data \ --n 3000 --val-n 100 \ --out data/sft_dataset.jsonl \ --val-out data/sft_validation.jsonl """ from __future__ import annotations import argparse import json import random import sys from pathlib import Path from typing import Iterable import numpy as np import pymatching from qubit_medic.config import ( PRIMARY_SEED, SFT_DATASET_SIZE, SFT_VAL_HOLDOUT, level_by_name, ) from qubit_medic.prompts import build_prompt, format_completion from qubit_medic.server.physics import ( build_circuit, build_dem, extract_layout, per_round_x_z_counts, pymatching_predicted_pauli_frame, rectify_pauli_frame_to_observable, ) # --------------------------------------------------------------------------- # # Optional reasoning helper # # --------------------------------------------------------------------------- # # Earlier revisions emitted a short reasoning sentence before the canonical # format line. The step-5 / step-15 raw outputs showed Qwen copying that too # eagerly: it spent the whole 128-token eval budget on generic analysis and # never reached ``X_ERRORS=[...] Z_ERRORS=[...]``. SFT warmup needs to teach # the parser contract first, so the active target below is format-line-only. def _build_reasoning(px: list[int], pz: list[int]) -> str: """Deterministic 1-sentence reasoning that matches the format line.""" if not px and not pz: return ("All stabilizer measurements report no detector firings, " "indicating no data-qubit errors.") if px and not pz: ids = ", ".join(str(q) for q in sorted(set(px))) return f"Z-stabilizer firings localize X-errors to qubit(s) {ids}." if pz and not px: ids = ", ".join(str(q) for q in sorted(set(pz))) return f"X-stabilizer firings localize Z-errors to qubit(s) {ids}." x_ids = ", ".join(str(q) for q in sorted(set(px))) z_ids = ", ".join(str(q) for q in sorted(set(pz))) return (f"X-stabilizer firings localize Z-errors to qubit(s) {z_ids}, " f"and Z-stabilizer firings localize X-errors to qubit(s) {x_ids}.") # Quota-based generation (master spec, section 1, plus the dataset-audit # fix): instead of weighted sampling + global rejection (which biased L1 # down because L1 produces mostly trivial syndromes), we generate a fixed # count per level with per-level non-empty floors. This guarantees the # 40/50/10 curriculum split exactly while still hitting an overall # non-empty fraction in the 65-75% target band. LEVEL_QUOTAS_TRAIN: dict[str, int] = { "L1_warmup": 1200, # 40% of 3000 "L2_target": 1500, # 50% "L3_stretch": 300, # 10% } LEVEL_QUOTAS_VAL: dict[str, int] = { "L1_warmup": 40, # 40% of 100 "L2_target": 50, # 50% "L3_stretch": 10, # 10% } # Per-level minimum non-empty correction fraction. The math (with the # configured 40/50/10 quota mix) gives: # L1 0.50 + L2 0.80 + L3 0.90 = 0.40*0.50 + 0.50*0.80 + 0.10*0.90 = 0.69 # which lands solidly inside the audit's 65-75% target band. ``None`` # would mean "accept all draws naturally" but the natural non-empty rate # at L1's p=0.0005 (~3.5%) is too low to satisfy the audit, so we enforce # an explicit floor here too. PER_LEVEL_NONEMPTY_FLOOR: dict[str, float | None] = { "L1_warmup": 0.50, # ~600 non-empty + 600 empty per 1200 "L2_target": 0.80, # ~1200 non-empty + 300 empty per 1500 "L3_stretch": 0.90, # ~270 non-empty + 30 empty per 300 } # Held-out validation runs from a disjoint seed stream so it is truly # independent of the train split. VALIDATION_SEED_OFFSET: int = 4_242 def _quotas_from_total(total: int, base: dict[str, int]) -> dict[str, int]: """Scale ``base`` quota proportions to sum to ``total``. When the user passes ``--n`` or ``--val-n`` overriding the default sizes, we keep the 40/50/10 curriculum proportions and absorb any rounding remainder into the largest level (L2) so the file row count matches ``total`` exactly. """ base_sum = sum(base.values()) if base_sum == 0: return {k: 0 for k in base} scaled = {k: int(round(v * total / base_sum)) for k, v in base.items()} diff = total - sum(scaled.values()) if diff != 0: # Largest level absorbs the remainder. target = max(scaled, key=scaled.get) scaled[target] += diff return scaled def _build_caches() -> dict[str, dict]: """Pre-compile circuits / matchers once per level.""" caches: dict[str, dict] = {} for name in LEVEL_QUOTAS_TRAIN.keys(): lvl = level_by_name(name) c = build_circuit(lvl) dem = build_dem(c) m = pymatching.Matching.from_detector_error_model(dem) layout = extract_layout(c) n_x, n_z = per_round_x_z_counts(layout) caches[name] = { "level": lvl, "circuit": c, "dem": dem, "matching": m, "layout": layout, "n_x_stab": n_x, "n_z_stab": n_z, } return caches # Per-level seed offsets so each level draws an independent shot stream # from a distinct RNG. Without this, switching from L1 to L2 with the # same `seed` would produce identical syndromes (Stim's RNG is per-sampler). _LEVEL_SEED_OFFSETS: dict[str, int] = { "L1_warmup": 0, "L2_target": 100_000, "L3_stretch": 200_000, } # Safety cap on shots per level. With L1 floor=0.50 at p=0.0005 (~3.5% # natural non-empty rate) we expect ~17k shots; 1M is a generous ceiling # that triggers a descriptive error if generation can't converge -- e.g. # someone bumped a level's floor too aggressively for its physical error # rate. _MAX_SHOTS_PER_LEVEL: int = 1_000_000 # Stim's compile_detector_sampler is the slow step (~ms per call); once # compiled, sample(N) is essentially free. We sample in chunks of this # size to amortise the compile cost across thousands of shots. _SHOT_BATCH_SIZE: int = 4096 def _level_shot_stream(cache: dict, base_seed: int): """Yield ``(det_row, obs_row)`` tuples lazily from a level's circuit. Compiles the detector sampler exactly ONCE per level and then pulls shots in batches of :data:`_SHOT_BATCH_SIZE`. ``det_row`` is a ``np.uint8`` 1-D array (the detector activations); ``obs_row`` is the 1-D observables vector for the same shot. Determinism: the same ``base_seed`` always produces the same shot sequence regardless of batch size (Stim's per-sampler RNG advances deterministically across each ``sample()`` call). """ sampler = cache["circuit"].compile_detector_sampler(seed=base_seed) while True: det, obs = sampler.sample(_SHOT_BATCH_SIZE, separate_observables=True) for i in range(_SHOT_BATCH_SIZE): yield det[i].astype(np.uint8), obs[i] def _generate_split( *, quotas: dict[str, int], seed: int, caches: dict[str, dict], out_path: Path, rng: random.Random, ) -> tuple[int, int, int]: """Quota-based generator with per-level non-empty floors. Returns ``(n_written, n_syndrome, n_errors)``. For each level in ``quotas`` we generate exactly ``quotas[level]`` rows. Within each level, :data:`PER_LEVEL_NONEMPTY_FLOOR` controls the non-empty/empty split: * ``floor=None`` -> accept every draw until the quota is filled (mostly empty for low-p levels). * ``floor=f`` -> accept exactly ``round(level_n * f)`` non-empty rows and ``level_n - round(level_n * f)`` empty rows. Surplus on either side is dropped, draws continue until both sub-quotas are filled or :data:`_MAX_SHOTS_PER_LEVEL` is exceeded. Stim sampling is batched per level (single ``compile_detector_sampler`` call, chunked ``sample()``) so generation is ~1 second per level even when the floor demands tens of thousands of shots. """ n_with_syndrome = n_with_errors = 0 out_path.parent.mkdir(parents=True, exist_ok=True) # Buffer all records in memory then shuffle before writing. This is # critical: per-level generation produces L1-block / L2-block / L3-block # contiguously, which (a) makes SFTTrainer's first batches all-L1 even # though Trainer shuffles per-epoch, and (b) makes the validation # callback's "first N samples" display all-L1 -- hiding model behaviour # on L2/L3 prompts. A deterministic shuffle keyed off `rng` (the # caller-passed random.Random) gives us level-mixed streams while # keeping `--seed N` fully reproducible. records: list[dict] = [] for level_name, level_n in quotas.items(): cache = caches[level_name] layout = cache["layout"] floor = PER_LEVEL_NONEMPTY_FLOOR.get(level_name) if floor is None: target_nonempty = None target_empty = None else: target_nonempty = int(round(level_n * floor)) target_empty = level_n - target_nonempty level_nonempty = 0 level_empty = 0 shots_drawn = 0 level_seed = seed + _LEVEL_SEED_OFFSETS.get(level_name, 0) shots = _level_shot_stream(cache, level_seed) while (level_nonempty + level_empty) < level_n: if shots_drawn >= _MAX_SHOTS_PER_LEVEL: raise RuntimeError( f"[gen] level {level_name}: exceeded " f"_MAX_SHOTS_PER_LEVEL={_MAX_SHOTS_PER_LEVEL} with " f"only {level_nonempty} non-empty + {level_empty} " f"empty rows (target: {target_nonempty} non-empty + " f"{target_empty} empty). Either lower " f"PER_LEVEL_NONEMPTY_FLOOR[{level_name!r}] or " f"raise the level's physical error rate in " f"qubit_medic/config.py." ) det_row, obs_row = next(shots) shots_drawn += 1 # Optimal correction via PyMatching (X + Z Pauli frame). px_stim, pz_stim = pymatching_predicted_pauli_frame( cache["matching"], det_row, layout, ) pm_obs = int(cache["matching"].decode(det_row)[0]) px_stim, pz_stim = rectify_pauli_frame_to_observable( px_stim, pz_stim, pm_obs, layout, ) # LLM ID space (consecutive 0..N-1). px = layout.stim_to_llm(px_stim) pz = layout.stim_to_llm(pz_stim) is_nonempty = bool(px or pz) # Per-level quota acceptance: if floor is None: pass # accept anything until level_n is filled elif is_nonempty: if level_nonempty >= target_nonempty: continue # surplus non-empty for this level else: if level_empty >= target_empty: continue # surplus empty for this level actual_obs = int(obs_row[0]) if obs_row.shape[0] else 0 prompt = build_prompt( distance=cache["level"].distance, rounds=cache["level"].rounds, p=cache["level"].p, syndrome_bits=det_row.tolist(), num_x_stabilizers=cache["n_x_stab"], num_z_stabilizers=cache["n_z_stab"], num_data_qubits=layout.num_data_qubits, ) completion = format_completion(px, pz) record = { "prompt": prompt, "completion": completion, "level": level_name, "distance": cache["level"].distance, "rounds": cache["level"].rounds, "p": cache["level"].p, "num_data_qubits": int(layout.num_data_qubits), "num_x_stabilizers": int(cache["n_x_stab"]), "num_z_stabilizers": int(cache["n_z_stab"]), "syndrome_bits": [int(b) for b in det_row.tolist()], "true_x_errors": list(map(int, px)), "true_z_errors": list(map(int, pz)), "actual_observable_flip": actual_obs, "pymatching_observable_pred": pm_obs, "had_syndrome": bool(det_row.any()), "had_errors": bool(px or pz), } records.append(record) if record["had_errors"]: n_with_errors += 1 level_nonempty += 1 else: level_empty += 1 if record["had_syndrome"]: n_with_syndrome += 1 print(f" [{level_name}] {level_nonempty} non-empty + " f"{level_empty} empty (drew {shots_drawn} shots, " f"natural non-empty rate " f"~{level_nonempty / max(1, shots_drawn):.1%})") # Deterministic shuffle: same `seed` -> same row order, but no longer # blocked by level. SFTTrainer's per-epoch shuffle still applies on top # of this; the buffer-shuffle ensures every batch (and every eval # display window) sees a representative L1/L2/L3 mix. rng.shuffle(records) with out_path.open("w") as f: for record in records: f.write(json.dumps(record) + "\n") return len(records), n_with_syndrome, n_with_errors def main(argv: Iterable[str] = ()) -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--n", type=int, default=SFT_DATASET_SIZE, help=f"train split size (default {SFT_DATASET_SIZE})") parser.add_argument("--val-n", type=int, default=SFT_VAL_HOLDOUT, help=f"held-out validation size (default {SFT_VAL_HOLDOUT})") parser.add_argument("--out", type=str, default="data/sft_dataset.jsonl") parser.add_argument("--val-out", type=str, default="data/sft_validation.jsonl") parser.add_argument("--sample-out", type=str, default="data/sft_dataset_sample.jsonl", help="optional small JSONL committed to the repo") parser.add_argument("--sample-size", type=int, default=50) parser.add_argument("--seed", type=int, default=PRIMARY_SEED, help=f"deterministic seed (default {PRIMARY_SEED})") parser.add_argument("--no-validation", action="store_true", help="skip writing the held-out validation split") args = parser.parse_args(list(argv)) train_path = Path(args.out) val_path = Path(args.val_out) sample_path = Path(args.sample_out) sample_path.parent.mkdir(parents=True, exist_ok=True) caches = _build_caches() print(f"prepared caches for {len(caches)} levels") # ---- training split ------------------------------------------------ # train_quotas = _quotas_from_total(args.n, LEVEL_QUOTAS_TRAIN) train_rng = random.Random(args.seed) print(f"writing TRAIN split: n={args.n}, seed={args.seed}, " f"quotas={train_quotas} -> {train_path}") train_written, train_syn, train_err = _generate_split( quotas=train_quotas, seed=args.seed, caches=caches, out_path=train_path, rng=train_rng, ) print(f" wrote {train_written}; syndrome-fraction={train_syn / max(1, train_written):.3f}; " f"non-empty-correction-fraction={train_err / max(1, train_written):.3f}") # ---- validation split (disjoint seed stream) ---------------------- # if not args.no_validation: val_quotas = _quotas_from_total(args.val_n, LEVEL_QUOTAS_VAL) val_seed = args.seed + VALIDATION_SEED_OFFSET val_rng = random.Random(val_seed) print(f"writing VAL split: n={args.val_n}, seed={val_seed}, " f"quotas={val_quotas} -> {val_path}") val_written, val_syn, val_err = _generate_split( quotas=val_quotas, seed=val_seed, caches=caches, out_path=val_path, rng=val_rng, ) print(f" wrote {val_written}; syndrome-fraction={val_syn / max(1, val_written):.3f}; " f"non-empty-correction-fraction={val_err / max(1, val_written):.3f}") # ---- sample preview (for repo commit / eyeball QC) ---------------- # sample_records: list[dict] = [] with train_path.open() as src: for line in src: sample_records.append(json.loads(line)) if len(sample_records) >= args.sample_size: break with sample_path.open("w") as sf: for r in sample_records: sf.write(json.dumps(r) + "\n") print(f"wrote {len(sample_records)} sample records to {sample_path}") # ---- self-audit (fail fast on bad regen) -------------------------- # # Run the same audit train_sft.py runs at startup, so a regen that # silently produced bad data exits non-zero immediately rather than # waiting until the next training launch. Lazy import so we don't # pull in train_sft's heavy ML deps at import time. if not args.no_validation: try: from scripts.train_sft import audit_sft_dataset except ImportError as exc: print(f"[gen] could not run self-audit: {exc}", file=sys.stderr) return 0 print() # blank line before banner audit_sft_dataset(str(train_path), str(val_path)) return 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:]))