fix(synthesis): handle 3-D DEM from _read_dem_patch (it returns (1,H,W))
Browse filesThe 400 from /v1/terramind on synthesis came from the wire shape being
(1,1,1,224,224) 5-D instead of (1,1,224,224) 4-D. Root cause: HF's
_read_dem_patch returns a 3-D (1, H, W) array (it interpolates via
torch.functional.interpolate then .squeeze(0).numpy()), and the
synthesis remote-call site assumed it was 2-D and added two leading
dims via dem[None, None, :, :].
Branch on dem_arr.ndim and add exactly the dim(s) we need to reach
4-D (B, C, H, W). Also leave the debug log on the droplet so any
future dim-shape mismatch surfaces in container logs.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
app/context/terramind_synthesis.py
CHANGED
|
@@ -294,7 +294,22 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
|
|
| 294 |
# layer unpacks `B, C, H, W = x.shape` (verified against
|
| 295 |
# terratorch_terramind_v1_base_generate). DEM has C=1, so
|
| 296 |
# the on-the-wire shape is (1, 1, H, W) 4-D.
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
remote = _inf.terramind("synthesis", None, None, dem_remote,
|
| 299 |
timeout=timeout_s)
|
| 300 |
if remote.get("ok"):
|
|
|
|
| 294 |
# layer unpacks `B, C, H, W = x.shape` (verified against
|
| 295 |
# terratorch_terramind_v1_base_generate). DEM has C=1, so
|
| 296 |
# the on-the-wire shape is (1, 1, H, W) 4-D.
|
| 297 |
+
# `_read_dem_patch` returns a 3-D (1, H, W) array (it
|
| 298 |
+
# interpolates to CHIP_PX×CHIP_PX through a 4-D
|
| 299 |
+
# torch.functional.interpolate then squeezes the batch),
|
| 300 |
+
# so we add only the batch dim — not two.
|
| 301 |
+
import numpy as _np_local
|
| 302 |
+
dem_arr = _np_local.asarray(dem, dtype="float32")
|
| 303 |
+
if dem_arr.ndim == 2: # (H, W)
|
| 304 |
+
dem_remote = dem_arr[None, None, :, :]
|
| 305 |
+
elif dem_arr.ndim == 3: # (1, H, W)
|
| 306 |
+
dem_remote = dem_arr[None, :, :, :]
|
| 307 |
+
elif dem_arr.ndim == 4: # already (1, 1, H, W)
|
| 308 |
+
dem_remote = dem_arr
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"unexpected DEM shape {dem_arr.shape}; "
|
| 312 |
+
"expected 2/3/4-D")
|
| 313 |
remote = _inf.terramind("synthesis", None, None, dem_remote,
|
| 314 |
timeout=timeout_s)
|
| 315 |
if remote.get("ok"):
|
services/riprap-models/main.py
CHANGED
|
@@ -379,7 +379,13 @@ def _terramind_synthesis_inference(payload: TerramindIn) -> dict[str, Any]:
|
|
| 379 |
DEM tensor, and emits a class-logit raster keyed by the ESRI
|
| 380 |
2020 LULC tokenizer codebook."""
|
| 381 |
t0 = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
if not payload.dem or not payload.dem_shape:
|
|
|
|
|
|
|
| 383 |
raise HTTPException(status_code=400,
|
| 384 |
detail="synthesis requires `dem` + `dem_shape`")
|
| 385 |
model = _load_terramind_synthesis()
|
|
|
|
| 379 |
DEM tensor, and emits a class-logit raster keyed by the ESRI
|
| 380 |
2020 LULC tokenizer codebook."""
|
| 381 |
t0 = time.time()
|
| 382 |
+
log.info("terramind/synthesis: payload dem=%s dem_shape=%s s2=%s",
|
| 383 |
+
"set" if payload.dem else "None",
|
| 384 |
+
payload.dem_shape,
|
| 385 |
+
"set" if payload.s2 else "None")
|
| 386 |
if not payload.dem or not payload.dem_shape:
|
| 387 |
+
log.warning("terramind/synthesis: missing dem (dem=%s, shape=%s)",
|
| 388 |
+
bool(payload.dem), payload.dem_shape)
|
| 389 |
raise HTTPException(status_code=400,
|
| 390 |
detail="synthesis requires `dem` + `dem_shape`")
|
| 391 |
model = _load_terramind_synthesis()
|