seriffic Claude Opus 4.7 (1M context) commited on
Commit
09e6848
·
1 Parent(s): 2cbe57a

fix(synthesis): handle 3-D DEM from _read_dem_patch (it returns (1,H,W))

Browse files

The 400 from /v1/terramind on synthesis came from the wire shape being
(1,1,1,224,224) 5-D instead of (1,1,224,224) 4-D. Root cause: HF's
_read_dem_patch returns a 3-D (1, H, W) array (it interpolates via
torch.functional.interpolate then .squeeze(0).numpy()), and the
synthesis remote-call site assumed it was 2-D and added two leading
dims via dem[None, None, :, :].

Branch on dem_arr.ndim and add exactly the dim(s) we need to reach
4-D (B, C, H, W). Also leave the debug log on the droplet so any
future dim-shape mismatch surfaces in container logs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

app/context/terramind_synthesis.py CHANGED
@@ -294,7 +294,22 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
294
  # layer unpacks `B, C, H, W = x.shape` (verified against
295
  # terratorch_terramind_v1_base_generate). DEM has C=1, so
296
  # the on-the-wire shape is (1, 1, H, W) 4-D.
297
- dem_remote = dem[None, None, :, :].astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  remote = _inf.terramind("synthesis", None, None, dem_remote,
299
  timeout=timeout_s)
300
  if remote.get("ok"):
 
294
  # layer unpacks `B, C, H, W = x.shape` (verified against
295
  # terratorch_terramind_v1_base_generate). DEM has C=1, so
296
  # the on-the-wire shape is (1, 1, H, W) 4-D.
297
+ # `_read_dem_patch` returns a 3-D (1, H, W) array (it
298
+ # interpolates to CHIP_PX×CHIP_PX through a 4-D
299
+ # torch.functional.interpolate then squeezes the batch),
300
+ # so we add only the batch dim — not two.
301
+ import numpy as _np_local
302
+ dem_arr = _np_local.asarray(dem, dtype="float32")
303
+ if dem_arr.ndim == 2: # (H, W)
304
+ dem_remote = dem_arr[None, None, :, :]
305
+ elif dem_arr.ndim == 3: # (1, H, W)
306
+ dem_remote = dem_arr[None, :, :, :]
307
+ elif dem_arr.ndim == 4: # already (1, 1, H, W)
308
+ dem_remote = dem_arr
309
+ else:
310
+ raise ValueError(
311
+ f"unexpected DEM shape {dem_arr.shape}; "
312
+ "expected 2/3/4-D")
313
  remote = _inf.terramind("synthesis", None, None, dem_remote,
314
  timeout=timeout_s)
315
  if remote.get("ok"):
services/riprap-models/main.py CHANGED
@@ -379,7 +379,13 @@ def _terramind_synthesis_inference(payload: TerramindIn) -> dict[str, Any]:
379
  DEM tensor, and emits a class-logit raster keyed by the ESRI
380
  2020 LULC tokenizer codebook."""
381
  t0 = time.time()
 
 
 
 
382
  if not payload.dem or not payload.dem_shape:
 
 
383
  raise HTTPException(status_code=400,
384
  detail="synthesis requires `dem` + `dem_shape`")
385
  model = _load_terramind_synthesis()
 
379
  DEM tensor, and emits a class-logit raster keyed by the ESRI
380
  2020 LULC tokenizer codebook."""
381
  t0 = time.time()
382
+ log.info("terramind/synthesis: payload dem=%s dem_shape=%s s2=%s",
383
+ "set" if payload.dem else "None",
384
+ payload.dem_shape,
385
+ "set" if payload.s2 else "None")
386
  if not payload.dem or not payload.dem_shape:
387
+ log.warning("terramind/synthesis: missing dem (dem=%s, shape=%s)",
388
+ bool(payload.dem), payload.dem_shape)
389
  raise HTTPException(status_code=400,
390
  detail="synthesis requires `dem` + `dem_shape`")
391
  model = _load_terramind_synthesis()