fix(synthesis): DEM shape is (1, 1, H, W) per encoder probe
Browse filesEmpirically verified against terratorch_terramind_v1_base_generate by
calling model({'DEM': torch.zeros(...)}) with each shape and reading
the encoder error:
(1, 224, 224) -> not enough values to unpack (got 3)
(1, 1, 224, 224) -> OK, LULC out (1, 10, 224, 224)
(1, 1, 1, 224, 224) -> too many values to unpack (got 5)
The encoder unpacks 'B, C, H, W = x.shape' so 4-D is required; DEM
has 1 channel, hence (1, 1, H, W). Local code's comment said this
but the implementation was off by one (.unsqueeze(0) only adds one
dim). Fix the local path too so it doesn't break if anyone runs
synthesis off the AMD path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
app/context/terramind_synthesis.py
CHANGED
|
@@ -290,12 +290,11 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
|
|
| 290 |
try:
|
| 291 |
from app import inference as _inf
|
| 292 |
if _inf.remote_enabled():
|
| 293 |
-
#
|
| 294 |
-
#
|
| 295 |
-
#
|
| 296 |
-
#
|
| 297 |
-
|
| 298 |
-
dem_remote = dem[None, :, :].astype("float32")
|
| 299 |
remote = _inf.terramind("synthesis", None, None, dem_remote,
|
| 300 |
timeout=timeout_s)
|
| 301 |
if remote.get("ok"):
|
|
@@ -343,7 +342,9 @@ def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]:
|
|
| 343 |
torch.manual_seed(DEFAULT_SEED)
|
| 344 |
|
| 345 |
model = _ensure_model()
|
| 346 |
-
|
|
|
|
|
|
|
| 347 |
if time.time() - t0 > timeout_s:
|
| 348 |
return {"ok": False, "skipped": "terramind exceeded budget"}
|
| 349 |
|
|
|
|
| 290 |
try:
|
| 291 |
from app import inference as _inf
|
| 292 |
if _inf.remote_enabled():
|
| 293 |
+
# The terramind v1 base generative encoder embedding
|
| 294 |
+
# layer unpacks `B, C, H, W = x.shape` (verified against
|
| 295 |
+
# terratorch_terramind_v1_base_generate). DEM has C=1, so
|
| 296 |
+
# the on-the-wire shape is (1, 1, H, W) 4-D.
|
| 297 |
+
dem_remote = dem[None, None, :, :].astype("float32")
|
|
|
|
| 298 |
remote = _inf.terramind("synthesis", None, None, dem_remote,
|
| 299 |
timeout=timeout_s)
|
| 300 |
if remote.get("ok"):
|
|
|
|
| 342 |
torch.manual_seed(DEFAULT_SEED)
|
| 343 |
|
| 344 |
model = _ensure_model()
|
| 345 |
+
# `dem` is 2-D (H, W) from `_read_dem_patch.src.read(1, ...)`. The
|
| 346 |
+
# terramind v1 base generative encoder wants (B=1, C=1, H, W) 4-D.
|
| 347 |
+
dem_t = torch.from_numpy(dem).unsqueeze(0).unsqueeze(0).float()
|
| 348 |
if time.time() - t0 > timeout_s:
|
| 349 |
return {"ok": False, "skipped": "terramind exceeded budget"}
|
| 350 |
|
services/riprap-models/main.py
CHANGED
|
@@ -381,19 +381,17 @@ def _terramind_synthesis_inference(payload: TerramindIn) -> dict[str, Any]:
|
|
| 381 |
import numpy as np
|
| 382 |
import torch
|
| 383 |
dem_t = torch.from_numpy(dem_np).float()
|
| 384 |
-
#
|
| 385 |
-
#
|
| 386 |
-
#
|
| 387 |
-
# internally. Anything more triggers `B, C, H, W = x.shape` to
|
| 388 |
-
# unpack 5-D and fail in the embedding layer.
|
| 389 |
if dem_t.ndim == 2:
|
| 390 |
-
dem_t = dem_t.unsqueeze(0)
|
| 391 |
-
elif dem_t.ndim ==
|
| 392 |
-
dem_t = dem_t.
|
| 393 |
-
elif dem_t.ndim !=
|
| 394 |
raise HTTPException(status_code=400,
|
| 395 |
detail=f"unexpected DEM shape {tuple(dem_t.shape)}; "
|
| 396 |
-
f"expected (
|
| 397 |
dem_t = _to_device(dem_t)
|
| 398 |
|
| 399 |
spec = _TERRAMIND_SPECS["synthesis"]
|
|
|
|
| 381 |
import numpy as np
|
| 382 |
import torch
|
| 383 |
dem_t = torch.from_numpy(dem_np).float()
|
| 384 |
+
# The v1 base generative encoder unpacks `B, C, H, W = x.shape` —
|
| 385 |
+
# 4-D required. DEM has C=1, so canonical shape is (1, 1, H, W).
|
| 386 |
+
# Verified empirically against terratorch_terramind_v1_base_generate.
|
|
|
|
|
|
|
| 387 |
if dem_t.ndim == 2:
|
| 388 |
+
dem_t = dem_t.unsqueeze(0).unsqueeze(0) # (H, W) -> (1, 1, H, W)
|
| 389 |
+
elif dem_t.ndim == 3:
|
| 390 |
+
dem_t = dem_t.unsqueeze(0) # (1, H, W) -> (1, 1, H, W)
|
| 391 |
+
elif dem_t.ndim != 4:
|
| 392 |
raise HTTPException(status_code=400,
|
| 393 |
detail=f"unexpected DEM shape {tuple(dem_t.shape)}; "
|
| 394 |
+
f"expected 4-D (B, C, H, W)")
|
| 395 |
dem_t = _to_device(dem_t)
|
| 396 |
|
| 397 |
spec = _TERRAMIND_SPECS["synthesis"]
|