seriffic Claude Opus 4.7 (1M context) commited on
Commit
abcf7cd
Β·
1 Parent(s): 86e2a29

feat: route all GPU-accelerable inference through MI300X (Phase 1+2 of full GPU)

Browse files

The user's directive: "I want anything that can be GPU accelerated to
run on there. Otherwise, keep it on the CPU of wherever it's running."

Lands seven pieces in one commit; each can stand alone but they're
interlocked.

app/inference.py (new, ~250 lines)
Router shim that mirrors app/llm.py's shape but for non-LLM models.
Exports: prithvi_pluvial(), terramind(), ttm_forecast(),
granite_embed(), gliner_extract(), healthcheck(), backend_info(),
plus the typed RemoteUnreachable exception caller modules catch
to fall back to local. Env-driven via RIPRAP_ML_BACKEND
(auto|remote|local) / RIPRAP_ML_BASE_URL / RIPRAP_ML_API_KEY,
same shape as RIPRAP_LLM_*.

services/riprap-models/ (new microservice)
FastAPI service that runs alongside vLLM on the AMD MI300X
droplet. One endpoint per model class:
/v1/prithvi-pluvial Prithvi-NYC-Pluvial v2 segmentation
/v1/terramind LULC / Buildings / Synthesis (LoRA)
/v1/ttm-forecast Granite TTM r2 (zero-shot + Battery
fine-tune + 311 + FloodNet)
/v1/granite-embed Granite Embedding 278M batch encode
/v1/gliner-extract GLiNER typed-entity extraction
/healthz reachability + warm-model list
Bearer auth same shape as vLLM. Lazy + cached model loads, ROCm
device binding via torch.cuda. Model loading code lifted from
the proven local paths (terratorch / peft / safetensors / tsfm
/ sentence-transformers / gliner). Designed to live in the
existing `terramind` Docker container on the droplet, which
already has every heavy dep installed.

Deploy:
Code rsync'd into the terramind container at /workspace/riprap-models
earlier in this session and pip install ran clean. Dropping
`uvicorn main:app --host 0.0.0.0 --port 7860` inside the
container brings it up on the host's already-mapped port 7860.
Currently blocked: droplet 129.212.181.238 went unreachable
mid-deploy; resume the start command once SSH comes back.

Per-specialist wiring (try-remote-then-local):
app/flood_layers/prithvi_live.py β€” Prithvi-NYC-Pluvial v2 (live)
app/context/terramind_nyc.py β€” TerraMind LULC + Buildings
app/live/ttm_forecast.py β€” TTM r2 zero-shot (Battery /
311 / FloodNet variants share
one inference function)
app/live/ttm_battery_surge.py β€” TTM r2 NYC fine-tune
app/rag.py β€” Granite Embedding 278M
(corpus encode + per-query)
app/context/gliner_extract.py β€” GLiNER typed extraction

Each module: try remote first, fall back to local on
RemoteUnreachable. Local _DEPS_OK gates only matter for the
fallback path now β€” the cpu-basic HF Space can run end-to-end
once the droplet service is live without baking transformers /
peft / terratorch / tsfm_public / sentence-transformers /
gliner into its image.

Result objects gain a `compute` field ("remote Β· cuda" / "local")
so the UI can surface where each specialist's GPU work landed.

The router fails open: with no env config, remote_enabled()=False and
every specialist takes its existing local path. Set RIPRAP_ML_BASE_URL
and the remote path activates without code changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

app/context/gliner_extract.py CHANGED
@@ -80,8 +80,30 @@ def _source_short(rag_doc_id: str) -> str:
80
 
81
 
82
  def extract_for_chunk(text: str, threshold: float = DEFAULT_THRESHOLD) -> list[Extraction]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  model = _ensure_model()
84
- if model is None or not text:
85
  return []
86
  raw = model.predict_entities(text, ENTITY_LABELS, threshold=threshold)
