Spaces:
Sleeping
Sleeping
deploy via scripts/deploy_to_space.py
Browse files- qubit_medic/__pycache__/__init__.cpython-312.pyc +0 -0
- qubit_medic/__pycache__/__init__.cpython-314.pyc +0 -0
- qubit_medic/__pycache__/config.cpython-312.pyc +0 -0
- qubit_medic/__pycache__/config.cpython-314.pyc +0 -0
- qubit_medic/__pycache__/models.cpython-312.pyc +0 -0
- qubit_medic/__pycache__/models.cpython-314.pyc +0 -0
- qubit_medic/__pycache__/prompts.cpython-312.pyc +0 -0
- qubit_medic/__pycache__/training_stack.cpython-312.pyc +0 -0
- qubit_medic/__pycache__/wandb_utils.cpython-312.pyc +0 -0
- qubit_medic/client/__pycache__/__init__.cpython-312.pyc +0 -0
- qubit_medic/client/__pycache__/client.cpython-312.pyc +0 -0
- qubit_medic/client/client.py +40 -2
- qubit_medic/config.py +201 -28
- qubit_medic/prompts.py +190 -87
- qubit_medic/server/__pycache__/__init__.cpython-312.pyc +0 -0
- qubit_medic/server/__pycache__/app.cpython-312.pyc +0 -0
- qubit_medic/server/__pycache__/curriculum.cpython-312.pyc +0 -0
- qubit_medic/server/__pycache__/environment.cpython-312.pyc +0 -0
- qubit_medic/server/__pycache__/openenv_adapter.cpython-312.pyc +0 -0
- qubit_medic/server/__pycache__/physics.cpython-312.pyc +0 -0
- qubit_medic/server/__pycache__/rewards.cpython-312.pyc +0 -0
- qubit_medic/server/app.py +76 -0
- qubit_medic/server/environment.py +44 -3
- qubit_medic/server/rewards.py +76 -23
- qubit_medic/wandb_utils.py +12 -2
qubit_medic/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (818 Bytes). View file
|
|
|
qubit_medic/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (815 Bytes). View file
|
|
|
qubit_medic/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (9.72 kB). View file
|
|
|
qubit_medic/__pycache__/config.cpython-314.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
qubit_medic/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
qubit_medic/__pycache__/models.cpython-314.pyc
ADDED
|
Binary file (5.62 kB). View file
|
|
|
qubit_medic/__pycache__/prompts.cpython-312.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
qubit_medic/__pycache__/training_stack.cpython-312.pyc
ADDED
|
Binary file (6.59 kB). View file
|
|
|
qubit_medic/__pycache__/wandb_utils.cpython-312.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
qubit_medic/client/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (414 Bytes). View file
|
|
|
qubit_medic/client/__pycache__/client.cpython-312.pyc
ADDED
|
Binary file (9.23 kB). View file
|
|
|
qubit_medic/client/client.py
CHANGED
|
@@ -28,6 +28,10 @@ class _ClientProtocol(Protocol):
|
|
| 28 |
def reset(self, *, seed: Optional[int] = None,
|
| 29 |
forced_level: Optional[str] = None) -> DecoderObservation: ...
|
| 30 |
def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def health(self) -> dict: ...
|
| 32 |
def close(self) -> None: ...
|
| 33 |
|
|
@@ -89,6 +93,20 @@ class DecoderClient:
|
|
| 89 |
info=dict(obs_payload.get("info", {})),
|
| 90 |
)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def health(self) -> dict:
|
| 93 |
r = self._client.get("/health")
|
| 94 |
r.raise_for_status()
|
|
@@ -100,6 +118,14 @@ class DecoderClient:
|
|
| 100 |
return r.json()
|
| 101 |
|
| 102 |
def close(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
self._client.close()
|
| 104 |
|
| 105 |
|
|
@@ -117,11 +143,23 @@ class LocalDecoderClient:
|
|
| 117 |
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
|
| 118 |
return self._env.step(raw_response=raw_response, episode_id=episode_id)
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
def health(self) -> dict:
|
| 121 |
return self._env.health()
|
| 122 |
|
| 123 |
-
def close(self) -> None:
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def make_default_client() -> _ClientProtocol:
|
|
|
|
| 28 |
def reset(self, *, seed: Optional[int] = None,
|
| 29 |
forced_level: Optional[str] = None) -> DecoderObservation: ...
|
| 30 |
def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
|
| 31 |
+
# Compliance Section 3 (audit, 2026-04): the client surface must
|
| 32 |
+
# mirror the server endpoints. state() returns a JSON-serialisable
|
| 33 |
+
# snapshot; close() releases per-episode bookkeeping.
|
| 34 |
+
def state(self) -> dict: ...
|
| 35 |
def health(self) -> dict: ...
|
| 36 |
def close(self) -> None: ...
|
| 37 |
|
|
|
|
| 93 |
info=dict(obs_payload.get("info", {})),
|
| 94 |
)
|
| 95 |
|
| 96 |
+
def state(self) -> dict:
|
| 97 |
+
"""GET /state on the OpenEnv server.
|
| 98 |
+
|
| 99 |
+
Compliance Section 3 (audit, 2026-04): the client must mirror
|
| 100 |
+
the server endpoints. We use GET (the OpenEnv canonical method)
|
| 101 |
+
first, then fall back to POST (the audit-required method we
|
| 102 |
+
also mounted) if some server build only exposes one of them.
|
| 103 |
+
"""
|
| 104 |
+
r = self._client.get("/state")
|
| 105 |
+
if r.status_code == 405: # method not allowed -> try POST
|
| 106 |
+
r = self._client.post("/state")
|
| 107 |
+
r.raise_for_status()
|
| 108 |
+
return r.json()
|
| 109 |
+
|
| 110 |
def health(self) -> dict:
|
| 111 |
r = self._client.get("/health")
|
| 112 |
r.raise_for_status()
|
|
|
|
| 118 |
return r.json()
|
| 119 |
|
| 120 |
def close(self) -> None:
|
| 121 |
+
# Best-effort: tell the server we're done (the POST /close route
|
| 122 |
+
# is mounted by qubit_medic.server.app) and then release the
|
| 123 |
+
# local httpx connection pool. If the server doesn't expose
|
| 124 |
+
# /close, swallow the 404 - this remains an idempotent cleanup.
|
| 125 |
+
try:
|
| 126 |
+
self._client.post("/close")
|
| 127 |
+
except Exception:
|
| 128 |
+
pass
|
| 129 |
self._client.close()
|
| 130 |
|
| 131 |
|
|
|
|
| 143 |
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
|
| 144 |
return self._env.step(raw_response=raw_response, episode_id=episode_id)
|
| 145 |
|
| 146 |
+
def state(self) -> dict:
|
| 147 |
+
"""Compliance Section 3 (audit, 2026-04): expose env state via
|
| 148 |
+
the same client surface as the HTTP variant. Delegates to the
|
| 149 |
+
in-process :meth:`DecoderEnvironment.state`."""
|
| 150 |
+
return self._env.state()
|
| 151 |
+
|
| 152 |
def health(self) -> dict:
|
| 153 |
return self._env.health()
|
| 154 |
|
| 155 |
+
def close(self) -> None:
|
| 156 |
+
# Compliance Section 3 (audit, 2026-04): close releases any
|
| 157 |
+
# per-episode bookkeeping on the inner DecoderEnvironment so a
|
| 158 |
+
# subsequent reset() starts from a clean active-episode dict.
|
| 159 |
+
try:
|
| 160 |
+
self._env.close()
|
| 161 |
+
except Exception:
|
| 162 |
+
pass
|
| 163 |
|
| 164 |
|
| 165 |
def make_default_client() -> _ClientProtocol:
|
qubit_medic/config.py
CHANGED
|
@@ -111,7 +111,12 @@ CURRICULUM: tuple[CurriculumLevel, ...] = (
|
|
| 111 |
name="L1_warmup",
|
| 112 |
distance=DISTANCE_PRIMARY,
|
| 113 |
rounds=1,
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
promotion_threshold=0.80,
|
| 116 |
eval_size=100,
|
| 117 |
),
|
|
@@ -139,9 +144,9 @@ CURRICULUM: tuple[CurriculumLevel, ...] = (
|
|
| 139 |
# --------------------------------------------------------------------------- #
|
| 140 |
|
| 141 |
REWARD_WEIGHTS: dict[str, float] = {
|
| 142 |
-
"logical_correction": 0.
|
|
|
|
| 143 |
"syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
|
| 144 |
-
"hamming_overlap": 0.20, # Reward 3 - dense partial credit
|
| 145 |
"format_compliance": 0.10, # Reward 4 - parser must succeed
|
| 146 |
"pymatching_beat": 0.10, # Reward 5 - the headline metric
|
| 147 |
}
|
|
@@ -163,29 +168,185 @@ PRIMARY_SEED: int = SEEDS[0]
|
|
| 163 |
# --------------------------------------------------------------------------- #
|
| 164 |
|
| 165 |
MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
|
| 166 |
-
"""3B params, 4-bit quantised + LoRA fits in a Colab T4.
|
|
|
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
LORA_R: int = 16
|
| 169 |
-
LORA_ALPHA: int = 32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
SFT_EPOCHS: int = 1
|
| 173 |
SFT_BATCH_SIZE: int = 4
|
| 174 |
-
SFT_GRAD_ACCUM: int = 4
|
| 175 |
-
SFT_LR: float =
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# Decoding sampler defaults at evaluation/format-test time.
|
|
|
|
| 189 |
SAMPLE_TEMPERATURE: float = 0.7
|
| 190 |
SAMPLE_TOP_P: float = 0.95
|
| 191 |
|
|
@@ -208,10 +369,14 @@ DEFAULT_PORT: int = 7860 # Hugging Face Spaces' default exposed port
|
|
| 208 |
# all log to the same project / dashboard. Override per-run on the CLI.
|
| 209 |
import os as _os # noqa: E402 (local import to keep top of module clean)
|
| 210 |
|
| 211 |
-
WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "
|
| 212 |
-
"""Default W&B project name. Override with ``WANDB_PROJECT=...``.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY") or None
|
| 215 |
"""W&B team or username. ``None`` -> wandb's default entity for the user."""
|
| 216 |
|
| 217 |
WANDB_DEFAULT_TAGS: tuple[str, ...] = (
|
|
@@ -224,17 +389,25 @@ WANDB_DEFAULT_TAGS: tuple[str, ...] = (
|
|
| 224 |
"""Tags applied to every W&B run (per-script tags appended on top)."""
|
| 225 |
|
| 226 |
WANDB_LOG_GENERATIONS_EVERY: int = 50
|
| 227 |
-
"""Log a sample-completion table every N GRPO steps."""
|
| 228 |
|
| 229 |
-
WANDB_SAMPLE_GENERATIONS: int =
|
| 230 |
-
"""Number of generations included in each sample-completion table.
|
|
|
|
| 231 |
|
| 232 |
-
WANDB_INLOOP_EVAL_EVERY: int =
|
| 233 |
"""Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
|
| 234 |
-
syndromes) every N GRPO steps.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
-
"""
|
| 238 |
|
| 239 |
|
| 240 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 111 |
name="L1_warmup",
|
| 112 |
distance=DISTANCE_PRIMARY,
|
| 113 |
rounds=1,
|
| 114 |
+
# 0.0005 (was 0.0001) — at the original budget, L1 syndromes were
|
| 115 |
+
# almost always trivial, dragging the SFT class balance down even
|
| 116 |
+
# under per-level rejection sampling. Bumping to 0.0005 keeps L1
|
| 117 |
+
# strictly easier than L2 (p=0.001) while giving the model real
|
| 118 |
+
# non-empty examples to learn from at the warmup stage.
|
| 119 |
+
p=0.0005,
|
| 120 |
promotion_threshold=0.80,
|
| 121 |
eval_size=100,
|
| 122 |
),
|
|
|
|
| 144 |
# --------------------------------------------------------------------------- #
|
| 145 |
|
| 146 |
REWARD_WEIGHTS: dict[str, float] = {
|
| 147 |
+
"logical_correction": 0.35, # Reward 1 - the unfakeable ground truth
|
| 148 |
+
"hamming_overlap": 0.25, # Reward 3 - dense partial credit
|
| 149 |
"syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
|
|
|
|
| 150 |
"format_compliance": 0.10, # Reward 4 - parser must succeed
|
| 151 |
"pymatching_beat": 0.10, # Reward 5 - the headline metric
|
| 152 |
}
|
|
|
|
| 168 |
# --------------------------------------------------------------------------- #
|
| 169 |
|
| 170 |
MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
|
| 171 |
+
"""Locked primary model. 3B params, 4-bit quantised + LoRA fits in a Colab T4.
|
| 172 |
+
Backup is ``Qwen/Qwen2.5-7B-Instruct`` - only swap if format-test < 30%."""
|
| 173 |
|
| 174 |
+
MODEL_BACKUP_ID: str = "Qwen/Qwen2.5-7B-Instruct"
|
| 175 |
+
"""Only swap to this if the pre-onsite format test fails."""
|
| 176 |
+
|
| 177 |
+
# ---- LoRA (shared SFT + GRPO) -------------------------------------------- #
|
| 178 |
LORA_R: int = 16
|
| 179 |
+
LORA_ALPHA: int = 32 # 2x rank, standard ratio
|
| 180 |
+
LORA_DROPOUT: float = 0.10
|
| 181 |
+
"""Bumped 0.05 -> 0.10 (2026-04 SFT regularisation) because the prior
|
| 182 |
+
SFT runs converged to a single-output mode (every checkpoint reported
|
| 183 |
+
output_diversity=1) which left GRPO unable to compute non-zero
|
| 184 |
+
within-group reward variance. 0.10 is the spec's first-pass dropout;
|
| 185 |
+
the post-SFT diversity preflight will bump to 0.15 if needed."""
|
| 186 |
LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
|
| 187 |
|
| 188 |
+
# ---- SFT warmup phase (master spec, section 1; 2026-04 regularisation) -- #
|
| 189 |
+
# 2026-04 changes (diversity-preserving regularisation): SFT collapsed to
|
| 190 |
+
# a constant-output model under the prior settings (LR=2e-4 + dropout=0.05
|
| 191 |
+
# + max_steps=200 left every checkpoint at output_diversity=1). New
|
| 192 |
+
# defaults trade some ceiling LCR for diversity headroom so GRPO has a
|
| 193 |
+
# reward signal to climb.
|
| 194 |
SFT_EPOCHS: int = 1
|
| 195 |
SFT_BATCH_SIZE: int = 4
|
| 196 |
+
SFT_GRAD_ACCUM: int = 4 # effective batch = 16
|
| 197 |
+
SFT_LR: float = 1e-4
|
| 198 |
+
"""Halved 2e-4 -> 1e-4 to slow the slide into mode collapse."""
|
| 199 |
+
SFT_LR_SCHEDULER: str = "constant_with_warmup" # 20-step warmup then constant
|
| 200 |
+
SFT_WARMUP_STEPS: int = 20
|
| 201 |
+
SFT_WEIGHT_DECAY: float = 0.01
|
| 202 |
+
SFT_LABEL_SMOOTHING: float = 0.05
|
| 203 |
+
"""TrainingArguments.label_smoothing_factor; spreads the loss across
|
| 204 |
+
non-target tokens so the model is less rewarded for memorising the
|
| 205 |
+
single highest-likelihood completion."""
|
| 206 |
+
SFT_OPTIMIZER: str = "adamw_8bit"
|
| 207 |
+
SFT_DATASET_SIZE: int = 3_000 # 3,000 train + 100 held-out validation
|
| 208 |
+
SFT_VAL_HOLDOUT: int = 100
|
| 209 |
+
SFT_MAX_SEQ_LEN: int = 1024 # ~300 prompt + ~80 completion + headroom
|
| 210 |
+
SFT_MAX_STEPS: int = 50
|
| 211 |
+
"""Cut 200 -> 50 so SFT stops well before the model can grind itself
|
| 212 |
+
into a single-output mode. The format-only knowledge fits in <50
|
| 213 |
+
steps and post-SFT diversity preflight is the gate to GRPO."""
|
| 214 |
+
SFT_EVAL_EVERY: int = 25 # legacy fallback if no schedule given
|
| 215 |
+
SFT_SAVE_EVERY: int = 25
|
| 216 |
+
SFT_LOG_EVERY: int = 10
|
| 217 |
+
SFT_PREFLIGHT_DIVERSITY_FLOOR: int = 2
|
| 218 |
+
"""eval/output_diversity threshold. If two consecutive evals both report
|
| 219 |
+
output_diversity below this floor, the diversity-collapse early stop
|
| 220 |
+
fires and SFT exits with reason=diversity_collapse."""
|
| 221 |
+
SFT_DIVERSITY_COLLAPSE_RUN_LEN: int = 2
|
| 222 |
+
"""Number of consecutive sub-floor evals required before stopping."""
|
| 223 |
+
SFT_MAX_NEW_TOKENS: int = 200 # generation cap during eval
|
| 224 |
+
# Was 128; bumped to 200 because Qwen2.5-Instruct's cold-start reasoning
|
| 225 |
+
# (### Analysis: 1. ... 2. ... 3. ...) regularly runs to 100+ tokens
|
| 226 |
+
# before reaching the format line in early SFT steps. With 128, every
|
| 227 |
+
# step-5 sample truncated mid-reasoning and format_compliance read 0.
|
| 228 |
+
# 200 gives ~70 tokens of headroom past a typical reasoning + format
|
| 229 |
+
# completion (~70 tokens total) so truncation never masks the model's
|
| 230 |
+
# real behaviour.
|
| 231 |
+
|
| 232 |
+
# --- Variable eval cadence ------------------------------------------------- #
|
| 233 |
+
# Early evals are quick sanity checks (small sample, format-only) so a
|
| 234 |
+
# broken parser / generation drift gets caught before ~10 min of compute is
|
| 235 |
+
# burned. Late evals are real measurements with the full sample size.
|
| 236 |
+
# Catching format-compliance failure at step 15 instead of step 50 saves
|
| 237 |
+
# ~7 minutes per fire.
|
| 238 |
+
#
|
| 239 |
+
# Each entry: (step, sample_size, mode) where mode is "format_only" or
|
| 240 |
+
# "full". format_only skips the diversity probe and the physics-heavy
|
| 241 |
+
# logical_correction / hamming / syndrome metrics, so the eval costs
|
| 242 |
+
# ~30 seconds instead of ~2 minutes.
|
| 243 |
+
SFT_EVAL_SCHEDULE: tuple[tuple[int, int, str], ...] = (
|
| 244 |
+
# 2026-04: schedule rebuilt to fit the SFT_MAX_STEPS=50 budget. Two
|
| 245 |
+
# full evals plus a fast format probe gives the diversity-collapse
|
| 246 |
+
# early-stop two consecutive data points before the run ends, which
|
| 247 |
+
# is the minimum to fire the new run-length-2 stop rule.
|
| 248 |
+
(5, 30, "format_only"),
|
| 249 |
+
(15, 50, "full"),
|
| 250 |
+
(25, 100, "full"),
|
| 251 |
+
(40, 100, "full"),
|
| 252 |
+
(50, 100, "full"),
|
| 253 |
+
)
|
| 254 |
+
SFT_PRINT_SAMPLE_OUTPUTS: int = 5 # raw outputs printed at each eval
|
| 255 |
+
|
| 256 |
+
# Early-stop thresholds (master spec, section 3).
|
| 257 |
+
SFT_EARLY_STOP_FORMAT: float = 0.95
|
| 258 |
+
SFT_EARLY_STOP_CORRECTION: float = 0.80
|
| 259 |
+
SFT_EARLY_STOP_DIVERSITY: int = 3
|
| 260 |
+
SFT_MAX_WALL_SECONDS: float = 30 * 60.0 # 30-minute hard ceiling
|
| 261 |
+
|
| 262 |
+
# HuggingFace Trainer subfolder (step-50 save) used to initialise GRPO.
|
| 263 |
+
# ``python -m scripts.train_grpo`` defaults to this path; pipeline scripts
|
| 264 |
+
# also pass it explicitly.
|
| 265 |
+
SFT_CHECKPOINT_PATH_FOR_GRPO: str = "checkpoints/sft_warmup/checkpoint-50"
|
| 266 |
+
|
| 267 |
+
# ---- GRPO RL phase (master spec, section 5; 2026-04 spec rewrite) -------- #
|
| 268 |
+
# All numbers below were re-pinned by the 2026-04 GRPO spec. The previous
|
| 269 |
+
# defaults (GRPO_STEPS=2000, LR=1e-5, KL=0.04, max_prompt=512,
|
| 270 |
+
# max_completion=256, temperature=0.7) produced a degenerate "always say
|
| 271 |
+
# []" policy in <100 steps because reward variance collapsed and KL
|
| 272 |
+
# saturated the loss. The new defaults emphasise diversity:
|
| 273 |
+
#
|
| 274 |
+
# - higher temperature (1.2) + top_k + repetition_penalty -> non-collapsed rollouts
|
| 275 |
+
# - shorter max_completion_length (50) -> the answer is one short line anyway
|
| 276 |
+
# - longer max_prompt_length (1500) -> distance-3 syndromes already use
|
| 277 |
+
# ~280 tokens; distance-5 / curriculum L3 needs the headroom
|
| 278 |
+
# - lower KL coefficient (0.02) -> reward signal not dominated by KL drift
|
| 279 |
+
# - 1500 steps -> wall-clock fits the 13h cap with margin
|
| 280 |
+
GRPO_STEPS: int = 1_500
|
| 281 |
+
GRPO_GEN_PER_PROMPT: int = 4 # GRPO needs >=2 for advantage
|
| 282 |
+
GRPO_BATCH_SIZE: int = 1 # per-device prompts per step
|
| 283 |
+
GRPO_GRAD_ACCUM: int = 8 # effective batch = 8 prompts
|
| 284 |
+
GRPO_LR: float = 2e-5 # bumped from 1e-5; reward signal is sparse
|
| 285 |
+
GRPO_LR_SCHEDULER: str = "constant" # no warmup, no decay
|
| 286 |
+
GRPO_KL_COEF: float = 0.02 # half the TRL default; alarm if KL > 0.3
|
| 287 |
+
GRPO_MAX_PROMPT_LEN: int = 1_500 # surface-code prompts can run long
|
| 288 |
+
GRPO_MAX_COMPLETION_LEN: int = 50 # answer is one line: X_ERRORS=[..] Z_ERRORS=[..]
|
| 289 |
+
|
| 290 |
+
# ---- Diversity-focused rollout sampling (critical) ----------------------- #
|
| 291 |
+
# These apply to GRPO ROLLOUT generation only. Eval uses temperature=0
|
| 292 |
+
# (greedy) regardless of these. The combination temperature=1.2 + top_p=0.95
|
| 293 |
+
# + top_k=50 + repetition_penalty=1.1 was selected because:
|
| 294 |
+
# * temperature=1.2 broadens the per-token distribution past the SFT
|
| 295 |
+
# mode-collapsed favourite ("X_ERRORS=[] Z_ERRORS=[]").
|
| 296 |
+
# * top_p=0.95 keeps tail tokens in but truncates the long tail.
|
| 297 |
+
# * top_k=50 caps the candidate set so we don't sample garbage.
|
| 298 |
+
# * repetition_penalty=1.1 discourages the model from repeating the
|
| 299 |
+
# exact same byte sequence within a 4-completion group (reduces
|
| 300 |
+
# "all 4 generations identical" rate, which kills GRPO's gradient).
|
| 301 |
+
GRPO_TEMPERATURE: float = 1.2
|
| 302 |
+
GRPO_TOP_P: float = 0.95
|
| 303 |
+
GRPO_TOP_K: int = 50
|
| 304 |
+
GRPO_REPETITION_PENALTY: float = 1.1
|
| 305 |
+
GRPO_DO_SAMPLE: bool = True
|
| 306 |
+
|
| 307 |
+
# ---- Checkpoint cadence + retention -------------------------------------- #
|
| 308 |
+
GRPO_CHECKPOINT_EVERY: int = 100
|
| 309 |
+
GRPO_SAVE_TOTAL_LIMIT: int = 3 # keep 3 most recent rolling checkpoints
|
| 310 |
+
GRPO_LOG_EVERY: int = 5 # real-time visibility (every 5 steps)
|
| 311 |
+
GRPO_OPTIMIZER: str = "adamw_8bit"
|
| 312 |
+
GRPO_KL_ALARM: float = 0.3 # >this triggers manual triage
|
| 313 |
+
GRPO_KL_HARD_CEIL: float = 0.5 # >this -> kill the run
|
| 314 |
+
|
| 315 |
+
# ---- Wall-clock safety --------------------------------------------------- #
|
| 316 |
+
GRPO_WALL_SECONDS: float = 46_800.0 # 13 hours. Save+exit if exceeded.
|
| 317 |
+
|
| 318 |
+
# ---- Frozen eval set ----------------------------------------------------- #
|
| 319 |
+
# The 200-syndrome eval set is regenerated from the env at GRPO start with
|
| 320 |
+
# this seed. Same seed as SFT validation (sft_validation.jsonl) so eval
|
| 321 |
+
# distributions are comparable across SFT and GRPO. The set is cached on
|
| 322 |
+
# disk under data/grpo_validation.jsonl so reruns hit identical syndromes.
|
| 323 |
+
GRPO_VAL_SEED: int = 4_284
|
| 324 |
+
GRPO_VAL_EPISODES: int = 200
|
| 325 |
+
GRPO_VAL_PATH: str = "data/grpo_validation.jsonl"
|
| 326 |
+
|
| 327 |
+
# ---- Sample-table logging ------------------------------------------------ #
|
| 328 |
+
GRPO_SAMPLE_LOG_EVERY: int = 50
|
| 329 |
+
GRPO_SAMPLE_LOG_N: int = 5
|
| 330 |
+
|
| 331 |
+
# ---- Anti-hacking: mode-collapse inspection hook ------------------------- #
|
| 332 |
+
# Every N steps, we sample the most-recent N rollouts and check what
|
| 333 |
+
# fraction of prompts had ALL 4 generations identical. If too many
|
| 334 |
+
# prompts collapsed, raise the rollout temperature by a fixed step.
|
| 335 |
+
GRPO_INSPECTION_HOOK_EVERY: int = 100
|
| 336 |
+
GRPO_INSPECTION_SAMPLE_N: int = 10
|
| 337 |
+
GRPO_INSPECTION_COLLAPSE_THRESHOLD: int = 7 # "> 7 of 10"
|
| 338 |
+
GRPO_TEMP_BUMP_ON_COLLAPSE: float = 0.2
|
| 339 |
+
|
| 340 |
+
# ---- Decision-rule thresholds (warnings only; no auto-action) ----------- #
|
| 341 |
+
GRPO_DECISION_REWARD_STD_FLOOR: float = 0.03
|
| 342 |
+
GRPO_DECISION_REWARD_STD_CHECK_STEP: int = 50
|
| 343 |
+
GRPO_DECISION_BEAT_RATE_CHECK_STEP: int = 500
|
| 344 |
+
GRPO_DECISION_FORMAT_FLOOR: float = 0.95
|
| 345 |
+
GRPO_DECISION_GRAD_NORM_CEIL: float = 50.0
|
| 346 |
+
GRPO_DECISION_GRAD_NORM_RUN_LEN: int = 3 # consecutive logs
|
| 347 |
|
| 348 |
# Decoding sampler defaults at evaluation/format-test time.
|
| 349 |
+
# (Used by greedy eval paths: temp/top_p only matter when do_sample=True.)
|
| 350 |
SAMPLE_TEMPERATURE: float = 0.7
|
| 351 |
SAMPLE_TOP_P: float = 0.95
|
| 352 |
|
|
|
|
| 369 |
# all log to the same project / dashboard. Override per-run on the CLI.
|
| 370 |
import os as _os # noqa: E402 (local import to keep top of module clean)
|
| 371 |
|
| 372 |
+
WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "QuantumScribe-GRPO")
|
| 373 |
+
"""Default W&B project name. Override with ``WANDB_PROJECT=...``.
|
| 374 |
+
|
| 375 |
+
Changed 2026-04 from ``"QuantumScribe"`` to ``"QuantumScribe-GRPO"`` per
|
| 376 |
+
the GRPO spec rewrite. SFT runs that should land in the original project
|
| 377 |
+
should set ``WANDB_PROJECT=QuantumScribe`` at the shell."""
|
| 378 |
|
| 379 |
+
WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY", "ronitraj") or None
|
| 380 |
"""W&B team or username. ``None`` -> wandb's default entity for the user."""
|
| 381 |
|
| 382 |
WANDB_DEFAULT_TAGS: tuple[str, ...] = (
|
|
|
|
| 389 |
"""Tags applied to every W&B run (per-script tags appended on top)."""
|
| 390 |
|
| 391 |
WANDB_LOG_GENERATIONS_EVERY: int = 50
|
| 392 |
+
"""Log a sample-completion table every N GRPO steps (master spec sec. 7)."""
|
| 393 |
|
| 394 |
+
WANDB_SAMPLE_GENERATIONS: int = 5
|
| 395 |
+
"""Number of generations included in each sample-completion table.
|
| 396 |
+
Master spec, section 7: 'Save 5 randomly sampled rollouts ... and their rewards.'"""
|
| 397 |
|
| 398 |
+
WANDB_INLOOP_EVAL_EVERY: int = 100
|
| 399 |
"""Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
|
| 400 |
+
syndromes) every N GRPO steps. Tightened from 250 -> 100 by the 2026-04 GRPO
|
| 401 |
+
spec rewrite so collapse / drift gets caught within a 5-minute window
|
| 402 |
+
instead of a 15-minute window."""
|
| 403 |
+
|
| 404 |
+
WANDB_INLOOP_EVAL_EPISODES: int = 200
|
| 405 |
+
"""Held-out syndromes per in-loop eval pass. Bumped from 100 -> 200 by the
|
| 406 |
+
2026-04 spec rewrite so eval-stat error bars are tight enough to read
|
| 407 |
+
pymatching_beat_rate movement (which is sub-5% in early training)."""
|
| 408 |
|
| 409 |
+
WANDB_COMPARE_EVERY: int = 500
|
| 410 |
+
"""Run the PyMatching head-to-head comparison every N steps (master spec sec. 7)."""
|
| 411 |
|
| 412 |
|
| 413 |
# --------------------------------------------------------------------------- #
|
qubit_medic/prompts.py
CHANGED
|
@@ -1,19 +1,23 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
from __future__ import annotations
|
| 19 |
|
|
@@ -23,33 +27,30 @@ from typing import Iterable
|
|
| 23 |
|
| 24 |
|
| 25 |
# --------------------------------------------------------------------------- #
|
| 26 |
-
# Prompt
|
| 27 |
# --------------------------------------------------------------------------- #
|
| 28 |
|
| 29 |
-
|
| 30 |
-
"You are a quantum error-correction decoder. You are decoding errors in "
|
| 31 |
-
"a distance-{distance} rotated surface code memory experiment."
|
| 32 |
-
)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
"Stabilizers are parity checks measured every round. A *syndrome bit* "
|
| 36 |
-
"is 1 when a stabilizer's measurement disagrees with its previous round, "
|
| 37 |
-
"indicating a nearby physical error. Your job is to look at the syndrome "
|
| 38 |
-
"history and output the smallest physical error pattern (X-flips and "
|
| 39 |
-
"Z-flips on data qubits, identified by integer IDs) that explains it."
|
| 40 |
-
)
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
def format_syndrome_block(
|
|
@@ -58,19 +59,33 @@ def format_syndrome_block(
|
|
| 58 |
num_x_stabilizers: int,
|
| 59 |
num_z_stabilizers: int,
|
| 60 |
) -> str:
|
| 61 |
-
"""Render
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
"""
|
| 67 |
bits = list(syndrome_bits)
|
| 68 |
per_round = num_x_stabilizers + num_z_stabilizers
|
| 69 |
-
lines = ["Syndrome (round-by-round):"]
|
| 70 |
if per_round == 0 or rounds == 0 or len(bits) == 0:
|
| 71 |
-
|
| 72 |
-
return "\n".join(lines)
|
| 73 |
|
|
|
|
| 74 |
for r in range(rounds):
|
| 75 |
offset = r * per_round
|
| 76 |
if offset >= len(bits):
|
|
@@ -79,18 +94,15 @@ def format_syndrome_block(
|
|
| 79 |
x_chunk = chunk[:num_x_stabilizers]
|
| 80 |
z_chunk = chunk[num_x_stabilizers : num_x_stabilizers + num_z_stabilizers]
|
| 81 |
lines.append(
|
| 82 |
-
f"
|
| 83 |
-
+ " ".join(str(b) for b in x_chunk)
|
| 84 |
)
|
| 85 |
lines.append(
|
| 86 |
-
f"
|
| 87 |
-
+ " ".join(str(b) for b in z_chunk)
|
| 88 |
)
|
| 89 |
-
# Trailing block for the final destructive measurement, if any extras.
|
| 90 |
used = rounds * per_round
|
| 91 |
if used < len(bits):
|
| 92 |
tail = bits[used:]
|
| 93 |
-
lines.append("
|
| 94 |
return "\n".join(lines)
|
| 95 |
|
| 96 |
|
|
@@ -104,10 +116,10 @@ def build_prompt(
|
|
| 104 |
num_z_stabilizers: int,
|
| 105 |
num_data_qubits: int,
|
| 106 |
) -> str:
|
| 107 |
-
"""Assemble the
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
"""
|
| 112 |
syndrome_block = format_syndrome_block(
|
| 113 |
syndrome_bits=syndrome_bits,
|
|
@@ -115,27 +127,53 @@ def build_prompt(
|
|
| 115 |
num_x_stabilizers=num_x_stabilizers,
|
| 116 |
num_z_stabilizers=num_z_stabilizers,
|
| 117 |
)
|
| 118 |
-
return (
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
+ _OUTPUT_SPEC
|
| 128 |
-
+ "\n\n"
|
| 129 |
-
+ _REASONING_TRIGGER
|
| 130 |
)
|
| 131 |
|
| 132 |
|
| 133 |
# --------------------------------------------------------------------------- #
|
| 134 |
-
# Output parsing
|
| 135 |
# --------------------------------------------------------------------------- #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
@dataclass(frozen=True)
|
|
@@ -145,13 +183,19 @@ class ParseResult:
|
|
| 145 |
parse_success: bool # True iff BOTH X_ERRORS and Z_ERRORS parsed cleanly
|
| 146 |
parse_partial: bool # True iff exactly one of the two parsed cleanly
|
| 147 |
raw_response: str
|
|
|
|
| 148 |
|
| 149 |
@property
|
| 150 |
def format_score(self) -> float:
|
| 151 |
-
"""Score for Reward 4 (format compliance).
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return 1.0
|
| 154 |
-
if self.parse_partial:
|
| 155 |
return 0.5
|
| 156 |
return 0.0
|
| 157 |
|
|
@@ -160,6 +204,8 @@ def _parse_int_list(s: str, max_qubit: int) -> tuple[list[int], bool]:
|
|
| 160 |
"""Parse a comma/space-separated integer list. Drops out-of-range and dups.
|
| 161 |
|
| 162 |
Returns ``(qubits_sorted_unique, all_tokens_were_valid)``.
|
|
|
|
|
|
|
| 163 |
"""
|
| 164 |
if not s.strip():
|
| 165 |
return [], True
|
|
@@ -182,25 +228,77 @@ def _parse_int_list(s: str, max_qubit: int) -> tuple[list[int], bool]:
|
|
| 182 |
|
| 183 |
|
| 184 |
def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
|
| 185 |
-
"""Convert the LLM's raw text to a `
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
"""
|
| 190 |
if not isinstance(raw_response, str):
|
| 191 |
-
return ParseResult([], [], False, False, raw_response="")
|
| 192 |
-
|
| 193 |
-
#
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
x_errors: list[int] = []
|
| 201 |
z_errors: list[int] = []
|
| 202 |
x_clean = z_clean = False
|
| 203 |
-
|
| 204 |
if x_match is not None:
|
| 205 |
x_errors, x_clean = _parse_int_list(x_match.group(1), num_data_qubits)
|
| 206 |
if z_match is not None:
|
|
@@ -214,12 +312,17 @@ def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
|
|
| 214 |
(x_match is not None and z_match is not None) and not parse_success
|
| 215 |
)
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
return ParseResult(
|
| 218 |
x_errors=x_errors,
|
| 219 |
z_errors=z_errors,
|
| 220 |
parse_success=parse_success,
|
| 221 |
parse_partial=parse_partial,
|
| 222 |
raw_response=raw_response,
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
|
|
|
|
| 1 |
+
"""Locked prompt template + parser (master spec, sections 4 + parser).
|
| 2 |
+
|
| 3 |
+
This module is the *single source of truth* for what the LLM sees during
|
| 4 |
+
SFT and GRPO. The exact wording is fixed: anything that drifts the prompt
|
| 5 |
+
between phases throws away the SFT investment because RL builds on the
|
| 6 |
+
format SFT taught.
|
| 7 |
+
|
| 8 |
+
Spec sections honoured:
|
| 9 |
+
* Section 4 - "The exact prompt template (locked, for both SFT and RL)"
|
| 10 |
+
* Section 4 - "The {syndrome_block} format" (round-by-round, X first then Z)
|
| 11 |
+
* Section 4 - "The parser specification (critical)"
|
| 12 |
+
|
| 13 |
+
Parser highlights
|
| 14 |
+
-----------------
|
| 15 |
+
* Case-insensitive on ``X_ERRORS``/``Z_ERRORS`` keys.
|
| 16 |
+
* Tolerant of trailing chain-of-thought, code fences, and whitespace.
|
| 17 |
+
* **Takes the LAST occurrence** of ``X_ERRORS`` so the literal example
|
| 18 |
+
inside the prompt's "Examples:" block is never confused for the answer.
|
| 19 |
+
* Validates each id against ``[0, max_qubit_id]`` and dedups within a list.
|
| 20 |
+
* Returns a partial-credit score (1.0 / 0.5 / 0.0) for Reward 4.
|
| 21 |
"""
|
| 22 |
from __future__ import annotations
|
| 23 |
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
# --------------------------------------------------------------------------- #
|
| 30 |
+
# Prompt template (LOCKED - see master spec, section 4) #
|
| 31 |
# --------------------------------------------------------------------------- #
|
| 32 |
|
| 33 |
+
_PROMPT_TEMPLATE = """You are an expert quantum error correction decoder. Your job is to identify which data qubits experienced errors based on syndrome measurements.
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
A surface code protects 1 logical qubit using {num_data_qubits} data qubits arranged in a {distance}x{distance} grid. Stabilizer measurements detect errors: a '1' means that stabilizer fired (detected something wrong nearby); a '0' means it looks fine. Errors must be deduced from the pattern of stabilizers that fired.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
Code distance: {distance}
|
| 38 |
+
Number of stabilizer rounds: {rounds}
|
| 39 |
+
Physical error rate: {p}
|
| 40 |
+
X-stabilizer count per round: {num_x_stabilizers}
|
| 41 |
+
Z-stabilizer count per round: {num_z_stabilizers}
|
|
|
|
| 42 |
|
| 43 |
+
{syndrome_block}
|
| 44 |
+
|
| 45 |
+
Identify which data qubits (numbered 0-{max_qubit_id}) had X-errors and Z-errors. Most syndromes have 0-2 errors; an empty list means no errors of that type.
|
| 46 |
+
|
| 47 |
+
Output exactly ONE line and nothing else. Do not write reasoning, markdown, bullets, analysis, or explanations. Your entire response must match this exact format:
|
| 48 |
+
X_ERRORS=[qubit_ids] Z_ERRORS=[qubit_ids]
|
| 49 |
+
|
| 50 |
+
Valid one-line examples:
|
| 51 |
+
X_ERRORS=[] Z_ERRORS=[]
|
| 52 |
+
X_ERRORS=[] Z_ERRORS=[4]
|
| 53 |
+
X_ERRORS=[2] Z_ERRORS=[5,6]"""
|
| 54 |
|
| 55 |
|
| 56 |
def format_syndrome_block(
|
|
|
|
| 59 |
num_x_stabilizers: int,
|
| 60 |
num_z_stabilizers: int,
|
| 61 |
) -> str:
|
| 62 |
+
"""Render detector activations round-by-round, exactly per the spec.
|
| 63 |
+
|
| 64 |
+
Format example for distance-3, rounds=3:
|
| 65 |
+
|
| 66 |
+
Round 1 X-stabilizers: 0 0 1 0
|
| 67 |
+
Round 1 Z-stabilizers: 0 0 0 0
|
| 68 |
+
Round 2 X-stabilizers: 0 0 1 0
|
| 69 |
+
Round 2 Z-stabilizers: 0 0 0 0
|
| 70 |
+
Round 3 X-stabilizers: 0 0 0 0
|
| 71 |
+
Round 3 Z-stabilizers: 0 0 0 0
|
| 72 |
+
|
| 73 |
+
Every round on its own line, X first then Z, space-separated bits, no
|
| 74 |
+
indent, no commas. Rounds are always emitted in full even when all
|
| 75 |
+
bits are zero so the LLM sees consistent shape.
|
| 76 |
+
|
| 77 |
+
Stim's detector layout for the rotated-memory experiment is row-major:
|
| 78 |
+
round 0 stabilizers first, then round 1, and so on. For each round it
|
| 79 |
+
interleaves the per-type detectors in the order Stim's circuit was
|
| 80 |
+
generated (we treat the first ``num_x_stabilizers`` per round as X
|
| 81 |
+
and the rest as Z, matching ``per_round_x_z_counts``).
|
| 82 |
"""
|
| 83 |
bits = list(syndrome_bits)
|
| 84 |
per_round = num_x_stabilizers + num_z_stabilizers
|
|
|
|
| 85 |
if per_round == 0 or rounds == 0 or len(bits) == 0:
|
| 86 |
+
return "(no detectors fired)"
|
|
|
|
| 87 |
|
| 88 |
+
lines: list[str] = []
|
| 89 |
for r in range(rounds):
|
| 90 |
offset = r * per_round
|
| 91 |
if offset >= len(bits):
|
|
|
|
| 94 |
x_chunk = chunk[:num_x_stabilizers]
|
| 95 |
z_chunk = chunk[num_x_stabilizers : num_x_stabilizers + num_z_stabilizers]
|
| 96 |
lines.append(
|
| 97 |
+
f"Round {r + 1} X-stabilizers: " + " ".join(str(int(b)) for b in x_chunk)
|
|
|
|
| 98 |
)
|
| 99 |
lines.append(
|
| 100 |
+
f"Round {r + 1} Z-stabilizers: " + " ".join(str(int(b)) for b in z_chunk)
|
|
|
|
| 101 |
)
|
|
|
|
| 102 |
used = rounds * per_round
|
| 103 |
if used < len(bits):
|
| 104 |
tail = bits[used:]
|
| 105 |
+
lines.append("Final-round detectors: " + " ".join(str(int(b)) for b in tail))
|
| 106 |
return "\n".join(lines)
|
| 107 |
|
| 108 |
|
|
|
|
| 116 |
num_z_stabilizers: int,
|
| 117 |
num_data_qubits: int,
|
| 118 |
) -> str:
|
| 119 |
+
"""Assemble the locked prompt the LLM sees on each step.
|
| 120 |
|
| 121 |
+
Pure function (no I/O, no globals) so the SFT pipeline and GRPO
|
| 122 |
+
rollout produce byte-identical prompt strings - a critical invariant.
|
| 123 |
"""
|
| 124 |
syndrome_block = format_syndrome_block(
|
| 125 |
syndrome_bits=syndrome_bits,
|
|
|
|
| 127 |
num_x_stabilizers=num_x_stabilizers,
|
| 128 |
num_z_stabilizers=num_z_stabilizers,
|
| 129 |
)
|
| 130 |
+
return _PROMPT_TEMPLATE.format(
|
| 131 |
+
num_data_qubits=num_data_qubits,
|
| 132 |
+
distance=distance,
|
| 133 |
+
rounds=rounds,
|
| 134 |
+
p=p,
|
| 135 |
+
num_x_stabilizers=num_x_stabilizers,
|
| 136 |
+
num_z_stabilizers=num_z_stabilizers,
|
| 137 |
+
syndrome_block=syndrome_block,
|
| 138 |
+
max_qubit_id=num_data_qubits - 1,
|
|
|
|
|
|
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
|
| 142 |
# --------------------------------------------------------------------------- #
|
| 143 |
+
# Output parsing (LOCKED - see master spec, section 4 "Parser specification") #
|
| 144 |
# --------------------------------------------------------------------------- #
|
| 145 |
+
#
|
| 146 |
+
# Two-tier parser:
|
| 147 |
+
# * STRICT - canonical "X_ERRORS=[...] Z_ERRORS=[...]". Only this form
|
| 148 |
+
# scores 1.0 on Reward 4 (format_compliance), so the GRPO signal still
|
| 149 |
+
# pushes the model toward the locked spec wording.
|
| 150 |
+
# * LENIENT - also accepts ":" instead of "=", "()" instead of "[]",
|
| 151 |
+
# "X-ERRORS" / "X ERRORS" key spellings, and tolerates
|
| 152 |
+
# \boxed{...} / **...** wrapping. Used so eval/metrics see
|
| 153 |
+
# the model's actual *answer* whenever it is extractable,
|
| 154 |
+
# instead of silently treating parse failures as
|
| 155 |
+
# "predict no errors" (which hides the bug at p=0.001 where
|
| 156 |
+
# ~95% of syndromes are trivial and an empty prediction is
|
| 157 |
+
# accidentally correct).
|
| 158 |
+
|
| 159 |
+
# Strict canonical form: "=" + "[]" - required for Reward 4 = 1.0.
|
| 160 |
+
_X_PATTERN_STRICT = re.compile(r"X_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
|
| 161 |
+
_Z_PATTERN_STRICT = re.compile(r"Z_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
|
| 162 |
+
|
| 163 |
+
# Lenient form: "=" or ":" separator, "[]" or "()" brackets, and the key may
|
| 164 |
+
# be spelt "X_ERRORS" / "X-ERRORS" / "X ERRORS" / "XERRORS".
|
| 165 |
+
_X_PATTERN_LENIENT = re.compile(
|
| 166 |
+
r"X[\s_\-]*ERRORS\s*[=:]\s*[\[\(]([^\]\)]*)[\]\)]",
|
| 167 |
+
re.IGNORECASE,
|
| 168 |
+
)
|
| 169 |
+
_Z_PATTERN_LENIENT = re.compile(
|
| 170 |
+
r"Z[\s_\-]*ERRORS\s*[=:]\s*[\[\(]([^\]\)]*)[\]\)]",
|
| 171 |
+
re.IGNORECASE,
|
| 172 |
+
)
|
| 173 |
|
| 174 |
+
# Key locator (lenient) - finds where any X-errors keyword starts so we can
|
| 175 |
+
# slice past in-prompt examples and home in on the model's actual answer.
|
| 176 |
+
_X_KEY = re.compile(r"X[\s_\-]*ERRORS", re.IGNORECASE)
|
| 177 |
|
| 178 |
|
| 179 |
@dataclass(frozen=True)
|
|
|
|
| 183 |
parse_success: bool # True iff BOTH X_ERRORS and Z_ERRORS parsed cleanly
|
| 184 |
parse_partial: bool # True iff exactly one of the two parsed cleanly
|
| 185 |
raw_response: str
|
| 186 |
+
strict_format: bool = False # True iff matched the canonical "=" + "[]" form
|
| 187 |
|
| 188 |
@property
|
| 189 |
def format_score(self) -> float:
|
| 190 |
+
"""Score for Reward 4 (format compliance).
|
| 191 |
+
|
| 192 |
+
Only the canonical strict form earns 1.0, so the GRPO reward stays
|
| 193 |
+
anchored to the locked spec wording. Lenient parses or partials
|
| 194 |
+
score 0.5; total miss scores 0.0.
|
| 195 |
+
"""
|
| 196 |
+
if self.parse_success and self.strict_format:
|
| 197 |
return 1.0
|
| 198 |
+
if self.parse_success or self.parse_partial:
|
| 199 |
return 0.5
|
| 200 |
return 0.0
|
| 201 |
|
|
|
|
| 204 |
"""Parse a comma/space-separated integer list. Drops out-of-range and dups.
|
| 205 |
|
| 206 |
Returns ``(qubits_sorted_unique, all_tokens_were_valid)``.
|
| 207 |
+
A token is "invalid" if it isn't an integer or falls outside ``[0, max_qubit)``.
|
| 208 |
+
Duplicates within a list count as silently de-duped, not invalid.
|
| 209 |
"""
|
| 210 |
if not s.strip():
|
| 211 |
return [], True
|
|
|
|
| 228 |
|
| 229 |
|
| 230 |
def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
|
| 231 |
+
"""Convert the LLM's raw text to a :class:`ParseResult`.
|
| 232 |
+
|
| 233 |
+
Two-pass algorithm:
|
| 234 |
+
1. Receive the full model response string; normalise common LaTeX/
|
| 235 |
+
markdown wrappers (``\\boxed{...}``, ``**bold**``).
|
| 236 |
+
2. If the model wrapped output in fenced code blocks, focus on the
|
| 237 |
+
LAST fenced block.
|
| 238 |
+
3. Locate all X-errors keys; slice forward from the LAST one (so the
|
| 239 |
+
example block in the prompt never wins).
|
| 240 |
+
4. Try the STRICT pattern (``X_ERRORS=[...]``) first. If both X and Z
|
| 241 |
+
lists match, ``strict_format=True``.
|
| 242 |
+
5. Otherwise try the LENIENT pattern (``=`` or ``:``, ``[]`` or ``()``)
|
| 243 |
+
so a near-miss like ``X_ERRORS: [1]`` still surfaces the model's
|
| 244 |
+
intended prediction.
|
| 245 |
+
6. Validate every parsed integer is in ``[0, max_qubit_id]``; reject
|
| 246 |
+
duplicates within a list.
|
| 247 |
+
7. ``parse_success`` requires BOTH lists to parse cleanly;
|
| 248 |
+
``parse_partial`` is set when exactly one parsed cleanly OR both
|
| 249 |
+
keys appear but tokens were dirty.
|
| 250 |
+
|
| 251 |
+
The lenient fallback exists for *eval/diagnostic honesty*, not to
|
| 252 |
+
weaken the training signal: ``format_score`` (Reward 4) only returns
|
| 253 |
+
1.0 when ``strict_format`` is also True.
|
| 254 |
"""
|
| 255 |
if not isinstance(raw_response, str):
|
| 256 |
+
return ParseResult([], [], False, False, raw_response="", strict_format=False)
|
| 257 |
+
|
| 258 |
+
# 1: normalise common wrappers so the regex sees the inner content.
|
| 259 |
+
normalised = raw_response
|
| 260 |
+
# Strip \boxed{...} (LaTeX) - keep inner text.
|
| 261 |
+
normalised = re.sub(r"\\boxed\{([^{}]*)\}", r"\1", normalised)
|
| 262 |
+
# Strip surrounding **bold** / *italic* markers around the format block.
|
| 263 |
+
normalised = re.sub(r"\*+([A-Za-z_][^*]{0,40})\*+", r"\1", normalised)
|
| 264 |
+
|
| 265 |
+
# 2: fence handling - prefer last fenced block if present.
|
| 266 |
+
fenced = re.findall(r"```(?:[^\n]*)\n(.*?)```", normalised, re.DOTALL)
|
| 267 |
+
search_text = fenced[-1] if fenced else normalised
|
| 268 |
+
|
| 269 |
+
# 3: find the LAST X-errors key occurrence.
|
| 270 |
+
x_keys = list(_X_KEY.finditer(search_text))
|
| 271 |
+
if x_keys:
|
| 272 |
+
last_x_pos = x_keys[-1].start()
|
| 273 |
+
slice_text = search_text[last_x_pos:]
|
| 274 |
+
# If the last key has no payload (truncated), fall back one.
|
| 275 |
+
if (
|
| 276 |
+
not _X_PATTERN_STRICT.search(slice_text)
|
| 277 |
+
and not _X_PATTERN_LENIENT.search(slice_text)
|
| 278 |
+
and len(x_keys) > 1
|
| 279 |
+
):
|
| 280 |
+
last_x_pos = x_keys[-2].start()
|
| 281 |
+
slice_text = search_text[last_x_pos:]
|
| 282 |
+
else:
|
| 283 |
+
slice_text = search_text
|
| 284 |
+
|
| 285 |
+
# 4-5: try strict, then lenient.
|
| 286 |
+
x_match = _X_PATTERN_STRICT.search(slice_text)
|
| 287 |
+
z_matches_strict = list(_Z_PATTERN_STRICT.finditer(slice_text))
|
| 288 |
+
z_match = z_matches_strict[-1] if z_matches_strict else None
|
| 289 |
+
strict_x = x_match is not None
|
| 290 |
+
strict_z = z_match is not None
|
| 291 |
+
|
| 292 |
+
if x_match is None:
|
| 293 |
+
x_match = _X_PATTERN_LENIENT.search(slice_text)
|
| 294 |
+
if z_match is None:
|
| 295 |
+
z_matches_lenient = list(_Z_PATTERN_LENIENT.finditer(slice_text))
|
| 296 |
+
z_match = z_matches_lenient[-1] if z_matches_lenient else None
|
| 297 |
+
|
| 298 |
+
# 6: extract + validate qubit IDs.
|
| 299 |
x_errors: list[int] = []
|
| 300 |
z_errors: list[int] = []
|
| 301 |
x_clean = z_clean = False
|
|
|
|
| 302 |
if x_match is not None:
|
| 303 |
x_errors, x_clean = _parse_int_list(x_match.group(1), num_data_qubits)
|
| 304 |
if z_match is not None:
|
|
|
|
| 312 |
(x_match is not None and z_match is not None) and not parse_success
|
| 313 |
)
|
| 314 |
|
| 315 |
+
# strict_format is true only when BOTH X and Z hit the canonical pattern
|
| 316 |
+
# cleanly (no garbage tokens, no out-of-range qubits).
|
| 317 |
+
strict_format = bool(strict_x and strict_z and parse_success)
|
| 318 |
+
|
| 319 |
return ParseResult(
|
| 320 |
x_errors=x_errors,
|
| 321 |
z_errors=z_errors,
|
| 322 |
parse_success=parse_success,
|
| 323 |
parse_partial=parse_partial,
|
| 324 |
raw_response=raw_response,
|
| 325 |
+
strict_format=strict_format,
|
| 326 |
)
|
| 327 |
|
| 328 |
|
qubit_medic/server/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (393 Bytes). View file
|
|
|
qubit_medic/server/__pycache__/app.cpython-312.pyc
ADDED
|
Binary file (9.51 kB). View file
|
|
|
qubit_medic/server/__pycache__/curriculum.cpython-312.pyc
ADDED
|
Binary file (5.55 kB). View file
|
|
|
qubit_medic/server/__pycache__/environment.cpython-312.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
qubit_medic/server/__pycache__/openenv_adapter.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
qubit_medic/server/__pycache__/physics.cpython-312.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
qubit_medic/server/__pycache__/rewards.cpython-312.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
qubit_medic/server/app.py
CHANGED
|
@@ -6,6 +6,8 @@ routes (``/reset``, ``/step``, ``/state``, ``/health``, ``/schema``,
|
|
| 6 |
|
| 7 |
We add a few extras on top:
|
| 8 |
|
|
|
|
|
|
|
| 9 |
* ``GET /healthz`` - the Day-0 deployment-substrate liveness probe
|
| 10 |
(returns Stim/PyMatching/openenv versions). Used by the recurring
|
| 11 |
4-hour HF Spaces wakeup ping.
|
|
@@ -24,6 +26,7 @@ import sys
|
|
| 24 |
from typing import Optional
|
| 25 |
|
| 26 |
from fastapi import Body, HTTPException
|
|
|
|
| 27 |
from openenv.core import create_fastapi_app
|
| 28 |
|
| 29 |
from qubit_medic.config import DEFAULT_HOST, DEFAULT_PORT
|
|
@@ -60,6 +63,44 @@ app.description = (
|
|
| 60 |
)
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
# --------------------------------------------------------------------------- #
|
| 64 |
# Day-0 + demo extras #
|
| 65 |
# --------------------------------------------------------------------------- #
|
|
@@ -79,6 +120,41 @@ def _get_legacy_env() -> DecoderEnvironment:
|
|
| 79 |
return _legacy_env
|
| 80 |
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
@app.get("/healthz")
|
| 83 |
def healthz() -> dict:
|
| 84 |
"""Lightweight liveness probe (Day-0 deployment-substrate test).
|
|
|
|
| 6 |
|
| 7 |
We add a few extras on top:
|
| 8 |
|
| 9 |
+
* ``GET /`` - HTML landing page (HF Spaces **App** tab); links to
|
| 10 |
+
``/docs``, ``/healthz``, ``/metadata`` (avoids 404 on the root URL).
|
| 11 |
* ``GET /healthz`` - the Day-0 deployment-substrate liveness probe
|
| 12 |
(returns Stim/PyMatching/openenv versions). Used by the recurring
|
| 13 |
4-hour HF Spaces wakeup ping.
|
|
|
|
| 26 |
from typing import Optional
|
| 27 |
|
| 28 |
from fastapi import Body, HTTPException
|
| 29 |
+
from fastapi.responses import HTMLResponse
|
| 30 |
from openenv.core import create_fastapi_app
|
| 31 |
|
| 32 |
from qubit_medic.config import DEFAULT_HOST, DEFAULT_PORT
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
|
| 66 |
+
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
| 67 |
+
def root() -> str:
|
| 68 |
+
"""Space + browser landing page (HF opens ``/`` in the App tab).
|
| 69 |
+
|
| 70 |
+
The OpenEnv API lives under ``/reset``, ``/step``, etc.; there was no
|
| 71 |
+
root handler, so visitors saw 404. This page links to docs and health.
|
| 72 |
+
"""
|
| 73 |
+
return """<!DOCTYPE html>
|
| 74 |
+
<html lang="en">
|
| 75 |
+
<head>
|
| 76 |
+
<meta charset="utf-8"/>
|
| 77 |
+
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
| 78 |
+
<title>Qubit-Medic OpenEnv</title>
|
| 79 |
+
<style>
|
| 80 |
+
body { font-family: system-ui, sans-serif; max-width: 40rem; margin: 2rem auto; padding: 0 1rem; line-height: 1.5; color: #1e293b; }
|
| 81 |
+
h1 { font-size: 1.5rem; }
|
| 82 |
+
ul { padding-left: 1.2rem; }
|
| 83 |
+
a { color: #2563eb; }
|
| 84 |
+
code { background: #f1f5f9; padding: 0.1em 0.3em; border-radius: 4px; }
|
| 85 |
+
</style>
|
| 86 |
+
</head>
|
| 87 |
+
<body>
|
| 88 |
+
<h1>Qubit-Medic — OpenEnv server</h1>
|
| 89 |
+
<p>This Space exposes a <strong>JSON API</strong> for the quantum error-decoding
|
| 90 |
+
environment (Stim + PyMatching, OpenEnv contract). There is no full-page
|
| 91 |
+
Gradio UI here; use the links below.</p>
|
| 92 |
+
<ul>
|
| 93 |
+
<li><a href="/docs">Interactive API docs (Swagger)</a></li>
|
| 94 |
+
<li><a href="/redoc">ReDoc</a></li>
|
| 95 |
+
<li><a href="/healthz">Liveness <code>GET /healthz</code></a> — versions probe</li>
|
| 96 |
+
<li><a href="/metadata">OpenEnv <code>GET /metadata</code></a></li>
|
| 97 |
+
</ul>
|
| 98 |
+
<p>Typical flow: <code>POST /reset</code> then <code>POST /step</code> with
|
| 99 |
+
the model’s text action — see the schema in <code>/docs</code>.</p>
|
| 100 |
+
</body>
|
| 101 |
+
</html>"""
|
| 102 |
+
|
| 103 |
+
|
| 104 |
# --------------------------------------------------------------------------- #
|
| 105 |
# Day-0 + demo extras #
|
| 106 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 120 |
return _legacy_env
|
| 121 |
|
| 122 |
|
| 123 |
+
# --------------------------------------------------------------------------- #
|
| 124 |
+
# Compliance Section 2 (audit 2026-04): POST /state and POST /close. #
|
| 125 |
+
# --------------------------------------------------------------------------- #
|
| 126 |
+
# OpenEnv's create_fastapi_app already mounts GET /state and (via the
|
| 127 |
+
# canonical contract) does not expose /close at all. The participant-guide
|
| 128 |
+
# audit explicitly requires POST /state and POST /close, so we surface
|
| 129 |
+
# both as additional routes that delegate to the legacy DecoderEnvironment
|
| 130 |
+
# singleton (the same one /decode already uses). The OpenEnv-canonical
|
| 131 |
+
# GET /state route is preserved untouched.
|
| 132 |
+
# --------------------------------------------------------------------------- #
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@app.post("/state")
|
| 136 |
+
def post_state() -> dict:
|
| 137 |
+
"""POST mirror of the OpenEnv GET /state route.
|
| 138 |
+
|
| 139 |
+
Returns a JSON-serialisable snapshot of env state. Uses the inner
|
| 140 |
+
:meth:`DecoderEnvironment.state` (added in Section 1 compliance work)
|
| 141 |
+
which excludes ground-truth fields by construction.
|
| 142 |
+
"""
|
| 143 |
+
return _get_legacy_env().state()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@app.post("/close")
|
| 147 |
+
def post_close() -> dict:
|
| 148 |
+
"""POST /close: drop in-flight episodes on the legacy env singleton.
|
| 149 |
+
|
| 150 |
+
The singleton is rebuilt lazily on the next /reset, so calling /close
|
| 151 |
+
repeatedly is idempotent. Returns a small JSON dict so the caller can
|
| 152 |
+
confirm the request landed.
|
| 153 |
+
"""
|
| 154 |
+
_get_legacy_env().close()
|
| 155 |
+
return {"ok": True, "closed": True}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
@app.get("/healthz")
|
| 159 |
def healthz() -> dict:
|
| 160 |
"""Lightweight liveness probe (Day-0 deployment-substrate test).
|
qubit_medic/server/environment.py
CHANGED
|
@@ -201,9 +201,17 @@ class DecoderEnvironment:
|
|
| 201 |
with self._lock:
|
| 202 |
episode = self._active.pop(episode_id, None)
|
| 203 |
if episode is None:
|
| 204 |
-
# Calling step() on an unknown episode ID is a
|
| 205 |
-
#
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
elapsed = time.monotonic() - episode.started_at
|
| 209 |
timed_out = elapsed > EPISODE_TIMEOUT_SECONDS
|
|
@@ -312,3 +320,36 @@ class DecoderEnvironment:
|
|
| 312 |
"curriculum": self._scheduler.stats(),
|
| 313 |
"cached_levels": list(self._caches.keys()),
|
| 314 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
with self._lock:
|
| 202 |
episode = self._active.pop(episode_id, None)
|
| 203 |
if episode is None:
|
| 204 |
+
# Calling step() on an unknown episode ID is a clean
|
| 205 |
+
# ValueError (compliance Section 1 of the participant-guide
|
| 206 |
+
# audit: the env must "raise a clean ValueError, not a
|
| 207 |
+
# Python traceback"). The trainer didn't follow reset/step
|
| 208 |
+
# pairing, or the episode already ended; either way we
|
| 209 |
+
# surface a typed exception so the FastAPI layer can turn
|
| 210 |
+
# it into a 400 response instead of a 500.
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"unknown or already-finished episode {episode_id}; "
|
| 213 |
+
f"call reset() before step()."
|
| 214 |
+
)
|
| 215 |
|
| 216 |
elapsed = time.monotonic() - episode.started_at
|
| 217 |
timed_out = elapsed > EPISODE_TIMEOUT_SECONDS
|
|
|
|
| 320 |
"curriculum": self._scheduler.stats(),
|
| 321 |
"cached_levels": list(self._caches.keys()),
|
| 322 |
}
|
| 323 |
+
|
| 324 |
+
def state(self) -> dict:
|
| 325 |
+
"""Return a JSON-serialisable snapshot of the env's externally-
|
| 326 |
+
visible state (compliance Section 1 of the participant-guide
|
| 327 |
+
audit: ``state()`` returns a JSON-serialisable object, not a raw
|
| 328 |
+
Python object).
|
| 329 |
+
|
| 330 |
+
Crucially this never includes the ground-truth fields stored on
|
| 331 |
+
the per-episode :class:`DecoderState` (true error patterns,
|
| 332 |
+
actual_observable_flip, pymatching_observable_pred, circuit_text,
|
| 333 |
+
dem_text). Those stay in ``self._active[ep].state`` and are only
|
| 334 |
+
consumed by the reward functions.
|
| 335 |
+
"""
|
| 336 |
+
with self._lock:
|
| 337 |
+
return {
|
| 338 |
+
"episodes_started": int(self._episode_counter),
|
| 339 |
+
"active_episodes": int(len(self._active)),
|
| 340 |
+
"active_episode_ids": [int(ep) for ep in self._active.keys()],
|
| 341 |
+
"cached_levels": list(self._caches.keys()),
|
| 342 |
+
"curriculum": self._scheduler.stats(),
|
| 343 |
+
"base_seed": int(self._base_seed),
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
def close(self) -> None:
|
| 347 |
+
"""Drop any in-flight episodes and clear caches.
|
| 348 |
+
|
| 349 |
+
Compliance Section 1: the gym-style API requires ``close()``.
|
| 350 |
+
After ``close()`` the env can still be re-used by calling
|
| 351 |
+
``reset()`` again - we don't tear down the curriculum scheduler
|
| 352 |
+
or release the lock; we only release per-episode bookkeeping.
|
| 353 |
+
"""
|
| 354 |
+
with self._lock:
|
| 355 |
+
self._active.clear()
|
qubit_medic/server/rewards.py
CHANGED
|
@@ -84,13 +84,21 @@ def reward_syndrome_consistency(
|
|
| 84 |
) -> float:
|
| 85 |
"""How well does the predicted Pauli frame reproduce the FINAL detectors?
|
| 86 |
|
| 87 |
-
Computes Hamming similarity between ``predicted_final_bits`` (induced
|
| 88 |
-
the predicted X errors) and ``observed_final_bits``. Returns
|
| 89 |
``1 - hamming_distance / num_final_detectors``.
|
| 90 |
|
| 91 |
Rationale (Section 3.2): without this term, an LLM that lucky-guesses
|
| 92 |
-
the right qubits could get Reward 1 occasionally; this signal forces
|
| 93 |
-
to also explain the data the syndrome carries.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"""
|
| 95 |
final_dets = layout.final_detectors
|
| 96 |
if not final_dets:
|
|
@@ -104,7 +112,17 @@ def reward_syndrome_consistency(
|
|
| 104 |
predicted = implied.get(det_idx, 0)
|
| 105 |
if observed != predicted:
|
| 106 |
distance += 1
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def compute_final_detector_supports(
|
|
@@ -141,13 +159,37 @@ def compute_final_detector_supports(
|
|
| 141 |
# --------------------------------------------------------------------------- #
|
| 142 |
|
| 143 |
|
| 144 |
-
def
|
| 145 |
-
"""Jaccard
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
return inter / union if union else 1.0
|
| 152 |
|
| 153 |
|
|
@@ -156,16 +198,19 @@ def reward_hamming_overlap(
|
|
| 156 |
sample: SyndromeSample,
|
| 157 |
layout: CircuitLayout,
|
| 158 |
) -> float:
|
| 159 |
-
"""Average of Jaccard(X) and Jaccard(Z) against
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
the
|
|
|
|
|
|
|
|
|
|
| 166 |
"""
|
| 167 |
-
jx =
|
| 168 |
-
jz =
|
| 169 |
return 0.5 * (jx + jz)
|
| 170 |
|
| 171 |
|
|
@@ -175,8 +220,16 @@ def reward_hamming_overlap(
|
|
| 175 |
|
| 176 |
|
| 177 |
def reward_format_compliance(parsed: ParseResult) -> float:
|
| 178 |
-
"""
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 84 |
) -> float:
|
| 85 |
"""How well does the predicted Pauli frame reproduce the FINAL detectors?
|
| 86 |
|
| 87 |
+
Computes Hamming similarity between ``predicted_final_bits`` (induced
|
| 88 |
+
by the predicted X errors) and ``observed_final_bits``. Returns
|
| 89 |
``1 - hamming_distance / num_final_detectors``.
|
| 90 |
|
| 91 |
Rationale (Section 3.2): without this term, an LLM that lucky-guesses
|
| 92 |
+
the right qubits could get Reward 1 occasionally; this signal forces
|
| 93 |
+
it to also explain the data the syndrome carries.
|
| 94 |
+
|
| 95 |
+
2026-04 anti-collapse cap (FIX 1, RL spec rewrite): if the prediction
|
| 96 |
+
is empty AND the observed syndrome is non-empty (at least one
|
| 97 |
+
detector fired), cap the score at 0.5. Without this cap, the
|
| 98 |
+
"always predict empty" policy can still pull a high syndrome-
|
| 99 |
+
consistency score on the prompts where the implied final-round bits
|
| 100 |
+
happen to coincide with zeros, which kept GRPO trapped in the
|
| 101 |
+
constant-empty mode.
|
| 102 |
"""
|
| 103 |
final_dets = layout.final_detectors
|
| 104 |
if not final_dets:
|
|
|
|
| 112 |
predicted = implied.get(det_idx, 0)
|
| 113 |
if observed != predicted:
|
| 114 |
distance += 1
|
| 115 |
+
base = 1.0 - distance / len(final_dets)
|
| 116 |
+
|
| 117 |
+
# Anti-collapse cap: empty prediction + non-empty observed syndrome
|
| 118 |
+
# is a "did nothing while alarms were firing" failure mode. Cap at
|
| 119 |
+
# 0.5 so the empty policy can never approach the full 1.0 even when
|
| 120 |
+
# the implied final-round bits happen to coincide.
|
| 121 |
+
pred_is_empty = (not parsed.x_errors) and (not parsed.z_errors)
|
| 122 |
+
has_active_syndrome = any(int(b) != 0 for b in sample.syndrome_bits)
|
| 123 |
+
if pred_is_empty and has_active_syndrome:
|
| 124 |
+
return min(base, 0.5)
|
| 125 |
+
return base
|
| 126 |
|
| 127 |
|
| 128 |
def compute_final_detector_supports(
|
|
|
|
| 159 |
# --------------------------------------------------------------------------- #
|
| 160 |
|
| 161 |
|
| 162 |
+
def _set_aware_jaccard(true_set: list[int], pred_set: list[int]) -> float:
|
| 163 |
+
"""Set-aware Jaccard: penalises BOTH false alarms and missed errors.
|
| 164 |
+
|
| 165 |
+
2026-04 spec rewrite (FIX 1). The four-case rule is what makes
|
| 166 |
+
"predict empty everywhere" stop being a near-optimal strategy:
|
| 167 |
+
|
| 168 |
+
+-------------+-----------+-----------------------------------------+
|
| 169 |
+
| true_set | pred_set | score |
|
| 170 |
+
+-------------+-----------+-----------------------------------------+
|
| 171 |
+
| empty | empty | 1.0 (perfect, "no errors -> no edit") |
|
| 172 |
+
| empty | non-empty | 0.0 false alarm |
|
| 173 |
+
| non-empty | empty | 0.0 missed errors <-- the key change |
|
| 174 |
+
| non-empty | non-empty | |inter| / |union| (standard Jaccard) |
|
| 175 |
+
+-------------+-----------+-----------------------------------------+
|
| 176 |
+
|
| 177 |
+
Critically the third case used to score 1.0 under the prior plain
|
| 178 |
+
Jaccard (because both sets were treated symmetrically; "everything
|
| 179 |
+
correct, just nothing predicted" was indistinguishable from "perfect
|
| 180 |
+
agreement"). Under this rule a missed-error answer scores 0.0,
|
| 181 |
+
which moves the GRPO reward landscape so a non-trivial prediction
|
| 182 |
+
can climb out of the empty-everywhere local optimum.
|
| 183 |
+
"""
|
| 184 |
+
sa, sp = set(true_set), set(pred_set)
|
| 185 |
+
if not sa and not sp:
|
| 186 |
+
return 1.0 # perfect agreement: no true errors AND no claimed errors
|
| 187 |
+
if not sa and sp:
|
| 188 |
+
return 0.0 # false alarm: claimed errors that were not there
|
| 189 |
+
if sa and not sp:
|
| 190 |
+
return 0.0 # missed errors: alarms fired but model said nothing
|
| 191 |
+
inter = len(sa & sp)
|
| 192 |
+
union = len(sa | sp)
|
| 193 |
return inter / union if union else 1.0
|
| 194 |
|
| 195 |
|
|
|
|
| 198 |
sample: SyndromeSample,
|
| 199 |
layout: CircuitLayout,
|
| 200 |
) -> float:
|
| 201 |
+
"""Average of set-aware Jaccard(X) and set-aware Jaccard(Z) against
|
| 202 |
+
the reference Pauli frame carried by ``SyndromeSample``.
|
| 203 |
+
|
| 204 |
+
The reference frame lives on
|
| 205 |
+
``sample.pymatching_x_errors`` / ``sample.pymatching_z_errors`` —
|
| 206 |
+
in this codebase that frame is treated as the ground-truth target
|
| 207 |
+
(the SFT/GRPO dataset builders fill it from the same source as the
|
| 208 |
+
JSONL ``true_x_errors`` / ``true_z_errors`` fields). Per-axis score
|
| 209 |
+
uses the set-aware rule (see :func:`_set_aware_jaccard`), so missed
|
| 210 |
+
errors no longer score 1.0 just because the prediction set is empty.
|
| 211 |
"""
|
| 212 |
+
jx = _set_aware_jaccard(sample.pymatching_x_errors, parsed.x_errors)
|
| 213 |
+
jz = _set_aware_jaccard(sample.pymatching_z_errors, parsed.z_errors)
|
| 214 |
return 0.5 * (jx + jz)
|
| 215 |
|
| 216 |
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
def reward_format_compliance(parsed: ParseResult) -> float:
|
| 223 |
+
"""Binary {0.0, 1.0}: 1.0 iff the parser fully extracted both lists.
|
| 224 |
+
|
| 225 |
+
2026-04 spec rewrite (FIX 1): partial credit (0.5) is removed. With
|
| 226 |
+
partial credit on, the model could still earn ~half the format
|
| 227 |
+
weight on garbage outputs that resembled the canonical form, which
|
| 228 |
+
is part of what kept the reward landscape too flat for GRPO to
|
| 229 |
+
escape the empty-everywhere mode. The new rule rewards only a
|
| 230 |
+
cleanly-parsed answer.
|
| 231 |
+
"""
|
| 232 |
+
return 1.0 if parsed.parse_success else 0.0
|
| 233 |
|
| 234 |
|
| 235 |
# --------------------------------------------------------------------------- #
|
qubit_medic/wandb_utils.py
CHANGED
|
@@ -260,12 +260,22 @@ def run_context(run_name: str, job_type: str, **kwargs):
|
|
| 260 |
|
| 261 |
def log(metrics: Mapping[str, Any], *, step: Optional[int] = None,
|
| 262 |
commit: bool = True) -> None:
|
| 263 |
-
"""No-op-safe ``wandb.log`` wrapper.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
wandb = _import_wandb()
|
| 265 |
if wandb is None or _RUN is None:
|
| 266 |
return
|
| 267 |
try:
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
| 269 |
except Exception as exc: # pragma: no cover - defensive
|
| 270 |
print(f"[wandb] log failed: {exc}", file=sys.stderr)
|
| 271 |
|
|
|
|
| 260 |
|
| 261 |
def log(metrics: Mapping[str, Any], *, step: Optional[int] = None,
|
| 262 |
commit: bool = True) -> None:
|
| 263 |
+
"""No-op-safe ``wandb.log`` wrapper.
|
| 264 |
+
|
| 265 |
+
We store training-step alignment as an explicit scalar
|
| 266 |
+
``train/global_step`` instead of passing W&B's reserved ``step=`` value.
|
| 267 |
+
HuggingFace/TRL may advance W&B's internal step before our callback logs,
|
| 268 |
+
which otherwise produces "Tried to log to step N that is less than the
|
| 269 |
+
current step N+1" and drops eval metrics.
|
| 270 |
+
"""
|
| 271 |
wandb = _import_wandb()
|
| 272 |
if wandb is None or _RUN is None:
|
| 273 |
return
|
| 274 |
try:
|
| 275 |
+
payload = dict(metrics)
|
| 276 |
+
if step is not None and "train/global_step" not in payload:
|
| 277 |
+
payload["train/global_step"] = int(step)
|
| 278 |
+
wandb.log(payload, commit=commit)
|
| 279 |
except Exception as exc: # pragma: no cover - defensive
|
| 280 |
print(f"[wandb] log failed: {exc}", file=sys.stderr)
|
| 281 |
|