seriffic Claude Opus 4.7 (1M context) commited on
Commit
eea4d6e
·
1 Parent(s): 123d27f

feat: terramind_synthesis now routes through droplet remote inference

Browse files

Synthesis (DEM -> generative LULC via TerraMind v1 base) was the only
EO specialist still local-only. On the HF Space, terratorch's import
chain crashes with `RuntimeError: operator torchvision::nms does not
exist` because torchvision's C extension can't load against our CPU
torch wheel — so synthesis returned a clean skip rather than running.

Add a remote dispatch path:

- services/riprap-models/main.py: extend `_TERRAMIND_SPECS` with the
`synthesis` entry (10 ESRI LULC labels), add `_load_terramind_synthesis`
that pulls the v1 base generate model from terratorch's
FULL_MODEL_REGISTRY, and `_terramind_synthesis_inference` that
takes a 4-D (B, 1, H, W) DEM and emits class fractions matching the
local code's response shape. Same /v1/terramind dispatch handles
all three adapter names now (lulc / buildings / synthesis).

- TerramindIn schema: make all modality fields optional. lulc/buildings
still require s2; synthesis requires dem; the dispatch enforces.

- app/inference.py: terramind() accepts None for s2l2a so the synthesis
call site doesn't fabricate a placeholder S2 chip.

- app/context/terramind_synthesis.py: try the remote path first via
app.inference.terramind('synthesis', dem=...). Local path is kept as
a fallback for environments where terratorch DOES load.

Hot-patched on the droplet (docker cp + docker restart). Source
commit so scripts/deploy_droplet.sh picks it up next bring-up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

app/context/terramind_synthesis.py CHANGED
@@ -273,8 +273,6 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
273
  """
274
  if not ENABLE:
275
  return {"ok": False, "skipped": "RIPRAP_TERRAMIND_ENABLE=0"}
276
- if not _DEPS_OK:
277
- return {"ok": False, "skipped": f"deps unavailable: {_DEPS_MISSING}"}
278
  t0 = time.time()
279
  try:
280
  import numpy as np
@@ -284,6 +282,58 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
284
  dem, bounds_4326 = patch
285
  dem_mean = float(dem.mean())
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  import torch
288
  random.seed(DEFAULT_SEED)
289
  torch.manual_seed(DEFAULT_SEED)
 
273
  """
274
  if not ENABLE:
275
  return {"ok": False, "skipped": "RIPRAP_TERRAMIND_ENABLE=0"}
 
 
276
  t0 = time.time()
277
  try:
278
  import numpy as np
 
282
  dem, bounds_4326 = patch
283
  dem_mean = float(dem.mean())
284
 
285
+ # v0.4.5+ — try the MI300X inference service first if configured.
286
+ # The droplet's /v1/terramind dispatch handles adapter='synthesis'
287
+ # via _terramind_synthesis_inference (DEM -> generative LULC). On
288
+ # the HF Space terratorch's torchvision binary doesn't load, so
289
+ # this is the only working path there.
290
+ try:
291
+ from app import inference as _inf
292
+ if _inf.remote_enabled():
293
+ # Local code uses (1, 1, H, W); send the same shape.
294
+ dem_remote = dem[None, None, :, :].astype("float32")
295
+ remote = _inf.terramind("synthesis", None, None, dem_remote,
296
+ timeout=timeout_s)
297
+ if remote.get("ok"):
298
+ elapsed = round(time.time() - t0, 2)
299
+ out = {
300
+ "ok": True,
301
+ "synthetic_modality": True,
302
+ "tim_chain": ["DEM", "LULC_synthetic"],
303
+ "diffusion_steps": remote.get("diffusion_steps",
304
+ DEFAULT_STEPS),
305
+ "diffusion_seed": DEFAULT_SEED,
306
+ "dem_mean_m": round(dem_mean, 2),
307
+ "class_fractions": remote.get("class_fractions") or {},
308
+ "dominant_class": remote.get("dominant_class") or "unknown",
309
+ "dominant_pct": remote.get("dominant_pct") or 0.0,
310
+ "n_classes_observed": remote.get("n_classes_observed") or 0,
311
+ "chip_shape": remote.get("shape") or [],
312
+ "bounds_4326": list(bounds_4326),
313
+ "polygons_geojson": None,
314
+ "label_schema": remote.get("label_schema") or "",
315
+ "compute": f"remote · {remote.get('device', 'gpu')}",
316
+ "elapsed_s": elapsed,
317
+ }
318
+ return out
319
+ # remote returned non-ok — surface that signal directly
320
+ return {"ok": False,
321
+ "skipped": f"remote terramind synthesis non-ok: "
322
+ f"{remote.get('error') or remote.get('detail') or 'unknown'}",
323
+ "elapsed_s": round(time.time() - t0, 2)}
324
+ except _inf.RemoteUnreachable as e:
325
+ log.info("terramind_synthesis: remote unreachable (%s); local fallback", e)
326
+ except Exception as e:
327
+ log.exception("terramind_synthesis: remote call failed")
328
+ return {"ok": False,
329
+ "skipped": f"remote terramind synthesis error: "
330
+ f"{type(e).__name__}: {e}",
331
+ "elapsed_s": round(time.time() - t0, 2)}
332
+
333
+ # Local fallback — original path; only available where terratorch
334
+ # imports without the torchvision::nms RuntimeError.
335
+ if not _DEPS_OK:
336
+ return {"ok": False, "skipped": f"deps unavailable: {_DEPS_MISSING}"}
337
  import torch
