File size: 9,203 Bytes
abcf7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9a10ad
 
abcf7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eea4d6e
abcf7cd
 
eea4d6e
 
 
 
abcf7cd
eea4d6e
 
 
 
abcf7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""Remote-vs-local ML inference router.

Mirrors the call-surface shape of `app/llm.py` but for the non-LLM
heavy models (Prithvi, TerraMind, TTM, Granite Embedding, GLiNER).

The droplet runs a `riprap-models` FastAPI service alongside vLLM that
exposes an OpenAI-style endpoint per model class. When configured the
router POSTs the relevant payload there and returns the parsed response;
on connection error / 5xx / timeout it surfaces a typed exception that
caller modules catch and fall back to a local in-process model load.

Backend selection (env):

  RIPRAP_ML_BACKEND   = "remote" | "local" | "auto"  (default: auto)
                        - remote: use only the droplet, raise if it errors
                        - local : never call the droplet, always use the
                                  in-process model
                        - auto  : try remote first, fall back to local if
                                  remote is unreachable / errors out;
                                  same semantics as app/llm.py
  RIPRAP_ML_BASE_URL  = http://129.212.181.238:8002    (no trailing slash)
  RIPRAP_ML_API_KEY   = <bearer token>

The router is *transport*-only β€” it does not own model bytes, weights,
or framework imports. Each specialist that wants remote inference calls
into the helpers below and provides its own local fallback. That keeps
the dependency graph clean: the local code path keeps working when the
RIPRAP_ML_* env is unset (e.g. on first-light dev or in unit tests).
"""
from __future__ import annotations

import base64
import logging
import os
from collections.abc import Iterable
from typing import Any

import httpx

log = logging.getLogger("riprap.inference")

_BACKEND = os.environ.get("RIPRAP_ML_BACKEND", "auto").lower()
_BASE_URL = os.environ.get("RIPRAP_ML_BASE_URL", "").rstrip("/")
_API_KEY = os.environ.get("RIPRAP_ML_API_KEY", "")
_DEFAULT_TIMEOUT = float(os.environ.get("RIPRAP_ML_TIMEOUT_S", "60"))


class RemoteUnreachable(RuntimeError):
    """Raised when the remote inference service is unconfigured, down,
    times out, or returns 5xx. Callers catch this to fall through to a
    local model load. 4xx errors propagate as the generic exception so
    a caller bug doesn't get masked by a "fallback to local" path."""


def remote_enabled() -> bool:
    """True iff the router is configured to attempt remote calls.
    Returns False under explicit `local` mode or when the base URL is
    empty (the auto-default with no env config)."""
    if _BACKEND == "local":
        return False
    if not _BASE_URL:
        return False
    return True


def _client(timeout: float | None = None) -> httpx.Client:
    headers = {"User-Agent": "riprap-app/0.4.5"}
    if _API_KEY:
        headers["Authorization"] = f"Bearer {_API_KEY}"
    return httpx.Client(
        base_url=_BASE_URL,
        headers=headers,
        timeout=timeout if timeout is not None else _DEFAULT_TIMEOUT,
    )


def _post(path: str, payload: dict[str, Any], timeout: float | None = None) -> dict:
    """POST {payload} as JSON to the remote service's `path`. Returns the
    parsed JSON body. Raises RemoteUnreachable on transport errors;
    raises HTTPStatusError on 4xx so caller bugs surface."""
    if not remote_enabled():
        raise RemoteUnreachable("remote ML backend not configured "
                                "(RIPRAP_ML_BASE_URL empty or BACKEND=local)")
    try:
        with _client(timeout) as c:
            r = c.post(path, json=payload)
    except (httpx.ConnectError, httpx.ReadError, httpx.WriteError,
             httpx.TimeoutException, httpx.RemoteProtocolError) as e:
        raise RemoteUnreachable(f"{type(e).__name__}: {e}") from e
    if r.status_code >= 500:
        raise RemoteUnreachable(f"HTTP {r.status_code} from {path}: {r.text[:200]}")
    r.raise_for_status()
    return r.json()


def _serialize_array(arr) -> str:
    """numpy/torch tensor β†’ base64-encoded float32 raw bytes for transport.
    Each remote handler decodes to (shape, dtype=float32) and reconstructs.
    Reasonable round-trip for chips up to a few MB; large rasters should
    use compressed numpy-savez instead β€” TODO when a model needs > 8 MB."""
    import numpy as np
    np_arr = arr if isinstance(arr, np.ndarray) else _to_numpy(arr)
    np_arr = np_arr.astype("float32", copy=False)
    return base64.b64encode(np_arr.tobytes()).decode("ascii")


def _to_numpy(t):
    """Best-effort tensor β†’ numpy. Accepts torch.Tensor or numpy already."""
    try:
        import torch
        if isinstance(t, torch.Tensor):
            return t.detach().cpu().numpy()
    except ImportError:
        pass
    import numpy as np
    return np.asarray(t)


