QuantumScribe / qubit_medic /server /openenv_adapter.py
ronitraj's picture
Real env: openenv-core wrapped DecoderEnvironment + /healthz + /decode
195f87e verified
"""OpenEnv-compliant adapter around :class:`DecoderEnvironment`.
This wrapper satisfies the submission requirement *"Use OpenEnv (latest
release). Build on top of the framework; don't reinvent the wheel."* by
exposing our underlying :class:`qubit_medic.server.environment.DecoderEnvironment`
through the official ``openenv.core.Environment`` base class.
The adapter is intentionally thin: it just translates between OpenEnv's
``Action`` / ``Observation`` / ``State`` Pydantic shapes and our internal
``DecoderObservation`` / ``DecoderAction`` / ``StepResult``. All the
physics, reward scoring, curriculum, and episode bookkeeping continue to
live in :class:`DecoderEnvironment` - that code is *the* tested,
production path.
Usage
-----
The OpenEnv-compliant FastAPI app is created with::
from openenv.core import create_fastapi_app
from qubit_medic.server.openenv_adapter import (
QubitMedicEnvironment, QubitMedicAction, QubitMedicObservation,
)
app = create_fastapi_app(
env=QubitMedicEnvironment,
action_cls=QubitMedicAction,
observation_cls=QubitMedicObservation,
)
This registers the canonical OpenEnv routes:
* ``POST /reset`` - body ``{"seed": int?, "episode_id": str?}``
* ``POST /step`` - body ``{"action": {...QubitMedicAction...},
"timeout_s": float?, "request_id": str?}``
* ``GET /state`` - returns the current :class:`QubitMedicState`
* ``GET /health`` - liveness probe
* ``GET /schema`` - JSON Schema for the action/observation models
* ``GET /metadata`` - environment metadata
* ``POST /mcp`` - Model Context Protocol endpoint
* ``GET /docs`` - Swagger UI (auto-generated by FastAPI)
We additionally mount our own ``/healthz`` (Day-0 contract) and
``/decode`` (PyMatching baseline demo) on the returned app from
``qubit_medic.server.app``.
"""
from __future__ import annotations
from typing import Any, Optional
from openenv.core import Action, Environment, Observation, State
from openenv.core.env_server.types import EnvironmentMetadata
from pydantic import ConfigDict, Field
from qubit_medic.server.environment import DecoderEnvironment
# --------------------------------------------------------------------------- #
# Process-wide singleton #
# --------------------------------------------------------------------------- #
# OpenEnv's HTTP server (simulation mode) instantiates a *fresh* Environment
# via the factory on every /reset and /step call. Our episode bookkeeping
# (the `_active` dict) lives inside DecoderEnvironment, so we route every
# QubitMedicEnvironment instance through the same DecoderEnvironment.
# This keeps reset() -> step() pairing intact across stateless HTTP calls
# while remaining fully compatible with OpenEnv's WebSocket session model
# (each WS session still gets its own QubitMedicEnvironment wrapper).
_INNER_SINGLETON: Optional[DecoderEnvironment] = None
def _get_shared_inner() -> DecoderEnvironment:
"""Return the process-wide DecoderEnvironment, building it lazily."""
global _INNER_SINGLETON
if _INNER_SINGLETON is None:
env = DecoderEnvironment()
env._cache_for("L1_warmup") # noqa: SLF001 - intentional pre-warm
env._cache_for("L2_target") # noqa: SLF001
_INNER_SINGLETON = env
return _INNER_SINGLETON
# --------------------------------------------------------------------------- #
# OpenEnv-flavoured Action / Observation / State #
# --------------------------------------------------------------------------- #
class QubitMedicAction(Action):
"""LLM-emitted action: the raw text the model generated.
The server parses this into ``x_error_qubits`` / ``z_error_qubits`` via
:func:`qubit_medic.prompts.parse_action`. We keep the wire format
*just the raw string* so the server retains full control over parsing
(and so the trainer's reward function can audit unparseable outputs).
The trainer is also free to populate ``parsed_x_errors`` /
``parsed_z_errors`` directly when it wants to bypass the LLM (useful
for baseline policies and unit tests).
"""
# Inherit Action.model_config (extra='forbid', validate_assignment=True).
raw_response: str = Field(
default="",
description="Raw LLM completion text. Server parses to x/z error lists.",
)
parsed_x_errors: Optional[list[int]] = Field(
default=None,
description="Optional pre-parsed X-error qubit ids (LLM-space). "
"When provided, the server skips text parsing.",
)
parsed_z_errors: Optional[list[int]] = Field(
default=None,
description="Optional pre-parsed Z-error qubit ids (LLM-space).",
)
episode_id: Optional[int] = Field(
default=None,
description="Server-assigned episode id from the matching reset(). "
"If omitted, the most-recent active episode is used.",
)
class QubitMedicObservation(Observation):
"""OpenEnv observation - mirrors :class:`DecoderObservation` plus the
standard OpenEnv ``done`` / ``reward`` fields.
The ``info`` dict (returned by ``step``) carries the per-component
reward breakdown, the ground-truth observable flip, and the PyMatching
baseline prediction so the trainer can score auxiliary metrics.
"""
model_config = ConfigDict(extra="forbid", validate_assignment=True,
arbitrary_types_allowed=True)
prompt: str = Field(default="", description="Pre-formatted LLM prompt.")
syndrome_bits: list[int] = Field(default_factory=list,
description="Detector activations (0/1).")
distance: int = Field(default=0, description="Code distance for this episode.")
rounds: int = Field(default=0, description="Number of stabilizer rounds.")
p: float = Field(default=0.0, description="SI1000 base error rate.")
curriculum_level: str = Field(default="",
description="Curriculum level name.")
episode_id: int = Field(default=0,
description="Server-assigned episode counter.")
dem_digest: str = Field(default="",
description="Short hash of the detector error model.")
info: dict[str, Any] = Field(default_factory=dict,
description="Per-step extras (reward "
"breakdown, ground-truth flip, "
"PyMatching baseline, etc.).")
class QubitMedicState(State):
"""Externally-visible state. We expose only the curriculum + episode
counters; physics-truth fields stay server-side to prevent reward
hacking (see :mod:`qubit_medic.models.DecoderState` doc-comment)."""
model_config = ConfigDict(extra="allow", validate_assignment=True,
arbitrary_types_allowed=True)
episodes_started: int = 0
active_episodes: int = 0
cached_levels: list[str] = Field(default_factory=list)
curriculum: dict[str, Any] = Field(default_factory=dict)
last_reward_breakdown: Optional[dict[str, float]] = None
# --------------------------------------------------------------------------- #
# Environment wrapper #
# --------------------------------------------------------------------------- #
class QubitMedicEnvironment(Environment[QubitMedicAction,
QubitMedicObservation,
QubitMedicState]):
"""OpenEnv-compliant view of :class:`DecoderEnvironment`.
Single-step episodes (``done=True`` after every ``step``). The OpenEnv
HTTP server gets a fresh instance per WebSocket session if
``SUPPORTS_CONCURRENT_SESSIONS=True``; we set it to ``False`` because
our DecoderEnvironment uses a single Stim cache + a coarse lock, which
is simpler than per-session state and good enough for the GRPO
training loop.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = False
def __init__(self) -> None:
super().__init__()
# Share the underlying DecoderEnvironment across every wrapper
# instance the HTTP server creates - see _get_shared_inner.
self._inner = _get_shared_inner()
self._last_episode_id: Optional[int] = None
self._last_reward_breakdown: Optional[dict[str, float]] = None
# ----- abstract API --------------------------------------------------- #
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> QubitMedicObservation:
forced_level = kwargs.get("forced_level")
obs = self._inner.reset(seed=seed, forced_level=forced_level)
self._last_episode_id = obs.episode_id
self._last_reward_breakdown = None
return QubitMedicObservation(
prompt=obs.prompt,
syndrome_bits=list(obs.syndrome_bits),
distance=obs.distance,
rounds=obs.rounds,
p=obs.p,
curriculum_level=obs.curriculum_level,
episode_id=obs.episode_id,
dem_digest=obs.dem_digest,
done=False,
reward=None,
info={"event": "reset"},
)
def step(
self,
action: QubitMedicAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> QubitMedicObservation:
ep = action.episode_id if action.episode_id is not None else self._last_episode_id
if ep is None:
raise RuntimeError(
"step() called before reset(); no active episode to score."
)
# If the trainer pre-parsed the action, format a synthetic raw
# response in the canonical "X: ... | Z: ..." shape so the server's
# parser produces the same x/z lists.
if action.parsed_x_errors is not None or action.parsed_z_errors is not None:
xs = action.parsed_x_errors or []
zs = action.parsed_z_errors or []
raw = f"<answer>X: {','.join(map(str, xs))} | Z: {','.join(map(str, zs))}</answer>"
else:
raw = action.raw_response
result = self._inner.step(raw_response=raw, episode_id=ep)
self._last_reward_breakdown = result.info.get("rewards")
return QubitMedicObservation(
prompt=result.observation.prompt,
syndrome_bits=list(result.observation.syndrome_bits),
distance=result.observation.distance,
rounds=result.observation.rounds,
p=result.observation.p,
curriculum_level=result.observation.curriculum_level,
episode_id=result.observation.episode_id,
dem_digest=result.observation.dem_digest,
done=result.done,
reward=float(result.reward),
info=result.info,
)
@property
def state(self) -> QubitMedicState:
h = self._inner.health()
return QubitMedicState(
episode_id=str(self._last_episode_id)
if self._last_episode_id is not None else None,
step_count=int(h.get("episodes_started", 0)),
episodes_started=int(h.get("episodes_started", 0)),
active_episodes=int(h.get("active_episodes", 0)),
cached_levels=list(h.get("cached_levels", [])),
curriculum=dict(h.get("curriculum", {})),
last_reward_breakdown=self._last_reward_breakdown,
)
# ----- nice-to-haves -------------------------------------------------- #
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name="QubitMedicEnvironment",
description=(
"RL training environment for LLM-based quantum error-"
"correction decoders. Built on Stim + PyMatching. Five "
"verifiable rewards (logical correction, syndrome consistency, "
"Hamming overlap, format compliance, PyMatching beat-rate)."
),
version="1.0.0",
)
def close(self) -> None: # nothing to clean up
return None