File size: 6,615 Bytes
195f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa68719
 
 
 
195f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa68719
 
 
 
 
 
 
 
 
 
 
 
 
 
195f87e
 
 
 
 
 
 
 
 
 
 
fa68719
 
 
 
 
 
 
 
195f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa68719
 
 
 
 
 
195f87e
 
 
fa68719
 
 
 
 
 
 
 
195f87e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Two equivalent client implementations:

* :class:`DecoderClient`     - hits an HTTP endpoint (HF Spaces deployment).
  Speaks the **OpenEnv** wire format:
  - ``POST /reset`` body ``{"seed": int?, "episode_id": str?}``
  - ``POST /step``  body ``{"action": {"raw_response": "...", ...},
    "timeout_s": float?, "request_id": str?}``
* :class:`LocalDecoderClient` - calls the in-process env directly. Use this
  in tests, in CI, and during local Colab runs to skip HTTP overhead.

Both expose the same ``reset`` / ``step`` API so the training scripts can
swap between them via a single env var (``QUBIT_MEDIC_URL``).
"""
from __future__ import annotations

import os
from typing import Optional, Protocol

import httpx

from qubit_medic.models import (
    DecoderObservation,
    StepResult,
)


class _ClientProtocol(Protocol):
    def reset(self, *, seed: Optional[int] = None,
              forced_level: Optional[str] = None) -> DecoderObservation: ...
    def step(self, *, raw_response: str, episode_id: int) -> StepResult: ...
    # Compliance Section 3 (audit, 2026-04): the client surface must
    # mirror the server endpoints. state() returns a JSON-serialisable
    # snapshot; close() releases per-episode bookkeeping.
    def state(self) -> dict: ...
    def health(self) -> dict: ...
    def close(self) -> None: ...


def _obs_from_openenv(payload: dict) -> DecoderObservation:
    """Re-hydrate our internal :class:`DecoderObservation` from the
    OpenEnv response body. The OpenEnv wrapper inlines all our fields
    onto the observation, so this is a 1-1 field mapping."""
    return DecoderObservation(
        prompt=payload.get("prompt", ""),
        syndrome_bits=list(payload.get("syndrome_bits", [])),
        distance=int(payload.get("distance", 0)),
        rounds=int(payload.get("rounds", 0)),
        p=float(payload.get("p", 0.0)),
        curriculum_level=payload.get("curriculum_level", ""),
        episode_id=int(payload.get("episode_id", 0)),
        dem_digest=payload.get("dem_digest", ""),
    )


class DecoderClient:
    """HTTP client targeting a deployed FastAPI server (OpenEnv shape)."""

    def __init__(self, base_url: str, *, timeout: float = 60.0) -> None:
        self._client = httpx.Client(base_url=base_url.rstrip("/"), timeout=timeout)

    def reset(self, *, seed: Optional[int] = None,
              forced_level: Optional[str] = None) -> DecoderObservation:
        # OpenEnv's ResetRequest only accepts seed + episode_id. We pass
        # forced_level via the URL query string so adapters that honour
        # it (our QubitMedicEnvironment via **kwargs) pick it up; servers
        # that ignore it just get a default level.
        body: dict = {}
        if seed is not None:
            body["seed"] = seed
        params = {"forced_level": forced_level} if forced_level else None
        r = self._client.post("/reset", json=body, params=params)
        r.raise_for_status()
        payload = r.json()
        # OpenEnv returns {observation: {...}, reward, done}.
        return _obs_from_openenv(payload.get("observation", payload))

    def step(self, *, raw_response: str, episode_id: int) -> StepResult:
        body = {
            "action": {
                "raw_response": raw_response,
                "episode_id": episode_id,
            },
        }
        r = self._client.post("/step", json=body)
        r.raise_for_status()
        payload = r.json()
        obs_payload = payload.get("observation", {})
        return StepResult(
            observation=_obs_from_openenv(obs_payload),
            reward=float(payload.get("reward", 0.0) or 0.0),
            done=bool(payload.get("done", True)),
            truncated=bool(obs_payload.get("info", {}).get("timed_out", False)),
            info=dict(obs_payload.get("info", {})),
        )

    def state(self) -> dict:
        """GET /state on the OpenEnv server.

        Compliance Section 3 (audit, 2026-04): the client must mirror
        the server endpoints. We use GET (the OpenEnv canonical method)
        first, then fall back to POST (the audit-required method we
        also mounted) if some server build only exposes one of them.
        """
        r = self._client.get("/state")
        if r.status_code == 405:  # method not allowed -> try POST
            r = self._client.post("/state")
        r.raise_for_status()
        return r.json()

    def health(self) -> dict:
        r = self._client.get("/health")
        r.raise_for_status()
        return r.json()

    def healthz(self) -> dict:
        r = self._client.get("/healthz")
        r.raise_for_status()
        return r.json()

    def close(self) -> None:
        # Best-effort: tell the server we're done (the POST /close route
        # is mounted by qubit_medic.server.app) and then release the
        # local httpx connection pool. If the server doesn't expose
        # /close, swallow the 404 - this remains an idempotent cleanup.
        try:
            self._client.post("/close")
        except Exception:
            pass
        self._client.close()


class LocalDecoderClient:
    """In-process client - calls :class:`DecoderEnvironment` directly."""

    def __init__(self, env=None) -> None:
        from qubit_medic.server.environment import DecoderEnvironment
        self._env = env if env is not None else DecoderEnvironment()

    def reset(self, *, seed: Optional[int] = None,
              forced_level: Optional[str] = None) -> DecoderObservation:
        return self._env.reset(seed=seed, forced_level=forced_level)

    def step(self, *, raw_response: str, episode_id: int) -> StepResult:
        return self._env.step(raw_response=raw_response, episode_id=episode_id)

    def state(self) -> dict:
        """Compliance Section 3 (audit, 2026-04): expose env state via
        the same client surface as the HTTP variant. Delegates to the
        in-process :meth:`DecoderEnvironment.state`."""
        return self._env.state()

    def health(self) -> dict:
        return self._env.health()

    def close(self) -> None:
        # Compliance Section 3 (audit, 2026-04): close releases any
        # per-episode bookkeeping on the inner DecoderEnvironment so a
        # subsequent reset() starts from a clean active-episode dict.
        try:
            self._env.close()
        except Exception:
            pass


def make_default_client() -> _ClientProtocol:
    """Return :class:`DecoderClient` if ``QUBIT_MEDIC_URL`` is set, else local."""
    url = os.getenv("QUBIT_MEDIC_URL")
    if url:
        return DecoderClient(url)
    return LocalDecoderClient()