338
  random.seed(DEFAULT_SEED)
339
  torch.manual_seed(DEFAULT_SEED)
app/inference.py CHANGED
@@ -167,15 +167,18 @@ def prithvi_pluvial(s2_chip, *, scene_id: str | None = None,
167
  }, timeout=timeout)
168
 
169
 
170
- def terramind(adapter: str, s2l2a, s1rtc=None, dem=None, *,
171
  timeout: float | None = None) -> dict[str, Any]:
172
  """Remote forward through TerraMind-NYC-Adapters (LULC or Buildings)
173
- or the v1 base (synthetic). `adapter` is one of: lulc, buildings,
174
- synthesis. Each modality is a numpy array or None."""
 
 
175
  payload: dict[str, Any] = {"adapter": adapter}
176
- s2_np = _to_numpy(s2l2a)
177
- payload["s2"] = _serialize_array(s2_np)
178
- payload["s2_shape"] = list(s2_np.shape)
 
179
  if s1rtc is not None:
180
  s1_np = _to_numpy(s1rtc)
181
  payload["s1"] = _serialize_array(s1_np)
 
167
  }, timeout=timeout)
168
 
169
 
170
+ def terramind(adapter: str, s2l2a=None, s1rtc=None, dem=None, *,
171
  timeout: float | None = None) -> dict[str, Any]:
172
  """Remote forward through TerraMind-NYC-Adapters (LULC or Buildings)
173
+ or the v1 base generative path (synthesis). `adapter` is one of:
174
+ lulc, buildings, synthesis. Each modality is a numpy array, torch
175
+ tensor, or None — `synthesis` only needs DEM; the LoRA adapters
176
+ need at minimum S2L2A."""
177
  payload: dict[str, Any] = {"adapter": adapter}
178
+ if s2l2a is not None:
179
+ s2_np = _to_numpy(s2l2a)
180
+ payload["s2"] = _serialize_array(s2_np)
181
+ payload["s2_shape"] = list(s2_np.shape)
182
  if s1rtc is not None:
183
  s1_np = _to_numpy(s1rtc)
184
  payload["s1"] = _serialize_array(s1_np)
services/riprap-models/main.py CHANGED
@@ -222,10 +222,54 @@ _TERRAMIND_SPECS = {
222
  "labels": ["Trees", "Cropland", "Built", "Bare", "Water"]},
223
  "buildings": {"subdir": "buildings_nyc", "num_classes": 2,
224
  "labels": ["Background", "Building"]},
 
 
 
 
 
 
 
225
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
 
228
  def _load_terramind(adapter: str):
 
 
229
  key = f"terramind_{adapter}"
230
  if key in _INSTANCES:
231
  return _INSTANCES[key]
@@ -291,8 +335,10 @@ def _load_terramind(adapter: str):
291
 
292
  class TerramindIn(BaseModel):
293
  adapter: str # "lulc" | "buildings" | "synthesis"
294
- s2: str
295
- s2_shape: list[int]
 
 
296
  s1: str | None = None
297
  s1_shape: list[int] | None = None
298
  dem: str | None = None
