fix(prithvi): patch v2 datamodule transforms for IBM inference.py compat
Browse filesprithvi_nyc_phase14.yaml uses GenericNonGeoSegmentationDataModule which
omits test_transform (leaves it None) and sets aug to terratorch Normalize
(only handles 4D/5D tensors). IBM inference.py:run_model() calls both on a
3D image dict, producing TypeError then ValueError for all neighborhood
queries locally.
Patch inside _ensure_model() when test_transform is None:
- test_transform → albumentations.Compose([ToTensorV2()]) matching IBM base
- aug → kornia.AugmentationSequential with v2-specific means/stds read from
the existing terratorch Normalize, preserving NYC Pluvial normalization
Full transform chain smoke-tested: (H,W,C)→CHW→BCHW dtype=float32.
Blast radius was all neighborhood intent queries, not Hollis specifically.
See /tmp/PRITHVI-HOLLIS-DIAGNOSIS.md for full analysis.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@@ -167,6 +167,25 @@ def _ensure_model():
|
|
| 167 |
log.info("prithvi_live: building v2 model from "
|
| 168 |
"yaml=%s ckpt=%s", v2_yaml, v2_ckpt)
|
| 169 |
m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
else:
|
| 171 |
log.warning("prithvi_live: v2 yaml/ckpt not "
|
| 172 |
"discoverable in %s; falling back to base",
|
|
|
|
| 167 |
log.info("prithvi_live: building v2 model from "
|
| 168 |
"yaml=%s ckpt=%s", v2_yaml, v2_ckpt)
|
| 169 |
m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt)
|
| 170 |
+
# prithvi_nyc_phase14.yaml uses GenericNonGeoSegmentationDataModule
|
| 171 |
+
# which omits test_transform (→ None) and uses terratorch Normalize
|
| 172 |
+
# for aug (only handles 4D/5D). IBM inference.py:run_model() calls
|
| 173 |
+
# both on a 3D dict. Patch both to match the IBM base contract:
|
| 174 |
+
# ToTensorV2 for test_transform; Kornia AugmentationSequential
|
| 175 |
+
# (accepts dict input, adds batch dim) for aug.
|
| 176 |
+
if getattr(getattr(m, 'datamodule', None),
|
| 177 |
+
'test_transform', None) is None:
|
| 178 |
+
import albumentations as A
|
| 179 |
+
import kornia.augmentation as _Ka
|
| 180 |
+
from albumentations.pytorch import ToTensorV2
|
| 181 |
+
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 182 |
+
_old = m.datamodule.aug
|
| 183 |
+
m.datamodule.aug = _Ka.AugmentationSequential(
|
| 184 |
+
_Ka.Normalize(_old.means.view(-1).tolist(),
|
| 185 |
+
_old.stds.view(-1).tolist()),
|
| 186 |
+
data_keys=None)
|
| 187 |
+
log.info("prithvi_live: patched v2 datamodule transforms "
|
| 188 |
+
"for IBM inference.py compat")
|
| 189 |
else:
|
| 190 |
log.warning("prithvi_live: v2 yaml/ckpt not "
|
| 191 |
"discoverable in %s; falling back to base",
|