"""OpenEnv-compliant adapter around :class:`DecoderEnvironment`. This wrapper satisfies the submission requirement *"Use OpenEnv (latest release). Build on top of the framework; don't reinvent the wheel."* by exposing our underlying :class:`qubit_medic.server.environment.DecoderEnvironment` through the official ``openenv.core.Environment`` base class. The adapter is intentionally thin: it just translates between OpenEnv's ``Action`` / ``Observation`` / ``State`` Pydantic shapes and our internal ``DecoderObservation`` / ``DecoderAction`` / ``StepResult``. All the physics, reward scoring, curriculum, and episode bookkeeping continue to live in :class:`DecoderEnvironment` - that code is *the* tested, production path. Usage ----- The OpenEnv-compliant FastAPI app is created with:: from openenv.core import create_fastapi_app from qubit_medic.server.openenv_adapter import ( QubitMedicEnvironment, QubitMedicAction, QubitMedicObservation, ) app = create_fastapi_app( env=QubitMedicEnvironment, action_cls=QubitMedicAction, observation_cls=QubitMedicObservation, ) This registers the canonical OpenEnv routes: * ``POST /reset`` - body ``{"seed": int?, "episode_id": str?}`` * ``POST /step`` - body ``{"action": {...QubitMedicAction...}, "timeout_s": float?, "request_id": str?}`` * ``GET /state`` - returns the current :class:`QubitMedicState` * ``GET /health`` - liveness probe * ``GET /schema`` - JSON Schema for the action/observation models * ``GET /metadata`` - environment metadata * ``POST /mcp`` - Model Context Protocol endpoint * ``GET /docs`` - Swagger UI (auto-generated by FastAPI) We additionally mount our own ``/healthz`` (Day-0 contract) and ``/decode`` (PyMatching baseline demo) on the returned app from ``qubit_medic.server.app``. """ from __future__ import annotations from typing import Any, Optional from openenv.core import Action, Environment, Observation, State from openenv.core.env_server.types import EnvironmentMetadata from pydantic import ConfigDict, Field from qubit_medic.server.environment import DecoderEnvironment # --------------------------------------------------------------------------- # # Process-wide singleton # # --------------------------------------------------------------------------- # # OpenEnv's HTTP server (simulation mode) instantiates a *fresh* Environment # via the factory on every /reset and /step call. Our episode bookkeeping # (the `_active` dict) lives inside DecoderEnvironment, so we route every # QubitMedicEnvironment instance through the same DecoderEnvironment. # This keeps reset() -> step() pairing intact across stateless HTTP calls # while remaining fully compatible with OpenEnv's WebSocket session model # (each WS session still gets its own QubitMedicEnvironment wrapper). _INNER_SINGLETON: Optional[DecoderEnvironment] = None def _get_shared_inner() -> DecoderEnvironment: """Return the process-wide DecoderEnvironment, building it lazily.""" global _INNER_SINGLETON if _INNER_SINGLETON is None: env = DecoderEnvironment() env._cache_for("L1_warmup") # noqa: SLF001 - intentional pre-warm env._cache_for("L2_target") # noqa: SLF001 _INNER_SINGLETON = env return _INNER_SINGLETON # --------------------------------------------------------------------------- # # OpenEnv-flavoured Action / Observation / State # # --------------------------------------------------------------------------- # class QubitMedicAction(Action): """LLM-emitted action: the raw text the model generated. The server parses this into ``x_error_qubits`` / ``z_error_qubits`` via :func:`qubit_medic.prompts.parse_action`. We keep the wire format *just the raw string* so the server retains full control over parsing (and so the trainer's reward function can audit unparseable outputs). The trainer is also free to populate ``parsed_x_errors`` / ``parsed_z_errors`` directly when it wants to bypass the LLM (useful for baseline policies and unit tests). """ # Inherit Action.model_config (extra='forbid', validate_assignment=True). raw_response: str = Field( default="", description="Raw LLM completion text. Server parses to x/z error lists.", ) parsed_x_errors: Optional[list[int]] = Field( default=None, description="Optional pre-parsed X-error qubit ids (LLM-space). " "When provided, the server skips text parsing.", ) parsed_z_errors: Optional[list[int]] = Field( default=None, description="Optional pre-parsed Z-error qubit ids (LLM-space).", ) episode_id: Optional[int] = Field( default=None, description="Server-assigned episode id from the matching reset(). " "If omitted, the most-recent active episode is used.", ) class QubitMedicObservation(Observation): """OpenEnv observation - mirrors :class:`DecoderObservation` plus the standard OpenEnv ``done`` / ``reward`` fields. The ``info`` dict (returned by ``step``) carries the per-component reward breakdown, the ground-truth observable flip, and the PyMatching baseline prediction so the trainer can score auxiliary metrics. """ model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True) prompt: str = Field(default="", description="Pre-formatted LLM prompt.") syndrome_bits: list[int] = Field(default_factory=list, description="Detector activations (0/1).") distance: int = Field(default=0, description="Code distance for this episode.") rounds: int = Field(default=0, description="Number of stabilizer rounds.") p: float = Field(default=0.0, description="SI1000 base error rate.") curriculum_level: str = Field(default="", description="Curriculum level name.") episode_id: int = Field(default=0, description="Server-assigned episode counter.") dem_digest: str = Field(default="", description="Short hash of the detector error model.") info: dict[str, Any] = Field(default_factory=dict, description="Per-step extras (reward " "breakdown, ground-truth flip, " "PyMatching baseline, etc.).") class QubitMedicState(State): """Externally-visible state. We expose only the curriculum + episode counters; physics-truth fields stay server-side to prevent reward hacking (see :mod:`qubit_medic.models.DecoderState` doc-comment).""" model_config = ConfigDict(extra="allow", validate_assignment=True, arbitrary_types_allowed=True) episodes_started: int = 0 active_episodes: int = 0 cached_levels: list[str] = Field(default_factory=list) curriculum: dict[str, Any] = Field(default_factory=dict) last_reward_breakdown: Optional[dict[str, float]] = None # --------------------------------------------------------------------------- # # Environment wrapper # # --------------------------------------------------------------------------- # class QubitMedicEnvironment(Environment[QubitMedicAction, QubitMedicObservation, QubitMedicState]): """OpenEnv-compliant view of :class:`DecoderEnvironment`. Single-step episodes (``done=True`` after every ``step``). The OpenEnv HTTP server gets a fresh instance per WebSocket session if ``SUPPORTS_CONCURRENT_SESSIONS=True``; we set it to ``False`` because our DecoderEnvironment uses a single Stim cache + a coarse lock, which is simpler than per-session state and good enough for the GRPO training loop. """ SUPPORTS_CONCURRENT_SESSIONS: bool = False def __init__(self) -> None: super().__init__() # Share the underlying DecoderEnvironment across every wrapper # instance the HTTP server creates - see _get_shared_inner. self._inner = _get_shared_inner() self._last_episode_id: Optional[int] = None self._last_reward_breakdown: Optional[dict[str, float]] = None # ----- abstract API --------------------------------------------------- # def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> QubitMedicObservation: forced_level = kwargs.get("forced_level") obs = self._inner.reset(seed=seed, forced_level=forced_level) self._last_episode_id = obs.episode_id self._last_reward_breakdown = None return QubitMedicObservation( prompt=obs.prompt, syndrome_bits=list(obs.syndrome_bits), distance=obs.distance, rounds=obs.rounds, p=obs.p, curriculum_level=obs.curriculum_level, episode_id=obs.episode_id, dem_digest=obs.dem_digest, done=False, reward=None, info={"event": "reset"}, ) def step( self, action: QubitMedicAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> QubitMedicObservation: ep = action.episode_id if action.episode_id is not None else self._last_episode_id if ep is None: raise RuntimeError( "step() called before reset(); no active episode to score." ) # If the trainer pre-parsed the action, format a synthetic raw # response in the canonical "X: ... | Z: ..." shape so the server's # parser produces the same x/z lists. if action.parsed_x_errors is not None or action.parsed_z_errors is not None: xs = action.parsed_x_errors or [] zs = action.parsed_z_errors or [] raw = f"X: {','.join(map(str, xs))} | Z: {','.join(map(str, zs))}" else: raw = action.raw_response result = self._inner.step(raw_response=raw, episode_id=ep) self._last_reward_breakdown = result.info.get("rewards") return QubitMedicObservation( prompt=result.observation.prompt, syndrome_bits=list(result.observation.syndrome_bits), distance=result.observation.distance, rounds=result.observation.rounds, p=result.observation.p, curriculum_level=result.observation.curriculum_level, episode_id=result.observation.episode_id, dem_digest=result.observation.dem_digest, done=result.done, reward=float(result.reward), info=result.info, ) @property def state(self) -> QubitMedicState: h = self._inner.health() return QubitMedicState( episode_id=str(self._last_episode_id) if self._last_episode_id is not None else None, step_count=int(h.get("episodes_started", 0)), episodes_started=int(h.get("episodes_started", 0)), active_episodes=int(h.get("active_episodes", 0)), cached_levels=list(h.get("cached_levels", [])), curriculum=dict(h.get("curriculum", {})), last_reward_breakdown=self._last_reward_breakdown, ) # ----- nice-to-haves -------------------------------------------------- # def get_metadata(self) -> EnvironmentMetadata: return EnvironmentMetadata( name="QubitMedicEnvironment", description=( "RL training environment for LLM-based quantum error-" "correction decoders. Built on Stim + PyMatching. Five " "verifiable rewards (logical correction, syndrome consistency, " "Hamming overlap, format compliance, PyMatching beat-rate)." ), version="1.0.0", ) def close(self) -> None: # nothing to clean up return None