ronitraj commited on
Commit
fa68719
·
verified ·
1 Parent(s): 1c5a86c

deploy via scripts/deploy_to_space.py

Browse files
qubit_medic/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (818 Bytes). View file
 
qubit_medic/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (815 Bytes). View file
 
qubit_medic/__pycache__/config.cpython-312.pyc ADDED
Binary file (9.72 kB). View file
 
qubit_medic/__pycache__/config.cpython-314.pyc ADDED
Binary file (10.4 kB). View file
 
qubit_medic/__pycache__/models.cpython-312.pyc ADDED
Binary file (5.47 kB). View file
 
qubit_medic/__pycache__/models.cpython-314.pyc ADDED
Binary file (5.62 kB). View file
 
qubit_medic/__pycache__/prompts.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
qubit_medic/__pycache__/training_stack.cpython-312.pyc ADDED
Binary file (6.59 kB). View file
 
qubit_medic/__pycache__/wandb_utils.cpython-312.pyc ADDED
Binary file (20.3 kB). View file
 
qubit_medic/client/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (414 Bytes). View file
 
qubit_medic/client/__pycache__/client.cpython-312.pyc ADDED
Binary file (9.23 kB). View file
 
qubit_medic/client/client.py CHANGED
@@ -28,6 +28,10 @@ class _ClientProtocol(Protocol):
28
  def reset(self, *, seed: Optional[int] = None,
29
  forced_level: Optional[str] = None) -> DecoderObservation: ...
30
  def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
 
 
 
 
31
  def health(self) -> dict: ...
32
  def close(self) -> None: ...
33
 
@@ -89,6 +93,20 @@ class DecoderClient:
89
  info=dict(obs_payload.get("info", {})),
90
  )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def health(self) -> dict:
93
  r = self._client.get("/health")
94
  r.raise_for_status()
@@ -100,6 +118,14 @@ class DecoderClient:
100
  return r.json()
101
 
102
  def close(self) -> None:
 
 
 
 
 
 
 
 
103
  self._client.close()
104
 
105
 
@@ -117,11 +143,23 @@ class LocalDecoderClient:
117
  def step(self, *, raw_response: str, episode_id: int) -> StepResult:
118
  return self._env.step(raw_response=raw_response, episode_id=episode_id)
119
 
 
 
 
 
 
 
120
  def health(self) -> dict:
121
  return self._env.health()
122
 
123
- def close(self) -> None: # nothing to clean up
124
- pass
 
 
 
 
 
 
125
 
126
 
127
  def make_default_client() -> _ClientProtocol:
 
28
  def reset(self, *, seed: Optional[int] = None,
29
  forced_level: Optional[str] = None) -> DecoderObservation: ...
30
  def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
31
+ # Compliance Section 3 (audit, 2026-04): the client surface must
32
+ # mirror the server endpoints. state() returns a JSON-serialisable
33
+ # snapshot; close() releases per-episode bookkeeping.
34
+ def state(self) -> dict: ...
35
  def health(self) -> dict: ...
36
  def close(self) -> None: ...
37
 
 
93
  info=dict(obs_payload.get("info", {})),
94
  )
95
 
96
+ def state(self) -> dict:
97
+ """GET /state on the OpenEnv server.
98
+
99
+ Compliance Section 3 (audit, 2026-04): the client must mirror
100
+ the server endpoints. We use GET (the OpenEnv canonical method)
101
+ first, then fall back to POST (the audit-required method we
102
+ also mounted) if some server build only exposes one of them.
103
+ """
104
+ r = self._client.get("/state")
105
+ if r.status_code == 405: # method not allowed -> try POST
106
+ r = self._client.post("/state")
107
+ r.raise_for_status()
108
+ return r.json()
109
+
110
  def health(self) -> dict:
111
  r = self._client.get("/health")
112
  r.raise_for_status()
 
118
  return r.json()
119
 
120
  def close(self) -> None:
121
+ # Best-effort: tell the server we're done (the POST /close route
122
+ # is mounted by qubit_medic.server.app) and then release the
123
+ # local httpx connection pool. If the server doesn't expose
124
+ # /close, swallow the 404 - this remains an idempotent cleanup.
125
+ try:
126
+ self._client.post("/close")
127
+ except Exception:
128
+ pass
129
  self._client.close()
130
 
131
 
 
143
  def step(self, *, raw_response: str, episode_id: int) -> StepResult:
144
  return self._env.step(raw_response=raw_response, episode_id=episode_id)
145
 
146
+ def state(self) -> dict:
147
+ """Compliance Section 3 (audit, 2026-04): expose env state via
148
+ the same client surface as the HTTP variant. Delegates to the
149
+ in-process :meth:`DecoderEnvironment.state`."""
150
+ return self._env.state()
151
+
152
  def health(self) -> dict:
153
  return self._env.health()
154
 
155
+ def close(self) -> None:
156
+ # Compliance Section 3 (audit, 2026-04): close releases any
157
+ # per-episode bookkeeping on the inner DecoderEnvironment so a
158
+ # subsequent reset() starts from a clean active-episode dict.
159
+ try:
160
+ self._env.close()
161
+ except Exception:
162
+ pass
163
 
164
 
165
  def make_default_client() -> _ClientProtocol:
qubit_medic/config.py CHANGED
@@ -111,7 +111,12 @@ CURRICULUM: tuple[CurriculumLevel, ...] = (
111
  name="L1_warmup",
112
  distance=DISTANCE_PRIMARY,
113
  rounds=1,
114
- p=0.0001,
 
 
 
 
 
115
  promotion_threshold=0.80,
116
  eval_size=100,
117
  ),
@@ -139,9 +144,9 @@ CURRICULUM: tuple[CurriculumLevel, ...] = (
139
  # --------------------------------------------------------------------------- #
140
 
141
  REWARD_WEIGHTS: dict[str, float] = {
142
- "logical_correction": 0.40, # Reward 1 - the unfakeable ground truth
 
143
  "syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
144
- "hamming_overlap": 0.20, # Reward 3 - dense partial credit
145
  "format_compliance": 0.10, # Reward 4 - parser must succeed
146
  "pymatching_beat": 0.10, # Reward 5 - the headline metric
147
  }
@@ -163,29 +168,185 @@ PRIMARY_SEED: int = SEEDS[0]
163
  # --------------------------------------------------------------------------- #
164
 
165
  MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
166
- """3B params, 4-bit quantised + LoRA fits in a Colab T4."""
 
167
 
 
 
 
 
168
  LORA_R: int = 16
169
- LORA_ALPHA: int = 32
 
 
 
 
 
 
170
  LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
171
 
 
 
 
 
 
 
172
  SFT_EPOCHS: int = 1
173
  SFT_BATCH_SIZE: int = 4
174
- SFT_GRAD_ACCUM: int = 4
175
- SFT_LR: float = 2e-4
176
- SFT_DATASET_SIZE: int = 5_000
177
- SFT_MAX_SEQ_LEN: int = 2048
178
-
179
- GRPO_STEPS: int = 2_000
180
- GRPO_GEN_PER_PROMPT: int = 4
181
- GRPO_LR: float = 1e-5
182
- GRPO_KL_COEF: float = 0.04
183
- GRPO_MAX_PROMPT_LEN: int = 512
184
- GRPO_MAX_COMPLETION_LEN: int = 256
185
- GRPO_CHECKPOINT_EVERY: int = 250
186
- GRPO_LOG_EVERY: int = 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  # Decoding sampler defaults at evaluation/format-test time.
 
189
  SAMPLE_TEMPERATURE: float = 0.7
190
  SAMPLE_TOP_P: float = 0.95
191
 
@@ -208,10 +369,14 @@ DEFAULT_PORT: int = 7860 # Hugging Face Spaces' default exposed port
208
  # all log to the same project / dashboard. Override per-run on the CLI.
209
  import os as _os # noqa: E402 (local import to keep top of module clean)
210
 
211
- WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "qubit-medic")
212
- """Default W&B project name. Override with ``WANDB_PROJECT=...``."""
 
 
 
 
213
 
214
- WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY") or None
215
  """W&B team or username. ``None`` -> wandb's default entity for the user."""
216
 
