ronitraj's picture
deploy via scripts/deploy_to_space.py
12982f6 verified
"""Locked experiment configuration (Section 1.4 of the plan).
Every magic number in the project lives here. Do not hard-code circuit
parameters, noise rates, or model identifiers anywhere else; import them
from this module instead.
Cited literature
----------------
Bausch et al., AlphaQubit, *Nature* 635:834 (2024)
DOI: 10.1038/s41586-024-08148-8
https://www.nature.com/articles/s41586-024-08148-8
Acharya et al. (Google QAI), *Willow*, arXiv:2408.13687 (2024)
https://arxiv.org/abs/2408.13687
Gidney & Fowler, *SI1000*, arXiv:2108.10457 (2021)
https://arxiv.org/abs/2108.10457
Higgott & Gidney, *PyMatching v2*, arXiv:2303.15933 (2023)
https://arxiv.org/abs/2303.15933
Shao et al., *DeepSeekMath / GRPO*, arXiv:2402.03300 (2024)
https://arxiv.org/abs/2402.03300
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Mapping
# --------------------------------------------------------------------------- #
# Quantum code geometry #
# --------------------------------------------------------------------------- #
CODE_TASK = "surface_code:rotated_memory_z"
"""Stim task identifier. We always use the rotated surface code with a Z
memory experiment - same family AlphaQubit and Willow report on."""
DISTANCE_PRIMARY: int = 3
"""Distance-3 is the primary benchmark configuration (AlphaQubit Fig. 2b)."""
DISTANCE_STRETCH: int = 5
"""Distance-5 is the stretch-goal configuration for Section 4.3."""
ROUNDS_FACTOR: int = 1
"""rounds = ROUNDS_FACTOR * distance. Value 1 matches the AlphaQubit
distance-equals-rounds protocol."""
# --------------------------------------------------------------------------- #
# Noise model: SI1000 sub-rates (Gidney & Fowler 2021, Table 1) #
# --------------------------------------------------------------------------- #
# SI1000 maps a single physical error budget ``p`` to four operation-specific
# sub-rates. The factors below come from arXiv:2108.10457 Table 1 and are the
# *same* values Google's QAI uses in their Willow analyses.
#
# Stim's surface_code:rotated_memory_z generator accepts four matching knobs:
# after_clifford_depolarization (two-qubit gate noise)
# before_round_data_depolarization (idle data-qubit noise per round)
# before_measure_flip_probability (measurement noise)
# after_reset_flip_probability (reset noise)
@dataclass(frozen=True)
class SI1000Rates:
"""Per-operation error rates derived from a single budget ``p``."""
after_clifford_depolarization: float
before_round_data_depolarization: float
before_measure_flip_probability: float
after_reset_flip_probability: float
@classmethod
def from_p(cls, p: float) -> "SI1000Rates":
"""Build SI1000 sub-rates from the headline budget ``p``.
The factors are exactly Gidney & Fowler 2021 Table 1.
"""
return cls(
after_clifford_depolarization=p,
before_round_data_depolarization=p / 10.0,
before_measure_flip_probability=p * 5.0,
after_reset_flip_probability=p * 2.0,
)
def as_stim_kwargs(self) -> Mapping[str, float]:
"""Return the kwargs dict accepted by ``stim.Circuit.generated``."""
return {
"after_clifford_depolarization": self.after_clifford_depolarization,
"before_round_data_depolarization": self.before_round_data_depolarization,
"before_measure_flip_probability": self.before_measure_flip_probability,
"after_reset_flip_probability": self.after_reset_flip_probability,
}
# --------------------------------------------------------------------------- #
# Curriculum levels (Section 4) #
# --------------------------------------------------------------------------- #
@dataclass(frozen=True)
class CurriculumLevel:
"""One rung on the difficulty ladder."""
name: str
distance: int
rounds: int
p: float
promotion_threshold: float # logical-correction rate at which we move on
eval_size: int # held-out shots used to test promotion
CURRICULUM: tuple[CurriculumLevel, ...] = (
CurriculumLevel(
name="L1_warmup",
distance=DISTANCE_PRIMARY,
rounds=1,
# 0.0005 (was 0.0001) — at the original budget, L1 syndromes were
# almost always trivial, dragging the SFT class balance down even
# under per-level rejection sampling. Bumping to 0.0005 keeps L1
# strictly easier than L2 (p=0.001) while giving the model real
# non-empty examples to learn from at the warmup stage.
p=0.0005,
promotion_threshold=0.80,
eval_size=100,
),
CurriculumLevel(
name="L2_target",
distance=DISTANCE_PRIMARY,
rounds=DISTANCE_PRIMARY,
p=0.001,
promotion_threshold=0.70,
eval_size=200,
),
CurriculumLevel(
name="L3_stretch",
distance=DISTANCE_STRETCH,
rounds=DISTANCE_STRETCH,
p=0.001,
promotion_threshold=0.30, # stretch goal - even partial counts
eval_size=200,
),
# 2026-04 evaluation-only stress level. Same geometry as L3 but 5x the
# noise rate so:
# * zeros policy drops to ~50-60% LCR
# * pymatching drops to ~80-90% LCR
# leaving real headroom for trained-model differentiation. NOT used
# during training (curriculum scheduler ignores it because it isn't
# in the SFT/GRPO mixes); only invoked via --level L4_stress on
# scripts/eval.py and scripts/eval_remote.py.
#
# NOTE: the deployed HF Space (the canonical remote /reset target)
# was built before this level existed. Remote eval against the Space
# for this level will fail until the Space container is rebuilt; run
# locally via `python -m scripts.eval --level L4_stress ...` instead.
CurriculumLevel(
name="L4_stress",
distance=DISTANCE_STRETCH,
rounds=DISTANCE_STRETCH,
p=0.005,
promotion_threshold=0.20, # eval-only; promotion never triggered
eval_size=200,
),
)
# --------------------------------------------------------------------------- #
# Reward weights (Section 3) - sum to 1.0 by construction #
# --------------------------------------------------------------------------- #
REWARD_WEIGHTS: dict[str, float] = {
"logical_correction": 0.35, # Reward 1 - the unfakeable ground truth
"hamming_overlap": 0.25, # Reward 3 - dense partial credit
"syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
"format_compliance": 0.10, # Reward 4 - parser must succeed
"pymatching_beat": 0.10, # Reward 5 - the headline metric
}
assert abs(sum(REWARD_WEIGHTS.values()) - 1.0) < 1e-9, "reward weights must sum to 1"
# --------------------------------------------------------------------------- #
# Reproducibility #
# --------------------------------------------------------------------------- #
SEEDS: tuple[int, ...] = (42, 1337, 2024)
"""Three seeds for error bars - never run with anything else."""
PRIMARY_SEED: int = SEEDS[0]
# --------------------------------------------------------------------------- #
# Model + training #
# --------------------------------------------------------------------------- #
MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
"""Locked primary model. 3B params, 4-bit quantised + LoRA fits in a Colab T4.
Backup is ``Qwen/Qwen2.5-7B-Instruct`` - only swap if format-test < 30%."""
MODEL_BACKUP_ID: str = "Qwen/Qwen2.5-7B-Instruct"
"""Only swap to this if the pre-onsite format test fails."""
# ---- LoRA (shared SFT + GRPO) -------------------------------------------- #
LORA_R: int = 16
LORA_ALPHA: int = 32 # 2x rank, standard ratio
LORA_DROPOUT: float = 0.10
"""Bumped 0.05 -> 0.10 (2026-04 SFT regularisation) because the prior
SFT runs converged to a single-output mode (every checkpoint reported
output_diversity=1) which left GRPO unable to compute non-zero
within-group reward variance. 0.10 is the spec's first-pass dropout;
the post-SFT diversity preflight will bump to 0.15 if needed."""
LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
# ---- SFT warmup phase (master spec, section 1; 2026-04 regularisation) -- #
# 2026-04 changes (diversity-preserving regularisation): SFT collapsed to
# a constant-output model under the prior settings (LR=2e-4 + dropout=0.05
# + max_steps=200 left every checkpoint at output_diversity=1). New
# defaults trade some ceiling LCR for diversity headroom so GRPO has a
# reward signal to climb.
SFT_EPOCHS: int = 1
SFT_BATCH_SIZE: int = 4
SFT_GRAD_ACCUM: int = 4 # effective batch = 16
SFT_LR: float = 1e-4
"""Halved 2e-4 -> 1e-4 to slow the slide into mode collapse."""
SFT_LR_SCHEDULER: str = "constant_with_warmup" # 20-step warmup then constant
SFT_WARMUP_STEPS: int = 20
SFT_WEIGHT_DECAY: float = 0.01
SFT_LABEL_SMOOTHING: float = 0.05
"""TrainingArguments.label_smoothing_factor; spreads the loss across
non-target tokens so the model is less rewarded for memorising the
single highest-likelihood completion."""
SFT_OPTIMIZER: str = "adamw_8bit"
SFT_DATASET_SIZE: int = 3_000 # 3,000 train + 100 held-out validation
SFT_VAL_HOLDOUT: int = 100
SFT_MAX_SEQ_LEN: int = 1024 # ~300 prompt + ~80 completion + headroom
SFT_MAX_STEPS: int = 50
"""Cut 200 -> 50 so SFT stops well before the model can grind itself
into a single-output mode. The format-only knowledge fits in <50
steps and post-SFT diversity preflight is the gate to GRPO."""
SFT_EVAL_EVERY: int = 25 # legacy fallback if no schedule given
SFT_SAVE_EVERY: int = 25
SFT_LOG_EVERY: int = 10
SFT_PREFLIGHT_DIVERSITY_FLOOR: int = 2
"""eval/output_diversity threshold. If two consecutive evals both report
output_diversity below this floor, the diversity-collapse early stop
fires and SFT exits with reason=diversity_collapse."""
SFT_DIVERSITY_COLLAPSE_RUN_LEN: int = 2
"""Number of consecutive sub-floor evals required before stopping."""
SFT_MAX_NEW_TOKENS: int = 200 # generation cap during eval
# Was 128; bumped to 200 because Qwen2.5-Instruct's cold-start reasoning
# (### Analysis: 1. ... 2. ... 3. ...) regularly runs to 100+ tokens
# before reaching the format line in early SFT steps. With 128, every
# step-5 sample truncated mid-reasoning and format_compliance read 0.
# 200 gives ~70 tokens of headroom past a typical reasoning + format
# completion (~70 tokens total) so truncation never masks the model's
# real behaviour.
# --- Variable eval cadence ------------------------------------------------- #
# Early evals are quick sanity checks (small sample, format-only) so a
# broken parser / generation drift gets caught before ~10 min of compute is
# burned. Late evals are real measurements with the full sample size.
# Catching format-compliance failure at step 15 instead of step 50 saves
# ~7 minutes per fire.
#
# Each entry: (step, sample_size, mode) where mode is "format_only" or
# "full". format_only skips the diversity probe and the physics-heavy
# logical_correction / hamming / syndrome metrics, so the eval costs
# ~30 seconds instead of ~2 minutes.
SFT_EVAL_SCHEDULE: tuple[tuple[int, int, str], ...] = (
# 2026-04: schedule rebuilt to fit the SFT_MAX_STEPS=50 budget. Two
# full evals plus a fast format probe gives the diversity-collapse
# early-stop two consecutive data points before the run ends, which
# is the minimum to fire the new run-length-2 stop rule.
(5, 30, "format_only"),
(15, 50, "full"),
(25, 100, "full"),
(40, 100, "full"),
(50, 100, "full"),
)
SFT_PRINT_SAMPLE_OUTPUTS: int = 5 # raw outputs printed at each eval
# Early-stop thresholds (master spec, section 3).
SFT_EARLY_STOP_FORMAT: float = 0.95
SFT_EARLY_STOP_CORRECTION: float = 0.80
SFT_EARLY_STOP_DIVERSITY: int = 3
SFT_MAX_WALL_SECONDS: float = 30 * 60.0 # 30-minute hard ceiling
# HuggingFace Trainer subfolder (step-50 save) used to initialise GRPO.
# ``python -m scripts.train_grpo`` defaults to this path; pipeline scripts
# also pass it explicitly.
SFT_CHECKPOINT_PATH_FOR_GRPO: str = "checkpoints/sft_warmup/checkpoint-50"
# ---- GRPO RL phase (master spec, section 5; 2026-04 spec rewrite) -------- #
# All numbers below were re-pinned by the 2026-04 GRPO spec. The previous
# defaults (GRPO_STEPS=2000, LR=1e-5, KL=0.04, max_prompt=512,
# max_completion=256, temperature=0.7) produced a degenerate "always say
# []" policy in <100 steps because reward variance collapsed and KL
# saturated the loss. The new defaults emphasise diversity:
#
# - higher temperature (1.2) + top_k + repetition_penalty -> non-collapsed rollouts
# - shorter max_completion_length (50) -> the answer is one short line anyway
# - longer max_prompt_length (1500) -> distance-3 syndromes already use
# ~280 tokens; distance-5 / curriculum L3 needs the headroom
# - lower KL coefficient (0.02) -> reward signal not dominated by KL drift
# - 1500 steps -> wall-clock fits the 13h cap with margin
GRPO_STEPS: int = 1_500
GRPO_GEN_PER_PROMPT: int = 4 # GRPO needs >=2 for advantage
GRPO_BATCH_SIZE: int = 1 # per-device prompts per step
GRPO_GRAD_ACCUM: int = 8 # effective batch = 8 prompts
GRPO_LR: float = 2e-5 # bumped from 1e-5; reward signal is sparse
GRPO_LR_SCHEDULER: str = "constant" # no warmup, no decay
GRPO_KL_COEF: float = 0.02 # half the TRL default; alarm if KL > 0.3
GRPO_MAX_PROMPT_LEN: int = 1_500 # surface-code prompts can run long
GRPO_MAX_COMPLETION_LEN: int = 50 # answer is one line: X_ERRORS=[..] Z_ERRORS=[..]
# ---- Diversity-focused rollout sampling (critical) ----------------------- #
# These apply to GRPO ROLLOUT generation only. Eval uses temperature=0
# (greedy) regardless of these. The combination temperature=1.2 + top_p=0.95
# + top_k=50 + repetition_penalty=1.1 was selected because:
# * temperature=1.2 broadens the per-token distribution past the SFT
# mode-collapsed favourite ("X_ERRORS=[] Z_ERRORS=[]").
# * top_p=0.95 keeps tail tokens in but truncates the long tail.
# * top_k=50 caps the candidate set so we don't sample garbage.
# * repetition_penalty=1.1 discourages the model from repeating the
# exact same byte sequence within a 4-completion group (reduces
# "all 4 generations identical" rate, which kills GRPO's gradient).
GRPO_TEMPERATURE: float = 1.2
GRPO_TOP_P: float = 0.95
GRPO_TOP_K: int = 50
GRPO_REPETITION_PENALTY: float = 1.1
GRPO_DO_SAMPLE: bool = True
# ---- Checkpoint cadence + retention -------------------------------------- #
GRPO_CHECKPOINT_EVERY: int = 100
GRPO_SAVE_TOTAL_LIMIT: int = 3 # keep 3 most recent rolling checkpoints
GRPO_LOG_EVERY: int = 5 # real-time visibility (every 5 steps)
GRPO_OPTIMIZER: str = "adamw_8bit"
GRPO_KL_ALARM: float = 0.3 # >this triggers manual triage
GRPO_KL_HARD_CEIL: float = 0.5 # >this -> kill the run
# ---- Wall-clock safety --------------------------------------------------- #
GRPO_WALL_SECONDS: float = 46_800.0 # 13 hours. Save+exit if exceeded.
# ---- Frozen eval set ----------------------------------------------------- #
# The 200-syndrome eval set is regenerated from the env at GRPO start with
# this seed. Same seed as SFT validation (sft_validation.jsonl) so eval
# distributions are comparable across SFT and GRPO. The set is cached on
# disk under data/grpo_validation.jsonl so reruns hit identical syndromes.
GRPO_VAL_SEED: int = 4_284
GRPO_VAL_EPISODES: int = 200
GRPO_VAL_PATH: str = "data/grpo_validation.jsonl"
# ---- Sample-table logging ------------------------------------------------ #
GRPO_SAMPLE_LOG_EVERY: int = 50
GRPO_SAMPLE_LOG_N: int = 5
# ---- Anti-hacking: mode-collapse inspection hook ------------------------- #
# Every N steps, we sample the most-recent N rollouts and check what
# fraction of prompts had ALL 4 generations identical. If too many
# prompts collapsed, raise the rollout temperature by a fixed step.
GRPO_INSPECTION_HOOK_EVERY: int = 100
GRPO_INSPECTION_SAMPLE_N: int = 10
GRPO_INSPECTION_COLLAPSE_THRESHOLD: int = 7 # "> 7 of 10"
GRPO_TEMP_BUMP_ON_COLLAPSE: float = 0.2
# ---- Decision-rule thresholds (warnings only; no auto-action) ----------- #
GRPO_DECISION_REWARD_STD_FLOOR: float = 0.03
GRPO_DECISION_REWARD_STD_CHECK_STEP: int = 50
GRPO_DECISION_BEAT_RATE_CHECK_STEP: int = 500
GRPO_DECISION_FORMAT_FLOOR: float = 0.95
GRPO_DECISION_GRAD_NORM_CEIL: float = 50.0
GRPO_DECISION_GRAD_NORM_RUN_LEN: int = 3 # consecutive logs
# Decoding sampler defaults at evaluation/format-test time.
# (Used by greedy eval paths: temp/top_p only matter when do_sample=True.)
SAMPLE_TEMPERATURE: float = 0.7
SAMPLE_TOP_P: float = 0.95
# --------------------------------------------------------------------------- #
# Server / deployment #
# --------------------------------------------------------------------------- #
EPISODE_TIMEOUT_SECONDS: float = 30.0
"""Wall-clock budget per episode (Section 2.6)."""
DEFAULT_HOST: str = "0.0.0.0"
DEFAULT_PORT: int = 7860 # Hugging Face Spaces' default exposed port
# --------------------------------------------------------------------------- #
# Weights & Biases #
# --------------------------------------------------------------------------- #
# Centralised so the SFT trainer, GRPO trainer, eval script, and notebook
# all log to the same project / dashboard. Override per-run on the CLI.
import os as _os # noqa: E402 (local import to keep top of module clean)
WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "QuantumScribe-GRPO")
"""Default W&B project name. Override with ``WANDB_PROJECT=...``.
Changed 2026-04 from ``"QuantumScribe"`` to ``"QuantumScribe-GRPO"`` per
the GRPO spec rewrite. SFT runs that should land in the original project
should set ``WANDB_PROJECT=QuantumScribe`` at the shell."""
WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY", "ronitraj") or None
"""W&B team or username. ``None`` -> wandb's default entity for the user."""
WANDB_DEFAULT_TAGS: tuple[str, ...] = (
"qubit-medic",
"quantum-error-correction",
"openenv",
f"distance-{DISTANCE_PRIMARY}",
"si1000",
)
"""Tags applied to every W&B run (per-script tags appended on top)."""
WANDB_LOG_GENERATIONS_EVERY: int = 50
"""Log a sample-completion table every N GRPO steps (master spec sec. 7)."""
WANDB_SAMPLE_GENERATIONS: int = 5
"""Number of generations included in each sample-completion table.
Master spec, section 7: 'Save 5 randomly sampled rollouts ... and their rewards.'"""
WANDB_INLOOP_EVAL_EVERY: int = 100
"""Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
syndromes) every N GRPO steps. Tightened from 250 -> 100 by the 2026-04 GRPO
spec rewrite so collapse / drift gets caught within a 5-minute window
instead of a 15-minute window."""
WANDB_INLOOP_EVAL_EPISODES: int = 200
"""Held-out syndromes per in-loop eval pass. Bumped from 100 -> 200 by the
2026-04 spec rewrite so eval-stat error bars are tight enough to read
pymatching_beat_rate movement (which is sub-5% in early training)."""
WANDB_COMPARE_EVERY: int = 500
"""Run the PyMatching head-to-head comparison every N steps (master spec sec. 7)."""
# --------------------------------------------------------------------------- #
# Convenience accessors #
# --------------------------------------------------------------------------- #
def level_by_name(name: str) -> CurriculumLevel:
for lvl in CURRICULUM:
if lvl.name == name:
return lvl
raise KeyError(f"unknown curriculum level {name!r}")
def primary_level() -> CurriculumLevel:
"""The L2 target benchmark - what the headline numbers come from."""
return level_by_name("L2_target")