Spaces:
Sleeping
Sleeping
Real env: openenv-core wrapped DecoderEnvironment + /healthz + /decode
Browse files- Dockerfile +48 -12
- README.md +93 -18
- app.py +0 -78
- qubit_medic/__init__.py +15 -0
- qubit_medic/client/__init__.py +5 -0
- qubit_medic/client/client.py +132 -0
- qubit_medic/config.py +254 -0
- qubit_medic/models.py +143 -0
- qubit_medic/prompts.py +230 -0
- qubit_medic/server/__init__.py +5 -0
- qubit_medic/server/app.py +169 -0
- qubit_medic/server/curriculum.py +104 -0
- qubit_medic/server/environment.py +314 -0
- qubit_medic/server/openenv_adapter.py +289 -0
- qubit_medic/server/physics.py +466 -0
- qubit_medic/server/rewards.py +264 -0
- qubit_medic/wandb_utils.py +482 -0
- requirements.txt +35 -0
Dockerfile
CHANGED
|
@@ -1,25 +1,61 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
PIP_NO_CACHE_DIR=1 \
|
| 5 |
PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
RUN useradd -m -u 1000 user
|
| 8 |
USER user
|
| 9 |
ENV PATH="/home/user/.local/bin:$PATH"
|
| 10 |
|
| 11 |
WORKDIR /app
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
"uvicorn[standard]>=0.27" \
|
| 20 |
-
"openenv-core>=0.2.1"
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
EXPOSE 7860
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Qubit-Medic OpenEnv server container.
|
| 2 |
+
#
|
| 3 |
+
# This image ships ONLY the env-server code:
|
| 4 |
+
# * stim + pymatching - quantum simulation + matching baseline
|
| 5 |
+
# * fastapi + uvicorn - HTTP transport
|
| 6 |
+
# * openenv-core - canonical OpenEnv contract (/reset, /step,
|
| 7 |
+
# /state, /health, /schema, /metadata, /mcp,
|
| 8 |
+
# /docs)
|
| 9 |
+
#
|
| 10 |
+
# Heavy ML training deps (torch, transformers, trl, unsloth) are
|
| 11 |
+
# deliberately NOT installed - they live in requirements-train.txt and
|
| 12 |
+
# are installed only by the Colab training notebook. Keeping the Spaces
|
| 13 |
+
# image lean avoids the ~10 GB CUDA wheel that would blow the free tier.
|
| 14 |
|
| 15 |
+
FROM python:3.11-slim AS base
|
| 16 |
+
|
| 17 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 18 |
+
PYTHONUNBUFFERED=1 \
|
| 19 |
PIP_NO_CACHE_DIR=1 \
|
| 20 |
PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 21 |
|
| 22 |
+
# Stim and PyMatching ship manylinux wheels - no system C++ deps needed
|
| 23 |
+
# beyond libstdc++. We keep build-essential as a safety net for any
|
| 24 |
+
# unexpected source-fallback path on the build host.
|
| 25 |
+
RUN apt-get update \
|
| 26 |
+
&& apt-get install -y --no-install-recommends \
|
| 27 |
+
build-essential \
|
| 28 |
+
ca-certificates \
|
| 29 |
+
curl \
|
| 30 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 31 |
+
|
| 32 |
+
# HF Spaces best-practice: run as non-root user with UID 1000.
|
| 33 |
RUN useradd -m -u 1000 user
|
| 34 |
USER user
|
| 35 |
ENV PATH="/home/user/.local/bin:$PATH"
|
| 36 |
|
| 37 |
WORKDIR /app
|
| 38 |
|
| 39 |
+
COPY --chown=user requirements.txt /app/requirements.txt
|
| 40 |
+
RUN pip install --user --upgrade pip \
|
| 41 |
+
&& pip install --user -r /app/requirements.txt
|
| 42 |
+
|
| 43 |
+
COPY --chown=user qubit_medic /app/qubit_medic
|
| 44 |
+
COPY --chown=user README.md /app/README.md
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
# Pre-warm Stim/PyMatching caches at build time so the first request
|
| 47 |
+
# after `docker run` has near-zero latency (Section 9.1 of the plan).
|
| 48 |
+
RUN python -c "from qubit_medic.server.environment import DecoderEnvironment; \
|
| 49 |
+
e = DecoderEnvironment(); \
|
| 50 |
+
e._cache_for('L1_warmup'); \
|
| 51 |
+
e._cache_for('L2_target')"
|
| 52 |
|
| 53 |
EXPOSE 7860
|
| 54 |
+
|
| 55 |
+
ENV LOG_LEVEL=INFO \
|
| 56 |
+
QUBIT_MEDIC_HOST=0.0.0.0 \
|
| 57 |
+
QUBIT_MEDIC_PORT=7860
|
| 58 |
+
|
| 59 |
+
# Boots the FastAPI app (qubit_medic.server.app) which is built on top
|
| 60 |
+
# of openenv.core.create_fastapi_app.
|
| 61 |
+
CMD ["python", "-m", "qubit_medic.server.app"]
|
README.md
CHANGED
|
@@ -7,43 +7,118 @@ sdk: docker
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
-
short_description:
|
| 11 |
tags:
|
| 12 |
- openenv
|
| 13 |
- reinforcement-learning
|
| 14 |
- quantum-error-correction
|
| 15 |
- stim
|
| 16 |
- pymatching
|
|
|
|
|
|
|
|
|
|
| 17 |
---
|
| 18 |
|
| 19 |
-
# QuantumScribe —
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
## Try it
|
| 26 |
|
| 27 |
-
|
| 28 |
-
* `GET /healthz` — liveness probe; returns `{"ok": true, "stim_version": "...", ...}`.
|
| 29 |
|
| 30 |
```bash
|
| 31 |
curl https://ronitraj-quantumscribe.hf.space/healthz
|
| 32 |
```
|
| 33 |
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|---|---|
|
| 40 |
-
| `POST /reset` | Sample a fresh syndrome + observation for the LLM. |
|
| 41 |
-
| `POST /step` | Score the LLM's predicted Pauli frame with five independent rewards (logical correction, syndrome consistency, Hamming overlap, format compliance, PyMatching beat-rate). |
|
| 42 |
-
| `GET /health` | Curriculum + episode statistics. |
|
| 43 |
-
| `GET /healthz` | Liveness (already live on this placeholder). |
|
| 44 |
|
| 45 |
## Stack
|
| 46 |
|
| 47 |
-
* [
|
| 48 |
-
* [
|
| 49 |
-
* [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
+
short_description: OpenEnv RL env that teaches an LLM to decode quantum errors.
|
| 11 |
tags:
|
| 12 |
- openenv
|
| 13 |
- reinforcement-learning
|
| 14 |
- quantum-error-correction
|
| 15 |
- stim
|
| 16 |
- pymatching
|
| 17 |
+
- grpo
|
| 18 |
+
- trl
|
| 19 |
+
- llm
|
| 20 |
---
|
| 21 |
|
| 22 |
+
# QuantumScribe — Qubit-Medic OpenEnv
|
| 23 |
|
| 24 |
+
> An [OpenEnv](https://meta-pytorch.github.io/OpenEnv/) reinforcement-learning
|
| 25 |
+
> environment that teaches a Large Language Model to decode errors on the
|
| 26 |
+
> rotated surface code, using **Stim** for physics-accurate noise sampling
|
| 27 |
+
> and **PyMatching** as the classical baseline to beat.
|
| 28 |
+
|
| 29 |
+
This Space hosts the **environment server only**. Training (SFT warm-up
|
| 30 |
+
+ GRPO RL) runs on a separate Colab T4; the trained LoRA adapter is
|
| 31 |
+
loaded client-side, not on this Space.
|
| 32 |
+
|
| 33 |
+
## Endpoints
|
| 34 |
+
|
| 35 |
+
This is the canonical **OpenEnv contract** registered by
|
| 36 |
+
`openenv.core.create_fastapi_app`, plus two extras of our own:
|
| 37 |
+
|
| 38 |
+
| Method | Path | Purpose |
|
| 39 |
+
|---|---|---|
|
| 40 |
+
| `POST` | `/reset` | Sample a fresh syndrome + observation. Body `{"seed": int?, "episode_id": str?}`. Optional `?forced_level=L1_warmup\|L2_target\|L3_stretch`. |
|
| 41 |
+
| `POST` | `/step` | Score the LLM's prediction with five independent rewards. Body `{"action": {"raw_response": "...", "episode_id": int}, "timeout_s": float?, "request_id": str?}`. |
|
| 42 |
+
| `GET` | `/state` | Curriculum + episode counters (no physics-truth fields). |
|
| 43 |
+
| `GET` | `/health` | OpenEnv liveness response. |
|
| 44 |
+
| `GET` | `/healthz` | Lightweight Day-0 probe — Stim/PyMatching/openenv versions. |
|
| 45 |
+
| `GET` | `/schema` | JSON Schema for `QubitMedicAction` / `QubitMedicObservation`. |
|
| 46 |
+
| `GET` | `/metadata` | Environment metadata (name, description, version). |
|
| 47 |
+
| `POST` | `/mcp` | Model Context Protocol endpoint. |
|
| 48 |
+
| `POST` | `/decode` | PyMatching baseline demo: pass a hand-crafted syndrome, get the matching-decoder result. |
|
| 49 |
+
| `GET` | `/docs` | Swagger UI for everything above. |
|
| 50 |
|
| 51 |
## Try it
|
| 52 |
|
| 53 |
+
Curl from anywhere:
|
|
|
|
| 54 |
|
| 55 |
```bash
|
| 56 |
curl https://ronitraj-quantumscribe.hf.space/healthz
|
| 57 |
```
|
| 58 |
|
| 59 |
+
Use it from Python with the OpenEnv client:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
from openenv.core import GenericEnvClient
|
| 63 |
+
|
| 64 |
+
with GenericEnvClient(base_url="https://ronitraj-quantumscribe.hf.space").sync() as env:
|
| 65 |
+
obs = env.reset(seed=42)
|
| 66 |
+
print(obs.observation["prompt"][:200])
|
| 67 |
+
result = env.step({"raw_response": "<answer>X: 0,3 | Z:</answer>", "episode_id": 1})
|
| 68 |
+
print("reward:", result.reward, "rewards breakdown:", result.observation["info"]["rewards"])
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Or hit the env directly with `httpx`:
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
import httpx
|
| 75 |
+
url = "https://ronitraj-quantumscribe.hf.space"
|
| 76 |
+
obs = httpx.post(f"{url}/reset", json={"seed": 42},
|
| 77 |
+
params={"forced_level": "L2_target"}).json()["observation"]
|
| 78 |
+
print(obs["prompt"][:200])
|
| 79 |
+
res = httpx.post(f"{url}/step", json={
|
| 80 |
+
"action": {"raw_response": "<answer>X: | Z:</answer>",
|
| 81 |
+
"episode_id": obs["episode_id"]}
|
| 82 |
+
}).json()
|
| 83 |
+
print("reward =", res["reward"], "rewards =", res["observation"]["info"]["rewards"])
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## What the rewards mean
|
| 87 |
+
|
| 88 |
+
Each `step` returns five *independent, verifiable* reward components:
|
| 89 |
|
| 90 |
+
| Reward | Weight | What it measures |
|
| 91 |
+
|---|---:|---|
|
| 92 |
+
| `logical_correction` | 0.40 | 1 if the predicted Pauli frame preserves the logical-Z observable. |
|
| 93 |
+
| `syndrome_consistency` | 0.20 | Hamming similarity over final-round detector parities. |
|
| 94 |
+
| `hamming_overlap` | 0.20 | Mean Jaccard similarity vs. the PyMatching reference Pauli frame. |
|
| 95 |
+
| `format_compliance` | 0.10 | 1 / 0.5 / 0 for full / partial / unparseable LLM output. |
|
| 96 |
+
| `pymatching_beat` | 0.10 | 1 iff PyMatching is wrong on this syndrome **and** the model is right. |
|
| 97 |
|
| 98 |
+
All five are computed from the same `(prompt, completion, syndrome)` tuple on every step (no redundant sampling — see the architecture notes in the GitHub README).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
## Stack
|
| 101 |
|
| 102 |
+
* [openenv-core](https://pypi.org/project/openenv-core/) `>=0.2.1` — environment contract, FastAPI scaffolding, MCP, WebSocket sessions.
|
| 103 |
+
* [Stim](https://github.com/quantumlib/Stim) — Clifford circuit simulator and detector-error-model generator.
|
| 104 |
+
* [PyMatching](https://github.com/oscarhiggott/PyMatching) `>=2.2` — minimum-weight matching baseline (and ground-truth for the `pymatching_beat` reward).
|
| 105 |
+
* FastAPI + Uvicorn — HTTP transport.
|
| 106 |
+
|
| 107 |
+
## Curriculum
|
| 108 |
+
|
| 109 |
+
| Level | Distance | Rounds | Physical error rate | Promotion threshold |
|
| 110 |
+
|---|---:|---:|---:|---:|
|
| 111 |
+
| `L1_warmup` | 3 | 1 | 1e-4 | logical_correction ≥ 0.80 |
|
| 112 |
+
| `L2_target` | 3 | 3 | 1e-3 | logical_correction ≥ 0.70 |
|
| 113 |
+
| `L3_stretch` | 5 | 5 | 1e-3 | logical_correction ≥ 0.30 |
|
| 114 |
+
|
| 115 |
+
The server's curriculum scheduler tracks a moving average of `logical_correction` per level and promotes the agent when it crosses the threshold.
|
| 116 |
+
|
| 117 |
+
## See also
|
| 118 |
+
|
| 119 |
+
* OpenEnv documentation — <https://meta-pytorch.github.io/OpenEnv/>
|
| 120 |
+
* TRL `environment_factory=` integration — <https://huggingface.co/docs/trl/main/openenv>
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
*Built for the META RL hackathon. Source code, training notebook, and reproduction instructions: see the [GitHub repository](https://github.com/ronitraj/qubit-medic) (link to be updated).*
|
app.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
"""QuantumScribe - Day-0 deployment-substrate placeholder.
|
| 2 |
-
|
| 3 |
-
This minimal FastAPI app proves the Hugging Face Spaces deployment
|
| 4 |
-
substrate works for the Qubit-Medic / QuantumScribe project:
|
| 5 |
-
|
| 6 |
-
* Stim + PyMatching + FastAPI + openenv-core install cleanly in HF's
|
| 7 |
-
build environment (Day-0 step 2).
|
| 8 |
-
* GET /healthz returns {"ok": true, "stim_version": "..."}, proving the
|
| 9 |
-
server boots and the heavy quantum dependency loads (Day-0 step 3).
|
| 10 |
-
* The endpoint is reachable from a browser (Day-0 step 4) and from a
|
| 11 |
-
Colab `requests.get(...)` call (Day-0 step 5).
|
| 12 |
-
|
| 13 |
-
Once all Day-0 gates pass, this file is replaced by the real Qubit-Medic
|
| 14 |
-
OpenEnv server (qubit_medic.server.app) at the same Space URL, inheriting
|
| 15 |
-
the warm build cache.
|
| 16 |
-
"""
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
import sys
|
| 20 |
-
|
| 21 |
-
import stim
|
| 22 |
-
from fastapi import FastAPI
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
app = FastAPI(
|
| 26 |
-
title="QuantumScribe (Qubit-Medic) - Hello Space",
|
| 27 |
-
description="Day-0 deployment-substrate placeholder for the Qubit-Medic "
|
| 28 |
-
"OpenEnv server. Will be replaced by the real env shortly.",
|
| 29 |
-
version="0.0.1-placeholder",
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
_PYMATCHING_VERSION: str
|
| 34 |
-
try:
|
| 35 |
-
import pymatching as _pm
|
| 36 |
-
_PYMATCHING_VERSION = getattr(_pm, "__version__", "unknown")
|
| 37 |
-
except Exception as exc:
|
| 38 |
-
_PYMATCHING_VERSION = f"import-error: {exc}"
|
| 39 |
-
|
| 40 |
-
_OPENENV_VERSION: str
|
| 41 |
-
try:
|
| 42 |
-
import openenv as _oe
|
| 43 |
-
_OPENENV_VERSION = getattr(_oe, "__version__", "unknown")
|
| 44 |
-
except Exception as exc:
|
| 45 |
-
_OPENENV_VERSION = f"import-error: {exc}"
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
@app.get("/")
|
| 49 |
-
def root() -> dict:
|
| 50 |
-
return {
|
| 51 |
-
"service": "QuantumScribe (Qubit-Medic)",
|
| 52 |
-
"status": "Day-0 placeholder live",
|
| 53 |
-
"next": "POST /reset and /step will become available once the real "
|
| 54 |
-
"DecoderEnvironment is pushed.",
|
| 55 |
-
"endpoints": ["/", "/healthz"],
|
| 56 |
-
"links": {
|
| 57 |
-
"github": "https://github.com/ronitraj (replace with repo URL once public)",
|
| 58 |
-
"openenv_docs": "https://meta-pytorch.org/OpenEnv/",
|
| 59 |
-
},
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@app.get("/healthz")
|
| 64 |
-
def healthz() -> dict:
|
| 65 |
-
"""Day-0 liveness probe.
|
| 66 |
-
|
| 67 |
-
Returns the Stim version - so curl-ing this in a browser or from Colab
|
| 68 |
-
proves both that networking works AND that the heavy quantum deps
|
| 69 |
-
actually loaded. This is the literal endpoint the plan calls for.
|
| 70 |
-
"""
|
| 71 |
-
return {
|
| 72 |
-
"ok": True,
|
| 73 |
-
"stim_version": stim.__version__,
|
| 74 |
-
"pymatching_version": _PYMATCHING_VERSION,
|
| 75 |
-
"openenv_version": _OPENENV_VERSION,
|
| 76 |
-
"python_version": sys.version.split()[0],
|
| 77 |
-
"service": "QuantumScribe",
|
| 78 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qubit_medic/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Qubit-Medic: An LLM-trained quantum error-correction decoder.
|
| 2 |
+
|
| 3 |
+
The package is split into three layers (Section 0 of the plan):
|
| 4 |
+
|
| 5 |
+
* ``qubit_medic.config`` - the locked experiment configuration.
|
| 6 |
+
* ``qubit_medic.server`` - Stim physics, rewards, curriculum, FastAPI app.
|
| 7 |
+
* ``qubit_medic.client`` - the lightweight HTTP stub the trainer imports.
|
| 8 |
+
|
| 9 |
+
``qubit_medic.models`` and ``qubit_medic.prompts`` are the contract both sides
|
| 10 |
+
agree on: what the LLM sees and what the LLM is allowed to emit.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from qubit_medic import config, models, prompts # noqa: F401
|
| 14 |
+
|
| 15 |
+
__version__ = "1.0.0"
|
qubit_medic/client/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Qubit-Medic client - the lightweight HTTP stub the trainer imports."""
|
| 2 |
+
|
| 3 |
+
from qubit_medic.client.client import DecoderClient, LocalDecoderClient
|
| 4 |
+
|
| 5 |
+
__all__ = ["DecoderClient", "LocalDecoderClient"]
|
qubit_medic/client/client.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Two equivalent client implementations:
|
| 2 |
+
|
| 3 |
+
* :class:`DecoderClient` - hits an HTTP endpoint (HF Spaces deployment).
|
| 4 |
+
Speaks the **OpenEnv** wire format:
|
| 5 |
+
- ``POST /reset`` body ``{"seed": int?, "episode_id": str?}``
|
| 6 |
+
- ``POST /step`` body ``{"action": {"raw_response": "...", ...},
|
| 7 |
+
"timeout_s": float?, "request_id": str?}``
|
| 8 |
+
* :class:`LocalDecoderClient` - calls the in-process env directly. Use this
|
| 9 |
+
in tests, in CI, and during local Colab runs to skip HTTP overhead.
|
| 10 |
+
|
| 11 |
+
Both expose the same ``reset`` / ``step`` API so the training scripts can
|
| 12 |
+
swap between them via a single env var (``QUBIT_MEDIC_URL``).
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from typing import Optional, Protocol
|
| 18 |
+
|
| 19 |
+
import httpx
|
| 20 |
+
|
| 21 |
+
from qubit_medic.models import (
|
| 22 |
+
DecoderObservation,
|
| 23 |
+
StepResult,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class _ClientProtocol(Protocol):
|
| 28 |
+
def reset(self, *, seed: Optional[int] = None,
|
| 29 |
+
forced_level: Optional[str] = None) -> DecoderObservation: ...
|
| 30 |
+
def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
|
| 31 |
+
def health(self) -> dict: ...
|
| 32 |
+
def close(self) -> None: ...
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _obs_from_openenv(payload: dict) -> DecoderObservation:
|
| 36 |
+
"""Re-hydrate our internal :class:`DecoderObservation` from the
|
| 37 |
+
OpenEnv response body. The OpenEnv wrapper inlines all our fields
|
| 38 |
+
onto the observation, so this is a 1-1 field mapping."""
|
| 39 |
+
return DecoderObservation(
|
| 40 |
+
prompt=payload.get("prompt", ""),
|
| 41 |
+
syndrome_bits=list(payload.get("syndrome_bits", [])),
|
| 42 |
+
distance=int(payload.get("distance", 0)),
|
| 43 |
+
rounds=int(payload.get("rounds", 0)),
|
| 44 |
+
p=float(payload.get("p", 0.0)),
|
| 45 |
+
curriculum_level=payload.get("curriculum_level", ""),
|
| 46 |
+
episode_id=int(payload.get("episode_id", 0)),
|
| 47 |
+
dem_digest=payload.get("dem_digest", ""),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DecoderClient:
|
| 52 |
+
"""HTTP client targeting a deployed FastAPI server (OpenEnv shape)."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, base_url: str, *, timeout: float = 60.0) -> None:
|
| 55 |
+
self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout)
|
| 56 |
+
|
| 57 |
+
def reset(self, *, seed: Optional[int] = None,
|
| 58 |
+
forced_level: Optional[str] = None) -> DecoderObservation:
|
| 59 |
+
# OpenEnv's ResetRequest only accepts seed + episode_id. We pass
|
| 60 |
+
# forced_level via the URL query string so adapters that honour
|
| 61 |
+
# it (our QubitMedicEnvironment via **kwargs) pick it up; servers
|
| 62 |
+
# that ignore it just get a default level.
|
| 63 |
+
body: dict = {}
|
| 64 |
+
if seed is not None:
|
| 65 |
+
body["seed"] = seed
|
| 66 |
+
params = {"forced_level": forced_level} if forced_level else None
|
| 67 |
+
r = self._client.post("/reset", json=body, params=params)
|
| 68 |
+
r.raise_for_status()
|
| 69 |
+
payload = r.json()
|
| 70 |
+
# OpenEnv returns {observation: {...}, reward, done}.
|
| 71 |
+
return _obs_from_openenv(payload.get("observation", payload))
|
| 72 |
+
|
| 73 |
+
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
|
| 74 |
+
body = {
|
| 75 |
+
"action": {
|
| 76 |
+
"raw_response": raw_response,
|
| 77 |
+
"episode_id": episode_id,
|
| 78 |
+
},
|
| 79 |
+
}
|
| 80 |
+
r = self._client.post("/step", json=body)
|
| 81 |
+
r.raise_for_status()
|
| 82 |
+
payload = r.json()
|
| 83 |
+
obs_payload = payload.get("observation", {})
|
| 84 |
+
return StepResult(
|
| 85 |
+
observation=_obs_from_openenv(obs_payload),
|
| 86 |
+
reward=float(payload.get("reward", 0.0) or 0.0),
|
| 87 |
+
done=bool(payload.get("done", True)),
|
| 88 |
+
truncated=bool(obs_payload.get("info", {}).get("timed_out", False)),
|
| 89 |
+
info=dict(obs_payload.get("info", {})),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def health(self) -> dict:
|
| 93 |
+
r = self._client.get("/health")
|
| 94 |
+
r.raise_for_status()
|
| 95 |
+
return r.json()
|
| 96 |
+
|
| 97 |
+
def healthz(self) -> dict:
|
| 98 |
+
r = self._client.get("/healthz")
|
| 99 |
+
r.raise_for_status()
|
| 100 |
+
return r.json()
|
| 101 |
+
|
| 102 |
+
def close(self) -> None:
|
| 103 |
+
self._client.close()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class LocalDecoderClient:
|
| 107 |
+
"""In-process client - calls :class:`DecoderEnvironment` directly."""
|
| 108 |
+
|
| 109 |
+
def __init__(self, env=None) -> None:
|
| 110 |
+
from qubit_medic.server.environment import DecoderEnvironment
|
| 111 |
+
self._env = env if env is not None else DecoderEnvironment()
|
| 112 |
+
|
| 113 |
+
def reset(self, *, seed: Optional[int] = None,
|
| 114 |
+
forced_level: Optional[str] = None) -> DecoderObservation:
|
| 115 |
+
return self._env.reset(seed=seed, forced_level=forced_level)
|
| 116 |
+
|
| 117 |
+
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
|
| 118 |
+
return self._env.step(raw_response=raw_response, episode_id=episode_id)
|
| 119 |
+
|
| 120 |
+
def health(self) -> dict:
|
| 121 |
+
return self._env.health()
|
| 122 |
+
|
| 123 |
+
def close(self) -> None: # nothing to clean up
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def make_default_client() -> _ClientProtocol:
|
| 128 |
+
"""Return :class:`DecoderClient` if ``QUBIT_MEDIC_URL`` is set, else local."""
|
| 129 |
+
url = os.getenv("QUBIT_MEDIC_URL")
|
| 130 |
+
if url:
|
| 131 |
+
return DecoderClient(url)
|
| 132 |
+
return LocalDecoderClient()
|
qubit_medic/config.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Locked experiment configuration (Section 1.4 of the plan).
|
| 2 |
+
|
| 3 |
+
Every magic number in the project lives here. Do not hard-code circuit
|
| 4 |
+
parameters, noise rates, or model identifiers anywhere else; import them
|
| 5 |
+
from this module instead.
|
| 6 |
+
|
| 7 |
+
Cited literature
|
| 8 |
+
----------------
|
| 9 |
+
Bausch et al., AlphaQubit, *Nature* 635:834 (2024)
|
| 10 |
+
DOI: 10.1038/s41586-024-08148-8
|
| 11 |
+
https://www.nature.com/articles/s41586-024-08148-8
|
| 12 |
+
Acharya et al. (Google QAI), *Willow*, arXiv:2408.13687 (2024)
|
| 13 |
+
https://arxiv.org/abs/2408.13687
|
| 14 |
+
Gidney & Fowler, *SI1000*, arXiv:2108.10457 (2021)
|
| 15 |
+
https://arxiv.org/abs/2108.10457
|
| 16 |
+
Higgott & Gidney, *PyMatching v2*, arXiv:2303.15933 (2023)
|
| 17 |
+
https://arxiv.org/abs/2303.15933
|
| 18 |
+
Shao et al., *DeepSeekMath / GRPO*, arXiv:2402.03300 (2024)
|
| 19 |
+
https://arxiv.org/abs/2402.03300
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
from typing import Mapping
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# --------------------------------------------------------------------------- #
|
| 28 |
+
# Quantum code geometry #
|
| 29 |
+
# --------------------------------------------------------------------------- #
|
| 30 |
+
|
| 31 |
+
CODE_TASK = "surface_code:rotated_memory_z"
|
| 32 |
+
"""Stim task identifier. We always use the rotated surface code with a Z
|
| 33 |
+
memory experiment - same family AlphaQubit and Willow report on."""
|
| 34 |
+
|
| 35 |
+
DISTANCE_PRIMARY: int = 3
|
| 36 |
+
"""Distance-3 is the primary benchmark configuration (AlphaQubit Fig. 2b)."""
|
| 37 |
+
|
| 38 |
+
DISTANCE_STRETCH: int = 5
|
| 39 |
+
"""Distance-5 is the stretch-goal configuration for Section 4.3."""
|
| 40 |
+
|
| 41 |
+
ROUNDS_FACTOR: int = 1
|
| 42 |
+
"""rounds = ROUNDS_FACTOR * distance. Value 1 matches the AlphaQubit
|
| 43 |
+
distance-equals-rounds protocol."""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --------------------------------------------------------------------------- #
|
| 47 |
+
# Noise model: SI1000 sub-rates (Gidney & Fowler 2021, Table 1) #
|
| 48 |
+
# --------------------------------------------------------------------------- #
|
| 49 |
+
# SI1000 maps a single physical error budget ``p`` to four operation-specific
|
| 50 |
+
# sub-rates. The factors below come from arXiv:2108.10457 Table 1 and are the
|
| 51 |
+
# *same* values Google's QAI uses in their Willow analyses.
|
| 52 |
+
#
|
| 53 |
+
# Stim's surface_code:rotated_memory_z generator accepts four matching knobs:
|
| 54 |
+
# after_clifford_depolarization (two-qubit gate noise)
|
| 55 |
+
# before_round_data_depolarization (idle data-qubit noise per round)
|
| 56 |
+
# before_measure_flip_probability (measurement noise)
|
| 57 |
+
# after_reset_flip_probability (reset noise)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass(frozen=True)
|
| 61 |
+
class SI1000Rates:
|
| 62 |
+
"""Per-operation error rates derived from a single budget ``p``."""
|
| 63 |
+
|
| 64 |
+
after_clifford_depolarization: float
|
| 65 |
+
before_round_data_depolarization: float
|
| 66 |
+
before_measure_flip_probability: float
|
| 67 |
+
after_reset_flip_probability: float
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_p(cls, p: float) -> "SI1000Rates":
|
| 71 |
+
"""Build SI1000 sub-rates from the headline budget ``p``.
|
| 72 |
+
|
| 73 |
+
The factors are exactly Gidney & Fowler 2021 Table 1.
|
| 74 |
+
"""
|
| 75 |
+
return cls(
|
| 76 |
+
after_clifford_depolarization=p,
|
| 77 |
+
before_round_data_depolarization=p / 10.0,
|
| 78 |
+
before_measure_flip_probability=p * 5.0,
|
| 79 |
+
after_reset_flip_probability=p * 2.0,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def as_stim_kwargs(self) -> Mapping[str, float]:
|
| 83 |
+
"""Return the kwargs dict accepted by ``stim.Circuit.generated``."""
|
| 84 |
+
return {
|
| 85 |
+
"after_clifford_depolarization": self.after_clifford_depolarization,
|
| 86 |
+
"before_round_data_depolarization": self.before_round_data_depolarization,
|
| 87 |
+
"before_measure_flip_probability": self.before_measure_flip_probability,
|
| 88 |
+
"after_reset_flip_probability": self.after_reset_flip_probability,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# --------------------------------------------------------------------------- #
|
| 93 |
+
# Curriculum levels (Section 4) #
|
| 94 |
+
# --------------------------------------------------------------------------- #
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@dataclass(frozen=True)
|
| 98 |
+
class CurriculumLevel:
|
| 99 |
+
"""One rung on the difficulty ladder."""
|
| 100 |
+
|
| 101 |
+
name: str
|
| 102 |
+
distance: int
|
| 103 |
+
rounds: int
|
| 104 |
+
p: float
|
| 105 |
+
promotion_threshold: float # logical-correction rate at which we move on
|
| 106 |
+
eval_size: int # held-out shots used to test promotion
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
CURRICULUM: tuple[CurriculumLevel, ...] = (
|
| 110 |
+
CurriculumLevel(
|
| 111 |
+
name="L1_warmup",
|
| 112 |
+
distance=DISTANCE_PRIMARY,
|
| 113 |
+
rounds=1,
|
| 114 |
+
p=0.0001,
|
| 115 |
+
promotion_threshold=0.80,
|
| 116 |
+
eval_size=100,
|
| 117 |
+
),
|
| 118 |
+
CurriculumLevel(
|
| 119 |
+
name="L2_target",
|
| 120 |
+
distance=DISTANCE_PRIMARY,
|
| 121 |
+
rounds=DISTANCE_PRIMARY,
|
| 122 |
+
p=0.001,
|
| 123 |
+
promotion_threshold=0.70,
|
| 124 |
+
eval_size=200,
|
| 125 |
+
),
|
| 126 |
+
CurriculumLevel(
|
| 127 |
+
name="L3_stretch",
|
| 128 |
+
distance=DISTANCE_STRETCH,
|
| 129 |
+
rounds=DISTANCE_STRETCH,
|
| 130 |
+
p=0.001,
|
| 131 |
+
promotion_threshold=0.30, # stretch goal - even partial counts
|
| 132 |
+
eval_size=200,
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# --------------------------------------------------------------------------- #
|
| 138 |
+
# Reward weights (Section 3) - sum to 1.0 by construction #
|
| 139 |
+
# --------------------------------------------------------------------------- #
|
| 140 |
+
|
| 141 |
+
REWARD_WEIGHTS: dict[str, float] = {
|
| 142 |
+
"logical_correction": 0.40, # Reward 1 - the unfakeable ground truth
|
| 143 |
+
"syndrome_consistency": 0.20, # Reward 2 - prevents lucky-guess attacks
|
| 144 |
+
"hamming_overlap": 0.20, # Reward 3 - dense partial credit
|
| 145 |
+
"format_compliance": 0.10, # Reward 4 - parser must succeed
|
| 146 |
+
"pymatching_beat": 0.10, # Reward 5 - the headline metric
|
| 147 |
+
}
|
| 148 |
+
assert abs(sum(REWARD_WEIGHTS.values()) - 1.0) < 1e-9, "reward weights must sum to 1"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# --------------------------------------------------------------------------- #
|
| 152 |
+
# Reproducibility #
|
| 153 |
+
# --------------------------------------------------------------------------- #
|
| 154 |
+
|
| 155 |
+
SEEDS: tuple[int, ...] = (42, 1337, 2024)
|
| 156 |
+
"""Three seeds for error bars - never run with anything else."""
|
| 157 |
+
|
| 158 |
+
PRIMARY_SEED: int = SEEDS[0]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# --------------------------------------------------------------------------- #
|
| 162 |
+
# Model + training #
|
| 163 |
+
# --------------------------------------------------------------------------- #
|
| 164 |
+
|
| 165 |
+
MODEL_ID: str = "Qwen/Qwen2.5-3B-Instruct"
|
| 166 |
+
"""3B params, 4-bit quantised + LoRA fits in a Colab T4."""
|
| 167 |
+
|
| 168 |
+
LORA_R: int = 16
|
| 169 |
+
LORA_ALPHA: int = 32
|
| 170 |
+
LORA_TARGET_MODULES: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
|
| 171 |
+
|
| 172 |
+
SFT_EPOCHS: int = 1
|
| 173 |
+
SFT_BATCH_SIZE: int = 4
|
| 174 |
+
SFT_GRAD_ACCUM: int = 4
|
| 175 |
+
SFT_LR: float = 2e-4
|
| 176 |
+
SFT_DATASET_SIZE: int = 5_000
|
| 177 |
+
SFT_MAX_SEQ_LEN: int = 2048
|
| 178 |
+
|
| 179 |
+
GRPO_STEPS: int = 2_000
|
| 180 |
+
GRPO_GEN_PER_PROMPT: int = 4
|
| 181 |
+
GRPO_LR: float = 1e-5
|
| 182 |
+
GRPO_KL_COEF: float = 0.04
|
| 183 |
+
GRPO_MAX_PROMPT_LEN: int = 512
|
| 184 |
+
GRPO_MAX_COMPLETION_LEN: int = 256
|
| 185 |
+
GRPO_CHECKPOINT_EVERY: int = 250
|
| 186 |
+
GRPO_LOG_EVERY: int = 50
|
| 187 |
+
|
| 188 |
+
# Decoding sampler defaults at evaluation/format-test time.
|
| 189 |
+
SAMPLE_TEMPERATURE: float = 0.7
|
| 190 |
+
SAMPLE_TOP_P: float = 0.95
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# --------------------------------------------------------------------------- #
|
| 194 |
+
# Server / deployment #
|
| 195 |
+
# --------------------------------------------------------------------------- #
|
| 196 |
+
|
| 197 |
+
EPISODE_TIMEOUT_SECONDS: float = 30.0
|
| 198 |
+
"""Wall-clock budget per episode (Section 2.6)."""
|
| 199 |
+
|
| 200 |
+
DEFAULT_HOST: str = "0.0.0.0"
|
| 201 |
+
DEFAULT_PORT: int = 7860 # Hugging Face Spaces' default exposed port
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# --------------------------------------------------------------------------- #
|
| 205 |
+
# Weights & Biases #
|
| 206 |
+
# --------------------------------------------------------------------------- #
|
| 207 |
+
# Centralised so the SFT trainer, GRPO trainer, eval script, and notebook
|
| 208 |
+
# all log to the same project / dashboard. Override per-run on the CLI.
|
| 209 |
+
import os as _os # noqa: E402 (local import to keep top of module clean)
|
| 210 |
+
|
| 211 |
+
WANDB_PROJECT: str = _os.environ.get("WANDB_PROJECT", "qubit-medic")
|
| 212 |
+
"""Default W&B project name. Override with ``WANDB_PROJECT=...``."""
|
| 213 |
+
|
| 214 |
+
WANDB_ENTITY: str | None = _os.environ.get("WANDB_ENTITY") or None
|
| 215 |
+
"""W&B team or username. ``None`` -> wandb's default entity for the user."""
|
| 216 |
+
|
| 217 |
+
WANDB_DEFAULT_TAGS: tuple[str, ...] = (
|
| 218 |
+
"qubit-medic",
|
| 219 |
+
"quantum-error-correction",
|
| 220 |
+
"openenv",
|
| 221 |
+
f"distance-{DISTANCE_PRIMARY}",
|
| 222 |
+
"si1000",
|
| 223 |
+
)
|
| 224 |
+
"""Tags applied to every W&B run (per-script tags appended on top)."""
|
| 225 |
+
|
| 226 |
+
WANDB_LOG_GENERATIONS_EVERY: int = 50
|
| 227 |
+
"""Log a sample-completion table every N GRPO steps."""
|
| 228 |
+
|
| 229 |
+
WANDB_SAMPLE_GENERATIONS: int = 8
|
| 230 |
+
"""Number of generations included in each sample-completion table."""
|
| 231 |
+
|
| 232 |
+
WANDB_INLOOP_EVAL_EVERY: int = 200
|
| 233 |
+
"""Run an in-loop evaluation pass (deterministic, ``WANDB_INLOOP_EVAL_EPISODES``
|
| 234 |
+
syndromes) every N GRPO steps. Set to 0 to disable."""
|
| 235 |
+
|
| 236 |
+
WANDB_INLOOP_EVAL_EPISODES: int = 50
|
| 237 |
+
"""Number of held-out syndromes per in-loop eval pass (kept small for speed)."""
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# --------------------------------------------------------------------------- #
|
| 241 |
+
# Convenience accessors #
|
| 242 |
+
# --------------------------------------------------------------------------- #
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def level_by_name(name: str) -> CurriculumLevel:
|
| 246 |
+
for lvl in CURRICULUM:
|
| 247 |
+
if lvl.name == name:
|
| 248 |
+
return lvl
|
| 249 |
+
raise KeyError(f"unknown curriculum level {name!r}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def primary_level() -> CurriculumLevel:
|
| 253 |
+
"""The L2 target benchmark - what the headline numbers come from."""
|
| 254 |
+
return level_by_name("L2_target")
|
qubit_medic/models.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic data classes shared by client and server (Section 2.2 of the plan).
|
| 2 |
+
|
| 3 |
+
Three classes draw the trust boundary:
|
| 4 |
+
|
| 5 |
+
* ``DecoderObservation`` - what the LLM sees on each step.
|
| 6 |
+
* ``DecoderAction`` - what the LLM emits (after parsing).
|
| 7 |
+
* ``DecoderState`` - server-side state, never serialised to the client.
|
| 8 |
+
|
| 9 |
+
Keeping the wire schema explicit is what closes off reward-hacking attacks:
|
| 10 |
+
the LLM literally cannot reach into the ``true_error_pattern`` because that
|
| 11 |
+
field is not in any class it ever receives.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from typing import Any, Optional
|
| 16 |
+
|
| 17 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# --------------------------------------------------------------------------- #
|
| 21 |
+
# Wire types - sent across the OpenEnv HTTP boundary #
|
| 22 |
+
# --------------------------------------------------------------------------- #
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DecoderObservation(BaseModel):
|
| 26 |
+
"""The view the LLM (and only the LLM) sees on each step."""
|
| 27 |
+
|
| 28 |
+
model_config = ConfigDict(frozen=True)
|
| 29 |
+
|
| 30 |
+
prompt: str = Field(
|
| 31 |
+
...,
|
| 32 |
+
description=(
|
| 33 |
+
"Pre-formatted prompt string. This is exactly what the trainer "
|
| 34 |
+
"passes to the policy - it appears verbatim in training logs."
|
| 35 |
+
),
|
| 36 |
+
)
|
| 37 |
+
syndrome_bits: list[int] = Field(
|
| 38 |
+
...,
|
| 39 |
+
description=(
|
| 40 |
+
"Raw detector activations (0/1). Provided for debugging and "
|
| 41 |
+
"reward-hacking audits; the LLM should be reading the prompt, not "
|
| 42 |
+
"this array."
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
distance: int = Field(..., description="Code distance for this episode.")
|
| 46 |
+
rounds: int = Field(..., description="Number of stabiliser rounds.")
|
| 47 |
+
p: float = Field(..., description="Physical error budget (SI1000 base).")
|
| 48 |
+
curriculum_level: str = Field(..., description="Curriculum level name.")
|
| 49 |
+
episode_id: int = Field(..., description="Monotonic episode counter.")
|
| 50 |
+
dem_digest: str = Field(
|
| 51 |
+
...,
|
| 52 |
+
description=(
|
| 53 |
+
"Short hash of the detector error model used this episode. The "
|
| 54 |
+
"trainer logs this so we can group rollouts by physics config."
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class DecoderAction(BaseModel):
|
| 60 |
+
"""Action emitted by the LLM after parsing.
|
| 61 |
+
|
| 62 |
+
``raw_response`` is preserved exactly so we can satisfy the participant
|
| 63 |
+
guide's *inspect generations* mandate (Section 2.5 of the plan).
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
model_config = ConfigDict(frozen=True)
|
| 67 |
+
|
| 68 |
+
x_error_qubits: list[int] = Field(default_factory=list)
|
| 69 |
+
z_error_qubits: list[int] = Field(default_factory=list)
|
| 70 |
+
raw_response: str = ""
|
| 71 |
+
parse_success: bool = True
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class StepResult(BaseModel):
|
| 75 |
+
"""Standard env step return (mirrors OpenEnv core/Gymnasium)."""
|
| 76 |
+
|
| 77 |
+
observation: DecoderObservation
|
| 78 |
+
reward: float
|
| 79 |
+
done: bool
|
| 80 |
+
truncated: bool = False
|
| 81 |
+
info: dict[str, Any] = Field(default_factory=dict)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ResetRequest(BaseModel):
|
| 85 |
+
"""Optional knobs the trainer can pass to ``reset``."""
|
| 86 |
+
|
| 87 |
+
seed: Optional[int] = None
|
| 88 |
+
forced_level: Optional[str] = Field(
|
| 89 |
+
default=None,
|
| 90 |
+
description=(
|
| 91 |
+
"Override the curriculum scheduler. Used by eval scripts that "
|
| 92 |
+
"want a specific (distance, rounds, p) configuration."
|
| 93 |
+
),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class StepRequest(BaseModel):
|
| 98 |
+
"""The trainer sends the LLM's raw text; the server parses + scores."""
|
| 99 |
+
|
| 100 |
+
raw_response: str
|
| 101 |
+
episode_id: int
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# --------------------------------------------------------------------------- #
|
| 105 |
+
# Server-only state - intentionally NOT a wire type #
|
| 106 |
+
# --------------------------------------------------------------------------- #
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class DecoderState(BaseModel):
|
| 110 |
+
"""Per-episode state kept on the server; never sent to the client.
|
| 111 |
+
|
| 112 |
+
Pydantic ``arbitrary_types_allowed`` is on because we hold a reference to
|
| 113 |
+
a ``stim.Circuit`` object. The state is not serialised over HTTP - it
|
| 114 |
+
lives in the server's per-episode dict and is discarded on ``done``.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=False)
|
| 118 |
+
|
| 119 |
+
episode_id: int
|
| 120 |
+
seed: int
|
| 121 |
+
curriculum_level: str
|
| 122 |
+
distance: int
|
| 123 |
+
rounds: int
|
| 124 |
+
p: float
|
| 125 |
+
|
| 126 |
+
syndrome_bits: list[int]
|
| 127 |
+
true_x_errors: list[int]
|
| 128 |
+
true_z_errors: list[int]
|
| 129 |
+
actual_observable_flip: int # 0 or 1; the unfakeable ground truth
|
| 130 |
+
pymatching_observable_pred: int # 0 or 1; baseline's prediction
|
| 131 |
+
|
| 132 |
+
# Pre-computed quantities the reward functions need.
|
| 133 |
+
x_observable_support: list[int] # data qubits whose Z error flips X obs
|
| 134 |
+
z_observable_support: list[int] # data qubits whose X error flips Z obs
|
| 135 |
+
num_data_qubits: int
|
| 136 |
+
num_stabilizers: int
|
| 137 |
+
|
| 138 |
+
# Stim/PyMatching objects - kept opaque to satisfy Pydantic.
|
| 139 |
+
circuit_text: str
|
| 140 |
+
dem_text: str
|
| 141 |
+
|
| 142 |
+
# Reward audit log.
|
| 143 |
+
last_reward_breakdown: dict[str, float] = Field(default_factory=dict)
|
qubit_medic/prompts.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt formatter and action parser (Section 2.3 + Section 2.5 of the plan).
|
| 2 |
+
|
| 3 |
+
The prompt is engineered around five sections:
|
| 4 |
+
|
| 5 |
+
1. Role declaration
|
| 6 |
+
2. Physics summary (~50 tokens, plain English)
|
| 7 |
+
3. Syndrome data (round-by-round, labelled)
|
| 8 |
+
4. Output format spec (one example included)
|
| 9 |
+
5. Reasoning trigger ("think step by step ...")
|
| 10 |
+
|
| 11 |
+
Total budget ~250-300 tokens for the prompt; ~150 for the response.
|
| 12 |
+
|
| 13 |
+
The parser is deliberately permissive on whitespace and bracket style but
|
| 14 |
+
strict on the existence of the two key tokens ``X_ERRORS`` and ``Z_ERRORS``.
|
| 15 |
+
A partial-credit hook is exposed so Reward 4 can hand out 0.5 for "partly
|
| 16 |
+
parseable".
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import re
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import Iterable
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# --------------------------------------------------------------------------- #
|
| 26 |
+
# Prompt formatting #
|
| 27 |
+
# --------------------------------------------------------------------------- #
|
| 28 |
+
|
| 29 |
+
_ROLE = (
|
| 30 |
+
"You are a quantum error-correction decoder. You are decoding errors in "
|
| 31 |
+
"a distance-{distance} rotated surface code memory experiment."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
_PHYSICS_SUMMARY = (
|
| 35 |
+
"Stabilizers are parity checks measured every round. A *syndrome bit* "
|
| 36 |
+
"is 1 when a stabilizer's measurement disagrees with its previous round, "
|
| 37 |
+
"indicating a nearby physical error. Your job is to look at the syndrome "
|
| 38 |
+
"history and output the smallest physical error pattern (X-flips and "
|
| 39 |
+
"Z-flips on data qubits, identified by integer IDs) that explains it."
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_OUTPUT_SPEC = (
|
| 43 |
+
"Output format (REQUIRED, exact):\n"
|
| 44 |
+
" X_ERRORS=[id1,id2,...] Z_ERRORS=[id1,id2,...]\n"
|
| 45 |
+
"Use empty lists when no errors of that type. Example with no errors:\n"
|
| 46 |
+
" X_ERRORS=[] Z_ERRORS=[]"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
_REASONING_TRIGGER = (
|
| 50 |
+
"Think step by step about which qubits could have caused this syndrome, "
|
| 51 |
+
"then output your prediction in the required format."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def format_syndrome_block(
|
| 56 |
+
syndrome_bits: Iterable[int],
|
| 57 |
+
rounds: int,
|
| 58 |
+
num_x_stabilizers: int,
|
| 59 |
+
num_z_stabilizers: int,
|
| 60 |
+
) -> str:
|
| 61 |
+
"""Render the detector activations round-by-round.
|
| 62 |
+
|
| 63 |
+
Stim emits detectors in a flat row-major order: round 0 stabilisers first,
|
| 64 |
+
then round 1, and so on. We label by round and stabiliser type so the LLM
|
| 65 |
+
can read the temporal structure.
|
| 66 |
+
"""
|
| 67 |
+
bits = list(syndrome_bits)
|
| 68 |
+
per_round = num_x_stabilizers + num_z_stabilizers
|
| 69 |
+
lines = ["Syndrome (round-by-round):"]
|
| 70 |
+
if per_round == 0 or rounds == 0 or len(bits) == 0:
|
| 71 |
+
lines.append(" (no detectors fired)")
|
| 72 |
+
return "\n".join(lines)
|
| 73 |
+
|
| 74 |
+
for r in range(rounds):
|
| 75 |
+
offset = r * per_round
|
| 76 |
+
if offset >= len(bits):
|
| 77 |
+
break
|
| 78 |
+
chunk = bits[offset : offset + per_round]
|
| 79 |
+
x_chunk = chunk[:num_x_stabilizers]
|
| 80 |
+
z_chunk = chunk[num_x_stabilizers : num_x_stabilizers + num_z_stabilizers]
|
| 81 |
+
lines.append(
|
| 82 |
+
f" Round {r + 1} X-stabilizers: "
|
| 83 |
+
+ " ".join(str(b) for b in x_chunk)
|
| 84 |
+
)
|
| 85 |
+
lines.append(
|
| 86 |
+
f" Round {r + 1} Z-stabilizers: "
|
| 87 |
+
+ " ".join(str(b) for b in z_chunk)
|
| 88 |
+
)
|
| 89 |
+
# Trailing block for the final destructive measurement, if any extras.
|
| 90 |
+
used = rounds * per_round
|
| 91 |
+
if used < len(bits):
|
| 92 |
+
tail = bits[used:]
|
| 93 |
+
lines.append(" Final-round detectors: " + " ".join(str(b) for b in tail))
|
| 94 |
+
return "\n".join(lines)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def build_prompt(
|
| 98 |
+
*,
|
| 99 |
+
distance: int,
|
| 100 |
+
rounds: int,
|
| 101 |
+
p: float,
|
| 102 |
+
syndrome_bits: list[int],
|
| 103 |
+
num_x_stabilizers: int,
|
| 104 |
+
num_z_stabilizers: int,
|
| 105 |
+
num_data_qubits: int,
|
| 106 |
+
) -> str:
|
| 107 |
+
"""Assemble the full prompt the LLM sees on each step.
|
| 108 |
+
|
| 109 |
+
Keeping this function pure (no I/O, no globals) means the SFT pipeline
|
| 110 |
+
and the GRPO rollout use byte-identical inputs - a critical invariant.
|
| 111 |
+
"""
|
| 112 |
+
syndrome_block = format_syndrome_block(
|
| 113 |
+
syndrome_bits=syndrome_bits,
|
| 114 |
+
rounds=rounds,
|
| 115 |
+
num_x_stabilizers=num_x_stabilizers,
|
| 116 |
+
num_z_stabilizers=num_z_stabilizers,
|
| 117 |
+
)
|
| 118 |
+
return (
|
| 119 |
+
_ROLE.format(distance=distance)
|
| 120 |
+
+ "\n\n"
|
| 121 |
+
+ _PHYSICS_SUMMARY
|
| 122 |
+
+ "\n\n"
|
| 123 |
+
+ f"Code parameters: distance={distance}, rounds={rounds}, "
|
| 124 |
+
+ f"physical_error_rate={p:g}, data_qubits=0..{num_data_qubits - 1}.\n\n"
|
| 125 |
+
+ syndrome_block
|
| 126 |
+
+ "\n\n"
|
| 127 |
+
+ _OUTPUT_SPEC
|
| 128 |
+
+ "\n\n"
|
| 129 |
+
+ _REASONING_TRIGGER
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# --------------------------------------------------------------------------- #
|
| 134 |
+
# Output parsing #
|
| 135 |
+
# --------------------------------------------------------------------------- #
|
| 136 |
+
|
| 137 |
+
_X_PATTERN = re.compile(r"X_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
|
| 138 |
+
_Z_PATTERN = re.compile(r"Z_ERRORS\s*=\s*\[([^\]]*)\]", re.IGNORECASE)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclass(frozen=True)
|
| 142 |
+
class ParseResult:
|
| 143 |
+
x_errors: list[int]
|
| 144 |
+
z_errors: list[int]
|
| 145 |
+
parse_success: bool # True iff BOTH X_ERRORS and Z_ERRORS parsed cleanly
|
| 146 |
+
parse_partial: bool # True iff exactly one of the two parsed cleanly
|
| 147 |
+
raw_response: str
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def format_score(self) -> float:
|
| 151 |
+
"""Score for Reward 4 (format compliance)."""
|
| 152 |
+
if self.parse_success:
|
| 153 |
+
return 1.0
|
| 154 |
+
if self.parse_partial:
|
| 155 |
+
return 0.5
|
| 156 |
+
return 0.0
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _parse_int_list(s: str, max_qubit: int) -> tuple[list[int], bool]:
|
| 160 |
+
"""Parse a comma/space-separated integer list. Drops out-of-range and dups.
|
| 161 |
+
|
| 162 |
+
Returns ``(qubits_sorted_unique, all_tokens_were_valid)``.
|
| 163 |
+
"""
|
| 164 |
+
if not s.strip():
|
| 165 |
+
return [], True
|
| 166 |
+
raw_tokens = re.split(r"[\s,]+", s.strip())
|
| 167 |
+
out: list[int] = []
|
| 168 |
+
all_clean = True
|
| 169 |
+
for tok in raw_tokens:
|
| 170 |
+
if not tok:
|
| 171 |
+
continue
|
| 172 |
+
try:
|
| 173 |
+
v = int(tok)
|
| 174 |
+
except ValueError:
|
| 175 |
+
all_clean = False
|
| 176 |
+
continue
|
| 177 |
+
if 0 <= v < max_qubit:
|
| 178 |
+
out.append(v)
|
| 179 |
+
else:
|
| 180 |
+
all_clean = False
|
| 181 |
+
return sorted(set(out)), all_clean
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def parse_action(raw_response: str, num_data_qubits: int) -> ParseResult:
|
| 185 |
+
"""Convert the LLM's raw text to a ``ParseResult``.
|
| 186 |
+
|
| 187 |
+
Tolerant of trailing chain-of-thought, surrounding code fences, and
|
| 188 |
+
casing, but strict on the existence of both ``X_ERRORS`` and ``Z_ERRORS``.
|
| 189 |
+
"""
|
| 190 |
+
if not isinstance(raw_response, str):
|
| 191 |
+
return ParseResult([], [], False, False, raw_response="")
|
| 192 |
+
|
| 193 |
+
# If the model wrapped its answer in ```...``` blocks, focus on the last one.
|
| 194 |
+
fenced = re.findall(r"```(?:[^\n]*)\n(.*?)```", raw_response, re.DOTALL)
|
| 195 |
+
search_text = fenced[-1] if fenced else raw_response
|
| 196 |
+
|
| 197 |
+
x_match = _X_PATTERN.search(search_text)
|
| 198 |
+
z_match = _Z_PATTERN.search(search_text)
|
| 199 |
+
|
| 200 |
+
x_errors: list[int] = []
|
| 201 |
+
z_errors: list[int] = []
|
| 202 |
+
x_clean = z_clean = False
|
| 203 |
+
|
| 204 |
+
if x_match is not None:
|
| 205 |
+
x_errors, x_clean = _parse_int_list(x_match.group(1), num_data_qubits)
|
| 206 |
+
if z_match is not None:
|
| 207 |
+
z_errors, z_clean = _parse_int_list(z_match.group(1), num_data_qubits)
|
| 208 |
+
|
| 209 |
+
x_present = x_match is not None and x_clean
|
| 210 |
+
z_present = z_match is not None and z_clean
|
| 211 |
+
parse_success = x_present and z_present
|
| 212 |
+
parse_partial = (x_present ^ z_present) or (
|
| 213 |
+
# Both keys present but at least one had garbage tokens.
|
| 214 |
+
(x_match is not None and z_match is not None) and not parse_success
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return ParseResult(
|
| 218 |
+
x_errors=x_errors,
|
| 219 |
+
z_errors=z_errors,
|
| 220 |
+
parse_success=parse_success,
|
| 221 |
+
parse_partial=parse_partial,
|
| 222 |
+
raw_response=raw_response,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def format_completion(x_errors: Iterable[int], z_errors: Iterable[int]) -> str:
|
| 227 |
+
"""The canonical SFT target string. Inverse of :func:`parse_action`."""
|
| 228 |
+
x_str = ",".join(str(q) for q in sorted(set(x_errors)))
|
| 229 |
+
z_str = ",".join(str(q) for q in sorted(set(z_errors)))
|
| 230 |
+
return f"X_ERRORS=[{x_str}] Z_ERRORS=[{z_str}]"
|
qubit_medic/server/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Server-side modules: physics, rewards, curriculum, FastAPI app.
|
| 2 |
+
|
| 3 |
+
Sub-modules are imported lazily on first attribute access to avoid
|
| 4 |
+
circular imports during partial initialisation.
|
| 5 |
+
"""
|
qubit_medic/server/app.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Qubit-Medic FastAPI server.
|
| 2 |
+
|
| 3 |
+
Built on **openenv-core** ``create_fastapi_app`` so the canonical OpenEnv
|
| 4 |
+
routes (``/reset``, ``/step``, ``/state``, ``/health``, ``/schema``,
|
| 5 |
+
``/metadata``, ``/mcp``) are wired automatically by the framework.
|
| 6 |
+
|
| 7 |
+
We add a few extras on top:
|
| 8 |
+
|
| 9 |
+
* ``GET /healthz`` - the Day-0 deployment-substrate liveness probe
|
| 10 |
+
(returns Stim/PyMatching/openenv versions). Used by the recurring
|
| 11 |
+
4-hour HF Spaces wakeup ping.
|
| 12 |
+
* ``POST /decode`` - PyMatching baseline demo: takes a hand-crafted
|
| 13 |
+
syndrome and returns the matching-decoder's prediction. Useful for
|
| 14 |
+
the Gradio playground.
|
| 15 |
+
|
| 16 |
+
Run with ``python -m qubit_medic.server.app`` or
|
| 17 |
+
``uvicorn qubit_medic.server.app:app --host 0.0.0.0 --port 7860``.
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
from fastapi import Body, HTTPException
|
| 27 |
+
from openenv.core import create_fastapi_app
|
| 28 |
+
|
| 29 |
+
from qubit_medic.config import DEFAULT_HOST, DEFAULT_PORT
|
| 30 |
+
from qubit_medic.server.environment import DecoderEnvironment
|
| 31 |
+
from qubit_medic.server.openenv_adapter import (
|
| 32 |
+
QubitMedicAction,
|
| 33 |
+
QubitMedicEnvironment,
|
| 34 |
+
QubitMedicObservation,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger("qubit_medic.server")
|
| 39 |
+
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# --------------------------------------------------------------------------- #
|
| 43 |
+
# Build the OpenEnv-compliant FastAPI app #
|
| 44 |
+
# --------------------------------------------------------------------------- #
|
| 45 |
+
|
| 46 |
+
app = create_fastapi_app(
|
| 47 |
+
env=QubitMedicEnvironment,
|
| 48 |
+
action_cls=QubitMedicAction,
|
| 49 |
+
observation_cls=QubitMedicObservation,
|
| 50 |
+
)
|
| 51 |
+
app.title = "Qubit-Medic OpenEnv"
|
| 52 |
+
app.version = os.getenv("QUBIT_MEDIC_VERSION", "1.0.0")
|
| 53 |
+
app.description = (
|
| 54 |
+
"RL training environment for LLM-based quantum error-correction "
|
| 55 |
+
"decoders. Built on Stim + PyMatching with five independent verifiable "
|
| 56 |
+
"rewards (logical correction, syndrome consistency, Hamming overlap, "
|
| 57 |
+
"format compliance, PyMatching beat-rate). Wraps "
|
| 58 |
+
"qubit_medic.server.environment.DecoderEnvironment in "
|
| 59 |
+
"openenv.core.Environment - see /metadata, /schema, /docs."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# --------------------------------------------------------------------------- #
|
| 64 |
+
# Day-0 + demo extras #
|
| 65 |
+
# --------------------------------------------------------------------------- #
|
| 66 |
+
|
| 67 |
+
# Lazy-built *legacy* DecoderEnvironment for /decode demos. The OpenEnv
|
| 68 |
+
# adapter has its own per-instance DecoderEnvironment; we keep this one
|
| 69 |
+
# around for the simple synchronous `/decode` baseline endpoint.
|
| 70 |
+
_legacy_env: Optional[DecoderEnvironment] = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _get_legacy_env() -> DecoderEnvironment:
|
| 74 |
+
global _legacy_env
|
| 75 |
+
if _legacy_env is None:
|
| 76 |
+
_legacy_env = DecoderEnvironment()
|
| 77 |
+
_legacy_env._cache_for("L1_warmup") # noqa: SLF001
|
| 78 |
+
_legacy_env._cache_for("L2_target") # noqa: SLF001
|
| 79 |
+
return _legacy_env
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.get("/healthz")
|
| 83 |
+
def healthz() -> dict:
|
| 84 |
+
"""Lightweight liveness probe (Day-0 deployment-substrate test).
|
| 85 |
+
|
| 86 |
+
Returns the versions of Stim, PyMatching, and openenv so curl-ing
|
| 87 |
+
this in a browser or from Colab proves both that networking works
|
| 88 |
+
AND that the heavy quantum + RL deps actually loaded. Used by the
|
| 89 |
+
recurring 4-hour HF Spaces wakeup ping.
|
| 90 |
+
"""
|
| 91 |
+
import stim
|
| 92 |
+
try:
|
| 93 |
+
import pymatching as _pm
|
| 94 |
+
pm_v = getattr(_pm, "__version__", "unknown")
|
| 95 |
+
except Exception as exc: # pragma: no cover - defensive
|
| 96 |
+
pm_v = f"import-error: {exc}"
|
| 97 |
+
try:
|
| 98 |
+
import openenv as _oe
|
| 99 |
+
oe_v = getattr(_oe, "__version__", "unknown")
|
| 100 |
+
except Exception as exc: # pragma: no cover - defensive
|
| 101 |
+
oe_v = f"import-error: {exc}"
|
| 102 |
+
return {
|
| 103 |
+
"ok": True,
|
| 104 |
+
"service": "qubit-medic",
|
| 105 |
+
"version": app.version,
|
| 106 |
+
"stim_version": stim.__version__,
|
| 107 |
+
"pymatching_version": pm_v,
|
| 108 |
+
"openenv_version": oe_v,
|
| 109 |
+
"python_version": sys.version.split()[0],
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@app.post("/decode")
|
| 114 |
+
def decode(
|
| 115 |
+
syndrome: list[int] = Body(..., embed=True),
|
| 116 |
+
level: str = Body("L2_target", embed=True),
|
| 117 |
+
) -> dict:
|
| 118 |
+
"""Decode an arbitrary syndrome with PyMatching (baseline) and return
|
| 119 |
+
its predicted Pauli frame and observable flip.
|
| 120 |
+
|
| 121 |
+
Intended for the live Gradio demo: a notebook or web page can POST a
|
| 122 |
+
hand-crafted syndrome here and visualise the matching-decoder result.
|
| 123 |
+
"""
|
| 124 |
+
import numpy as np
|
| 125 |
+
|
| 126 |
+
env = _get_legacy_env()
|
| 127 |
+
cache = env._cache_for(level) # noqa: SLF001
|
| 128 |
+
arr = np.asarray(syndrome, dtype=np.uint8)
|
| 129 |
+
if arr.shape[0] != cache.layout.num_detectors:
|
| 130 |
+
raise HTTPException(
|
| 131 |
+
status_code=400,
|
| 132 |
+
detail=f"syndrome length {arr.shape[0]} != "
|
| 133 |
+
f"{cache.layout.num_detectors} expected for {level}",
|
| 134 |
+
)
|
| 135 |
+
from qubit_medic.server.physics import (
|
| 136 |
+
predicted_observable_flip,
|
| 137 |
+
pymatching_predicted_pauli_frame,
|
| 138 |
+
)
|
| 139 |
+
pm_obs = int(cache.matching.decode(arr)[0])
|
| 140 |
+
px, pz = pymatching_predicted_pauli_frame(cache.matching, arr, cache.layout)
|
| 141 |
+
return {
|
| 142 |
+
"level": level,
|
| 143 |
+
"syndrome": syndrome,
|
| 144 |
+
"pymatching_observable_flip": pm_obs,
|
| 145 |
+
"pymatching_x_errors": px,
|
| 146 |
+
"pymatching_z_errors": pz,
|
| 147 |
+
"implied_observable_from_x_errors": predicted_observable_flip(
|
| 148 |
+
px, cache.layout
|
| 149 |
+
),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# --------------------------------------------------------------------------- #
|
| 154 |
+
# Local entry point #
|
| 155 |
+
# --------------------------------------------------------------------------- #
|
| 156 |
+
|
| 157 |
+
def _main() -> None:
|
| 158 |
+
import uvicorn
|
| 159 |
+
|
| 160 |
+
uvicorn.run(
|
| 161 |
+
"qubit_medic.server.app:app",
|
| 162 |
+
host=os.getenv("QUBIT_MEDIC_HOST", DEFAULT_HOST),
|
| 163 |
+
port=int(os.getenv("QUBIT_MEDIC_PORT", str(DEFAULT_PORT))),
|
| 164 |
+
log_level=os.getenv("LOG_LEVEL", "info").lower(),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
_main()
|
qubit_medic/server/curriculum.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adaptive curriculum scheduler (Section 4.4 of the plan).
|
| 2 |
+
|
| 3 |
+
Maintains a moving-average logical-correction rate per level and promotes
|
| 4 |
+
the agent to harder levels once the threshold is met. Implements the
|
| 5 |
+
Section 4.4 mixing rules:
|
| 6 |
+
|
| 7 |
+
* Stay at L1 until L1 hits 80%.
|
| 8 |
+
* Then mix L1/L2 with weights 30/70 until L2 hits 70%.
|
| 9 |
+
* Then unlock L3 at 30% weight (with L1/L2 sharing the remaining 70%).
|
| 10 |
+
|
| 11 |
+
The scheduler is *override-able* - eval scripts pass ``forced_level`` to
|
| 12 |
+
hold one configuration steady.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
from collections import deque
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from qubit_medic.config import CURRICULUM, CurriculumLevel, level_by_name
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# --------------------------------------------------------------------------- #
|
| 25 |
+
# Per-level moving average #
|
| 26 |
+
# --------------------------------------------------------------------------- #
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class _MovingWindow:
|
| 31 |
+
window_size: int = 100
|
| 32 |
+
history: deque[float] = field(default_factory=deque)
|
| 33 |
+
|
| 34 |
+
def push(self, value: float) -> None:
|
| 35 |
+
self.history.append(value)
|
| 36 |
+
while len(self.history) > self.window_size:
|
| 37 |
+
self.history.popleft()
|
| 38 |
+
|
| 39 |
+
def mean(self) -> float:
|
| 40 |
+
return sum(self.history) / len(self.history) if self.history else 0.0
|
| 41 |
+
|
| 42 |
+
def __len__(self) -> int:
|
| 43 |
+
return len(self.history)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --------------------------------------------------------------------------- #
|
| 47 |
+
# Scheduler #
|
| 48 |
+
# --------------------------------------------------------------------------- #
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class CurriculumScheduler:
|
| 53 |
+
"""Picks a curriculum level for each new episode."""
|
| 54 |
+
|
| 55 |
+
rng: random.Random = field(default_factory=lambda: random.Random(42))
|
| 56 |
+
windows: dict[str, _MovingWindow] = field(default_factory=dict)
|
| 57 |
+
|
| 58 |
+
def __post_init__(self) -> None:
|
| 59 |
+
for lvl in CURRICULUM:
|
| 60 |
+
self.windows.setdefault(lvl.name, _MovingWindow())
|
| 61 |
+
|
| 62 |
+
# ----- public API -----------------------------------------------------
|
| 63 |
+
|
| 64 |
+
def update(self, level_name: str, logical_correction: float) -> None:
|
| 65 |
+
"""Record one episode's logical-correction outcome."""
|
| 66 |
+
self.windows[level_name].push(float(logical_correction))
|
| 67 |
+
|
| 68 |
+
def sample(self, forced_level: Optional[str] = None) -> CurriculumLevel:
|
| 69 |
+
"""Return the level to use for the next episode."""
|
| 70 |
+
if forced_level is not None:
|
| 71 |
+
return level_by_name(forced_level)
|
| 72 |
+
|
| 73 |
+
l1, l2, l3 = (level_by_name(n) for n in ("L1_warmup", "L2_target", "L3_stretch"))
|
| 74 |
+
l1_rate = self.windows["L1_warmup"].mean()
|
| 75 |
+
l2_rate = self.windows["L2_target"].mean()
|
| 76 |
+
l1_n = len(self.windows["L1_warmup"])
|
| 77 |
+
l2_n = len(self.windows["L2_target"])
|
| 78 |
+
|
| 79 |
+
# Phase A: still working on L1.
|
| 80 |
+
if l1_n < 30 or l1_rate < l1.promotion_threshold:
|
| 81 |
+
return l1
|
| 82 |
+
|
| 83 |
+
# Phase B: L1 unlocked, mixing L1 (30%) and L2 (70%).
|
| 84 |
+
if l2_n < 30 or l2_rate < l2.promotion_threshold:
|
| 85 |
+
return l1 if self.rng.random() < 0.30 else l2
|
| 86 |
+
|
| 87 |
+
# Phase C: L3 unlocked, splits 20% L1, 50% L2, 30% L3.
|
| 88 |
+
roll = self.rng.random()
|
| 89 |
+
if roll < 0.20:
|
| 90 |
+
return l1
|
| 91 |
+
if roll < 0.70:
|
| 92 |
+
return l2
|
| 93 |
+
return l3
|
| 94 |
+
|
| 95 |
+
# ----- introspection (used by /state endpoint and logs) ---------------
|
| 96 |
+
|
| 97 |
+
def stats(self) -> dict[str, dict[str, float]]:
|
| 98 |
+
return {
|
| 99 |
+
name: {
|
| 100 |
+
"moving_mean": w.mean(),
|
| 101 |
+
"samples": float(len(w)),
|
| 102 |
+
}
|
| 103 |
+
for name, w in self.windows.items()
|
| 104 |
+
}
|
qubit_medic/server/environment.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DecoderEnvironment: the OpenEnv-style env that the LLM trainer talks to.
|
| 2 |
+
|
| 3 |
+
This is the heart of the server (Sections 2.4 + 2.5 of the plan):
|
| 4 |
+
|
| 5 |
+
* ``reset()``: pick a curriculum level, build a circuit, sample a syndrome,
|
| 6 |
+
return a :class:`DecoderObservation`.
|
| 7 |
+
* ``step(raw_response)``: parse the LLM's text, score five rewards, return
|
| 8 |
+
a :class:`StepResult` whose ``info`` dict carries the per-component
|
| 9 |
+
breakdown.
|
| 10 |
+
|
| 11 |
+
Episodes are single-step (Section 2.5): the LLM emits one prediction and
|
| 12 |
+
the episode ends.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import threading
|
| 17 |
+
import time
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import pymatching
|
| 22 |
+
|
| 23 |
+
from qubit_medic.config import (
|
| 24 |
+
EPISODE_TIMEOUT_SECONDS,
|
| 25 |
+
PRIMARY_SEED,
|
| 26 |
+
REWARD_WEIGHTS,
|
| 27 |
+
)
|
| 28 |
+
from qubit_medic.models import (
|
| 29 |
+
DecoderAction,
|
| 30 |
+
DecoderObservation,
|
| 31 |
+
DecoderState,
|
| 32 |
+
StepResult,
|
| 33 |
+
)
|
| 34 |
+
from qubit_medic.prompts import build_prompt, parse_action
|
| 35 |
+
from qubit_medic.server import physics
|
| 36 |
+
from qubit_medic.server.curriculum import CurriculumScheduler
|
| 37 |
+
from qubit_medic.server.physics import (
|
| 38 |
+
CircuitLayout,
|
| 39 |
+
SyndromeSample,
|
| 40 |
+
build_circuit,
|
| 41 |
+
build_dem,
|
| 42 |
+
dem_digest,
|
| 43 |
+
extract_layout,
|
| 44 |
+
per_round_x_z_counts,
|
| 45 |
+
sample_episode,
|
| 46 |
+
)
|
| 47 |
+
from qubit_medic.server.rewards import (
|
| 48 |
+
RewardBreakdown,
|
| 49 |
+
compute_all_rewards,
|
| 50 |
+
compute_final_detector_supports,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# --------------------------------------------------------------------------- #
|
| 55 |
+
# Per-level cached compilation - building Stim/PyMatching is the slow step #
|
| 56 |
+
# --------------------------------------------------------------------------- #
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class _LevelCache:
|
| 61 |
+
"""Compiled Stim/PyMatching artefacts for one curriculum level."""
|
| 62 |
+
circuit: object
|
| 63 |
+
dem: object
|
| 64 |
+
matching: pymatching.Matching
|
| 65 |
+
layout: CircuitLayout
|
| 66 |
+
final_detector_supports: dict
|
| 67 |
+
dem_digest: str
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def build(cls, level) -> "_LevelCache":
|
| 71 |
+
c = build_circuit(level)
|
| 72 |
+
d = build_dem(c)
|
| 73 |
+
m = pymatching.Matching.from_detector_error_model(d)
|
| 74 |
+
layout = extract_layout(c)
|
| 75 |
+
supports = compute_final_detector_supports(layout)
|
| 76 |
+
return cls(
|
| 77 |
+
circuit=c,
|
| 78 |
+
dem=d,
|
| 79 |
+
matching=m,
|
| 80 |
+
layout=layout,
|
| 81 |
+
final_detector_supports=supports,
|
| 82 |
+
dem_digest=dem_digest(d),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# --------------------------------------------------------------------------- #
|
| 87 |
+
# DecoderEnvironment #
|
| 88 |
+
# --------------------------------------------------------------------------- #
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class _ActiveEpisode:
|
| 93 |
+
"""In-flight episode bookkeeping."""
|
| 94 |
+
state: DecoderState
|
| 95 |
+
sample: SyndromeSample
|
| 96 |
+
layout: CircuitLayout
|
| 97 |
+
final_detector_supports: dict
|
| 98 |
+
started_at: float
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class DecoderEnvironment:
|
| 102 |
+
"""OpenEnv-style env for surface-code decoding.
|
| 103 |
+
|
| 104 |
+
Thread-safe by virtue of a single ``_lock``: the FastAPI server is
|
| 105 |
+
expected to be I/O bound, and per-call latency is well under a
|
| 106 |
+
millisecond, so a coarse lock is fine and dramatically simplifies the
|
| 107 |
+
state machine.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, *, base_seed: int = PRIMARY_SEED) -> None:
|
| 111 |
+
self._lock = threading.Lock()
|
| 112 |
+
self._scheduler = CurriculumScheduler(rng=__import__("random").Random(base_seed))
|
| 113 |
+
self._caches: dict[str, _LevelCache] = {}
|
| 114 |
+
self._episode_counter = 0
|
| 115 |
+
self._base_seed = base_seed
|
| 116 |
+
self._active: dict[int, _ActiveEpisode] = {}
|
| 117 |
+
|
| 118 |
+
# ----- cache helpers --------------------------------------------------
|
| 119 |
+
|
| 120 |
+
def _cache_for(self, level_name: str):
|
| 121 |
+
cache = self._caches.get(level_name)
|
| 122 |
+
if cache is not None:
|
| 123 |
+
return cache
|
| 124 |
+
from qubit_medic.config import level_by_name
|
| 125 |
+
cache = _LevelCache.build(level_by_name(level_name))
|
| 126 |
+
self._caches[level_name] = cache
|
| 127 |
+
return cache
|
| 128 |
+
|
| 129 |
+
# ----- public API -----------------------------------------------------
|
| 130 |
+
|
| 131 |
+
def reset(
|
| 132 |
+
self,
|
| 133 |
+
*,
|
| 134 |
+
seed: Optional[int] = None,
|
| 135 |
+
forced_level: Optional[str] = None,
|
| 136 |
+
) -> DecoderObservation:
|
| 137 |
+
with self._lock:
|
| 138 |
+
self._episode_counter += 1
|
| 139 |
+
ep_id = self._episode_counter
|
| 140 |
+
shot_seed = seed if seed is not None else self._base_seed + ep_id
|
| 141 |
+
level = self._scheduler.sample(forced_level=forced_level)
|
| 142 |
+
cache = self._cache_for(level.name)
|
| 143 |
+
|
| 144 |
+
sample = sample_episode(
|
| 145 |
+
circuit=cache.circuit,
|
| 146 |
+
matching=cache.matching,
|
| 147 |
+
layout=cache.layout,
|
| 148 |
+
seed=shot_seed,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
state = DecoderState(
|
| 152 |
+
episode_id=ep_id,
|
| 153 |
+
seed=shot_seed,
|
| 154 |
+
curriculum_level=level.name,
|
| 155 |
+
distance=level.distance,
|
| 156 |
+
rounds=level.rounds,
|
| 157 |
+
p=level.p,
|
| 158 |
+
syndrome_bits=sample.syndrome_bits,
|
| 159 |
+
true_x_errors=sample.pymatching_x_errors,
|
| 160 |
+
true_z_errors=sample.pymatching_z_errors,
|
| 161 |
+
actual_observable_flip=sample.actual_observable_flip,
|
| 162 |
+
pymatching_observable_pred=sample.pymatching_observable_pred,
|
| 163 |
+
x_observable_support=[], # memory_z task: no X observable
|
| 164 |
+
z_observable_support=list(cache.layout.z_observable_support),
|
| 165 |
+
num_data_qubits=cache.layout.num_data_qubits,
|
| 166 |
+
num_stabilizers=cache.layout.num_ancilla_qubits,
|
| 167 |
+
circuit_text=str(cache.circuit),
|
| 168 |
+
dem_text=str(cache.dem),
|
| 169 |
+
)
|
| 170 |
+
self._active[ep_id] = _ActiveEpisode(
|
| 171 |
+
state=state,
|
| 172 |
+
sample=sample,
|
| 173 |
+
layout=cache.layout,
|
| 174 |
+
final_detector_supports=cache.final_detector_supports,
|
| 175 |
+
started_at=time.monotonic(),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
n_x, n_z = per_round_x_z_counts(cache.layout)
|
| 179 |
+
prompt = build_prompt(
|
| 180 |
+
distance=level.distance,
|
| 181 |
+
rounds=level.rounds,
|
| 182 |
+
p=level.p,
|
| 183 |
+
syndrome_bits=sample.syndrome_bits,
|
| 184 |
+
num_x_stabilizers=n_x,
|
| 185 |
+
num_z_stabilizers=n_z,
|
| 186 |
+
num_data_qubits=cache.layout.num_data_qubits,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return DecoderObservation(
|
| 190 |
+
prompt=prompt,
|
| 191 |
+
syndrome_bits=sample.syndrome_bits,
|
| 192 |
+
distance=level.distance,
|
| 193 |
+
rounds=level.rounds,
|
| 194 |
+
p=level.p,
|
| 195 |
+
curriculum_level=level.name,
|
| 196 |
+
episode_id=ep_id,
|
| 197 |
+
dem_digest=cache.dem_digest,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def step(self, raw_response: str, episode_id: int) -> StepResult:
|
| 201 |
+
with self._lock:
|
| 202 |
+
episode = self._active.pop(episode_id, None)
|
| 203 |
+
if episode is None:
|
| 204 |
+
# Calling step() on an unknown episode ID is a hard error -
|
| 205 |
+
# the trainer didn't follow reset/step pairing.
|
| 206 |
+
raise KeyError(f"unknown or already-finished episode {episode_id}")
|
| 207 |
+
|
| 208 |
+
elapsed = time.monotonic() - episode.started_at
|
| 209 |
+
timed_out = elapsed > EPISODE_TIMEOUT_SECONDS
|
| 210 |
+
|
| 211 |
+
parsed = parse_action(
|
| 212 |
+
raw_response=raw_response,
|
| 213 |
+
num_data_qubits=episode.layout.num_data_qubits,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if timed_out:
|
| 217 |
+
# Hard timeout: zero reward, mark format compliance as zero,
|
| 218 |
+
# close the episode cleanly (Section 2.6).
|
| 219 |
+
breakdown = RewardBreakdown(
|
| 220 |
+
logical_correction=0.0,
|
| 221 |
+
syndrome_consistency=0.0,
|
| 222 |
+
hamming_overlap=0.0,
|
| 223 |
+
format_compliance=0.0,
|
| 224 |
+
pymatching_beat=0.0,
|
| 225 |
+
total=0.0,
|
| 226 |
+
)
|
| 227 |
+
action = DecoderAction(
|
| 228 |
+
raw_response=raw_response,
|
| 229 |
+
parse_success=False,
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
# Convert LLM-space qubit IDs (0..N-1) to Stim IDs before
|
| 233 |
+
# scoring; rewards operate in the Stim coordinate system.
|
| 234 |
+
from qubit_medic.prompts import ParseResult
|
| 235 |
+
parsed_stim = ParseResult(
|
| 236 |
+
x_errors=episode.layout.llm_to_stim(parsed.x_errors),
|
| 237 |
+
z_errors=episode.layout.llm_to_stim(parsed.z_errors),
|
| 238 |
+
parse_success=parsed.parse_success,
|
| 239 |
+
parse_partial=parsed.parse_partial,
|
| 240 |
+
raw_response=parsed.raw_response,
|
| 241 |
+
)
|
| 242 |
+
breakdown = compute_all_rewards(
|
| 243 |
+
parsed=parsed_stim,
|
| 244 |
+
sample=episode.sample,
|
| 245 |
+
layout=episode.layout,
|
| 246 |
+
final_detector_supports=episode.final_detector_supports,
|
| 247 |
+
weights=REWARD_WEIGHTS,
|
| 248 |
+
)
|
| 249 |
+
action = DecoderAction(
|
| 250 |
+
x_error_qubits=parsed.x_errors,
|
| 251 |
+
z_error_qubits=parsed.z_errors,
|
| 252 |
+
raw_response=raw_response,
|
| 253 |
+
parse_success=parsed.parse_success,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
self._scheduler.update(
|
| 257 |
+
episode.state.curriculum_level,
|
| 258 |
+
logical_correction=breakdown.logical_correction,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
episode.state.last_reward_breakdown = breakdown.as_dict()
|
| 262 |
+
|
| 263 |
+
n_x, n_z = per_round_x_z_counts(episode.layout)
|
| 264 |
+
prompt = build_prompt(
|
| 265 |
+
distance=episode.state.distance,
|
| 266 |
+
rounds=episode.state.rounds,
|
| 267 |
+
p=episode.state.p,
|
| 268 |
+
syndrome_bits=episode.state.syndrome_bits,
|
| 269 |
+
num_x_stabilizers=n_x,
|
| 270 |
+
num_z_stabilizers=n_z,
|
| 271 |
+
num_data_qubits=episode.layout.num_data_qubits,
|
| 272 |
+
)
|
| 273 |
+
obs = DecoderObservation(
|
| 274 |
+
prompt=prompt,
|
| 275 |
+
syndrome_bits=episode.state.syndrome_bits,
|
| 276 |
+
distance=episode.state.distance,
|
| 277 |
+
rounds=episode.state.rounds,
|
| 278 |
+
p=episode.state.p,
|
| 279 |
+
curriculum_level=episode.state.curriculum_level,
|
| 280 |
+
episode_id=episode.state.episode_id,
|
| 281 |
+
dem_digest=episode.state.dem_text[:8],
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
info = {
|
| 285 |
+
"rewards": breakdown.as_dict(),
|
| 286 |
+
"parsed_action": action.model_dump(),
|
| 287 |
+
"actual_observable_flip": episode.sample.actual_observable_flip,
|
| 288 |
+
"pymatching_observable_pred": episode.sample.pymatching_observable_pred,
|
| 289 |
+
"pymatching_x_errors": episode.sample.pymatching_x_errors,
|
| 290 |
+
"pymatching_z_errors": episode.sample.pymatching_z_errors,
|
| 291 |
+
"elapsed_seconds": elapsed,
|
| 292 |
+
"timed_out": timed_out,
|
| 293 |
+
"curriculum_stats": self._scheduler.stats(),
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
return StepResult(
|
| 297 |
+
observation=obs,
|
| 298 |
+
reward=breakdown.total,
|
| 299 |
+
done=True, # single-step episodes
|
| 300 |
+
truncated=timed_out,
|
| 301 |
+
info=info,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# ----- introspection --------------------------------------------------
|
| 305 |
+
|
| 306 |
+
def health(self) -> dict:
|
| 307 |
+
with self._lock:
|
| 308 |
+
return {
|
| 309 |
+
"ok": True,
|
| 310 |
+
"episodes_started": self._episode_counter,
|
| 311 |
+
"active_episodes": len(self._active),
|
| 312 |
+
"curriculum": self._scheduler.stats(),
|
| 313 |
+
"cached_levels": list(self._caches.keys()),
|
| 314 |
+
}
|
qubit_medic/server/openenv_adapter.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenEnv-compliant adapter around :class:`DecoderEnvironment`.
|
| 2 |
+
|
| 3 |
+
This wrapper satisfies the submission requirement *"Use OpenEnv (latest
|
| 4 |
+
release). Build on top of the framework; don't reinvent the wheel."* by
|
| 5 |
+
exposing our underlying :class:`qubit_medic.server.environment.DecoderEnvironment`
|
| 6 |
+
through the official ``openenv.core.Environment`` base class.
|
| 7 |
+
|
| 8 |
+
The adapter is intentionally thin: it just translates between OpenEnv's
|
| 9 |
+
``Action`` / ``Observation`` / ``State`` Pydantic shapes and our internal
|
| 10 |
+
``DecoderObservation`` / ``DecoderAction`` / ``StepResult``. All the
|
| 11 |
+
physics, reward scoring, curriculum, and episode bookkeeping continue to
|
| 12 |
+
live in :class:`DecoderEnvironment` - that code is *the* tested,
|
| 13 |
+
production path.
|
| 14 |
+
|
| 15 |
+
Usage
|
| 16 |
+
-----
|
| 17 |
+
|
| 18 |
+
The OpenEnv-compliant FastAPI app is created with::
|
| 19 |
+
|
| 20 |
+
from openenv.core import create_fastapi_app
|
| 21 |
+
from qubit_medic.server.openenv_adapter import (
|
| 22 |
+
QubitMedicEnvironment, QubitMedicAction, QubitMedicObservation,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
app = create_fastapi_app(
|
| 26 |
+
env=QubitMedicEnvironment,
|
| 27 |
+
action_cls=QubitMedicAction,
|
| 28 |
+
observation_cls=QubitMedicObservation,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
This registers the canonical OpenEnv routes:
|
| 32 |
+
|
| 33 |
+
* ``POST /reset`` - body ``{"seed": int?, "episode_id": str?}``
|
| 34 |
+
* ``POST /step`` - body ``{"action": {...QubitMedicAction...},
|
| 35 |
+
"timeout_s": float?, "request_id": str?}``
|
| 36 |
+
* ``GET /state`` - returns the current :class:`QubitMedicState`
|
| 37 |
+
* ``GET /health`` - liveness probe
|
| 38 |
+
* ``GET /schema`` - JSON Schema for the action/observation models
|
| 39 |
+
* ``GET /metadata`` - environment metadata
|
| 40 |
+
* ``POST /mcp`` - Model Context Protocol endpoint
|
| 41 |
+
* ``GET /docs`` - Swagger UI (auto-generated by FastAPI)
|
| 42 |
+
|
| 43 |
+
We additionally mount our own ``/healthz`` (Day-0 contract) and
|
| 44 |
+
``/decode`` (PyMatching baseline demo) on the returned app from
|
| 45 |
+
``qubit_medic.server.app``.
|
| 46 |
+
"""
|
| 47 |
+
from __future__ import annotations
|
| 48 |
+
|
| 49 |
+
from typing import Any, Optional
|
| 50 |
+
|
| 51 |
+
from openenv.core import Action, Environment, Observation, State
|
| 52 |
+
from openenv.core.env_server.types import EnvironmentMetadata
|
| 53 |
+
from pydantic import ConfigDict, Field
|
| 54 |
+
|
| 55 |
+
from qubit_medic.server.environment import DecoderEnvironment
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# --------------------------------------------------------------------------- #
|
| 59 |
+
# Process-wide singleton #
|
| 60 |
+
# --------------------------------------------------------------------------- #
|
| 61 |
+
# OpenEnv's HTTP server (simulation mode) instantiates a *fresh* Environment
|
| 62 |
+
# via the factory on every /reset and /step call. Our episode bookkeeping
|
| 63 |
+
# (the `_active` dict) lives inside DecoderEnvironment, so we route every
|
| 64 |
+
# QubitMedicEnvironment instance through the same DecoderEnvironment.
|
| 65 |
+
# This keeps reset() -> step() pairing intact across stateless HTTP calls
|
| 66 |
+
# while remaining fully compatible with OpenEnv's WebSocket session model
|
| 67 |
+
# (each WS session still gets its own QubitMedicEnvironment wrapper).
|
| 68 |
+
|
| 69 |
+
_INNER_SINGLETON: Optional[DecoderEnvironment] = None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _get_shared_inner() -> DecoderEnvironment:
|
| 73 |
+
"""Return the process-wide DecoderEnvironment, building it lazily."""
|
| 74 |
+
global _INNER_SINGLETON
|
| 75 |
+
if _INNER_SINGLETON is None:
|
| 76 |
+
env = DecoderEnvironment()
|
| 77 |
+
env._cache_for("L1_warmup") # noqa: SLF001 - intentional pre-warm
|
| 78 |
+
env._cache_for("L2_target") # noqa: SLF001
|
| 79 |
+
_INNER_SINGLETON = env
|
| 80 |
+
return _INNER_SINGLETON
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# --------------------------------------------------------------------------- #
|
| 84 |
+
# OpenEnv-flavoured Action / Observation / State #
|
| 85 |
+
# --------------------------------------------------------------------------- #
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class QubitMedicAction(Action):
|
| 89 |
+
"""LLM-emitted action: the raw text the model generated.
|
| 90 |
+
|
| 91 |
+
The server parses this into ``x_error_qubits`` / ``z_error_qubits`` via
|
| 92 |
+
:func:`qubit_medic.prompts.parse_action`. We keep the wire format
|
| 93 |
+
*just the raw string* so the server retains full control over parsing
|
| 94 |
+
(and so the trainer's reward function can audit unparseable outputs).
|
| 95 |
+
|
| 96 |
+
The trainer is also free to populate ``parsed_x_errors`` /
|
| 97 |
+
``parsed_z_errors`` directly when it wants to bypass the LLM (useful
|
| 98 |
+
for baseline policies and unit tests).
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# Inherit Action.model_config (extra='forbid', validate_assignment=True).
|
| 102 |
+
raw_response: str = Field(
|
| 103 |
+
default="",
|
| 104 |
+
description="Raw LLM completion text. Server parses to x/z error lists.",
|
| 105 |
+
)
|
| 106 |
+
parsed_x_errors: Optional[list[int]] = Field(
|
| 107 |
+
default=None,
|
| 108 |
+
description="Optional pre-parsed X-error qubit ids (LLM-space). "
|
| 109 |
+
"When provided, the server skips text parsing.",
|
| 110 |
+
)
|
| 111 |
+
parsed_z_errors: Optional[list[int]] = Field(
|
| 112 |
+
default=None,
|
| 113 |
+
description="Optional pre-parsed Z-error qubit ids (LLM-space).",
|
| 114 |
+
)
|
| 115 |
+
episode_id: Optional[int] = Field(
|
| 116 |
+
default=None,
|
| 117 |
+
description="Server-assigned episode id from the matching reset(). "
|
| 118 |
+
"If omitted, the most-recent active episode is used.",
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class QubitMedicObservation(Observation):
|
| 123 |
+
"""OpenEnv observation - mirrors :class:`DecoderObservation` plus the
|
| 124 |
+
standard OpenEnv ``done`` / ``reward`` fields.
|
| 125 |
+
|
| 126 |
+
The ``info`` dict (returned by ``step``) carries the per-component
|
| 127 |
+
reward breakdown, the ground-truth observable flip, and the PyMatching
|
| 128 |
+
baseline prediction so the trainer can score auxiliary metrics.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
model_config = ConfigDict(extra="forbid", validate_assignment=True,
|
| 132 |
+
arbitrary_types_allowed=True)
|
| 133 |
+
|
| 134 |
+
prompt: str = Field(default="", description="Pre-formatted LLM prompt.")
|
| 135 |
+
syndrome_bits: list[int] = Field(default_factory=list,
|
| 136 |
+
description="Detector activations (0/1).")
|
| 137 |
+
distance: int = Field(default=0, description="Code distance for this episode.")
|
| 138 |
+
rounds: int = Field(default=0, description="Number of stabilizer rounds.")
|
| 139 |
+
p: float = Field(default=0.0, description="SI1000 base error rate.")
|
| 140 |
+
curriculum_level: str = Field(default="",
|
| 141 |
+
description="Curriculum level name.")
|
| 142 |
+
episode_id: int = Field(default=0,
|
| 143 |
+
description="Server-assigned episode counter.")
|
| 144 |
+
dem_digest: str = Field(default="",
|
| 145 |
+
description="Short hash of the detector error model.")
|
| 146 |
+
info: dict[str, Any] = Field(default_factory=dict,
|
| 147 |
+
description="Per-step extras (reward "
|
| 148 |
+
"breakdown, ground-truth flip, "
|
| 149 |
+
"PyMatching baseline, etc.).")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class QubitMedicState(State):
|
| 153 |
+
"""Externally-visible state. We expose only the curriculum + episode
|
| 154 |
+
counters; physics-truth fields stay server-side to prevent reward
|
| 155 |
+
hacking (see :mod:`qubit_medic.models.DecoderState` doc-comment)."""
|
| 156 |
+
|
| 157 |
+
model_config = ConfigDict(extra="allow", validate_assignment=True,
|
| 158 |
+
arbitrary_types_allowed=True)
|
| 159 |
+
|
| 160 |
+
episodes_started: int = 0
|
| 161 |
+
active_episodes: int = 0
|
| 162 |
+
cached_levels: list[str] = Field(default_factory=list)
|
| 163 |
+
curriculum: dict[str, Any] = Field(default_factory=dict)
|
| 164 |
+
last_reward_breakdown: Optional[dict[str, float]] = None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# --------------------------------------------------------------------------- #
|
| 168 |
+
# Environment wrapper #
|
| 169 |
+
# --------------------------------------------------------------------------- #
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class QubitMedicEnvironment(Environment[QubitMedicAction,
|
| 173 |
+
QubitMedicObservation,
|
| 174 |
+
QubitMedicState]):
|
| 175 |
+
"""OpenEnv-compliant view of :class:`DecoderEnvironment`.
|
| 176 |
+
|
| 177 |
+
Single-step episodes (``done=True`` after every ``step``). The OpenEnv
|
| 178 |
+
HTTP server gets a fresh instance per WebSocket session if
|
| 179 |
+
``SUPPORTS_CONCURRENT_SESSIONS=True``; we set it to ``False`` because
|
| 180 |
+
our DecoderEnvironment uses a single Stim cache + a coarse lock, which
|
| 181 |
+
is simpler than per-session state and good enough for the GRPO
|
| 182 |
+
training loop.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = False
|
| 186 |
+
|
| 187 |
+
def __init__(self) -> None:
|
| 188 |
+
super().__init__()
|
| 189 |
+
# Share the underlying DecoderEnvironment across every wrapper
|
| 190 |
+
# instance the HTTP server creates - see _get_shared_inner.
|
| 191 |
+
self._inner = _get_shared_inner()
|
| 192 |
+
self._last_episode_id: Optional[int] = None
|
| 193 |
+
self._last_reward_breakdown: Optional[dict[str, float]] = None
|
| 194 |
+
|
| 195 |
+
# ----- abstract API --------------------------------------------------- #
|
| 196 |
+
|
| 197 |
+
def reset(
|
| 198 |
+
self,
|
| 199 |
+
seed: Optional[int] = None,
|
| 200 |
+
episode_id: Optional[str] = None,
|
| 201 |
+
**kwargs: Any,
|
| 202 |
+
) -> QubitMedicObservation:
|
| 203 |
+
forced_level = kwargs.get("forced_level")
|
| 204 |
+
obs = self._inner.reset(seed=seed, forced_level=forced_level)
|
| 205 |
+
self._last_episode_id = obs.episode_id
|
| 206 |
+
self._last_reward_breakdown = None
|
| 207 |
+
return QubitMedicObservation(
|
| 208 |
+
prompt=obs.prompt,
|
| 209 |
+
syndrome_bits=list(obs.syndrome_bits),
|
| 210 |
+
distance=obs.distance,
|
| 211 |
+
rounds=obs.rounds,
|
| 212 |
+
p=obs.p,
|
| 213 |
+
curriculum_level=obs.curriculum_level,
|
| 214 |
+
episode_id=obs.episode_id,
|
| 215 |
+
dem_digest=obs.dem_digest,
|
| 216 |
+
done=False,
|
| 217 |
+
reward=None,
|
| 218 |
+
info={"event": "reset"},
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def step(
|
| 222 |
+
self,
|
| 223 |
+
action: QubitMedicAction,
|
| 224 |
+
timeout_s: Optional[float] = None,
|
| 225 |
+
**kwargs: Any,
|
| 226 |
+
) -> QubitMedicObservation:
|
| 227 |
+
ep = action.episode_id if action.episode_id is not None else self._last_episode_id
|
| 228 |
+
if ep is None:
|
| 229 |
+
raise RuntimeError(
|
| 230 |
+
"step() called before reset(); no active episode to score."
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# If the trainer pre-parsed the action, format a synthetic raw
|
| 234 |
+
# response in the canonical "X: ... | Z: ..." shape so the server's
|
| 235 |
+
# parser produces the same x/z lists.
|
| 236 |
+
if action.parsed_x_errors is not None or action.parsed_z_errors is not None:
|
| 237 |
+
xs = action.parsed_x_errors or []
|
| 238 |
+
zs = action.parsed_z_errors or []
|
| 239 |
+
raw = f"<answer>X: {','.join(map(str, xs))} | Z: {','.join(map(str, zs))}</answer>"
|
| 240 |
+
else:
|
| 241 |
+
raw = action.raw_response
|
| 242 |
+
|
| 243 |
+
result = self._inner.step(raw_response=raw, episode_id=ep)
|
| 244 |
+
self._last_reward_breakdown = result.info.get("rewards")
|
| 245 |
+
|
| 246 |
+
return QubitMedicObservation(
|
| 247 |
+
prompt=result.observation.prompt,
|
| 248 |
+
syndrome_bits=list(result.observation.syndrome_bits),
|
| 249 |
+
distance=result.observation.distance,
|
| 250 |
+
rounds=result.observation.rounds,
|
| 251 |
+
p=result.observation.p,
|
| 252 |
+
curriculum_level=result.observation.curriculum_level,
|
| 253 |
+
episode_id=result.observation.episode_id,
|
| 254 |
+
dem_digest=result.observation.dem_digest,
|
| 255 |
+
done=result.done,
|
| 256 |
+
reward=float(result.reward),
|
| 257 |
+
info=result.info,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
@property
|
| 261 |
+
def state(self) -> QubitMedicState:
|
| 262 |
+
h = self._inner.health()
|
| 263 |
+
return QubitMedicState(
|
| 264 |
+
episode_id=str(self._last_episode_id)
|
| 265 |
+
if self._last_episode_id is not None else None,
|
| 266 |
+
step_count=int(h.get("episodes_started", 0)),
|
| 267 |
+
episodes_started=int(h.get("episodes_started", 0)),
|
| 268 |
+
active_episodes=int(h.get("active_episodes", 0)),
|
| 269 |
+
cached_levels=list(h.get("cached_levels", [])),
|
| 270 |
+
curriculum=dict(h.get("curriculum", {})),
|
| 271 |
+
last_reward_breakdown=self._last_reward_breakdown,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# ----- nice-to-haves -------------------------------------------------- #
|
| 275 |
+
|
| 276 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 277 |
+
return EnvironmentMetadata(
|
| 278 |
+
name="QubitMedicEnvironment",
|
| 279 |
+
description=(
|
| 280 |
+
"RL training environment for LLM-based quantum error-"
|
| 281 |
+
"correction decoders. Built on Stim + PyMatching. Five "
|
| 282 |
+
"verifiable rewards (logical correction, syndrome consistency, "
|
| 283 |
+
"Hamming overlap, format compliance, PyMatching beat-rate)."
|
| 284 |
+
),
|
| 285 |
+
version="1.0.0",
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
def close(self) -> None: # nothing to clean up
|
| 289 |
+
return None
|
qubit_medic/server/physics.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stim + PyMatching wrapper - the physics engine (Section 2.4 of the plan).
|
| 2 |
+
|
| 3 |
+
This module never makes decoding decisions: it builds circuits, samples
|
| 4 |
+
syndromes, computes baselines, and exposes the observable's support on the
|
| 5 |
+
data qubits so the reward functions can score predictions deterministically.
|
| 6 |
+
|
| 7 |
+
Two design choices worth flagging up-front:
|
| 8 |
+
|
| 9 |
+
* The LLM's action is a **terminal Pauli frame** on data qubits (the X and Z
|
| 10 |
+
errors on each data qubit at the moment of final measurement). This
|
| 11 |
+
representation is exact for the rotated memory_z task and lets us reuse
|
| 12 |
+
Stim/PyMatching ground-truth machinery. The trade-off is documented in
|
| 13 |
+
``rewards.py``: the syndrome-consistency reward (Reward 2) only constrains
|
| 14 |
+
the *final-round* detectors. Earlier rounds are silent w.r.t. an
|
| 15 |
+
end-of-circuit Pauli frame; that is intentional and made explicit in the
|
| 16 |
+
reward's docstring.
|
| 17 |
+
|
| 18 |
+
* "Ground-truth error pattern" for Reward 3 is taken to be the
|
| 19 |
+
**PyMatching-most-probable error pattern** explaining the syndrome
|
| 20 |
+
(extracted via ``Matching.decode_to_edges_array``). This is the
|
| 21 |
+
near-optimal canonical choice and matches what the AlphaQubit baseline
|
| 22 |
+
comparison uses. The README's *honesty note* repeats this.
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import hashlib
|
| 27 |
+
from dataclasses import dataclass
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import pymatching
|
| 32 |
+
import stim
|
| 33 |
+
|
| 34 |
+
from qubit_medic.config import (
|
| 35 |
+
CODE_TASK,
|
| 36 |
+
CurriculumLevel,
|
| 37 |
+
SI1000Rates,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# --------------------------------------------------------------------------- #
|
| 42 |
+
# Circuit + DEM construction #
|
| 43 |
+
# --------------------------------------------------------------------------- #
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_circuit(level: CurriculumLevel) -> stim.Circuit:
|
| 47 |
+
"""Generate a Stim ``rotated_memory_z`` circuit at the given level."""
|
| 48 |
+
rates = SI1000Rates.from_p(level.p)
|
| 49 |
+
return stim.Circuit.generated(
|
| 50 |
+
CODE_TASK,
|
| 51 |
+
distance=level.distance,
|
| 52 |
+
rounds=level.rounds,
|
| 53 |
+
**rates.as_stim_kwargs(),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def build_dem(circuit: stim.Circuit) -> stim.DetectorErrorModel:
|
| 58 |
+
"""Decompose-errors=True is mandatory for PyMatching."""
|
| 59 |
+
return circuit.detector_error_model(decompose_errors=True)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def dem_digest(dem: stim.DetectorErrorModel) -> str:
|
| 63 |
+
"""8-char digest of the DEM, useful for grouping training logs."""
|
| 64 |
+
return hashlib.sha256(str(dem).encode("utf-8")).hexdigest()[:8]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# --------------------------------------------------------------------------- #
|
| 68 |
+
# Layout introspection - figure out data qubits, ancillas, observable support #
|
| 69 |
+
# --------------------------------------------------------------------------- #
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass(frozen=True)
|
| 73 |
+
class CircuitLayout:
|
| 74 |
+
"""Static facts about a circuit, computed once per episode.
|
| 75 |
+
|
| 76 |
+
Two indexings coexist:
|
| 77 |
+
|
| 78 |
+
* **Stim IDs** (``data_qubits``) are the physical qubit IDs Stim emits
|
| 79 |
+
(e.g. ``(1, 3, 5, 8, 10, 12, 15, 17, 19)`` for distance-3). These are
|
| 80 |
+
what Stim/PyMatching speak.
|
| 81 |
+
* **LLM IDs** are consecutive ``0..num_data_qubits-1``. These are what
|
| 82 |
+
the prompt advertises and what the LLM emits, because consecutive
|
| 83 |
+
small ints are dramatically easier for a language model to handle.
|
| 84 |
+
|
| 85 |
+
:meth:`llm_to_stim` and :meth:`stim_to_llm` perform the remap. *All*
|
| 86 |
+
server-internal scoring uses Stim IDs; the boundary at the prompt
|
| 87 |
+
formatter / parser converts.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
data_qubits: tuple[int, ...]
|
| 91 |
+
"""Stim IDs of data qubits (measured by terminal ``M``), sorted."""
|
| 92 |
+
|
| 93 |
+
data_qubit_coords: tuple[tuple[float, float], ...]
|
| 94 |
+
"""(x, y) coordinate of each data qubit, in the same order as
|
| 95 |
+
``data_qubits``. Used by Reward 3 to snap PyMatching edges to qubits."""
|
| 96 |
+
|
| 97 |
+
ancilla_qubits: tuple[int, ...]
|
| 98 |
+
"""Physical qubit IDs that hold stabiliser measurements (``MR``)."""
|
| 99 |
+
|
| 100 |
+
z_observable_support: tuple[int, ...]
|
| 101 |
+
"""Data qubits whose Z value is XOR'd into the logical Z observable.
|
| 102 |
+
An X error on any of these flips the observable."""
|
| 103 |
+
|
| 104 |
+
detector_round: tuple[int, ...]
|
| 105 |
+
"""For each detector index, the round it nominally belongs to (0-based,
|
| 106 |
+
extracted from the ``DETECTOR(x, y, t)`` coordinate)."""
|
| 107 |
+
|
| 108 |
+
detector_coords: tuple[tuple[float, float], ...]
|
| 109 |
+
"""(x, y) coordinate of each detector, used by Reward 3."""
|
| 110 |
+
|
| 111 |
+
detector_is_x_type: tuple[bool, ...]
|
| 112 |
+
"""Whether the detector watches an X-stabiliser. For the rotated surface
|
| 113 |
+
code Stim places X-stabilisers at coordinates with ``(x + y) mod 4 == 2``
|
| 114 |
+
and Z-stabilisers at ``(x + y) mod 4 == 0`` (verified empirically against
|
| 115 |
+
Stim 1.15's ``surface_code:rotated_memory_z``)."""
|
| 116 |
+
|
| 117 |
+
final_detectors: tuple[int, ...]
|
| 118 |
+
"""Indices of detectors that correspond to the *last* timeslice - those
|
| 119 |
+
are the only detectors a terminal Pauli frame can affect (Reward 2)."""
|
| 120 |
+
|
| 121 |
+
num_data_qubits: int
|
| 122 |
+
num_ancilla_qubits: int
|
| 123 |
+
num_detectors: int
|
| 124 |
+
num_observables: int
|
| 125 |
+
|
| 126 |
+
# ----- LLM <-> Stim qubit-ID remapping ---------------------------------
|
| 127 |
+
|
| 128 |
+
def llm_to_stim(self, llm_ids: list[int]) -> list[int]:
|
| 129 |
+
"""Convert consecutive LLM IDs to physical Stim IDs.
|
| 130 |
+
|
| 131 |
+
Out-of-range IDs are silently dropped (the parser already enforces
|
| 132 |
+
the upper bound, but we double-check here as a defence-in-depth).
|
| 133 |
+
"""
|
| 134 |
+
out: list[int] = []
|
| 135 |
+
n = len(self.data_qubits)
|
| 136 |
+
for i in llm_ids:
|
| 137 |
+
if 0 <= i < n:
|
| 138 |
+
out.append(self.data_qubits[i])
|
| 139 |
+
return out
|
| 140 |
+
|
| 141 |
+
def stim_to_llm(self, stim_ids: list[int]) -> list[int]:
|
| 142 |
+
"""Inverse of :meth:`llm_to_stim` - used to render targets in the
|
| 143 |
+
SFT data and the imitator policy."""
|
| 144 |
+
lookup = {q: i for i, q in enumerate(self.data_qubits)}
|
| 145 |
+
return [lookup[q] for q in stim_ids if q in lookup]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _walk_measurement_records(
|
| 149 |
+
circuit: stim.Circuit,
|
| 150 |
+
) -> tuple[list[int], list[Optional[str]]]:
|
| 151 |
+
"""Replay the circuit (no sampling) to map each measurement record to a
|
| 152 |
+
qubit. Returns parallel lists: qubits[i] = qubit id, instr[i] = gate."""
|
| 153 |
+
qubits: list[int] = []
|
| 154 |
+
instrs: list[Optional[str]] = []
|
| 155 |
+
|
| 156 |
+
def _walk(c: stim.Circuit, repeats: int = 1) -> None:
|
| 157 |
+
for _ in range(repeats):
|
| 158 |
+
for inst in c:
|
| 159 |
+
if isinstance(inst, stim.CircuitRepeatBlock):
|
| 160 |
+
_walk(inst.body_copy(), inst.repeat_count)
|
| 161 |
+
continue
|
| 162 |
+
name = inst.name
|
| 163 |
+
if name in {
|
| 164 |
+
"M", "MX", "MY", "MZ",
|
| 165 |
+
"MR", "MRX", "MRY", "MRZ",
|
| 166 |
+
"MPP",
|
| 167 |
+
}:
|
| 168 |
+
for t in inst.targets_copy():
|
| 169 |
+
if t.is_qubit_target:
|
| 170 |
+
qubits.append(t.qubit_value)
|
| 171 |
+
instrs.append(name)
|
| 172 |
+
|
| 173 |
+
_walk(circuit)
|
| 174 |
+
return qubits, instrs
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def extract_layout(circuit: stim.Circuit) -> CircuitLayout:
|
| 178 |
+
"""Walk the circuit once to build a full :class:`CircuitLayout`."""
|
| 179 |
+
flat = circuit.flattened()
|
| 180 |
+
measurement_qubits, measurement_instrs = _walk_measurement_records(circuit)
|
| 181 |
+
|
| 182 |
+
# Data qubits = those measured by terminal ``M`` (destructive, no reset).
|
| 183 |
+
data_qubits_in_order: list[int] = []
|
| 184 |
+
seen_data = set()
|
| 185 |
+
for q, instr in zip(measurement_qubits, measurement_instrs):
|
| 186 |
+
if instr == "M" and q not in seen_data:
|
| 187 |
+
data_qubits_in_order.append(q)
|
| 188 |
+
seen_data.add(q)
|
| 189 |
+
data_qubits = tuple(sorted(seen_data))
|
| 190 |
+
|
| 191 |
+
# Ancilla qubits = everything measured by MR (reset after measurement).
|
| 192 |
+
ancilla_qubits = tuple(
|
| 193 |
+
sorted({q for q, instr in zip(measurement_qubits, measurement_instrs)
|
| 194 |
+
if instr == "MR"})
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Observable support: walk OBSERVABLE_INCLUDE entries and resolve their
|
| 198 |
+
# rec[-k] back to qubit IDs via the measurement record table.
|
| 199 |
+
obs_support: dict[int, set[int]] = {}
|
| 200 |
+
for inst in flat:
|
| 201 |
+
if inst.name == "OBSERVABLE_INCLUDE":
|
| 202 |
+
args = inst.gate_args_copy()
|
| 203 |
+
obs_idx = int(args[0]) if args else 0
|
| 204 |
+
for t in inst.targets_copy():
|
| 205 |
+
if t.is_measurement_record_target:
|
| 206 |
+
actual = len(measurement_qubits) + t.value # value is negative
|
| 207 |
+
if 0 <= actual < len(measurement_qubits):
|
| 208 |
+
obs_support.setdefault(obs_idx, set()).add(
|
| 209 |
+
measurement_qubits[actual]
|
| 210 |
+
)
|
| 211 |
+
z_obs = tuple(sorted(obs_support.get(0, set())))
|
| 212 |
+
|
| 213 |
+
# Qubit coordinates from QUBIT_COORDS instructions.
|
| 214 |
+
qubit_coords: dict[int, tuple[float, float]] = {}
|
| 215 |
+
for inst in flat:
|
| 216 |
+
if inst.name == "QUBIT_COORDS":
|
| 217 |
+
args = inst.gate_args_copy()
|
| 218 |
+
x = float(args[0]) if len(args) >= 1 else 0.0
|
| 219 |
+
y = float(args[1]) if len(args) >= 2 else 0.0
|
| 220 |
+
for t in inst.targets_copy():
|
| 221 |
+
if t.is_qubit_target:
|
| 222 |
+
qubit_coords[t.qubit_value] = (x, y)
|
| 223 |
+
data_qubit_coords = tuple(qubit_coords.get(q, (0.0, 0.0)) for q in data_qubits)
|
| 224 |
+
|
| 225 |
+
# Detector coordinates - last value of the tuple is the round index.
|
| 226 |
+
det_coords_raw = circuit.get_detector_coordinates()
|
| 227 |
+
num_dets = circuit.num_detectors
|
| 228 |
+
rounds_per_det: list[int] = []
|
| 229 |
+
is_x_type: list[bool] = []
|
| 230 |
+
detector_coords: list[tuple[float, float]] = []
|
| 231 |
+
for i in range(num_dets):
|
| 232 |
+
c = det_coords_raw.get(i, ())
|
| 233 |
+
if not c:
|
| 234 |
+
rounds_per_det.append(0)
|
| 235 |
+
is_x_type.append(False)
|
| 236 |
+
detector_coords.append((0.0, 0.0))
|
| 237 |
+
continue
|
| 238 |
+
round_idx = int(c[-1]) if len(c) >= 3 else 0
|
| 239 |
+
rounds_per_det.append(round_idx)
|
| 240 |
+
x = float(c[0]) if len(c) >= 1 else 0.0
|
| 241 |
+
y = float(c[1]) if len(c) >= 2 else 0.0
|
| 242 |
+
detector_coords.append((x, y))
|
| 243 |
+
# X-stabilisers sit at (x + y) % 4 == 2 in Stim's generator.
|
| 244 |
+
is_x_type.append((int(x + y) % 4) == 2)
|
| 245 |
+
|
| 246 |
+
final_round = max(rounds_per_det) if rounds_per_det else 0
|
| 247 |
+
final_dets = tuple(i for i, r in enumerate(rounds_per_det) if r == final_round)
|
| 248 |
+
|
| 249 |
+
return CircuitLayout(
|
| 250 |
+
data_qubits=data_qubits,
|
| 251 |
+
data_qubit_coords=data_qubit_coords,
|
| 252 |
+
ancilla_qubits=ancilla_qubits,
|
| 253 |
+
z_observable_support=z_obs,
|
| 254 |
+
detector_round=tuple(rounds_per_det),
|
| 255 |
+
detector_coords=tuple(detector_coords),
|
| 256 |
+
detector_is_x_type=tuple(is_x_type),
|
| 257 |
+
final_detectors=final_dets,
|
| 258 |
+
num_data_qubits=len(data_qubits),
|
| 259 |
+
num_ancilla_qubits=len(ancilla_qubits),
|
| 260 |
+
num_detectors=num_dets,
|
| 261 |
+
num_observables=circuit.num_observables,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# --------------------------------------------------------------------------- #
|
| 266 |
+
# Sampling and decoding #
|
| 267 |
+
# --------------------------------------------------------------------------- #
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@dataclass(frozen=True)
|
| 271 |
+
class SyndromeSample:
|
| 272 |
+
"""One noisy episode: detector activations, ground-truth observable
|
| 273 |
+
flip, and PyMatching's prediction (used by Rewards 1 and 5)."""
|
| 274 |
+
|
| 275 |
+
syndrome_bits: list[int]
|
| 276 |
+
actual_observable_flip: int
|
| 277 |
+
pymatching_observable_pred: int
|
| 278 |
+
pymatching_x_errors: list[int] # Pauli frame at end of circuit (X part)
|
| 279 |
+
pymatching_z_errors: list[int] # Pauli frame at end of circuit (Z part)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def sample_episode(
|
| 283 |
+
circuit: stim.Circuit,
|
| 284 |
+
matching: pymatching.Matching,
|
| 285 |
+
layout: CircuitLayout,
|
| 286 |
+
seed: Optional[int] = None,
|
| 287 |
+
) -> SyndromeSample:
|
| 288 |
+
"""Sample one shot, decode it with PyMatching, and bundle the result."""
|
| 289 |
+
sampler = circuit.compile_detector_sampler(seed=seed)
|
| 290 |
+
detection, observables = sampler.sample(1, separate_observables=True)
|
| 291 |
+
detection_row = detection[0].astype(np.uint8)
|
| 292 |
+
observable_flip = int(observables[0, 0]) if observables.shape[1] else 0
|
| 293 |
+
|
| 294 |
+
# PyMatching's prediction (observable level).
|
| 295 |
+
pred_obs = int(matching.decode(detection_row)[0])
|
| 296 |
+
|
| 297 |
+
# PyMatching's predicted physical Pauli frame on data qubits.
|
| 298 |
+
pred_x, pred_z = pymatching_predicted_pauli_frame(
|
| 299 |
+
matching=matching, syndrome=detection_row, layout=layout,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return SyndromeSample(
|
| 303 |
+
syndrome_bits=detection_row.tolist(),
|
| 304 |
+
actual_observable_flip=observable_flip,
|
| 305 |
+
pymatching_observable_pred=pred_obs,
|
| 306 |
+
pymatching_x_errors=pred_x,
|
| 307 |
+
pymatching_z_errors=pred_z,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def pymatching_predicted_pauli_frame(
|
| 312 |
+
matching: pymatching.Matching,
|
| 313 |
+
syndrome: np.ndarray,
|
| 314 |
+
layout: CircuitLayout,
|
| 315 |
+
) -> tuple[list[int], list[int]]:
|
| 316 |
+
"""Convert PyMatching's per-edge prediction into a data-qubit Pauli frame.
|
| 317 |
+
|
| 318 |
+
The matching graph's edges correspond to error mechanisms in the DEM.
|
| 319 |
+
Each edge connects two detectors (or a detector and a boundary). The
|
| 320 |
+
data qubit responsible for the edge sits geometrically between the two
|
| 321 |
+
detectors on the surface-code grid - we recover it by snapping the
|
| 322 |
+
midpoint of the detector coordinates to the nearest data qubit.
|
| 323 |
+
|
| 324 |
+
This frame is used as ground-truth for Reward 3 (Hamming overlap).
|
| 325 |
+
Z-stabiliser endpoints (``(x+y) mod 4 == 0``) catch X errors on data
|
| 326 |
+
qubits; X-stabiliser endpoints catch Z errors. Boundary edges are
|
| 327 |
+
snapped to the unique data qubit adjacent to that boundary.
|
| 328 |
+
"""
|
| 329 |
+
try:
|
| 330 |
+
edges = matching.decode_to_edges_array(syndrome)
|
| 331 |
+
except Exception:
|
| 332 |
+
return [], []
|
| 333 |
+
|
| 334 |
+
if edges is None or len(edges) == 0:
|
| 335 |
+
return [], []
|
| 336 |
+
|
| 337 |
+
data_qubits = layout.data_qubits
|
| 338 |
+
data_coords = layout.data_qubit_coords
|
| 339 |
+
det_coords = layout.detector_coords
|
| 340 |
+
det_is_x = layout.detector_is_x_type
|
| 341 |
+
n_dets = len(det_coords)
|
| 342 |
+
|
| 343 |
+
def _snap(x: float, y: float) -> int:
|
| 344 |
+
best_q = data_qubits[0]
|
| 345 |
+
best_d = float("inf")
|
| 346 |
+
for q, (qx, qy) in zip(data_qubits, data_coords):
|
| 347 |
+
d = (qx - x) ** 2 + (qy - y) ** 2
|
| 348 |
+
if d < best_d:
|
| 349 |
+
best_d = d
|
| 350 |
+
best_q = q
|
| 351 |
+
return best_q
|
| 352 |
+
|
| 353 |
+
x_errs: set[int] = set()
|
| 354 |
+
z_errs: set[int] = set()
|
| 355 |
+
for edge in edges:
|
| 356 |
+
a, b = int(edge[0]), int(edge[1])
|
| 357 |
+
ca = det_coords[a] if 0 <= a < n_dets else None
|
| 358 |
+
cb = det_coords[b] if 0 <= b < n_dets else None
|
| 359 |
+
if ca is None and cb is None:
|
| 360 |
+
continue
|
| 361 |
+
if cb is None:
|
| 362 |
+
mid_x, mid_y = ca
|
| 363 |
+
ref_is_x = det_is_x[a]
|
| 364 |
+
elif ca is None:
|
| 365 |
+
mid_x, mid_y = cb
|
| 366 |
+
ref_is_x = det_is_x[b]
|
| 367 |
+
else:
|
| 368 |
+
mid_x = (ca[0] + cb[0]) / 2.0
|
| 369 |
+
mid_y = (ca[1] + cb[1]) / 2.0
|
| 370 |
+
ref_is_x = det_is_x[a] if 0 <= a < n_dets else det_is_x[b]
|
| 371 |
+
snap = _snap(mid_x, mid_y)
|
| 372 |
+
if ref_is_x:
|
| 373 |
+
z_errs.add(snap)
|
| 374 |
+
else:
|
| 375 |
+
x_errs.add(snap)
|
| 376 |
+
|
| 377 |
+
return sorted(x_errs), sorted(z_errs)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# --------------------------------------------------------------------------- #
|
| 381 |
+
# Predicted-observable computation (used by Reward 1) #
|
| 382 |
+
# --------------------------------------------------------------------------- #
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def predicted_observable_flip(
|
| 386 |
+
predicted_x_qubits: list[int],
|
| 387 |
+
layout: CircuitLayout,
|
| 388 |
+
) -> int:
|
| 389 |
+
"""Compute the implied logical-Z flip from a predicted Pauli frame.
|
| 390 |
+
|
| 391 |
+
Only X errors on data qubits in ``z_observable_support`` matter for the
|
| 392 |
+
Z observable - Z errors on data qubits commute with the destructive Z
|
| 393 |
+
measurement and so cannot flip the observable.
|
| 394 |
+
"""
|
| 395 |
+
support = set(layout.z_observable_support)
|
| 396 |
+
parity = 0
|
| 397 |
+
for q in predicted_x_qubits:
|
| 398 |
+
if q in support:
|
| 399 |
+
parity ^= 1
|
| 400 |
+
return parity
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def rectify_pauli_frame_to_observable(
|
| 404 |
+
x_errors: list[int],
|
| 405 |
+
z_errors: list[int],
|
| 406 |
+
target_observable_flip: int,
|
| 407 |
+
layout: CircuitLayout,
|
| 408 |
+
) -> tuple[list[int], list[int]]:
|
| 409 |
+
"""Adjust a predicted X-error frame so its implied observable matches.
|
| 410 |
+
|
| 411 |
+
Used by the SFT data generator and the PyMatching imitator policy: the
|
| 412 |
+
snap-to-data-qubit edge mapping (:func:`pymatching_predicted_pauli_frame`)
|
| 413 |
+
is only ~95% faithful, but PyMatching's *observable* prediction is exact.
|
| 414 |
+
We patch the X frame by toggling the smallest-degree data qubit on the
|
| 415 |
+
observable support whenever the implied parity disagrees with the
|
| 416 |
+
target. Z errors are untouched because they don't affect a Z observable.
|
| 417 |
+
"""
|
| 418 |
+
implied = predicted_observable_flip(x_errors, layout)
|
| 419 |
+
if implied == target_observable_flip:
|
| 420 |
+
return list(x_errors), list(z_errors)
|
| 421 |
+
|
| 422 |
+
support = list(layout.z_observable_support)
|
| 423 |
+
if not support:
|
| 424 |
+
return list(x_errors), list(z_errors)
|
| 425 |
+
|
| 426 |
+
x_set = set(x_errors)
|
| 427 |
+
intersect = sorted(x_set & set(support))
|
| 428 |
+
if intersect:
|
| 429 |
+
# Remove the smallest one currently flipping the observable.
|
| 430 |
+
x_set.discard(intersect[0])
|
| 431 |
+
else:
|
| 432 |
+
# Add the smallest support qubit to introduce a flip.
|
| 433 |
+
x_set.add(support[0])
|
| 434 |
+
return sorted(x_set), list(z_errors)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# --------------------------------------------------------------------------- #
|
| 438 |
+
# Stabiliser counts - derived from layout #
|
| 439 |
+
# --------------------------------------------------------------------------- #
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def detector_round_split(layout: CircuitLayout, syndrome_bits: list[int]) -> dict[int, list[int]]:
|
| 443 |
+
"""Group detector bits by their nominal round (used for prompt formatting)."""
|
| 444 |
+
out: dict[int, list[int]] = {}
|
| 445 |
+
for idx, bit in enumerate(syndrome_bits):
|
| 446 |
+
r = layout.detector_round[idx] if idx < len(layout.detector_round) else 0
|
| 447 |
+
out.setdefault(r, []).append(bit)
|
| 448 |
+
return out
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def per_round_x_z_counts(layout: CircuitLayout) -> tuple[int, int]:
|
| 452 |
+
"""Best-effort count of X-type and Z-type stabiliser detectors per round.
|
| 453 |
+
|
| 454 |
+
For a rotated surface code at distance d there are (d^2-1)/2 of each
|
| 455 |
+
type. We compute that from the layout to be robust.
|
| 456 |
+
"""
|
| 457 |
+
# Take one fully-populated round (the one with the most detectors).
|
| 458 |
+
round_counts: dict[int, list[bool]] = {}
|
| 459 |
+
for idx, r in enumerate(layout.detector_round):
|
| 460 |
+
round_counts.setdefault(r, []).append(layout.detector_is_x_type[idx])
|
| 461 |
+
if not round_counts:
|
| 462 |
+
return 0, 0
|
| 463 |
+
full_round = max(round_counts.values(), key=len)
|
| 464 |
+
n_x = sum(1 for v in full_round if v)
|
| 465 |
+
n_z = sum(1 for v in full_round if not v)
|
| 466 |
+
return n_x, n_z
|
qubit_medic/server/rewards.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""The five reward functions (Section 3 of the plan).
|
| 2 |
+
|
| 3 |
+
Design contract (from Section 3.6):
|
| 4 |
+
|
| 5 |
+
* Each reward is a pure function ``(action, state, layout) -> float in [0, 1]``.
|
| 6 |
+
* Rewards never observe each other - they're independent by construction so
|
| 7 |
+
the LLM can't satisfy one at the expense of another without genuine task
|
| 8 |
+
understanding.
|
| 9 |
+
* The combined reward is a weighted sum (weights in :mod:`qubit_medic.config`)
|
| 10 |
+
clamped to ``[0, 1]``.
|
| 11 |
+
* Every per-component score is reported in the ``info`` dict so logs can
|
| 12 |
+
surface reward-hacking early (Section 3.7).
|
| 13 |
+
|
| 14 |
+
A note on Reward 2 and Reward 3 ground truth - see ``physics.py``: the LLM
|
| 15 |
+
predicts a *terminal Pauli frame*, which fully determines the logical-Z
|
| 16 |
+
observable but only constrains the *final-round* detectors. Earlier rounds'
|
| 17 |
+
detectors are intentionally unscored. Reward 3 compares against PyMatching's
|
| 18 |
+
near-optimal Pauli-frame prediction (the canonical decoder reference used in
|
| 19 |
+
AlphaQubit's Nature paper).
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
|
| 25 |
+
from qubit_medic.config import REWARD_WEIGHTS
|
| 26 |
+
from qubit_medic.prompts import ParseResult
|
| 27 |
+
from qubit_medic.server.physics import (
|
| 28 |
+
CircuitLayout,
|
| 29 |
+
SyndromeSample,
|
| 30 |
+
predicted_observable_flip,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# --------------------------------------------------------------------------- #
|
| 35 |
+
# Reward 1: logical correction success #
|
| 36 |
+
# --------------------------------------------------------------------------- #
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def reward_logical_correction(
|
| 40 |
+
parsed: ParseResult,
|
| 41 |
+
sample: SyndromeSample,
|
| 42 |
+
layout: CircuitLayout,
|
| 43 |
+
) -> float:
|
| 44 |
+
"""Did the predicted correction preserve the logical state?
|
| 45 |
+
|
| 46 |
+
Apply the predicted X errors as a Pauli frame at end-of-circuit and
|
| 47 |
+
compute the implied observable flip. If this matches the actual
|
| 48 |
+
observable flip recorded by Stim, the logical state was preserved.
|
| 49 |
+
Outputs 1.0 if so, else 0.0.
|
| 50 |
+
|
| 51 |
+
This is the unfakeable reward - it depends only on Stim's ground truth.
|
| 52 |
+
"""
|
| 53 |
+
implied = predicted_observable_flip(parsed.x_errors, layout)
|
| 54 |
+
return 1.0 if implied == sample.actual_observable_flip else 0.0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# --------------------------------------------------------------------------- #
|
| 58 |
+
# Reward 2: syndrome consistency #
|
| 59 |
+
# --------------------------------------------------------------------------- #
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _syndrome_from_pauli_frame(
|
| 63 |
+
x_errors: list[int],
|
| 64 |
+
layout: CircuitLayout,
|
| 65 |
+
final_detector_supports: dict[int, frozenset[int]],
|
| 66 |
+
) -> dict[int, int]:
|
| 67 |
+
"""Compute the implied bits for FINAL-round detectors only.
|
| 68 |
+
|
| 69 |
+
A terminal X error on data qubit ``q`` flips a final-round Z-stabiliser
|
| 70 |
+
detector iff ``q`` is in that detector's support.
|
| 71 |
+
"""
|
| 72 |
+
out: dict[int, int] = {}
|
| 73 |
+
x_set = set(x_errors)
|
| 74 |
+
for det_idx, support in final_detector_supports.items():
|
| 75 |
+
out[det_idx] = 1 if len(x_set & support) % 2 == 1 else 0
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def reward_syndrome_consistency(
|
| 80 |
+
parsed: ParseResult,
|
| 81 |
+
sample: SyndromeSample,
|
| 82 |
+
layout: CircuitLayout,
|
| 83 |
+
final_detector_supports: dict[int, frozenset[int]],
|
| 84 |
+
) -> float:
|
| 85 |
+
"""How well does the predicted Pauli frame reproduce the FINAL detectors?
|
| 86 |
+
|
| 87 |
+
Computes Hamming similarity between ``predicted_final_bits`` (induced by
|
| 88 |
+
the predicted X errors) and ``observed_final_bits``. Returns
|
| 89 |
+
``1 - hamming_distance / num_final_detectors``.
|
| 90 |
+
|
| 91 |
+
Rationale (Section 3.2): without this term, an LLM that lucky-guesses
|
| 92 |
+
the right qubits could get Reward 1 occasionally; this signal forces it
|
| 93 |
+
to also explain the data the syndrome carries.
|
| 94 |
+
"""
|
| 95 |
+
final_dets = layout.final_detectors
|
| 96 |
+
if not final_dets:
|
| 97 |
+
return 0.0
|
| 98 |
+
implied = _syndrome_from_pauli_frame(
|
| 99 |
+
parsed.x_errors, layout, final_detector_supports
|
| 100 |
+
)
|
| 101 |
+
distance = 0
|
| 102 |
+
for det_idx in final_dets:
|
| 103 |
+
observed = sample.syndrome_bits[det_idx]
|
| 104 |
+
predicted = implied.get(det_idx, 0)
|
| 105 |
+
if observed != predicted:
|
| 106 |
+
distance += 1
|
| 107 |
+
return 1.0 - distance / len(final_dets)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def compute_final_detector_supports(
|
| 111 |
+
layout: CircuitLayout,
|
| 112 |
+
syndrome_bits_unused: list[int] | None = None, # API symmetry
|
| 113 |
+
*,
|
| 114 |
+
detector_to_data_qubits: dict[int, frozenset[int]] | None = None,
|
| 115 |
+
) -> dict[int, frozenset[int]]:
|
| 116 |
+
"""Map each final-round detector to the set of data qubits whose
|
| 117 |
+
terminal X error flips it.
|
| 118 |
+
|
| 119 |
+
For the rotated memory_z code, each Z-stabiliser final detector watches
|
| 120 |
+
the four (or two/one on the boundary) data qubits adjacent to it on the
|
| 121 |
+
grid. We compute adjacency by Euclidean distance; data qubits at
|
| 122 |
+
distance ``sqrt(2)`` from a Z-stabiliser ancilla coordinate are
|
| 123 |
+
incident.
|
| 124 |
+
"""
|
| 125 |
+
if detector_to_data_qubits is not None:
|
| 126 |
+
return detector_to_data_qubits
|
| 127 |
+
|
| 128 |
+
out: dict[int, frozenset[int]] = {}
|
| 129 |
+
for det_idx in layout.final_detectors:
|
| 130 |
+
dx, dy = layout.detector_coords[det_idx]
|
| 131 |
+
adj: set[int] = set()
|
| 132 |
+
for q, (qx, qy) in zip(layout.data_qubits, layout.data_qubit_coords):
|
| 133 |
+
if abs((qx - dx) ** 2 + (qy - dy) ** 2 - 2.0) < 1e-6:
|
| 134 |
+
adj.add(q)
|
| 135 |
+
out[det_idx] = frozenset(adj)
|
| 136 |
+
return out
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# --------------------------------------------------------------------------- #
|
| 140 |
+
# Reward 3: Hamming overlap with reference Pauli frame #
|
| 141 |
+
# --------------------------------------------------------------------------- #
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _jaccard(a: list[int], b: list[int]) -> float:
|
| 145 |
+
"""Jaccard index. Returns 1.0 when both sets are empty (perfect agreement)."""
|
| 146 |
+
sa, sb = set(a), set(b)
|
| 147 |
+
if not sa and not sb:
|
| 148 |
+
return 1.0
|
| 149 |
+
inter = len(sa & sb)
|
| 150 |
+
union = len(sa | sb)
|
| 151 |
+
return inter / union if union else 1.0
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def reward_hamming_overlap(
|
| 155 |
+
parsed: ParseResult,
|
| 156 |
+
sample: SyndromeSample,
|
| 157 |
+
layout: CircuitLayout,
|
| 158 |
+
) -> float:
|
| 159 |
+
"""Average of Jaccard(X) and Jaccard(Z) against the reference frame.
|
| 160 |
+
|
| 161 |
+
Reference is PyMatching's per-edge predicted Pauli frame
|
| 162 |
+
(``sample.pymatching_x_errors`` / ``..._z_errors``). This is the dense
|
| 163 |
+
partial-credit signal of Section 3.3 - even if Reward 1 fires zero,
|
| 164 |
+
being *close* to the canonical solution still gets credit, smoothing
|
| 165 |
+
the reward landscape during early training.
|
| 166 |
+
"""
|
| 167 |
+
jx = _jaccard(parsed.x_errors, sample.pymatching_x_errors)
|
| 168 |
+
jz = _jaccard(parsed.z_errors, sample.pymatching_z_errors)
|
| 169 |
+
return 0.5 * (jx + jz)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# --------------------------------------------------------------------------- #
|
| 173 |
+
# Reward 4: format compliance #
|
| 174 |
+
# --------------------------------------------------------------------------- #
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def reward_format_compliance(parsed: ParseResult) -> float:
|
| 178 |
+
"""1.0 if both keys parsed, 0.5 if exactly one, 0.0 if neither."""
|
| 179 |
+
return parsed.format_score
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# --------------------------------------------------------------------------- #
|
| 183 |
+
# Reward 5: PyMatching beat-rate bonus #
|
| 184 |
+
# --------------------------------------------------------------------------- #
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def reward_pymatching_beat(
|
| 188 |
+
parsed: ParseResult,
|
| 189 |
+
sample: SyndromeSample,
|
| 190 |
+
layout: CircuitLayout,
|
| 191 |
+
) -> float:
|
| 192 |
+
"""1.0 iff PyMatching got this syndrome wrong AND the LLM got it right.
|
| 193 |
+
|
| 194 |
+
This is the headline metric (Section 3.5). Most of training it'll be
|
| 195 |
+
near zero; the trajectory of its mean over steps is the proof we've
|
| 196 |
+
moved past pure imitation.
|
| 197 |
+
"""
|
| 198 |
+
pm_correct = sample.pymatching_observable_pred == sample.actual_observable_flip
|
| 199 |
+
if pm_correct:
|
| 200 |
+
return 0.0
|
| 201 |
+
llm_implied = predicted_observable_flip(parsed.x_errors, layout)
|
| 202 |
+
return 1.0 if llm_implied == sample.actual_observable_flip else 0.0
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# --------------------------------------------------------------------------- #
|
| 206 |
+
# Combined reward #
|
| 207 |
+
# --------------------------------------------------------------------------- #
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@dataclass(frozen=True)
|
| 211 |
+
class RewardBreakdown:
|
| 212 |
+
"""Per-component scores plus the weighted total."""
|
| 213 |
+
|
| 214 |
+
logical_correction: float
|
| 215 |
+
syndrome_consistency: float
|
| 216 |
+
hamming_overlap: float
|
| 217 |
+
format_compliance: float
|
| 218 |
+
pymatching_beat: float
|
| 219 |
+
total: float
|
| 220 |
+
|
| 221 |
+
def as_dict(self) -> dict[str, float]:
|
| 222 |
+
return {
|
| 223 |
+
"logical_correction": self.logical_correction,
|
| 224 |
+
"syndrome_consistency": self.syndrome_consistency,
|
| 225 |
+
"hamming_overlap": self.hamming_overlap,
|
| 226 |
+
"format_compliance": self.format_compliance,
|
| 227 |
+
"pymatching_beat": self.pymatching_beat,
|
| 228 |
+
"total": self.total,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def compute_all_rewards(
|
| 233 |
+
parsed: ParseResult,
|
| 234 |
+
sample: SyndromeSample,
|
| 235 |
+
layout: CircuitLayout,
|
| 236 |
+
final_detector_supports: dict[int, frozenset[int]],
|
| 237 |
+
weights: dict[str, float] = REWARD_WEIGHTS,
|
| 238 |
+
) -> RewardBreakdown:
|
| 239 |
+
"""Compute all five rewards and the weighted total.
|
| 240 |
+
|
| 241 |
+
Returns a :class:`RewardBreakdown` whose ``as_dict`` is what the env's
|
| 242 |
+
``info`` payload contains. The trainer logs each component separately.
|
| 243 |
+
"""
|
| 244 |
+
r1 = reward_logical_correction(parsed, sample, layout)
|
| 245 |
+
r2 = reward_syndrome_consistency(parsed, sample, layout, final_detector_supports)
|
| 246 |
+
r3 = reward_hamming_overlap(parsed, sample, layout)
|
| 247 |
+
r4 = reward_format_compliance(parsed)
|
| 248 |
+
r5 = reward_pymatching_beat(parsed, sample, layout)
|
| 249 |
+
total = (
|
| 250 |
+
weights["logical_correction"] * r1
|
| 251 |
+
+ weights["syndrome_consistency"] * r2
|
| 252 |
+
+ weights["hamming_overlap"] * r3
|
| 253 |
+
+ weights["format_compliance"] * r4
|
| 254 |
+
+ weights["pymatching_beat"] * r5
|
| 255 |
+
)
|
| 256 |
+
total = max(0.0, min(1.0, total))
|
| 257 |
+
return RewardBreakdown(
|
| 258 |
+
logical_correction=r1,
|
| 259 |
+
syndrome_consistency=r2,
|
| 260 |
+
hamming_overlap=r3,
|
| 261 |
+
format_compliance=r4,
|
| 262 |
+
pymatching_beat=r5,
|
| 263 |
+
total=total,
|
| 264 |
+
)
|
qubit_medic/wandb_utils.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Central Weights & Biases integration for Qubit-Medic.
|
| 2 |
+
|
| 3 |
+
Design goals
|
| 4 |
+
------------
|
| 5 |
+
1. **Single source of truth** for the W&B project name, default tags, and
|
| 6 |
+
the ``config`` dump that every run logs. Trainers, eval scripts, and
|
| 7 |
+
notebooks all funnel through :func:`init_run` so dashboards always
|
| 8 |
+
line up.
|
| 9 |
+
|
| 10 |
+
2. **Safe to import without wandb installed.** The package's training
|
| 11 |
+
deps (``wandb``) live in ``requirements-train.txt`` and are absent on
|
| 12 |
+
the lean Spaces image. Anything in this module degrades gracefully
|
| 13 |
+
when the import fails - the rest of the project keeps working.
|
| 14 |
+
|
| 15 |
+
3. **Disable-by-env-var.** Set ``WANDB_DISABLED=1`` (or
|
| 16 |
+
``QUBIT_MEDIC_WANDB=0``) and every helper here becomes a no-op,
|
| 17 |
+
regardless of whether the package is installed. CI runs and offline
|
| 18 |
+
testing rely on this.
|
| 19 |
+
|
| 20 |
+
4. **Rich first-class logging.** We expose dedicated helpers for the
|
| 21 |
+
things this project cares about:
|
| 22 |
+
|
| 23 |
+
* Per-reward component scalars (5 lines per step, not just total)
|
| 24 |
+
* Curriculum-level moving averages (one line per level)
|
| 25 |
+
* Parse success / partial / failure rates
|
| 26 |
+
* Generation sample tables (prompt / completion / per-reward)
|
| 27 |
+
* Eval summary tables (one row per (policy, level))
|
| 28 |
+
|
| 29 |
+
The trainers and eval script only have to call these helpers; the
|
| 30 |
+
Pythonic context manager handles init, summary, and finish.
|
| 31 |
+
"""
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import contextlib
|
| 35 |
+
import dataclasses
|
| 36 |
+
import os
|
| 37 |
+
import socket
|
| 38 |
+
import sys
|
| 39 |
+
import time
|
| 40 |
+
from typing import Any, Iterable, Mapping, Optional, Sequence
|
| 41 |
+
|
| 42 |
+
from qubit_medic.config import (
|
| 43 |
+
CURRICULUM, GRPO_GEN_PER_PROMPT, GRPO_KL_COEF, GRPO_LR, GRPO_MAX_COMPLETION_LEN,
|
| 44 |
+
GRPO_MAX_PROMPT_LEN, GRPO_STEPS, LORA_ALPHA, LORA_R, LORA_TARGET_MODULES,
|
| 45 |
+
MODEL_ID, PRIMARY_SEED, REWARD_WEIGHTS, SFT_BATCH_SIZE, SFT_EPOCHS,
|
| 46 |
+
SFT_GRAD_ACCUM, SFT_LR, SFT_MAX_SEQ_LEN, WANDB_DEFAULT_TAGS, WANDB_ENTITY,
|
| 47 |
+
WANDB_LOG_GENERATIONS_EVERY, WANDB_PROJECT, WANDB_SAMPLE_GENERATIONS,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# --------------------------------------------------------------------------- #
|
| 52 |
+
# Lazy import + on/off toggle #
|
| 53 |
+
# --------------------------------------------------------------------------- #
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
_WANDB_MODULE = None
|
| 57 |
+
_RUN: Any = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _import_wandb():
|
| 61 |
+
"""Import wandb on first use. Returns ``None`` if it isn't installed."""
|
| 62 |
+
global _WANDB_MODULE
|
| 63 |
+
if _WANDB_MODULE is None:
|
| 64 |
+
try:
|
| 65 |
+
import wandb # type: ignore[import-not-found]
|
| 66 |
+
_WANDB_MODULE = wandb
|
| 67 |
+
except ImportError:
|
| 68 |
+
_WANDB_MODULE = False # sentinel: "tried, failed"
|
| 69 |
+
return _WANDB_MODULE if _WANDB_MODULE is not False else None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def is_disabled() -> bool:
|
| 73 |
+
"""Honours ``WANDB_DISABLED`` and ``QUBIT_MEDIC_WANDB=0``."""
|
| 74 |
+
if os.environ.get("WANDB_DISABLED", "").lower() in {"1", "true", "yes"}:
|
| 75 |
+
return True
|
| 76 |
+
if os.environ.get("QUBIT_MEDIC_WANDB", "1").lower() in {"0", "false", "no"}:
|
| 77 |
+
return True
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def is_available() -> bool:
|
| 82 |
+
"""``True`` iff wandb is importable AND not disabled by env var."""
|
| 83 |
+
return _import_wandb() is not None and not is_disabled()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_run():
|
| 87 |
+
"""Return the active W&B run object, or ``None`` if not initialised."""
|
| 88 |
+
return _RUN
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --------------------------------------------------------------------------- #
|
| 92 |
+
# Init / finish #
|
| 93 |
+
# --------------------------------------------------------------------------- #
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _system_metadata() -> dict:
|
| 97 |
+
"""Static metadata that's helpful on the dashboard but isn't a hyperparam."""
|
| 98 |
+
info = {
|
| 99 |
+
"python_version": sys.version.split()[0],
|
| 100 |
+
"hostname": socket.gethostname(),
|
| 101 |
+
"argv": " ".join(sys.argv),
|
| 102 |
+
"pid": os.getpid(),
|
| 103 |
+
}
|
| 104 |
+
try:
|
| 105 |
+
import torch
|
| 106 |
+
info["torch_version"] = torch.__version__
|
| 107 |
+
info["cuda_available"] = bool(torch.cuda.is_available())
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
info["cuda_device"] = torch.cuda.get_device_name(0)
|
| 110 |
+
except Exception:
|
| 111 |
+
pass
|
| 112 |
+
try:
|
| 113 |
+
import stim
|
| 114 |
+
info["stim_version"] = stim.__version__
|
| 115 |
+
except Exception:
|
| 116 |
+
pass
|
| 117 |
+
try:
|
| 118 |
+
import pymatching
|
| 119 |
+
info["pymatching_version"] = pymatching.__version__
|
| 120 |
+
except Exception:
|
| 121 |
+
pass
|
| 122 |
+
try:
|
| 123 |
+
import trl, transformers, peft
|
| 124 |
+
info["trl_version"] = trl.__version__
|
| 125 |
+
info["transformers_version"] = transformers.__version__
|
| 126 |
+
info["peft_version"] = peft.__version__
|
| 127 |
+
except Exception:
|
| 128 |
+
pass
|
| 129 |
+
return info
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _build_default_config(extra: Optional[Mapping[str, Any]] = None) -> dict:
|
| 133 |
+
"""The config every run logs - hyperparameters + reward weights + curriculum."""
|
| 134 |
+
cfg: dict[str, Any] = {
|
| 135 |
+
"model_id": MODEL_ID,
|
| 136 |
+
"primary_seed": PRIMARY_SEED,
|
| 137 |
+
"lora_r": LORA_R,
|
| 138 |
+
"lora_alpha": LORA_ALPHA,
|
| 139 |
+
"lora_target_modules": list(LORA_TARGET_MODULES),
|
| 140 |
+
"sft": {
|
| 141 |
+
"epochs": SFT_EPOCHS,
|
| 142 |
+
"batch_size": SFT_BATCH_SIZE,
|
| 143 |
+
"grad_accum": SFT_GRAD_ACCUM,
|
| 144 |
+
"lr": SFT_LR,
|
| 145 |
+
"max_seq_len": SFT_MAX_SEQ_LEN,
|
| 146 |
+
},
|
| 147 |
+
"grpo": {
|
| 148 |
+
"steps": GRPO_STEPS,
|
| 149 |
+
"gen_per_prompt": GRPO_GEN_PER_PROMPT,
|
| 150 |
+
"lr": GRPO_LR,
|
| 151 |
+
"kl_coef": GRPO_KL_COEF,
|
| 152 |
+
"max_prompt_len": GRPO_MAX_PROMPT_LEN,
|
| 153 |
+
"max_completion_len": GRPO_MAX_COMPLETION_LEN,
|
| 154 |
+
},
|
| 155 |
+
"reward_weights": dict(REWARD_WEIGHTS),
|
| 156 |
+
"curriculum": [
|
| 157 |
+
{
|
| 158 |
+
"name": lvl.name, "distance": lvl.distance, "rounds": lvl.rounds,
|
| 159 |
+
"p": lvl.p, "promotion_threshold": lvl.promotion_threshold,
|
| 160 |
+
}
|
| 161 |
+
for lvl in CURRICULUM
|
| 162 |
+
],
|
| 163 |
+
"system": _system_metadata(),
|
| 164 |
+
}
|
| 165 |
+
if extra:
|
| 166 |
+
cfg.update(extra)
|
| 167 |
+
return cfg
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def init_run(
|
| 171 |
+
run_name: str,
|
| 172 |
+
job_type: str,
|
| 173 |
+
*,
|
| 174 |
+
tags: Optional[Sequence[str]] = None,
|
| 175 |
+
extra_config: Optional[Mapping[str, Any]] = None,
|
| 176 |
+
notes: Optional[str] = None,
|
| 177 |
+
group: Optional[str] = None,
|
| 178 |
+
):
|
| 179 |
+
"""Initialise (or no-op) a W&B run.
|
| 180 |
+
|
| 181 |
+
Parameters
|
| 182 |
+
----------
|
| 183 |
+
run_name:
|
| 184 |
+
Human-readable run name (e.g. ``"sft-warmup-2026-04-25"``).
|
| 185 |
+
job_type:
|
| 186 |
+
One of ``"sft"``, ``"grpo"``, ``"eval"``, ``"format-test"``,
|
| 187 |
+
``"baseline"``. Used to group runs on the dashboard.
|
| 188 |
+
tags:
|
| 189 |
+
Extra tags appended to :data:`qubit_medic.config.WANDB_DEFAULT_TAGS`.
|
| 190 |
+
extra_config:
|
| 191 |
+
Hyperparameters specific to this run (e.g. SFT epochs override).
|
| 192 |
+
notes:
|
| 193 |
+
Free-text notes shown on the dashboard.
|
| 194 |
+
group:
|
| 195 |
+
Optional W&B group (used to bundle SFT + GRPO + eval runs of the
|
| 196 |
+
same experiment).
|
| 197 |
+
|
| 198 |
+
Returns
|
| 199 |
+
-------
|
| 200 |
+
The wandb Run object, or ``None`` if W&B is unavailable / disabled.
|
| 201 |
+
"""
|
| 202 |
+
global _RUN
|
| 203 |
+
wandb = _import_wandb()
|
| 204 |
+
if wandb is None or is_disabled():
|
| 205 |
+
if wandb is None:
|
| 206 |
+
print("[wandb] not installed; skipping init "
|
| 207 |
+
"(install with `pip install wandb` to enable logging)",
|
| 208 |
+
file=sys.stderr)
|
| 209 |
+
else:
|
| 210 |
+
print("[wandb] disabled by env var; skipping init", file=sys.stderr)
|
| 211 |
+
_RUN = None
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
all_tags = list(WANDB_DEFAULT_TAGS) + list(tags or [])
|
| 215 |
+
cfg = _build_default_config(extra=extra_config)
|
| 216 |
+
cfg["job_type"] = job_type
|
| 217 |
+
|
| 218 |
+
_RUN = wandb.init(
|
| 219 |
+
project=WANDB_PROJECT,
|
| 220 |
+
entity=WANDB_ENTITY,
|
| 221 |
+
name=run_name,
|
| 222 |
+
job_type=job_type,
|
| 223 |
+
tags=all_tags,
|
| 224 |
+
config=cfg,
|
| 225 |
+
notes=notes,
|
| 226 |
+
group=group,
|
| 227 |
+
reinit=True,
|
| 228 |
+
)
|
| 229 |
+
print(f"[wandb] run live at {_RUN.url}", file=sys.stderr)
|
| 230 |
+
return _RUN
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def finish_run() -> None:
|
| 234 |
+
"""Cleanly close the current W&B run, if any."""
|
| 235 |
+
global _RUN
|
| 236 |
+
wandb = _import_wandb()
|
| 237 |
+
if wandb is None or _RUN is None:
|
| 238 |
+
_RUN = None
|
| 239 |
+
return
|
| 240 |
+
try:
|
| 241 |
+
wandb.finish()
|
| 242 |
+
finally:
|
| 243 |
+
_RUN = None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@contextlib.contextmanager
|
| 247 |
+
def run_context(run_name: str, job_type: str, **kwargs):
|
| 248 |
+
"""Context-manager wrapper around :func:`init_run` / :func:`finish_run`."""
|
| 249 |
+
init_run(run_name, job_type, **kwargs)
|
| 250 |
+
try:
|
| 251 |
+
yield get_run()
|
| 252 |
+
finally:
|
| 253 |
+
finish_run()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# --------------------------------------------------------------------------- #
|
| 257 |
+
# Generic logging helpers #
|
| 258 |
+
# --------------------------------------------------------------------------- #
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def log(metrics: Mapping[str, Any], *, step: Optional[int] = None,
|
| 262 |
+
commit: bool = True) -> None:
|
| 263 |
+
"""No-op-safe ``wandb.log`` wrapper."""
|
| 264 |
+
wandb = _import_wandb()
|
| 265 |
+
if wandb is None or _RUN is None:
|
| 266 |
+
return
|
| 267 |
+
try:
|
| 268 |
+
wandb.log(dict(metrics), step=step, commit=commit)
|
| 269 |
+
except Exception as exc: # pragma: no cover - defensive
|
| 270 |
+
print(f"[wandb] log failed: {exc}", file=sys.stderr)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def update_summary(values: Mapping[str, Any]) -> None:
|
| 274 |
+
"""Write to ``run.summary`` (the run's headline numbers)."""
|
| 275 |
+
if _RUN is None:
|
| 276 |
+
return
|
| 277 |
+
try:
|
| 278 |
+
for k, v in values.items():
|
| 279 |
+
_RUN.summary[k] = v
|
| 280 |
+
except Exception as exc: # pragma: no cover
|
| 281 |
+
print(f"[wandb] summary update failed: {exc}", file=sys.stderr)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# --------------------------------------------------------------------------- #
|
| 285 |
+
# Project-specific helpers #
|
| 286 |
+
# --------------------------------------------------------------------------- #
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
_REWARD_KEYS = (
|
| 290 |
+
"logical_correction",
|
| 291 |
+
"syndrome_consistency",
|
| 292 |
+
"hamming_overlap",
|
| 293 |
+
"format_compliance",
|
| 294 |
+
"pymatching_beat",
|
| 295 |
+
"total",
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def log_reward_breakdown(
|
| 300 |
+
breakdowns: Sequence[Mapping[str, float]],
|
| 301 |
+
*,
|
| 302 |
+
step: Optional[int] = None,
|
| 303 |
+
prefix: str = "rl",
|
| 304 |
+
) -> None:
|
| 305 |
+
"""Log mean / min / max for each of the five reward components.
|
| 306 |
+
|
| 307 |
+
``breakdowns`` is a list of dicts, one per generation in the most-recent
|
| 308 |
+
GRPO step (length = ``GRPO_GEN_PER_PROMPT * batch``). We log mean and
|
| 309 |
+
standard deviation so the dashboard has both signal and noise.
|
| 310 |
+
"""
|
| 311 |
+
if not breakdowns or _RUN is None:
|
| 312 |
+
return
|
| 313 |
+
out: dict[str, float] = {}
|
| 314 |
+
for k in _REWARD_KEYS:
|
| 315 |
+
vals = [float(b.get(k, 0.0)) for b in breakdowns]
|
| 316 |
+
n = max(1, len(vals))
|
| 317 |
+
mean = sum(vals) / n
|
| 318 |
+
var = sum((v - mean) ** 2 for v in vals) / n
|
| 319 |
+
out[f"{prefix}/reward/{k}_mean"] = mean
|
| 320 |
+
out[f"{prefix}/reward/{k}_std"] = var ** 0.5
|
| 321 |
+
out[f"{prefix}/reward/{k}_max"] = max(vals)
|
| 322 |
+
out[f"{prefix}/reward/{k}_min"] = min(vals)
|
| 323 |
+
log(out, step=step)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def log_parse_stats(
|
| 327 |
+
parse_results: Iterable, # iterable of qubit_medic.prompts.ParseResult
|
| 328 |
+
*,
|
| 329 |
+
step: Optional[int] = None,
|
| 330 |
+
prefix: str = "rl",
|
| 331 |
+
) -> None:
|
| 332 |
+
"""Log parse-success / partial / failure rates for the most-recent batch."""
|
| 333 |
+
if _RUN is None:
|
| 334 |
+
return
|
| 335 |
+
parse_results = list(parse_results)
|
| 336 |
+
n = max(1, len(parse_results))
|
| 337 |
+
success = sum(1 for r in parse_results if getattr(r, "parse_success", False))
|
| 338 |
+
partial = sum(1 for r in parse_results
|
| 339 |
+
if not getattr(r, "parse_success", False)
|
| 340 |
+
and getattr(r, "parse_partial", False))
|
| 341 |
+
log({
|
| 342 |
+
f"{prefix}/parse/success_rate": success / n,
|
| 343 |
+
f"{prefix}/parse/partial_rate": partial / n,
|
| 344 |
+
f"{prefix}/parse/failure_rate": (n - success - partial) / n,
|
| 345 |
+
f"{prefix}/parse/sample_count": n,
|
| 346 |
+
}, step=step)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def log_curriculum(
|
| 350 |
+
curriculum_stats: Mapping[str, Mapping[str, float]],
|
| 351 |
+
*,
|
| 352 |
+
step: Optional[int] = None,
|
| 353 |
+
prefix: str = "rl",
|
| 354 |
+
) -> None:
|
| 355 |
+
"""Log the per-level moving-average from the env health endpoint.
|
| 356 |
+
|
| 357 |
+
``curriculum_stats`` is what
|
| 358 |
+
:meth:`qubit_medic.server.curriculum.CurriculumScheduler.snapshot`
|
| 359 |
+
returns; one inner dict per level with keys ``moving_mean`` / ``samples``.
|
| 360 |
+
"""
|
| 361 |
+
if _RUN is None or not curriculum_stats:
|
| 362 |
+
return
|
| 363 |
+
out: dict[str, float] = {}
|
| 364 |
+
for level_name, stats in curriculum_stats.items():
|
| 365 |
+
out[f"{prefix}/curriculum/{level_name}_mean"] = float(stats.get("moving_mean", 0.0))
|
| 366 |
+
out[f"{prefix}/curriculum/{level_name}_samples"] = float(stats.get("samples", 0.0))
|
| 367 |
+
log(out, step=step)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def log_generation_table(
|
| 371 |
+
rows: Sequence[Mapping[str, Any]],
|
| 372 |
+
*,
|
| 373 |
+
step: Optional[int],
|
| 374 |
+
table_name: str = "rl/generations",
|
| 375 |
+
columns: Optional[Sequence[str]] = None,
|
| 376 |
+
) -> None:
|
| 377 |
+
"""Log a wandb.Table of (prompt, completion, reward, ...) rows.
|
| 378 |
+
|
| 379 |
+
Each row is a flat dict; the column set is the union of all keys (or
|
| 380 |
+
the explicit ``columns`` arg). Used to surface qualitative samples
|
| 381 |
+
in addition to the scalar curves.
|
| 382 |
+
"""
|
| 383 |
+
wandb = _import_wandb()
|
| 384 |
+
if wandb is None or _RUN is None or not rows:
|
| 385 |
+
return
|
| 386 |
+
cols = list(columns) if columns is not None else sorted(
|
| 387 |
+
{k for row in rows for k in row.keys()}
|
| 388 |
+
)
|
| 389 |
+
try:
|
| 390 |
+
table = wandb.Table(columns=cols)
|
| 391 |
+
for row in rows:
|
| 392 |
+
table.add_data(*[row.get(c, None) for c in cols])
|
| 393 |
+
log({table_name: table}, step=step)
|
| 394 |
+
except Exception as exc: # pragma: no cover
|
| 395 |
+
print(f"[wandb] table log failed: {exc}", file=sys.stderr)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def log_eval_summary(
|
| 399 |
+
summary: Mapping[str, Any],
|
| 400 |
+
*,
|
| 401 |
+
step: Optional[int] = None,
|
| 402 |
+
prefix: str = "eval",
|
| 403 |
+
) -> None:
|
| 404 |
+
"""Log the dict produced by ``scripts/eval._summary`` as scalars."""
|
| 405 |
+
if _RUN is None:
|
| 406 |
+
return
|
| 407 |
+
out: dict[str, Any] = {}
|
| 408 |
+
for k, v in summary.items():
|
| 409 |
+
if isinstance(v, (int, float)):
|
| 410 |
+
out[f"{prefix}/{k}"] = v
|
| 411 |
+
log(out, step=step)
|
| 412 |
+
update_summary({f"{prefix}/{k}": v for k, v in summary.items()
|
| 413 |
+
if isinstance(v, (int, float, str, bool))})
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def log_artifact(
|
| 417 |
+
path: str, *, name: str, artifact_type: str,
|
| 418 |
+
description: Optional[str] = None,
|
| 419 |
+
) -> None:
|
| 420 |
+
"""Save a file or directory as a W&B artifact."""
|
| 421 |
+
wandb = _import_wandb()
|
| 422 |
+
if wandb is None or _RUN is None:
|
| 423 |
+
return
|
| 424 |
+
try:
|
| 425 |
+
art = wandb.Artifact(name, type=artifact_type, description=description)
|
| 426 |
+
if os.path.isdir(path):
|
| 427 |
+
art.add_dir(path)
|
| 428 |
+
else:
|
| 429 |
+
art.add_file(path)
|
| 430 |
+
_RUN.log_artifact(art)
|
| 431 |
+
except Exception as exc: # pragma: no cover
|
| 432 |
+
print(f"[wandb] artifact log failed: {exc}", file=sys.stderr)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
# --------------------------------------------------------------------------- #
|
| 436 |
+
# CLI integration helpers #
|
| 437 |
+
# --------------------------------------------------------------------------- #
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def derive_report_to(report_to: str) -> str:
|
| 441 |
+
"""Translate the user-facing ``--report-to`` flag.
|
| 442 |
+
|
| 443 |
+
If the user passes ``"wandb"`` but wandb is unavailable, fall back to
|
| 444 |
+
``"none"`` rather than crashing the trainer. Lets the same script run
|
| 445 |
+
on Colab (with wandb) and CI (without).
|
| 446 |
+
"""
|
| 447 |
+
if report_to == "wandb" and not is_available():
|
| 448 |
+
print("[wandb] requested via --report-to but unavailable; falling back to 'none'",
|
| 449 |
+
file=sys.stderr)
|
| 450 |
+
return "none"
|
| 451 |
+
return report_to
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def make_run_name(prefix: str, suffix: Optional[str] = None) -> str:
|
| 455 |
+
"""Build a default run name like ``sft-warmup-20260425-2105``."""
|
| 456 |
+
stamp = time.strftime("%Y%m%d-%H%M%S")
|
| 457 |
+
bits = [prefix, stamp]
|
| 458 |
+
if suffix:
|
| 459 |
+
bits.append(suffix)
|
| 460 |
+
return "-".join(bits)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
__all__ = [
|
| 464 |
+
"derive_report_to",
|
| 465 |
+
"finish_run",
|
| 466 |
+
"get_run",
|
| 467 |
+
"init_run",
|
| 468 |
+
"is_available",
|
| 469 |
+
"is_disabled",
|
| 470 |
+
"log",
|
| 471 |
+
"log_artifact",
|
| 472 |
+
"log_curriculum",
|
| 473 |
+
"log_eval_summary",
|
| 474 |
+
"log_generation_table",
|
| 475 |
+
"log_parse_stats",
|
| 476 |
+
"log_reward_breakdown",
|
| 477 |
+
"make_run_name",
|
| 478 |
+
"run_context",
|
| 479 |
+
"update_summary",
|
| 480 |
+
"WANDB_LOG_GENERATIONS_EVERY",
|
| 481 |
+
"WANDB_SAMPLE_GENERATIONS",
|
| 482 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pin the versions called out in the plan (Section 1.1).
|
| 2 |
+
# These versions are known compatible with Python 3.11 + CUDA 12.x.
|
| 3 |
+
# DO NOT bump without re-running scripts/validate_env.py.
|
| 4 |
+
|
| 5 |
+
# --- Quantum simulation ---
|
| 6 |
+
stim>=1.13,<2.0
|
| 7 |
+
pymatching>=2.2,<3.0
|
| 8 |
+
|
| 9 |
+
# --- Environment / serving ---
|
| 10 |
+
fastapi>=0.110
|
| 11 |
+
uvicorn[standard]>=0.27
|
| 12 |
+
pydantic>=2.5,<3.0
|
| 13 |
+
httpx>=0.27
|
| 14 |
+
numpy>=1.26,<2.1
|
| 15 |
+
|
| 16 |
+
# --- Plotting (used by scripts/plot_results.py) ---
|
| 17 |
+
matplotlib>=3.8
|
| 18 |
+
pillow>=10
|
| 19 |
+
|
| 20 |
+
# --- Test runner ---
|
| 21 |
+
pytest>=8
|
| 22 |
+
|
| 23 |
+
# --- OpenEnv (HuggingFace's RL-env framework). Required by the submission
|
| 24 |
+
# guidelines ("Use OpenEnv (latest release). Build on top of the
|
| 25 |
+
# framework; don't reinvent the wheel."). Our server wraps the
|
| 26 |
+
# DecoderEnvironment with `openenv.core.Environment`-compatible
|
| 27 |
+
# Action/Observation/State models so TRL can drive it via
|
| 28 |
+
# `environment_factory=` (see qubit_medic/server/openenv_adapter.py
|
| 29 |
+
# and scripts/train_grpo.py). Lightweight; safe in the Spaces image.
|
| 30 |
+
openenv-core>=0.2.1
|
| 31 |
+
|
| 32 |
+
# --- Training (omit on CPU-only deploy) ---
|
| 33 |
+
# Heavy ML deps live in requirements-train.txt; the env server itself does
|
| 34 |
+
# not import any of them. Keeping them out of the Docker image keeps the
|
| 35 |
+
# Spaces upload under the free-tier image-size limit.
|