seriffic Claude Opus 4.7 (1M context) commited on
Commit
ac8fbc5
·
1 Parent(s): 7ad8df4

fix(synthesis): DEM shape is (1, 1, H, W) per encoder probe

Browse files

Empirically verified against terratorch_terramind_v1_base_generate by
calling model({'DEM': torch.zeros(...)}) with each shape and reading
the encoder error:

(1, 224, 224) -> not enough values to unpack (got 3)
(1, 1, 224, 224) -> OK, LULC out (1, 10, 224, 224)
(1, 1, 1, 224, 224) -> too many values to unpack (got 5)

The encoder unpacks 'B, C, H, W = x.shape' so 4-D is required; DEM
has 1 channel, hence (1, 1, H, W). Local code's comment said this
but the implementation was off by one (.unsqueeze(0) only adds one
dim). Fix the local path too so it doesn't break if anyone runs
synthesis off the AMD path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

app/context/terramind_synthesis.py CHANGED
@@ -290,12 +290,11 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
290
  try:
291
  from app import inference as _inf
292
  if _inf.remote_enabled():
293
- # Local code does `torch.from_numpy(dem).unsqueeze(0)`
294
- # i.e. 2-D (H, W) → 3-D (1, H, W). The terramind v1 base
295
- # generative encoder adds the batch dim internally; sending
296
- # an extra leading dim makes its embedding layer trip on
297
- # `B, C, H, W = x.shape` (5-D in, expects 4). Match local.
298
- dem_remote = dem[None, :, :].astype("float32")
299
  remote = _inf.terramind("synthesis", None, None, dem_remote,
300
  timeout=timeout_s)
301
  if remote.get("ok"):
@@ -343,7 +342,9 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
343
  torch.manual_seed(DEFAULT_SEED)
344
 
345
  model = _ensure_model()
346
- dem_t = torch.from_numpy(dem).unsqueeze(0).float() # (1, 1, H, W)
 
 
347
  if time.time() - t0 > timeout_s:
348
  return {"ok": False, "skipped": "terramind exceeded budget"}
349
 
 
290
  try:
291
  from app import inference as _inf
292
  if _inf.remote_enabled():
293
+ # The terramind v1 base generative encoder embedding
294
+ # layer unpacks `B, C, H, W = x.shape` (verified against
295
+ # terratorch_terramind_v1_base_generate). DEM has C=1, so
296
+ # the on-the-wire shape is (1, 1, H, W) 4-D.
297
+ dem_remote = dem[None, None, :, :].astype("float32")
 
298
  remote = _inf.terramind("synthesis", None, None, dem_remote,
299
  timeout=timeout_s)
300
  if remote.get("ok"):
 
342
  torch.manual_seed(DEFAULT_SEED)
343
 
344
  model = _ensure_model()
345
+ # `dem` is 2-D (H, W) from `_read_dem_patch.src.read(1, ...)`. The
346
+ # terramind v1 base generative encoder wants (B=1, C=1, H, W) 4-D.
347
+ dem_t = torch.from_numpy(dem).unsqueeze(0).unsqueeze(0).float()
348
  if time.time() - t0 > timeout_s:
349
  return {"ok": False, "skipped": "terramind exceeded budget"}
350
 
services/riprap-models/main.py CHANGED
@@ -381,19 +381,17 @@ def _terramind_synthesis_inference(payload: TerramindIn) -> dict[str, Any]:
381
  import numpy as np
382
  import torch
383
  dem_t = torch.from_numpy(dem_np).float()
384
- # Match the local-inference shape contract from
385
- # app/context/terramind_synthesis.py:_ensure_model the v1 base
386
- # generative encoder wants 3-D (1, H, W) and adds the batch dim
387
- # internally. Anything more triggers `B, C, H, W = x.shape` to
388
- # unpack 5-D and fail in the embedding layer.
389
  if dem_t.ndim == 2:
390
- dem_t = dem_t.unsqueeze(0) # (H, W) -> (1, H, W)
391
- elif dem_t.ndim == 4 and dem_t.shape[0] == 1 and dem_t.shape[1] == 1:
392
- dem_t = dem_t.squeeze(0) # (1, 1, H, W) -> (1, H, W)
393
- elif dem_t.ndim != 3:
394
  raise HTTPException(status_code=400,
395
  detail=f"unexpected DEM shape {tuple(dem_t.shape)}; "
396
- f"expected (H, W) or (1, H, W)")
397
  dem_t = _to_device(dem_t)
398
 
399
  spec = _TERRAMIND_SPECS["synthesis"]
 
381
  import numpy as np
382
  import torch
383
  dem_t = torch.from_numpy(dem_np).float()
384
+ # The v1 base generative encoder unpacks `B, C, H, W = x.shape` —
385
+ # 4-D required. DEM has C=1, so canonical shape is (1, 1, H, W).
386
+ # Verified empirically against terratorch_terramind_v1_base_generate.
 
 
387
  if dem_t.ndim == 2:
388
+ dem_t = dem_t.unsqueeze(0).unsqueeze(0) # (H, W) -> (1, 1, H, W)
389
+ elif dem_t.ndim == 3:
390
+ dem_t = dem_t.unsqueeze(0) # (1, H, W) -> (1, 1, H, W)
391
+ elif dem_t.ndim != 4:
392
  raise HTTPException(status_code=400,
393
  detail=f"unexpected DEM shape {tuple(dem_t.shape)}; "
394
+ f"expected 4-D (B, C, H, W)")
395
  dem_t = _to_device(dem_t)
396
 
397
  spec = _TERRAMIND_SPECS["synthesis"]