217
  WANDB_DEFAULT_TAGS: tuple[str, ...] = (
@@ -224,17 +389,25 @@ WANDB_DEFAULT_TAGS: tuple[str, ...] = (
224
  """Tags applied to every W&B run (per-script tags appended on top)."""
225
 
226
  WANDB_LOG_GENERATIONS_EVERY: int = 50
227
- """Log a sample-completion table every N GRPO steps."""
228
 
229
- WANDB_SAMPLE_GENERATIONS: int = 8
230
- """Number of generations included in each sample-completion table."""
 
231
 
232
- WANDB_INLOOP_EVAL_EVERY: int = 200
233
  """Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
234
- syndromes) every N GRPO steps. Set to 0 to disable."""
 
 
 
 
 
 
 
235
 
236
- WANDB_INLOOP_EVAL_EPISODES: int = 50
237
- """Number of held-out syndromes per in-loop eval pass (kept small for speed)."""
238
 
239
 
240
  # --------------------------------------------------------------------------- #
 
111
  name="L1_warmup",
112
  distance=DISTANCE_PRIMARY,
113
  rounds=1,
114
+ # 0.0005 (was 0.0001) — at the original budget, L1 syndromes were
115
+ # almost always trivial, dragging the SFT class balance down even
116
+ # under per-level rejection sampling. Bumping to 0.0005 keeps L1
117
+ # strictly easier than L2 (p=0.001) while giving the model real
118
+ # non-empty examples to learn from at the warmup stage.
119
+ p=0.0005,
120
  promotion_threshold=0.80,
121
  eval_size=100,
122
  ),
 
144
  # --------------------------------------------------------------------------- #
145
 
146
  REWARD_WEIGHTS: dict[str, float] = {
147
+ "logical_correction": 0.35, # Reward 1 - the unfakeable ground truth
148
+ "hamming_overlap": 0.25, # Reward 3 - dense partial credit
149
  "syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
 
150
  "format_compliance": 0.10, # Reward 4 - parser must succeed
151
  "pymatching_beat": 0.10, # Reward 5 - the headline metric
152
  }
 
168
  # --------------------------------------------------------------------------- #
169
 
170
  MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
171
+ """Locked primary model. 3B params, 4-bit quantised + LoRA fits in a Colab T4.
172
+ Backup is ``Qwen/Qwen2.5-7B-Instruct`` - only swap if format-test < 30%."""
173
 
174
+ MODEL_BACKUP_ID: str = "Qwen/Qwen2.5-7B-Instruct"
175
+ """Only swap to this if the pre-onsite format test fails."""
176
+
177
+ # ---- LoRA (shared SFT + GRPO) -------------------------------------------- #
178
  LORA_R: int = 16
179
+ LORA_ALPHA: int = 32 # 2x rank, standard ratio
180
+ LORA_DROPOUT: float = 0.10
181
+ """Bumped 0.05 -> 0.10 (2026-04 SFT regularisation) because the prior
182
+ SFT runs converged to a single-output mode (every checkpoint reported
183
+ output_diversity=1) which left GRPO unable to compute non-zero
184
+ within-group reward variance. 0.10 is the spec's first-pass dropout;
185
+ the post-SFT diversity preflight will bump to 0.15 if needed."""
186
  LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
187
 
188
+ # ---- SFT warmup phase (master spec, section 1; 2026-04 regularisation) -- #
189
+ # 2026-04 changes (diversity-preserving regularisation): SFT collapsed to
190
+ # a constant-output model under the prior settings (LR=2e-4 + dropout=0.05
191
+ # + max_steps=200 left every checkpoint at output_diversity=1). New
192
+ # defaults trade some ceiling LCR for diversity headroom so GRPO has a
193
+ # reward signal to climb.
194
  SFT_EPOCHS: int = 1
195
  SFT_BATCH_SIZE: int = 4
196
+ SFT_GRAD_ACCUM: int = 4 # effective batch = 16
197
+ SFT_LR: float = 1e-4
198
+ """Halved 2e-4 -> 1e-4 to slow the slide into mode collapse."""
199
+ SFT_LR_SCHEDULER: str = "constant_with_warmup" # 20-step warmup then constant
200
+ SFT_WARMUP_STEPS: int = 20
201
+ SFT_WEIGHT_DECAY: float = 0.01
202
+ SFT_LABEL_SMOOTHING: float = 0.05
203
+ """TrainingArguments.label_smoothing_factor; spreads the loss across
204
+ non-target tokens so the model is less rewarded for memorising the
205
+ single highest-likelihood completion."""
206
+ SFT_OPTIMIZER: str = "adamw_8bit"
207
+ SFT_DATASET_SIZE: int = 3_000 # 3,000 train + 100 held-out validation
208
+ SFT_VAL_HOLDOUT: int = 100
209
+ SFT_MAX_SEQ_LEN: int = 1024 # ~300 prompt + ~80 completion + headroom
210
+ SFT_MAX_STEPS: int = 50
211
+ """Cut 200 -> 50 so SFT stops well before the model can grind itself
212
+ into a single-output mode. The format-only knowledge fits in <50
213
+ steps and post-SFT diversity preflight is the gate to GRPO."""
214
+ SFT_EVAL_EVERY: int = 25 # legacy fallback if no schedule given
215
+ SFT_SAVE_EVERY: int = 25
216
+ SFT_LOG_EVERY: int = 10
217
+ SFT_PREFLIGHT_DIVERSITY_FLOOR: int = 2
218
+ """eval/output_diversity threshold. If two consecutive evals both report
219
+ output_diversity below this floor, the diversity-collapse early stop
220
+ fires and SFT exits with reason=diversity_collapse."""
221
+ SFT_DIVERSITY_COLLAPSE_RUN_LEN: int = 2
222
+ """Number of consecutive sub-floor evals required before stopping."""
223
+ SFT_MAX_NEW_TOKENS: int = 200 # generation cap during eval
224
+ # Was 128; bumped to 200 because Qwen2.5-Instruct's cold-start reasoning
225
+ # (### Analysis: 1. ... 2. ... 3. ...) regularly runs to 100+ tokens
226
+ # before reaching the format line in early SFT steps. With 128, every
227
+ # step-5 sample truncated mid-reasoning and format_compliance read 0.
228
+ # 200 gives ~70 tokens of headroom past a typical reasoning + format
229
+ # completion (~70 tokens total) so truncation never masks the model's
230
+ # real behaviour.
231
+
232
+ # --- Variable eval cadence ------------------------------------------------- #
233
+ # Early evals are quick sanity checks (small sample, format-only) so a
234
+ # broken parser / generation drift gets caught before ~10 min of compute is
235
+ # burned. Late evals are real measurements with the full sample size.
236
+ # Catching format-compliance failure at step 15 instead of step 50 saves
237
+ # ~7 minutes per fire.
238
+ #
239
+ # Each entry: (step, sample_size, mode) where mode is "format_only" or
240
+ # "full". format_only skips the diversity probe and the physics-heavy
241
+ # logical_correction / hamming / syndrome metrics, so the eval costs
242
+ # ~30 seconds instead of ~2 minutes.
243
+ SFT_EVAL_SCHEDULE: tuple[tuple[int, int, str], ...] = (
244
+ # 2026-04: schedule rebuilt to fit the SFT_MAX_STEPS=50 budget. Two
245
+ # full evals plus a fast format probe gives the diversity-collapse
246
+ # early-stop two consecutive data points before the run ends, which
247
+ # is the minimum to fire the new run-length-2 stop rule.
248
+ (5, 30, "format_only"),
249
+ (15, 50, "full"),
250
+ (25, 100, "full"),
251
+ (40, 100, "full"),
252
+ (50, 100, "full"),
253
+ )
254
+ SFT_PRINT_SAMPLE_OUTPUTS: int = 5 # raw outputs printed at each eval
255
+
256
+ # Early-stop thresholds (master spec, section 3).
257
+ SFT_EARLY_STOP_FORMAT: float = 0.95
258
+ SFT_EARLY_STOP_CORRECTION: float = 0.80
259
+ SFT_EARLY_STOP_DIVERSITY: int = 3
260
+ SFT_MAX_WALL_SECONDS: float = 30 * 60.0 # 30-minute hard ceiling
261
+
262
+ # HuggingFace Trainer subfolder (step-50 save) used to initialise GRPO.
263
+ # ``python -m scripts.train_grpo`` defaults to this path; pipeline scripts
264
+ # also pass it explicitly.
265
+ SFT_CHECKPOINT_PATH_FOR_GRPO: str = "checkpoints/sft_warmup/checkpoint-50"
266
+
267
+ # ---- GRPO RL phase (master spec, section 5; 2026-04 spec rewrite) -------- #
268
+ # All numbers below were re-pinned by the 2026-04 GRPO spec. The previous
269
+ # defaults (GRPO_STEPS=2000, LR=1e-5, KL=0.04, max_prompt=512,
270
+ # max_completion=256, temperature=0.7) produced a degenerate "always say
271
+ # []" policy in <100 steps because reward variance collapsed and KL
272
+ # saturated the loss. The new defaults emphasise diversity:
273
+ #
274
+ # - higher temperature (1.2) + top_k + repetition_penalty -> non-collapsed rollouts
275
+ # - shorter max_completion_length (50) -> the answer is one short line anyway
276
+ # - longer max_prompt_length (1500) -> distance-3 syndromes already use
277
+ # ~280 tokens; distance-5 / curriculum L3 needs the headroom
278
+ # - lower KL coefficient (0.02) -> reward signal not dominated by KL drift
279
+ # - 1500 steps -> wall-clock fits the 13h cap with margin
280
+ GRPO_STEPS: int = 1_500
281
+ GRPO_GEN_PER_PROMPT: int = 4 # GRPO needs >=2 for advantage
282
+ GRPO_BATCH_SIZE: int = 1 # per-device prompts per step
283
+ GRPO_GRAD_ACCUM: int = 8 # effective batch = 8 prompts
284
+ GRPO_LR: float = 2e-5 # bumped from 1e-5; reward signal is sparse
285
+ GRPO_LR_SCHEDULER: str = "constant" # no warmup, no decay
286
+ GRPO_KL_COEF: float = 0.02 # half the TRL default; alarm if KL > 0.3
287
+ GRPO_MAX_PROMPT_LEN: int = 1_500 # surface-code prompts can run long
288
+ GRPO_MAX_COMPLETION_LEN: int = 50 # answer is one line: X_ERRORS=[..] Z_ERRORS=[..]
289
+
290
+ # ---- Diversity-focused rollout sampling (critical) ----------------------- #
291
+ # These apply to GRPO ROLLOUT generation only. Eval uses temperature=0
292
+ # (greedy) regardless of these. The combination temperature=1.2 + top_p=0.95
293
+ # + top_k=50 + repetition_penalty=1.1 was selected because:
294
+ # * temperature=1.2 broadens the per-token distribution past the SFT
295
+ # mode-collapsed favourite ("X_ERRORS=[] Z_ERRORS=[]").
296
+ # * top_p=0.95 keeps tail tokens in but truncates the long tail.
297
+ # * top_k=50 caps the candidate set so we don't sample garbage.
298
+ # * repetition_penalty=1.1 discourages the model from repeating the
299
+ # exact same byte sequence within a 4-completion group (reduces
300
+ # "all 4 generations identical" rate, which kills GRPO's gradient).
301
+ GRPO_TEMPERATURE: float = 1.2
302
+ GRPO_TOP_P: float = 0.95
303
+ GRPO_TOP_K: int = 50
304
+ GRPO_REPETITION_PENALTY: float = 1.1
305
+ GRPO_DO_SAMPLE: bool = True
306
+
307
+ # ---- Checkpoint cadence + retention -------------------------------------- #
308
+ GRPO_CHECKPOINT_EVERY: int = 100
309
+ GRPO_SAVE_TOTAL_LIMIT: int = 3 # keep 3 most recent rolling checkpoints
310
+ GRPO_LOG_EVERY: int = 5 # real-time visibility (every 5 steps)
311
+ GRPO_OPTIMIZER: str = "adamw_8bit"
312
+ GRPO_KL_ALARM: float = 0.3 # >this triggers manual triage
313
+ GRPO_KL_HARD_CEIL: float = 0.5 # >this -> kill the run
314
+
315
+ # ---- Wall-clock safety --------------------------------------------------- #
316
+ GRPO_WALL_SECONDS: float = 46_800.0 # 13 hours. Save+exit if exceeded.
317
+
318
+ # ---- Frozen eval set ----------------------------------------------------- #
319
+ # The 200-syndrome eval set is regenerated from the env at GRPO start with
320
+ # this seed. Same seed as SFT validation (sft_validation.jsonl) so eval
321
+ # distributions are comparable across SFT and GRPO. The set is cached on
322
+ # disk under data/grpo_validation.jsonl so reruns hit identical syndromes.
323
+ GRPO_VAL_SEED: int = 4_284
324
+ GRPO_VAL_EPISODES: int = 200
325
+ GRPO_VAL_PATH: str = "data/grpo_validation.jsonl"
326
+
327
+ # ---- Sample-table logging ------------------------------------------------ #
328
+ GRPO_SAMPLE_LOG_EVERY: int = 50
329
+ GRPO_SAMPLE_LOG_N: int = 5
330
+
331
+ # ---- Anti-hacking: mode-collapse inspection hook ------------------------- #
332
+ # Every N steps, we sample the most-recent N rollouts and check what
333
+ # fraction of prompts had ALL 4 generations identical. If too many
334
+ # prompts collapsed, raise the rollout temperature by a fixed step.
335
+ GRPO_INSPECTION_HOOK_EVERY: int = 100
336
+ GRPO_INSPECTION_SAMPLE_N: int = 10
337
+ GRPO_INSPECTION_COLLAPSE_THRESHOLD: int = 7 # "> 7 of 10"
338
+ GRPO_TEMP_BUMP_ON_COLLAPSE: float = 0.2
339
+
340
+ # ---- Decision-rule thresholds (warnings only; no auto-action) ----------- #
341
+ GRPO_DECISION_REWARD_STD_FLOOR: float = 0.03
342
+ GRPO_DECISION_REWARD_STD_CHECK_STEP: int = 50
343
+ GRPO_DECISION_BEAT_RATE_CHECK_STEP: int = 500
344
+ GRPO_DECISION_FORMAT_FLOOR: float = 0.95
345
+ GRPO_DECISION_GRAD_NORM_CEIL: float = 50.0
346
+ GRPO_DECISION_GRAD_NORM_RUN_LEN: int = 3 # consecutive logs
347
 
348
  # Decoding sampler defaults at evaluation/format-test time.
349
+ # (Used by greedy eval paths: temp/top_p only matter when do_sample=True.)
350
  SAMPLE_TEMPERATURE: float = 0.7
351
  SAMPLE_TOP_P: float = 0.95
352
 
 
369
  # all log to the same project / dashboard. Override per-run on the CLI.
370
  import os as _os # noqa: E402 (local import to keep top of module clean)
371
 
372
+ WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "QuantumScribe-GRPO")
373
+ """Default W&B project name. Override with ``WANDB_PROJECT=...``.
374
+
375
+ Changed 2026-04 from ``"QuantumScribe"`` to ``"QuantumScribe-GRPO"`` per
376
+ the GRPO spec rewrite. SFT runs that should land in the original project
377
+ should set ``WANDB_PROJECT=QuantumScribe`` at the shell."""
378
 
