Spaces:
Sleeping
Sleeping
| """Two equivalent client implementations: | |
| * :class:`DecoderClient` - hits an HTTP endpoint (HF Spaces deployment). | |
| Speaks the **OpenEnv** wire format: | |
| - ``POST /reset`` body ``{"seed": int?, "episode_id": str?}`` | |
| - ``POST /step`` body ``{"action": {"raw_response": "...", ...}, | |
| "timeout_s": float?, "request_id": str?}`` | |
| * :class:`LocalDecoderClient` - calls the in-process env directly. Use this | |
| in tests, in CI, and during local Colab runs to skip HTTP overhead. | |
| Both expose the same ``reset`` / ``step`` API so the training scripts can | |
| swap between them via a single env var (``QUBIT_MEDIC_URL``). | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Optional, Protocol | |
| import httpx | |
| from qubit_medic.models import ( | |
| DecoderObservation, | |
| StepResult, | |
| ) | |
| class _ClientProtocol(Protocol): | |
| def reset(self, *, seed: Optional[int] = None, | |
| forced_level: Optional[str] = None) -> DecoderObservation: ... | |
| def step(self, *, raw_response: str, episode_id: int) -> StepResult: ... | |
| # Compliance Section 3 (audit, 2026-04): the client surface must | |
| # mirror the server endpoints. state() returns a JSON-serialisable | |
| # snapshot; close() releases per-episode bookkeeping. | |
| def state(self) -> dict: ... | |
| def health(self) -> dict: ... | |
| def close(self) -> None: ... | |
| def _obs_from_openenv(payload: dict) -> DecoderObservation: | |
| """Re-hydrate our internal :class:`DecoderObservation` from the | |
| OpenEnv response body. The OpenEnv wrapper inlines all our fields | |
| onto the observation, so this is a 1-1 field mapping.""" | |
| return DecoderObservation( | |
| prompt=payload.get("prompt", ""), | |
| syndrome_bits=list(payload.get("syndrome_bits", [])), | |
| distance=int(payload.get("distance", 0)), | |
| rounds=int(payload.get("rounds", 0)), | |
| p=float(payload.get("p", 0.0)), | |
| curriculum_level=payload.get("curriculum_level", ""), | |
| episode_id=int(payload.get("episode_id", 0)), | |
| dem_digest=payload.get("dem_digest", ""), | |
| ) | |
| class DecoderClient: | |
| """HTTP client targeting a deployed FastAPI server (OpenEnv shape).""" | |
| def __init__(self, base_url: str, *, timeout: float = 60.0) -> None: | |
| self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout) | |
| def reset(self, *, seed: Optional[int] = None, | |
| forced_level: Optional[str] = None) -> DecoderObservation: | |
| # OpenEnv's ResetRequest only accepts seed + episode_id. We pass | |
| # forced_level via the URL query string so adapters that honour | |
| # it (our QubitMedicEnvironment via **kwargs) pick it up; servers | |
| # that ignore it just get a default level. | |
| body: dict = {} | |
| if seed is not None: | |
| body["seed"] = seed | |
| params = {"forced_level": forced_level} if forced_level else None | |
| r = self._client.post("/reset", json=body, params=params) | |
| r.raise_for_status() | |
| payload = r.json() | |
| # OpenEnv returns {observation: {...}, reward, done}. | |
| return _obs_from_openenv(payload.get("observation", payload)) | |
| def step(self, *, raw_response: str, episode_id: int) -> StepResult: | |
| body = { | |
| "action": { | |
| "raw_response": raw_response, | |
| "episode_id": episode_id, | |
| }, | |
| } | |
| r = self._client.post("/step", json=body) | |
| r.raise_for_status() | |
| payload = r.json() | |
| obs_payload = payload.get("observation", {}) | |
| return StepResult( | |
| observation=_obs_from_openenv(obs_payload), | |
| reward=float(payload.get("reward", 0.0) or 0.0), | |
| done=bool(payload.get("done", True)), | |
| truncated=bool(obs_payload.get("info", {}).get("timed_out", False)), | |
| info=dict(obs_payload.get("info", {})), | |
| ) | |
| def state(self) -> dict: | |
| """GET /state on the OpenEnv server. | |
| Compliance Section 3 (audit, 2026-04): the client must mirror | |
| the server endpoints. We use GET (the OpenEnv canonical method) | |
| first, then fall back to POST (the audit-required method we | |
| also mounted) if some server build only exposes one of them. | |
| """ | |
| r = self._client.get("/state") | |
| if r.status_code == 405: # method not allowed -> try POST | |
| r = self._client.post("/state") | |
| r.raise_for_status() | |
| return r.json() | |
| def health(self) -> dict: | |
| r = self._client.get("/health") | |
| r.raise_for_status() | |
| return r.json() | |
| def healthz(self) -> dict: | |
| r = self._client.get("/healthz") | |
| r.raise_for_status() | |
| return r.json() | |
| def close(self) -> None: | |
| # Best-effort: tell the server we're done (the POST /close route | |
| # is mounted by qubit_medic.server.app) and then release the | |
| # local httpx connection pool. If the server doesn't expose | |
| # /close, swallow the 404 - this remains an idempotent cleanup. | |
| try: | |
| self._client.post("/close") | |
| except Exception: | |
| pass | |
| self._client.close() | |
| class LocalDecoderClient: | |
| """In-process client - calls :class:`DecoderEnvironment` directly.""" | |
| def __init__(self, env=None) -> None: | |
| from qubit_medic.server.environment import DecoderEnvironment | |
| self._env = env if env is not None else DecoderEnvironment() | |
| def reset(self, *, seed: Optional[int] = None, | |
| forced_level: Optional[str] = None) -> DecoderObservation: | |
| return self._env.reset(seed=seed, forced_level=forced_level) | |
| def step(self, *, raw_response: str, episode_id: int) -> StepResult: | |
| return self._env.step(raw_response=raw_response, episode_id=episode_id) | |
| def state(self) -> dict: | |
| """Compliance Section 3 (audit, 2026-04): expose env state via | |
| the same client surface as the HTTP variant. Delegates to the | |
| in-process :meth:`DecoderEnvironment.state`.""" | |
| return self._env.state() | |
| def health(self) -> dict: | |
| return self._env.health() | |
| def close(self) -> None: | |
| # Compliance Section 3 (audit, 2026-04): close releases any | |
| # per-episode bookkeeping on the inner DecoderEnvironment so a | |
| # subsequent reset() starts from a clean active-episode dict. | |
| try: | |
| self._env.close() | |
| except Exception: | |
| pass | |
| def make_default_client() -> _ClientProtocol: | |
| """Return :class:`DecoderClient` if ``QUBIT_MEDIC_URL`` is set, else local.""" | |
| url = os.getenv("QUBIT_MEDIC_URL") | |
| if url: | |
| return DecoderClient(url) | |
| return LocalDecoderClient() | |