seriffic Claude Sonnet 4.6 commited on
Commit
be27626
·
1 Parent(s): f7bf63f

fix(prithvi): patch v2 datamodule transforms for IBM inference.py compat

Browse files

prithvi_nyc_phase14.yaml uses GenericNonGeoSegmentationDataModule which
omits test_transform (leaves it None) and sets aug to terratorch Normalize
(only handles 4D/5D tensors). IBM inference.py:run_model() calls both on a
3D image dict, producing TypeError then ValueError for all neighborhood
queries locally.

Patch inside _ensure_model() when test_transform is None:
- test_transform → albumentations.Compose([ToTensorV2()]) matching IBM base
- aug → kornia.AugmentationSequential with v2-specific means/stds read from
the existing terratorch Normalize, preserving NYC Pluvial normalization

Full transform chain smoke-tested: (H,W,C)→CHW→BCHW dtype=float32.
Blast radius was all neighborhood intent queries, not Hollis specifically.
See /tmp/PRITHVI-HOLLIS-DIAGNOSIS.md for full analysis.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app/flood_layers/prithvi_live.py +19 -0
app/flood_layers/prithvi_live.py CHANGED
@@ -167,6 +167,25 @@ def _ensure_model():
167
  log.info("prithvi_live: building v2 model from "
168
  "yaml=%s ckpt=%s", v2_yaml, v2_ckpt)
169
  m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  else:
171
  log.warning("prithvi_live: v2 yaml/ckpt not "
172
  "discoverable in %s; falling back to base",
 
167
  log.info("prithvi_live: building v2 model from "
168
  "yaml=%s ckpt=%s", v2_yaml, v2_ckpt)
169
  m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt)
170
+ # prithvi_nyc_phase14.yaml uses GenericNonGeoSegmentationDataModule
171
+ # which omits test_transform (→ None) and uses terratorch Normalize
172
+ # for aug (only handles 4D/5D). IBM inference.py:run_model() calls
173
+ # both on a 3D dict. Patch both to match the IBM base contract:
174
+ # ToTensorV2 for test_transform; Kornia AugmentationSequential
175
+ # (accepts dict input, adds batch dim) for aug.
176
+ if getattr(getattr(m, 'datamodule', None),
177
+ 'test_transform', None) is None:
178
+ import albumentations as A
179
+ import kornia.augmentation as _Ka
180
+ from albumentations.pytorch import ToTensorV2
181
+ m.datamodule.test_transform = A.Compose([ToTensorV2()])
182
+ _old = m.datamodule.aug
183
+ m.datamodule.aug = _Ka.AugmentationSequential(
184
+ _Ka.Normalize(_old.means.view(-1).tolist(),
185
+ _old.stds.view(-1).tolist()),
186
+ data_keys=None)
187
+ log.info("prithvi_live: patched v2 datamodule transforms "
188
+ "for IBM inference.py compat")
189
  else:
190
  log.warning("prithvi_live: v2 yaml/ckpt not "
191
  "discoverable in %s; falling back to base",