Spaces:
Sleeping
Sleeping
| """Locked experiment configuration (Section 1.4 of the plan). | |
| Every magic number in the project lives here. Do not hard-code circuit | |
| parameters, noise rates, or model identifiers anywhere else; import them | |
| from this module instead. | |
| Cited literature | |
| ---------------- | |
| Bausch et al., AlphaQubit, *Nature* 635:834 (2024) | |
| DOI: 10.1038/s41586-024-08148-8 | |
| https://www.nature.com/articles/s41586-024-08148-8 | |
| Acharya et al. (Google QAI), *Willow*, arXiv:2408.13687 (2024) | |
| https://arxiv.org/abs/2408.13687 | |
| Gidney & Fowler, *SI1000*, arXiv:2108.10457 (2021) | |
| https://arxiv.org/abs/2108.10457 | |
| Higgott & Gidney, *PyMatching v2*, arXiv:2303.15933 (2023) | |
| https://arxiv.org/abs/2303.15933 | |
| Shao et al., *DeepSeekMath / GRPO*, arXiv:2402.03300 (2024) | |
| https://arxiv.org/abs/2402.03300 | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Mapping | |
| # --------------------------------------------------------------------------- # | |
| # Quantum code geometry # | |
| # --------------------------------------------------------------------------- # | |
| CODE_TASK = "surface_code:rotated_memory_z" | |
| """Stim task identifier. We always use the rotated surface code with a Z | |
| memory experiment - same family AlphaQubit and Willow report on.""" | |
| DISTANCE_PRIMARY: int = 3 | |
| """Distance-3 is the primary benchmark configuration (AlphaQubit Fig. 2b).""" | |
| DISTANCE_STRETCH: int = 5 | |
| """Distance-5 is the stretch-goal configuration for Section 4.3.""" | |
| ROUNDS_FACTOR: int = 1 | |
| """rounds = ROUNDS_FACTOR * distance. Value 1 matches the AlphaQubit | |
| distance-equals-rounds protocol.""" | |
| # --------------------------------------------------------------------------- # | |
| # Noise model: SI1000 sub-rates (Gidney & Fowler 2021, Table 1) # | |
| # --------------------------------------------------------------------------- # | |
| # SI1000 maps a single physical error budget ``p`` to four operation-specific | |
| # sub-rates. The factors below come from arXiv:2108.10457 Table 1 and are the | |
| # *same* values Google's QAI uses in their Willow analyses. | |
| # | |
| # Stim's surface_code:rotated_memory_z generator accepts four matching knobs: | |
| # after_clifford_depolarization (two-qubit gate noise) | |
| # before_round_data_depolarization (idle data-qubit noise per round) | |
| # before_measure_flip_probability (measurement noise) | |
| # after_reset_flip_probability (reset noise) | |
| class SI1000Rates: | |
| """Per-operation error rates derived from a single budget ``p``.""" | |
| after_clifford_depolarization: float | |
| before_round_data_depolarization: float | |
| before_measure_flip_probability: float | |
| after_reset_flip_probability: float | |
| def from_p(cls, p: float) -> "SI1000Rates": | |
| """Build SI1000 sub-rates from the headline budget ``p``. | |
| The factors are exactly Gidney & Fowler 2021 Table 1. | |
| """ | |
| return cls( | |
| after_clifford_depolarization=p, | |
| before_round_data_depolarization=p / 10.0, | |
| before_measure_flip_probability=p * 5.0, | |
| after_reset_flip_probability=p * 2.0, | |
| ) | |
| def as_stim_kwargs(self) -> Mapping[str, float]: | |
| """Return the kwargs dict accepted by ``stim.Circuit.generated``.""" | |
| return { | |
| "after_clifford_depolarization": self.after_clifford_depolarization, | |
| "before_round_data_depolarization": self.before_round_data_depolarization, | |
| "before_measure_flip_probability": self.before_measure_flip_probability, | |
| "after_reset_flip_probability": self.after_reset_flip_probability, | |
| } | |
| # --------------------------------------------------------------------------- # | |
| # Curriculum levels (Section 4) # | |
| # --------------------------------------------------------------------------- # | |
| class CurriculumLevel: | |
| """One rung on the difficulty ladder.""" | |
| name: str | |
| distance: int | |
| rounds: int | |
| p: float | |
| promotion_threshold: float # logical-correction rate at which we move on | |
| eval_size: int # held-out shots used to test promotion | |
| CURRICULUM: tuple[CurriculumLevel, ...] = ( | |
| CurriculumLevel( | |
| name="L1_warmup", | |
| distance=DISTANCE_PRIMARY, | |
| rounds=1, | |
| # 0.0005 (was 0.0001) — at the original budget, L1 syndromes were | |
| # almost always trivial, dragging the SFT class balance down even | |
| # under per-level rejection sampling. Bumping to 0.0005 keeps L1 | |
| # strictly easier than L2 (p=0.001) while giving the model real | |
| # non-empty examples to learn from at the warmup stage. | |
| p=0.0005, | |
| promotion_threshold=0.80, | |
| eval_size=100, | |
| ), | |
| CurriculumLevel( | |
| name="L2_target", | |
| distance=DISTANCE_PRIMARY, | |
| rounds=DISTANCE_PRIMARY, | |
| p=0.001, | |
| promotion_threshold=0.70, | |
| eval_size=200, | |
| ), | |
| CurriculumLevel( | |
| name="L3_stretch", | |
| distance=DISTANCE_STRETCH, | |
| rounds=DISTANCE_STRETCH, | |
| p=0.001, | |
| promotion_threshold=0.30, # stretch goal - even partial counts | |
| eval_size=200, | |
| ), | |
| # 2026-04 evaluation-only stress level. Same geometry as L3 but 5x the | |
| # noise rate so: | |
| # * zeros policy drops to ~50-60% LCR | |
| # * pymatching drops to ~80-90% LCR | |
| # leaving real headroom for trained-model differentiation. NOT used | |
| # during training (curriculum scheduler ignores it because it isn't | |
| # in the SFT/GRPO mixes); only invoked via --level L4_stress on | |
| # scripts/eval.py and scripts/eval_remote.py. | |
| # | |
| # NOTE: the deployed HF Space (the canonical remote /reset target) | |
| # was built before this level existed. Remote eval against the Space | |
| # for this level will fail until the Space container is rebuilt; run | |
| # locally via `python -m scripts.eval --level L4_stress ...` instead. | |
| CurriculumLevel( | |
| name="L4_stress", | |
| distance=DISTANCE_STRETCH, | |
| rounds=DISTANCE_STRETCH, | |
| p=0.005, | |
| promotion_threshold=0.20, # eval-only; promotion never triggered | |
| eval_size=200, | |
| ), | |
| ) | |
| # --------------------------------------------------------------------------- # | |
| # Reward weights (Section 3) - sum to 1.0 by construction # | |
| # --------------------------------------------------------------------------- # | |
| REWARD_WEIGHTS: dict[str, float] = { | |
| "logical_correction": 0.35, # Reward 1 - the unfakeable ground truth | |
| "hamming_overlap": 0.25, # Reward 3 - dense partial credit | |
| "syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks | |
| "format_compliance": 0.10, # Reward 4 - parser must succeed | |
| "pymatching_beat": 0.10, # Reward 5 - the headline metric | |
| } | |
| assert abs(sum(REWARD_WEIGHTS.values()) - 1.0) < 1e-9, "reward weights must sum to 1" | |
| # --------------------------------------------------------------------------- # | |
| # Reproducibility # | |
| # --------------------------------------------------------------------------- # | |
| SEEDS: tuple[int, ...] = (42, 1337, 2024) | |
| """Three seeds for error bars - never run with anything else.""" | |
| PRIMARY_SEED: int = SEEDS[0] | |
| # --------------------------------------------------------------------------- # | |
| # Model + training # | |
| # --------------------------------------------------------------------------- # | |
| MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct" | |
| """Locked primary model. 3B params, 4-bit quantised + LoRA fits in a Colab T4. | |
| Backup is ``Qwen/Qwen2.5-7B-Instruct`` - only swap if format-test < 30%.""" | |
| MODEL_BACKUP_ID: str = "Qwen/Qwen2.5-7B-Instruct" | |
| """Only swap to this if the pre-onsite format test fails.""" | |
| # ---- LoRA (shared SFT + GRPO) -------------------------------------------- # | |
| LORA_R: int = 16 | |
| LORA_ALPHA: int = 32 # 2x rank, standard ratio | |
| LORA_DROPOUT: float = 0.10 | |
| """Bumped 0.05 -> 0.10 (2026-04 SFT regularisation) because the prior | |
| SFT runs converged to a single-output mode (every checkpoint reported | |
| output_diversity=1) which left GRPO unable to compute non-zero | |
| within-group reward variance. 0.10 is the spec's first-pass dropout; | |
| the post-SFT diversity preflight will bump to 0.15 if needed.""" | |
| LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj") | |
| # ---- SFT warmup phase (master spec, section 1; 2026-04 regularisation) -- # | |
| # 2026-04 changes (diversity-preserving regularisation): SFT collapsed to | |
| # a constant-output model under the prior settings (LR=2e-4 + dropout=0.05 | |
| # + max_steps=200 left every checkpoint at output_diversity=1). New | |
| # defaults trade some ceiling LCR for diversity headroom so GRPO has a | |
| # reward signal to climb. | |
| SFT_EPOCHS: int = 1 | |
| SFT_BATCH_SIZE: int = 4 | |
| SFT_GRAD_ACCUM: int = 4 # effective batch = 16 | |
| SFT_LR: float = 1e-4 | |
| """Halved 2e-4 -> 1e-4 to slow the slide into mode collapse.""" | |
| SFT_LR_SCHEDULER: str = "constant_with_warmup" # 20-step warmup then constant | |
| SFT_WARMUP_STEPS: int = 20 | |
| SFT_WEIGHT_DECAY: float = 0.01 | |
| SFT_LABEL_SMOOTHING: float = 0.05 | |
| """TrainingArguments.label_smoothing_factor; spreads the loss across | |
| non-target tokens so the model is less rewarded for memorising the | |
| single highest-likelihood completion.""" | |
| SFT_OPTIMIZER: str = "adamw_8bit" | |
| SFT_DATASET_SIZE: int = 3_000 # 3,000 train + 100 held-out validation | |
| SFT_VAL_HOLDOUT: int = 100 | |
| SFT_MAX_SEQ_LEN: int = 1024 # ~300 prompt + ~80 completion + headroom | |
| SFT_MAX_STEPS: int = 50 | |
| """Cut 200 -> 50 so SFT stops well before the model can grind itself | |
| into a single-output mode. The format-only knowledge fits in <50 | |
| steps and post-SFT diversity preflight is the gate to GRPO.""" | |
| SFT_EVAL_EVERY: int = 25 # legacy fallback if no schedule given | |
| SFT_SAVE_EVERY: int = 25 | |
| SFT_LOG_EVERY: int = 10 | |
| SFT_PREFLIGHT_DIVERSITY_FLOOR: int = 2 | |
| """eval/output_diversity threshold. If two consecutive evals both report | |
| output_diversity below this floor, the diversity-collapse early stop | |
| fires and SFT exits with reason=diversity_collapse.""" | |
| SFT_DIVERSITY_COLLAPSE_RUN_LEN: int = 2 | |
| """Number of consecutive sub-floor evals required before stopping.""" | |
| SFT_MAX_NEW_TOKENS: int = 200 # generation cap during eval | |
| # Was 128; bumped to 200 because Qwen2.5-Instruct's cold-start reasoning | |
| # (### Analysis: 1. ... 2. ... 3. ...) regularly runs to 100+ tokens | |
| # before reaching the format line in early SFT steps. With 128, every | |
| # step-5 sample truncated mid-reasoning and format_compliance read 0. | |
| # 200 gives ~70 tokens of headroom past a typical reasoning + format | |
| # completion (~70 tokens total) so truncation never masks the model's | |
| # real behaviour. | |
| # --- Variable eval cadence ------------------------------------------------- # | |
| # Early evals are quick sanity checks (small sample, format-only) so a | |
| # broken parser / generation drift gets caught before ~10 min of compute is | |
| # burned. Late evals are real measurements with the full sample size. | |
| # Catching format-compliance failure at step 15 instead of step 50 saves | |
| # ~7 minutes per fire. | |
| # | |
| # Each entry: (step, sample_size, mode) where mode is "format_only" or | |
| # "full". format_only skips the diversity probe and the physics-heavy | |
| # logical_correction / hamming / syndrome metrics, so the eval costs | |
| # ~30 seconds instead of ~2 minutes. | |
| SFT_EVAL_SCHEDULE: tuple[tuple[int, int, str], ...] = ( | |
| # 2026-04: schedule rebuilt to fit the SFT_MAX_STEPS=50 budget. Two | |
| # full evals plus a fast format probe gives the diversity-collapse | |
| # early-stop two consecutive data points before the run ends, which | |
| # is the minimum to fire the new run-length-2 stop rule. | |
| (5, 30, "format_only"), | |
| (15, 50, "full"), | |
| (25, 100, "full"), | |
| (40, 100, "full"), | |
| (50, 100, "full"), | |
| ) | |
| SFT_PRINT_SAMPLE_OUTPUTS: int = 5 # raw outputs printed at each eval | |
| # Early-stop thresholds (master spec, section 3). | |
| SFT_EARLY_STOP_FORMAT: float = 0.95 | |
| SFT_EARLY_STOP_CORRECTION: float = 0.80 | |
| SFT_EARLY_STOP_DIVERSITY: int = 3 | |
| SFT_MAX_WALL_SECONDS: float = 30 * 60.0 # 30-minute hard ceiling | |
| # HuggingFace Trainer subfolder (step-50 save) used to initialise GRPO. | |
| # ``python -m scripts.train_grpo`` defaults to this path; pipeline scripts | |
| # also pass it explicitly. | |
| SFT_CHECKPOINT_PATH_FOR_GRPO: str = "checkpoints/sft_warmup/checkpoint-50" | |
| # ---- GRPO RL phase (master spec, section 5; 2026-04 spec rewrite) -------- # | |
| # All numbers below were re-pinned by the 2026-04 GRPO spec. The previous | |
| # defaults (GRPO_STEPS=2000, LR=1e-5, KL=0.04, max_prompt=512, | |
| # max_completion=256, temperature=0.7) produced a degenerate "always say | |
| # []" policy in <100 steps because reward variance collapsed and KL | |
| # saturated the loss. The new defaults emphasise diversity: | |
| # | |
| # - higher temperature (1.2) + top_k + repetition_penalty -> non-collapsed rollouts | |
| # - shorter max_completion_length (50) -> the answer is one short line anyway | |
| # - longer max_prompt_length (1500) -> distance-3 syndromes already use | |
| # ~280 tokens; distance-5 / curriculum L3 needs the headroom | |
| # - lower KL coefficient (0.02) -> reward signal not dominated by KL drift | |
| # - 1500 steps -> wall-clock fits the 13h cap with margin | |
| GRPO_STEPS: int = 1_500 | |
| GRPO_GEN_PER_PROMPT: int = 4 # GRPO needs >=2 for advantage | |
| GRPO_BATCH_SIZE: int = 1 # per-device prompts per step | |
| GRPO_GRAD_ACCUM: int = 8 # effective batch = 8 prompts | |
| GRPO_LR: float = 2e-5 # bumped from 1e-5; reward signal is sparse | |
| GRPO_LR_SCHEDULER: str = "constant" # no warmup, no decay | |
| GRPO_KL_COEF: float = 0.02 # half the TRL default; alarm if KL > 0.3 | |
| GRPO_MAX_PROMPT_LEN: int = 1_500 # surface-code prompts can run long | |
| GRPO_MAX_COMPLETION_LEN: int = 50 # answer is one line: X_ERRORS=[..] Z_ERRORS=[..] | |
| # ---- Diversity-focused rollout sampling (critical) ----------------------- # | |
| # These apply to GRPO ROLLOUT generation only. Eval uses temperature=0 | |
| # (greedy) regardless of these. The combination temperature=1.2 + top_p=0.95 | |
| # + top_k=50 + repetition_penalty=1.1 was selected because: | |
| # * temperature=1.2 broadens the per-token distribution past the SFT | |
| # mode-collapsed favourite ("X_ERRORS=[] Z_ERRORS=[]"). | |
| # * top_p=0.95 keeps tail tokens in but truncates the long tail. | |
| # * top_k=50 caps the candidate set so we don't sample garbage. | |
| # * repetition_penalty=1.1 discourages the model from repeating the | |
| # exact same byte sequence within a 4-completion group (reduces | |
| # "all 4 generations identical" rate, which kills GRPO's gradient). | |
| GRPO_TEMPERATURE: float = 1.2 | |
| GRPO_TOP_P: float = 0.95 | |
| GRPO_TOP_K: int = 50 | |
| GRPO_REPETITION_PENALTY: float = 1.1 | |
| GRPO_DO_SAMPLE: bool = True | |
| # ---- Checkpoint cadence + retention -------------------------------------- # | |
| GRPO_CHECKPOINT_EVERY: int = 100 | |
| GRPO_SAVE_TOTAL_LIMIT: int = 3 # keep 3 most recent rolling checkpoints | |
| GRPO_LOG_EVERY: int = 5 # real-time visibility (every 5 steps) | |
| GRPO_OPTIMIZER: str = "adamw_8bit" | |
| GRPO_KL_ALARM: float = 0.3 # >this triggers manual triage | |
| GRPO_KL_HARD_CEIL: float = 0.5 # >this -> kill the run | |
| # ---- Wall-clock safety --------------------------------------------------- # | |
| GRPO_WALL_SECONDS: float = 46_800.0 # 13 hours. Save+exit if exceeded. | |
| # ---- Frozen eval set ----------------------------------------------------- # | |
| # The 200-syndrome eval set is regenerated from the env at GRPO start with | |
| # this seed. Same seed as SFT validation (sft_validation.jsonl) so eval | |
| # distributions are comparable across SFT and GRPO. The set is cached on | |
| # disk under data/grpo_validation.jsonl so reruns hit identical syndromes. | |
| GRPO_VAL_SEED: int = 4_284 | |
| GRPO_VAL_EPISODES: int = 200 | |
| GRPO_VAL_PATH: str = "data/grpo_validation.jsonl" | |
| # ---- Sample-table logging ------------------------------------------------ # | |
| GRPO_SAMPLE_LOG_EVERY: int = 50 | |
| GRPO_SAMPLE_LOG_N: int = 5 | |
| # ---- Anti-hacking: mode-collapse inspection hook ------------------------- # | |
| # Every N steps, we sample the most-recent N rollouts and check what | |
| # fraction of prompts had ALL 4 generations identical. If too many | |
| # prompts collapsed, raise the rollout temperature by a fixed step. | |
| GRPO_INSPECTION_HOOK_EVERY: int = 100 | |
| GRPO_INSPECTION_SAMPLE_N: int = 10 | |
| GRPO_INSPECTION_COLLAPSE_THRESHOLD: int = 7 # "> 7 of 10" | |
| GRPO_TEMP_BUMP_ON_COLLAPSE: float = 0.2 | |
| # ---- Decision-rule thresholds (warnings only; no auto-action) ----------- # | |
| GRPO_DECISION_REWARD_STD_FLOOR: float = 0.03 | |
| GRPO_DECISION_REWARD_STD_CHECK_STEP: int = 50 | |
| GRPO_DECISION_BEAT_RATE_CHECK_STEP: int = 500 | |
| GRPO_DECISION_FORMAT_FLOOR: float = 0.95 | |
| GRPO_DECISION_GRAD_NORM_CEIL: float = 50.0 | |
| GRPO_DECISION_GRAD_NORM_RUN_LEN: int = 3 # consecutive logs | |
| # Decoding sampler defaults at evaluation/format-test time. | |
| # (Used by greedy eval paths: temp/top_p only matter when do_sample=True.) | |
| SAMPLE_TEMPERATURE: float = 0.7 | |
| SAMPLE_TOP_P: float = 0.95 | |
| # --------------------------------------------------------------------------- # | |
| # Server / deployment # | |
| # --------------------------------------------------------------------------- # | |
| EPISODE_TIMEOUT_SECONDS: float = 30.0 | |
| """Wall-clock budget per episode (Section 2.6).""" | |
| DEFAULT_HOST: str = "0.0.0.0" | |
| DEFAULT_PORT: int = 7860 # Hugging Face Spaces' default exposed port | |
| # --------------------------------------------------------------------------- # | |
| # Weights & Biases # | |
| # --------------------------------------------------------------------------- # | |
| # Centralised so the SFT trainer, GRPO trainer, eval script, and notebook | |
| # all log to the same project / dashboard. Override per-run on the CLI. | |
| import os as _os # noqa: E402 (local import to keep top of module clean) | |
| WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "QuantumScribe-GRPO") | |
| """Default W&B project name. Override with ``WANDB_PROJECT=...``. | |
| Changed 2026-04 from ``"QuantumScribe"`` to ``"QuantumScribe-GRPO"`` per | |
| the GRPO spec rewrite. SFT runs that should land in the original project | |
| should set ``WANDB_PROJECT=QuantumScribe`` at the shell.""" | |
| WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY", "ronitraj") or None | |
| """W&B team or username. ``None`` -> wandb's default entity for the user.""" | |
| WANDB_DEFAULT_TAGS: tuple[str, ...] = ( | |
| "qubit-medic", | |
| "quantum-error-correction", | |
| "openenv", | |
| f"distance-{DISTANCE_PRIMARY}", | |
| "si1000", | |
| ) | |
| """Tags applied to every W&B run (per-script tags appended on top).""" | |
| WANDB_LOG_GENERATIONS_EVERY: int = 50 | |
| """Log a sample-completion table every N GRPO steps (master spec sec. 7).""" | |
| WANDB_SAMPLE_GENERATIONS: int = 5 | |
| """Number of generations included in each sample-completion table. | |
| Master spec, section 7: 'Save 5 randomly sampled rollouts ... and their rewards.'""" | |
| WANDB_INLOOP_EVAL_EVERY: int = 100 | |
| """Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES`` | |
| syndromes) every N GRPO steps. Tightened from 250 -> 100 by the 2026-04 GRPO | |
| spec rewrite so collapse / drift gets caught within a 5-minute window | |
| instead of a 15-minute window.""" | |
| WANDB_INLOOP_EVAL_EPISODES: int = 200 | |
| """Held-out syndromes per in-loop eval pass. Bumped from 100 -> 200 by the | |
| 2026-04 spec rewrite so eval-stat error bars are tight enough to read | |
| pymatching_beat_rate movement (which is sub-5% in early training).""" | |
| WANDB_COMPARE_EVERY: int = 500 | |
| """Run the PyMatching head-to-head comparison every N steps (master spec sec. 7).""" | |
| # --------------------------------------------------------------------------- # | |
| # Convenience accessors # | |
| # --------------------------------------------------------------------------- # | |
| def level_by_name(name: str) -> CurriculumLevel: | |
| for lvl in CURRICULUM: | |
| if lvl.name == name: | |
| return lvl | |
| raise KeyError(f"unknown curriculum level {name!r}") | |
| def primary_level() -> CurriculumLevel: | |
| """The L2 target benchmark - what the headline numbers come from.""" | |
| return level_by_name("L2_target") | |