| """Remote-vs-local ML inference router. |
| |
| Mirrors the call-surface shape of `app/llm.py` but for the non-LLM |
| heavy models (Prithvi, TerraMind, TTM, Granite Embedding, GLiNER). |
| |
| The droplet runs a `riprap-models` FastAPI service alongside vLLM that |
| exposes an OpenAI-style endpoint per model class. When configured the |
| router POSTs the relevant payload there and returns the parsed response; |
| on connection error / 5xx / timeout it surfaces a typed exception that |
| caller modules catch and fall back to a local in-process model load. |
| |
| Backend selection (env): |
| |
| RIPRAP_ML_BACKEND = "remote" | "local" | "auto" (default: auto) |
| - remote: use only the droplet, raise if it errors |
| - local : never call the droplet, always use the |
| in-process model |
| - auto : try remote first, fall back to local if |
| remote is unreachable / errors out; |
| same semantics as app/llm.py |
| RIPRAP_ML_BASE_URL = http://129.212.181.238:8002 (no trailing slash) |
| RIPRAP_ML_API_KEY = <bearer token> |
| |
| The router is *transport*-only β it does not own model bytes, weights, |
| or framework imports. Each specialist that wants remote inference calls |
| into the helpers below and provides its own local fallback. That keeps |
| the dependency graph clean: the local code path keeps working when the |
| RIPRAP_ML_* env is unset (e.g. on first-light dev or in unit tests). |
| """ |
| from __future__ import annotations |
|
|
| import base64 |
| import logging |
| import os |
| from collections.abc import Iterable |
| from typing import Any |
|
|
| import httpx |
|
|
| log = logging.getLogger("riprap.inference") |
|
|
| _BACKEND = os.environ.get("RIPRAP_ML_BACKEND", "auto").lower() |
| _BASE_URL = os.environ.get("RIPRAP_ML_BASE_URL", "").rstrip("/") |
| _API_KEY = os.environ.get("RIPRAP_ML_API_KEY", "") |
| _DEFAULT_TIMEOUT = float(os.environ.get("RIPRAP_ML_TIMEOUT_S", "60")) |
|
|
|
|
| class RemoteUnreachable(RuntimeError): |
| """Raised when the remote inference service is unconfigured, down, |
| times out, or returns 5xx. Callers catch this to fall through to a |
| local model load. 4xx errors propagate as the generic exception so |
| a caller bug doesn't get masked by a "fallback to local" path.""" |
|
|
|
|
| def remote_enabled() -> bool: |
| """True iff the router is configured to attempt remote calls. |
| Returns False under explicit `local` mode or when the base URL is |
| empty (the auto-default with no env config).""" |
| if _BACKEND == "local": |
| return False |
| if not _BASE_URL: |
| return False |
| return True |
|
|
|
|
| def _client(timeout: float | None = None) -> httpx.Client: |
| headers = {"User-Agent": "riprap-app/0.4.5"} |
| if _API_KEY: |
| headers["Authorization"] = f"Bearer {_API_KEY}" |
| return httpx.Client( |
| base_url=_BASE_URL, |
| headers=headers, |
| timeout=timeout if timeout is not None else _DEFAULT_TIMEOUT, |
| ) |
|
|
|
|
| def _post(path: str, payload: dict[str, Any], timeout: float | None = None) -> dict: |
| """POST {payload} as JSON to the remote service's `path`. Returns the |
| parsed JSON body. Raises RemoteUnreachable on transport errors; |
| raises HTTPStatusError on 4xx so caller bugs surface.""" |
| if not remote_enabled(): |
| raise RemoteUnreachable("remote ML backend not configured " |
| "(RIPRAP_ML_BASE_URL empty or BACKEND=local)") |
| try: |
| with _client(timeout) as c: |
| r = c.post(path, json=payload) |
| except (httpx.ConnectError, httpx.ReadError, httpx.WriteError, |
| httpx.TimeoutException, httpx.RemoteProtocolError) as e: |
| raise RemoteUnreachable(f"{type(e).__name__}: {e}") from e |
| if r.status_code >= 500: |
| raise RemoteUnreachable(f"HTTP {r.status_code} from {path}: {r.text[:200]}") |
| r.raise_for_status() |
| return r.json() |
|
|
|
|
| def _serialize_array(arr) -> str: |
| """numpy/torch tensor β base64-encoded float32 raw bytes for transport. |
| Each remote handler decodes to (shape, dtype=float32) and reconstructs. |
| Reasonable round-trip for chips up to a few MB; large rasters should |
| use compressed numpy-savez instead β TODO when a model needs > 8 MB.""" |
| import numpy as np |
| np_arr = arr if isinstance(arr, np.ndarray) else _to_numpy(arr) |
| np_arr = np_arr.astype("float32", copy=False) |
| return base64.b64encode(np_arr.tobytes()).decode("ascii") |
|
|
|
|
| def _to_numpy(t): |
| """Best-effort tensor β numpy. Accepts torch.Tensor or numpy already.""" |
| try: |
| import torch |
| if isinstance(t, torch.Tensor): |
| return t.detach().cpu().numpy() |
| except ImportError: |
| pass |
| import numpy as np |
| return np.asarray(t) |
|
|
|
|
| def _deserialize_array(b64: str, shape: list[int]): |
| """Inverse of _serialize_array β bytes β numpy float32 with given shape.""" |
| import numpy as np |
| raw = base64.b64decode(b64) |
| return np.frombuffer(raw, dtype="float32").reshape(shape) |
|
|
|
|
| |
|
|
| def healthcheck(timeout: float = 3.0) -> bool: |
| """Quick reachability probe. True if the service responds 200 to GET |
| /healthz within `timeout` seconds. Used by /api/backend so the UI can |
| show whether the remote ML backend is currently live.""" |
| if not remote_enabled(): |
| return False |
| try: |
| with _client(timeout) as c: |
| r = c.get("/healthz") |
| return r.status_code == 200 |
| except Exception: |
| return False |
|
|
|
|
| def backend_info() -> dict[str, Any]: |
| """Snapshot for /api/backend β what the UI should advertise.""" |
| return { |
| "backend": _BACKEND, |
| "base_url": _BASE_URL or None, |
| "remote_enabled": remote_enabled(), |
| "reachable": healthcheck() if remote_enabled() else False, |
| } |
|
|
|
|
| def prithvi_pluvial(s2_chip, *, scene_id: str | None = None, |
| scene_datetime: str | None = None, |
| cloud_cover: float | None = None, |
| timeout: float | None = None) -> dict[str, Any]: |
| """Remote forward pass through Prithvi-NYC-Pluvial v2. |
| Input: 6-band Sentinel-2 chip (numpy or torch, shape [6, H, W]). |
| Output: { ok, pct_water_within_500m, pct_water_full, scene_id, ... }. |
| Raises RemoteUnreachable if the service is down.""" |
| arr = _to_numpy(s2_chip) |
| return _post("/v1/prithvi-pluvial", { |
| "s2": _serialize_array(arr), |
| "shape": list(arr.shape), |
| "scene_id": scene_id, |
| "scene_datetime": scene_datetime, |
| "cloud_cover": cloud_cover, |
| }, timeout=timeout) |
|
|
|
|
| def terramind(adapter: str, s2l2a=None, s1rtc=None, dem=None, *, |
| timeout: float | None = None) -> dict[str, Any]: |
| """Remote forward through TerraMind-NYC-Adapters (LULC or Buildings) |
| or the v1 base generative path (synthesis). `adapter` is one of: |
| lulc, buildings, synthesis. Each modality is a numpy array, torch |
| tensor, or None β `synthesis` only needs DEM; the LoRA adapters |
| need at minimum S2L2A.""" |
| payload: dict[str, Any] = {"adapter": adapter} |
| if s2l2a is not None: |
| s2_np = _to_numpy(s2l2a) |
| payload["s2"] = _serialize_array(s2_np) |
| payload["s2_shape"] = list(s2_np.shape) |
| if s1rtc is not None: |
| s1_np = _to_numpy(s1rtc) |
| payload["s1"] = _serialize_array(s1_np) |
| payload["s1_shape"] = list(s1_np.shape) |
| if dem is not None: |
| dem_np = _to_numpy(dem) |
| payload["dem"] = _serialize_array(dem_np) |
| payload["dem_shape"] = list(dem_np.shape) |
| return _post("/v1/terramind", payload, timeout=timeout) |
|
|
|
|
| def ttm_forecast(model: str, history: Iterable[float], *, |
| context_length: int, prediction_length: int, |
| cadence: str = "h", |
| timeout: float | None = None) -> dict[str, Any]: |
| """Remote Granite TTM r2 forecast. |
| `model` is one of: zero_shot_battery, fine_tune_battery, weekly_311, |
| floodnet_recurrence β the service decides which checkpoint to use. |
| `history` is a 1-D iterable of floats (the time series); `cadence` |
| is for the service's labelling (h / d / w / 6m). Output shape is |
| `{ ok, forecast: [...], peak_index, peak_value }`.""" |
| series = list(map(float, history)) |
| return _post("/v1/ttm-forecast", { |
| "model": model, |
| "history": series, |
| "context_length": context_length, |
| "prediction_length": prediction_length, |
| "cadence": cadence, |
| }, timeout=timeout) |
|
|
|
|
| def granite_embed(texts: list[str], *, |
| timeout: float | None = None) -> dict[str, Any]: |
| """Remote Granite Embedding 278M batch encode. |
| Output: { ok, vectors: [[float, ...], ...] }. Vector dimension fixed |
| at 768 (granite-embedding-278m-multilingual).""" |
| return _post("/v1/granite-embed", {"texts": list(texts)}, timeout=timeout) |
|
|
|
|
| def gliner_extract(text: str, labels: list[str], *, |
| timeout: float | None = None) -> dict[str, Any]: |
| """Remote GLiNER typed-entity extraction. |
| Output: { ok, entities: [{label, text, start, end, score}, ...] }.""" |
| return _post("/v1/gliner-extract", { |
| "text": text, "labels": list(labels), |
| }, timeout=timeout) |
|
|