Spaces:
Sleeping
Sleeping
| """OpenEnv-compliant adapter around :class:`DecoderEnvironment`. | |
| This wrapper satisfies the submission requirement *"Use OpenEnv (latest | |
| release). Build on top of the framework; don't reinvent the wheel."* by | |
| exposing our underlying :class:`qubit_medic.server.environment.DecoderEnvironment` | |
| through the official ``openenv.core.Environment`` base class. | |
| The adapter is intentionally thin: it just translates between OpenEnv's | |
| ``Action`` / ``Observation`` / ``State`` Pydantic shapes and our internal | |
| ``DecoderObservation`` / ``DecoderAction`` / ``StepResult``. All the | |
| physics, reward scoring, curriculum, and episode bookkeeping continue to | |
| live in :class:`DecoderEnvironment` - that code is *the* tested, | |
| production path. | |
| Usage | |
| ----- | |
| The OpenEnv-compliant FastAPI app is created with:: | |
| from openenv.core import create_fastapi_app | |
| from qubit_medic.server.openenv_adapter import ( | |
| QubitMedicEnvironment, QubitMedicAction, QubitMedicObservation, | |
| ) | |
| app = create_fastapi_app( | |
| env=QubitMedicEnvironment, | |
| action_cls=QubitMedicAction, | |
| observation_cls=QubitMedicObservation, | |
| ) | |
| This registers the canonical OpenEnv routes: | |
| * ``POST /reset`` - body ``{"seed": int?, "episode_id": str?}`` | |
| * ``POST /step`` - body ``{"action": {...QubitMedicAction...}, | |
| "timeout_s": float?, "request_id": str?}`` | |
| * ``GET /state`` - returns the current :class:`QubitMedicState` | |
| * ``GET /health`` - liveness probe | |
| * ``GET /schema`` - JSON Schema for the action/observation models | |
| * ``GET /metadata`` - environment metadata | |
| * ``POST /mcp`` - Model Context Protocol endpoint | |
| * ``GET /docs`` - Swagger UI (auto-generated by FastAPI) | |
| We additionally mount our own ``/healthz`` (Day-0 contract) and | |
| ``/decode`` (PyMatching baseline demo) on the returned app from | |
| ``qubit_medic.server.app``. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Optional | |
| from openenv.core import Action, Environment, Observation, State | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| from pydantic import ConfigDict, Field | |
| from qubit_medic.server.environment import DecoderEnvironment | |
| # --------------------------------------------------------------------------- # | |
| # Process-wide singleton # | |
| # --------------------------------------------------------------------------- # | |
| # OpenEnv's HTTP server (simulation mode) instantiates a *fresh* Environment | |
| # via the factory on every /reset and /step call. Our episode bookkeeping | |
| # (the `_active` dict) lives inside DecoderEnvironment, so we route every | |
| # QubitMedicEnvironment instance through the same DecoderEnvironment. | |
| # This keeps reset() -> step() pairing intact across stateless HTTP calls | |
| # while remaining fully compatible with OpenEnv's WebSocket session model | |
| # (each WS session still gets its own QubitMedicEnvironment wrapper). | |
| _INNER_SINGLETON: Optional[DecoderEnvironment] = None | |
| def _get_shared_inner() -> DecoderEnvironment: | |
| """Return the process-wide DecoderEnvironment, building it lazily.""" | |
| global _INNER_SINGLETON | |
| if _INNER_SINGLETON is None: | |
| env = DecoderEnvironment() | |
| env._cache_for("L1_warmup") # noqa: SLF001 - intentional pre-warm | |
| env._cache_for("L2_target") # noqa: SLF001 | |
| _INNER_SINGLETON = env | |
| return _INNER_SINGLETON | |
| # --------------------------------------------------------------------------- # | |
| # OpenEnv-flavoured Action / Observation / State # | |
| # --------------------------------------------------------------------------- # | |
| class QubitMedicAction(Action): | |
| """LLM-emitted action: the raw text the model generated. | |
| The server parses this into ``x_error_qubits`` / ``z_error_qubits`` via | |
| :func:`qubit_medic.prompts.parse_action`. We keep the wire format | |
| *just the raw string* so the server retains full control over parsing | |
| (and so the trainer's reward function can audit unparseable outputs). | |
| The trainer is also free to populate ``parsed_x_errors`` / | |
| ``parsed_z_errors`` directly when it wants to bypass the LLM (useful | |
| for baseline policies and unit tests). | |
| """ | |
| # Inherit Action.model_config (extra='forbid', validate_assignment=True). | |
| raw_response: str = Field( | |
| default="", | |
| description="Raw LLM completion text. Server parses to x/z error lists.", | |
| ) | |
| parsed_x_errors: Optional[list[int]] = Field( | |
| default=None, | |
| description="Optional pre-parsed X-error qubit ids (LLM-space). " | |
| "When provided, the server skips text parsing.", | |
| ) | |
| parsed_z_errors: Optional[list[int]] = Field( | |
| default=None, | |
| description="Optional pre-parsed Z-error qubit ids (LLM-space).", | |
| ) | |
| episode_id: Optional[int] = Field( | |
| default=None, | |
| description="Server-assigned episode id from the matching reset(). " | |
| "If omitted, the most-recent active episode is used.", | |
| ) | |
| class QubitMedicObservation(Observation): | |
| """OpenEnv observation - mirrors :class:`DecoderObservation` plus the | |
| standard OpenEnv ``done`` / ``reward`` fields. | |
| The ``info`` dict (returned by ``step``) carries the per-component | |
| reward breakdown, the ground-truth observable flip, and the PyMatching | |
| baseline prediction so the trainer can score auxiliary metrics. | |
| """ | |
| model_config = ConfigDict(extra="forbid", validate_assignment=True, | |
| arbitrary_types_allowed=True) | |
| prompt: str = Field(default="", description="Pre-formatted LLM prompt.") | |
| syndrome_bits: list[int] = Field(default_factory=list, | |
| description="Detector activations (0/1).") | |
| distance: int = Field(default=0, description="Code distance for this episode.") | |
| rounds: int = Field(default=0, description="Number of stabilizer rounds.") | |
| p: float = Field(default=0.0, description="SI1000 base error rate.") | |
| curriculum_level: str = Field(default="", | |
| description="Curriculum level name.") | |
| episode_id: int = Field(default=0, | |
| description="Server-assigned episode counter.") | |
| dem_digest: str = Field(default="", | |
| description="Short hash of the detector error model.") | |
| info: dict[str, Any] = Field(default_factory=dict, | |
| description="Per-step extras (reward " | |
| "breakdown, ground-truth flip, " | |
| "PyMatching baseline, etc.).") | |
| class QubitMedicState(State): | |
| """Externally-visible state. We expose only the curriculum + episode | |
| counters; physics-truth fields stay server-side to prevent reward | |
| hacking (see :mod:`qubit_medic.models.DecoderState` doc-comment).""" | |
| model_config = ConfigDict(extra="allow", validate_assignment=True, | |
| arbitrary_types_allowed=True) | |
| episodes_started: int = 0 | |
| active_episodes: int = 0 | |
| cached_levels: list[str] = Field(default_factory=list) | |
| curriculum: dict[str, Any] = Field(default_factory=dict) | |
| last_reward_breakdown: Optional[dict[str, float]] = None | |
| # --------------------------------------------------------------------------- # | |
| # Environment wrapper # | |
| # --------------------------------------------------------------------------- # | |
| class QubitMedicEnvironment(Environment[QubitMedicAction, | |
| QubitMedicObservation, | |
| QubitMedicState]): | |
| """OpenEnv-compliant view of :class:`DecoderEnvironment`. | |
| Single-step episodes (``done=True`` after every ``step``). The OpenEnv | |
| HTTP server gets a fresh instance per WebSocket session if | |
| ``SUPPORTS_CONCURRENT_SESSIONS=True``; we set it to ``False`` because | |
| our DecoderEnvironment uses a single Stim cache + a coarse lock, which | |
| is simpler than per-session state and good enough for the GRPO | |
| training loop. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = False | |
| def __init__(self) -> None: | |
| super().__init__() | |
| # Share the underlying DecoderEnvironment across every wrapper | |
| # instance the HTTP server creates - see _get_shared_inner. | |
| self._inner = _get_shared_inner() | |
| self._last_episode_id: Optional[int] = None | |
| self._last_reward_breakdown: Optional[dict[str, float]] = None | |
| # ----- abstract API --------------------------------------------------- # | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> QubitMedicObservation: | |
| forced_level = kwargs.get("forced_level") | |
| obs = self._inner.reset(seed=seed, forced_level=forced_level) | |
| self._last_episode_id = obs.episode_id | |
| self._last_reward_breakdown = None | |
| return QubitMedicObservation( | |
| prompt=obs.prompt, | |
| syndrome_bits=list(obs.syndrome_bits), | |
| distance=obs.distance, | |
| rounds=obs.rounds, | |
| p=obs.p, | |
| curriculum_level=obs.curriculum_level, | |
| episode_id=obs.episode_id, | |
| dem_digest=obs.dem_digest, | |
| done=False, | |
| reward=None, | |
| info={"event": "reset"}, | |
| ) | |
| def step( | |
| self, | |
| action: QubitMedicAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> QubitMedicObservation: | |
| ep = action.episode_id if action.episode_id is not None else self._last_episode_id | |
| if ep is None: | |
| raise RuntimeError( | |
| "step() called before reset(); no active episode to score." | |
| ) | |
| # If the trainer pre-parsed the action, format a synthetic raw | |
| # response in the canonical "X: ... | Z: ..." shape so the server's | |
| # parser produces the same x/z lists. | |
| if action.parsed_x_errors is not None or action.parsed_z_errors is not None: | |
| xs = action.parsed_x_errors or [] | |
| zs = action.parsed_z_errors or [] | |
| raw = f"<answer>X: {','.join(map(str, xs))} | Z: {','.join(map(str, zs))}</answer>" | |
| else: | |
| raw = action.raw_response | |
| result = self._inner.step(raw_response=raw, episode_id=ep) | |
| self._last_reward_breakdown = result.info.get("rewards") | |
| return QubitMedicObservation( | |
| prompt=result.observation.prompt, | |
| syndrome_bits=list(result.observation.syndrome_bits), | |
| distance=result.observation.distance, | |
| rounds=result.observation.rounds, | |
| p=result.observation.p, | |
| curriculum_level=result.observation.curriculum_level, | |
| episode_id=result.observation.episode_id, | |
| dem_digest=result.observation.dem_digest, | |
| done=result.done, | |
| reward=float(result.reward), | |
| info=result.info, | |
| ) | |
| def state(self) -> QubitMedicState: | |
| h = self._inner.health() | |
| return QubitMedicState( | |
| episode_id=str(self._last_episode_id) | |
| if self._last_episode_id is not None else None, | |
| step_count=int(h.get("episodes_started", 0)), | |
| episodes_started=int(h.get("episodes_started", 0)), | |
| active_episodes=int(h.get("active_episodes", 0)), | |
| cached_levels=list(h.get("cached_levels", [])), | |
| curriculum=dict(h.get("curriculum", {})), | |
| last_reward_breakdown=self._last_reward_breakdown, | |
| ) | |
| # ----- nice-to-haves -------------------------------------------------- # | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| return EnvironmentMetadata( | |
| name="QubitMedicEnvironment", | |
| description=( | |
| "RL training environment for LLM-based quantum error-" | |
| "correction decoders. Built on Stim + PyMatching. Five " | |
| "verifiable rewards (logical correction, syndrome consistency, " | |
| "Hamming overlap, format compliance, PyMatching beat-rate)." | |
| ), | |
| version="1.0.0", | |
| ) | |
| def close(self) -> None: # nothing to clean up | |
| return None | |