379
+ WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY", "ronitraj") or None
380
  """W&B team or username. ``None`` -> wandb's default entity for the user."""
381
 
382
  WANDB_DEFAULT_TAGS: tuple[str, ...] = (
 
389
  """Tags applied to every W&B run (per-script tags appended on top)."""
390
 
391
  WANDB_LOG_GENERATIONS_EVERY: int = 50
392
+ """Log a sample-completion table every N GRPO steps (master spec sec. 7)."""
393
 
394
+ WANDB_SAMPLE_GENERATIONS: int = 5
395
+ """Number of generations included in each sample-completion table.
396
+ Master spec, section 7: 'Save 5 randomly sampled rollouts ... and their rewards.'"""
397
 
398
+ WANDB_INLOOP_EVAL_EVERY: int = 100
399
  """Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
400
+ syndromes) every N GRPO steps. Tightened from 250 -> 100 by the 2026-04 GRPO
401
+ spec rewrite so collapse / drift gets caught within a 5-minute window
402
+ instead of a 15-minute window."""
403
+
404
+ WANDB_INLOOP_EVAL_EPISODES: int = 200
405
+ """Held-out syndromes per in-loop eval pass. Bumped from 100 -> 200 by the
406
+ 2026-04 spec rewrite so eval-stat error bars are tight enough to read
407
+ pymatching_beat_rate movement (which is sub-5% in early training)."""
408
 
409
+ WANDB_COMPARE_EVERY: int = 500
410
+ """Run the PyMatching head-to-head comparison every N steps (master spec sec. 7)."""
411
 
412
 
413
  # --------------------------------------------------------------------------- #
qubit_medic/prompts.py CHANGED
@@ -1,19 +1,23 @@
1
- """Prompt formatter and action parser (Section 2.3 + Section 2.5 of the plan).
2
-
3
- The prompt is engineered around five sections:
4
-
5
- 1. Role declaration
6
- 2. Physics summary (~50 tokens, plain English)
7
- 3. Syndrome data (round-by-round, labelled)
8
- 4. Output format spec (one example included)
9
- 5. Reasoning trigger ("think step by step ...")
10
-
11
- Total budget ~250-300 tokens for the prompt; ~150 for the response.
12
-
13
- The parser is deliberately permissive on whitespace and bracket style but
14
- strict on the existence of the two key tokens ``X_ERRORS`` and ``Z_ERRORS``.
15
- A partial-credit hook is exposed so Reward 4 can hand out 0.5 for "partly
16
- parseable".
 
 
 
 
17
  """
18
  from __future__ import annotations
19
 
@@ -23,33 +27,30 @@ from typing import Iterable
23
 
24
 
25
  # --------------------------------------------------------------------------- #
26
- # Prompt formatting #
27
  # --------------------------------------------------------------------------- #
28
 
29
- _ROLE = (
30
- "You are a quantum error-correction decoder. You are decoding errors in "
31
- "a distance-{distance} rotated surface code memory experiment."
32
- )
33
 
34
- _PHYSICS_SUMMARY = (
35
- "Stabilizers are parity checks measured every round. A *syndrome bit* "
36
- "is 1 when a stabilizer's measurement disagrees with its previous round, "
37
- "indicating a nearby physical error. Your job is to look at the syndrome "
38
- "history and output the smallest physical error pattern (X-flips and "
39
- "Z-flips on data qubits, identified by integer IDs) that explains it."
40
- )
41
 
42
- _OUTPUT_SPEC = (
43
- "Output format (REQUIRED, exact):\n"
44
- " X_ERRORS=[id1,id2,...] Z_ERRORS=[id1,id2,...]\n"
45
- "Use empty lists when no errors of that type. Example with no errors:\n"
46
- " X_ERRORS=[] Z_ERRORS=[]"
47
- )
48
 
49
- _REASONING_TRIGGER = (
50
- "Think step by step about which qubits could have caused this syndrome, "
51
- "then output your prediction in the required format."
52
- )
 
 
 
 
 
 
 
53
 
54
 
55
  def format_syndrome_block(
@@ -58,19 +59,33 @@ def format_syndrome_block(
58
  num_x_stabilizers: int,
59
  num_z_stabilizers: int,
60
  ) -> str:
61
- """Render the detector activations round-by-round.
62
-
63
- Stim emits detectors in a flat row-major order: round 0 stabilisers first,
64
- then round 1, and so on. We label by round and stabiliser type so the LLM
65
- can read the temporal structure.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  """
67
  bits = list(syndrome_bits)
68
  per_round = num_x_stabilizers + num_z_stabilizers
69
- lines = ["Syndrome (round-by-round):"]
70
  if per_round == 0 or rounds == 0 or len(bits) == 0:
71
- lines.append(" (no detectors fired)")
72
- return "\n".join(lines)
73
 
 
74
  for r in range(rounds):
75
  offset = r * per_round
76
  if offset >= len(bits):
@@ -79,18 +94,15 @@ def format_syndrome_block(
79
  x_chunk = chunk[:num_x_stabilizers]
80
  z_chunk = chunk[num_x_stabilizers : num_x_stabilizers + num_z_stabilizers]
81
  lines.append(
82
- f" Round {r + 1} X-stabilizers: "
83
- + " ".join(str(b) for b in x_chunk)
84
  )
85
  lines.append(
86
- f" Round {r + 1} Z-stabilizers: "
87
- + " ".join(str(b) for b in z_chunk)
88
  )
89
- # Trailing block for the final destructive measurement, if any extras.
90
  used = rounds * per_round
91
  if used < len(bits):
92
  tail = bits[used:]
93
- lines.append(" Final-round detectors: " + " ".join(str(b) for b in tail))
94
  return "\n".join(lines)
95
 
96
 
@@ -104,10 +116,10 @@ def build_prompt(
104
  num_z_stabilizers: int,
105
  num_data_qubits: int,
106
  ) -> str:
107
- """Assemble the full prompt the LLM sees on each step.
108
 
109
- Keeping this function pure (no I/O, no globals) means the SFT pipeline
110
- and the GRPO rollout use byte-identical inputs - a critical invariant.
111
  """
112
  syndrome_block = format_syndrome_block(
113
  syndrome_bits=syndrome_bits,
@@ -115,27 +127,53 @@ def build_prompt(
115
  num_x_stabilizers=num_x_stabilizers,
116
  num_z_stabilizers=num_z_stabilizers,
117
  )
118
- return (
119
- _ROLE.format(distance=distance)
120
- + "\n\n"
121
- + _PHYSICS_SUMMARY
122
- + "\n\n"
123
- + f"Code parameters: distance={distance}, rounds={rounds}, "
124
- + f"physical_error_rate={p:g}, data_qubits=0..{num_data_qubits - 1}.\n\n"
125
- + syndrome_block
126
- + "\n\n"
127
- + _OUTPUT_SPEC
128
- + "\n\n"
129
- + _REASONING_TRIGGER
130
  )
131
 
132
 
133
  # --------------------------------------------------------------------------- #
134
- # Output parsing #
135
  # --------------------------------------------------------------------------- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- _X_PATTERN = re.compile(r"X_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
138
- _Z_PATTERN = re.compile(r"Z_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
 
139
 
140
 
141
  @dataclass(frozen=True)
@@ -145,13 +183,19 @@ class ParseResult:
145
  parse_success: bool # True iff BOTH X_ERRORS and Z_ERRORS parsed cleanly
146
  parse_partial: bool # True iff exactly one of the two parsed cleanly
147
  raw_response: str
 
148
 
149
  @property
150
  def format_score(self) -> float:
151
- """Score for Reward 4 (format compliance)."""
152
- if self.parse_success:
 
 
 
 
 
153
  return 1.0
154
- if self.parse_partial:
155
  return 0.5
156
  return 0.0
157
 
@@ -160,6 +204,8 @@ def _parse_int_list(s: str, max_qubit: int) -> tuple[list[int], bool]:
160
  """Parse a comma/space-separated integer list. Drops out-of-range and dups.
161
 
162
  Returns ``(qubits_sorted_unique, all_tokens_were_valid)``.
 
 
163
  """
164
  if not s.strip():
165
  return [], True
@@ -182,25 +228,77 @@ def _parse_int_list(s: str, max_qubit: int) -> tuple[list[int], bool]:
182
 
183
 
184
  def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
185
- """Convert the LLM's raw text to a ``ParseResult``.
186
-
187
- Tolerant of trailing chain-of-thought, surrounding code fences, and
188
- casing, but strict on the existence of both ``X_ERRORS`` and ``Z_ERRORS``.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  """
190
  if not isinstance(raw_response, str):
191
- return ParseResult([], [], False, False, raw_response="")
192
-
193
- # If the model wrapped its answer in ```...``` blocks, focus on the last one.
194
- fenced = re.findall(r"```(?:[^\n]*)\n(.*?)```", raw_response, re.DOTALL)
195
- search_text = fenced[-1] if fenced else raw_response
196
-
197
- x_match = _X_PATTERN.search(search_text)
198
- z_match = _Z_PATTERN.search(search_text)
199
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  x_errors: list[int] = []
201
  z_errors: list[int] = []
202
  x_clean = z_clean = False
203
-
204
  if x_match is not None:
205
  x_errors, x_clean = _parse_int_list(x_match.group(1), num_data_qubits)
206
  if z_match is not None:
@@ -214,12 +312,17 @@ def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
214
  (x_match is not None and z_match is not None) and not parse_success
215
  )
216
 
 
 
 
 
217
  return ParseResult(
218
  x_errors=x_errors,
219
  z_errors=z_errors,
220
  parse_success=parse_success,
221
  parse_partial=parse_partial,
222
  raw_response=raw_response,
 
223
  )
224
 
225
 
 
1
+ """Locked prompt template + parser (master spec, sections 4 + parser).
2
+
3
+ This module is the *single source of truth* for what the LLM sees during
4
+ SFT and GRPO. The exact wording is fixed: anything that drifts the prompt
5
+ between phases throws away the SFT investment because RL builds on the
6
+ format SFT taught.
7
+
8
+ Spec sections honoured:
9
+ * Section 4 - "The exact prompt template (locked, for both SFT and RL)"
10
+ * Section 4 - "The {syndrome_block} format" (round-by-round, X first then Z)
11
+ * Section 4 - "The parser specification (critical)"
12
+
13
+ Parser highlights
14
+ -----------------
15
+ * Case-insensitive on ``X_ERRORS``/``Z_ERRORS`` keys.
16
+ * Tolerant of trailing chain-of-thought, code fences, and whitespace.
17
+ * **Takes the LAST occurrence** of ``X_ERRORS`` so the literal example
18
+ inside the prompt's "Examples:" block is never confused for the answer.
19
+ * Validates each id against ``[0, max_qubit_id]`` and dedups within a list.
20
+ * Returns a partial-credit score (1.0 / 0.5 / 0.0) for Reward 4.
21
  """
22
  from __future__ import annotations
23
 
 
27
 
28
 
29
  # --------------------------------------------------------------------------- #
30
+ # Prompt template (LOCKED - see master spec, section 4) #
31
  # --------------------------------------------------------------------------- #
32
 
33
+ _PROMPT_TEMPLATE = """You are an expert quantum error correction decoder. Your job is to identify which data qubits experienced errors based on syndrome measurements.
 
 
 
34
 
35
+ A surface code protects 1 logical qubit using {num_data_qubits} data qubits arranged in a {distance}x{distance} grid. Stabilizer measurements detect errors: a '1' means that stabilizer fired (detected something wrong nearby); a '0' means it looks fine. Errors must be deduced from the pattern of stabilizers that fired.
 
 
 
 
 
 
36
 
37
+ Code distance: {distance}
38
+ Number of stabilizer rounds: {rounds}
39
+ Physical error rate: {p}
40
+ X-stabilizer count per round: {num_x_stabilizers}
41
+ Z-stabilizer count per round: {num_z_stabilizers}
 
42
 
43
+ {syndrome_block}
44
+
45
+ Identify which data qubits (numbered 0-{max_qubit_id}) had X-errors and Z-errors. Most syndromes have 0-2 errors; an empty list means no errors of that type.
46
+
47
+ Output exactly ONE line and nothing else. Do not write reasoning, markdown, bullets, analysis, or explanations. Your entire response must match this exact format:
48
+ X_ERRORS=[qubit_ids] Z_ERRORS=[qubit_ids]
49
+
50
+ Valid one-line examples:
51
+ X_ERRORS=[] Z_ERRORS=[]
52
+ X_ERRORS=[] Z_ERRORS=[4]
53
+ X_ERRORS=[2] Z_ERRORS=[5,6]"""
54
 
55
 
56
  def format_syndrome_block(
 
59
  num_x_stabilizers: int,
60
  num_z_stabilizers: int,
61
  ) -> str:
62
+ """Render detector activations round-by-round, exactly per the spec.
63
+
64
+ Format example for distance-3, rounds=3:
65
+
66
+ Round 1 X-stabilizers: 0 0 1 0
67
+ Round 1 Z-stabilizers: 0 0 0 0
68
+ Round 2 X-stabilizers: 0 0 1 0
69
+ Round 2 Z-stabilizers: 0 0 0 0
70
+ Round 3 X-stabilizers: 0 0 0 0
71
+ Round 3 Z-stabilizers: 0 0 0 0
72
+
73
+ Every round on its own line, X first then Z, space-separated bits, no
74
+ indent, no commas. Rounds are always emitted in full even when all
75
+ bits are zero so the LLM sees consistent shape.
76
+
77
+ Stim's detector layout for the rotated-memory experiment is row-major:
78
+ round 0 stabilizers first, then round 1, and so on. For each round it
79
+ interleaves the per-type detectors in the order Stim's circuit was
80
+ generated (we treat the first ``num_x_stabilizers`` per round as X
81
+ and the rest as Z, matching ``per_round_x_z_counts``).
82
  """
83
  bits = list(syndrome_bits)
84
  per_round = num_x_stabilizers + num_z_stabilizers
 
85
  if per_round == 0 or rounds == 0 or len(bits) == 0:
86
+ return "(no detectors fired)"
 
87
 
88
+ lines: list[str] = []
89
  for r in range(rounds):
90
  offset = r * per_round
91
  if offset >= len(bits):
 
94
  x_chunk = chunk[:num_x_stabilizers]
95
  z_chunk = chunk[num_x_stabilizers : num_x_stabilizers + num_z_stabilizers]
96
  lines.append(
97
+ f"Round {r + 1} X-stabilizers: " + " ".join(str(int(b)) for b in x_chunk)
 
98
  )
99
  lines.append(
100
+ f"Round {r + 1} Z-stabilizers: " + " ".join(str(int(b)) for b in z_chunk)
 
101
  )
 
102
  used = rounds * per_round
103
  if used < len(bits):
104
  tail = bits[used:]
105
+ lines.append("Final-round detectors: " + " ".join(str(int(b)) for b in tail))
106
  return "\n".join(lines)
107
 
108
 
 
116
  num_z_stabilizers: int,
117
  num_data_qubits: int,
118
  ) -> str:
119
+ """Assemble the locked prompt the LLM sees on each step.
120
 
121
+ Pure function (no I/O, no globals) so the SFT pipeline and GRPO
122
+ rollout produce byte-identical prompt strings - a critical invariant.
123
  """
124
  syndrome_block = format_syndrome_block(
125
  syndrome_bits=syndrome_bits,
 
127
  num_x_stabilizers=num_x_stabilizers,
128
  num_z_stabilizers=num_z_stabilizers,
129
  )
130
+ return _PROMPT_TEMPLATE.format(
131
+ num_data_qubits=num_data_qubits,
132
+ distance=distance,
133
+ rounds=rounds,
134
+ p=p,
135
+ num_x_stabilizers=num_x_stabilizers,
136
+ num_z_stabilizers=num_z_stabilizers,
137
+ syndrome_block=syndrome_block,
138
+ max_qubit_id=num_data_qubits - 1,
 
 
 
139
  )
140
 
141
 
142
  # --------------------------------------------------------------------------- #
143
+ # Output parsing (LOCKED - see master spec, section 4 "Parser specification") #
144
  # --------------------------------------------------------------------------- #
145
+ #
146
+ # Two-tier parser:
147
+ # * STRICT - canonical "X_ERRORS=[...] Z_ERRORS=[...]". Only this form
148
+ # scores 1.0 on Reward 4 (format_compliance), so the GRPO signal still
149
+ # pushes the model toward the locked spec wording.
150
+ # * LENIENT - also accepts ":" instead of "=", "()" instead of "[]",
151
+ # "X-ERRORS" / "X ERRORS" key spellings, and tolerates
152
+ # \boxed{...} / **...** wrapping. Used so eval/metrics see
153
+ # the model's actual *answer* whenever it is extractable,
154
+ # instead of silently treating parse failures as
155
+ # "predict no errors" (which hides the bug at p=0.001 where
156
+ # ~95% of syndromes are trivial and an empty prediction is
157
+ # accidentally correct).
158
+
159
+ # Strict canonical form: "=" + "[]" - required for Reward 4 = 1.0.
160
+ _X_PATTERN_STRICT = re.compile(r"X_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
161
+ _Z_PATTERN_STRICT = re.compile(r"Z_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
162
+
163
+ # Lenient form: "=" or ":" separator, "[]" or "()" brackets, and the key may
164
+ # be spelt "X_ERRORS" / "X-ERRORS" / "X ERRORS" / "XERRORS".
165
+ _X_PATTERN_LENIENT = re.compile(
166
+ r"X[\s_\-]*ERRORS\s*[=:]\s*[\[\(]([^\]\)]*)[\]\)]",
167
+ re.IGNORECASE,
168
+ )
169
+ _Z_PATTERN_LENIENT = re.compile(
170
+ r"Z[\s_\-]*ERRORS\s*[=:]\s*[\[\(]([^\]\)]*)[\]\)]",
171
+ re.IGNORECASE,
172
+ )
173
 
174
+ # Key locator (lenient) - finds where any X-errors keyword starts so we can
175
+ # slice past in-prompt examples and home in on the model's actual answer.
176
+ _X_KEY = re.compile(r"X[\s_\-]*ERRORS", re.IGNORECASE)
177
 
178
 
179
  @dataclass(frozen=True)
 
183
  parse_success: bool # True iff BOTH X_ERRORS and Z_ERRORS parsed cleanly
184
  parse_partial: bool # True iff exactly one of the two parsed cleanly
185
  raw_response: str
186
+ strict_format: bool = False # True iff matched the canonical "=" + "[]" form
187
 
188
  @property
189
  def format_score(self) -> float:
190
+ """Score for Reward 4 (format compliance).
191
+
192
+ Only the canonical strict form earns 1.0, so the GRPO reward stays
193
+ anchored to the locked spec wording. Lenient parses or partials
194
+ score 0.5; total miss scores 0.0.
195
+ """
196
+ if self.parse_success and self.strict_format:
197
  return 1.0
198
+ if self.parse_success or self.parse_partial:
199
  return 0.5
200
  return 0.0
201
 
 
204
  """Parse a comma/space-separated integer list. Drops out-of-range and dups.
205
 
206
  Returns ``(qubits_sorted_unique, all_tokens_were_valid)``.
207
+ A token is "invalid" if it isn't an integer or falls outside ``[0, max_qubit)``.
208
+ Duplicates within a list count as silently de-duped, not invalid.
209
  """
210
  if not s.strip():
211
  return [], True
 
228
 
229
 
230
  def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
231
+ """Convert the LLM's raw text to a :class:`ParseResult`.
232
+
233
+ Two-pass algorithm:
234
+ 1. Receive the full model response string; normalise common LaTeX/
235
+ markdown wrappers (``\\boxed{...}``, ``**bold**``).
236
+ 2. If the model wrapped output in fenced code blocks, focus on the
237
+ LAST fenced block.
238
+ 3. Locate all X-errors keys; slice forward from the LAST one (so the
239
+ example block in the prompt never wins).
240
+ 4. Try the STRICT pattern (``X_ERRORS=[...]``) first. If both X and Z
241
+ lists match, ``strict_format=True``.
242
+ 5. Otherwise try the LENIENT pattern (``=`` or ``:``, ``[]`` or ``()``)
243
+ so a near-miss like ``X_ERRORS: [1]`` still surfaces the model's
244
+ intended prediction.
245
+ 6. Validate every parsed integer is in ``[0, max_qubit_id]``; reject
246
+ duplicates within a list.
247
+ 7. ``parse_success`` requires BOTH lists to parse cleanly;
248
+ ``parse_partial`` is set when exactly one parsed cleanly OR both
249
+ keys appear but tokens were dirty.
250
+
251
+ The lenient fallback exists for *eval/diagnostic honesty*, not to
252
+ weaken the training signal: ``format_score`` (Reward 4) only returns
253
+ 1.0 when ``strict_format`` is also True.
254
  """
255
  if not isinstance(raw_response, str):
256
+ return ParseResult([], [], False, False, raw_response="", strict_format=False)
257
+
258
+ # 1: normalise common wrappers so the regex sees the inner content.
259
+ normalised = raw_response
260
+ # Strip \boxed{...} (LaTeX) - keep inner text.
261
+ normalised = re.sub(r"\\boxed\{([^{}]*)\}", r"\1", normalised)
262
+ # Strip surrounding **bold** / *italic* markers around the format block.
263
+ normalised = re.sub(r"\*+([A-Za-z_][^*]{0,40})\*+", r"\1", normalised)
264
+
265
+ # 2: fence handling - prefer last fenced block if present.
266
+ fenced = re.findall(r"```(?:[^\n]*)\n(.*?)```", normalised, re.DOTALL)
267
+ search_text = fenced[-1] if fenced else normalised
268
+
269
+ # 3: find the LAST X-errors key occurrence.
270
+ x_keys = list(_X_KEY.finditer(search_text))
271
+ if x_keys:
272
+ last_x_pos = x_keys[-1].start()
273
+ slice_text = search_text[last_x_pos:]
274
+ # If the last key has no payload (truncated), fall back one.
275
+ if (
276
+ not _X_PATTERN_STRICT.search(slice_text)
277
+ and not _X_PATTERN_LENIENT.search(slice_text)
278
+ and len(x_keys) > 1
279
+ ):
280
+ last_x_pos = x_keys[-2].start()
281
+ slice_text = search_text[last_x_pos:]
282
+ else:
283
+ slice_text = search_text
284
+
285
+ # 4-5: try strict, then lenient.
286
+ x_match = _X_PATTERN_STRICT.search(slice_text)
287
+ z_matches_strict = list(_Z_PATTERN_STRICT.finditer(slice_text))
288
+ z_match = z_matches_strict[-1] if z_matches_strict else None
289
+ strict_x = x_match is not None
290
+ strict_z = z_match is not None
291
+
292
+ if x_match is None:
293
+ x_match = _X_PATTERN_LENIENT.search(slice_text)
294
+ if z_match is None:
295
+ z_matches_lenient = list(_Z_PATTERN_LENIENT.finditer(slice_text))
296
+ z_match = z_matches_lenient[-1] if z_matches_lenient else None
297
+
298
+ # 6: extract + validate qubit IDs.
299
  x_errors: list[int] = []
300
  z_errors: list[int] = []
301
  x_clean = z_clean = False
 
302
  if x_match is not None:
303
  x_errors, x_clean = _parse_int_list(x_match.group(1), num_data_qubits)
304
  if z_match is not None:
 
312
  (x_match is not None and z_match is not None) and not parse_success
313
  )
314
 
315
+ # strict_format is true only when BOTH X and Z hit the canonical pattern
316
+ # cleanly (no garbage tokens, no out-of-range qubits).
317
+ strict_format = bool(strict_x and strict_z and parse_success)
318
+
319
  return ParseResult(
320
  x_errors=x_errors,
321
  z_errors=z_errors,
322
  parse_success=parse_success,
323
  parse_partial=parse_partial,
324
  raw_response=raw_response,
325
+ strict_format=strict_format,
326
  )
327
 
328
 
qubit_medic/server/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (393 Bytes). View file
 
qubit_medic/server/__pycache__/app.cpython-312.pyc ADDED
Binary file (9.51 kB). View file
 
qubit_medic/server/__pycache__/curriculum.cpython-312.pyc ADDED
Binary file (5.55 kB). View file
 
qubit_medic/server/__pycache__/environment.cpython-312.pyc ADDED
Binary file (14.2 kB). View file
 
qubit_medic/server/__pycache__/openenv_adapter.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
qubit_medic/server/__pycache__/physics.cpython-312.pyc ADDED
Binary file (19.9 kB). View file
 
qubit_medic/server/__pycache__/rewards.cpython-312.pyc ADDED
Binary file (12.7 kB). View file
 
qubit_medic/server/app.py CHANGED
@@ -6,6 +6,8 @@ routes (``/reset``, ``/step``, ``/state``, ``/health``, ``/schema``,
6
 
7
  We add a few extras on top:
8
 
 
 
9
  * ``GET /healthz`` - the Day-0 deployment-substrate liveness probe
10
  (returns Stim/PyMatching/openenv versions). Used by the recurring
11
  4-hour HF Spaces wakeup ping.
@@ -24,6 +26,7 @@ import sys
24
  from typing import Optional
25
 
26
  from fastapi import Body, HTTPException
 
27
  from openenv.core import create_fastapi_app
28
 
29
  from qubit_medic.config import DEFAULT_HOST, DEFAULT_PORT
@@ -60,6 +63,44 @@ app.description = (
60
  )
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # --------------------------------------------------------------------------- #
64
  # Day-0 + demo extras #
65
  # --------------------------------------------------------------------------- #
@@ -79,6 +120,41 @@ def _get_legacy_env() -> DecoderEnvironment:
79
  return _legacy_env
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  @app.get("/healthz")
83
  def healthz() -> dict:
84
  """Lightweight liveness probe (Day-0 deployment-substrate test).
 
6
 
7
  We add a few extras on top:
8
 
9
+ * ``GET /`` - HTML landing page (HF Spaces **App** tab); links to
10
+ ``/docs``, ``/healthz``, ``/metadata`` (avoids 404 on the root URL).
11
  * ``GET /healthz`` - the Day-0 deployment-substrate liveness probe
12
  (returns Stim/PyMatching/openenv versions). Used by the recurring
13
  4-hour HF Spaces wakeup ping.
 
26
  from typing import Optional
27
 
28
  from fastapi import Body, HTTPException
29
+ from fastapi.responses import HTMLResponse
30
  from openenv.core import create_fastapi_app
31
 
32
  from qubit_medic.config import DEFAULT_HOST, DEFAULT_PORT
 
63
  )
64
 
65
 
66
+ @app.get("/", response_class=HTMLResponse, include_in_schema=False)
67
+ def root() -> str:
68
+ """Space + browser landing page (HF opens ``/`` in the App tab).
69
+
70
+ The OpenEnv API lives under ``/reset``, ``/step``, etc.; there was no
71
+ root handler, so visitors saw 404. This page links to docs and health.
72
+ """
73
+ return """<!DOCTYPE html>
74
+ <html lang="en">
75
+ <head>
76
+ <meta charset="utf-8"/>
77
+ <meta name="viewport" content="width=device-width, initial-scale=1"/>
78
+ <title>Qubit-Medic OpenEnv</title>
79
+ <style>
80
+ body { font-family: system-ui, sans-serif; max-width: 40rem; margin: 2rem auto; padding: 0 1rem; line-height: 1.5; color: #1e293b; }
81
+ h1 { font-size: 1.5rem; }
82
+ ul { padding-left: 1.2rem; }
83
+ a { color: #2563eb; }
84
+ code { background: #f1f5f9; padding: 0.1em 0.3em; border-radius: 4px; }
85
+ </style>
86
+ </head>
87
+ <body>
88
+ <h1>Qubit-Medic — OpenEnv server</h1>
89
+ <p>This Space exposes a <strong>JSON API</strong> for the quantum error-decoding
90
+ environment (Stim + PyMatching, OpenEnv contract). There is no full-page
91
+ Gradio UI here; use the links below.</p>
92
+ <ul>
93
+ <li><a href="/docs">Interactive API docs (Swagger)</a></li>
94
+ <li><a href="/redoc">ReDoc</a></li>
95
+ <li><a href="/healthz">Liveness <code>GET /healthz</code></a> — versions probe</li>
96
+ <li><a href="/metadata">OpenEnv <code>GET /metadata</code></a></li>
97
+ </ul>
98
+ <p>Typical flow: <code>POST /reset</code> then <code>POST /step</code> with
99
+ the model&rsquo;s text action — see the schema in <code>/docs</code>.</p>
100
+ </body>
101
+ </html>"""
102
+
103
+
104
  # --------------------------------------------------------------------------- #
105
  # Day-0 + demo extras #
106
  # --------------------------------------------------------------------------- #
 
120
  return _legacy_env
121
 
122
 
123
+ # --------------------------------------------------------------------------- #
124
+ # Compliance Section 2 (audit 2026-04): POST /state and POST /close. #
125
+ # --------------------------------------------------------------------------- #
126
+ # OpenEnv's create_fastapi_app already mounts GET /state and (via the
127
+ # canonical contract) does not expose /close at all. The participant-guide
128
+ # audit explicitly requires POST /state and POST /close, so we surface
129
+ # both as additional routes that delegate to the legacy DecoderEnvironment
130
+ # singleton (the same one /decode already uses). The OpenEnv-canonical
131
+ # GET /state route is preserved untouched.
132
+ # --------------------------------------------------------------------------- #
133
+
134
+
135
+ @app.post("/state")
136
+ def post_state() -> dict:
137
+ """POST mirror of the OpenEnv GET /state route.
138
+
139
+ Returns a JSON-serialisable snapshot of env state. Uses the inner
140
+ :meth:`DecoderEnvironment.state` (added in Section 1 compliance work)
141
+ which excludes ground-truth fields by construction.
142
+ """
143
+ return _get_legacy_env().state()
144
+
145
+
146
+ @app.post("/close")
147
+ def post_close() -> dict:
148
+ """POST /close: drop in-flight episodes on the legacy env singleton.
149
+
150
+ The singleton is rebuilt lazily on the next /reset, so calling /close
151
+ repeatedly is idempotent. Returns a small JSON dict so the caller can
152
+ confirm the request landed.
153
+ """
154
+ _get_legacy_env().close()
155
+ return {"ok": True, "closed": True}
156
+
157
+
158
  @app.get("/healthz")
159
  def healthz() -> dict:
160
  """Lightweight liveness probe (Day-0 deployment-substrate test).
qubit_medic/server/environment.py CHANGED
@@ -201,9 +201,17 @@ class DecoderEnvironment:
201
  with self._lock:
202
  episode = self._active.pop(episode_id, None)
203
  if episode is None:
204
- # Calling step() on an unknown episode ID is a hard error -
205
- # the trainer didn't follow reset/step pairing.
206
- raise KeyError(f"unknown or already-finished episode {episode_id}")
 
 
 
 
 
 
 
 
207
 
208
  elapsed = time.monotonic() - episode.started_at
209
  timed_out = elapsed > EPISODE_TIMEOUT_SECONDS
@@ -312,3 +320,36 @@ class DecoderEnvironment:
312
  "curriculum": self._scheduler.stats(),
313
  "cached_levels": list(self._caches.keys()),
314
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  with self._lock:
202
  episode = self._active.pop(episode_id, None)
203
  if episode is None:
204
+ # Calling step() on an unknown episode ID is a clean
205
+ # ValueError (compliance Section 1 of the participant-guide
206
+ # audit: the env must "raise a clean ValueError, not a
207
+ # Python traceback"). The trainer didn't follow reset/step
208
+ # pairing, or the episode already ended; either way we
209
+ # surface a typed exception so the FastAPI layer can turn
210
+ # it into a 400 response instead of a 500.
211
+ raise ValueError(
212
+ f"unknown or already-finished episode {episode_id}; "
213
+ f"call reset() before step()."
214
+ )
215
 
216
  elapsed = time.monotonic() - episode.started_at
217
  timed_out = elapsed > EPISODE_TIMEOUT_SECONDS
 
320
  "curriculum": self._scheduler.stats(),
321
  "cached_levels": list(self._caches.keys()),
322
  }
323
+
324
+ def state(self) -> dict:
325
+ """Return a JSON-serialisable snapshot of the env's externally-
326
+ visible state (compliance Section 1 of the participant-guide
327
+ audit: ``state()`` returns a JSON-serialisable object, not a raw
328
+ Python object).
329
+
330
+ Crucially this never includes the ground-truth fields stored on
331
+ the per-episode :class:`DecoderState` (true error patterns,
332
+ actual_observable_flip, pymatching_observable_pred, circuit_text,
333
+ dem_text). Those stay in ``self._active[ep].state`` and are only
334
+ consumed by the reward functions.
335
+ """
336
+ with self._lock:
337
+ return {
338
+ "episodes_started": int(self._episode_counter),
339
+ "active_episodes": int(len(self._active)),
340
+ "active_episode_ids": [int(ep) for ep in self._active.keys()],
341
+ "cached_levels": list(self._caches.keys()),
342
+ "curriculum": self._scheduler.stats(),
343
+ "base_seed": int(self._base_seed),
344
+ }
345
+
346
+ def close(self) -> None:
347
+ """Drop any in-flight episodes and clear caches.
348
+
349
+ Compliance Section 1: the gym-style API requires ``close()``.
350
+ After ``close()`` the env can still be re-used by calling
351
+ ``reset()`` again - we don't tear down the curriculum scheduler
352
+ or release the lock; we only release per-episode bookkeeping.
353
+ """
354
+ with self._lock:
355
+ self._active.clear()
qubit_medic/server/rewards.py CHANGED
@@ -84,13 +84,21 @@ def reward_syndrome_consistency(
84
  ) -> float:
85
  """How well does the predicted Pauli frame reproduce the FINAL detectors?
86
 
87
- Computes Hamming similarity between ``predicted_final_bits`` (induced by
88
- the predicted X errors) and ``observed_final_bits``. Returns
89
  ``1 - hamming_distance / num_final_detectors``.
90
 
91
  Rationale (Section 3.2): without this term, an LLM that lucky-guesses
92
- the right qubits could get Reward 1 occasionally; this signal forces it
93
- to also explain the data the syndrome carries.
 
 
 
 
 
 
 
 
94
  """
95
  final_dets = layout.final_detectors
96
  if not final_dets:
@@ -104,7 +112,17 @@ def reward_syndrome_consistency(
104
  predicted = implied.get(det_idx, 0)
105
  if observed != predicted:
106
  distance += 1
107
- return 1.0 - distance / len(final_dets)
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  def compute_final_detector_supports(
@@ -141,13 +159,37 @@ def compute_final_detector_supports(
141
  # --------------------------------------------------------------------------- #
142
 
143
 
144
- def _jaccard(a: list[int], b: list[int]) -> float:
145
- """Jaccard index. Returns 1.0 when both sets are empty (perfect agreement)."""
146
- sa, sb = set(a), set(b)
147
- if not sa and not sb:
148
- return 1.0
149
- inter = len(sa & sb)
150
- union = len(sa | sb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  return inter / union if union else 1.0
152
 
153
 
@@ -156,16 +198,19 @@ def reward_hamming_overlap(
156
  sample: SyndromeSample,
157
  layout: CircuitLayout,
158
  ) -> float:
159
- """Average of Jaccard(X) and Jaccard(Z) against the reference frame.
160
-
161
- Reference is PyMatching's per-edge predicted Pauli frame
162
- (``sample.pymatching_x_errors`` / ``..._z_errors``). This is the dense
163
- partial-credit signal of Section 3.3 - even if Reward 1 fires zero,
164
- being *close* to the canonical solution still gets credit, smoothing
165
- the reward landscape during early training.
 
 
 
166
  """
167
- jx = _jaccard(parsed.x_errors, sample.pymatching_x_errors)
168
- jz = _jaccard(parsed.z_errors, sample.pymatching_z_errors)
169
  return 0.5 * (jx + jz)
170
 
171
 
@@ -175,8 +220,16 @@ def reward_hamming_overlap(
175
 
176
 
177
  def reward_format_compliance(parsed: ParseResult) -> float:
178
- """1.0 if both keys parsed, 0.5 if exactly one, 0.0 if neither."""
179
- return parsed.format_score
 
 
 
 
 
 
 
 
180
 
181
 
182
  # --------------------------------------------------------------------------- #
 
84
  ) -> float:
85
  """How well does the predicted Pauli frame reproduce the FINAL detectors?
86
 
87
+ Computes Hamming similarity between ``predicted_final_bits`` (induced
88
+ by the predicted X errors) and ``observed_final_bits``. Returns
89
  ``1 - hamming_distance / num_final_detectors``.
90
 
91
  Rationale (Section 3.2): without this term, an LLM that lucky-guesses
92
+ the right qubits could get Reward 1 occasionally; this signal forces
93
+ it to also explain the data the syndrome carries.
94
+
95
+ 2026-04 anti-collapse cap (FIX 1, RL spec rewrite): if the prediction
96
+ is empty AND the observed syndrome is non-empty (at least one
97
+ detector fired), cap the score at 0.5. Without this cap, the
98
+ "always predict empty" policy can still pull a high syndrome-
99
+ consistency score on the prompts where the implied final-round bits
100
+ happen to coincide with zeros, which kept GRPO trapped in the
101
+ constant-empty mode.
102
  """
103
  final_dets = layout.final_detectors
104
  if not final_dets:
 
112
  predicted = implied.get(det_idx, 0)
113
  if observed != predicted:
114
  distance += 1
115
+ base = 1.0 - distance / len(final_dets)
116
+
117
+ # Anti-collapse cap: empty prediction + non-empty observed syndrome
118
+ # is a "did nothing while alarms were firing" failure mode. Cap at
119
+ # 0.5 so the empty policy can never approach the full 1.0 even when
120
+ # the implied final-round bits happen to coincide.
121
+ pred_is_empty = (not parsed.x_errors) and (not parsed.z_errors)
122
+ has_active_syndrome = any(int(b) != 0 for b in sample.syndrome_bits)
123
+ if pred_is_empty and has_active_syndrome:
124
+ return min(base, 0.5)
125
+ return base
126
 
127
 
128
  def compute_final_detector_supports(
 
159
  # --------------------------------------------------------------------------- #
160
 
161
 
162
+ def _set_aware_jaccard(true_set: list[int], pred_set: list[int]) -> float:
163
+ """Set-aware Jaccard: penalises BOTH false alarms and missed errors.
164
+
165
+ 2026-04 spec rewrite (FIX 1). The four-case rule is what makes
166
+ "predict empty everywhere" stop being a near-optimal strategy:
167
+
168
+ +-------------+-----------+-----------------------------------------+
169
+ | true_set | pred_set | score |
170
+ +-------------+-----------+-----------------------------------------+
171
+ | empty | empty | 1.0 (perfect, "no errors -> no edit") |
172
+ | empty | non-empty | 0.0 false alarm |
173
+ | non-empty | empty | 0.0 missed errors <-- the key change |
174
+ | non-empty | non-empty | |inter| / |union| (standard Jaccard) |
175
+ +-------------+-----------+-----------------------------------------+
176
+
177
+ Critically the third case used to score 1.0 under the prior plain
178
+ Jaccard (because both sets were treated symmetrically; "everything
179
+ correct, just nothing predicted" was indistinguishable from "perfect
180
+ agreement"). Under this rule a missed-error answer scores 0.0,
181
+ which moves the GRPO reward landscape so a non-trivial prediction
182
+ can climb out of the empty-everywhere local optimum.
183
+ """
184
+ sa, sp = set(true_set), set(pred_set)
185
+ if not sa and not sp:
186
+ return 1.0 # perfect agreement: no true errors AND no claimed errors
187
+ if not sa and sp:
188
+ return 0.0 # false alarm: claimed errors that were not there
189
+ if sa and not sp:
190
+ return 0.0 # missed errors: alarms fired but model said nothing
191
+ inter = len(sa & sp)
192
+ union = len(sa | sp)
193
  return inter / union if union else 1.0
194
 
195
 
 
198
  sample: SyndromeSample,
199
  layout: CircuitLayout,
200
  ) -> float:
201
+ """Average of set-aware Jaccard(X) and set-aware Jaccard(Z) against
202
+ the reference Pauli frame carried by ``SyndromeSample``.
203
+
204
+ The reference frame lives on
205
+ ``sample.pymatching_x_errors`` / ``sample.pymatching_z_errors``
206
+ in this codebase that frame is treated as the ground-truth target
207
+ (the SFT/GRPO dataset builders fill it from the same source as the
208
+ JSONL ``true_x_errors`` / ``true_z_errors`` fields). Per-axis score
209
+ uses the set-aware rule (see :func:`_set_aware_jaccard`), so missed
210
+ errors no longer score 1.0 just because the prediction set is empty.
211
  """
212
+ jx = _set_aware_jaccard(sample.pymatching_x_errors, parsed.x_errors)
213
+ jz = _set_aware_jaccard(sample.pymatching_z_errors, parsed.z_errors)
214
  return 0.5 * (jx + jz)
215
 
216
 
 
220
 
221
 
222
  def reward_format_compliance(parsed: ParseResult) -> float:
223
+ """Binary {0.0, 1.0}: 1.0 iff the parser fully extracted both lists.
224
+
225
+ 2026-04 spec rewrite (FIX 1): partial credit (0.5) is removed. With
226
+ partial credit on, the model could still earn ~half the format
227
+ weight on garbage outputs that resembled the canonical form, which
228
+ is part of what kept the reward landscape too flat for GRPO to
229
+ escape the empty-everywhere mode. The new rule rewards only a
230
+ cleanly-parsed answer.
231
+ """
232
+ return 1.0 if parsed.parse_success else 0.0
233
 
234
 
235
  # --------------------------------------------------------------------------- #
qubit_medic/wandb_utils.py CHANGED
@@ -260,12 +260,22 @@ def run_context(run_name: str, job_type: str, **kwargs):
260
 
261
  def log(metrics: Mapping[str, Any], *, step: Optional[int] = None,
262
  commit: bool = True) -> None:
263
- """No-op-safe ``wandb.log`` wrapper."""
 
 
 
 
 
 
 
264
  wandb = _import_wandb()
265
  if wandb is None or _RUN is None:
266
  return
267
  try:
268
- wandb.log(dict(metrics), step=step, commit=commit)
 
 
 
269
  except Exception as exc: # pragma: no cover - defensive
270
  print(f"[wandb] log failed: {exc}", file=sys.stderr)
271
 
 
260
 
261
  def log(metrics: Mapping[str, Any], *, step: Optional[int] = None,
262
  commit: bool = True) -> None:
263
+ """No-op-safe ``wandb.log`` wrapper.
264
+
265
+ We store training-step alignment as an explicit scalar
266
+ ``train/global_step`` instead of passing W&B's reserved ``step=`` value.
267
+ HuggingFace/TRL may advance W&B's internal step before our callback logs,
268
+ which otherwise produces "Tried to log to step N that is less than the
269
+ current step N+1" and drops eval metrics.
270
+ """
271
  wandb = _import_wandb()
272
  if wandb is None or _RUN is None:
273
  return
274
  try:
275
+ payload = dict(metrics)
276
+ if step is not None and "train/global_step" not in payload:
277
+ payload["train/global_step"] = int(step)
278
+ wandb.log(payload, commit=commit)
279
  except Exception as exc: # pragma: no cover - defensive
280
  print(f"[wandb] log failed: {exc}", file=sys.stderr)
281