ronitraj commited on
Commit
195f87e
·
verified ·
1 Parent(s): 1d9e50c

Real env: openenv-core wrapped DecoderEnvironment + /healthz + /decode

Browse files
Dockerfile CHANGED
@@ -1,25 +1,61 @@
1
- FROM python:3.11-slim
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- ENV PYTHONUNBUFFERED=1 \
 
 
 
4
  PIP_NO_CACHE_DIR=1 \
5
  PIP_DISABLE_PIP_VERSION_CHECK=1
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  RUN useradd -m -u 1000 user
8
  USER user
9
  ENV PATH="/home/user/.local/bin:$PATH"
10
 
11
  WORKDIR /app
12
 
13
- # Day-0 deployment-substrate dependencies (Section 11 of the plan):
14
- # stim + pymatching + fastapi + openenv-core, plus uvicorn for the server.
15
- RUN pip install --user --no-cache-dir \
16
- "stim>=1.13,<2.0" \
17
- "pymatching>=2.2,<3.0" \
18
- "fastapi>=0.110" \
19
- "uvicorn[standard]>=0.27" \
20
- "openenv-core>=0.2.1"
21
 
22
- COPY --chown=user app.py /app/app.py
 
 
 
 
 
23
 
24
  EXPOSE 7860
25
- CMD ["python", "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
 
 
 
 
1
+ # Qubit-Medic OpenEnv server container.
2
+ #
3
+ # This image ships ONLY the env-server code:
4
+ # * stim + pymatching - quantum simulation + matching baseline
5
+ # * fastapi + uvicorn - HTTP transport
6
+ # * openenv-core - canonical OpenEnv contract (/reset, /step,
7
+ # /state, /health, /schema, /metadata, /mcp,
8
+ # /docs)
9
+ #
10
+ # Heavy ML training deps (torch, transformers, trl, unsloth) are
11
+ # deliberately NOT installed - they live in requirements-train.txt and
12
+ # are installed only by the Colab training notebook. Keeping the Spaces
13
+ # image lean avoids the ~10 GB CUDA wheel that would blow the free tier.
14
 
15
+ FROM python:3.11-slim AS base
16
+
17
+ ENV PYTHONDONTWRITEBYTECODE=1 \
18
+ PYTHONUNBUFFERED=1 \
19
  PIP_NO_CACHE_DIR=1 \
20
  PIP_DISABLE_PIP_VERSION_CHECK=1
21
 
22
+ # Stim and PyMatching ship manylinux wheels - no system C++ deps needed
23
+ # beyond libstdc++. We keep build-essential as a safety net for any
24
+ # unexpected source-fallback path on the build host.
25
+ RUN apt-get update \
26
+ && apt-get install -y --no-install-recommends \
27
+ build-essential \
28
+ ca-certificates \
29
+ curl \
30
+ && rm -rf /var/lib/apt/lists/*
31
+
32
+ # HF Spaces best-practice: run as non-root user with UID 1000.
33
  RUN useradd -m -u 1000 user
34
  USER user
35
  ENV PATH="/home/user/.local/bin:$PATH"
36
 
37
  WORKDIR /app
38
 
39
+ COPY --chown=user requirements.txt /app/requirements.txt
40
+ RUN pip install --user --upgrade pip \
41
+ && pip install --user -r /app/requirements.txt
42
+
43
+ COPY --chown=user qubit_medic /app/qubit_medic
44
+ COPY --chown=user README.md /app/README.md
 
 
45
 
46
+ # Pre-warm Stim/PyMatching caches at build time so the first request
47
+ # after `docker run` has near-zero latency (Section 9.1 of the plan).
48
+ RUN python -c "from qubit_medic.server.environment import DecoderEnvironment; \
49
+ e = DecoderEnvironment(); \
50
+ e._cache_for('L1_warmup'); \
51
+ e._cache_for('L2_target')"
52
 
53
  EXPOSE 7860
54
+
55
+ ENV LOG_LEVEL=INFO \
56
+ QUBIT_MEDIC_HOST=0.0.0.0 \
57
+ QUBIT_MEDIC_PORT=7860
58
+
59
+ # Boots the FastAPI app (qubit_medic.server.app) which is built on top
60
+ # of openenv.core.create_fastapi_app.
61
+ CMD ["python", "-m", "qubit_medic.server.app"]
README.md CHANGED
@@ -7,43 +7,118 @@ sdk: docker
7
  app_port: 7860
8
  pinned: false
9
  license: mit
10
- short_description: Day-0 placeholder for the Qubit-Medic OpenEnv server.
11
  tags:
12
  - openenv
13
  - reinforcement-learning
14
  - quantum-error-correction
15
  - stim
16
  - pymatching
 
 
 
17
  ---
18
 
19
- # QuantumScribe — Day-0 placeholder
20
 
21
- This Space is the **deployment-substrate placeholder** for the
22
- **Qubit-Medic** OpenEnv server, an RL training environment that teaches
23
- an LLM to decode errors on the rotated surface code.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  ## Try it
26
 
27
- * `GET /` — root metadata.
28
- * `GET /healthz` — liveness probe; returns `{"ok": true, "stim_version": "...", ...}`.
29
 
30
  ```bash
31
  curl https://ronitraj-quantumscribe.hf.space/healthz
32
  ```
33
 
34
- ## What's coming
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- This placeholder will be replaced by the real **Qubit-Medic** environment:
 
 
 
 
 
 
37
 
38
- | Endpoint | What it does |
39
- |---|---|
40
- | `POST /reset` | Sample a fresh syndrome + observation for the LLM. |
41
- | `POST /step` | Score the LLM's predicted Pauli frame with five independent rewards (logical correction, syndrome consistency, Hamming overlap, format compliance, PyMatching beat-rate). |
42
- | `GET /health` | Curriculum + episode statistics. |
43
- | `GET /healthz` | Liveness (already live on this placeholder). |
44
 
45
  ## Stack
46
 
47
- * [Stim](https://github.com/quantumlib/Stim) — Clifford circuit simulator.
48
- * [PyMatching](https://github.com/oscarhiggott/PyMatching) — minimum-weight matching baseline.
49
- * [FastAPI](https://fastapi.tiangolo.com/) + [openenv-core](https://pypi.org/project/openenv-core/) server + RL contract.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  app_port: 7860
8
  pinned: false
9
  license: mit
10
+ short_description: OpenEnv RL env that teaches an LLM to decode quantum errors.
11
  tags:
12
  - openenv
13
  - reinforcement-learning
14
  - quantum-error-correction
15
  - stim
16
  - pymatching
17
+ - grpo
18
+ - trl
19
+ - llm
20
  ---
21
 
22
+ # QuantumScribe — Qubit-Medic OpenEnv
23
 
24
+ > An [OpenEnv](https://meta-pytorch.github.io/OpenEnv/) reinforcement-learning
25
+ > environment that teaches a Large Language Model to decode errors on the
26
+ > rotated surface code, using **Stim** for physics-accurate noise sampling
27
+ > and **PyMatching** as the classical baseline to beat.
28
+
29
+ This Space hosts the **environment server only**. Training (SFT warm-up
30
+ + GRPO RL) runs on a separate Colab T4; the trained LoRA adapter is
31
+ loaded client-side, not on this Space.
32
+
33
+ ## Endpoints
34
+
35
+ This is the canonical **OpenEnv contract** registered by
36
+ `openenv.core.create_fastapi_app`, plus two extras of our own:
37
+
38
+ | Method | Path | Purpose |
39
+ |---|---|---|
40
+ | `POST` | `/reset` | Sample a fresh syndrome + observation. Body `{"seed": int?, "episode_id": str?}`. Optional `?forced_level=L1_warmup\|L2_target\|L3_stretch`. |
41
+ | `POST` | `/step` | Score the LLM's prediction with five independent rewards. Body `{"action": {"raw_response": "...", "episode_id": int}, "timeout_s": float?, "request_id": str?}`. |
42
+ | `GET` | `/state` | Curriculum + episode counters (no physics-truth fields). |
43
+ | `GET` | `/health` | OpenEnv liveness response. |
44
+ | `GET` | `/healthz` | Lightweight Day-0 probe — Stim/PyMatching/openenv versions. |
45
+ | `GET` | `/schema` | JSON Schema for `QubitMedicAction` / `QubitMedicObservation`. |
46
+ | `GET` | `/metadata` | Environment metadata (name, description, version). |
47
+ | `POST` | `/mcp` | Model Context Protocol endpoint. |
48
+ | `POST` | `/decode` | PyMatching baseline demo: pass a hand-crafted syndrome, get the matching-decoder result. |
49
+ | `GET` | `/docs` | Swagger UI for everything above. |
50
 
51
  ## Try it
52
 
53
+ Curl from anywhere:
 
54
 
55
  ```bash
56
  curl https://ronitraj-quantumscribe.hf.space/healthz
57
  ```
58
 
59
+ Use it from Python with the OpenEnv client:
60
+
61
+ ```python
62
+ from openenv.core import GenericEnvClient
63
+
64
+ with GenericEnvClient(base_url="https://ronitraj-quantumscribe.hf.space").sync() as env:
65
+ obs = env.reset(seed=42)
66
+ print(obs.observation["prompt"][:200])
67
+ result = env.step({"raw_response": "<answer>X: 0,3 | Z:</answer>", "episode_id": 1})
68
+ print("reward:", result.reward, "rewards breakdown:", result.observation["info"]["rewards"])
69
+ ```
70
+
71
+ Or hit the env directly with `httpx`:
72
+
73
+ ```python
74
+ import httpx
75
+ url = "https://ronitraj-quantumscribe.hf.space"
76
+ obs = httpx.post(f"{url}/reset", json={"seed": 42},
77
+ params={"forced_level": "L2_target"}).json()["observation"]
78
+ print(obs["prompt"][:200])
79
+ res = httpx.post(f"{url}/step", json={
80
+ "action": {"raw_response": "<answer>X: | Z:</answer>",
81
+ "episode_id": obs["episode_id"]}
82
+ }).json()
83
+ print("reward =", res["reward"], "rewards =", res["observation"]["info"]["rewards"])
84
+ ```
85
+
86
+ ## What the rewards mean
87
+
88
+ Each `step` returns five *independent, verifiable* reward components:
89
 
90
+ | Reward | Weight | What it measures |
91
+ |---|---:|---|
92
+ | `logical_correction` | 0.40 | 1 if the predicted Pauli frame preserves the logical-Z observable. |
93
+ | `syndrome_consistency` | 0.20 | Hamming similarity over final-round detector parities. |
94
+ | `hamming_overlap` | 0.20 | Mean Jaccard similarity vs. the PyMatching reference Pauli frame. |
95
+ | `format_compliance` | 0.10 | 1 / 0.5 / 0 for full / partial / unparseable LLM output. |
96
+ | `pymatching_beat` | 0.10 | 1 iff PyMatching is wrong on this syndrome **and** the model is right. |
97
 
98
+ All five are computed from the same `(prompt, completion, syndrome)` tuple on every step (no redundant sampling — see the architecture notes in the GitHub README).
 
 
 
 
 
99
 
100
  ## Stack
101
 
102
+ * [openenv-core](https://pypi.org/project/openenv-core/) `>=0.2.1` environment contract, FastAPI scaffolding, MCP, WebSocket sessions.
103
+ * [Stim](https://github.com/quantumlib/Stim) — Clifford circuit simulator and detector-error-model generator.
104
+ * [PyMatching](https://github.com/oscarhiggott/PyMatching) `>=2.2` — minimum-weight matching baseline (and ground-truth for the `pymatching_beat` reward).
105
+ * FastAPI + Uvicorn — HTTP transport.
106
+
107
+ ## Curriculum
108
+
109
+ | Level | Distance | Rounds | Physical error rate | Promotion threshold |
110
+ |---|---:|---:|---:|---:|
111
+ | `L1_warmup` | 3 | 1 | 1e-4 | logical_correction ≥ 0.80 |
112
+ | `L2_target` | 3 | 3 | 1e-3 | logical_correction ≥ 0.70 |
113
+ | `L3_stretch` | 5 | 5 | 1e-3 | logical_correction ≥ 0.30 |
114
+
115
+ The server's curriculum scheduler tracks a moving average of `logical_correction` per level and promotes the agent when it crosses the threshold.
116
+
117
+ ## See also
118
+
119
+ * OpenEnv documentation — <https://meta-pytorch.github.io/OpenEnv/>
120
+ * TRL `environment_factory=` integration — <https://huggingface.co/docs/trl/main/openenv>
121
+
122
+ ---
123
+
124
+ *Built for the META RL hackathon. Source code, training notebook, and reproduction instructions: see the [GitHub repository](https://github.com/ronitraj/qubit-medic) (link to be updated).*
app.py DELETED
@@ -1,78 +0,0 @@
1
- """QuantumScribe - Day-0 deployment-substrate placeholder.
2
-
3
- This minimal FastAPI app proves the Hugging Face Spaces deployment
4
- substrate works for the Qubit-Medic / QuantumScribe project:
5
-
6
- * Stim + PyMatching + FastAPI + openenv-core install cleanly in HF's
7
- build environment (Day-0 step 2).
8
- * GET /healthz returns {"ok": true, "stim_version": "..."}, proving the
9
- server boots and the heavy quantum dependency loads (Day-0 step 3).
10
- * The endpoint is reachable from a browser (Day-0 step 4) and from a
11
- Colab `requests.get(...)` call (Day-0 step 5).
12
-
13
- Once all Day-0 gates pass, this file is replaced by the real Qubit-Medic
14
- OpenEnv server (qubit_medic.server.app) at the same Space URL, inheriting
15
- the warm build cache.
16
- """
17
- from __future__ import annotations
18
-
19
- import sys
20
-
21
- import stim
22
- from fastapi import FastAPI
23
-
24
-
25
- app = FastAPI(
26
- title="QuantumScribe (Qubit-Medic) - Hello Space",
27
- description="Day-0 deployment-substrate placeholder for the Qubit-Medic "
28
- "OpenEnv server. Will be replaced by the real env shortly.",
29
- version="0.0.1-placeholder",
30
- )
31
-
32
-
33
- _PYMATCHING_VERSION: str
34
- try:
35
- import pymatching as _pm
36
- _PYMATCHING_VERSION = getattr(_pm, "__version__", "unknown")
37
- except Exception as exc:
38
- _PYMATCHING_VERSION = f"import-error: {exc}"
39
-
40
- _OPENENV_VERSION: str
41
- try:
42
- import openenv as _oe
43
- _OPENENV_VERSION = getattr(_oe, "__version__", "unknown")
44
- except Exception as exc:
45
- _OPENENV_VERSION = f"import-error: {exc}"
46
-
47
-
48
- @app.get("/")
49
- def root() -> dict:
50
- return {
51
- "service": "QuantumScribe (Qubit-Medic)",
52
- "status": "Day-0 placeholder live",
53
- "next": "POST /reset and /step will become available once the real "
54
- "DecoderEnvironment is pushed.",
55
- "endpoints": ["/", "/healthz"],
56
- "links": {
57
- "github": "https://github.com/ronitraj (replace with repo URL once public)",
58
- "openenv_docs": "https://meta-pytorch.org/OpenEnv/",
59
- },
60
- }
61
-
62
-
63
- @app.get("/healthz")
64
- def healthz() -> dict:
65
- """Day-0 liveness probe.
66
-
67
- Returns the Stim version - so curl-ing this in a browser or from Colab
68
- proves both that networking works AND that the heavy quantum deps
69
- actually loaded. This is the literal endpoint the plan calls for.
70
- """
71
- return {
72
- "ok": True,
73
- "stim_version": stim.__version__,
74
- "pymatching_version": _PYMATCHING_VERSION,
75
- "openenv_version": _OPENENV_VERSION,
76
- "python_version": sys.version.split()[0],
77
- "service": "QuantumScribe",
78
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
qubit_medic/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qubit-Medic: An LLM-trained quantum error-correction decoder.
2
+
3
+ The package is split into three layers (Section 0 of the plan):
4
+
5
+ * ``qubit_medic.config`` - the locked experiment configuration.
6
+ * ``qubit_medic.server`` - Stim physics, rewards, curriculum, FastAPI app.
7
+ * ``qubit_medic.client`` - the lightweight HTTP stub the trainer imports.
8
+
9
+ ``qubit_medic.models`` and ``qubit_medic.prompts`` are the contract both sides
10
+ agree on: what the LLM sees and what the LLM is allowed to emit.
11
+ """
12
+
13
+ from qubit_medic import config, models, prompts # noqa: F401
14
+
15
+ __version__ = "1.0.0"
qubit_medic/client/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Qubit-Medic client - the lightweight HTTP stub the trainer imports."""
2
+
3
+ from qubit_medic.client.client import DecoderClient, LocalDecoderClient
4
+
5
+ __all__ = ["DecoderClient", "LocalDecoderClient"]
qubit_medic/client/client.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Two equivalent client implementations:
2
+
3
+ * :class:`DecoderClient` - hits an HTTP endpoint (HF Spaces deployment).
4
+ Speaks the **OpenEnv** wire format:
5
+ - ``POST /reset`` body ``{"seed": int?, "episode_id": str?}``
6
+ - ``POST /step`` body ``{"action": {"raw_response": "...", ...},
7
+ "timeout_s": float?, "request_id": str?}``
8
+ * :class:`LocalDecoderClient` - calls the in-process env directly. Use this
9
+ in tests, in CI, and during local Colab runs to skip HTTP overhead.
10
+
11
+ Both expose the same ``reset`` / ``step`` API so the training scripts can
12
+ swap between them via a single env var (``QUBIT_MEDIC_URL``).
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ from typing import Optional, Protocol
18
+
19
+ import httpx
20
+
21
+ from qubit_medic.models import (
22
+ DecoderObservation,
23
+ StepResult,
24
+ )
25
+
26
+
27
+ class _ClientProtocol(Protocol):
28
+ def reset(self, *, seed: Optional[int] = None,
29
+ forced_level: Optional[str] = None) -> DecoderObservation: ...
30
+ def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
31
+ def health(self) -> dict: ...
32
+ def close(self) -> None: ...
33
+
34
+
35
+ def _obs_from_openenv(payload: dict) -> DecoderObservation:
36
+ """Re-hydrate our internal :class:`DecoderObservation` from the
37
+ OpenEnv response body. The OpenEnv wrapper inlines all our fields
38
+ onto the observation, so this is a 1-1 field mapping."""
39
+ return DecoderObservation(
40
+ prompt=payload.get("prompt", ""),
41
+ syndrome_bits=list(payload.get("syndrome_bits", [])),
42
+ distance=int(payload.get("distance", 0)),
43
+ rounds=int(payload.get("rounds", 0)),
44
+ p=float(payload.get("p", 0.0)),
45
+ curriculum_level=payload.get("curriculum_level", ""),
46
+ episode_id=int(payload.get("episode_id", 0)),
47
+ dem_digest=payload.get("dem_digest", ""),
48
+ )
49
+
50
+
51
+ class DecoderClient:
52
+ """HTTP client targeting a deployed FastAPI server (OpenEnv shape)."""
53
+
54
+ def __init__(self, base_url: str, *, timeout: float = 60.0) -> None:
55
+ self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout)
56
+
57
+ def reset(self, *, seed: Optional[int] = None,
58
+ forced_level: Optional[str] = None) -> DecoderObservation:
59
+ # OpenEnv's ResetRequest only accepts seed + episode_id. We pass
60
+ # forced_level via the URL query string so adapters that honour
61
+ # it (our QubitMedicEnvironment via **kwargs) pick it up; servers
62
+ # that ignore it just get a default level.
63
+ body: dict = {}
64
+ if seed is not None:
65
+ body["seed"] = seed
66
+ params = {"forced_level": forced_level} if forced_level else None
67
+ r = self._client.post("/reset", json=body, params=params)
68
+ r.raise_for_status()
69
+ payload = r.json()
70
+ # OpenEnv returns {observation: {...}, reward, done}.
71
+ return _obs_from_openenv(payload.get("observation", payload))
72
+
73
+ def step(self, *, raw_response: str, episode_id: int) -> StepResult:
74
+ body = {
75
+ "action": {
76
+ "raw_response": raw_response,
77
+ "episode_id": episode_id,
78
+ },
79
+ }
80
+ r = self._client.post("/step", json=body)
81
+ r.raise_for_status()
82
+ payload = r.json()
83
+ obs_payload = payload.get("observation", {})
84
+ return StepResult(
85
+ observation=_obs_from_openenv(obs_payload),
86
+ reward=float(payload.get("reward", 0.0) or 0.0),
87
+ done=bool(payload.get("done", True)),
88
+ truncated=bool(obs_payload.get("info", {}).get("timed_out", False)),
89
+ info=dict(obs_payload.get("info", {})),
90
+ )
91
+
92
+ def health(self) -> dict:
93
+ r = self._client.get("/health")
94
+ r.raise_for_status()
95
+ return r.json()
96
+
97
+ def healthz(self) -> dict:
98
+ r = self._client.get("/healthz")
99
+ r.raise_for_status()
100
+ return r.json()
101
+
102
+ def close(self) -> None:
103
+ self._client.close()
104
+
105
+
106
+ class LocalDecoderClient:
107
+ """In-process client - calls :class:`DecoderEnvironment` directly."""
108
+
109
+ def __init__(self, env=None) -> None:
110
+ from qubit_medic.server.environment import DecoderEnvironment
111
+ self._env = env if env is not None else DecoderEnvironment()
112
+
113
+ def reset(self, *, seed: Optional[int] = None,
114
+ forced_level: Optional[str] = None) -> DecoderObservation:
115
+ return self._env.reset(seed=seed, forced_level=forced_level)
116
+
117
+ def step(self, *, raw_response: str, episode_id: int) -> StepResult:
118
+ return self._env.step(raw_response=raw_response, episode_id=episode_id)
119
+
120
+ def health(self) -> dict:
121
+ return self._env.health()
122
+
123
+ def close(self) -> None: # nothing to clean up
124
+ pass
125
+
126
+
127
+ def make_default_client() -> _ClientProtocol:
128
+ """Return :class:`DecoderClient` if ``QUBIT_MEDIC_URL`` is set, else local."""
129
+ url = os.getenv("QUBIT_MEDIC_URL")
130
+ if url:
131
+ return DecoderClient(url)
132
+ return LocalDecoderClient()
qubit_medic/config.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Locked experiment configuration (Section 1.4 of the plan).
2
+
3
+ Every magic number in the project lives here. Do not hard-code circuit
4
+ parameters, noise rates, or model identifiers anywhere else; import them
5
+ from this module instead.
6
+
7
+ Cited literature
8
+ ----------------
9
+ Bausch et al., AlphaQubit, *Nature* 635:834 (2024)
10
+ DOI: 10.1038/s41586-024-08148-8
11
+ https://www.nature.com/articles/s41586-024-08148-8
12
+ Acharya et al. (Google QAI), *Willow*, arXiv:2408.13687 (2024)
13
+ https://arxiv.org/abs/2408.13687
14
+ Gidney & Fowler, *SI1000*, arXiv:2108.10457 (2021)
15
+ https://arxiv.org/abs/2108.10457
16
+ Higgott & Gidney, *PyMatching v2*, arXiv:2303.15933 (2023)
17
+ https://arxiv.org/abs/2303.15933
18
+ Shao et al., *DeepSeekMath / GRPO*, arXiv:2402.03300 (2024)
19
+ https://arxiv.org/abs/2402.03300
20
+ """
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass, field
24
+ from typing import Mapping
25
+
26
+
27
+ # --------------------------------------------------------------------------- #
28
+ # Quantum code geometry #
29
+ # --------------------------------------------------------------------------- #
30
+
31
+ CODE_TASK = "surface_code:rotated_memory_z"
32
+ """Stim task identifier. We always use the rotated surface code with a Z
33
+ memory experiment - same family AlphaQubit and Willow report on."""
34
+
35
+ DISTANCE_PRIMARY: int = 3
36
+ """Distance-3 is the primary benchmark configuration (AlphaQubit Fig. 2b)."""
37
+
38
+ DISTANCE_STRETCH: int = 5
39
+ """Distance-5 is the stretch-goal configuration for Section 4.3."""
40
+
41
+ ROUNDS_FACTOR: int = 1
42
+ """rounds = ROUNDS_FACTOR * distance. Value 1 matches the AlphaQubit
43
+ distance-equals-rounds protocol."""
44
+
45
+
46
+ # --------------------------------------------------------------------------- #
47
+ # Noise model: SI1000 sub-rates (Gidney & Fowler 2021, Table 1) #
48
+ # --------------------------------------------------------------------------- #
49
+ # SI1000 maps a single physical error budget ``p`` to four operation-specific
50
+ # sub-rates. The factors below come from arXiv:2108.10457 Table 1 and are the
51
+ # *same* values Google's QAI uses in their Willow analyses.
52
+ #
53
+ # Stim's surface_code:rotated_memory_z generator accepts four matching knobs:
54
+ # after_clifford_depolarization (two-qubit gate noise)
55
+ # before_round_data_depolarization (idle data-qubit noise per round)
56
+ # before_measure_flip_probability (measurement noise)
57
+ # after_reset_flip_probability (reset noise)
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class SI1000Rates:
62
+ """Per-operation error rates derived from a single budget ``p``."""
63
+
64
+ after_clifford_depolarization: float
65
+ before_round_data_depolarization: float
66
+ before_measure_flip_probability: float
67
+ after_reset_flip_probability: float
68
+
69
+ @classmethod
70
+ def from_p(cls, p: float) -> "SI1000Rates":
71
+ """Build SI1000 sub-rates from the headline budget ``p``.
72
+
73
+ The factors are exactly Gidney & Fowler 2021 Table 1.
74
+ """
75
+ return cls(
76
+ after_clifford_depolarization=p,
77
+ before_round_data_depolarization=p / 10.0,
78
+ before_measure_flip_probability=p * 5.0,
79
+ after_reset_flip_probability=p * 2.0,
80
+ )
81
+
82
+ def as_stim_kwargs(self) -> Mapping[str, float]:
83
+ """Return the kwargs dict accepted by ``stim.Circuit.generated``."""
84
+ return {
85
+ "after_clifford_depolarization": self.after_clifford_depolarization,
86
+ "before_round_data_depolarization": self.before_round_data_depolarization,
87
+ "before_measure_flip_probability": self.before_measure_flip_probability,
88
+ "after_reset_flip_probability": self.after_reset_flip_probability,
89
+ }
90
+
91
+
92
+ # --------------------------------------------------------------------------- #
93
+ # Curriculum levels (Section 4) #
94
+ # --------------------------------------------------------------------------- #
95
+
96
+
97
+ @dataclass(frozen=True)
98
+ class CurriculumLevel:
99
+ """One rung on the difficulty ladder."""
100
+
101
+ name: str
102
+ distance: int
103
+ rounds: int
104
+ p: float
105
+ promotion_threshold: float # logical-correction rate at which we move on
106
+ eval_size: int # held-out shots used to test promotion
107
+
108
+
109
+ CURRICULUM: tuple[CurriculumLevel, ...] = (
110
+ CurriculumLevel(
111
+ name="L1_warmup",
112
+ distance=DISTANCE_PRIMARY,
113
+ rounds=1,
114
+ p=0.0001,
115
+ promotion_threshold=0.80,
116
+ eval_size=100,
117
+ ),
118
+ CurriculumLevel(
119
+ name="L2_target",
120
+ distance=DISTANCE_PRIMARY,
121
+ rounds=DISTANCE_PRIMARY,
122
+ p=0.001,
123
+ promotion_threshold=0.70,
124
+ eval_size=200,
125
+ ),
126
+ CurriculumLevel(
127
+ name="L3_stretch",
128
+ distance=DISTANCE_STRETCH,
129
+ rounds=DISTANCE_STRETCH,
130
+ p=0.001,
131
+ promotion_threshold=0.30, # stretch goal - even partial counts
132
+ eval_size=200,
133
+ ),
134
+ )
135
+
136
+
137
+ # --------------------------------------------------------------------------- #
138
+ # Reward weights (Section 3) - sum to 1.0 by construction #
139
+ # --------------------------------------------------------------------------- #
140
+
141
+ REWARD_WEIGHTS: dict[str, float] = {
142
+ "logical_correction": 0.40, # Reward 1 - the unfakeable ground truth
143
+ "syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
144
+ "hamming_overlap": 0.20, # Reward 3 - dense partial credit
145
+ "format_compliance": 0.10, # Reward 4 - parser must succeed
146
+ "pymatching_beat": 0.10, # Reward 5 - the headline metric
147
+ }
148
+ assert abs(sum(REWARD_WEIGHTS.values()) - 1.0) < 1e-9, "reward weights must sum to 1"
149
+
150
+
151
+ # --------------------------------------------------------------------------- #
152
+ # Reproducibility #
153
+ # --------------------------------------------------------------------------- #
154
+
155
+ SEEDS: tuple[int, ...] = (42, 1337, 2024)
156
+ """Three seeds for error bars - never run with anything else."""
157
+
158
+ PRIMARY_SEED: int = SEEDS[0]
159
+
160
+
161
+ # --------------------------------------------------------------------------- #
162
+ # Model + training #
163
+ # --------------------------------------------------------------------------- #
164
+
165
+ MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
166
+ """3B params, 4-bit quantised + LoRA fits in a Colab T4."""
167
+
168
+ LORA_R: int = 16
169
+ LORA_ALPHA: int = 32
170
+ LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
171
+
172
+ SFT_EPOCHS: int = 1
173
+ SFT_BATCH_SIZE: int = 4
174
+ SFT_GRAD_ACCUM: int = 4
175
+ SFT_LR: float = 2e-4
176
+ SFT_DATASET_SIZE: int = 5_000
177
+ SFT_MAX_SEQ_LEN: int = 2048
178
+
179
+ GRPO_STEPS: int = 2_000
180
+ GRPO_GEN_PER_PROMPT: int = 4
181
+ GRPO_LR: float = 1e-5
182
+ GRPO_KL_COEF: float = 0.04
183
+ GRPO_MAX_PROMPT_LEN: int = 512
184
+ GRPO_MAX_COMPLETION_LEN: int = 256
185
+ GRPO_CHECKPOINT_EVERY: int = 250
186
+ GRPO_LOG_EVERY: int = 50
187
+
188
+ # Decoding sampler defaults at evaluation/format-test time.
189
+ SAMPLE_TEMPERATURE: float = 0.7
190
+ SAMPLE_TOP_P: float = 0.95
191
+
192
+
193
+ # --------------------------------------------------------------------------- #
194
+ # Server / deployment #
195
+ # --------------------------------------------------------------------------- #
196
+
197
+ EPISODE_TIMEOUT_SECONDS: float = 30.0
198
+ """Wall-clock budget per episode (Section 2.6)."""
199
+
200
+ DEFAULT_HOST: str = "0.0.0.0"
201
+ DEFAULT_PORT: int = 7860 # Hugging Face Spaces' default exposed port
202
+
203
+
204
+ # --------------------------------------------------------------------------- #
205
+ # Weights & Biases #
206
+ # --------------------------------------------------------------------------- #
207
+ # Centralised so the SFT trainer, GRPO trainer, eval script, and notebook
208
+ # all log to the same project / dashboard. Override per-run on the CLI.
209
+ import os as _os # noqa: E402 (local import to keep top of module clean)
210
+
211
+ WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "qubit-medic")
212
+ """Default W&B project name. Override with ``WANDB_PROJECT=...``."""
213
+
214
+ WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY") or None
215
+ """W&B team or username. ``None`` -> wandb's default entity for the user."""
216
+
217
+ WANDB_DEFAULT_TAGS: tuple[str, ...] = (
218
+ "qubit-medic",
219
+ "quantum-error-correction",
220
+ "openenv",
221
+ f"distance-{DISTANCE_PRIMARY}",
222
+ "si1000",
223
+ )
224
+ """Tags applied to every W&B run (per-script tags appended on top)."""
225
+
226
+ WANDB_LOG_GENERATIONS_EVERY: int = 50
227
+ """Log a sample-completion table every N GRPO steps."""
228
+
229
+ WANDB_SAMPLE_GENERATIONS: int = 8
230
+ """Number of generations included in each sample-completion table."""
231
+
232
+ WANDB_INLOOP_EVAL_EVERY: int = 200
233
+ """Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
234
+ syndromes) every N GRPO steps. Set to 0 to disable."""
235
+
236
+ WANDB_INLOOP_EVAL_EPISODES: int = 50
237
+ """Number of held-out syndromes per in-loop eval pass (kept small for speed)."""
238
+
239
+
240
+ # --------------------------------------------------------------------------- #
241
+ # Convenience accessors #
242
+ # --------------------------------------------------------------------------- #
243
+
244
+
245
+ def level_by_name(name: str) -> CurriculumLevel:
246
+ for lvl in CURRICULUM:
247
+ if lvl.name == name:
248
+ return lvl
249
+ raise KeyError(f"unknown curriculum level {name!r}")
250
+
251
+
252
+ def primary_level() -> CurriculumLevel:
253
+ """The L2 target benchmark - what the headline numbers come from."""
254
+ return level_by_name("L2_target")
qubit_medic/models.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic data classes shared by client and server (Section 2.2 of the plan).
2
+
3
+ Three classes draw the trust boundary:
4
+
5
+ * ``DecoderObservation`` - what the LLM sees on each step.
6
+ * ``DecoderAction`` - what the LLM emits (after parsing).
7
+ * ``DecoderState`` - server-side state, never serialised to the client.
8
+
9
+ Keeping the wire schema explicit is what closes off reward-hacking attacks:
10
+ the LLM literally cannot reach into the ``true_error_pattern`` because that
11
+ field is not in any class it ever receives.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ from typing import Any, Optional
16
+
17
+ from pydantic import BaseModel, ConfigDict, Field
18
+
19
+
20
+ # --------------------------------------------------------------------------- #
21
+ # Wire types - sent across the OpenEnv HTTP boundary #
22
+ # --------------------------------------------------------------------------- #
23
+
24
+
25
+ class DecoderObservation(BaseModel):
26
+ """The view the LLM (and only the LLM) sees on each step."""
27
+
28
+ model_config = ConfigDict(frozen=True)
29
+
30
+ prompt: str = Field(
31
+ ...,
32
+ description=(
33
+ "Pre-formatted prompt string. This is exactly what the trainer "
34
+ "passes to the policy - it appears verbatim in training logs."
35
+ ),
36
+ )
37
+ syndrome_bits: list[int] = Field(
38
+ ...,
39
+ description=(
40
+ "Raw detector activations (0/1). Provided for debugging and "
41
+ "reward-hacking audits; the LLM should be reading the prompt, not "
42
+ "this array."
43
+ ),
44
+ )
45
+ distance: int = Field(..., description="Code distance for this episode.")
46
+ rounds: int = Field(..., description="Number of stabiliser rounds.")
47
+ p: float = Field(..., description="Physical error budget (SI1000 base).")
48
+ curriculum_level: str = Field(..., description="Curriculum level name.")
49
+ episode_id: int = Field(..., description="Monotonic episode counter.")
50
+ dem_digest: str = Field(
51
+ ...,
52
+ description=(
53
+ "Short hash of the detector error model used this episode. The "
54
+ "trainer logs this so we can group rollouts by physics config."
55
+ ),
56
+ )
57
+
58
+
59
+ class DecoderAction(BaseModel):
60
+ """Action emitted by the LLM after parsing.
61
+
62
+ ``raw_response`` is preserved exactly so we can satisfy the participant
63
+ guide's *inspect generations* mandate (Section 2.5 of the plan).
64
+ """
65
+
66
+ model_config = ConfigDict(frozen=True)
67
+
68
+ x_error_qubits: list[int] = Field(default_factory=list)
69
+ z_error_qubits: list[int] = Field(default_factory=list)
70
+ raw_response: str = ""
71
+ parse_success: bool = True
72
+
73
+
74
+ class StepResult(BaseModel):
75
+ """Standard env step return (mirrors OpenEnv core/Gymnasium)."""
76
+
77
+ observation: DecoderObservation
78
+ reward: float
79
+ done: bool
80
+ truncated: bool = False
81
+ info: dict[str, Any] = Field(default_factory=dict)
82
+
83
+
84
+ class ResetRequest(BaseModel):
85
+ """Optional knobs the trainer can pass to ``reset``."""
86
+
87
+ seed: Optional[int] = None
88
+ forced_level: Optional[str] = Field(
89
+ default=None,
90
+ description=(
91
+ "Override the curriculum scheduler. Used by eval scripts that "
92
+ "want a specific (distance, rounds, p) configuration."
93
+ ),
94
+ )
95
+
96
+
97
+ class StepRequest(BaseModel):
98
+ """The trainer sends the LLM's raw text; the server parses + scores."""
99
+
100
+ raw_response: str
101
+ episode_id: int
102
+
103
+
104
+ # --------------------------------------------------------------------------- #
105
+ # Server-only state - intentionally NOT a wire type #
106
+ # --------------------------------------------------------------------------- #
107
+
108
+
109
+ class DecoderState(BaseModel):
110
+ """Per-episode state kept on the server; never sent to the client.
111
+
112
+ Pydantic ``arbitrary_types_allowed`` is on because we hold a reference to
113
+ a ``stim.Circuit`` object. The state is not serialised over HTTP - it
114
+ lives in the server's per-episode dict and is discarded on ``done``.
115
+ """
116
+
117
+ model_config = ConfigDict(arbitrary_types_allowed=True, frozen=False)
118
+
119
+ episode_id: int
120
+ seed: int
121
+ curriculum_level: str
122
+ distance: int
123
+ rounds: int
124
+ p: float
125
+
126
+ syndrome_bits: list[int]
127
+ true_x_errors: list[int]
128
+ true_z_errors: list[int]
129
+ actual_observable_flip: int # 0 or 1; the unfakeable ground truth
130
+ pymatching_observable_pred: int # 0 or 1; baseline's prediction
131
+
132
+ # Pre-computed quantities the reward functions need.
133
+ x_observable_support: list[int] # data qubits whose Z error flips X obs
134
+ z_observable_support: list[int] # data qubits whose X error flips Z obs
135
+ num_data_qubits: int
136
+ num_stabilizers: int
137
+
138
+ # Stim/PyMatching objects - kept opaque to satisfy Pydantic.
139
+ circuit_text: str
140
+ dem_text: str
141
+
142
+ # Reward audit log.
143
+ last_reward_breakdown: dict[str, float] = Field(default_factory=dict)
qubit_medic/prompts.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt formatter and action parser (Section 2.3 + Section 2.5 of the plan).
2
+
3
+ The prompt is engineered around five sections:
4
+
5
+ 1. Role declaration
6
+ 2. Physics summary (~50 tokens, plain English)
7
+ 3. Syndrome data (round-by-round, labelled)
8
+ 4. Output format spec (one example included)
9
+ 5. Reasoning trigger ("think step by step ...")
10
+
11
+ Total budget ~250-300 tokens for the prompt; ~150 for the response.
12
+
13
+ The parser is deliberately permissive on whitespace and bracket style but
14
+ strict on the existence of the two key tokens ``X_ERRORS`` and ``Z_ERRORS``.
15
+ A partial-credit hook is exposed so Reward 4 can hand out 0.5 for "partly
16
+ parseable".
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import re
21
+ from dataclasses import dataclass
22
+ from typing import Iterable
23
+
24
+
25
+ # --------------------------------------------------------------------------- #
26
+ # Prompt formatting #
27
+ # --------------------------------------------------------------------------- #
28
+
29
+ _ROLE = (
30
+ "You are a quantum error-correction decoder. You are decoding errors in "
31
+ "a distance-{distance} rotated surface code memory experiment."
32
+ )
33
+
34
+ _PHYSICS_SUMMARY = (
35
+ "Stabilizers are parity checks measured every round. A *syndrome bit* "
36
+ "is 1 when a stabilizer's measurement disagrees with its previous round, "
37
+ "indicating a nearby physical error. Your job is to look at the syndrome "
38
+ "history and output the smallest physical error pattern (X-flips and "
39
+ "Z-flips on data qubits, identified by integer IDs) that explains it."
40
+ )
41
+
42
+ _OUTPUT_SPEC = (
43
+ "Output format (REQUIRED, exact):\n"
44
+ " X_ERRORS=[id1,id2,...] Z_ERRORS=[id1,id2,...]\n"
45
+ "Use empty lists when no errors of that type. Example with no errors:\n"
46
+ " X_ERRORS=[] Z_ERRORS=[]"
47
+ )
48
+
49
+ _REASONING_TRIGGER = (
50
+ "Think step by step about which qubits could have caused this syndrome, "
51
+ "then output your prediction in the required format."
52
+ )
53
+
54
+
55
+ def format_syndrome_block(
56
+ syndrome_bits: Iterable[int],
57
+ rounds: int,
58
+ num_x_stabilizers: int,
59
+ num_z_stabilizers: int,
60
+ ) -> str:
61
+ """Render the detector activations round-by-round.
62
+
63
+ Stim emits detectors in a flat row-major order: round 0 stabilisers first,
64
+ then round 1, and so on. We label by round and stabiliser type so the LLM
65
+ can read the temporal structure.
66
+ """
67
+ bits = list(syndrome_bits)
68
+ per_round = num_x_stabilizers + num_z_stabilizers
69
+ lines = ["Syndrome (round-by-round):"]
70
+ if per_round == 0 or rounds == 0 or len(bits) == 0:
71
+ lines.append(" (no detectors fired)")
72
+ return "\n".join(lines)
73
+
74
+ for r in range(rounds):
75
+ offset = r * per_round
76
+ if offset >= len(bits):
77
+ break
78
+ chunk = bits[offset : offset + per_round]
79
+ x_chunk = chunk[:num_x_stabilizers]
80
+ z_chunk = chunk[num_x_stabilizers : num_x_stabilizers + num_z_stabilizers]
81
+ lines.append(
82
+ f" Round {r + 1} X-stabilizers: "
83
+ + " ".join(str(b) for b in x_chunk)
84
+ )
85
+ lines.append(
86
+ f" Round {r + 1} Z-stabilizers: "
87
+ + " ".join(str(b) for b in z_chunk)
88
+ )
89
+ # Trailing block for the final destructive measurement, if any extras.
90
+ used = rounds * per_round
91
+ if used < len(bits):
92
+ tail = bits[used:]
93
+ lines.append(" Final-round detectors: " + " ".join(str(b) for b in tail))
94
+ return "\n".join(lines)
95
+
96
+
97
+ def build_prompt(
98
+ *,
99
+ distance: int,
100
+ rounds: int,
101
+ p: float,
102
+ syndrome_bits: list[int],
103
+ num_x_stabilizers: int,
104
+ num_z_stabilizers: int,
105
+ num_data_qubits: int,
106
+ ) -> str:
107
+ """Assemble the full prompt the LLM sees on each step.
108
+
109
+ Keeping this function pure (no I/O, no globals) means the SFT pipeline
110
+ and the GRPO rollout use byte-identical inputs - a critical invariant.
111
+ """
112
+ syndrome_block = format_syndrome_block(
113
+ syndrome_bits=syndrome_bits,
114
+ rounds=rounds,
115
+ num_x_stabilizers=num_x_stabilizers,
116
+ num_z_stabilizers=num_z_stabilizers,
117
+ )
118
+ return (
119
+ _ROLE.format(distance=distance)
120
+ + "\n\n"
121
+ + _PHYSICS_SUMMARY
122
+ + "\n\n"
123
+ + f"Code parameters: distance={distance}, rounds={rounds}, "
124
+ + f"physical_error_rate={p:g}, data_qubits=0..{num_data_qubits - 1}.\n\n"
125
+ + syndrome_block
126
+ + "\n\n"
127
+ + _OUTPUT_SPEC
128
+ + "\n\n"
129
+ + _REASONING_TRIGGER
130
+ )
131
+
132
+
133
+ # --------------------------------------------------------------------------- #
134
+ # Output parsing #
135
+ # --------------------------------------------------------------------------- #
136
+
137
+ _X_PATTERN = re.compile(r"X_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
138
+ _Z_PATTERN = re.compile(r"Z_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
139
+
140
+
141
+ @dataclass(frozen=True)
142
+ class ParseResult:
143
+ x_errors: list[int]
144
+ z_errors: list[int]
145
+ parse_success: bool # True iff BOTH X_ERRORS and Z_ERRORS parsed cleanly
146
+ parse_partial: bool # True iff exactly one of the two parsed cleanly
147
+ raw_response: str
148
+
149
+ @property
150
+ def format_score(self) -> float:
151
+ """Score for Reward 4 (format compliance)."""
152
+ if self.parse_success:
153
+ return 1.0
154
+ if self.parse_partial:
155
+ return 0.5
156
+ return 0.0
157
+
158
+
159
+ def _parse_int_list(s: str, max_qubit: int) -> tuple[list[int], bool]:
160
+ """Parse a comma/space-separated integer list. Drops out-of-range and dups.
161
+
162
+ Returns ``(qubits_sorted_unique, all_tokens_were_valid)``.
163
+ """
164
+ if not s.strip():
165
+ return [], True
166
+ raw_tokens = re.split(r"[\s,]+", s.strip())
167
+ out: list[int] = []
168
+ all_clean = True
169
+ for tok in raw_tokens:
170
+ if not tok:
171
+ continue
172
+ try:
173
+ v = int(tok)
174
+ except ValueError:
175
+ all_clean = False
176
+ continue
177
+ if 0 <= v < max_qubit:
178
+ out.append(v)
179
+ else:
180
+ all_clean = False
181
+ return sorted(set(out)), all_clean
182
+
183
+
184
+ def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
185
+ """Convert the LLM's raw text to a ``ParseResult``.
186
+
187
+ Tolerant of trailing chain-of-thought, surrounding code fences, and
188
+ casing, but strict on the existence of both ``X_ERRORS`` and ``Z_ERRORS``.
189
+ """
190
+ if not isinstance(raw_response, str):
191
+ return ParseResult([], [], False, False, raw_response="")
192
+
193
+ # If the model wrapped its answer in ```...``` blocks, focus on the last one.
194
+ fenced = re.findall(r"```(?:[^\n]*)\n(.*?)```", raw_response, re.DOTALL)
195
+ search_text = fenced[-1] if fenced else raw_response
196
+
197
+ x_match = _X_PATTERN.search(search_text)
198
+ z_match = _Z_PATTERN.search(search_text)
199
+
200
+ x_errors: list[int] = []
201
+ z_errors: list[int] = []
202
+ x_clean = z_clean = False
203
+
204
+ if x_match is not None:
205
+ x_errors, x_clean = _parse_int_list(x_match.group(1), num_data_qubits)
206
+ if z_match is not None:
207
+ z_errors, z_clean = _parse_int_list(z_match.group(1), num_data_qubits)
208
+
209
+ x_present = x_match is not None and x_clean
210
+ z_present = z_match is not None and z_clean
211
+ parse_success = x_present and z_present
212
+ parse_partial = (x_present ^ z_present) or (
213
+ # Both keys present but at least one had garbage tokens.
214
+ (x_match is not None and z_match is not None) and not parse_success
215
+ )
216
+
217
+ return ParseResult(
218
+ x_errors=x_errors,
219
+ z_errors=z_errors,
220
+ parse_success=parse_success,
221
+ parse_partial=parse_partial,
222
+ raw_response=raw_response,
223
+ )
224
+
225
+
226
+ def format_completion(x_errors: Iterable[int], z_errors: Iterable[int]) -> str:
227
+ """The canonical SFT target string. Inverse of :func:`parse_action`."""
228
+ x_str = ",".join(str(q) for q in sorted(set(x_errors)))
229
+ z_str = ",".join(str(q) for q in sorted(set(z_errors)))
230
+ return f"X_ERRORS=[{x_str}] Z_ERRORS=[{z_str}]"
qubit_medic/server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Server-side modules: physics, rewards, curriculum, FastAPI app.
2
+
3
+ Sub-modules are imported lazily on first attribute access to avoid
4
+ circular imports during partial initialisation.
5
+ """
qubit_medic/server/app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Qubit-Medic FastAPI server.
2
+
3
+ Built on **openenv-core** ``create_fastapi_app`` so the canonical OpenEnv
4
+ routes (``/reset``, ``/step``, ``/state``, ``/health``, ``/schema``,
5
+ ``/metadata``, ``/mcp``) are wired automatically by the framework.
6
+
7
+ We add a few extras on top:
8
+
9
+ * ``GET /healthz`` - the Day-0 deployment-substrate liveness probe
10
+ (returns Stim/PyMatching/openenv versions). Used by the recurring
11
+ 4-hour HF Spaces wakeup ping.
12
+ * ``POST /decode`` - PyMatching baseline demo: takes a hand-crafted
13
+ syndrome and returns the matching-decoder's prediction. Useful for
14
+ the Gradio playground.
15
+
16
+ Run with ``python -m qubit_medic.server.app`` or
17
+ ``uvicorn qubit_medic.server.app:app --host 0.0.0.0 --port 7860``.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ from typing import Optional
25
+
26
+ from fastapi import Body, HTTPException
27
+ from openenv.core import create_fastapi_app
28
+
29
+ from qubit_medic.config import DEFAULT_HOST, DEFAULT_PORT
30
+ from qubit_medic.server.environment import DecoderEnvironment
31
+ from qubit_medic.server.openenv_adapter import (
32
+ QubitMedicAction,
33
+ QubitMedicEnvironment,
34
+ QubitMedicObservation,
35
+ )
36
+
37
+
38
+ logger = logging.getLogger("qubit_medic.server")
39
+ logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
40
+
41
+
42
+ # --------------------------------------------------------------------------- #
43
+ # Build the OpenEnv-compliant FastAPI app #
44
+ # --------------------------------------------------------------------------- #
45
+
46
+ app = create_fastapi_app(
47
+ env=QubitMedicEnvironment,
48
+ action_cls=QubitMedicAction,
49
+ observation_cls=QubitMedicObservation,
50
+ )
51
+ app.title = "Qubit-Medic OpenEnv"
52
+ app.version = os.getenv("QUBIT_MEDIC_VERSION", "1.0.0")
53
+ app.description = (
54
+ "RL training environment for LLM-based quantum error-correction "
55
+ "decoders. Built on Stim + PyMatching with five independent verifiable "
56
+ "rewards (logical correction, syndrome consistency, Hamming overlap, "
57
+ "format compliance, PyMatching beat-rate). Wraps "
58
+ "qubit_medic.server.environment.DecoderEnvironment in "
59
+ "openenv.core.Environment - see /metadata, /schema, /docs."
60
+ )
61
+
62
+
63
+ # --------------------------------------------------------------------------- #
64
+ # Day-0 + demo extras #
65
+ # --------------------------------------------------------------------------- #
66
+
67
+ # Lazy-built *legacy* DecoderEnvironment for /decode demos. The OpenEnv
68
+ # adapter has its own per-instance DecoderEnvironment; we keep this one
69
+ # around for the simple synchronous `/decode` baseline endpoint.
70
+ _legacy_env: Optional[DecoderEnvironment] = None
71
+
72
+
73
+ def _get_legacy_env() -> DecoderEnvironment:
74
+ global _legacy_env
75
+ if _legacy_env is None:
76
+ _legacy_env = DecoderEnvironment()
77
+ _legacy_env._cache_for("L1_warmup") # noqa: SLF001
78
+ _legacy_env._cache_for("L2_target") # noqa: SLF001
79
+ return _legacy_env
80
+
81
+
82
+ @app.get("/healthz")
83
+ def healthz() -> dict:
84
+ """Lightweight liveness probe (Day-0 deployment-substrate test).
85
+
86
+ Returns the versions of Stim, PyMatching, and openenv so curl-ing
87
+ this in a browser or from Colab proves both that networking works
88
+ AND that the heavy quantum + RL deps actually loaded. Used by the
89
+ recurring 4-hour HF Spaces wakeup ping.
90
+ """
91
+ import stim
92
+ try:
93
+ import pymatching as _pm
94
+ pm_v = getattr(_pm, "__version__", "unknown")
95
+ except Exception as exc: # pragma: no cover - defensive
96
+ pm_v = f"import-error: {exc}"
97
+ try:
98
+ import openenv as _oe
99
+ oe_v = getattr(_oe, "__version__", "unknown")
100
+ except Exception as exc: # pragma: no cover - defensive
101
+ oe_v = f"import-error: {exc}"
102
+ return {
103
+ "ok": True,
104
+ "service": "qubit-medic",
105
+ "version": app.version,
106
+ "stim_version": stim.__version__,
107
+ "pymatching_version": pm_v,
108
+ "openenv_version": oe_v,
109
+ "python_version": sys.version.split()[0],
110
+ }
111
+
112
+
113
+ @app.post("/decode")
114
+ def decode(
115
+ syndrome: list[int] = Body(..., embed=True),
116
+ level: str = Body("L2_target", embed=True),
117
+ ) -> dict:
118
+ """Decode an arbitrary syndrome with PyMatching (baseline) and return
119
+ its predicted Pauli frame and observable flip.
120
+
121
+ Intended for the live Gradio demo: a notebook or web page can POST a
122
+ hand-crafted syndrome here and visualise the matching-decoder result.
123
+ """
124
+ import numpy as np
125
+
126
+ env = _get_legacy_env()
127
+ cache = env._cache_for(level) # noqa: SLF001
128
+ arr = np.asarray(syndrome, dtype=np.uint8)
129
+ if arr.shape[0] != cache.layout.num_detectors:
130
+ raise HTTPException(
131
+ status_code=400,
132
+ detail=f"syndrome length {arr.shape[0]} != "
133
+ f"{cache.layout.num_detectors} expected for {level}",
134
+ )
135
+ from qubit_medic.server.physics import (
136
+ predicted_observable_flip,
137
+ pymatching_predicted_pauli_frame,
138
+ )
139
+ pm_obs = int(cache.matching.decode(arr)[0])
140
+ px, pz = pymatching_predicted_pauli_frame(cache.matching, arr, cache.layout)
141
+ return {
142
+ "level": level,
143
+ "syndrome": syndrome,
144
+ "pymatching_observable_flip": pm_obs,
145
+ "pymatching_x_errors": px,
146
+ "pymatching_z_errors": pz,
147
+ "implied_observable_from_x_errors": predicted_observable_flip(
148
+ px, cache.layout
149
+ ),
150
+ }
151
+
152
+
153
+ # --------------------------------------------------------------------------- #
154
+ # Local entry point #
155
+ # --------------------------------------------------------------------------- #
156
+
157
+ def _main() -> None:
158
+ import uvicorn
159
+
160
+ uvicorn.run(
161
+ "qubit_medic.server.app:app",
162
+ host=os.getenv("QUBIT_MEDIC_HOST", DEFAULT_HOST),
163
+ port=int(os.getenv("QUBIT_MEDIC_PORT", str(DEFAULT_PORT))),
164
+ log_level=os.getenv("LOG_LEVEL", "info").lower(),
165
+ )
166
+
167
+
168
+ if __name__ == "__main__":
169
+ _main()
qubit_medic/server/curriculum.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adaptive curriculum scheduler (Section 4.4 of the plan).
2
+
3
+ Maintains a moving-average logical-correction rate per level and promotes
4
+ the agent to harder levels once the threshold is met. Implements the
5
+ Section 4.4 mixing rules:
6
+
7
+ * Stay at L1 until L1 hits 80%.
8
+ * Then mix L1/L2 with weights 30/70 until L2 hits 70%.
9
+ * Then unlock L3 at 30% weight (with L1/L2 sharing the remaining 70%).
10
+
11
+ The scheduler is *override-able* - eval scripts pass ``forced_level`` to
12
+ hold one configuration steady.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import random
17
+ from collections import deque
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+ from qubit_medic.config import CURRICULUM, CurriculumLevel, level_by_name
22
+
23
+
24
+ # --------------------------------------------------------------------------- #
25
+ # Per-level moving average #
26
+ # --------------------------------------------------------------------------- #
27
+
28
+
29
+ @dataclass
30
+ class _MovingWindow:
31
+ window_size: int = 100
32
+ history: deque[float] = field(default_factory=deque)
33
+
34
+ def push(self, value: float) -> None:
35
+ self.history.append(value)
36
+ while len(self.history) > self.window_size:
37
+ self.history.popleft()
38
+
39
+ def mean(self) -> float:
40
+ return sum(self.history) / len(self.history) if self.history else 0.0
41
+
42
+ def __len__(self) -> int:
43
+ return len(self.history)
44
+
45
+
46
+ # --------------------------------------------------------------------------- #
47
+ # Scheduler #
48
+ # --------------------------------------------------------------------------- #
49
+
50
+
51
+ @dataclass
52
+ class CurriculumScheduler:
53
+ """Picks a curriculum level for each new episode."""
54
+
55
+ rng: random.Random = field(default_factory=lambda: random.Random(42))
56
+ windows: dict[str, _MovingWindow] = field(default_factory=dict)
57
+
58
+ def __post_init__(self) -> None:
59
+ for lvl in CURRICULUM:
60
+ self.windows.setdefault(lvl.name, _MovingWindow())
61
+
62
+ # ----- public API -----------------------------------------------------
63
+
64
+ def update(self, level_name: str, logical_correction: float) -> None:
65
+ """Record one episode's logical-correction outcome."""
66
+ self.windows[level_name].push(float(logical_correction))
67
+
68
+ def sample(self, forced_level: Optional[str] = None) -> CurriculumLevel:
69
+ """Return the level to use for the next episode."""
70
+ if forced_level is not None:
71
+ return level_by_name(forced_level)
72
+
73
+ l1, l2, l3 = (level_by_name(n) for n in ("L1_warmup", "L2_target", "L3_stretch"))
74
+ l1_rate = self.windows["L1_warmup"].mean()
75
+ l2_rate = self.windows["L2_target"].mean()
76
+ l1_n = len(self.windows["L1_warmup"])
77
+ l2_n = len(self.windows["L2_target"])
78
+
79
+ # Phase A: still working on L1.
80
+ if l1_n < 30 or l1_rate < l1.promotion_threshold:
81
+ return l1
82
+
83
+ # Phase B: L1 unlocked, mixing L1 (30%) and L2 (70%).
84
+ if l2_n < 30 or l2_rate < l2.promotion_threshold:
85
+ return l1 if self.rng.random() < 0.30 else l2
86
+
87
+ # Phase C: L3 unlocked, splits 20% L1, 50% L2, 30% L3.
88
+ roll = self.rng.random()
89
+ if roll < 0.20:
90
+ return l1
91
+ if roll < 0.70:
92
+ return l2
93
+ return l3
94
+
95
+ # ----- introspection (used by /state endpoint and logs) ---------------
96
+
97
+ def stats(self) -> dict[str, dict[str, float]]:
98
+ return {
99
+ name: {
100
+ "moving_mean": w.mean(),
101
+ "samples": float(len(w)),
102
+ }
103
+ for name, w in self.windows.items()
104
+ }
qubit_medic/server/environment.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DecoderEnvironment: the OpenEnv-style env that the LLM trainer talks to.
2
+
3
+ This is the heart of the server (Sections 2.4 + 2.5 of the plan):
4
+
5
+ * ``reset()``: pick a curriculum level, build a circuit, sample a syndrome,
6
+ return a :class:`DecoderObservation`.
7
+ * ``step(raw_response)``: parse the LLM's text, score five rewards, return
8
+ a :class:`StepResult` whose ``info`` dict carries the per-component
9
+ breakdown.
10
+
11
+ Episodes are single-step (Section 2.5): the LLM emits one prediction and
12
+ the episode ends.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import threading
17
+ import time
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+ import pymatching
22
+
23
+ from qubit_medic.config import (
24
+ EPISODE_TIMEOUT_SECONDS,
25
+ PRIMARY_SEED,
26
+ REWARD_WEIGHTS,
27
+ )
28
+ from qubit_medic.models import (
29
+ DecoderAction,
30
+ DecoderObservation,
31
+ DecoderState,
32
+ StepResult,
33
+ )
34
+ from qubit_medic.prompts import build_prompt, parse_action
35
+ from qubit_medic.server import physics
36
+ from qubit_medic.server.curriculum import CurriculumScheduler
37
+ from qubit_medic.server.physics import (
38
+ CircuitLayout,
39
+ SyndromeSample,
40
+ build_circuit,
41
+ build_dem,
42
+ dem_digest,
43
+ extract_layout,
44
+ per_round_x_z_counts,
45
+ sample_episode,
46
+ )
47
+ from qubit_medic.server.rewards import (
48
+ RewardBreakdown,
49
+ compute_all_rewards,
50
+ compute_final_detector_supports,
51
+ )
52
+
53
+
54
+ # --------------------------------------------------------------------------- #
55
+ # Per-level cached compilation - building Stim/PyMatching is the slow step #
56
+ # --------------------------------------------------------------------------- #
57
+
58
+
59
+ @dataclass
60
+ class _LevelCache:
61
+ """Compiled Stim/PyMatching artefacts for one curriculum level."""
62
+ circuit: object
63
+ dem: object
64
+ matching: pymatching.Matching
65
+ layout: CircuitLayout
66
+ final_detector_supports: dict
67
+ dem_digest: str
68
+
69
+ @classmethod
70
+ def build(cls, level) -> "_LevelCache":
71
+ c = build_circuit(level)
72
+ d = build_dem(c)
73
+ m = pymatching.Matching.from_detector_error_model(d)
74
+ layout = extract_layout(c)
75
+ supports = compute_final_detector_supports(layout)
76
+ return cls(
77
+ circuit=c,
78
+ dem=d,
79
+ matching=m,
80
+ layout=layout,
81
+ final_detector_supports=supports,
82
+ dem_digest=dem_digest(d),
83
+ )
84
+
85
+
86
+ # --------------------------------------------------------------------------- #
87
+ # DecoderEnvironment #
88
+ # --------------------------------------------------------------------------- #
89
+
90
+
91
+ @dataclass
92
+ class _ActiveEpisode:
93
+ """In-flight episode bookkeeping."""
94
+ state: DecoderState
95
+ sample: SyndromeSample
96
+ layout: CircuitLayout
97
+ final_detector_supports: dict
98
+ started_at: float
99
+
100
+
101
+ class DecoderEnvironment:
102
+ """OpenEnv-style env for surface-code decoding.
103
+
104
+ Thread-safe by virtue of a single ``_lock``: the FastAPI server is
105
+ expected to be I/O bound, and per-call latency is well under a
106
+ millisecond, so a coarse lock is fine and dramatically simplifies the
107
+ state machine.
108
+ """
109
+
110
+ def __init__(self, *, base_seed: int = PRIMARY_SEED) -> None:
111
+ self._lock = threading.Lock()
112
+ self._scheduler = CurriculumScheduler(rng=__import__("random").Random(base_seed))
113
+ self._caches: dict[str, _LevelCache] = {}
114
+ self._episode_counter = 0
115
+ self._base_seed = base_seed
116
+ self._active: dict[int, _ActiveEpisode] = {}
117
+
118
+ # ----- cache helpers --------------------------------------------------
119
+
120
+ def _cache_for(self, level_name: str):
121
+ cache = self._caches.get(level_name)
122
+ if cache is not None:
123
+ return cache
124
+ from qubit_medic.config import level_by_name
125
+ cache = _LevelCache.build(level_by_name(level_name))
126
+ self._caches[level_name] = cache
127
+ return cache
128
+
129
+ # ----- public API -----------------------------------------------------
130
+
131
+ def reset(
132
+ self,
133
+ *,
134
+ seed: Optional[int] = None,
135
+ forced_level: Optional[str] = None,
136
+ ) -> DecoderObservation:
137
+ with self._lock:
138
+ self._episode_counter += 1
139
+ ep_id = self._episode_counter
140
+ shot_seed = seed if seed is not None else self._base_seed + ep_id
141
+ level = self._scheduler.sample(forced_level=forced_level)
142
+ cache = self._cache_for(level.name)
143
+
144
+ sample = sample_episode(
145
+ circuit=cache.circuit,
146
+ matching=cache.matching,
147
+ layout=cache.layout,
148
+ seed=shot_seed,
149
+ )
150
+
151
+ state = DecoderState(
152
+ episode_id=ep_id,
153
+ seed=shot_seed,
154
+ curriculum_level=level.name,
155
+ distance=level.distance,
156
+ rounds=level.rounds,
157
+ p=level.p,
158
+ syndrome_bits=sample.syndrome_bits,
159
+ true_x_errors=sample.pymatching_x_errors,
160
+ true_z_errors=sample.pymatching_z_errors,
161
+ actual_observable_flip=sample.actual_observable_flip,
162
+ pymatching_observable_pred=sample.pymatching_observable_pred,
163
+ x_observable_support=[], # memory_z task: no X observable
164
+ z_observable_support=list(cache.layout.z_observable_support),
165
+ num_data_qubits=cache.layout.num_data_qubits,
166
+ num_stabilizers=cache.layout.num_ancilla_qubits,
167
+ circuit_text=str(cache.circuit),
168
+ dem_text=str(cache.dem),
169
+ )
170
+ self._active[ep_id] = _ActiveEpisode(
171
+ state=state,
172
+ sample=sample,
173
+ layout=cache.layout,
174
+ final_detector_supports=cache.final_detector_supports,
175
+ started_at=time.monotonic(),
176
+ )
177
+
178
+ n_x, n_z = per_round_x_z_counts(cache.layout)
179
+ prompt = build_prompt(
180
+ distance=level.distance,
181
+ rounds=level.rounds,
182
+ p=level.p,
183
+ syndrome_bits=sample.syndrome_bits,
184
+ num_x_stabilizers=n_x,
185
+ num_z_stabilizers=n_z,
186
+ num_data_qubits=cache.layout.num_data_qubits,
187
+ )
188
+
189
+ return DecoderObservation(
190
+ prompt=prompt,
191
+ syndrome_bits=sample.syndrome_bits,
192
+ distance=level.distance,
193
+ rounds=level.rounds,
194
+ p=level.p,
195
+ curriculum_level=level.name,
196
+ episode_id=ep_id,
197
+ dem_digest=cache.dem_digest,
198
+ )
199
+
200
+ def step(self, raw_response: str, episode_id: int) -> StepResult:
201
+ with self._lock:
202
+ episode = self._active.pop(episode_id, None)
203
+ if episode is None:
204
+ # Calling step() on an unknown episode ID is a hard error -
205
+ # the trainer didn't follow reset/step pairing.
206
+ raise KeyError(f"unknown or already-finished episode {episode_id}")
207
+
208
+ elapsed = time.monotonic() - episode.started_at
209
+ timed_out = elapsed > EPISODE_TIMEOUT_SECONDS
210
+
211
+ parsed = parse_action(
212
+ raw_response=raw_response,
213
+ num_data_qubits=episode.layout.num_data_qubits,
214
+ )
215
+
216
+ if timed_out:
217
+ # Hard timeout: zero reward, mark format compliance as zero,
218
+ # close the episode cleanly (Section 2.6).
219
+ breakdown = RewardBreakdown(
220
+ logical_correction=0.0,
221
+ syndrome_consistency=0.0,
222
+ hamming_overlap=0.0,
223
+ format_compliance=0.0,
224
+ pymatching_beat=0.0,
225
+ total=0.0,
226
+ )
227
+ action = DecoderAction(
228
+ raw_response=raw_response,
229
+ parse_success=False,
230
+ )
231
+ else:
232
+ # Convert LLM-space qubit IDs (0..N-1) to Stim IDs before
233
+ # scoring; rewards operate in the Stim coordinate system.
234
+ from qubit_medic.prompts import ParseResult
235
+ parsed_stim = ParseResult(
236
+ x_errors=episode.layout.llm_to_stim(parsed.x_errors),
237
+ z_errors=episode.layout.llm_to_stim(parsed.z_errors),
238
+ parse_success=parsed.parse_success,
239
+ parse_partial=parsed.parse_partial,
240
+ raw_response=parsed.raw_response,
241
+ )
242
+ breakdown = compute_all_rewards(
243
+ parsed=parsed_stim,
244
+ sample=episode.sample,
245
+ layout=episode.layout,
246
+ final_detector_supports=episode.final_detector_supports,
247
+ weights=REWARD_WEIGHTS,
248
+ )
249
+ action = DecoderAction(
250
+ x_error_qubits=parsed.x_errors,
251
+ z_error_qubits=parsed.z_errors,
252
+ raw_response=raw_response,
253
+ parse_success=parsed.parse_success,
254
+ )
255
+
256
+ self._scheduler.update(
257
+ episode.state.curriculum_level,
258
+ logical_correction=breakdown.logical_correction,
259
+ )
260
+
261
+ episode.state.last_reward_breakdown = breakdown.as_dict()
262
+
263
+ n_x, n_z = per_round_x_z_counts(episode.layout)
264
+ prompt = build_prompt(
265
+ distance=episode.state.distance,
266
+ rounds=episode.state.rounds,
267
+ p=episode.state.p,
268
+ syndrome_bits=episode.state.syndrome_bits,
269
+ num_x_stabilizers=n_x,
270
+ num_z_stabilizers=n_z,
271
+ num_data_qubits=episode.layout.num_data_qubits,
272
+ )
273
+ obs = DecoderObservation(
274
+ prompt=prompt,
275
+ syndrome_bits=episode.state.syndrome_bits,
276
+ distance=episode.state.distance,
277
+ rounds=episode.state.rounds,
278
+ p=episode.state.p,
279
+ curriculum_level=episode.state.curriculum_level,
280
+ episode_id=episode.state.episode_id,
281
+ dem_digest=episode.state.dem_text[:8],
282
+ )
283
+
284
+ info = {
285
+ "rewards": breakdown.as_dict(),
286
+ "parsed_action": action.model_dump(),
287
+ "actual_observable_flip": episode.sample.actual_observable_flip,
288
+ "pymatching_observable_pred": episode.sample.pymatching_observable_pred,
289
+ "pymatching_x_errors": episode.sample.pymatching_x_errors,
290
+ "pymatching_z_errors": episode.sample.pymatching_z_errors,
291
+ "elapsed_seconds": elapsed,
292
+ "timed_out": timed_out,
293
+ "curriculum_stats": self._scheduler.stats(),
294
+ }
295
+
296
+ return StepResult(
297
+ observation=obs,
298
+ reward=breakdown.total,
299
+ done=True, # single-step episodes
300
+ truncated=timed_out,
301
+ info=info,
302
+ )
303
+
304
+ # ----- introspection --------------------------------------------------
305
+
306
+ def health(self) -> dict:
307
+ with self._lock:
308
+ return {
309
+ "ok": True,
310
+ "episodes_started": self._episode_counter,
311
+ "active_episodes": len(self._active),
312
+ "curriculum": self._scheduler.stats(),
313
+ "cached_levels": list(self._caches.keys()),
314
+ }
qubit_medic/server/openenv_adapter.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv-compliant adapter around :class:`DecoderEnvironment`.
2
+
3
+ This wrapper satisfies the submission requirement *"Use OpenEnv (latest
4
+ release). Build on top of the framework; don't reinvent the wheel."* by
5
+ exposing our underlying :class:`qubit_medic.server.environment.DecoderEnvironment`
6
+ through the official ``openenv.core.Environment`` base class.
7
+
8
+ The adapter is intentionally thin: it just translates between OpenEnv's
9
+ ``Action`` / ``Observation`` / ``State`` Pydantic shapes and our internal
10
+ ``DecoderObservation`` / ``DecoderAction`` / ``StepResult``. All the
11
+ physics, reward scoring, curriculum, and episode bookkeeping continue to
12
+ live in :class:`DecoderEnvironment` - that code is *the* tested,
13
+ production path.
14
+
15
+ Usage
16
+ -----
17
+
18
+ The OpenEnv-compliant FastAPI app is created with::
19
+
20
+ from openenv.core import create_fastapi_app
21
+ from qubit_medic.server.openenv_adapter import (
22
+ QubitMedicEnvironment, QubitMedicAction, QubitMedicObservation,
23
+ )
24
+
25
+ app = create_fastapi_app(
26
+ env=QubitMedicEnvironment,
27
+ action_cls=QubitMedicAction,
28
+ observation_cls=QubitMedicObservation,
29
+ )
30
+
31
+ This registers the canonical OpenEnv routes:
32
+
33
+ * ``POST /reset`` - body ``{"seed": int?, "episode_id": str?}``
34
+ * ``POST /step`` - body ``{"action": {...QubitMedicAction...},
35
+ "timeout_s": float?, "request_id": str?}``
36
+ * ``GET /state`` - returns the current :class:`QubitMedicState`
37
+ * ``GET /health`` - liveness probe
38
+ * ``GET /schema`` - JSON Schema for the action/observation models
39
+ * ``GET /metadata`` - environment metadata
40
+ * ``POST /mcp`` - Model Context Protocol endpoint
41
+ * ``GET /docs`` - Swagger UI (auto-generated by FastAPI)
42
+
43
+ We additionally mount our own ``/healthz`` (Day-0 contract) and
44
+ ``/decode`` (PyMatching baseline demo) on the returned app from
45
+ ``qubit_medic.server.app``.
46
+ """
47
+ from __future__ import annotations
48
+
49
+ from typing import Any, Optional
50
+
51
+ from openenv.core import Action, Environment, Observation, State
52
+ from openenv.core.env_server.types import EnvironmentMetadata
53
+ from pydantic import ConfigDict, Field
54
+
55
+ from qubit_medic.server.environment import DecoderEnvironment
56
+
57
+
58
+ # --------------------------------------------------------------------------- #
59
+ # Process-wide singleton #
60
+ # --------------------------------------------------------------------------- #
61
+ # OpenEnv's HTTP server (simulation mode) instantiates a *fresh* Environment
62
+ # via the factory on every /reset and /step call. Our episode bookkeeping
63
+ # (the `_active` dict) lives inside DecoderEnvironment, so we route every
64
+ # QubitMedicEnvironment instance through the same DecoderEnvironment.
65
+ # This keeps reset() -> step() pairing intact across stateless HTTP calls
66
+ # while remaining fully compatible with OpenEnv's WebSocket session model
67
+ # (each WS session still gets its own QubitMedicEnvironment wrapper).
68
+
69
+ _INNER_SINGLETON: Optional[DecoderEnvironment] = None
70
+
71
+
72
+ def _get_shared_inner() -> DecoderEnvironment:
73
+ """Return the process-wide DecoderEnvironment, building it lazily."""
74
+ global _INNER_SINGLETON
75
+ if _INNER_SINGLETON is None:
76
+ env = DecoderEnvironment()
77
+ env._cache_for("L1_warmup") # noqa: SLF001 - intentional pre-warm
78
+ env._cache_for("L2_target") # noqa: SLF001
79
+ _INNER_SINGLETON = env
80
+ return _INNER_SINGLETON
81
+
82
+
83
+ # --------------------------------------------------------------------------- #
84
+ # OpenEnv-flavoured Action / Observation / State #
85
+ # --------------------------------------------------------------------------- #
86
+
87
+
88
+ class QubitMedicAction(Action):
89
+ """LLM-emitted action: the raw text the model generated.
90
+
91
+ The server parses this into ``x_error_qubits`` / ``z_error_qubits`` via
92
+ :func:`qubit_medic.prompts.parse_action`. We keep the wire format
93
+ *just the raw string* so the server retains full control over parsing
94
+ (and so the trainer's reward function can audit unparseable outputs).
95
+
96
+ The trainer is also free to populate ``parsed_x_errors`` /
97
+ ``parsed_z_errors`` directly when it wants to bypass the LLM (useful
98
+ for baseline policies and unit tests).
99
+ """
100
+
101
+ # Inherit Action.model_config (extra='forbid', validate_assignment=True).
102
+ raw_response: str = Field(
103
+ default="",
104
+ description="Raw LLM completion text. Server parses to x/z error lists.",
105
+ )
106
+ parsed_x_errors: Optional[list[int]] = Field(
107
+ default=None,
108
+ description="Optional pre-parsed X-error qubit ids (LLM-space). "
109
+ "When provided, the server skips text parsing.",
110
+ )
111
+ parsed_z_errors: Optional[list[int]] = Field(
112
+ default=None,
113
+ description="Optional pre-parsed Z-error qubit ids (LLM-space).",
114
+ )
115
+ episode_id: Optional[int] = Field(
116
+ default=None,
117
+ description="Server-assigned episode id from the matching reset(). "
118
+ "If omitted, the most-recent active episode is used.",
119
+ )
120
+
121
+
122
+ class QubitMedicObservation(Observation):
123
+ """OpenEnv observation - mirrors :class:`DecoderObservation` plus the
124
+ standard OpenEnv ``done`` / ``reward`` fields.
125
+
126
+ The ``info`` dict (returned by ``step``) carries the per-component
127
+ reward breakdown, the ground-truth observable flip, and the PyMatching
128
+ baseline prediction so the trainer can score auxiliary metrics.
129
+ """
130
+
131
+ model_config = ConfigDict(extra="forbid", validate_assignment=True,
132
+ arbitrary_types_allowed=True)
133
+
134
+ prompt: str = Field(default="", description="Pre-formatted LLM prompt.")
135
+ syndrome_bits: list[int] = Field(default_factory=list,
136
+ description="Detector activations (0/1).")
137
+ distance: int = Field(default=0, description="Code distance for this episode.")
138
+ rounds: int = Field(default=0, description="Number of stabilizer rounds.")
139
+ p: float = Field(default=0.0, description="SI1000 base error rate.")
140
+ curriculum_level: str = Field(default="",
141
+ description="Curriculum level name.")
142
+ episode_id: int = Field(default=0,
143
+ description="Server-assigned episode counter.")
144
+ dem_digest: str = Field(default="",
145
+ description="Short hash of the detector error model.")
146
+ info: dict[str, Any] = Field(default_factory=dict,
147
+ description="Per-step extras (reward "
148
+ "breakdown, ground-truth flip, "
149
+ "PyMatching baseline, etc.).")
150
+
151
+
152
+ class QubitMedicState(State):
153
+ """Externally-visible state. We expose only the curriculum + episode
154
+ counters; physics-truth fields stay server-side to prevent reward
155
+ hacking (see :mod:`qubit_medic.models.DecoderState` doc-comment)."""
156
+
157
+ model_config = ConfigDict(extra="allow", validate_assignment=True,
158
+ arbitrary_types_allowed=True)
159
+
160
+ episodes_started: int = 0
161
+ active_episodes: int = 0
162
+ cached_levels: list[str] = Field(default_factory=list)
163
+ curriculum: dict[str, Any] = Field(default_factory=dict)
164
+ last_reward_breakdown: Optional[dict[str, float]] = None
165
+
166
+
167
+ # --------------------------------------------------------------------------- #
168
+ # Environment wrapper #
169
+ # --------------------------------------------------------------------------- #
170
+
171
+
172
+ class QubitMedicEnvironment(Environment[QubitMedicAction,
173
+ QubitMedicObservation,
174
+ QubitMedicState]):
175
+ """OpenEnv-compliant view of :class:`DecoderEnvironment`.
176
+
177
+ Single-step episodes (``done=True`` after every ``step``). The OpenEnv
178
+ HTTP server gets a fresh instance per WebSocket session if
179
+ ``SUPPORTS_CONCURRENT_SESSIONS=True``; we set it to ``False`` because
180
+ our DecoderEnvironment uses a single Stim cache + a coarse lock, which
181
+ is simpler than per-session state and good enough for the GRPO
182
+ training loop.
183
+ """
184
+
185
+ SUPPORTS_CONCURRENT_SESSIONS: bool = False
186
+
187
+ def __init__(self) -> None:
188
+ super().__init__()
189
+ # Share the underlying DecoderEnvironment across every wrapper
190
+ # instance the HTTP server creates - see _get_shared_inner.
191
+ self._inner = _get_shared_inner()
192
+ self._last_episode_id: Optional[int] = None
193
+ self._last_reward_breakdown: Optional[dict[str, float]] = None
194
+
195
+ # ----- abstract API --------------------------------------------------- #
196
+
197
+ def reset(
198
+ self,
199
+ seed: Optional[int] = None,
200
+ episode_id: Optional[str] = None,
201
+ **kwargs: Any,
202
+ ) -> QubitMedicObservation:
203
+ forced_level = kwargs.get("forced_level")
204
+ obs = self._inner.reset(seed=seed, forced_level=forced_level)
205
+ self._last_episode_id = obs.episode_id
206
+ self._last_reward_breakdown = None
207
+ return QubitMedicObservation(
208
+ prompt=obs.prompt,
209
+ syndrome_bits=list(obs.syndrome_bits),
210
+ distance=obs.distance,
211
+ rounds=obs.rounds,
212
+ p=obs.p,
213
+ curriculum_level=obs.curriculum_level,
214
+ episode_id=obs.episode_id,
215
+ dem_digest=obs.dem_digest,
216
+ done=False,
217
+ reward=None,
218
+ info={"event": "reset"},
219
+ )
220
+
221
+ def step(
222
+ self,
223
+ action: QubitMedicAction,
224
+ timeout_s: Optional[float] = None,
225
+ **kwargs: Any,
226
+ ) -> QubitMedicObservation:
227
+ ep = action.episode_id if action.episode_id is not None else self._last_episode_id
228
+ if ep is None:
229
+ raise RuntimeError(
230
+ "step() called before reset(); no active episode to score."
231
+ )
232
+
233
+ # If the trainer pre-parsed the action, format a synthetic raw
234
+ # response in the canonical "X: ... | Z: ..." shape so the server's
235
+ # parser produces the same x/z lists.
236
+ if action.parsed_x_errors is not None or action.parsed_z_errors is not None:
237
+ xs = action.parsed_x_errors or []
238
+ zs = action.parsed_z_errors or []
239
+ raw = f"<answer>X: {','.join(map(str, xs))} | Z: {','.join(map(str, zs))}</answer>"
240
+ else:
241
+ raw = action.raw_response
242
+
243
+ result = self._inner.step(raw_response=raw, episode_id=ep)
244
+ self._last_reward_breakdown = result.info.get("rewards")
245
+
246
+ return QubitMedicObservation(
247
+ prompt=result.observation.prompt,
248
+ syndrome_bits=list(result.observation.syndrome_bits),
249
+ distance=result.observation.distance,
250
+ rounds=result.observation.rounds,
251
+ p=result.observation.p,
252
+ curriculum_level=result.observation.curriculum_level,
253
+ episode_id=result.observation.episode_id,
254
+ dem_digest=result.observation.dem_digest,
255
+ done=result.done,
256
+ reward=float(result.reward),
257
+ info=result.info,
258
+ )
259
+
260
+ @property
261
+ def state(self) -> QubitMedicState:
262
+ h = self._inner.health()
263
+ return QubitMedicState(
264
+ episode_id=str(self._last_episode_id)
265
+ if self._last_episode_id is not None else None,
266
+ step_count=int(h.get("episodes_started", 0)),
267
+ episodes_started=int(h.get("episodes_started", 0)),
268
+ active_episodes=int(h.get("active_episodes", 0)),
269
+ cached_levels=list(h.get("cached_levels", [])),
270
+ curriculum=dict(h.get("curriculum", {})),
271
+ last_reward_breakdown=self._last_reward_breakdown,
272
+ )
273
+
274
+ # ----- nice-to-haves -------------------------------------------------- #
275
+
276
+ def get_metadata(self) -> EnvironmentMetadata:
277
+ return EnvironmentMetadata(
278
+ name="QubitMedicEnvironment",
279
+ description=(
280
+ "RL training environment for LLM-based quantum error-"
281
+ "correction decoders. Built on Stim + PyMatching. Five "
282
+ "verifiable rewards (logical correction, syndrome consistency, "
283
+ "Hamming overlap, format compliance, PyMatching beat-rate)."
284
+ ),
285
+ version="1.0.0",
286
+ )
287
+
288
+ def close(self) -> None: # nothing to clean up
289
+ return None
qubit_medic/server/physics.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stim + PyMatching wrapper - the physics engine (Section 2.4 of the plan).
2
+
3
+ This module never makes decoding decisions: it builds circuits, samples
4
+ syndromes, computes baselines, and exposes the observable's support on the
5
+ data qubits so the reward functions can score predictions deterministically.
6
+
7
+ Two design choices worth flagging up-front:
8
+
9
+ * The LLM's action is a **terminal Pauli frame** on data qubits (the X and Z
10
+ errors on each data qubit at the moment of final measurement). This
11
+ representation is exact for the rotated memory_z task and lets us reuse
12
+ Stim/PyMatching ground-truth machinery. The trade-off is documented in
13
+ ``rewards.py``: the syndrome-consistency reward (Reward 2) only constrains
14
+ the *final-round* detectors. Earlier rounds are silent w.r.t. an
15
+ end-of-circuit Pauli frame; that is intentional and made explicit in the
16
+ reward's docstring.
17
+
18
+ * "Ground-truth error pattern" for Reward 3 is taken to be the
19
+ **PyMatching-most-probable error pattern** explaining the syndrome
20
+ (extracted via ``Matching.decode_to_edges_array``). This is the
21
+ near-optimal canonical choice and matches what the AlphaQubit baseline
22
+ comparison uses. The README's *honesty note* repeats this.
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import hashlib
27
+ from dataclasses import dataclass
28
+ from typing import Optional
29
+
30
+ import numpy as np
31
+ import pymatching
32
+ import stim
33
+
34
+ from qubit_medic.config import (
35
+ CODE_TASK,
36
+ CurriculumLevel,
37
+ SI1000Rates,
38
+ )
39
+
40
+
41
+ # --------------------------------------------------------------------------- #
42
+ # Circuit + DEM construction #
43
+ # --------------------------------------------------------------------------- #
44
+
45
+
46
+ def build_circuit(level: CurriculumLevel) -> stim.Circuit:
47
+ """Generate a Stim ``rotated_memory_z`` circuit at the given level."""
48
+ rates = SI1000Rates.from_p(level.p)
49
+ return stim.Circuit.generated(
50
+ CODE_TASK,
51
+ distance=level.distance,
52
+ rounds=level.rounds,
53
+ **rates.as_stim_kwargs(),
54
+ )
55
+
56
+
57
+ def build_dem(circuit: stim.Circuit) -> stim.DetectorErrorModel:
58
+ """Decompose-errors=True is mandatory for PyMatching."""
59
+ return circuit.detector_error_model(decompose_errors=True)
60
+
61
+
62
+ def dem_digest(dem: stim.DetectorErrorModel) -> str:
63
+ """8-char digest of the DEM, useful for grouping training logs."""
64
+ return hashlib.sha256(str(dem).encode("utf-8")).hexdigest()[:8]
65
+
66
+
67
+ # --------------------------------------------------------------------------- #
68
+ # Layout introspection - figure out data qubits, ancillas, observable support #
69
+ # --------------------------------------------------------------------------- #
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class CircuitLayout:
74
+ """Static facts about a circuit, computed once per episode.
75
+
76
+ Two indexings coexist:
77
+
78
+ * **Stim IDs** (``data_qubits``) are the physical qubit IDs Stim emits
79
+ (e.g. ``(1, 3, 5, 8, 10, 12, 15, 17, 19)`` for distance-3). These are
80
+ what Stim/PyMatching speak.
81
+ * **LLM IDs** are consecutive ``0..num_data_qubits-1``. These are what
82
+ the prompt advertises and what the LLM emits, because consecutive
83
+ small ints are dramatically easier for a language model to handle.
84
+
85
+ :meth:`llm_to_stim` and :meth:`stim_to_llm` perform the remap. *All*
86
+ server-internal scoring uses Stim IDs; the boundary at the prompt
87
+ formatter / parser converts.
88
+ """
89
+
90
+ data_qubits: tuple[int, ...]
91
+ """Stim IDs of data qubits (measured by terminal ``M``), sorted."""
92
+
93
+ data_qubit_coords: tuple[tuple[float, float], ...]
94
+ """(x, y) coordinate of each data qubit, in the same order as
95
+ ``data_qubits``. Used by Reward 3 to snap PyMatching edges to qubits."""
96
+
97
+ ancilla_qubits: tuple[int, ...]
98
+ """Physical qubit IDs that hold stabiliser measurements (``MR``)."""
99
+
100
+ z_observable_support: tuple[int, ...]
101
+ """Data qubits whose Z value is XOR'd into the logical Z observable.
102
+ An X error on any of these flips the observable."""
103
+
104
+ detector_round: tuple[int, ...]
105
+ """For each detector index, the round it nominally belongs to (0-based,
106
+ extracted from the ``DETECTOR(x, y, t)`` coordinate)."""
107
+
108
+ detector_coords: tuple[tuple[float, float], ...]
109
+ """(x, y) coordinate of each detector, used by Reward 3."""
110
+
111
+ detector_is_x_type: tuple[bool, ...]
112
+ """Whether the detector watches an X-stabiliser. For the rotated surface
113
+ code Stim places X-stabilisers at coordinates with ``(x + y) mod 4 == 2``
114
+ and Z-stabilisers at ``(x + y) mod 4 == 0`` (verified empirically against
115
+ Stim 1.15's ``surface_code:rotated_memory_z``)."""
116
+
117
+ final_detectors: tuple[int, ...]
118
+ """Indices of detectors that correspond to the *last* timeslice - those
119
+ are the only detectors a terminal Pauli frame can affect (Reward 2)."""
120
+
121
+ num_data_qubits: int
122
+ num_ancilla_qubits: int
123
+ num_detectors: int
124
+ num_observables: int
125
+
126
+ # ----- LLM <-> Stim qubit-ID remapping ---------------------------------
127
+
128
+ def llm_to_stim(self, llm_ids: list[int]) -> list[int]:
129
+ """Convert consecutive LLM IDs to physical Stim IDs.
130
+
131
+ Out-of-range IDs are silently dropped (the parser already enforces
132
+ the upper bound, but we double-check here as a defence-in-depth).
133
+ """
134
+ out: list[int] = []
135
+ n = len(self.data_qubits)
136
+ for i in llm_ids:
137
+ if 0 <= i < n:
138
+ out.append(self.data_qubits[i])
139
+ return out
140
+
141
+ def stim_to_llm(self, stim_ids: list[int]) -> list[int]:
142
+ """Inverse of :meth:`llm_to_stim` - used to render targets in the
143
+ SFT data and the imitator policy."""
144
+ lookup = {q: i for i, q in enumerate(self.data_qubits)}
145
+ return [lookup[q] for q in stim_ids if q in lookup]
146
+
147
+
148
+ def _walk_measurement_records(
149
+ circuit: stim.Circuit,
150
+ ) -> tuple[list[int], list[Optional[str]]]:
151
+ """Replay the circuit (no sampling) to map each measurement record to a
152
+ qubit. Returns parallel lists: qubits[i] = qubit id, instr[i] = gate."""
153
+ qubits: list[int] = []
154
+ instrs: list[Optional[str]] = []
155
+
156
+ def _walk(c: stim.Circuit, repeats: int = 1) -> None:
157
+ for _ in range(repeats):
158
+ for inst in c:
159
+ if isinstance(inst, stim.CircuitRepeatBlock):
160
+ _walk(inst.body_copy(), inst.repeat_count)
161
+ continue
162
+ name = inst.name
163
+ if name in {
164
+ "M", "MX", "MY", "MZ",
165
+ "MR", "MRX", "MRY", "MRZ",
166
+ "MPP",
167
+ }:
168
+ for t in inst.targets_copy():
169
+ if t.is_qubit_target:
170
+ qubits.append(t.qubit_value)
171
+ instrs.append(name)
172
+
173
+ _walk(circuit)
174
+ return qubits, instrs
175
+
176
+
177
+ def extract_layout(circuit: stim.Circuit) -> CircuitLayout:
178
+ """Walk the circuit once to build a full :class:`CircuitLayout`."""
179
+ flat = circuit.flattened()
180
+ measurement_qubits, measurement_instrs = _walk_measurement_records(circuit)
181
+
182
+ # Data qubits = those measured by terminal ``M`` (destructive, no reset).
183
+ data_qubits_in_order: list[int] = []
184
+ seen_data = set()
185
+ for q, instr in zip(measurement_qubits, measurement_instrs):
186
+ if instr == "M" and q not in seen_data:
187
+ data_qubits_in_order.append(q)
188
+ seen_data.add(q)
189
+ data_qubits = tuple(sorted(seen_data))
190
+
191
+ # Ancilla qubits = everything measured by MR (reset after measurement).
192
+ ancilla_qubits = tuple(
193
+ sorted({q for q, instr in zip(measurement_qubits, measurement_instrs)
194
+ if instr == "MR"})
195
+ )
196
+
197
+ # Observable support: walk OBSERVABLE_INCLUDE entries and resolve their
198
+ # rec[-k] back to qubit IDs via the measurement record table.
199
+ obs_support: dict[int, set[int]] = {}
200
+ for inst in flat:
201
+ if inst.name == "OBSERVABLE_INCLUDE":
202
+ args = inst.gate_args_copy()
203
+ obs_idx = int(args[0]) if args else 0
204
+ for t in inst.targets_copy():
205
+ if t.is_measurement_record_target:
206
+ actual = len(measurement_qubits) + t.value # value is negative
207
+ if 0 <= actual < len(measurement_qubits):
208
+ obs_support.setdefault(obs_idx, set()).add(
209
+ measurement_qubits[actual]
210
+ )
211
+ z_obs = tuple(sorted(obs_support.get(0, set())))
212
+
213
+ # Qubit coordinates from QUBIT_COORDS instructions.
214
+ qubit_coords: dict[int, tuple[float, float]] = {}
215
+ for inst in flat:
216
+ if inst.name == "QUBIT_COORDS":
217
+ args = inst.gate_args_copy()
218
+ x = float(args[0]) if len(args) >= 1 else 0.0
219
+ y = float(args[1]) if len(args) >= 2 else 0.0
220
+ for t in inst.targets_copy():
221
+ if t.is_qubit_target:
222
+ qubit_coords[t.qubit_value] = (x, y)
223
+ data_qubit_coords = tuple(qubit_coords.get(q, (0.0, 0.0)) for q in data_qubits)
224
+
225
+ # Detector coordinates - last value of the tuple is the round index.
226
+ det_coords_raw = circuit.get_detector_coordinates()
227
+ num_dets = circuit.num_detectors
228
+ rounds_per_det: list[int] = []
229
+ is_x_type: list[bool] = []
230
+ detector_coords: list[tuple[float, float]] = []
231
+ for i in range(num_dets):
232
+ c = det_coords_raw.get(i, ())
233
+ if not c:
234
+ rounds_per_det.append(0)
235
+ is_x_type.append(False)
236
+ detector_coords.append((0.0, 0.0))
237
+ continue
238
+ round_idx = int(c[-1]) if len(c) >= 3 else 0
239
+ rounds_per_det.append(round_idx)
240
+ x = float(c[0]) if len(c) >= 1 else 0.0
241
+ y = float(c[1]) if len(c) >= 2 else 0.0
242
+ detector_coords.append((x, y))
243
+ # X-stabilisers sit at (x + y) % 4 == 2 in Stim's generator.
244
+ is_x_type.append((int(x + y) % 4) == 2)
245
+
246
+ final_round = max(rounds_per_det) if rounds_per_det else 0
247
+ final_dets = tuple(i for i, r in enumerate(rounds_per_det) if r == final_round)
248
+
249
+ return CircuitLayout(
250
+ data_qubits=data_qubits,
251
+ data_qubit_coords=data_qubit_coords,
252
+ ancilla_qubits=ancilla_qubits,
253
+ z_observable_support=z_obs,
254
+ detector_round=tuple(rounds_per_det),
255
+ detector_coords=tuple(detector_coords),
256
+ detector_is_x_type=tuple(is_x_type),
257
+ final_detectors=final_dets,
258
+ num_data_qubits=len(data_qubits),
259
+ num_ancilla_qubits=len(ancilla_qubits),
260
+ num_detectors=num_dets,
261
+ num_observables=circuit.num_observables,
262
+ )
263
+
264
+
265
+ # --------------------------------------------------------------------------- #
266
+ # Sampling and decoding #
267
+ # --------------------------------------------------------------------------- #
268
+
269
+
270
+ @dataclass(frozen=True)
271
+ class SyndromeSample:
272
+ """One noisy episode: detector activations, ground-truth observable
273
+ flip, and PyMatching's prediction (used by Rewards 1 and 5)."""
274
+
275
+ syndrome_bits: list[int]
276
+ actual_observable_flip: int
277
+ pymatching_observable_pred: int
278
+ pymatching_x_errors: list[int] # Pauli frame at end of circuit (X part)
279
+ pymatching_z_errors: list[int] # Pauli frame at end of circuit (Z part)
280
+
281
+
282
+ def sample_episode(
283
+ circuit: stim.Circuit,
284
+ matching: pymatching.Matching,
285
+ layout: CircuitLayout,
286
+ seed: Optional[int] = None,
287
+ ) -> SyndromeSample:
288
+ """Sample one shot, decode it with PyMatching, and bundle the result."""
289
+ sampler = circuit.compile_detector_sampler(seed=seed)
290
+ detection, observables = sampler.sample(1, separate_observables=True)
291
+ detection_row = detection[0].astype(np.uint8)
292
+ observable_flip = int(observables[0, 0]) if observables.shape[1] else 0
293
+
294
+ # PyMatching's prediction (observable level).
295
+ pred_obs = int(matching.decode(detection_row)[0])
296
+
297
+ # PyMatching's predicted physical Pauli frame on data qubits.
298
+ pred_x, pred_z = pymatching_predicted_pauli_frame(
299
+ matching=matching, syndrome=detection_row, layout=layout,
300
+ )
301
+
302
+ return SyndromeSample(
303
+ syndrome_bits=detection_row.tolist(),
304
+ actual_observable_flip=observable_flip,
305
+ pymatching_observable_pred=pred_obs,
306
+ pymatching_x_errors=pred_x,
307
+ pymatching_z_errors=pred_z,
308
+ )
309
+
310
+
311
+ def pymatching_predicted_pauli_frame(
312
+ matching: pymatching.Matching,
313
+ syndrome: np.ndarray,
314
+ layout: CircuitLayout,
315
+ ) -> tuple[list[int], list[int]]:
316
+ """Convert PyMatching's per-edge prediction into a data-qubit Pauli frame.
317
+
318
+ The matching graph's edges correspond to error mechanisms in the DEM.
319
+ Each edge connects two detectors (or a detector and a boundary). The
320
+ data qubit responsible for the edge sits geometrically between the two
321
+ detectors on the surface-code grid - we recover it by snapping the
322
+ midpoint of the detector coordinates to the nearest data qubit.
323
+
324
+ This frame is used as ground-truth for Reward 3 (Hamming overlap).
325
+ Z-stabiliser endpoints (``(x+y) mod 4 == 0``) catch X errors on data
326
+ qubits; X-stabiliser endpoints catch Z errors. Boundary edges are
327
+ snapped to the unique data qubit adjacent to that boundary.
328
+ """
329
+ try:
330
+ edges = matching.decode_to_edges_array(syndrome)
331
+ except Exception:
332
+ return [], []
333
+
334
+ if edges is None or len(edges) == 0:
335
+ return [], []
336
+
337
+ data_qubits = layout.data_qubits
338
+ data_coords = layout.data_qubit_coords
339
+ det_coords = layout.detector_coords
340
+ det_is_x = layout.detector_is_x_type
341
+ n_dets = len(det_coords)
342
+
343
+ def _snap(x: float, y: float) -> int:
344
+ best_q = data_qubits[0]
345
+ best_d = float("inf")
346
+ for q, (qx, qy) in zip(data_qubits, data_coords):
347
+ d = (qx - x) ** 2 + (qy - y) ** 2
348
+ if d < best_d:
349
+ best_d = d
350
+ best_q = q
351
+ return best_q
352
+
353
+ x_errs: set[int] = set()
354
+ z_errs: set[int] = set()
355
+ for edge in edges:
356
+ a, b = int(edge[0]), int(edge[1])
357
+ ca = det_coords[a] if 0 <= a < n_dets else None
358
+ cb = det_coords[b] if 0 <= b < n_dets else None
359
+ if ca is None and cb is None:
360
+ continue
361
+ if cb is None:
362
+ mid_x, mid_y = ca
363
+ ref_is_x = det_is_x[a]
364
+ elif ca is None:
365
+ mid_x, mid_y = cb
366
+ ref_is_x = det_is_x[b]
367
+ else:
368
+ mid_x = (ca[0] + cb[0]) / 2.0
369
+ mid_y = (ca[1] + cb[1]) / 2.0
370
+ ref_is_x = det_is_x[a] if 0 <= a < n_dets else det_is_x[b]
371
+ snap = _snap(mid_x, mid_y)
372
+ if ref_is_x:
373
+ z_errs.add(snap)
374
+ else:
375
+ x_errs.add(snap)
376
+
377
+ return sorted(x_errs), sorted(z_errs)
378
+
379
+
380
+ # --------------------------------------------------------------------------- #
381
+ # Predicted-observable computation (used by Reward 1) #
382
+ # --------------------------------------------------------------------------- #
383
+
384
+
385
+ def predicted_observable_flip(
386
+ predicted_x_qubits: list[int],
387
+ layout: CircuitLayout,
388
+ ) -> int:
389
+ """Compute the implied logical-Z flip from a predicted Pauli frame.
390
+
391
+ Only X errors on data qubits in ``z_observable_support`` matter for the
392
+ Z observable - Z errors on data qubits commute with the destructive Z
393
+ measurement and so cannot flip the observable.
394
+ """
395
+ support = set(layout.z_observable_support)
396
+ parity = 0
397
+ for q in predicted_x_qubits:
398
+ if q in support:
399
+ parity ^= 1
400
+ return parity
401
+
402
+
403
+ def rectify_pauli_frame_to_observable(
404
+ x_errors: list[int],
405
+ z_errors: list[int],
406
+ target_observable_flip: int,
407
+ layout: CircuitLayout,
408
+ ) -> tuple[list[int], list[int]]:
409
+ """Adjust a predicted X-error frame so its implied observable matches.
410
+
411
+ Used by the SFT data generator and the PyMatching imitator policy: the
412
+ snap-to-data-qubit edge mapping (:func:`pymatching_predicted_pauli_frame`)
413
+ is only ~95% faithful, but PyMatching's *observable* prediction is exact.
414
+ We patch the X frame by toggling the smallest-degree data qubit on the
415
+ observable support whenever the implied parity disagrees with the
416
+ target. Z errors are untouched because they don't affect a Z observable.
417
+ """
418
+ implied = predicted_observable_flip(x_errors, layout)
419
+ if implied == target_observable_flip:
420
+ return list(x_errors), list(z_errors)
421
+
422
+ support = list(layout.z_observable_support)
423
+ if not support:
424
+ return list(x_errors), list(z_errors)
425
+
426
+ x_set = set(x_errors)
427
+ intersect = sorted(x_set & set(support))
428
+ if intersect:
429
+ # Remove the smallest one currently flipping the observable.
430
+ x_set.discard(intersect[0])
431
+ else:
432
+ # Add the smallest support qubit to introduce a flip.
433
+ x_set.add(support[0])
434
+ return sorted(x_set), list(z_errors)
435
+
436
+
437
+ # --------------------------------------------------------------------------- #
438
+ # Stabiliser counts - derived from layout #
439
+ # --------------------------------------------------------------------------- #
440
+
441
+
442
+ def detector_round_split(layout: CircuitLayout, syndrome_bits: list[int]) -> dict[int, list[int]]:
443
+ """Group detector bits by their nominal round (used for prompt formatting)."""
444
+ out: dict[int, list[int]] = {}
445
+ for idx, bit in enumerate(syndrome_bits):
446
+ r = layout.detector_round[idx] if idx < len(layout.detector_round) else 0
447
+ out.setdefault(r, []).append(bit)
448
+ return out
449
+
450
+
451
+ def per_round_x_z_counts(layout: CircuitLayout) -> tuple[int, int]:
452
+ """Best-effort count of X-type and Z-type stabiliser detectors per round.
453
+
454
+ For a rotated surface code at distance d there are (d^2-1)/2 of each
455
+ type. We compute that from the layout to be robust.
456
+ """
457
+ # Take one fully-populated round (the one with the most detectors).
458
+ round_counts: dict[int, list[bool]] = {}
459
+ for idx, r in enumerate(layout.detector_round):
460
+ round_counts.setdefault(r, []).append(layout.detector_is_x_type[idx])
461
+ if not round_counts:
462
+ return 0, 0
463
+ full_round = max(round_counts.values(), key=len)
464
+ n_x = sum(1 for v in full_round if v)
465
+ n_z = sum(1 for v in full_round if not v)
466
+ return n_x, n_z
qubit_medic/server/rewards.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The five reward functions (Section 3 of the plan).
2
+
3
+ Design contract (from Section 3.6):
4
+
5
+ * Each reward is a pure function ``(action, state, layout) -> float in [0, 1]``.
6
+ * Rewards never observe each other - they're independent by construction so
7
+ the LLM can't satisfy one at the expense of another without genuine task
8
+ understanding.
9
+ * The combined reward is a weighted sum (weights in :mod:`qubit_medic.config`)
10
+ clamped to ``[0, 1]``.
11
+ * Every per-component score is reported in the ``info`` dict so logs can
12
+ surface reward-hacking early (Section 3.7).
13
+
14
+ A note on Reward 2 and Reward 3 ground truth - see ``physics.py``: the LLM
15
+ predicts a *terminal Pauli frame*, which fully determines the logical-Z
16
+ observable but only constrains the *final-round* detectors. Earlier rounds'
17
+ detectors are intentionally unscored. Reward 3 compares against PyMatching's
18
+ near-optimal Pauli-frame prediction (the canonical decoder reference used in
19
+ AlphaQubit's Nature paper).
20
+ """
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass
24
+
25
+ from qubit_medic.config import REWARD_WEIGHTS
26
+ from qubit_medic.prompts import ParseResult
27
+ from qubit_medic.server.physics import (
28
+ CircuitLayout,
29
+ SyndromeSample,
30
+ predicted_observable_flip,
31
+ )
32
+
33
+
34
+ # --------------------------------------------------------------------------- #
35
+ # Reward 1: logical correction success #
36
+ # --------------------------------------------------------------------------- #
37
+
38
+
39
+ def reward_logical_correction(
40
+ parsed: ParseResult,
41
+ sample: SyndromeSample,
42
+ layout: CircuitLayout,
43
+ ) -> float:
44
+ """Did the predicted correction preserve the logical state?
45
+
46
+ Apply the predicted X errors as a Pauli frame at end-of-circuit and
47
+ compute the implied observable flip. If this matches the actual
48
+ observable flip recorded by Stim, the logical state was preserved.
49
+ Outputs 1.0 if so, else 0.0.
50
+
51
+ This is the unfakeable reward - it depends only on Stim's ground truth.
52
+ """
53
+ implied = predicted_observable_flip(parsed.x_errors, layout)
54
+ return 1.0 if implied == sample.actual_observable_flip else 0.0
55
+
56
+
57
+ # --------------------------------------------------------------------------- #
58
+ # Reward 2: syndrome consistency #
59
+ # --------------------------------------------------------------------------- #
60
+
61
+
62
+ def _syndrome_from_pauli_frame(
63
+ x_errors: list[int],
64
+ layout: CircuitLayout,
65
+ final_detector_supports: dict[int, frozenset[int]],
66
+ ) -> dict[int, int]:
67
+ """Compute the implied bits for FINAL-round detectors only.
68
+
69
+ A terminal X error on data qubit ``q`` flips a final-round Z-stabiliser
70
+ detector iff ``q`` is in that detector's support.
71
+ """
72
+ out: dict[int, int] = {}
73
+ x_set = set(x_errors)
74
+ for det_idx, support in final_detector_supports.items():
75
+ out[det_idx] = 1 if len(x_set & support) % 2 == 1 else 0
76
+ return out
77
+
78
+
79
+ def reward_syndrome_consistency(
80
+ parsed: ParseResult,
81
+ sample: SyndromeSample,
82
+ layout: CircuitLayout,
83
+ final_detector_supports: dict[int, frozenset[int]],
84
+ ) -> float:
85
+ """How well does the predicted Pauli frame reproduce the FINAL detectors?
86
+
87
+ Computes Hamming similarity between ``predicted_final_bits`` (induced by
88
+ the predicted X errors) and ``observed_final_bits``. Returns
89
+ ``1 - hamming_distance / num_final_detectors``.
90
+
91
+ Rationale (Section 3.2): without this term, an LLM that lucky-guesses
92
+ the right qubits could get Reward 1 occasionally; this signal forces it
93
+ to also explain the data the syndrome carries.
94
+ """
95
+ final_dets = layout.final_detectors
96
+ if not final_dets:
97
+ return 0.0
98
+ implied = _syndrome_from_pauli_frame(
99
+ parsed.x_errors, layout, final_detector_supports
100
+ )
101
+ distance = 0
102
+ for det_idx in final_dets:
103
+ observed = sample.syndrome_bits[det_idx]
104
+ predicted = implied.get(det_idx, 0)
105
+ if observed != predicted:
106
+ distance += 1
107
+ return 1.0 - distance / len(final_dets)
108
+
109
+
110
+ def compute_final_detector_supports(
111
+ layout: CircuitLayout,
112
+ syndrome_bits_unused: list[int] | None = None, # API symmetry
113
+ *,
114
+ detector_to_data_qubits: dict[int, frozenset[int]] | None = None,
115
+ ) -> dict[int, frozenset[int]]:
116
+ """Map each final-round detector to the set of data qubits whose
117
+ terminal X error flips it.
118
+
119
+ For the rotated memory_z code, each Z-stabiliser final detector watches
120
+ the four (or two/one on the boundary) data qubits adjacent to it on the
121
+ grid. We compute adjacency by Euclidean distance; data qubits at
122
+ distance ``sqrt(2)`` from a Z-stabiliser ancilla coordinate are
123
+ incident.
124
+ """
125
+ if detector_to_data_qubits is not None:
126
+ return detector_to_data_qubits
127
+
128
+ out: dict[int, frozenset[int]] = {}
129
+ for det_idx in layout.final_detectors:
130
+ dx, dy = layout.detector_coords[det_idx]
131
+ adj: set[int] = set()
132
+ for q, (qx, qy) in zip(layout.data_qubits, layout.data_qubit_coords):
133
+ if abs((qx - dx) ** 2 + (qy - dy) ** 2 - 2.0) < 1e-6:
134
+ adj.add(q)
135
+ out[det_idx] = frozenset(adj)
136
+ return out
137
+
138
+
139
+ # --------------------------------------------------------------------------- #
140
+ # Reward 3: Hamming overlap with reference Pauli frame #
141
+ # --------------------------------------------------------------------------- #
142
+
143
+
144
+ def _jaccard(a: list[int], b: list[int]) -> float:
145
+ """Jaccard index. Returns 1.0 when both sets are empty (perfect agreement)."""
146
+ sa, sb = set(a), set(b)
147
+ if not sa and not sb:
148
+ return 1.0
149
+ inter = len(sa & sb)
150
+ union = len(sa | sb)
151
+ return inter / union if union else 1.0
152
+
153
+
154
+ def reward_hamming_overlap(
155
+ parsed: ParseResult,
156
+ sample: SyndromeSample,
157
+ layout: CircuitLayout,
158
+ ) -> float:
159
+ """Average of Jaccard(X) and Jaccard(Z) against the reference frame.
160
+
161
+ Reference is PyMatching's per-edge predicted Pauli frame
162
+ (``sample.pymatching_x_errors`` / ``..._z_errors``). This is the dense
163
+ partial-credit signal of Section 3.3 - even if Reward 1 fires zero,
164
+ being *close* to the canonical solution still gets credit, smoothing
165
+ the reward landscape during early training.
166
+ """
167
+ jx = _jaccard(parsed.x_errors, sample.pymatching_x_errors)
168
+ jz = _jaccard(parsed.z_errors, sample.pymatching_z_errors)
169
+ return 0.5 * (jx + jz)
170
+
171
+
172
+ # --------------------------------------------------------------------------- #
173
+ # Reward 4: format compliance #
174
+ # --------------------------------------------------------------------------- #
175
+
176
+
177
+ def reward_format_compliance(parsed: ParseResult) -> float:
178
+ """1.0 if both keys parsed, 0.5 if exactly one, 0.0 if neither."""
179
+ return parsed.format_score
180
+
181
+
182
+ # --------------------------------------------------------------------------- #
183
+ # Reward 5: PyMatching beat-rate bonus #
184
+ # --------------------------------------------------------------------------- #
185
+
186
+
187
+ def reward_pymatching_beat(
188
+ parsed: ParseResult,
189
+ sample: SyndromeSample,
190
+ layout: CircuitLayout,
191
+ ) -> float:
192
+ """1.0 iff PyMatching got this syndrome wrong AND the LLM got it right.
193
+
194
+ This is the headline metric (Section 3.5). Most of training it'll be
195
+ near zero; the trajectory of its mean over steps is the proof we've
196
+ moved past pure imitation.
197
+ """
198
+ pm_correct = sample.pymatching_observable_pred == sample.actual_observable_flip
199
+ if pm_correct:
200
+ return 0.0
201
+ llm_implied = predicted_observable_flip(parsed.x_errors, layout)
202
+ return 1.0 if llm_implied == sample.actual_observable_flip else 0.0
203
+
204
+
205
+ # --------------------------------------------------------------------------- #
206
+ # Combined reward #
207
+ # --------------------------------------------------------------------------- #
208
+
209
+
210
+ @dataclass(frozen=True)
211
+ class RewardBreakdown:
212
+ """Per-component scores plus the weighted total."""
213
+
214
+ logical_correction: float
215
+ syndrome_consistency: float
216
+ hamming_overlap: float
217
+ format_compliance: float
218
+ pymatching_beat: float
219
+ total: float
220
+
221
+ def as_dict(self) -> dict[str, float]:
222
+ return {
223
+ "logical_correction": self.logical_correction,
224
+ "syndrome_consistency": self.syndrome_consistency,
225
+ "hamming_overlap": self.hamming_overlap,
226
+ "format_compliance": self.format_compliance,
227
+ "pymatching_beat": self.pymatching_beat,
228
+ "total": self.total,
229
+ }
230
+
231
+
232
+ def compute_all_rewards(
233
+ parsed: ParseResult,
234
+ sample: SyndromeSample,
235
+ layout: CircuitLayout,
236
+ final_detector_supports: dict[int, frozenset[int]],
237
+ weights: dict[str, float] = REWARD_WEIGHTS,
238
+ ) -> RewardBreakdown:
239
+ """Compute all five rewards and the weighted total.
240
+
241
+ Returns a :class:`RewardBreakdown` whose ``as_dict`` is what the env's
242
+ ``info`` payload contains. The trainer logs each component separately.
243
+ """
244
+ r1 = reward_logical_correction(parsed, sample, layout)
245
+ r2 = reward_syndrome_consistency(parsed, sample, layout, final_detector_supports)
246
+ r3 = reward_hamming_overlap(parsed, sample, layout)
247
+ r4 = reward_format_compliance(parsed)
248
+ r5 = reward_pymatching_beat(parsed, sample, layout)
249
+ total = (
250
+ weights["logical_correction"] * r1
251
+ + weights["syndrome_consistency"] * r2
252
+ + weights["hamming_overlap"] * r3
253
+ + weights["format_compliance"] * r4
254
+ + weights["pymatching_beat"] * r5
255
+ )
256
+ total = max(0.0, min(1.0, total))
257
+ return RewardBreakdown(
258
+ logical_correction=r1,
259
+ syndrome_consistency=r2,
260
+ hamming_overlap=r3,
261
+ format_compliance=r4,
262
+ pymatching_beat=r5,
263
+ total=total,
264
+ )
qubit_medic/wandb_utils.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Central Weights & Biases integration for Qubit-Medic.
2
+
3
+ Design goals
4
+ ------------
5
+ 1. **Single source of truth** for the W&B project name, default tags, and
6
+ the ``config`` dump that every run logs. Trainers, eval scripts, and
7
+ notebooks all funnel through :func:`init_run` so dashboards always
8
+ line up.
9
+
10
+ 2. **Safe to import without wandb installed.** The package's training
11
+ deps (``wandb``) live in ``requirements-train.txt`` and are absent on
12
+ the lean Spaces image. Anything in this module degrades gracefully
13
+ when the import fails - the rest of the project keeps working.
14
+
15
+ 3. **Disable-by-env-var.** Set ``WANDB_DISABLED=1`` (or
16
+ ``QUBIT_MEDIC_WANDB=0``) and every helper here becomes a no-op,
17
+ regardless of whether the package is installed. CI runs and offline
18
+ testing rely on this.
19
+
20
+ 4. **Rich first-class logging.** We expose dedicated helpers for the
21
+ things this project cares about:
22
+
23
+ * Per-reward component scalars (5 lines per step, not just total)
24
+ * Curriculum-level moving averages (one line per level)
25
+ * Parse success / partial / failure rates
26
+ * Generation sample tables (prompt / completion / per-reward)
27
+ * Eval summary tables (one row per (policy, level))
28
+
29
+ The trainers and eval script only have to call these helpers; the
30
+ Pythonic context manager handles init, summary, and finish.
31
+ """
32
+ from __future__ import annotations
33
+
34
+ import contextlib
35
+ import dataclasses
36
+ import os
37
+ import socket
38
+ import sys
39
+ import time
40
+ from typing import Any, Iterable, Mapping, Optional, Sequence
41
+
42
+ from qubit_medic.config import (
43
+ CURRICULUM, GRPO_GEN_PER_PROMPT, GRPO_KL_COEF, GRPO_LR, GRPO_MAX_COMPLETION_LEN,
44
+ GRPO_MAX_PROMPT_LEN, GRPO_STEPS, LORA_ALPHA, LORA_R, LORA_TARGET_MODULES,
45
+ MODEL_ID, PRIMARY_SEED, REWARD_WEIGHTS, SFT_BATCH_SIZE, SFT_EPOCHS,
46
+ SFT_GRAD_ACCUM, SFT_LR, SFT_MAX_SEQ_LEN, WANDB_DEFAULT_TAGS, WANDB_ENTITY,
47
+ WANDB_LOG_GENERATIONS_EVERY, WANDB_PROJECT, WANDB_SAMPLE_GENERATIONS,
48
+ )
49
+
50
+
51
+ # --------------------------------------------------------------------------- #
52
+ # Lazy import + on/off toggle #
53
+ # --------------------------------------------------------------------------- #
54
+
55
+
56
+ _WANDB_MODULE = None
57
+ _RUN: Any = None
58
+
59
+
60
+ def _import_wandb():
61
+ """Import wandb on first use. Returns ``None`` if it isn't installed."""
62
+ global _WANDB_MODULE
63
+ if _WANDB_MODULE is None:
64
+ try:
65
+ import wandb # type: ignore[import-not-found]
66
+ _WANDB_MODULE = wandb
67
+ except ImportError:
68
+ _WANDB_MODULE = False # sentinel: "tried, failed"
69
+ return _WANDB_MODULE if _WANDB_MODULE is not False else None
70
+
71
+
72
+ def is_disabled() -> bool:
73
+ """Honours ``WANDB_DISABLED`` and ``QUBIT_MEDIC_WANDB=0``."""
74
+ if os.environ.get("WANDB_DISABLED", "").lower() in {"1", "true", "yes"}:
75
+ return True
76
+ if os.environ.get("QUBIT_MEDIC_WANDB", "1").lower() in {"0", "false", "no"}:
77
+ return True
78
+ return False
79
+
80
+
81
+ def is_available() -> bool:
82
+ """``True`` iff wandb is importable AND not disabled by env var."""
83
+ return _import_wandb() is not None and not is_disabled()
84
+
85
+
86
+ def get_run():
87
+ """Return the active W&B run object, or ``None`` if not initialised."""
88
+ return _RUN
89
+
90
+
91
+ # --------------------------------------------------------------------------- #
92
+ # Init / finish #
93
+ # --------------------------------------------------------------------------- #
94
+
95
+
96
+ def _system_metadata() -> dict:
97
+ """Static metadata that's helpful on the dashboard but isn't a hyperparam."""
98
+ info = {
99
+ "python_version": sys.version.split()[0],
100
+ "hostname": socket.gethostname(),
101
+ "argv": " ".join(sys.argv),
102
+ "pid": os.getpid(),
103
+ }
104
+ try:
105
+ import torch
106
+ info["torch_version"] = torch.__version__
107
+ info["cuda_available"] = bool(torch.cuda.is_available())
108
+ if torch.cuda.is_available():
109
+ info["cuda_device"] = torch.cuda.get_device_name(0)
110
+ except Exception:
111
+ pass
112
+ try:
113
+ import stim
114
+ info["stim_version"] = stim.__version__
115
+ except Exception:
116
+ pass
117
+ try:
118
+ import pymatching
119
+ info["pymatching_version"] = pymatching.__version__
120
+ except Exception:
121
+ pass
122
+ try:
123
+ import trl, transformers, peft
124
+ info["trl_version"] = trl.__version__
125
+ info["transformers_version"] = transformers.__version__
126
+ info["peft_version"] = peft.__version__
127
+ except Exception:
128
+ pass
129
+ return info
130
+
131
+
132
+ def _build_default_config(extra: Optional[Mapping[str, Any]] = None) -> dict:
133
+ """The config every run logs - hyperparameters + reward weights + curriculum."""
134
+ cfg: dict[str, Any] = {
135
+ "model_id": MODEL_ID,
136
+ "primary_seed": PRIMARY_SEED,
137
+ "lora_r": LORA_R,
138
+ "lora_alpha": LORA_ALPHA,
139
+ "lora_target_modules": list(LORA_TARGET_MODULES),
140
+ "sft": {
141
+ "epochs": SFT_EPOCHS,
142
+ "batch_size": SFT_BATCH_SIZE,
143
+ "grad_accum": SFT_GRAD_ACCUM,
144
+ "lr": SFT_LR,
145
+ "max_seq_len": SFT_MAX_SEQ_LEN,
146
+ },
147
+ "grpo": {
148
+ "steps": GRPO_STEPS,
149
+ "gen_per_prompt": GRPO_GEN_PER_PROMPT,
150
+ "lr": GRPO_LR,
151
+ "kl_coef": GRPO_KL_COEF,
152
+ "max_prompt_len": GRPO_MAX_PROMPT_LEN,
153
+ "max_completion_len": GRPO_MAX_COMPLETION_LEN,
154
+ },
155
+ "reward_weights": dict(REWARD_WEIGHTS),
156
+ "curriculum": [
157
+ {
158
+ "name": lvl.name, "distance": lvl.distance, "rounds": lvl.rounds,
159
+ "p": lvl.p, "promotion_threshold": lvl.promotion_threshold,
160
+ }
161
+ for lvl in CURRICULUM
162
+ ],
163
+ "system": _system_metadata(),
164
+ }
165
+ if extra:
166
+ cfg.update(extra)
167
+ return cfg
168
+
169
+
170
+ def init_run(
171
+ run_name: str,
172
+ job_type: str,
173
+ *,
174
+ tags: Optional[Sequence[str]] = None,
175
+ extra_config: Optional[Mapping[str, Any]] = None,
176
+ notes: Optional[str] = None,
177
+ group: Optional[str] = None,
178
+ ):
179
+ """Initialise (or no-op) a W&B run.
180
+
181
+ Parameters
182
+ ----------
183
+ run_name:
184
+ Human-readable run name (e.g. ``"sft-warmup-2026-04-25"``).
185
+ job_type:
186
+ One of ``"sft"``, ``"grpo"``, ``"eval"``, ``"format-test"``,
187
+ ``"baseline"``. Used to group runs on the dashboard.
188
+ tags:
189
+ Extra tags appended to :data:`qubit_medic.config.WANDB_DEFAULT_TAGS`.
190
+ extra_config:
191
+ Hyperparameters specific to this run (e.g. SFT epochs override).
192
+ notes:
193
+ Free-text notes shown on the dashboard.
194
+ group:
195
+ Optional W&B group (used to bundle SFT + GRPO + eval runs of the
196
+ same experiment).
197
+
198
+ Returns
199
+ -------
200
+ The wandb Run object, or ``None`` if W&B is unavailable / disabled.
201
+ """
202
+ global _RUN
203
+ wandb = _import_wandb()
204
+ if wandb is None or is_disabled():
205
+ if wandb is None:
206
+ print("[wandb] not installed; skipping init "
207
+ "(install with `pip install wandb` to enable logging)",
208
+ file=sys.stderr)
209
+ else:
210
+ print("[wandb] disabled by env var; skipping init", file=sys.stderr)
211
+ _RUN = None
212
+ return None
213
+
214
+ all_tags = list(WANDB_DEFAULT_TAGS) + list(tags or [])
215
+ cfg = _build_default_config(extra=extra_config)
216
+ cfg["job_type"] = job_type
217
+
218
+ _RUN = wandb.init(
219
+ project=WANDB_PROJECT,
220
+ entity=WANDB_ENTITY,
221
+ name=run_name,
222
+ job_type=job_type,
223
+ tags=all_tags,
224
+ config=cfg,
225
+ notes=notes,
226
+ group=group,
227
+ reinit=True,
228
+ )
229
+ print(f"[wandb] run live at {_RUN.url}", file=sys.stderr)
230
+ return _RUN
231
+
232
+
233
+ def finish_run() -> None:
234
+ """Cleanly close the current W&B run, if any."""
235
+ global _RUN
236
+ wandb = _import_wandb()
237
+ if wandb is None or _RUN is None:
238
+ _RUN = None
239
+ return
240
+ try:
241
+ wandb.finish()
242
+ finally:
243
+ _RUN = None
244
+
245
+
246
+ @contextlib.contextmanager
247
+ def run_context(run_name: str, job_type: str, **kwargs):
248
+ """Context-manager wrapper around :func:`init_run` / :func:`finish_run`."""
249
+ init_run(run_name, job_type, **kwargs)
250
+ try:
251
+ yield get_run()
252
+ finally:
253
+ finish_run()
254
+
255
+
256
+ # --------------------------------------------------------------------------- #
257
+ # Generic logging helpers #
258
+ # --------------------------------------------------------------------------- #
259
+
260
+
261
+ def log(metrics: Mapping[str, Any], *, step: Optional[int] = None,
262
+ commit: bool = True) -> None:
263
+ """No-op-safe ``wandb.log`` wrapper."""
264
+ wandb = _import_wandb()
265
+ if wandb is None or _RUN is None:
266
+ return
267
+ try:
268
+ wandb.log(dict(metrics), step=step, commit=commit)
269
+ except Exception as exc: # pragma: no cover - defensive
270
+ print(f"[wandb] log failed: {exc}", file=sys.stderr)
271
+
272
+
273
+ def update_summary(values: Mapping[str, Any]) -> None:
274
+ """Write to ``run.summary`` (the run's headline numbers)."""
275
+ if _RUN is None:
276
+ return
277
+ try:
278
+ for k, v in values.items():
279
+ _RUN.summary[k] = v
280
+ except Exception as exc: # pragma: no cover
281
+ print(f"[wandb] summary update failed: {exc}", file=sys.stderr)
282
+
283
+
284
+ # --------------------------------------------------------------------------- #
285
+ # Project-specific helpers #
286
+ # --------------------------------------------------------------------------- #
287
+
288
+
289
+ _REWARD_KEYS = (
290
+ "logical_correction",
291
+ "syndrome_consistency",
292
+ "hamming_overlap",
293
+ "format_compliance",
294
+ "pymatching_beat",
295
+ "total",
296
+ )
297
+
298
+
299
+ def log_reward_breakdown(
300
+ breakdowns: Sequence[Mapping[str, float]],
301
+ *,
302
+ step: Optional[int] = None,
303
+ prefix: str = "rl",
304
+ ) -> None:
305
+ """Log mean / min / max for each of the five reward components.
306
+
307
+ ``breakdowns`` is a list of dicts, one per generation in the most-recent
308
+ GRPO step (length = ``GRPO_GEN_PER_PROMPT * batch``). We log mean and
309
+ standard deviation so the dashboard has both signal and noise.
310
+ """
311
+ if not breakdowns or _RUN is None:
312
+ return
313
+ out: dict[str, float] = {}
314
+ for k in _REWARD_KEYS:
315
+ vals = [float(b.get(k, 0.0)) for b in breakdowns]
316
+ n = max(1, len(vals))
317
+ mean = sum(vals) / n
318
+ var = sum((v - mean) ** 2 for v in vals) / n
319
+ out[f"{prefix}/reward/{k}_mean"] = mean
320
+ out[f"{prefix}/reward/{k}_std"] = var ** 0.5
321
+ out[f"{prefix}/reward/{k}_max"] = max(vals)
322
+ out[f"{prefix}/reward/{k}_min"] = min(vals)
323
+ log(out, step=step)
324
+
325
+
326
+ def log_parse_stats(
327
+ parse_results: Iterable, # iterable of qubit_medic.prompts.ParseResult
328
+ *,
329
+ step: Optional[int] = None,
330
+ prefix: str = "rl",
331
+ ) -> None:
332
+ """Log parse-success / partial / failure rates for the most-recent batch."""
333
+ if _RUN is None:
334
+ return
335
+ parse_results = list(parse_results)
336
+ n = max(1, len(parse_results))
337
+ success = sum(1 for r in parse_results if getattr(r, "parse_success", False))
338
+ partial = sum(1 for r in parse_results
339
+ if not getattr(r, "parse_success", False)
340
+ and getattr(r, "parse_partial", False))
341
+ log({
342
+ f"{prefix}/parse/success_rate": success / n,
343
+ f"{prefix}/parse/partial_rate": partial / n,
344
+ f"{prefix}/parse/failure_rate": (n - success - partial) / n,
345
+ f"{prefix}/parse/sample_count": n,
346
+ }, step=step)
347
+
348
+
349
+ def log_curriculum(
350
+ curriculum_stats: Mapping[str, Mapping[str, float]],
351
+ *,
352
+ step: Optional[int] = None,
353
+ prefix: str = "rl",
354
+ ) -> None:
355
+ """Log the per-level moving-average from the env health endpoint.
356
+
357
+ ``curriculum_stats`` is what
358
+ :meth:`qubit_medic.server.curriculum.CurriculumScheduler.snapshot`
359
+ returns; one inner dict per level with keys ``moving_mean`` / ``samples``.
360
+ """
361
+ if _RUN is None or not curriculum_stats:
362
+ return
363
+ out: dict[str, float] = {}
364
+ for level_name, stats in curriculum_stats.items():
365
+ out[f"{prefix}/curriculum/{level_name}_mean"] = float(stats.get("moving_mean", 0.0))
366
+ out[f"{prefix}/curriculum/{level_name}_samples"] = float(stats.get("samples", 0.0))
367
+ log(out, step=step)
368
+
369
+
370
+ def log_generation_table(
371
+ rows: Sequence[Mapping[str, Any]],
372
+ *,
373
+ step: Optional[int],
374
+ table_name: str = "rl/generations",
375
+ columns: Optional[Sequence[str]] = None,
376
+ ) -> None:
377
+ """Log a wandb.Table of (prompt, completion, reward, ...) rows.
378
+
379
+ Each row is a flat dict; the column set is the union of all keys (or
380
+ the explicit ``columns`` arg). Used to surface qualitative samples
381
+ in addition to the scalar curves.
382
+ """
383
+ wandb = _import_wandb()
384
+ if wandb is None or _RUN is None or not rows:
385
+ return
386
+ cols = list(columns) if columns is not None else sorted(
387
+ {k for row in rows for k in row.keys()}
388
+ )
389
+ try:
390
+ table = wandb.Table(columns=cols)
391
+ for row in rows:
392
+ table.add_data(*[row.get(c, None) for c in cols])
393
+ log({table_name: table}, step=step)
394
+ except Exception as exc: # pragma: no cover
395
+ print(f"[wandb] table log failed: {exc}", file=sys.stderr)
396
+
397
+
398
+ def log_eval_summary(
399
+ summary: Mapping[str, Any],
400
+ *,
401
+ step: Optional[int] = None,
402
+ prefix: str = "eval",
403
+ ) -> None:
404
+ """Log the dict produced by ``scripts/eval._summary`` as scalars."""
405
+ if _RUN is None:
406
+ return
407
+ out: dict[str, Any] = {}
408
+ for k, v in summary.items():
409
+ if isinstance(v, (int, float)):
410
+ out[f"{prefix}/{k}"] = v
411
+ log(out, step=step)
412
+ update_summary({f"{prefix}/{k}": v for k, v in summary.items()
413
+ if isinstance(v, (int, float, str, bool))})
414
+
415
+
416
+ def log_artifact(
417
+ path: str, *, name: str, artifact_type: str,
418
+ description: Optional[str] = None,
419
+ ) -> None:
420
+ """Save a file or directory as a W&B artifact."""
421
+ wandb = _import_wandb()
422
+ if wandb is None or _RUN is None:
423
+ return
424
+ try:
425
+ art = wandb.Artifact(name, type=artifact_type, description=description)
426
+ if os.path.isdir(path):
427
+ art.add_dir(path)
428
+ else:
429
+ art.add_file(path)
430
+ _RUN.log_artifact(art)
431
+ except Exception as exc: # pragma: no cover
432
+ print(f"[wandb] artifact log failed: {exc}", file=sys.stderr)
433
+
434
+
435
+ # --------------------------------------------------------------------------- #
436
+ # CLI integration helpers #
437
+ # --------------------------------------------------------------------------- #
438
+
439
+
440
+ def derive_report_to(report_to: str) -> str:
441
+ """Translate the user-facing ``--report-to`` flag.
442
+
443
+ If the user passes ``"wandb"`` but wandb is unavailable, fall back to
444
+ ``"none"`` rather than crashing the trainer. Lets the same script run
445
+ on Colab (with wandb) and CI (without).
446
+ """
447
+ if report_to == "wandb" and not is_available():
448
+ print("[wandb] requested via --report-to but unavailable; falling back to 'none'",
449
+ file=sys.stderr)
450
+ return "none"
451
+ return report_to
452
+
453
+
454
+ def make_run_name(prefix: str, suffix: Optional[str] = None) -> str:
455
+ """Build a default run name like ``sft-warmup-20260425-2105``."""
456
+ stamp = time.strftime("%Y%m%d-%H%M%S")
457
+ bits = [prefix, stamp]
458
+ if suffix:
459
+ bits.append(suffix)
460
+ return "-".join(bits)
461
+
462
+
463
+ __all__ = [
464
+ "derive_report_to",
465
+ "finish_run",
466
+ "get_run",
467
+ "init_run",
468
+ "is_available",
469
+ "is_disabled",
470
+ "log",
471
+ "log_artifact",
472
+ "log_curriculum",
473
+ "log_eval_summary",
474
+ "log_generation_table",
475
+ "log_parse_stats",
476
+ "log_reward_breakdown",
477
+ "make_run_name",
478
+ "run_context",
479
+ "update_summary",
480
+ "WANDB_LOG_GENERATIONS_EVERY",
481
+ "WANDB_SAMPLE_GENERATIONS",
482
+ ]
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pin the versions called out in the plan (Section 1.1).
2
+ # These versions are known compatible with Python 3.11 + CUDA 12.x.
3
+ # DO NOT bump without re-running scripts/validate_env.py.
4
+
5
+ # --- Quantum simulation ---
6
+ stim>=1.13,<2.0
7
+ pymatching>=2.2,<3.0
8
+
9
+ # --- Environment / serving ---
10
+ fastapi>=0.110
11
+ uvicorn[standard]>=0.27
12
+ pydantic>=2.5,<3.0
13
+ httpx>=0.27
14
+ numpy>=1.26,<2.1
15
+
16
+ # --- Plotting (used by scripts/plot_results.py) ---
17
+ matplotlib>=3.8
18
+ pillow>=10
19
+
20
+ # --- Test runner ---
21
+ pytest>=8
22
+
23
+ # --- OpenEnv (HuggingFace's RL-env framework). Required by the submission
24
+ # guidelines ("Use OpenEnv (latest release). Build on top of the
25
+ # framework; don't reinvent the wheel."). Our server wraps the
26
+ # DecoderEnvironment with `openenv.core.Environment`-compatible
27
+ # Action/Observation/State models so TRL can drive it via
28
+ # `environment_factory=` (see qubit_medic/server/openenv_adapter.py
29
+ # and scripts/train_grpo.py). Lightweight; safe in the Spaces image.
30
+ openenv-core>=0.2.1
31
+
32
+ # --- Training (omit on CPU-only deploy) ---
33
+ # Heavy ML deps live in requirements-train.txt; the env server itself does
34
+ # not import any of them. Keeping them out of the Docker image keeps the
35
+ # Spaces upload under the free-tier image-size limit.