seriffic Claude Opus 4.7 (1M context) commited on
Commit
0d831ce
·
1 Parent(s): fee1c30

fix(eo): unblock TerraMind LoRA + Prithvi v2 inference on the L4 Space

Browse files

The diag wrapper from the previous commit surfaced two real
upstream bugs that the legacy "deps unavailable" / generic 500
masking had been hiding:

1. TerraMind LULC + Buildings: "Expected size 12 but got size 2
for tensor nu" inside terratorch's tiled_inference.

Root cause: terratorch.tasks.tiled_inference doesn't handle
the 5-D (B, C, T, H, W) modality tensor shape that
backbone_use_temporal=True / backbone_temporal_n_timestamps=4
produces, so it slices/concats incorrectly when fusing the
per-modality patches and trips on the 12 (S2 bands) vs 2 (S1
bands) channel mismatch.

Fix: the canonical chip from app/context/eo_chip_cache.py is
already exactly 224×224 — the model's native input
resolution. Tiling is unnecessary at that size. Mirror the
training-time inference at experiments/18_terramind_nyc_lora/
shared/inference_ensemble.py:155 (`task.model(x)` direct on
the modality dict). tiled_inference is preserved as a
fallback for chips larger than 224×224.

2. Prithvi-NYC-Pluvial v2: "AttributeError: 'list' object has no
attribute 'view'" on first inference.

Root cause: when patching the v2 datamodule's missing
`test_transform`, we replaced its kornia AugmentationSequential
with a Normalize built from .means.view(-1).tolist() and
.stds.view(-1).tolist(). kornia ≥ 0.7 stores those values
as-is and calls .view() on them at augment-apply time;
passing a Python list crashes with the AttributeError above.

Fix: pass the underlying torch.Tensor directly via
.view(-1).detach().clone() — same numeric data, but kornia
gets the type it expects.

Same patch applied to the local fallback at
app/flood_layers/prithvi_live.py for parity (the local path is
unreachable on the cpu-basic UI Space for unrelated reasons but
will be live on any deployment with a working CUDA torch).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

app/flood_layers/prithvi_live.py CHANGED
@@ -209,9 +209,14 @@ def _ensure_model():
209
  from albumentations.pytorch import ToTensorV2
210
  m.datamodule.test_transform = A.Compose([ToTensorV2()])
211
  _old = m.datamodule.aug
 
 
 
 
 
212
  m.datamodule.aug = _Ka.AugmentationSequential(
213
- _Ka.Normalize(_old.means.view(-1).tolist(),
214
- _old.stds.view(-1).tolist()),
215
  data_keys=None)
216
  log.info("prithvi_live: patched v2 datamodule transforms "
217
  "for IBM inference.py compat")
 
209
  from albumentations.pytorch import ToTensorV2
210
  m.datamodule.test_transform = A.Compose([ToTensorV2()])
211
  _old = m.datamodule.aug
212
+ # Pass torch.Tensor (not list via .tolist()).
213
+ # kornia 0.7+ stores values as-is and calls
214
+ # .view() on them at apply time; passing a
215
+ # Python list crashes with `AttributeError:
216
+ # 'list' object has no attribute 'view'`.
217
  m.datamodule.aug = _Ka.AugmentationSequential(
218
+ _Ka.Normalize(_old.means.view(-1).detach().clone(),
219
+ _old.stds.view(-1).detach().clone()),
220
  data_keys=None)
221
  log.info("prithvi_live: patched v2 datamodule transforms "
222
  "for IBM inference.py compat")
services/riprap-models/main.py CHANGED
@@ -148,9 +148,14 @@ def _load_prithvi():
148
  from albumentations.pytorch import ToTensorV2
149
  m.datamodule.test_transform = A.Compose([ToTensorV2()])
150
  _old = m.datamodule.aug
 
 
 
 
 
151
  m.datamodule.aug = _Ka.AugmentationSequential(
152
- _Ka.Normalize(_old.means.view(-1).tolist(),
153
- _old.stds.view(-1).tolist()),
154
  data_keys=None)
155
  log.info("prithvi: patched v2 datamodule transforms "
156
  "for IBM inference.py compat")
@@ -478,17 +483,35 @@ def _terramind_inference(payload: TerramindIn) -> dict[str, Any]:
478
  chips["DEM"] = _to_device(_build_chip_tensor(dem))
479
 
480
  import torch
481
- from terratorch.tasks.tiled_inference import tiled_inference
482
 
483
- def _forward(x, **_extra):
484
  out = task.model(x)
485
  return out.output if hasattr(out, "output") else out
 
 
 
 
 
 
 
 
 
 
 
486
  with torch.no_grad():
487
- logits = tiled_inference(
488
- _forward, chips, out_channels=spec["num_classes"],
489
- h_crop=224, w_crop=224, h_stride=128, w_stride=128,
490
- average_patches=True, blend_overlaps=True, padding="reflect",
491
- )
 
 
 
 
 
 
 
 
492
  pred = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype("uint8")
493
  n = max(int(pred.size), 1)
494
  fractions = {
 
148
  from albumentations.pytorch import ToTensorV2
149
  m.datamodule.test_transform = A.Compose([ToTensorV2()])
150
  _old = m.datamodule.aug
151
+ # Pass torch.Tensor (not Python list via .tolist()) —
152
+ # kornia 0.7+ stores the values as-is and calls .view()
153
+ # on them at apply time. With a list that fails with
154
+ # `AttributeError: 'list' object has no attribute 'view'`.
155
+ # Cloning detaches from the source datamodule's params.
156
  m.datamodule.aug = _Ka.AugmentationSequential(
157
+ _Ka.Normalize(_old.means.view(-1).detach().clone(),
158
+ _old.stds.view(-1).detach().clone()),
159
  data_keys=None)
160
  log.info("prithvi: patched v2 datamodule transforms "
161
  "for IBM inference.py compat")
 
483
  chips["DEM"] = _to_device(_build_chip_tensor(dem))
484
 
485
  import torch
 
486
 
487
+ def _forward(x):
488
  out = task.model(x)
489
  return out.output if hasattr(out, "output") else out
490
+
491
+ # Call the model directly — same shape contract as the
492
+ # training-time inference at
493
+ # experiments/18_terramind_nyc_lora/shared/inference_ensemble.py:
494
+ # the canonical chip is already the model's native 224×224 input
495
+ # in (B, C, T, H, W) form, so terratorch's `tiled_inference` is
496
+ # unnecessary and was the cause of the "Expected size 12 but got
497
+ # size 2" 5-D handling regression we hit on the L4 deploy.
498
+ # Tile only when the chip is bigger than the model resolution.
499
+ s2_t = chips["S2L2A"]
500
+ h_chip, w_chip = int(s2_t.shape[-2]), int(s2_t.shape[-1])
501
  with torch.no_grad():
502
+ if h_chip == 224 and w_chip == 224:
503
+ logits = _forward(chips)
504
+ else:
505
+ from terratorch.tasks.tiled_inference import tiled_inference
506
+
507
+ def _forward_tile(x, **_extra):
508
+ return _forward(x)
509
+
510
+ logits = tiled_inference(
511
+ _forward_tile, chips, out_channels=spec["num_classes"],
512
+ h_crop=224, w_crop=224, h_stride=128, w_stride=128,
513
+ average_patches=True, blend_overlaps=True, padding="reflect",
514
+ )
515
  pred = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype("uint8")
516
  n = max(int(pred.size), 1)
517
  fractions = {