File size: 12,292 Bytes
195f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""OpenEnv-compliant adapter around :class:`DecoderEnvironment`.

This wrapper satisfies the submission requirement *"Use OpenEnv (latest
release). Build on top of the framework; don't reinvent the wheel."* by
exposing our underlying :class:`qubit_medic.server.environment.DecoderEnvironment`
through the official ``openenv.core.Environment`` base class.

The adapter is intentionally thin: it just translates between OpenEnv's
``Action`` / ``Observation`` / ``State`` Pydantic shapes and our internal
``DecoderObservation`` / ``DecoderAction`` / ``StepResult``. All the
physics, reward scoring, curriculum, and episode bookkeeping continue to
live in :class:`DecoderEnvironment` - that code is *the* tested,
production path.

Usage
-----

The OpenEnv-compliant FastAPI app is created with::

    from openenv.core import create_fastapi_app
    from qubit_medic.server.openenv_adapter import (
        QubitMedicEnvironment, QubitMedicAction, QubitMedicObservation,
    )

    app = create_fastapi_app(
        env=QubitMedicEnvironment,
        action_cls=QubitMedicAction,
        observation_cls=QubitMedicObservation,
    )

This registers the canonical OpenEnv routes:

* ``POST /reset``    - body ``{"seed": int?, "episode_id": str?}``
* ``POST /step``     - body ``{"action": {...QubitMedicAction...},
  "timeout_s": float?, "request_id": str?}``
* ``GET  /state``    - returns the current :class:`QubitMedicState`
* ``GET  /health``   - liveness probe
* ``GET  /schema``   - JSON Schema for the action/observation models
* ``GET  /metadata`` - environment metadata
* ``POST /mcp``      - Model Context Protocol endpoint
* ``GET  /docs``     - Swagger UI (auto-generated by FastAPI)

