File size: 9,203 Bytes
abcf7cd b9a10ad abcf7cd eea4d6e abcf7cd eea4d6e abcf7cd eea4d6e abcf7cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | """Remote-vs-local ML inference router.
Mirrors the call-surface shape of `app/llm.py` but for the non-LLM
heavy models (Prithvi, TerraMind, TTM, Granite Embedding, GLiNER).
The droplet runs a `riprap-models` FastAPI service alongside vLLM that
exposes an OpenAI-style endpoint per model class. When configured the
router POSTs the relevant payload there and returns the parsed response;
on connection error / 5xx / timeout it surfaces a typed exception that
caller modules catch and fall back to a local in-process model load.
Backend selection (env):
RIPRAP_ML_BACKEND = "remote" | "local" | "auto" (default: auto)
- remote: use only the droplet, raise if it errors
- local : never call the droplet, always use the
in-process model
- auto : try remote first, fall back to local if
remote is unreachable / errors out;
same semantics as app/llm.py
RIPRAP_ML_BASE_URL = http://129.212.181.238:8002 (no trailing slash)
RIPRAP_ML_API_KEY = <bearer token>
The router is *transport*-only β it does not own model bytes, weights,
or framework imports. Each specialist that wants remote inference calls
into the helpers below and provides its own local fallback. That keeps
the dependency graph clean: the local code path keeps working when the
RIPRAP_ML_* env is unset (e.g. on first-light dev or in unit tests).
"""
from __future__ import annotations
import base64
import logging
import os
from collections.abc import Iterable
from typing import Any
import httpx
log = logging.getLogger("riprap.inference")
_BACKEND = os.environ.get("RIPRAP_ML_BACKEND", "auto").lower()
_BASE_URL = os.environ.get("RIPRAP_ML_BASE_URL", "").rstrip("/")
_API_KEY = os.environ.get("RIPRAP_ML_API_KEY", "")
_DEFAULT_TIMEOUT = float(os.environ.get("RIPRAP_ML_TIMEOUT_S", "60"))
class RemoteUnreachable(RuntimeError):
"""Raised when the remote inference service is unconfigured, down,
times out, or returns 5xx. Callers catch this to fall through to a
local model load. 4xx errors propagate as the generic exception so
a caller bug doesn't get masked by a "fallback to local" path."""
def remote_enabled() -> bool:
"""True iff the router is configured to attempt remote calls.
Returns False under explicit `local` mode or when the base URL is
empty (the auto-default with no env config)."""
if _BACKEND == "local":
return False
if not _BASE_URL:
return False
return True
def _client(timeout: float | None = None) -> httpx.Client:
headers = {"User-Agent": "riprap-app/0.4.5"}
if _API_KEY:
headers["Authorization"] = f"Bearer {_API_KEY}"
return httpx.Client(
base_url=_BASE_URL,
headers=headers,
timeout=timeout if timeout is not None else _DEFAULT_TIMEOUT,
)
def _post(path: str, payload: dict[str, Any], timeout: float | None = None) -> dict:
"""POST {payload} as JSON to the remote service's `path`. Returns the
parsed JSON body. Raises RemoteUnreachable on transport errors;
raises HTTPStatusError on 4xx so caller bugs surface."""
if not remote_enabled():
raise RemoteUnreachable("remote ML backend not configured "
"(RIPRAP_ML_BASE_URL empty or BACKEND=local)")
try:
with _client(timeout) as c:
r = c.post(path, json=payload)
except (httpx.ConnectError, httpx.ReadError, httpx.WriteError,
httpx.TimeoutException, httpx.RemoteProtocolError) as e:
raise RemoteUnreachable(f"{type(e).__name__}: {e}") from e
if r.status_code >= 500:
raise RemoteUnreachable(f"HTTP {r.status_code} from {path}: {r.text[:200]}")
r.raise_for_status()
return r.json()
def _serialize_array(arr) -> str:
"""numpy/torch tensor β base64-encoded float32 raw bytes for transport.
Each remote handler decodes to (shape, dtype=float32) and reconstructs.
Reasonable round-trip for chips up to a few MB; large rasters should
use compressed numpy-savez instead β TODO when a model needs > 8 MB."""
import numpy as np
np_arr = arr if isinstance(arr, np.ndarray) else _to_numpy(arr)
np_arr = np_arr.astype("float32", copy=False)
return base64.b64encode(np_arr.tobytes()).decode("ascii")
def _to_numpy(t):
"""Best-effort tensor β numpy. Accepts torch.Tensor or numpy already."""
try:
import torch
if isinstance(t, torch.Tensor):
return t.detach().cpu().numpy()
except ImportError:
pass
import numpy as np
return np.asarray(t)
def _deserialize_array(b64: str, shape: list[int]):
"""Inverse of _serialize_array β bytes β numpy float32 with given shape."""
import numpy as np
raw = base64.b64decode(b64)
return np.frombuffer(raw, dtype="float32").reshape(shape)
# ---- Public router entry points -------------------------------------------
def healthcheck(timeout: float = 3.0) -> bool:
"""Quick reachability probe. True if the service responds 200 to GET
/healthz within `timeout` seconds. Used by /api/backend so the UI can
show whether the remote ML backend is currently live."""
if not remote_enabled():
return False
try:
with _client(timeout) as c:
r = c.get("/healthz")
return r.status_code == 200
except Exception:
return False
def backend_info() -> dict[str, Any]:
"""Snapshot for /api/backend β what the UI should advertise."""
return {
"backend": _BACKEND,
"base_url": _BASE_URL or None,
"remote_enabled": remote_enabled(),
"reachable": healthcheck() if remote_enabled() else False,
}
def prithvi_pluvial(s2_chip, *, scene_id: str | None = None,
scene_datetime: str | None = None,
cloud_cover: float | None = None,
timeout: float | None = None) -> dict[str, Any]:
"""Remote forward pass through Prithvi-NYC-Pluvial v2.
Input: 6-band Sentinel-2 chip (numpy or torch, shape [6, H, W]).
Output: { ok, pct_water_within_500m, pct_water_full, scene_id, ... }.
Raises RemoteUnreachable if the service is down."""
arr = _to_numpy(s2_chip)
return _post("/v1/prithvi-pluvial", {
"s2": _serialize_array(arr),
"shape": list(arr.shape),
"scene_id": scene_id,
"scene_datetime": scene_datetime,
"cloud_cover": cloud_cover,
}, timeout=timeout)
def terramind(adapter: str, s2l2a=None, s1rtc=None, dem=None, *,
timeout: float | None = None) -> dict[str, Any]:
"""Remote forward through TerraMind-NYC-Adapters (LULC or Buildings)
or the v1 base generative path (synthesis). `adapter` is one of:
lulc, buildings, synthesis. Each modality is a numpy array, torch
tensor, or None β `synthesis` only needs DEM; the LoRA adapters
need at minimum S2L2A."""
payload: dict[str, Any] = {"adapter": adapter}
if s2l2a is not None:
s2_np = _to_numpy(s2l2a)
payload["s2"] = _serialize_array(s2_np)
payload["s2_shape"] = list(s2_np.shape)
if s1rtc is not None:
s1_np = _to_numpy(s1rtc)
payload["s1"] = _serialize_array(s1_np)
payload["s1_shape"] = list(s1_np.shape)
if dem is not None:
dem_np = _to_numpy(dem)
payload["dem"] = _serialize_array(dem_np)
payload["dem_shape"] = list(dem_np.shape)
return _post("/v1/terramind", payload, timeout=timeout)
def ttm_forecast(model: str, history: Iterable[float], *,
context_length: int, prediction_length: int,
cadence: str = "h",
timeout: float | None = None) -> dict[str, Any]:
"""Remote Granite TTM r2 forecast.
`model` is one of: zero_shot_battery, fine_tune_battery, weekly_311,
floodnet_recurrence β the service decides which checkpoint to use.
`history` is a 1-D iterable of floats (the time series); `cadence`
is for the service's labelling (h / d / w / 6m). Output shape is
`{ ok, forecast: [...], peak_index, peak_value }`."""
series = list(map(float, history))
return _post("/v1/ttm-forecast", {
"model": model,
"history": series,
"context_length": context_length,
"prediction_length": prediction_length,
"cadence": cadence,
}, timeout=timeout)
def granite_embed(texts: list[str], *,
timeout: float | None = None) -> dict[str, Any]:
"""Remote Granite Embedding 278M batch encode.
Output: { ok, vectors: [[float, ...], ...] }. Vector dimension fixed
at 768 (granite-embedding-278m-multilingual)."""
return _post("/v1/granite-embed", {"texts": list(texts)}, timeout=timeout)
def gliner_extract(text: str, labels: list[str], *,
timeout: float | None = None) -> dict[str, Any]:
"""Remote GLiNER typed-entity extraction.
Output: { ok, entities: [{label, text, start, end, score}, ...] }."""
return _post("/v1/gliner-extract", {
"text": text, "labels": list(labels),
}, timeout=timeout)
|