fix(eo): unblock TerraMind LoRA + Prithvi v2 inference on the L4 Space
Browse filesThe diag wrapper from the previous commit surfaced two real
upstream bugs that the legacy "deps unavailable" / generic 500
masking had been hiding:
1. TerraMind LULC + Buildings: "Expected size 12 but got size 2
for tensor nu" inside terratorch's tiled_inference.
Root cause: terratorch.tasks.tiled_inference doesn't handle
the 5-D (B, C, T, H, W) modality tensor shape that
backbone_use_temporal=True / backbone_temporal_n_timestamps=4
produces, so it slices/concats incorrectly when fusing the
per-modality patches and trips on the 12 (S2 bands) vs 2 (S1
bands) channel mismatch.
Fix: the canonical chip from app/context/eo_chip_cache.py is
already exactly 224×224 — the model's native input
resolution. Tiling is unnecessary at that size. Mirror the
training-time inference at experiments/18_terramind_nyc_lora/
shared/inference_ensemble.py:155 (`task.model(x)` direct on
the modality dict). tiled_inference is preserved as a
fallback for chips larger than 224×224.
2. Prithvi-NYC-Pluvial v2: "AttributeError: 'list' object has no
attribute 'view'" on first inference.
Root cause: when patching the v2 datamodule's missing
`test_transform`, we replaced its kornia AugmentationSequential
with a Normalize built from .means.view(-1).tolist() and
.stds.view(-1).tolist(). kornia ≥ 0.7 stores those values
as-is and calls .view() on them at augment-apply time;
passing a Python list crashes with the AttributeError above.
Fix: pass the underlying torch.Tensor directly via
.view(-1).detach().clone() — same numeric data, but kornia
gets the type it expects.
Same patch applied to the local fallback at
app/flood_layers/prithvi_live.py for parity (the local path is
unreachable on the cpu-basic UI Space for unrelated reasons but
will be live on any deployment with a working CUDA torch).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@@ -209,9 +209,14 @@ def _ensure_model():
|
|
| 209 |
from albumentations.pytorch import ToTensorV2
|
| 210 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 211 |
_old = m.datamodule.aug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
m.datamodule.aug = _Ka.AugmentationSequential(
|
| 213 |
-
_Ka.Normalize(_old.means.view(-1).
|
| 214 |
-
_old.stds.view(-1).
|
| 215 |
data_keys=None)
|
| 216 |
log.info("prithvi_live: patched v2 datamodule transforms "
|
| 217 |
"for IBM inference.py compat")
|
|
|
|
| 209 |
from albumentations.pytorch import ToTensorV2
|
| 210 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 211 |
_old = m.datamodule.aug
|
| 212 |
+
# Pass torch.Tensor (not list via .tolist()).
|
| 213 |
+
# kornia 0.7+ stores values as-is and calls
|
| 214 |
+
# .view() on them at apply time; passing a
|
| 215 |
+
# Python list crashes with `AttributeError:
|
| 216 |
+
# 'list' object has no attribute 'view'`.
|
| 217 |
m.datamodule.aug = _Ka.AugmentationSequential(
|
| 218 |
+
_Ka.Normalize(_old.means.view(-1).detach().clone(),
|
| 219 |
+
_old.stds.view(-1).detach().clone()),
|
| 220 |
data_keys=None)
|
| 221 |
log.info("prithvi_live: patched v2 datamodule transforms "
|
| 222 |
"for IBM inference.py compat")
|
|
@@ -148,9 +148,14 @@ def _load_prithvi():
|
|
| 148 |
from albumentations.pytorch import ToTensorV2
|
| 149 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 150 |
_old = m.datamodule.aug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
m.datamodule.aug = _Ka.AugmentationSequential(
|
| 152 |
-
_Ka.Normalize(_old.means.view(-1).
|
| 153 |
-
_old.stds.view(-1).
|
| 154 |
data_keys=None)
|
| 155 |
log.info("prithvi: patched v2 datamodule transforms "
|
| 156 |
"for IBM inference.py compat")
|
|
@@ -478,17 +483,35 @@ def _terramind_inference(payload: TerramindIn) -> dict[str, Any]:
|
|
| 478 |
chips["DEM"] = _to_device(_build_chip_tensor(dem))
|
| 479 |
|
| 480 |
import torch
|
| 481 |
-
from terratorch.tasks.tiled_inference import tiled_inference
|
| 482 |
|
| 483 |
-
def _forward(x
|
| 484 |
out = task.model(x)
|
| 485 |
return out.output if hasattr(out, "output") else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
with torch.no_grad():
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
pred = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype("uint8")
|
| 493 |
n = max(int(pred.size), 1)
|
| 494 |
fractions = {
|
|
|
|
| 148 |
from albumentations.pytorch import ToTensorV2
|
| 149 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 150 |
_old = m.datamodule.aug
|
| 151 |
+
# Pass torch.Tensor (not Python list via .tolist()) —
|
| 152 |
+
# kornia 0.7+ stores the values as-is and calls .view()
|
| 153 |
+
# on them at apply time. With a list that fails with
|
| 154 |
+
# `AttributeError: 'list' object has no attribute 'view'`.
|
| 155 |
+
# Cloning detaches from the source datamodule's params.
|
| 156 |
m.datamodule.aug = _Ka.AugmentationSequential(
|
| 157 |
+
_Ka.Normalize(_old.means.view(-1).detach().clone(),
|
| 158 |
+
_old.stds.view(-1).detach().clone()),
|
| 159 |
data_keys=None)
|
| 160 |
log.info("prithvi: patched v2 datamodule transforms "
|
| 161 |
"for IBM inference.py compat")
|
|
|
|
| 483 |
chips["DEM"] = _to_device(_build_chip_tensor(dem))
|
| 484 |
|
| 485 |
import torch
|
|
|
|
| 486 |
|
| 487 |
+
def _forward(x):
|
| 488 |
out = task.model(x)
|
| 489 |
return out.output if hasattr(out, "output") else out
|
| 490 |
+
|
| 491 |
+
# Call the model directly — same shape contract as the
|
| 492 |
+
# training-time inference at
|
| 493 |
+
# experiments/18_terramind_nyc_lora/shared/inference_ensemble.py:
|
| 494 |
+
# the canonical chip is already the model's native 224×224 input
|
| 495 |
+
# in (B, C, T, H, W) form, so terratorch's `tiled_inference` is
|
| 496 |
+
# unnecessary and was the cause of the "Expected size 12 but got
|
| 497 |
+
# size 2" 5-D handling regression we hit on the L4 deploy.
|
| 498 |
+
# Tile only when the chip is bigger than the model resolution.
|
| 499 |
+
s2_t = chips["S2L2A"]
|
| 500 |
+
h_chip, w_chip = int(s2_t.shape[-2]), int(s2_t.shape[-1])
|
| 501 |
with torch.no_grad():
|
| 502 |
+
if h_chip == 224 and w_chip == 224:
|
| 503 |
+
logits = _forward(chips)
|
| 504 |
+
else:
|
| 505 |
+
from terratorch.tasks.tiled_inference import tiled_inference
|
| 506 |
+
|
| 507 |
+
def _forward_tile(x, **_extra):
|
| 508 |
+
return _forward(x)
|
| 509 |
+
|
| 510 |
+
logits = tiled_inference(
|
| 511 |
+
_forward_tile, chips, out_channels=spec["num_classes"],
|
| 512 |
+
h_crop=224, w_crop=224, h_stride=128, w_stride=128,
|
| 513 |
+
average_patches=True, blend_overlaps=True, padding="reflect",
|
| 514 |
+
)
|
| 515 |
pred = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype("uint8")
|
| 516 |
n = max(int(pred.size), 1)
|
| 517 |
fractions = {
|