Spaces:
Sleeping
Sleeping
File size: 20,242 Bytes
195f87e fa68719 195f87e 12982f6 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 | """Locked experiment configuration (Section 1.4 of the plan).
Every magic number in the project lives here. Do not hard-code circuit
parameters, noise rates, or model identifiers anywhere else; import them
from this module instead.
Cited literature
----------------
Bausch et al., AlphaQubit, *Nature* 635:834 (2024)
DOI: 10.1038/s41586-024-08148-8
https://www.nature.com/articles/s41586-024-08148-8
Acharya et al. (Google QAI), *Willow*, arXiv:2408.13687 (2024)
https://arxiv.org/abs/2408.13687
Gidney & Fowler, *SI1000*, arXiv:2108.10457 (2021)
https://arxiv.org/abs/2108.10457
Higgott & Gidney, *PyMatching v2*, arXiv:2303.15933 (2023)
https://arxiv.org/abs/2303.15933
Shao et al., *DeepSeekMath / GRPO*, arXiv:2402.03300 (2024)
https://arxiv.org/abs/2402.03300
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Mapping
# --------------------------------------------------------------------------- #
# Quantum code geometry #
# --------------------------------------------------------------------------- #
CODE_TASK = "surface_code:rotated_memory_z"
"""Stim task identifier. We always use the rotated surface code with a Z
memory experiment - same family AlphaQubit and Willow report on."""
DISTANCE_PRIMARY: int = 3
"""Distance-3 is the primary benchmark configuration (AlphaQubit Fig. 2b)."""
DISTANCE_STRETCH: int = 5
"""Distance-5 is the stretch-goal configuration for Section 4.3."""
ROUNDS_FACTOR: int = 1
"""rounds = ROUNDS_FACTOR * distance. Value 1 matches the AlphaQubit
distance-equals-rounds protocol."""
# --------------------------------------------------------------------------- #
# Noise model: SI1000 sub-rates (Gidney & Fowler 2021, Table 1) #
# --------------------------------------------------------------------------- #
# SI1000 maps a single physical error budget ``p`` to four operation-specific
# sub-rates. The factors below come from arXiv:2108.10457 Table 1 and are the
# *same* values Google's QAI uses in their Willow analyses.
#
# Stim's surface_code:rotated_memory_z generator accepts four matching knobs:
# after_clifford_depolarization (two-qubit gate noise)
# before_round_data_depolarization (idle data-qubit noise per round)
# before_measure_flip_probability (measurement noise)
# after_reset_flip_probability (reset noise)
@dataclass(frozen=True)
class SI1000Rates:
"""Per-operation error rates derived from a single budget ``p``."""
after_clifford_depolarization: float
before_round_data_depolarization: float
before_measure_flip_probability: float
after_reset_flip_probability: float
@classmethod
def from_p(cls, p: float) -> "SI1000Rates":
"""Build SI1000 sub-rates from the headline budget ``p``.
The factors are exactly Gidney & Fowler 2021 Table 1.
"""
return cls(
after_clifford_depolarization=p,
before_round_data_depolarization=p / 10.0,
before_measure_flip_probability=p * 5.0,
after_reset_flip_probability=p * 2.0,
)
def as_stim_kwargs(self) -> Mapping[str, float]:
"""Return the kwargs dict accepted by ``stim.Circuit.generated``."""
return {
"after_clifford_depolarization": self.after_clifford_depolarization,
"before_round_data_depolarization": self.before_round_data_depolarization,
"before_measure_flip_probability": self.before_measure_flip_probability,
"after_reset_flip_probability": self.after_reset_flip_probability,
}
# --------------------------------------------------------------------------- #
# Curriculum levels (Section 4) #
# --------------------------------------------------------------------------- #
@dataclass(frozen=True)
class CurriculumLevel:
"""One rung on the difficulty ladder."""
name: str
distance: int
rounds: int
p: float
promotion_threshold: float # logical-correction rate at which we move on
eval_size: int # held-out shots used to test promotion
CURRICULUM: tuple[CurriculumLevel, ...] = (
CurriculumLevel(
name="L1_warmup",
distance=DISTANCE_PRIMARY,
rounds=1,
# 0.0005 (was 0.0001) — at the original budget, L1 syndromes were
# almost always trivial, dragging the SFT class balance down even
# under per-level rejection sampling. Bumping to 0.0005 keeps L1
# strictly easier than L2 (p=0.001) while giving the model real
# non-empty examples to learn from at the warmup stage.
p=0.0005,
promotion_threshold=0.80,
eval_size=100,
),
CurriculumLevel(
name="L2_target",
distance=DISTANCE_PRIMARY,
rounds=DISTANCE_PRIMARY,
p=0.001,
promotion_threshold=0.70,
eval_size=200,
),
CurriculumLevel(
name="L3_stretch",
distance=DISTANCE_STRETCH,
rounds=DISTANCE_STRETCH,
p=0.001,
promotion_threshold=0.30, # stretch goal - even partial counts
eval_size=200,
),
# 2026-04 evaluation-only stress level. Same geometry as L3 but 5x the
# noise rate so:
# * zeros policy drops to ~50-60% LCR
# * pymatching drops to ~80-90% LCR
# leaving real headroom for trained-model differentiation. NOT used
# during training (curriculum scheduler ignores it because it isn't
# in the SFT/GRPO mixes); only invoked via --level L4_stress on
# scripts/eval.py and scripts/eval_remote.py.
#
# NOTE: the deployed HF Space (the canonical remote /reset target)
# was built before this level existed. Remote eval against the Space
# for this level will fail until the Space container is rebuilt; run
# locally via `python -m scripts.eval --level L4_stress ...` instead.
CurriculumLevel(
name="L4_stress",
distance=DISTANCE_STRETCH,
rounds=DISTANCE_STRETCH,
p=0.005,
promotion_threshold=0.20, # eval-only; promotion never triggered
eval_size=200,
),
)
# --------------------------------------------------------------------------- #
# Reward weights (Section 3) - sum to 1.0 by construction #
# --------------------------------------------------------------------------- #
REWARD_WEIGHTS: dict[str, float] = {
"logical_correction": 0.35, # Reward 1 - the unfakeable ground truth
"hamming_overlap": 0.25, # Reward 3 - dense partial credit
"syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
"format_compliance": 0.10, # Reward 4 - parser must succeed
"pymatching_beat": 0.10, # Reward 5 - the headline metric
}
assert abs(sum(REWARD_WEIGHTS.values()) - 1.0) < 1e-9, "reward weights must sum to 1"
# --------------------------------------------------------------------------- #
# Reproducibility #
# --------------------------------------------------------------------------- #
SEEDS: tuple[int, ...] = (42, 1337, 2024)
"""Three seeds for error bars - never run with anything else."""
PRIMARY_SEED: int = SEEDS[0]
# --------------------------------------------------------------------------- #
# Model + training #
# --------------------------------------------------------------------------- #
MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
"""Locked primary model. 3B params, 4-bit quantised + LoRA fits in a Colab T4.
Backup is ``Qwen/Qwen2.5-7B-Instruct`` - only swap if format-test < 30%."""
MODEL_BACKUP_ID: str = "Qwen/Qwen2.5-7B-Instruct"
"""Only swap to this if the pre-onsite format test fails."""
# ---- LoRA (shared SFT + GRPO) -------------------------------------------- #
LORA_R: int = 16
LORA_ALPHA: int = 32 # 2x rank, standard ratio
LORA_DROPOUT: float = 0.10
"""Bumped 0.05 -> 0.10 (2026-04 SFT regularisation) because the prior
SFT runs converged to a single-output mode (every checkpoint reported
output_diversity=1) which left GRPO unable to compute non-zero
within-group reward variance. 0.10 is the spec's first-pass dropout;
the post-SFT diversity preflight will bump to 0.15 if needed."""
LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
# ---- SFT warmup phase (master spec, section 1; 2026-04 regularisation) -- #
# 2026-04 changes (diversity-preserving regularisation): SFT collapsed to
# a constant-output model under the prior settings (LR=2e-4 + dropout=0.05
# + max_steps=200 left every checkpoint at output_diversity=1). New
# defaults trade some ceiling LCR for diversity headroom so GRPO has a
# reward signal to climb.
SFT_EPOCHS: int = 1
SFT_BATCH_SIZE: int = 4
SFT_GRAD_ACCUM: int = 4 # effective batch = 16
SFT_LR: float = 1e-4
"""Halved 2e-4 -> 1e-4 to slow the slide into mode collapse."""
SFT_LR_SCHEDULER: str = "constant_with_warmup" # 20-step warmup then constant
SFT_WARMUP_STEPS: int = 20
SFT_WEIGHT_DECAY: float = 0.01
SFT_LABEL_SMOOTHING: float = 0.05
"""TrainingArguments.label_smoothing_factor; spreads the loss across
non-target tokens so the model is less rewarded for memorising the
single highest-likelihood completion."""
SFT_OPTIMIZER: str = "adamw_8bit"
SFT_DATASET_SIZE: int = 3_000 # 3,000 train + 100 held-out validation
SFT_VAL_HOLDOUT: int = 100
SFT_MAX_SEQ_LEN: int = 1024 # ~300 prompt + ~80 completion + headroom
SFT_MAX_STEPS: int = 50
"""Cut 200 -> 50 so SFT stops well before the model can grind itself
into a single-output mode. The format-only knowledge fits in <50
steps and post-SFT diversity preflight is the gate to GRPO."""
SFT_EVAL_EVERY: int = 25 # legacy fallback if no schedule given
SFT_SAVE_EVERY: int = 25
SFT_LOG_EVERY: int = 10
SFT_PREFLIGHT_DIVERSITY_FLOOR: int = 2
"""eval/output_diversity threshold. If two consecutive evals both report
output_diversity below this floor, the diversity-collapse early stop
fires and SFT exits with reason=diversity_collapse."""
SFT_DIVERSITY_COLLAPSE_RUN_LEN: int = 2
"""Number of consecutive sub-floor evals required before stopping."""
SFT_MAX_NEW_TOKENS: int = 200 # generation cap during eval
# Was 128; bumped to 200 because Qwen2.5-Instruct's cold-start reasoning
# (### Analysis: 1. ... 2. ... 3. ...) regularly runs to 100+ tokens
# before reaching the format line in early SFT steps. With 128, every
# step-5 sample truncated mid-reasoning and format_compliance read 0.
# 200 gives ~70 tokens of headroom past a typical reasoning + format
# completion (~70 tokens total) so truncation never masks the model's
# real behaviour.
# --- Variable eval cadence ------------------------------------------------- #
# Early evals are quick sanity checks (small sample, format-only) so a
# broken parser / generation drift gets caught before ~10 min of compute is
# burned. Late evals are real measurements with the full sample size.
# Catching format-compliance failure at step 15 instead of step 50 saves
# ~7 minutes per fire.
#
# Each entry: (step, sample_size, mode) where mode is "format_only" or
# "full". format_only skips the diversity probe and the physics-heavy
# logical_correction / hamming / syndrome metrics, so the eval costs
# ~30 seconds instead of ~2 minutes.
SFT_EVAL_SCHEDULE: tuple[tuple[int, int, str], ...] = (
# 2026-04: schedule rebuilt to fit the SFT_MAX_STEPS=50 budget. Two
# full evals plus a fast format probe gives the diversity-collapse
# early-stop two consecutive data points before the run ends, which
# is the minimum to fire the new run-length-2 stop rule.
(5, 30, "format_only"),
(15, 50, "full"),
(25, 100, "full"),
(40, 100, "full"),
(50, 100, "full"),
)
SFT_PRINT_SAMPLE_OUTPUTS: int = 5 # raw outputs printed at each eval
# Early-stop thresholds (master spec, section 3).
SFT_EARLY_STOP_FORMAT: float = 0.95
SFT_EARLY_STOP_CORRECTION: float = 0.80
SFT_EARLY_STOP_DIVERSITY: int = 3
SFT_MAX_WALL_SECONDS: float = 30 * 60.0 # 30-minute hard ceiling
# HuggingFace Trainer subfolder (step-50 save) used to initialise GRPO.
# ``python -m scripts.train_grpo`` defaults to this path; pipeline scripts
# also pass it explicitly.
SFT_CHECKPOINT_PATH_FOR_GRPO: str = "checkpoints/sft_warmup/checkpoint-50"
# ---- GRPO RL phase (master spec, section 5; 2026-04 spec rewrite) -------- #
# All numbers below were re-pinned by the 2026-04 GRPO spec. The previous
# defaults (GRPO_STEPS=2000, LR=1e-5, KL=0.04, max_prompt=512,
# max_completion=256, temperature=0.7) produced a degenerate "always say
# []" policy in <100 steps because reward variance collapsed and KL
# saturated the loss. The new defaults emphasise diversity:
#
# - higher temperature (1.2) + top_k + repetition_penalty -> non-collapsed rollouts
# - shorter max_completion_length (50) -> the answer is one short line anyway
# - longer max_prompt_length (1500) -> distance-3 syndromes already use
# ~280 tokens; distance-5 / curriculum L3 needs the headroom
# - lower KL coefficient (0.02) -> reward signal not dominated by KL drift
# - 1500 steps -> wall-clock fits the 13h cap with margin
GRPO_STEPS: int = 1_500
GRPO_GEN_PER_PROMPT: int = 4 # GRPO needs >=2 for advantage
GRPO_BATCH_SIZE: int = 1 # per-device prompts per step
GRPO_GRAD_ACCUM: int = 8 # effective batch = 8 prompts
GRPO_LR: float = 2e-5 # bumped from 1e-5; reward signal is sparse
GRPO_LR_SCHEDULER: str = "constant" # no warmup, no decay
GRPO_KL_COEF: float = 0.02 # half the TRL default; alarm if KL > 0.3
GRPO_MAX_PROMPT_LEN: int = 1_500 # surface-code prompts can run long
GRPO_MAX_COMPLETION_LEN: int = 50 # answer is one line: X_ERRORS=[..] Z_ERRORS=[..]
# ---- Diversity-focused rollout sampling (critical) ----------------------- #
# These apply to GRPO ROLLOUT generation only. Eval uses temperature=0
# (greedy) regardless of these. The combination temperature=1.2 + top_p=0.95
# + top_k=50 + repetition_penalty=1.1 was selected because:
# * temperature=1.2 broadens the per-token distribution past the SFT
# mode-collapsed favourite ("X_ERRORS=[] Z_ERRORS=[]").
# * top_p=0.95 keeps tail tokens in but truncates the long tail.
# * top_k=50 caps the candidate set so we don't sample garbage.
# * repetition_penalty=1.1 discourages the model from repeating the
# exact same byte sequence within a 4-completion group (reduces
# "all 4 generations identical" rate, which kills GRPO's gradient).
GRPO_TEMPERATURE: float = 1.2
GRPO_TOP_P: float = 0.95
GRPO_TOP_K: int = 50
GRPO_REPETITION_PENALTY: float = 1.1
GRPO_DO_SAMPLE: bool = True
# ---- Checkpoint cadence + retention -------------------------------------- #
GRPO_CHECKPOINT_EVERY: int = 100
GRPO_SAVE_TOTAL_LIMIT: int = 3 # keep 3 most recent rolling checkpoints
GRPO_LOG_EVERY: int = 5 # real-time visibility (every 5 steps)
GRPO_OPTIMIZER: str = "adamw_8bit"
GRPO_KL_ALARM: float = 0.3 # >this triggers manual triage
GRPO_KL_HARD_CEIL: float = 0.5 # >this -> kill the run
# ---- Wall-clock safety --------------------------------------------------- #
GRPO_WALL_SECONDS: float = 46_800.0 # 13 hours. Save+exit if exceeded.
# ---- Frozen eval set ----------------------------------------------------- #
# The 200-syndrome eval set is regenerated from the env at GRPO start with
# this seed. Same seed as SFT validation (sft_validation.jsonl) so eval
# distributions are comparable across SFT and GRPO. The set is cached on
# disk under data/grpo_validation.jsonl so reruns hit identical syndromes.
GRPO_VAL_SEED: int = 4_284
GRPO_VAL_EPISODES: int = 200
GRPO_VAL_PATH: str = "data/grpo_validation.jsonl"
# ---- Sample-table logging ------------------------------------------------ #
GRPO_SAMPLE_LOG_EVERY: int = 50
GRPO_SAMPLE_LOG_N: int = 5
# ---- Anti-hacking: mode-collapse inspection hook ------------------------- #
# Every N steps, we sample the most-recent N rollouts and check what
# fraction of prompts had ALL 4 generations identical. If too many
# prompts collapsed, raise the rollout temperature by a fixed step.
GRPO_INSPECTION_HOOK_EVERY: int = 100
GRPO_INSPECTION_SAMPLE_N: int = 10
GRPO_INSPECTION_COLLAPSE_THRESHOLD: int = 7 # "> 7 of 10"
GRPO_TEMP_BUMP_ON_COLLAPSE: float = 0.2
# ---- Decision-rule thresholds (warnings only; no auto-action) ----------- #
GRPO_DECISION_REWARD_STD_FLOOR: float = 0.03
GRPO_DECISION_REWARD_STD_CHECK_STEP: int = 50
GRPO_DECISION_BEAT_RATE_CHECK_STEP: int = 500
GRPO_DECISION_FORMAT_FLOOR: float = 0.95
GRPO_DECISION_GRAD_NORM_CEIL: float = 50.0
GRPO_DECISION_GRAD_NORM_RUN_LEN: int = 3 # consecutive logs
# Decoding sampler defaults at evaluation/format-test time.
# (Used by greedy eval paths: temp/top_p only matter when do_sample=True.)
SAMPLE_TEMPERATURE: float = 0.7
SAMPLE_TOP_P: float = 0.95
# --------------------------------------------------------------------------- #
# Server / deployment #
# --------------------------------------------------------------------------- #
EPISODE_TIMEOUT_SECONDS: float = 30.0
"""Wall-clock budget per episode (Section 2.6)."""
DEFAULT_HOST: str = "0.0.0.0"
DEFAULT_PORT: int = 7860 # Hugging Face Spaces' default exposed port
# --------------------------------------------------------------------------- #
# Weights & Biases #
# --------------------------------------------------------------------------- #
# Centralised so the SFT trainer, GRPO trainer, eval script, and notebook
# all log to the same project / dashboard. Override per-run on the CLI.
import os as _os # noqa: E402 (local import to keep top of module clean)
WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "QuantumScribe-GRPO")
"""Default W&B project name. Override with ``WANDB_PROJECT=...``.
Changed 2026-04 from ``"QuantumScribe"`` to ``"QuantumScribe-GRPO"`` per
the GRPO spec rewrite. SFT runs that should land in the original project
should set ``WANDB_PROJECT=QuantumScribe`` at the shell."""
WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY", "ronitraj") or None
"""W&B team or username. ``None`` -> wandb's default entity for the user."""
WANDB_DEFAULT_TAGS: tuple[str, ...] = (
"qubit-medic",
"quantum-error-correction",
"openenv",
f"distance-{DISTANCE_PRIMARY}",
"si1000",
)
"""Tags applied to every W&B run (per-script tags appended on top)."""
WANDB_LOG_GENERATIONS_EVERY: int = 50
"""Log a sample-completion table every N GRPO steps (master spec sec. 7)."""
WANDB_SAMPLE_GENERATIONS: int = 5
"""Number of generations included in each sample-completion table.
Master spec, section 7: 'Save 5 randomly sampled rollouts ... and their rewards.'"""
WANDB_INLOOP_EVAL_EVERY: int = 100
"""Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
syndromes) every N GRPO steps. Tightened from 250 -> 100 by the 2026-04 GRPO
spec rewrite so collapse / drift gets caught within a 5-minute window
instead of a 15-minute window."""
WANDB_INLOOP_EVAL_EPISODES: int = 200
"""Held-out syndromes per in-loop eval pass. Bumped from 100 -> 200 by the
2026-04 spec rewrite so eval-stat error bars are tight enough to read
pymatching_beat_rate movement (which is sub-5% in early training)."""
WANDB_COMPARE_EVERY: int = 500
"""Run the PyMatching head-to-head comparison every N steps (master spec sec. 7)."""
# --------------------------------------------------------------------------- #
# Convenience accessors #
# --------------------------------------------------------------------------- #
def level_by_name(name: str) -> CurriculumLevel:
for lvl in CURRICULUM:
if lvl.name == name:
return lvl
raise KeyError(f"unknown curriculum level {name!r}")
def primary_level() -> CurriculumLevel:
"""The L2 target benchmark - what the headline numbers come from."""
return level_by_name("L2_target")
|