def _deserialize_array(b64: str, shape: list[int]):
    """Inverse of _serialize_array β€” bytes β†’ numpy float32 with given shape."""
    import numpy as np
    raw = base64.b64decode(b64)
    return np.frombuffer(raw, dtype="float32").reshape(shape)


# ---- Public router entry points -------------------------------------------

def healthcheck(timeout: float = 3.0) -> bool:
    """Quick reachability probe. True if the service responds 200 to GET
    /healthz within `timeout` seconds. Used by /api/backend so the UI can
    show whether the remote ML backend is currently live."""
    if not remote_enabled():
        return False
    try:
        with _client(timeout) as c:
            r = c.get("/healthz")
        return r.status_code == 200
    except Exception:
        return False


def backend_info() -> dict[str, Any]:
    """Snapshot for /api/backend β€” what the UI should advertise."""
    return {
        "backend": _BACKEND,
        "base_url": _BASE_URL or None,
        "remote_enabled": remote_enabled(),
        "reachable": healthcheck() if remote_enabled() else False,
    }


def prithvi_pluvial(s2_chip, *, scene_id: str | None = None,
                     scene_datetime: str | None = None,
                     cloud_cover: float | None = None,
                     timeout: float | None = None) -> dict[str, Any]:
    """Remote forward pass through Prithvi-NYC-Pluvial v2.
    Input: 6-band Sentinel-2 chip (numpy or torch, shape [6, H, W]).
    Output: { ok, pct_water_within_500m, pct_water_full, scene_id, ... }.
    Raises RemoteUnreachable if the service is down."""
    arr = _to_numpy(s2_chip)
    return _post("/v1/prithvi-pluvial", {
        "s2": _serialize_array(arr),
        "shape": list(arr.shape),
        "scene_id": scene_id,
        "scene_datetime": scene_datetime,
        "cloud_cover": cloud_cover,
    }, timeout=timeout)


def terramind(adapter: str, s2l2a=None, s1rtc=None, dem=None, *,
               timeout: float | None = None) -> dict[str, Any]:
    """Remote forward through TerraMind-NYC-Adapters (LULC or Buildings)
    or the v1 base generative path (synthesis). `adapter` is one of:
    lulc, buildings, synthesis. Each modality is a numpy array, torch
    tensor, or None β€” `synthesis` only needs DEM; the LoRA adapters
    need at minimum S2L2A."""
    payload: dict[str, Any] = {"adapter": adapter}
    if s2l2a is not None:
        s2_np = _to_numpy(s2l2a)
        payload["s2"] = _serialize_array(s2_np)
        payload["s2_shape"] = list(s2_np.shape)
    if s1rtc is not None:
        s1_np = _to_numpy(s1rtc)
        payload["s1"] = _serialize_array(s1_np)
        payload["s1_shape"] = list(s1_np.shape)
    if dem is not None:
        dem_np = _to_numpy(dem)
        payload["dem"] = _serialize_array(dem_np)
        payload["dem_shape"] = list(dem_np.shape)
    return _post("/v1/terramind", payload, timeout=timeout)


def ttm_forecast(model: str, history: Iterable[float], *,
                  context_length: int, prediction_length: int,
                  cadence: str = "h",
                  timeout: float | None = None) -> dict[str, Any]:
    """Remote Granite TTM r2 forecast.
    `model` is one of: zero_shot_battery, fine_tune_battery, weekly_311,
    floodnet_recurrence β€” the service decides which checkpoint to use.
    `history` is a 1-D iterable of floats (the time series); `cadence`
    is for the service's labelling (h / d / w / 6m). Output shape is
    `{ ok, forecast: [...], peak_index, peak_value }`."""
    series = list(map(float, history))
    return _post("/v1/ttm-forecast", {
        "model": model,
        "history": series,
        "context_length": context_length,
        "prediction_length": prediction_length,
        "cadence": cadence,
    }, timeout=timeout)


def granite_embed(texts: list[str], *,
                   timeout: float | None = None) -> dict[str, Any]:
    """Remote Granite Embedding 278M batch encode.
    Output: { ok, vectors: [[float, ...], ...] }. Vector dimension fixed
    at 768 (granite-embedding-278m-multilingual)."""
    return _post("/v1/granite-embed", {"texts": list(texts)}, timeout=timeout)


def gliner_extract(text: str, labels: list[str], *,
                    timeout: float | None = None) -> dict[str, Any]:
    """Remote GLiNER typed-entity extraction.
    Output: { ok, entities: [{label, text, start, end, score}, ...] }."""
    return _post("/v1/gliner-extract", {
        "text": text, "labels": list(labels),
    }, timeout=timeout)