We additionally mount our own ``/healthz`` (Day-0 contract) and
``/decode`` (PyMatching baseline demo) on the returned app from
``qubit_medic.server.app``.
"""
from __future__ import annotations

from typing import Any, Optional

from openenv.core import Action, Environment, Observation, State
from openenv.core.env_server.types import EnvironmentMetadata
from pydantic import ConfigDict, Field

from qubit_medic.server.environment import DecoderEnvironment


# --------------------------------------------------------------------------- #
# Process-wide singleton                                                      #
# --------------------------------------------------------------------------- #
# OpenEnv's HTTP server (simulation mode) instantiates a *fresh* Environment
# via the factory on every /reset and /step call. Our episode bookkeeping
# (the `_active` dict) lives inside DecoderEnvironment, so we route every
# QubitMedicEnvironment instance through the same DecoderEnvironment.
# This keeps reset() -> step() pairing intact across stateless HTTP calls
# while remaining fully compatible with OpenEnv's WebSocket session model
# (each WS session still gets its own QubitMedicEnvironment wrapper).

_INNER_SINGLETON: Optional[DecoderEnvironment] = None


def _get_shared_inner() -> DecoderEnvironment:
    """Return the process-wide DecoderEnvironment, building it lazily."""
    global _INNER_SINGLETON
    if _INNER_SINGLETON is None:
        env = DecoderEnvironment()
        env._cache_for("L1_warmup")  # noqa: SLF001 - intentional pre-warm
        env._cache_for("L2_target")  # noqa: SLF001
        _INNER_SINGLETON = env
    return _INNER_SINGLETON


# --------------------------------------------------------------------------- #
# OpenEnv-flavoured Action / Observation / State                              #
# --------------------------------------------------------------------------- #


class QubitMedicAction(Action):
    """LLM-emitted action: the raw text the model generated.

    The server parses this into ``x_error_qubits`` / ``z_error_qubits`` via
    :func:`qubit_medic.prompts.parse_action`. We keep the wire format
    *just the raw string* so the server retains full control over parsing
    (and so the trainer's reward function can audit unparseable outputs).

    The trainer is also free to populate ``parsed_x_errors`` /
    ``parsed_z_errors`` directly when it wants to bypass the LLM (useful
    for baseline policies and unit tests).
    """

    # Inherit Action.model_config (extra='forbid', validate_assignment=True).
    raw_response: str = Field(
        default="",
        description="Raw LLM completion text. Server parses to x/z error lists.",
    )
    parsed_x_errors: Optional[list[int]] = Field(
        default=None,
        description="Optional pre-parsed X-error qubit ids (LLM-space). "
                    "When provided, the server skips text parsing.",
    )
    parsed_z_errors: Optional[list[int]] = Field(
        default=None,
        description="Optional pre-parsed Z-error qubit ids (LLM-space).",
    )
    episode_id: Optional[int] = Field(
        default=None,
        description="Server-assigned episode id from the matching reset(). "
                    "If omitted, the most-recent active episode is used.",
    )


class QubitMedicObservation(Observation):
    """OpenEnv observation - mirrors :class:`DecoderObservation` plus the
    standard OpenEnv ``done`` / ``reward`` fields.

    The ``info`` dict (returned by ``step``) carries the per-component
    reward breakdown, the ground-truth observable flip, and the PyMatching
    baseline prediction so the trainer can score auxiliary metrics.
    """

    model_config = ConfigDict(extra="forbid", validate_assignment=True,
                              arbitrary_types_allowed=True)

    prompt: str = Field(default="", description="Pre-formatted LLM prompt.")
    syndrome_bits: list[int] = Field(default_factory=list,
                                      description="Detector activations (0/1).")
    distance: int = Field(default=0, description="Code distance for this episode.")
    rounds: int = Field(default=0, description="Number of stabilizer rounds.")
    p: float = Field(default=0.0, description="SI1000 base error rate.")
    curriculum_level: str = Field(default="",
                                   description="Curriculum level name.")
    episode_id: int = Field(default=0,
                             description="Server-assigned episode counter.")
    dem_digest: str = Field(default="",
                             description="Short hash of the detector error model.")
    info: dict[str, Any] = Field(default_factory=dict,
                                  description="Per-step extras (reward "
                                              "breakdown, ground-truth flip, "
                                              "PyMatching baseline, etc.).")


class QubitMedicState(State):
    """Externally-visible state. We expose only the curriculum + episode
    counters; physics-truth fields stay server-side to prevent reward
    hacking (see :mod:`qubit_medic.models.DecoderState` doc-comment)."""

    model_config = ConfigDict(extra="allow", validate_assignment=True,
                              arbitrary_types_allowed=True)

    episodes_started: int = 0
    active_episodes: int = 0
    cached_levels: list[str] = Field(default_factory=list)
    curriculum: dict[str, Any] = Field(default_factory=dict)
    last_reward_breakdown: Optional[dict[str, float]] = None


# --------------------------------------------------------------------------- #
# Environment wrapper                                                         #
# --------------------------------------------------------------------------- #


class QubitMedicEnvironment(Environment[QubitMedicAction,
                                         QubitMedicObservation,
                                         QubitMedicState]):
    """OpenEnv-compliant view of :class:`DecoderEnvironment`.

    Single-step episodes (``done=True`` after every ``step``). The OpenEnv
    HTTP server gets a fresh instance per WebSocket session if
    ``SUPPORTS_CONCURRENT_SESSIONS=True``; we set it to ``False`` because
    our DecoderEnvironment uses a single Stim cache + a coarse lock, which
    is simpler than per-session state and good enough for the GRPO
    training loop.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = False

    def __init__(self) -> None:
        super().__init__()
        # Share the underlying DecoderEnvironment across every wrapper
        # instance the HTTP server creates - see _get_shared_inner.
        self._inner = _get_shared_inner()
        self._last_episode_id: Optional[int] = None
        self._last_reward_breakdown: Optional[dict[str, float]] = None

    # ----- abstract API --------------------------------------------------- #

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> QubitMedicObservation:
        forced_level = kwargs.get("forced_level")
        obs = self._inner.reset(seed=seed, forced_level=forced_level)
        self._last_episode_id = obs.episode_id
        self._last_reward_breakdown = None
        return QubitMedicObservation(
            prompt=obs.prompt,
            syndrome_bits=list(obs.syndrome_bits),
            distance=obs.distance,
            rounds=obs.rounds,
            p=obs.p,
            curriculum_level=obs.curriculum_level,
            episode_id=obs.episode_id,
            dem_digest=obs.dem_digest,
            done=False,
            reward=None,
            info={"event": "reset"},
        )

    def step(
        self,
        action: QubitMedicAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> QubitMedicObservation:
        ep = action.episode_id if action.episode_id is not None else self._last_episode_id
        if ep is None:
            raise RuntimeError(
                "step() called before reset(); no active episode to score."
            )

        # If the trainer pre-parsed the action, format a synthetic raw
        # response in the canonical "X: ... | Z: ..." shape so the server's
        # parser produces the same x/z lists.
        if action.parsed_x_errors is not None or action.parsed_z_errors is not None:
            xs = action.parsed_x_errors or []
            zs = action.parsed_z_errors or []
            raw = f"<answer>X: {','.join(map(str, xs))} | Z: {','.join(map(str, zs))}</answer>"
        else:
            raw = action.raw_response

        result = self._inner.step(raw_response=raw, episode_id=ep)
        self._last_reward_breakdown = result.info.get("rewards")

        return QubitMedicObservation(
            prompt=result.observation.prompt,
            syndrome_bits=list(result.observation.syndrome_bits),
            distance=result.observation.distance,
            rounds=result.observation.rounds,
            p=result.observation.p,
            curriculum_level=result.observation.curriculum_level,
            episode_id=result.observation.episode_id,
            dem_digest=result.observation.dem_digest,
            done=result.done,
            reward=float(result.reward),
            info=result.info,
        )

    @property
    def state(self) -> QubitMedicState:
        h = self._inner.health()
        return QubitMedicState(
            episode_id=str(self._last_episode_id)
                if self._last_episode_id is not None else None,
            step_count=int(h.get("episodes_started", 0)),
            episodes_started=int(h.get("episodes_started", 0)),
            active_episodes=int(h.get("active_episodes", 0)),
            cached_levels=list(h.get("cached_levels", [])),
            curriculum=dict(h.get("curriculum", {})),
            last_reward_breakdown=self._last_reward_breakdown,
        )

    # ----- nice-to-haves -------------------------------------------------- #

    def get_metadata(self) -> EnvironmentMetadata:
        return EnvironmentMetadata(
            name="QubitMedicEnvironment",
            description=(
                "RL training environment for LLM-based quantum error-"
                "correction decoders. Built on Stim + PyMatching. Five "
                "verifiable rewards (logical correction, syndrome consistency, "
                "Hamming overlap, format compliance, PyMatching beat-rate)."
            ),
            version="1.0.0",
        )

    def close(self) -> None:  # nothing to clean up
        return None