@@ -319,14 +365,84 @@ def _build_chip_tensor(np_arr, n_timesteps: int = 4):
319
  raise ValueError(f"unexpected chip shape {tuple(t.shape)}")
320
 
321
 
322
- def _terramind_inference(payload: TerramindIn) -> dict[str, Any]:
 
 
 
 
 
323
  t0 = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  if payload.adapter not in _TERRAMIND_SPECS:
325
  raise HTTPException(status_code=400,
326
  detail=f"unknown adapter {payload.adapter!r}")
 
 
 
327
  task = _load_terramind(payload.adapter)
328
  spec = _TERRAMIND_SPECS[payload.adapter]
329
 
 
 
 
 
330
  s2 = _decode_array(payload.s2, payload.s2_shape)
331
  chips = {"S2L2A": _to_device(_build_chip_tensor(s2))}
332
  if payload.s1 and payload.s1_shape:
 
222
  "labels": ["Trees", "Cropland", "Built", "Bare", "Water"]},
223
  "buildings": {"subdir": "buildings_nyc", "num_classes": 2,
224
  "labels": ["Background", "Building"]},
225
+ # Synthesis is the IBM/NASA base TerraMind generative path
226
+ # (DEM -> LULC), not a NYC fine-tune. Listed here so the same
227
+ # /v1/terramind dispatch handles it.
228
+ "synthesis": {"subdir": None, "num_classes": None,
229
+ "labels": ["Water", "Trees", "Grass", "Flooded vegetation",
230
+ "Crops", "Scrub/Shrub", "Built", "Bare ground",
231
+ "Snow/Ice", "Clouds"]},
232
  }
233
+ _TERRAMIND_SYNTH_TIMESTEPS = int(os.environ.get(
234
+ "RIPRAP_TERRAMIND_SYNTH_TIMESTEPS", "10"))
235
+
236
+
237
+ def _load_terramind_synthesis():
238
+ """Load the IBM/NASA base TerraMind v1 generative path
239
+ (DEM -> LULC) once. Different machinery from the LoRA adapters:
240
+ pulled via terratorch's FULL_MODEL_REGISTRY rather than
241
+ SemanticSegmentationTask + LoRA injection."""
242
+ key = "terramind_synthesis"
243
+ if key in _INSTANCES:
244
+ return _INSTANCES[key]
245
+ with _LOCKS.get(key, _LOCKS.get("terramind_lulc")):
246
+ if key in _INSTANCES:
247
+ return _INSTANCES[key]
248
+ log.info("terramind/synthesis: cold load (v1 base generate)")
249
+ import terratorch.models.backbones.terramind.model.terramind_register # noqa
250
+ from terratorch.registry import FULL_MODEL_REGISTRY
251
+ m = FULL_MODEL_REGISTRY.build(
252
+ "terratorch_terramind_v1_base_generate",
253
+ modalities=["DEM"],
254
+ output_modalities=["LULC"],
255
+ pretrained=True,
256
+ timesteps=_TERRAMIND_SYNTH_TIMESTEPS,
257
+ )
258
+ try:
259
+ import torch
260
+ if _DEVICE == "cuda" and torch.cuda.is_available():
261
+ m = m.to("cuda")
262
+ except Exception:
263
+ log.exception("terramind/synthesis: cuda move failed")
264
+ m.eval()
265
+ _INSTANCES[key] = m
266
+ log.info("terramind/synthesis: ready")
267
+ return m
268
 
269
 
270
  def _load_terramind(adapter: str):
271
+ if adapter == "synthesis":
272
+ return _load_terramind_synthesis()
273
  key = f"terramind_{adapter}"
274
  if key in _INSTANCES:
275
  return _INSTANCES[key]
 
335
 
336
  class TerramindIn(BaseModel):
337
  adapter: str # "lulc" | "buildings" | "synthesis"
338
+ # All modality fields optional — `synthesis` adapter only needs DEM,
339
+ # while lulc / buildings need at minimum S2L2A.
340
+ s2: str | None = None
341
+ s2_shape: list[int] | None = None
342
  s1: str | None = None
343
  s1_shape: list[int] | None = None
344
  dem: str | None = None
 
365
  raise ValueError(f"unexpected chip shape {tuple(t.shape)}")
366
 
367
 