87
  return [Extraction(label=r["label"], text=r["text"],
 
80
 
81
 
82
  def extract_for_chunk(text: str, threshold: float = DEFAULT_THRESHOLD) -> list[Extraction]:
83
+ if not text:
84
+ return []
85
+
86
+ # v0.4.5 β€” try the MI300X service first. The remote handles its
87
+ # own GLiNER load; this lets cpu-basic surfaces run typed
88
+ # extraction without baking gliner into the image.
89
+ try:
90
+ from app import inference as _inf
91
+ if _inf.remote_enabled():
92
+ remote = _inf.gliner_extract(text, ENTITY_LABELS)
93
+ if remote.get("ok"):
94
+ return [
95
+ Extraction(label=e["label"], text=e["text"],
96
+ score=float(e.get("score", 0)))
97
+ for e in remote.get("entities", [])
98
+ if e.get("score", 0) >= threshold
99
+ ]
100
+ except _inf.RemoteUnreachable as e:
101
+ log.info("gliner: remote unreachable (%s); local fallback", e)
102
+ except Exception:
103
+ log.exception("gliner: remote call failed; local fallback")
104
+
105
  model = _ensure_model()
106
+ if model is None:
107
  return []
108
  raw = model.predict_entities(text, ENTITY_LABELS, threshold=threshold)
109
  return [Extraction(label=r["label"], text=r["text"],
app/context/terramind_nyc.py CHANGED
@@ -293,11 +293,51 @@ def _summarize_buildings(pred, class_labels: list[str]) -> dict[str, Any]:
293
  }
294
 
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  def _run(adapter_name: str, modality_chips: dict, summarizer):
297
- """Common boilerplate: gate, time, load, tiled predict, summarize."""
 
298
  if not ENABLE:
299
  return {"ok": False,
300
  "skipped": "RIPRAP_TERRAMIND_NYC_ENABLE=0"}
 
 
 
 
 
 
 
 
301
  if not _DEPS_OK:
302
  return {"ok": False,
303
  "skipped": f"deps unavailable on this deployment: "
@@ -315,6 +355,7 @@ def _run(adapter_name: str, modality_chips: dict, summarizer):
315
  result["elapsed_s"] = round(time.time() - t0, 2)
316
  result["adapter"] = adapter_name
317
  result["repo"] = ADAPTERS_REPO
 
318
  return result
319
  except Exception as e:
320
  log.exception("terramind_nyc.%s failed", adapter_name)
 
293
  }
294
 
295
 
296
+ def _try_remote(adapter_name: str, modality_chips: dict) -> dict | None:
297
+ """v0.4.5 β€” POST to MI300X riprap-models if configured. Returns the
298
+ parsed result on success; None on RemoteUnreachable so the caller
299
+ falls through to the local terratorch path."""
300
+ try:
301
+ from app import inference as _inf
302
+ if not _inf.remote_enabled():
303
+ return None
304
+ s2 = modality_chips.get("S2L2A")
305
+ s1 = modality_chips.get("S1RTC")
306
+ dem = modality_chips.get("DEM")
307
+ # The router serializes torch tensors to base64 numpy float32 β€”
308
+ # the chip cache hands us [B, C, T, H, W]; keep that shape, the
309
+ # service rebuilds the temporal stack on its end.
310
+ result = _inf.terramind(adapter_name, s2, s1, dem)
311
+ if not result.get("ok"):
312
+ return None
313
+ result.setdefault("adapter", adapter_name)
314
+ result.setdefault("repo", ADAPTERS_REPO)
315
+ result["compute"] = f"remote Β· {result.get('device', 'gpu')}"
316
+ return result
317
+ except _inf.RemoteUnreachable as e:
318
+ log.info("terramind/%s: remote unreachable (%s); local fallback",
319
+ adapter_name, e)
320
+ return None
321
+ except Exception:
322
+ log.exception("terramind/%s: remote call failed; local fallback",
323
+ adapter_name)
324
+ return None
325
+
326
+
327
  def _run(adapter_name: str, modality_chips: dict, summarizer):
328
+ """Common boilerplate: gate, time, [remote attempt], load, tiled
329
+ predict, summarize."""
330
  if not ENABLE:
331
  return {"ok": False,
332
  "skipped": "RIPRAP_TERRAMIND_NYC_ENABLE=0"}
333
+
334
+ # v0.4.5 β€” try remote first. The remote service has its own deps,
335
+ # so this path works even when local _DEPS_OK is False (the most
336
+ # common HF Spaces case until terratorch + peft are baked in).
337
+ remote = _try_remote(adapter_name, modality_chips or {})
338
+ if remote is not None:
339
+ return remote
340
+
341
  if not _DEPS_OK:
342
  return {"ok": False,
343
  "skipped": f"deps unavailable on this deployment: "
 
355
  result["elapsed_s"] = round(time.time() - t0, 2)
356
  result["adapter"] = adapter_name
357
  result["repo"] = ADAPTERS_REPO
358
+ result["compute"] = "local"
359
  return result
360
  except Exception as e:
361
  log.exception("terramind_nyc.%s failed", adapter_name)
app/flood_layers/prithvi_live.py CHANGED
@@ -350,6 +350,43 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
350
  img, ref_da, epsg = _build_chip(item, lat, lon)
351
  if time.time() - t0 > timeout_s:
352
  return {"ok": False, "skipped": "chip build exceeded budget"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  model, run_model = _ensure_model()
354
  x = img[None, :, None, :, :] # (1, 6, 1, H, W)
355
  pred_t = run_model(x, None, None, model.model, model.datamodule, IMG_SIZE)
@@ -361,7 +398,6 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
361
  radius_px = CENTER_RADIUS_M / PIXEL_M
362
  circle = (yy - cy) ** 2 + (xx - cx) ** 2 <= radius_px ** 2
363
  pct_500 = float(100.0 * pred[circle].mean()) if circle.sum() else 0.0
364
- # Polygonize the water mask into EPSG:4326 GeoJSON for the map.
365
  polygons_geojson = _polygonize_mask(pred, ref_da, epsg)
366
  return {
367
  "ok": True,
@@ -371,6 +407,7 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
371
  "pct_water_full": pct_full,
372
  "pct_water_within_500m": pct_500,
373
  "polygons_geojson": polygons_geojson,
 
374
  "elapsed_s": round(time.time() - t0, 2),
375
  }
376
  except Exception as e:
 
350
  img, ref_da, epsg = _build_chip(item, lat, lon)
351
  if time.time() - t0 > timeout_s:
352
  return {"ok": False, "skipped": "chip build exceeded budget"}
353
+
354
+ # v0.4.5 β€” try the MI300X inference service first if configured.
355
+ # On RemoteUnreachable (service down / not configured / 5xx) fall
356
+ # through to the local terratorch path. The 4-band slice the
357
+ # service expects is the same shape the local path uses.
358
+ try:
359
+ from app import inference as _inf
360
+ if _inf.remote_enabled():
361
+ remote = _inf.prithvi_pluvial(
362
+ img, scene_id=item.id,
363
+ scene_datetime=str(item.datetime),
364
+ cloud_cover=cc,
365
+ timeout=timeout_s,
366
+ )
367
+ if remote.get("ok"):
368
+ return {
369
+ "ok": True,
370
+ "item_id": item.id,
371
+ "item_datetime": str(item.datetime),
372
+ "cloud_cover": cc,
373
+ "pct_water_full": remote.get("pct_water_full"),
374
+ "pct_water_within_500m": remote.get("pct_water_within_500m"),
375
+ # Service doesn't currently return polygonised GeoJSON
376
+ # (transport size); the local fallback below produces
377
+ # them. For now the remote path leaves polygons null
378
+ # and the map renders the layer empty until the
379
+ # service grows a polygonisation step.
380
+ "polygons_geojson": None,
381
+ "compute": f"remote Β· {remote.get('device', 'gpu')}",
382
+ "elapsed_s": round(time.time() - t0, 2),
383
+ }
384
+ except _inf.RemoteUnreachable as e:
385
+ log.info("prithvi_live: remote unreachable (%s); falling back to local", e)
386
+ except Exception:
387
+ log.exception("prithvi_live: remote call failed; falling back to local")
388
+
389
+ # Local fallback β€” the path that's been live since v0.4.4.
390
  model, run_model = _ensure_model()
391
  x = img[None, :, None, :, :] # (1, 6, 1, H, W)
392
  pred_t = run_model(x, None, None, model.model, model.datamodule, IMG_SIZE)
 
398
  radius_px = CENTER_RADIUS_M / PIXEL_M
399
  circle = (yy - cy) ** 2 + (xx - cx) ** 2 <= radius_px ** 2
400
  pct_500 = float(100.0 * pred[circle].mean()) if circle.sum() else 0.0
 
401
  polygons_geojson = _polygonize_mask(pred, ref_da, epsg)
402
  return {
403
  "ok": True,
 
407
  "pct_water_full": pct_full,
408
  "pct_water_within_500m": pct_500,
409
  "polygons_geojson": polygons_geojson,
410
+ "compute": "local",
411
  "elapsed_s": round(time.time() - t0, 2),
412
  }
413
  except Exception as e:
app/inference.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Remote-vs-local ML inference router.
2
+
3
+ Mirrors the call-surface shape of `app/llm.py` but for the non-LLM
4
+ heavy models (Prithvi, TerraMind, TTM, Granite Embedding, GLiNER).
5
+
6
+ The droplet runs a `riprap-models` FastAPI service alongside vLLM that
7
+ exposes an OpenAI-style endpoint per model class. When configured the
8
+ router POSTs the relevant payload there and returns the parsed response;
9
+ on connection error / 5xx / timeout it surfaces a typed exception that
10
+ caller modules catch and fall back to a local in-process model load.
11
+
12
+ Backend selection (env):
13
+
14
+ RIPRAP_ML_BACKEND = "remote" | "local" | "auto" (default: auto)
15
+ - remote: use only the droplet, raise if it errors
16
+ - local : never call the droplet, always use the
17
+ in-process model
18
+ - auto : try remote first, fall back to local if
19
+ remote is unreachable / errors out;
20
+ same semantics as app/llm.py
21
+ RIPRAP_ML_BASE_URL = http://129.212.181.238:8002 (no trailing slash)
22
+ RIPRAP_ML_API_KEY = <bearer token>
23
+
24
+ The router is *transport*-only β€” it does not own model bytes, weights,
25
+ or framework imports. Each specialist that wants remote inference calls
26
+ into the helpers below and provides its own local fallback. That keeps
27
+ the dependency graph clean: the local code path keeps working when the
28
+ RIPRAP_ML_* env is unset (e.g. on first-light dev or in unit tests).
29
+ """
30
+ from __future__ import annotations
31
+
32
+ import base64
33
+ import io
34
+ import logging
35
+ import os
36
+ from typing import Any, Iterable
37
+
38
+ import httpx
39
+
40
+ log = logging.getLogger("riprap.inference")
41
+
42
+ _BACKEND = os.environ.get("RIPRAP_ML_BACKEND", "auto").lower()
43
+ _BASE_URL = os.environ.get("RIPRAP_ML_BASE_URL", "").rstrip("/")
44
+ _API_KEY = os.environ.get("RIPRAP_ML_API_KEY", "")
45
+ _DEFAULT_TIMEOUT = float(os.environ.get("RIPRAP_ML_TIMEOUT_S", "60"))
46
+
47
+
48
+ class RemoteUnreachable(RuntimeError):
49
+ """Raised when the remote inference service is unconfigured, down,
50
+ times out, or returns 5xx. Callers catch this to fall through to a
51
+ local model load. 4xx errors propagate as the generic exception so
52
+ a caller bug doesn't get masked by a "fallback to local" path."""
53
+
54
+
55
+ def remote_enabled() -> bool:
56
+ """True iff the router is configured to attempt remote calls.
57
+ Returns False under explicit `local` mode or when the base URL is
58
+ empty (the auto-default with no env config)."""
59
+ if _BACKEND == "local":
60
+ return False
61
+ if not _BASE_URL:
62
+ return False
63
+ return True
64
+
65
+
66
+ def _client(timeout: float | None = None) -> httpx.Client:
67
+ headers = {"User-Agent": "riprap-app/0.4.5"}
68
+ if _API_KEY:
69
+ headers["Authorization"] = f"Bearer {_API_KEY}"
70
+ return httpx.Client(
71
+ base_url=_BASE_URL,
72
+ headers=headers,
73
+ timeout=timeout if timeout is not None else _DEFAULT_TIMEOUT,
74
+ )
75
+
76
+
77
+ def _post(path: str, payload: dict[str, Any], timeout: float | None = None) -> dict:
78
+ """POST {payload} as JSON to the remote service's `path`. Returns the
79
+ parsed JSON body. Raises RemoteUnreachable on transport errors;
80
+ raises HTTPStatusError on 4xx so caller bugs surface."""
81
+ if not remote_enabled():
82
+ raise RemoteUnreachable("remote ML backend not configured "
83
+ "(RIPRAP_ML_BASE_URL empty or BACKEND=local)")
84
+ try:
85
+ with _client(timeout) as c:
86
+ r = c.post(path, json=payload)
87
+ except (httpx.ConnectError, httpx.ReadError, httpx.WriteError,
88
+ httpx.TimeoutException, httpx.RemoteProtocolError) as e:
89
+ raise RemoteUnreachable(f"{type(e).__name__}: {e}") from e
90
+ if r.status_code >= 500:
91
+ raise RemoteUnreachable(f"HTTP {r.status_code} from {path}: {r.text[:200]}")
92
+ r.raise_for_status()
93
+ return r.json()
94
+
95
+
96
+ def _serialize_array(arr) -> str:
97
+ """numpy/torch tensor β†’ base64-encoded float32 raw bytes for transport.
98
+ Each remote handler decodes to (shape, dtype=float32) and reconstructs.
99
+ Reasonable round-trip for chips up to a few MB; large rasters should
100
+ use compressed numpy-savez instead β€” TODO when a model needs > 8 MB."""
101
+ import numpy as np
102
+ np_arr = arr if isinstance(arr, np.ndarray) else _to_numpy(arr)
103
+ np_arr = np_arr.astype("float32", copy=False)
104
+ return base64.b64encode(np_arr.tobytes()).decode("ascii")
105
+
106
+
107
+ def _to_numpy(t):
108
+ """Best-effort tensor β†’ numpy. Accepts torch.Tensor or numpy already."""
109
+ try:
110
+ import torch
111
+ if isinstance(t, torch.Tensor):
112
+ return t.detach().cpu().numpy()
113
+ except ImportError:
114
+ pass
115
+ import numpy as np
116
+ return np.asarray(t)
117
+
118
+
119
+ def _deserialize_array(b64: str, shape: list[int]):
120
+ """Inverse of _serialize_array β€” bytes β†’ numpy float32 with given shape."""
121
+ import numpy as np
122
+ raw = base64.b64decode(b64)
123
+ return np.frombuffer(raw, dtype="float32").reshape(shape)
124
+
125
+
126
+ # ---- Public router entry points -------------------------------------------
127
+
128
+ def healthcheck(timeout: float = 3.0) -> bool:
129
+ """Quick reachability probe. True if the service responds 200 to GET
130
+ /healthz within `timeout` seconds. Used by /api/backend so the UI can
131
+ show whether the remote ML backend is currently live."""
132
+ if not remote_enabled():
133
+ return False
134
+ try:
135
+ with _client(timeout) as c:
136
+ r = c.get("/healthz")
137
+ return r.status_code == 200
138
+ except Exception:
139
+ return False
140
+
141
+
142
+ def backend_info() -> dict[str, Any]:
143
+ """Snapshot for /api/backend β€” what the UI should advertise."""
144
+ return {
145
+ "backend": _BACKEND,
146
+ "base_url": _BASE_URL or None,
147
+ "remote_enabled": remote_enabled(),
148
+ "reachable": healthcheck() if remote_enabled() else False,
149
+ }
150
+
151
+
152
+ def prithvi_pluvial(s2_chip, *, scene_id: str | None = None,
153
+ scene_datetime: str | None = None,
154
+ cloud_cover: float | None = None,
155
+ timeout: float | None = None) -> dict[str, Any]:
156
+ """Remote forward pass through Prithvi-NYC-Pluvial v2.
157
+ Input: 6-band Sentinel-2 chip (numpy or torch, shape [6, H, W]).
158
+ Output: { ok, pct_water_within_500m, pct_water_full, scene_id, ... }.
159
+ Raises RemoteUnreachable if the service is down."""
160
+ arr = _to_numpy(s2_chip)
161
+ return _post("/v1/prithvi-pluvial", {
162
+ "s2": _serialize_array(arr),
163
+ "shape": list(arr.shape),
164
+ "scene_id": scene_id,
165
+ "scene_datetime": scene_datetime,
166
+ "cloud_cover": cloud_cover,
167
+ }, timeout=timeout)
168
+
169
+
170
+ def terramind(adapter: str, s2l2a, s1rtc=None, dem=None, *,
171
+ timeout: float | None = None) -> dict[str, Any]:
172
+ """Remote forward through TerraMind-NYC-Adapters (LULC or Buildings)
173
+ or the v1 base (synthetic). `adapter` is one of: lulc, buildings,
174
+ synthesis. Each modality is a numpy array or None."""
175
+ payload: dict[str, Any] = {"adapter": adapter}
176
+ s2_np = _to_numpy(s2l2a)
177
+ payload["s2"] = _serialize_array(s2_np)
178
+ payload["s2_shape"] = list(s2_np.shape)
179
+ if s1rtc is not None:
180
+ s1_np = _to_numpy(s1rtc)
181
+ payload["s1"] = _serialize_array(s1_np)
182
+ payload["s1_shape"] = list(s1_np.shape)
183
+ if dem is not None:
184
+ dem_np = _to_numpy(dem)
185
+ payload["dem"] = _serialize_array(dem_np)
186
+ payload["dem_shape"] = list(dem_np.shape)
187
+ return _post("/v1/terramind", payload, timeout=timeout)
188
+
189
+
190
+ def ttm_forecast(model: str, history: Iterable[float], *,
191
+ context_length: int, prediction_length: int,
192
+ cadence: str = "h",
193
+ timeout: float | None = None) -> dict[str, Any]:
194
+ """Remote Granite TTM r2 forecast.
195
+ `model` is one of: zero_shot_battery, fine_tune_battery, weekly_311,
196
+ floodnet_recurrence β€” the service decides which checkpoint to use.
197
+ `history` is a 1-D iterable of floats (the time series); `cadence`
198
+ is for the service's labelling (h / d / w / 6m). Output shape is
199
+ `{ ok, forecast: [...], peak_index, peak_value }`."""
200
+ series = list(map(float, history))
201
+ return _post("/v1/ttm-forecast", {
202
+ "model": model,
203
+ "history": series,
204
+ "context_length": context_length,
205
+ "prediction_length": prediction_length,
206
+ "cadence": cadence,
207
+ }, timeout=timeout)
208
+
209
+
210
+ def granite_embed(texts: list[str], *,
211
+ timeout: float | None = None) -> dict[str, Any]:
212
+ """Remote Granite Embedding 278M batch encode.
213
+ Output: { ok, vectors: [[float, ...], ...] }. Vector dimension fixed
214
+ at 768 (granite-embedding-278m-multilingual)."""
215
+ return _post("/v1/granite-embed", {"texts": list(texts)}, timeout=timeout)
216
+
217
+
218
+ def gliner_extract(text: str, labels: list[str], *,
219
+ timeout: float | None = None) -> dict[str, Any]:
220
+ """Remote GLiNER typed-entity extraction.
221
+ Output: { ok, entities: [{label, text, start, end, score}, ...] }."""
222
+ return _post("/v1/gliner-extract", {
223
+ "text": text, "labels": list(labels),
224
+ }, timeout=timeout)
app/live/ttm_battery_surge.py CHANGED
@@ -230,10 +230,7 @@ def fetch(timeout_s: float = 60.0) -> dict[str, Any]:
230
  if not ENABLE:
231
  return {"available": False,
232
  "reason": "RIPRAP_TTM_BATTERY_SURGE_ENABLE=0"}
233
- if not _DEPS_OK:
234
- return {"available": False,
235
- "reason": f"deps unavailable on this deployment: "
236
- f"{_DEPS_MISSING}"}
237
  t0 = time.time()
238
  try:
239
  df = _fetch_battery_history(CONTEXT_LENGTH)
@@ -245,21 +242,51 @@ def fetch(timeout_s: float = 60.0) -> dict[str, Any]:
245
  return {"available": False,
246
  "reason": "NOAA fetch exceeded budget"}
247
 
248
- import torch
249
- model = _ensure_model()
250
- # [B=1, T=1024, C=1] tensor of metres surge residual.
251
  residuals = df["surge_residual_m"].to_numpy().astype("float32")
252
- past = torch.from_numpy(residuals).unsqueeze(0).unsqueeze(-1)
253
- if DEVICE == "cuda":
254
- try:
255
- if torch.cuda.is_available():
256
- past = past.cuda()
257
- except Exception:
258
- log.exception("ttm_battery_surge: cuda move failed")
259
- with torch.no_grad():
260
- out = model(past_values=past)
261
- forecast = out.prediction_outputs.squeeze(-1).squeeze(0).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  result = _summarize(df, forecast)
 
263
  result["elapsed_s"] = round(time.time() - t0, 2)
264
  return result
265
  except Exception as e:
 
230
  if not ENABLE:
231
  return {"available": False,
232
  "reason": "RIPRAP_TTM_BATTERY_SURGE_ENABLE=0"}
233
+
 
 
 
234
  t0 = time.time()
235
  try:
236
  df = _fetch_battery_history(CONTEXT_LENGTH)
 
242
  return {"available": False,
243
  "reason": "NOAA fetch exceeded budget"}
244
 
 
 
 
245
  residuals = df["surge_residual_m"].to_numpy().astype("float32")
246
+
247
+ # v0.4.5 β€” try the MI300X service first. The remote handles its
248
+ # own model loading; if it's reachable we never need local
249
+ # tsfm_public, which lets the HF Space drop the granite-tsfm
250
+ # bake from the image.
251
+ forecast = None
252
+ compute = "local"
253
+ try:
254
+ from app import inference as _inf
255
+ if _inf.remote_enabled():
256
+ remote = _inf.ttm_forecast(
257
+ "fine_tune_battery", residuals.tolist(),
258
+ context_length=CONTEXT_LENGTH,
259
+ prediction_length=PREDICTION_LENGTH,
260
+ cadence="h",
261
+ timeout=timeout_s,
262
+ )
263
+ if remote.get("ok"):
264
+ import numpy as np
265
+ forecast = np.asarray(remote["forecast"], dtype="float32")
266
+ compute = f"remote Β· {remote.get('device', 'gpu')}"
267
+ except _inf.RemoteUnreachable as e:
268
+ log.info("ttm_battery_surge: remote unreachable (%s); local", e)
269
+
270
+ if forecast is None:
271
+ if not _DEPS_OK:
272
+ return {"available": False,
273
+ "reason": f"deps unavailable on this deployment: "
274
+ f"{_DEPS_MISSING}"}
275
+ import torch
276
+ model = _ensure_model()
277
+ past = torch.from_numpy(residuals).unsqueeze(0).unsqueeze(-1)
278
+ if DEVICE == "cuda":
279
+ try:
280
+ if torch.cuda.is_available():
281
+ past = past.cuda()
282
+ except Exception:
283
+ log.exception("ttm_battery_surge: cuda move failed")
284
+ with torch.no_grad():
285
+ out = model(past_values=past)
286
+ forecast = out.prediction_outputs.squeeze(-1).squeeze(0).cpu().numpy()
287
+
288
  result = _summarize(df, forecast)
289
+ result["compute"] = compute
290
  result["elapsed_s"] = round(time.time() - t0, 2)
291
  return result
292
  except Exception as e:
app/live/ttm_forecast.py CHANGED
@@ -180,16 +180,44 @@ def _residual_series(station_id: str,
180
 
181
  def _run_ttm(history: np.ndarray,
182
  context_length: int = CONTEXT_LENGTH,
183
- prediction_length: int = PREDICTION_LENGTH) -> np.ndarray | None:
 
184
  """Channel-wise standardize, run model, de-standardize. Returns a
185
- `prediction_length`-step de-standardized forecast in input units."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  model = _load_model(context_length, prediction_length)
187
  if model is None:
188
  return None
189
  import torch
190
- mu = float(history.mean())
191
- sigma = float(history.std() + 1e-6)
192
- normed = (history - mu) / sigma
193
  x = torch.from_numpy(normed.astype(np.float32))[None, :, None]
194
  try:
195
  with torch.no_grad():
 
180
 
181
  def _run_ttm(history: np.ndarray,
182
  context_length: int = CONTEXT_LENGTH,
183
+ prediction_length: int = PREDICTION_LENGTH,
184
+ cadence: str = "h") -> np.ndarray | None:
185
  """Channel-wise standardize, run model, de-standardize. Returns a
186
+ `prediction_length`-step de-standardized forecast in input units.
187
+
188
+ v0.4.5 β€” tries the MI300X riprap-models service first; falls back
189
+ to the local in-process model on RemoteUnreachable. The
190
+ standardize / de-standardize math is owned by THIS function so the
191
+ remote service stays a thin "given a series, give me a forecast"
192
+ contract.
193
+ """
194
+ mu = float(history.mean())
195
+ sigma = float(history.std() + 1e-6)
196
+ normed = (history - mu) / sigma
197
+
198
+ # Try remote first
199
+ try:
200
+ from app import inference as _inf
201
+ if _inf.remote_enabled():
202
+ remote = _inf.ttm_forecast(
203
+ "zero_shot_battery", normed.tolist(),
204
+ context_length=context_length,
205
+ prediction_length=prediction_length,
206
+ cadence=cadence,
207
+ )
208
+ if remote.get("ok"):
209
+ pred = np.asarray(remote["forecast"], dtype=np.float32)
210
+ return pred * sigma + mu
211
+ except _inf.RemoteUnreachable as e:
212
+ log.info("TTM zero-shot: remote unreachable (%s); local fallback", e)
213
+ except Exception:
214
+ log.exception("TTM zero-shot remote call failed; local fallback")
215
+
216
+ # Local fallback
217
  model = _load_model(context_length, prediction_length)
218
  if model is None:
219
  return None
220
  import torch
 
 
 
221
  x = torch.from_numpy(normed.astype(np.float32))[None, :, None]
222
  try:
223
  with torch.no_grad():
app/rag.py CHANGED
@@ -132,15 +132,38 @@ def _ensure_index():
132
  _INDEX = {"chunks": [], "embs": None, "model": None}
133
  return _INDEX
134
 
135
- from sentence_transformers import SentenceTransformer
136
- log.info("rag: loading %s", EMBED_MODEL_NAME)
137
- model = SentenceTransformer(EMBED_MODEL_NAME)
138
-
139
  texts = [c.text for c in chunks]
140
  log.info("rag: embedding %d chunks", len(texts))
141
- embs = model.encode(texts, batch_size=32, show_progress_bar=False,
142
- convert_to_numpy=True, normalize_embeddings=True)
143
- _INDEX = {"chunks": chunks, "embs": embs.astype("float32"), "model": model}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  log.info("rag: index ready (%s)", embs.shape)
145
  return _INDEX
146
 
@@ -173,8 +196,35 @@ def retrieve(query: str, k: int = 4, min_score: float = 0.30) -> list[dict]:
173
  idx = _ensure_index()
174
  if idx["embs"] is None or not idx["chunks"]:
175
  return []
176
- qv = idx["model"].encode([query], convert_to_numpy=True,
177
- normalize_embeddings=True).astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  sims = (idx["embs"] @ qv.T).ravel()
179
 
180
  reranker = _ensure_reranker()
 
132
  _INDEX = {"chunks": [], "embs": None, "model": None}
133
  return _INDEX
134
 
 
 
 
 
135
  texts = [c.text for c in chunks]
136
  log.info("rag: embedding %d chunks", len(texts))
137
+
138
+ # v0.4.5 β€” try the MI300X service first. Avoids loading
139
+ # sentence-transformers + the granite-embedding weights on a
140
+ # cpu-basic surface (HF Space). Falls back to local on
141
+ # RemoteUnreachable so dev laptops keep working with no env.
142
+ embs = None
143
+ model = None
144
+ try:
145
+ from app import inference as _inf
146
+ if _inf.remote_enabled():
147
+ log.info("rag: encoding via remote MI300X")
148
+ remote = _inf.granite_embed(texts, timeout=120.0)
149
+ if remote.get("ok"):
150
+ embs = np.asarray(remote["vectors"], dtype="float32")
151
+ # Per-query encodes will also route through remote;
152
+ # `model` stays None and `retrieve()` checks for it.
153
+ except _inf.RemoteUnreachable as e:
154
+ log.info("rag: remote unreachable (%s); local fallback", e)
155
+ except Exception:
156
+ log.exception("rag: remote encode failed; local fallback")
157
+
158
+ if embs is None:
159
+ from sentence_transformers import SentenceTransformer
160
+ log.info("rag: loading %s (local fallback)", EMBED_MODEL_NAME)
161
+ model = SentenceTransformer(EMBED_MODEL_NAME)
162
+ embs = model.encode(texts, batch_size=32, show_progress_bar=False,
163
+ convert_to_numpy=True, normalize_embeddings=True)
164
+ embs = embs.astype("float32")
165
+
166
+ _INDEX = {"chunks": chunks, "embs": embs, "model": model}
167
  log.info("rag: index ready (%s)", embs.shape)
168
  return _INDEX
169
 
 
196
  idx = _ensure_index()
197
  if idx["embs"] is None or not idx["chunks"]:
198
  return []
199
+
200
+ # v0.4.5 β€” encode query via remote when corpus was embedded remotely.
201
+ # `_ensure_index` leaves `model = None` when it took the remote
202
+ # path, so this branch handles both:
203
+ # - model present β†’ local SentenceTransformer.encode (fast, in-mem)
204
+ # - model is None β†’ POST to MI300X, fallback to a one-shot local
205
+ # SentenceTransformer load if remote is down.
206
+ if idx["model"] is not None:
207
+ qv = idx["model"].encode([query], convert_to_numpy=True,
208
+ normalize_embeddings=True).astype("float32")
209
+ else:
210
+ qv = None
211
+ try:
212
+ from app import inference as _inf
213
+ if _inf.remote_enabled():
214
+ remote = _inf.granite_embed([query])
215
+ if remote.get("ok"):
216
+ qv = np.asarray(remote["vectors"], dtype="float32")
217
+ except _inf.RemoteUnreachable as e:
218
+ log.info("rag: per-query encode remote unreachable (%s)", e)
219
+ if qv is None:
220
+ from sentence_transformers import SentenceTransformer
221
+ log.info("rag: cold-loading %s for per-query encode (remote down)",
222
+ EMBED_MODEL_NAME)
223
+ local = SentenceTransformer(EMBED_MODEL_NAME)
224
+ qv = local.encode([query], convert_to_numpy=True,
225
+ normalize_embeddings=True).astype("float32")
226
+ # Cache so subsequent queries don't re-load
227
+ idx["model"] = local
228
  sims = (idx["embs"] @ qv.T).ravel()
229
 
230
  reranker = _ensure_reranker()
services/riprap-models/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Riprap Models β€” droplet inference service
2
+
3
+ GPU inference microservice that runs alongside vLLM on the AMD MI300X
4
+ droplet. Exposes one HTTP endpoint per model class consumed by the
5
+ Riprap FastAPI app's specialists, so all GPU-accelerable forward
6
+ passes (Prithvi-NYC-Pluvial, TerraMind LULC + Buildings, Granite TTM
7
+ r2, Granite Embedding 278M, GLiNER) run on the MI300X regardless of
8
+ which surface β€” laptop or HF Space β€” hosts the FastAPI process.
9
+
10
+ ## Service contract
11
+
12
+ | Method | Path | Purpose |
13
+ |---|---|---|
14
+ | GET | `/healthz` | reachability probe + which models are warm |
15
+ | POST | `/v1/prithvi-pluvial` | Prithvi-NYC-Pluvial v2 segmentation |
16
+ | POST | `/v1/terramind` | TerraMind LULC / Buildings / Synthesis (adapter-dispatched) |
17
+ | POST | `/v1/ttm-forecast` | Granite TTM r2 (zero-shot Battery, fine-tune Battery, weekly 311, FloodNet recurrence) |
18
+ | POST | `/v1/granite-embed` | Granite Embedding 278M batch encode |
19
+ | POST | `/v1/gliner-extract` | GLiNER typed-entity extraction |
20
+
21
+ Auth: bearer token on every `/v1/*` route via `RIPRAP_MODELS_API_KEY`.
22
+ Same shape as vLLM. `/healthz` is open so liveness probes don't need
23
+ auth.
24
+
25
+ ## Deploy
26
+
27
+ The droplet's existing `terramind` container already has
28
+ `torch+ROCm 7.0`, `terratorch 1.2.7`, `granite-tsfm 0.3.6`,
29
+ `transformers 4.57`, `peft`, `safetensors`, `fastapi`, `uvicorn`. The
30
+ service code lands under `/workspace/riprap-models/`; only deltas
31
+ need installing.
32
+
33
+ ```bash
34
+ # Copy code (run from project root)
35
+ ssh root@129.212.181.238 'mkdir -p /workspace/riprap-models'
36
+ rsync -av --delete services/riprap-models/ \
37
+ root@129.212.181.238:/workspace/riprap-models/
38
+
39
+ # Install deltas + start uvicorn inside the terramind container
40
+ ssh root@129.212.181.238 bash <<'REMOTE'
41
+ docker cp /workspace/riprap-models terramind:/workspace/
42
+ docker exec -d -e RIPRAP_MODELS_API_KEY="$RIPRAP_MODELS_API_KEY" terramind \
43
+ bash -c "cd /workspace/riprap-models && \
44
+ pip install --no-cache-dir -r requirements.txt && \
45
+ uvicorn main:app --host 0.0.0.0 --port 7860 --log-level info \
46
+ > /workspace/riprap-models.log 2>&1"
47
+ REMOTE
48
+ ```
49
+
50
+ Service binds inside the container at `:7860`; the host port
51
+ mapping was set when the `terramind` container was created
52
+ (`docker run -p 7860:7860 ...`), so externally the service is at
53
+ `http://129.212.181.238:7860`.
54
+
55
+ ## Local app config
56
+
57
+ Set in either env or HF Space variables:
58
+
59
+ ```
60
+ RIPRAP_ML_BACKEND = remote
61
+ RIPRAP_ML_BASE_URL = http://129.212.181.238:7860
62
+ RIPRAP_ML_API_KEY = <bearer>
63
+ ```
64
+
65
+ `app/inference.py` posts to those endpoints; specialists fall back
66
+ to local in-process model loads when the service is unreachable.
services/riprap-models/main.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Riprap Models β€” GPU inference microservice.
2
+
3
+ Runs on the AMD MI300X droplet alongside vLLM, exposes one HTTP
4
+ endpoint per model class consumed by the Riprap FastAPI app's
5
+ specialists. The local app routes through this service when
6
+ RIPRAP_ML_BACKEND=remote (or =auto with the service reachable),
7
+ keeping all GPU-accelerable forward passes on the MI300X β€” Granite
8
+ 4.1 (LLM), Prithvi-NYC-Pluvial (segmentation), TerraMind LULC +
9
+ Buildings + Synthesis (LoRA), Granite TTM r2 (forecasts), Granite
10
+ Embedding 278M (RAG), and GLiNER (typed extraction).
11
+
12
+ Authoritative bearer-token auth same as vLLM. Same env-var shape so
13
+ the same secret can be reused across both services on a Space.
14
+
15
+ Service contract (mirrors app/inference.py):
16
+
17
+ GET /healthz β†’ {ok: true, models_loaded: [...]}
18
+ POST /v1/prithvi-pluvial β†’ see _prithvi_pluvial below
19
+ POST /v1/terramind β†’ adapter dispatch (lulc/buildings/synth)
20
+ POST /v1/ttm-forecast β†’ model dispatch (zero_shot_battery, ...)
21
+ POST /v1/granite-embed β†’ batch text β†’ 768-d vectors
22
+ POST /v1/gliner-extract β†’ text + labels β†’ typed entities
23
+
24
+ Model loading is lazy + cached per-process. The first call to a given
25
+ model pays the cold-load cost (~5-30 s); subsequent calls reuse the
26
+ in-memory instance. ROCm device binding goes through torch's CUDA
27
+ shim β€” `cuda` is the ROCm device when running on a ROCm-built torch.
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import base64
32
+ import logging
33
+ import os
34
+ import threading
35
+ import time
36
+ from contextlib import asynccontextmanager
37
+ from typing import Any
38
+
39
+ import numpy as np
40
+ from fastapi import Depends, FastAPI, HTTPException, Header
41
+ from pydantic import BaseModel
42
+
43
+ log = logging.getLogger("riprap.models")
44
+ logging.basicConfig(
45
+ level=os.environ.get("RIPRAP_MODELS_LOG", "INFO").upper(),
46
+ format="%(asctime)s %(levelname)-5s %(name)s: %(message)s",
47
+ )
48
+
49
+ # Auth β€” same shape as vLLM. Set RIPRAP_MODELS_API_KEY in the
50
+ # `docker run` env. When empty, the service runs unauthenticated
51
+ # (only sane for localhost-only deployments).
52
+ _AUTH_TOKEN = os.environ.get("RIPRAP_MODELS_API_KEY", "")
53
+
54
+ # Device. ROCm-built torch reports CUDA-style symbols; "cuda" maps to
55
+ # the first ROCm device on the MI300X.
56
+ _DEVICE = os.environ.get("RIPRAP_MODELS_DEVICE", "cuda")
57
+
58
+
59
+ def _require_auth(authorization: str | None = Header(default=None)) -> None:
60
+ if not _AUTH_TOKEN:
61
+ return
62
+ if not authorization or not authorization.startswith("Bearer "):
63
+ raise HTTPException(status_code=401, detail="Missing bearer token")
64
+ if authorization[7:].strip() != _AUTH_TOKEN:
65
+ raise HTTPException(status_code=401, detail="Invalid bearer token")
66
+
67
+
68
+ # ---- Lazy model singletons --------------------------------------------------
69
+ #
70
+ # Each model has a `_load_<name>()` that returns the in-memory instance
71
+ # (locking on a per-model threading.Lock so concurrent first-call
72
+ # requests don't double-load). Callers grab via `_get_<name>()`.
73
+
74
+ _LOCKS = {
75
+ "prithvi": threading.Lock(),
76
+ "terramind_lulc": threading.Lock(),
77
+ "terramind_buildings": threading.Lock(),
78
+ "terramind_synth": threading.Lock(),
79
+ "ttm": threading.Lock(),
80
+ "granite_embed": threading.Lock(),
81
+ "gliner": threading.Lock(),
82
+ }
83
+ _INSTANCES: dict[str, Any] = {}
84
+
85
+
86
+ def _decode_array(b64: str, shape: list[int], dtype: str = "float32") -> np.ndarray:
87
+ raw = base64.b64decode(b64)
88
+ return np.frombuffer(raw, dtype=dtype).reshape(shape)
89
+
90
+
91
+ def _to_device(t):
92
+ """Move a torch tensor to the configured device. No-op for CPU."""
93
+ if _DEVICE == "cpu":
94
+ return t
95
+ try:
96
+ import torch
97
+ if torch.cuda.is_available():
98
+ return t.to("cuda")
99
+ except Exception as e:
100
+ log.warning("device move skipped: %s", e)
101
+ return t
102
+
103
+
104
+ # ---- Prithvi-NYC-Pluvial v2 -------------------------------------------------
105
+
106
+ def _load_prithvi():
107
+ if "prithvi" in _INSTANCES:
108
+ return _INSTANCES["prithvi"]
109
+ with _LOCKS["prithvi"]:
110
+ if "prithvi" in _INSTANCES:
111
+ return _INSTANCES["prithvi"]
112
+ log.info("prithvi: cold load (msradam/Prithvi-EO-2.0-NYC-Pluvial)")
113
+ import importlib.util
114
+
115
+ from huggingface_hub import hf_hub_download
116
+ from terratorch.cli_tools import LightningInferenceModel
117
+
118
+ BASE_REPO = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
119
+ V2_REPO = "msradam/Prithvi-EO-2.0-NYC-Pluvial"
120
+
121
+ # Use the IBM-NASA base config + v2 ckpt. Mirrors
122
+ # app/flood_layers/prithvi_live.py:_ensure_model().
123
+ base_config = hf_hub_download(BASE_REPO, "config.yaml")
124
+ inference_py = hf_hub_download(BASE_REPO, "inference.py")
125
+
126
+ v2_yaml = None
127
+ v2_ckpt = None
128
+ for name in ("prithvi_nyc_phase14.yaml", "config.yaml"):
129
+ try:
130
+ v2_yaml = hf_hub_download(V2_REPO, name); break
131
+ except Exception:
132
+ continue
133
+ for name in ("prithvi_nyc_pluvial_v2.ckpt", "best_val_loss.ckpt", "model.ckpt"):
134
+ try:
135
+ v2_ckpt = hf_hub_download(V2_REPO, name); break
136
+ except Exception:
137
+ continue
138
+ if v2_yaml and v2_ckpt:
139
+ log.info("prithvi: building from v2 yaml=%s ckpt=%s", v2_yaml, v2_ckpt)
140
+ m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt)
141
+ else:
142
+ log.info("prithvi: v2 unavailable, falling back to base")
143
+ base_ckpt = hf_hub_download(
144
+ BASE_REPO, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt")
145
+ m = LightningInferenceModel.from_config(base_config, base_ckpt)
146
+ m.model.eval()
147
+ try:
148
+ import torch
149
+ if _DEVICE == "cuda" and torch.cuda.is_available():
150
+ m.model.cuda()
151
+ except Exception:
152
+ log.exception("prithvi: cuda move failed; staying on cpu")
153
+
154
+ spec = importlib.util.spec_from_file_location("_prithvi_inference",
155
+ inference_py)
156
+ mod = importlib.util.module_from_spec(spec)
157
+ spec.loader.exec_module(mod)
158
+ _INSTANCES["prithvi"] = (m, mod.run_model)
159
+ log.info("prithvi: ready")
160
+ return _INSTANCES["prithvi"]
161
+
162
+
163
+ class PrithviIn(BaseModel):
164
+ s2: str
165
+ shape: list[int]
166
+ scene_id: str | None = None
167
+ scene_datetime: str | None = None
168
+ cloud_cover: float | None = None
169
+
170
+
171
+ def _prithvi_pluvial(payload: PrithviIn) -> dict[str, Any]:
172
+ t0 = time.time()
173
+ m, run_model = _load_prithvi()
174
+ chip = _decode_array(payload.s2, payload.shape, "float32")
175
+ # Sen1Floods11 expects [1, 6, 1, H, W]
176
+ if chip.ndim == 3:
177
+ chip = chip[None, :, None, :, :]
178
+ pred_t = run_model(chip, None, None, m.model, m.datamodule, chip.shape[-1])
179
+ pred = pred_t[0].cpu().numpy().astype("uint8")
180
+ pct_full = float(100.0 * pred.mean())
181
+ # Center-disk fraction (500 m at 10 m/px β†’ 50 px radius from chip center).
182
+ h, w = pred.shape
183
+ yy, xx = np.indices(pred.shape)
184
+ cy, cx = h // 2, w // 2
185
+ dist = np.sqrt((yy - cy) ** 2 + (xx - cx) ** 2)
186
+ mask = dist <= min(50, min(h, w) // 4)
187
+ pct_500m = float(100.0 * pred[mask].mean()) if mask.any() else pct_full
188
+ return {
189
+ "ok": True,
190
+ "elapsed_s": round(time.time() - t0, 2),
191
+ "device": _DEVICE,
192
+ "pct_water_within_500m": round(pct_500m, 3),
193
+ "pct_water_full": round(pct_full, 3),
194
+ "scene_id": payload.scene_id,
195
+ "scene_datetime": payload.scene_datetime,
196
+ "cloud_cover": payload.cloud_cover,
197
+ "shape": [int(h), int(w)],
198
+ }
199
+
200
+
201
+ # ---- TerraMind (lulc / buildings / synthesis) -------------------------------
202
+
203
+ _TERRAMIND_REPO = "msradam/TerraMind-NYC-Adapters"
204
+ _TERRAMIND_SPECS = {
205
+ "lulc": {"subdir": "lulc_nyc", "num_classes": 5,
206
+ "labels": ["Trees", "Cropland", "Built", "Bare", "Water"]},
207
+ "buildings": {"subdir": "buildings_nyc", "num_classes": 2,
208
+ "labels": ["Background", "Building"]},
209
+ }
210
+
211
+
212
+ def _load_terramind(adapter: str):
213
+ key = f"terramind_{adapter}"
214
+ if key in _INSTANCES:
215
+ return _INSTANCES[key]
216
+ with _LOCKS.get(key, _LOCKS.get("terramind_lulc")):
217
+ if key in _INSTANCES:
218
+ return _INSTANCES[key]
219
+ log.info("terramind/%s: cold load", adapter)
220
+ from huggingface_hub import snapshot_download
221
+ from peft import LoraConfig, inject_adapter_in_model
222
+ from safetensors.torch import load_file
223
+ from terratorch.tasks import SemanticSegmentationTask
224
+
225
+ spec = _TERRAMIND_SPECS[adapter]
226
+ adapter_root = snapshot_download(
227
+ _TERRAMIND_REPO, allow_patterns=[f"{spec['subdir']}/*"])
228
+ task = SemanticSegmentationTask(
229
+ model_factory="EncoderDecoderFactory",
230
+ model_args=dict(
231
+ backbone="terramind_v1_base",
232
+ backbone_pretrained=True,
233
+ backbone_modalities=["S2L2A", "S1RTC", "DEM"],
234
+ backbone_use_temporal=True,
235
+ backbone_temporal_pooling="concat",
236
+ backbone_temporal_n_timestamps=4,
237
+ necks=[
238
+ {"name": "SelectIndices", "indices": [2, 5, 8, 11]},
239
+ {"name": "ReshapeTokensToImage", "remove_cls_token": False},
240
+ {"name": "LearnedInterpolateToPyramidal"},
241
+ ],
242
+ decoder="UNetDecoder",
243
+ decoder_channels=[512, 256, 128, 64],
244
+ head_dropout=0.1,
245
+ num_classes=spec["num_classes"],
246
+ ),
247
+ loss="ce", lr=1e-4, freeze_backbone=False, freeze_decoder=False,
248
+ )
249
+ inject_adapter_in_model(LoraConfig(
250
+ r=16, lora_alpha=32, lora_dropout=0.05,
251
+ target_modules=["attn.qkv", "attn.proj"], bias="none",
252
+ ), task.model.encoder)
253
+ adapter_dir = f"{adapter_root}/{spec['subdir']}"
254
+ lora = load_file(f"{adapter_dir}/adapter_model.safetensors")
255
+ head = load_file(f"{adapter_dir}/decoder_head.safetensors")
256
+ task.model.encoder.load_state_dict(
257
+ {k.removeprefix("encoder."): v for k, v in lora.items()
258
+ if k.startswith("encoder.")}, strict=False)
259
+ for sub in ("decoder", "neck", "head", "aux_heads"):
260
+ ss = {k[len(sub) + 1:]: v for k, v in head.items()
261
+ if k.startswith(sub + ".")}
262
+ if ss and hasattr(task.model, sub):
263
+ getattr(task.model, sub).load_state_dict(ss, strict=False)
264
+ try:
265
+ import torch
266
+ if _DEVICE == "cuda" and torch.cuda.is_available():
267
+ task = task.to("cuda")
268
+ except Exception:
269
+ log.exception("terramind: cuda move failed")
270
+ task.eval()
271
+ _INSTANCES[key] = task
272
+ log.info("terramind/%s: ready", adapter)
273
+ return task
274
+
275
+
276
+ class TerramindIn(BaseModel):
277
+ adapter: str # "lulc" | "buildings" | "synthesis"
278
+ s2: str
279
+ s2_shape: list[int]
280
+ s1: str | None = None
281
+ s1_shape: list[int] | None = None
282
+ dem: str | None = None
283
+ dem_shape: list[int] | None = None
284
+
285
+
286
+ def _build_chip_tensor(np_arr, n_timesteps: int = 4):
287
+ import torch
288
+ t = torch.from_numpy(np_arr).float().unsqueeze(1) # add T dim
289
+ if t.shape[1] == 1:
290
+ t = t.repeat(1, n_timesteps, 1, 1)
291
+ return t.unsqueeze(0) # add batch
292
+
293
+
294
+ def _terramind_inference(payload: TerramindIn) -> dict[str, Any]:
295
+ t0 = time.time()
296
+ if payload.adapter not in _TERRAMIND_SPECS:
297
+ raise HTTPException(status_code=400,
298
+ detail=f"unknown adapter {payload.adapter!r}")
299
+ task = _load_terramind(payload.adapter)
300
+ spec = _TERRAMIND_SPECS[payload.adapter]
301
+
302
+ s2 = _decode_array(payload.s2, payload.s2_shape)
303
+ chips = {"S2L2A": _to_device(_build_chip_tensor(s2))}
304
+ if payload.s1 and payload.s1_shape:
305
+ s1 = _decode_array(payload.s1, payload.s1_shape)
306
+ chips["S1RTC"] = _to_device(_build_chip_tensor(s1))
307
+ if payload.dem and payload.dem_shape:
308
+ dem = _decode_array(payload.dem, payload.dem_shape)
309
+ chips["DEM"] = _to_device(_build_chip_tensor(dem))
310
+
311
+ import torch
312
+ from terratorch.tasks.tiled_inference import tiled_inference
313
+
314
+ def _forward(x, **_extra):
315
+ out = task.model(x)
316
+ return out.output if hasattr(out, "output") else out
317
+ with torch.no_grad():
318
+ logits = tiled_inference(
319
+ _forward, chips, out_channels=spec["num_classes"],
320
+ h_crop=224, w_crop=224, h_stride=128, w_stride=128,
321
+ average_patches=True, blend_overlaps=True, padding="reflect",
322
+ )
323
+ pred = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype("uint8")
324
+ n = max(int(pred.size), 1)
325
+ fractions = {
326
+ spec["labels"][i]: round(100.0 * float((pred == i).sum()) / n, 2)
327
+ for i in range(spec["num_classes"])
328
+ }
329
+ fractions = {k: v for k, v in fractions.items() if v > 0}
330
+ dom_idx = int(max(range(spec["num_classes"]),
331
+ key=lambda i: int((pred == i).sum())))
332
+ return {
333
+ "ok": True,
334
+ "adapter": payload.adapter,
335
+ "elapsed_s": round(time.time() - t0, 2),
336
+ "device": _DEVICE,
337
+ "shape": list(pred.shape),
338
+ "n_pixels": int(pred.size),
339
+ "class_fractions": fractions,
340
+ "dominant_class": spec["labels"][dom_idx],
341
+ "dominant_pct": fractions.get(spec["labels"][dom_idx], 0.0),
342
+ # Buildings-specific stat (NaN-safe; 0 when not the buildings adapter).
343
+ "pct_buildings": round(100.0 * float((pred == 1).sum()) / n, 2)
344
+ if payload.adapter == "buildings" else None,
345
+ }
346
+
347
+
348
+ # ---- Granite TTM r2 ---------------------------------------------------------
349
+
350
+ _TTM_MODELS = {
351
+ "zero_shot_battery": "ibm-granite/granite-timeseries-ttm-r2",
352
+ "fine_tune_battery": "msradam/Granite-TTM-r2-Battery-Surge",
353
+ "weekly_311": "ibm-granite/granite-timeseries-ttm-r2",
354
+ "floodnet_recurrence": "ibm-granite/granite-timeseries-ttm-r2",
355
+ }
356
+
357
+
358
+ def _load_ttm(model_key: str):
359
+ key = f"ttm:{model_key}"
360
+ if key in _INSTANCES:
361
+ return _INSTANCES[key]
362
+ with _LOCKS["ttm"]:
363
+ if key in _INSTANCES:
364
+ return _INSTANCES[key]
365
+ log.info("ttm/%s: cold load", model_key)
366
+ if model_key == "fine_tune_battery":
367
+ from huggingface_hub import snapshot_download
368
+ from tsfm_public import TinyTimeMixerForPrediction
369
+ local_dir = snapshot_download(_TTM_MODELS[model_key])
370
+ m = TinyTimeMixerForPrediction.from_pretrained(local_dir).eval()
371
+ else:
372
+ from tsfm_public.toolkit.get_model import get_model
373
+ # Caller passes (context_length, prediction_length) β€” for the
374
+ # zero-shot & 311 & FloodNet specialists we let the toolkit
375
+ # pick the best matching pretrained config. Cache one per
376
+ # model_key to avoid duplicate loads.
377
+ m = get_model(_TTM_MODELS[model_key],
378
+ context_length=512, prediction_length=96).eval()
379
+ try:
380
+ import torch
381
+ if _DEVICE == "cuda" and torch.cuda.is_available():
382
+ m = m.to("cuda")
383
+ except Exception:
384
+ log.exception("ttm: cuda move failed")
385
+ _INSTANCES[key] = m
386
+ log.info("ttm/%s: ready", model_key)
387
+ return m
388
+
389
+
390
+ class TtmIn(BaseModel):
391
+ model: str # zero_shot_battery | fine_tune_battery | weekly_311 | floodnet_recurrence
392
+ history: list[float]
393
+ context_length: int
394
+ prediction_length: int
395
+ cadence: str = "h"
396
+
397
+
398
+ def _ttm_forecast(payload: TtmIn) -> dict[str, Any]:
399
+ t0 = time.time()
400
+ if payload.model not in _TTM_MODELS:
401
+ raise HTTPException(status_code=400,
402
+ detail=f"unknown model {payload.model!r}")
403
+ m = _load_ttm(payload.model)
404
+ import torch
405
+ series = np.array(payload.history, dtype="float32")
406
+ if len(series) < payload.context_length:
407
+ # Front-pad with the leading value so the model gets the right
408
+ # shape β€” caller-side fills are NaN-clean already, so this only
409
+ # extends a series whose history is shorter than context.
410
+ pad = np.full(payload.context_length - len(series), series[0]
411
+ if len(series) else 0.0, dtype="float32")
412
+ series = np.concatenate([pad, series])
413
+ series = series[-payload.context_length:]
414
+ x = torch.from_numpy(series).float().unsqueeze(0).unsqueeze(-1)
415
+ x = _to_device(x)
416
+ with torch.no_grad():
417
+ out = m(past_values=x)
418
+ fc = out.prediction_outputs.squeeze(-1).squeeze(0).cpu().numpy()
419
+ peak_idx = int(np.argmax(np.abs(fc)))
420
+ return {
421
+ "ok": True,
422
+ "model": payload.model,
423
+ "elapsed_s": round(time.time() - t0, 2),
424
+ "device": _DEVICE,
425
+ "context_length": payload.context_length,
426
+ "prediction_length": payload.prediction_length,
427
+ "cadence": payload.cadence,
428
+ "forecast": [round(float(v), 6) for v in fc.tolist()],
429
+ "peak_index": peak_idx,
430
+ "peak_value": round(float(fc[peak_idx]), 6),
431
+ }
432
+
433
+
434
+ # ---- Granite Embedding 278M -------------------------------------------------
435
+
436
+ _EMBED_REPO = "ibm-granite/granite-embedding-278m-multilingual"
437
+
438
+
439
+ def _load_embed():
440
+ if "granite_embed" in _INSTANCES:
441
+ return _INSTANCES["granite_embed"]
442
+ with _LOCKS["granite_embed"]:
443
+ if "granite_embed" in _INSTANCES:
444
+ return _INSTANCES["granite_embed"]
445
+ log.info("granite-embed: cold load")
446
+ from sentence_transformers import SentenceTransformer
447
+ m = SentenceTransformer(_EMBED_REPO,
448
+ device="cuda" if _DEVICE == "cuda" else "cpu")
449
+ _INSTANCES["granite_embed"] = m
450
+ log.info("granite-embed: ready")
451
+ return m
452
+
453
+
454
+ class EmbedIn(BaseModel):
455
+ texts: list[str]
456
+
457
+
458
+ def _granite_embed(payload: EmbedIn) -> dict[str, Any]:
459
+ t0 = time.time()
460
+ m = _load_embed()
461
+ vecs = m.encode(payload.texts, normalize_embeddings=True,
462
+ show_progress_bar=False)
463
+ return {
464
+ "ok": True,
465
+ "elapsed_s": round(time.time() - t0, 2),
466
+ "device": _DEVICE,
467
+ "n": len(payload.texts),
468
+ "dim": int(vecs.shape[-1]) if hasattr(vecs, "shape") else len(vecs[0]),
469
+ "vectors": [list(map(float, v)) for v in vecs],
470
+ }
471
+
472
+
473
+ # ---- GLiNER ----------------------------------------------------------------
474
+
475
+ _GLINER_REPO = "urchade/gliner_medium-v2.1"
476
+
477
+
478
+ def _load_gliner():
479
+ if "gliner" in _INSTANCES:
480
+ return _INSTANCES["gliner"]
481
+ with _LOCKS["gliner"]:
482
+ if "gliner" in _INSTANCES:
483
+ return _INSTANCES["gliner"]
484
+ log.info("gliner: cold load")
485
+ from gliner import GLiNER
486
+ m = GLiNER.from_pretrained(_GLINER_REPO)
487
+ try:
488
+ import torch
489
+ if _DEVICE == "cuda" and torch.cuda.is_available():
490
+ m = m.to("cuda")
491
+ except Exception:
492
+ log.exception("gliner: cuda move failed")
493
+ _INSTANCES["gliner"] = m
494
+ log.info("gliner: ready")
495
+ return m
496
+
497
+
498
+ class GlinerIn(BaseModel):
499
+ text: str
500
+ labels: list[str]
501
+
502
+
503
+ def _gliner_extract(payload: GlinerIn) -> dict[str, Any]:
504
+ t0 = time.time()
505
+ m = _load_gliner()
506
+ ents = m.predict_entities(payload.text, payload.labels)
507
+ return {
508
+ "ok": True,
509
+ "elapsed_s": round(time.time() - t0, 2),
510
+ "device": _DEVICE,
511
+ "entities": [
512
+ {"label": e["label"], "text": e["text"],
513
+ "start": int(e.get("start", 0)), "end": int(e.get("end", 0)),
514
+ "score": float(e.get("score", 0))}
515
+ for e in ents
516
+ ],
517
+ }
518
+
519
+
520
+ # ---- FastAPI app ------------------------------------------------------------
521
+
522
+ @asynccontextmanager
523
+ async def lifespan(_app: FastAPI):
524
+ log.info("riprap-models starting on device=%s auth=%s",
525
+ _DEVICE, "yes" if _AUTH_TOKEN else "no")
526
+ yield
527
+ log.info("riprap-models stopping")
528
+
529
+
530
+ app = FastAPI(title="riprap-models", version="0.4.5", lifespan=lifespan)
531
+
532
+
533
+ @app.get("/healthz")
534
+ def healthz():
535
+ return {"ok": True, "device": _DEVICE,
536
+ "models_loaded": sorted(_INSTANCES.keys())}
537
+
538
+
539
+ @app.post("/v1/prithvi-pluvial", dependencies=[Depends(_require_auth)])
540
+ def prithvi_pluvial_route(payload: PrithviIn):
541
+ return _prithvi_pluvial(payload)
542
+
543
+
544
+ @app.post("/v1/terramind", dependencies=[Depends(_require_auth)])
545
+ def terramind_route(payload: TerramindIn):
546
+ return _terramind_inference(payload)
547
+
548
+
549
+ @app.post("/v1/ttm-forecast", dependencies=[Depends(_require_auth)])
550
+ def ttm_forecast_route(payload: TtmIn):
551
+ return _ttm_forecast(payload)
552
+
553
+
554
+ @app.post("/v1/granite-embed", dependencies=[Depends(_require_auth)])
555
+ def granite_embed_route(payload: EmbedIn):
556
+ return _granite_embed(payload)
557
+
558
+
559
+ @app.post("/v1/gliner-extract", dependencies=[Depends(_require_auth)])
560
+ def gliner_extract_route(payload: GlinerIn):
561
+ return _gliner_extract(payload)
services/riprap-models/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Riprap Models β€” droplet inference service.
2
+ #
3
+ # Most heavy deps (torch+ROCm, terratorch, granite-tsfm, transformers,
4
+ # peft, safetensors, fastapi, uvicorn) are already in the `terramind`
5
+ # container's image. This list is only the deltas the service needs
6
+ # beyond that base β€” install with:
7
+ #
8
+ # docker exec terramind pip install -r /workspace/riprap-models/requirements.txt
9
+ fastapi-cli >= 0.0.5
10
+ gliner >= 0.2.6
11
+ sentence-transformers >= 5.0.0
12
+ huggingface_hub >= 0.34