riprap-nyc / app /inference.py
seriffic's picture
feat: terramind_synthesis now routes through droplet remote inference
eea4d6e
"""Remote-vs-local ML inference router.
Mirrors the call-surface shape of `app/llm.py` but for the non-LLM
heavy models (Prithvi, TerraMind, TTM, Granite Embedding, GLiNER).
The droplet runs a `riprap-models` FastAPI service alongside vLLM that
exposes an OpenAI-style endpoint per model class. When configured the
router POSTs the relevant payload there and returns the parsed response;
on connection error / 5xx / timeout it surfaces a typed exception that
caller modules catch and fall back to a local in-process model load.
Backend selection (env):
RIPRAP_ML_BACKEND = "remote" | "local" | "auto" (default: auto)
- remote: use only the droplet, raise if it errors
- local : never call the droplet, always use the
in-process model
- auto : try remote first, fall back to local if
remote is unreachable / errors out;
same semantics as app/llm.py
RIPRAP_ML_BASE_URL = http://129.212.181.238:8002 (no trailing slash)
RIPRAP_ML_API_KEY = <bearer token>
The router is *transport*-only β€” it does not own model bytes, weights,
or framework imports. Each specialist that wants remote inference calls
into the helpers below and provides its own local fallback. That keeps
the dependency graph clean: the local code path keeps working when the
RIPRAP_ML_* env is unset (e.g. on first-light dev or in unit tests).
"""
from __future__ import annotations
import base64
import logging
import os
from collections.abc import Iterable
from typing import Any
import httpx
log = logging.getLogger("riprap.inference")
_BACKEND = os.environ.get("RIPRAP_ML_BACKEND", "auto").lower()
_BASE_URL = os.environ.get("RIPRAP_ML_BASE_URL", "").rstrip("/")
_API_KEY = os.environ.get("RIPRAP_ML_API_KEY", "")
_DEFAULT_TIMEOUT = float(os.environ.get("RIPRAP_ML_TIMEOUT_S", "60"))
class RemoteUnreachable(RuntimeError):
"""Raised when the remote inference service is unconfigured, down,
times out, or returns 5xx. Callers catch this to fall through to a
local model load. 4xx errors propagate as the generic exception so
a caller bug doesn't get masked by a "fallback to local" path."""
def remote_enabled() -> bool:
"""True iff the router is configured to attempt remote calls.
Returns False under explicit `local` mode or when the base URL is
empty (the auto-default with no env config)."""
if _BACKEND == "local":
return False
if not _BASE_URL:
return False
return True
def _client(timeout: float | None = None) -> httpx.Client:
headers = {"User-Agent": "riprap-app/0.4.5"}
if _API_KEY:
headers["Authorization"] = f"Bearer {_API_KEY}"
return httpx.Client(
base_url=_BASE_URL,
headers=headers,
timeout=timeout if timeout is not None else _DEFAULT_TIMEOUT,
)
def _post(path: str, payload: dict[str, Any], timeout: float | None = None) -> dict:
"""POST {payload} as JSON to the remote service's `path`. Returns the
parsed JSON body. Raises RemoteUnreachable on transport errors;
raises HTTPStatusError on 4xx so caller bugs surface."""
if not remote_enabled():
raise RemoteUnreachable("remote ML backend not configured "
"(RIPRAP_ML_BASE_URL empty or BACKEND=local)")
try:
with _client(timeout) as c:
r = c.post(path, json=payload)
except (httpx.ConnectError, httpx.ReadError, httpx.WriteError,
httpx.TimeoutException, httpx.RemoteProtocolError) as e:
raise RemoteUnreachable(f"{type(e).__name__}: {e}") from e
if r.status_code >= 500:
raise RemoteUnreachable(f"HTTP {r.status_code} from {path}: {r.text[:200]}")
r.raise_for_status()
return r.json()
def _serialize_array(arr) -> str:
"""numpy/torch tensor β†’ base64-encoded float32 raw bytes for transport.
Each remote handler decodes to (shape, dtype=float32) and reconstructs.
Reasonable round-trip for chips up to a few MB; large rasters should
use compressed numpy-savez instead β€” TODO when a model needs > 8 MB."""
import numpy as np
np_arr = arr if isinstance(arr, np.ndarray) else _to_numpy(arr)
np_arr = np_arr.astype("float32", copy=False)
return base64.b64encode(np_arr.tobytes()).decode("ascii")
def _to_numpy(t):
"""Best-effort tensor β†’ numpy. Accepts torch.Tensor or numpy already."""
try:
import torch
if isinstance(t, torch.Tensor):
return t.detach().cpu().numpy()
except ImportError:
pass
import numpy as np
return np.asarray(t)
def _deserialize_array(b64: str, shape: list[int]):
"""Inverse of _serialize_array β€” bytes β†’ numpy float32 with given shape."""
import numpy as np
raw = base64.b64decode(b64)
return np.frombuffer(raw, dtype="float32").reshape(shape)
# ---- Public router entry points -------------------------------------------
def healthcheck(timeout: float = 3.0) -> bool:
"""Quick reachability probe. True if the service responds 200 to GET
/healthz within `timeout` seconds. Used by /api/backend so the UI can
show whether the remote ML backend is currently live."""
if not remote_enabled():
return False
try:
with _client(timeout) as c:
r = c.get("/healthz")
return r.status_code == 200
except Exception:
return False
def backend_info() -> dict[str, Any]:
"""Snapshot for /api/backend β€” what the UI should advertise."""
return {
"backend": _BACKEND,
"base_url": _BASE_URL or None,
"remote_enabled": remote_enabled(),
"reachable": healthcheck() if remote_enabled() else False,
}
def prithvi_pluvial(s2_chip, *, scene_id: str | None = None,
scene_datetime: str | None = None,
cloud_cover: float | None = None,
timeout: float | None = None) -> dict[str, Any]:
"""Remote forward pass through Prithvi-NYC-Pluvial v2.
Input: 6-band Sentinel-2 chip (numpy or torch, shape [6, H, W]).
Output: { ok, pct_water_within_500m, pct_water_full, scene_id, ... }.
Raises RemoteUnreachable if the service is down."""
arr = _to_numpy(s2_chip)
return _post("/v1/prithvi-pluvial", {
"s2": _serialize_array(arr),
"shape": list(arr.shape),
"scene_id": scene_id,
"scene_datetime": scene_datetime,
"cloud_cover": cloud_cover,
}, timeout=timeout)
def terramind(adapter: str, s2l2a=None, s1rtc=None, dem=None, *,
timeout: float | None = None) -> dict[str, Any]:
"""Remote forward through TerraMind-NYC-Adapters (LULC or Buildings)
or the v1 base generative path (synthesis). `adapter` is one of:
lulc, buildings, synthesis. Each modality is a numpy array, torch
tensor, or None β€” `synthesis` only needs DEM; the LoRA adapters
need at minimum S2L2A."""
payload: dict[str, Any] = {"adapter": adapter}
if s2l2a is not None:
s2_np = _to_numpy(s2l2a)
payload["s2"] = _serialize_array(s2_np)
payload["s2_shape"] = list(s2_np.shape)
if s1rtc is not None:
s1_np = _to_numpy(s1rtc)
payload["s1"] = _serialize_array(s1_np)
payload["s1_shape"] = list(s1_np.shape)
if dem is not None:
dem_np = _to_numpy(dem)
payload["dem"] = _serialize_array(dem_np)
payload["dem_shape"] = list(dem_np.shape)
return _post("/v1/terramind", payload, timeout=timeout)
def ttm_forecast(model: str, history: Iterable[float], *,
context_length: int, prediction_length: int,
cadence: str = "h",
timeout: float | None = None) -> dict[str, Any]:
"""Remote Granite TTM r2 forecast.
`model` is one of: zero_shot_battery, fine_tune_battery, weekly_311,
floodnet_recurrence β€” the service decides which checkpoint to use.
`history` is a 1-D iterable of floats (the time series); `cadence`
is for the service's labelling (h / d / w / 6m). Output shape is
`{ ok, forecast: [...], peak_index, peak_value }`."""
series = list(map(float, history))
return _post("/v1/ttm-forecast", {
"model": model,
"history": series,
"context_length": context_length,
"prediction_length": prediction_length,
"cadence": cadence,
}, timeout=timeout)
def granite_embed(texts: list[str], *,
timeout: float | None = None) -> dict[str, Any]:
"""Remote Granite Embedding 278M batch encode.
Output: { ok, vectors: [[float, ...], ...] }. Vector dimension fixed
at 768 (granite-embedding-278m-multilingual)."""
return _post("/v1/granite-embed", {"texts": list(texts)}, timeout=timeout)
def gliner_extract(text: str, labels: list[str], *,
timeout: float | None = None) -> dict[str, Any]:
"""Remote GLiNER typed-entity extraction.
Output: { ok, entities: [{label, text, start, end, score}, ...] }."""
return _post("/v1/gliner-extract", {
"text": text, "labels": list(labels),
}, timeout=timeout)