368
+ def _terramind_synthesis_inference(payload: TerramindIn) -> dict[str, Any]:
369
+ """DEM -> LULC generative path. Different machinery from the LoRA
370
+ adapters: model is the v1 base generate stack pulled from
371
+ terratorch's FULL_MODEL_REGISTRY, takes a single 4-D (B, 1, H, W)
372
+ DEM tensor, and emits a class-logit raster keyed by the ESRI
373
+ 2020 LULC tokenizer codebook."""
374
  t0 = time.time()
375
+ if not payload.dem or not payload.dem_shape:
376
+ raise HTTPException(status_code=400,
377
+ detail="synthesis requires `dem` + `dem_shape`")
378
+ model = _load_terramind_synthesis()
379
+ dem_np = _decode_array(payload.dem, payload.dem_shape)
380
+
381
+ import numpy as np
382
+ import torch
383
+ dem_t = torch.from_numpy(dem_np).float()
384
+ # Accept (H, W), (1, H, W), or (1, 1, H, W) — the local code builds
385
+ # (1, 1, H, W) so that's the most common.
386
+ while dem_t.ndim < 4:
387
+ dem_t = dem_t.unsqueeze(0)
388
+ dem_t = _to_device(dem_t)
389
+
390
+ spec = _TERRAMIND_SPECS["synthesis"]
391
+ with torch.no_grad():
392
+ out = model({"DEM": dem_t},
393
+ timesteps=_TERRAMIND_SYNTH_TIMESTEPS,
394
+ verbose=False)
395
+ lulc = out["LULC"]
396
+ if hasattr(lulc, "detach"):
397
+ lulc = lulc.detach().cpu().numpy()
398
+ if lulc.ndim == 4:
399
+ lulc = lulc[0] # (n_classes, H, W)
400
+ class_idx = lulc.argmax(axis=0) # (H, W) per-pixel class
401
+ unique, counts = np.unique(class_idx, return_counts=True)
402
+ total = float(class_idx.size) or 1.0
403
+ fractions: dict[str, float] = {}
404
+ for u, c in zip(unique, counts):
405
+ u = int(u)
406
+ label = spec["labels"][u] if 0 <= u < len(spec["labels"]) else f"class_{u}"
407
+ fractions[label] = round(100.0 * c / total, 2)
408
+ ordered = dict(sorted(fractions.items(),
409
+ key=lambda kv: kv[1], reverse=True))
410
+ dominant_class = next(iter(ordered)) if ordered else "unknown"
411
+ dominant_pct = ordered.get(dominant_class, 0.0)
412
+ return {
413
+ "ok": True,
414
+ "adapter": "synthesis",
415
+ "elapsed_s": round(time.time() - t0, 2),
416
+ "device": _DEVICE,
417
+ "synthetic_modality": True,
418
+ "tim_chain": ["DEM", "LULC_synthetic"],
419
+ "diffusion_steps": _TERRAMIND_SYNTH_TIMESTEPS,
420
+ "class_fractions": ordered,
421
+ "dominant_class": dominant_class,
422
+ "dominant_pct": dominant_pct,
423
+ "n_classes_observed": len(ordered),
424
+ "shape": list(lulc.shape),
425
+ "n_pixels": int(class_idx.size),
426
+ "label_schema": "ESRI 2020-2022 Land Cover (tentative — TerraMind "
427
+ "tokenizer source confirms ESRI but not exact "
428
+ "label-to-index mapping)",
429
+ }
430
+
431
+
432
+ def _terramind_inference(payload: TerramindIn) -> dict[str, Any]:
433
  if payload.adapter not in _TERRAMIND_SPECS:
434
  raise HTTPException(status_code=400,
435
  detail=f"unknown adapter {payload.adapter!r}")
436
+ if payload.adapter == "synthesis":
437
+ return _terramind_synthesis_inference(payload)
438
+ t0 = time.time()
439
  task = _load_terramind(payload.adapter)
440
  spec = _TERRAMIND_SPECS[payload.adapter]
441
 
442
+ if not payload.s2 or not payload.s2_shape:
443
+ raise HTTPException(status_code=400,
444
+ detail=f"adapter {payload.adapter!r} requires "
445
+ f"`s2` + `s2_shape`")
446
  s2 = _decode_array(payload.s2, payload.s2_shape)
447
  chips = {"S2L2A": _to_device(_build_chip_tensor(s2))}
448
  if payload.s1 and payload.s1_shape: