Spaces:
Sleeping
Sleeping
File size: 6,615 Bytes
195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e fa68719 195f87e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """Two equivalent client implementations:
* :class:`DecoderClient` - hits an HTTP endpoint (HF Spaces deployment).
Speaks the **OpenEnv** wire format:
- ``POST /reset`` body ``{"seed": int?, "episode_id": str?}``
- ``POST /step`` body ``{"action": {"raw_response": "...", ...},
"timeout_s": float?, "request_id": str?}``
* :class:`LocalDecoderClient` - calls the in-process env directly. Use this
in tests, in CI, and during local Colab runs to skip HTTP overhead.
Both expose the same ``reset`` / ``step`` API so the training scripts can
swap between them via a single env var (``QUBIT_MEDIC_URL``).
"""
from __future__ import annotations
import os
from typing import Optional, Protocol
import httpx
from qubit_medic.models import (
DecoderObservation,
StepResult,
)
class _ClientProtocol(Protocol):
def reset(self, *, seed: Optional[int] = None,
forced_level: Optional[str] = None) -> DecoderObservation: ...
def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
# Compliance Section 3 (audit, 2026-04): the client surface must
# mirror the server endpoints. state() returns a JSON-serialisable
# snapshot; close() releases per-episode bookkeeping.
def state(self) -> dict: ...
def health(self) -> dict: ...
def close(self) -> None: ...
def _obs_from_openenv(payload: dict) -> DecoderObservation:
"""Re-hydrate our internal :class:`DecoderObservation` from the
OpenEnv response body. The OpenEnv wrapper inlines all our fields
onto the observation, so this is a 1-1 field mapping."""
return DecoderObservation(
prompt=payload.get("prompt", ""),
syndrome_bits=list(payload.get("syndrome_bits", [])),
distance=int(payload.get("distance", 0)),
rounds=int(payload.get("rounds", 0)),
p=float(payload.get("p", 0.0)),
curriculum_level=payload.get("curriculum_level", ""),
episode_id=int(payload.get("episode_id", 0)),
dem_digest=payload.get("dem_digest", ""),
)
class DecoderClient:
"""HTTP client targeting a deployed FastAPI server (OpenEnv shape)."""
def __init__(self, base_url: str, *, timeout: float = 60.0) -> None:
self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout)
def reset(self, *, seed: Optional[int] = None,
forced_level: Optional[str] = None) -> DecoderObservation:
# OpenEnv's ResetRequest only accepts seed + episode_id. We pass
# forced_level via the URL query string so adapters that honour
# it (our QubitMedicEnvironment via **kwargs) pick it up; servers
# that ignore it just get a default level.
body: dict = {}
if seed is not None:
body["seed"] = seed
params = {"forced_level": forced_level} if forced_level else None
r = self._client.post("/reset", json=body, params=params)
r.raise_for_status()
payload = r.json()
# OpenEnv returns {observation: {...}, reward, done}.
return _obs_from_openenv(payload.get("observation", payload))
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
body = {
"action": {
"raw_response": raw_response,
"episode_id": episode_id,
},
}
r = self._client.post("/step", json=body)
r.raise_for_status()
payload = r.json()
obs_payload = payload.get("observation", {})
return StepResult(
observation=_obs_from_openenv(obs_payload),
reward=float(payload.get("reward", 0.0) or 0.0),
done=bool(payload.get("done", True)),
truncated=bool(obs_payload.get("info", {}).get("timed_out", False)),
info=dict(obs_payload.get("info", {})),
)
def state(self) -> dict:
"""GET /state on the OpenEnv server.
Compliance Section 3 (audit, 2026-04): the client must mirror
the server endpoints. We use GET (the OpenEnv canonical method)
first, then fall back to POST (the audit-required method we
also mounted) if some server build only exposes one of them.
"""
r = self._client.get("/state")
if r.status_code == 405: # method not allowed -> try POST
r = self._client.post("/state")
r.raise_for_status()
return r.json()
def health(self) -> dict:
r = self._client.get("/health")
r.raise_for_status()
return r.json()
def healthz(self) -> dict:
r = self._client.get("/healthz")
r.raise_for_status()
return r.json()
def close(self) -> None:
# Best-effort: tell the server we're done (the POST /close route
# is mounted by qubit_medic.server.app) and then release the
# local httpx connection pool. If the server doesn't expose
# /close, swallow the 404 - this remains an idempotent cleanup.
try:
self._client.post("/close")
except Exception:
pass
self._client.close()
class LocalDecoderClient:
"""In-process client - calls :class:`DecoderEnvironment` directly."""
def __init__(self, env=None) -> None:
from qubit_medic.server.environment import DecoderEnvironment
self._env = env if env is not None else DecoderEnvironment()
def reset(self, *, seed: Optional[int] = None,
forced_level: Optional[str] = None) -> DecoderObservation:
return self._env.reset(seed=seed, forced_level=forced_level)
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
return self._env.step(raw_response=raw_response, episode_id=episode_id)
def state(self) -> dict:
"""Compliance Section 3 (audit, 2026-04): expose env state via
the same client surface as the HTTP variant. Delegates to the
in-process :meth:`DecoderEnvironment.state`."""
return self._env.state()
def health(self) -> dict:
return self._env.health()
def close(self) -> None:
# Compliance Section 3 (audit, 2026-04): close releases any
# per-episode bookkeeping on the inner DecoderEnvironment so a
# subsequent reset() starts from a clean active-episode dict.
try:
self._env.close()
except Exception:
pass
def make_default_client() -> _ClientProtocol:
"""Return :class:`DecoderClient` if ``QUBIT_MEDIC_URL`` is set, else local."""
url = os.getenv("QUBIT_MEDIC_URL")
if url:
return DecoderClient(url)
return LocalDecoderClient()
|