ronitraj's picture
deploy via scripts/deploy_to_space.py
fa68719 verified
"""Two equivalent client implementations:
* :class:`DecoderClient` - hits an HTTP endpoint (HF Spaces deployment).
Speaks the **OpenEnv** wire format:
- ``POST /reset`` body ``{"seed": int?, "episode_id": str?}``
- ``POST /step`` body ``{"action": {"raw_response": "...", ...},
"timeout_s": float?, "request_id": str?}``
* :class:`LocalDecoderClient` - calls the in-process env directly. Use this
in tests, in CI, and during local Colab runs to skip HTTP overhead.
Both expose the same ``reset`` / ``step`` API so the training scripts can
swap between them via a single env var (``QUBIT_MEDIC_URL``).
"""
from __future__ import annotations
import os
from typing import Optional, Protocol
import httpx
from qubit_medic.models import (
DecoderObservation,
StepResult,
)
class _ClientProtocol(Protocol):
def reset(self, *, seed: Optional[int] = None,
forced_level: Optional[str] = None) -> DecoderObservation: ...
def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
# Compliance Section 3 (audit, 2026-04): the client surface must
# mirror the server endpoints. state() returns a JSON-serialisable
# snapshot; close() releases per-episode bookkeeping.
def state(self) -> dict: ...
def health(self) -> dict: ...
def close(self) -> None: ...
def _obs_from_openenv(payload: dict) -> DecoderObservation:
"""Re-hydrate our internal :class:`DecoderObservation` from the
OpenEnv response body. The OpenEnv wrapper inlines all our fields
onto the observation, so this is a 1-1 field mapping."""
return DecoderObservation(
prompt=payload.get("prompt", ""),
syndrome_bits=list(payload.get("syndrome_bits", [])),
distance=int(payload.get("distance", 0)),
rounds=int(payload.get("rounds", 0)),
p=float(payload.get("p", 0.0)),
curriculum_level=payload.get("curriculum_level", ""),
episode_id=int(payload.get("episode_id", 0)),
dem_digest=payload.get("dem_digest", ""),
)
class DecoderClient:
"""HTTP client targeting a deployed FastAPI server (OpenEnv shape)."""
def __init__(self, base_url: str, *, timeout: float = 60.0) -> None:
self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout)
def reset(self, *, seed: Optional[int] = None,
forced_level: Optional[str] = None) -> DecoderObservation:
# OpenEnv's ResetRequest only accepts seed + episode_id. We pass
# forced_level via the URL query string so adapters that honour
# it (our QubitMedicEnvironment via **kwargs) pick it up; servers
# that ignore it just get a default level.
body: dict = {}
if seed is not None:
body["seed"] = seed
params = {"forced_level": forced_level} if forced_level else None
r = self._client.post("/reset", json=body, params=params)
r.raise_for_status()
payload = r.json()
# OpenEnv returns {observation: {...}, reward, done}.
return _obs_from_openenv(payload.get("observation", payload))
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
body = {
"action": {
"raw_response": raw_response,
"episode_id": episode_id,
},
}
r = self._client.post("/step", json=body)
r.raise_for_status()
payload = r.json()
obs_payload = payload.get("observation", {})
return StepResult(
observation=_obs_from_openenv(obs_payload),
reward=float(payload.get("reward", 0.0) or 0.0),
done=bool(payload.get("done", True)),
truncated=bool(obs_payload.get("info", {}).get("timed_out", False)),
info=dict(obs_payload.get("info", {})),
)
def state(self) -> dict:
"""GET /state on the OpenEnv server.
Compliance Section 3 (audit, 2026-04): the client must mirror
the server endpoints. We use GET (the OpenEnv canonical method)
first, then fall back to POST (the audit-required method we
also mounted) if some server build only exposes one of them.
"""
r = self._client.get("/state")
if r.status_code == 405: # method not allowed -> try POST
r = self._client.post("/state")
r.raise_for_status()
return r.json()
def health(self) -> dict:
r = self._client.get("/health")
r.raise_for_status()
return r.json()
def healthz(self) -> dict:
r = self._client.get("/healthz")
r.raise_for_status()
return r.json()
def close(self) -> None:
# Best-effort: tell the server we're done (the POST /close route
# is mounted by qubit_medic.server.app) and then release the
# local httpx connection pool. If the server doesn't expose
# /close, swallow the 404 - this remains an idempotent cleanup.
try:
self._client.post("/close")
except Exception:
pass
self._client.close()
class LocalDecoderClient:
"""In-process client - calls :class:`DecoderEnvironment` directly."""
def __init__(self, env=None) -> None:
from qubit_medic.server.environment import DecoderEnvironment
self._env = env if env is not None else DecoderEnvironment()
def reset(self, *, seed: Optional[int] = None,
forced_level: Optional[str] = None) -> DecoderObservation:
return self._env.reset(seed=seed, forced_level=forced_level)
def step(self, *, raw_response: str, episode_id: int) -> StepResult:
return self._env.step(raw_response=raw_response, episode_id=episode_id)
def state(self) -> dict:
"""Compliance Section 3 (audit, 2026-04): expose env state via
the same client surface as the HTTP variant. Delegates to the
in-process :meth:`DecoderEnvironment.state`."""
return self._env.state()
def health(self) -> dict:
return self._env.health()
def close(self) -> None:
# Compliance Section 3 (audit, 2026-04): close releases any
# per-episode bookkeeping on the inner DecoderEnvironment so a
# subsequent reset() starts from a clean active-episode dict.
try:
self._env.close()
except Exception:
pass
def make_default_client() -> _ClientProtocol:
"""Return :class:`DecoderClient` if ``QUBIT_MEDIC_URL`` is set, else local."""
url = os.getenv("QUBIT_MEDIC_URL")
if url:
return DecoderClient(url)
return LocalDecoderClient()