fix(prithvi): dict-aware normalizer; expose riprap-models last_errors
Browse filesTwo follow-ups to the previous EO-fix commit. The TerraMind LoRA
fix (drop tiled_inference at 224Γ224) shipped clean and is now
firing on both Stones; Prithvi-NYC-Pluvial v2 was still erroring
with the same `'list' object has no attribute 'view'` despite
the tensor-typed Normalize.
Real root cause: IBM's run_model in inference.py calls
`datamodule.aug({'image': tensor})['image']` β passing a dict
and indexing the result. The previous patch wrapped a kornia
AugmentationSequential there, which in 0.7+ doesn't accept dict
input cleanly and trips the AttributeError deep inside its
internal storage on first augmentation apply.
Fix: drop kornia entirely from the v2 patch path. Replace it
with a 12-line hand-rolled `_DictNormalize` that does the same
arithmetic β `(img - mean) / std` β and explicitly handles both
dict and tensor input shapes. Identical math, fewer moving
parts, no kornia version skew. Applied symmetrically in:
services/riprap-models/main.py
app/flood_layers/prithvi_live.py
Also: surface diagnostics through the proxy so operators don't
need to grep container logs.
inference-vllm/proxy.py:
- GET /healthz now bubbles up `models_loaded` and
`last_errors` from riprap-models's healthz body
- GET /v1/diag (auth-required) forwards riprap-models's
operator-only diagnostic snapshot β what's loaded, last
per-stage error with traceback tail, CUDA memory state
per device
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- app/flood_layers/prithvi_live.py +30 -11
- services/riprap-models/main.py +31 -11
|
@@ -205,21 +205,40 @@ def _ensure_model():
|
|
| 205 |
if getattr(getattr(m, 'datamodule', None),
|
| 206 |
'test_transform', None) is None:
|
| 207 |
import albumentations as A
|
| 208 |
-
import
|
| 209 |
from albumentations.pytorch import ToTensorV2
|
| 210 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 211 |
_old = m.datamodule.aug
|
| 212 |
-
|
| 213 |
-
#
|
| 214 |
-
# .
|
| 215 |
-
#
|
| 216 |
-
#
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
log.info("prithvi_live: patched v2 datamodule transforms "
|
| 222 |
-
"for IBM inference.py compat")
|
| 223 |
else:
|
| 224 |
log.warning("prithvi_live: v2 yaml/ckpt not "
|
| 225 |
"discoverable in %s; falling back to base",
|
|
|
|
| 205 |
if getattr(getattr(m, 'datamodule', None),
|
| 206 |
'test_transform', None) is None:
|
| 207 |
import albumentations as A
|
| 208 |
+
import torch as _torch
|
| 209 |
from albumentations.pytorch import ToTensorV2
|
| 210 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 211 |
_old = m.datamodule.aug
|
| 212 |
+
|
| 213 |
+
# IBM's inference.py:188 calls
|
| 214 |
+
# `datamodule.aug({'image': tensor})['image']`.
|
| 215 |
+
# kornia's AugmentationSequential doesn't accept
|
| 216 |
+
# dict input cleanly and tripped the
|
| 217 |
+
# `'list' object has no attribute 'view'`
|
| 218 |
+
# error on the L4 deploy. Use a hand-rolled
|
| 219 |
+
# dict-aware normalizer instead β same math,
|
| 220 |
+
# fewer moving parts, no kornia version skew.
|
| 221 |
+
class _DictNormalize:
|
| 222 |
+
def __init__(self, mean, std):
|
| 223 |
+
self.mean = _torch.as_tensor(mean).view(-1, 1, 1).float()
|
| 224 |
+
self.std = _torch.as_tensor(std).view(-1, 1, 1).float()
|
| 225 |
+
|
| 226 |
+
def __call__(self, sample):
|
| 227 |
+
if isinstance(sample, dict):
|
| 228 |
+
img = sample["image"]
|
| 229 |
+
mean = self.mean.to(img.device)
|
| 230 |
+
std = self.std.to(img.device)
|
| 231 |
+
return {**sample, "image": (img - mean) / std}
|
| 232 |
+
mean = self.mean.to(sample.device)
|
| 233 |
+
std = self.std.to(sample.device)
|
| 234 |
+
return (sample - mean) / std
|
| 235 |
+
|
| 236 |
+
m.datamodule.aug = _DictNormalize(
|
| 237 |
+
_old.means.view(-1).detach().clone(),
|
| 238 |
+
_old.stds.view(-1).detach().clone(),
|
| 239 |
+
)
|
| 240 |
log.info("prithvi_live: patched v2 datamodule transforms "
|
| 241 |
+
"for IBM inference.py compat (dict-aware Normalize)")
|
| 242 |
else:
|
| 243 |
log.warning("prithvi_live: v2 yaml/ckpt not "
|
| 244 |
"discoverable in %s; falling back to base",
|
|
@@ -144,21 +144,41 @@ def _load_prithvi():
|
|
| 144 |
if getattr(getattr(m, 'datamodule', None),
|
| 145 |
'test_transform', None) is None:
|
| 146 |
import albumentations as A
|
| 147 |
-
import
|
| 148 |
from albumentations.pytorch import ToTensorV2
|
| 149 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 150 |
_old = m.datamodule.aug
|
| 151 |
-
|
| 152 |
-
#
|
| 153 |
-
#
|
| 154 |
-
#
|
| 155 |
-
#
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
log.info("prithvi: patched v2 datamodule transforms "
|
| 161 |
-
"for IBM inference.py compat")
|
| 162 |
else:
|
| 163 |
log.info("prithvi: v2 unavailable, falling back to base")
|
| 164 |
base_ckpt = hf_hub_download(
|
|
|
|
| 144 |
if getattr(getattr(m, 'datamodule', None),
|
| 145 |
'test_transform', None) is None:
|
| 146 |
import albumentations as A
|
| 147 |
+
import torch as _torch
|
| 148 |
from albumentations.pytorch import ToTensorV2
|
| 149 |
m.datamodule.test_transform = A.Compose([ToTensorV2()])
|
| 150 |
_old = m.datamodule.aug
|
| 151 |
+
|
| 152 |
+
# IBM's inference.py:188 calls
|
| 153 |
+
# `datamodule.aug({'image': tensor})['image']` β
|
| 154 |
+
# passing a dict and indexing the result. The previous
|
| 155 |
+
# patch wrapped a kornia AugmentationSequential here,
|
| 156 |
+
# which doesn't natively accept dict input and tripped
|
| 157 |
+
# `'list' object has no attribute 'view'` deep inside
|
| 158 |
+
# kornia's internal storage on first inference. Drop
|
| 159 |
+
# kornia entirely and use a hand-rolled dict-aware
|
| 160 |
+
# normalizer β fewer moving parts, identical math.
|
| 161 |
+
class _DictNormalize:
|
| 162 |
+
def __init__(self, mean, std):
|
| 163 |
+
self.mean = _torch.as_tensor(mean).view(-1, 1, 1).float()
|
| 164 |
+
self.std = _torch.as_tensor(std).view(-1, 1, 1).float()
|
| 165 |
+
|
| 166 |
+
def __call__(self, sample):
|
| 167 |
+
if isinstance(sample, dict):
|
| 168 |
+
img = sample["image"]
|
| 169 |
+
mean = self.mean.to(img.device)
|
| 170 |
+
std = self.std.to(img.device)
|
| 171 |
+
return {**sample, "image": (img - mean) / std}
|
| 172 |
+
mean = self.mean.to(sample.device)
|
| 173 |
+
std = self.std.to(sample.device)
|
| 174 |
+
return (sample - mean) / std
|
| 175 |
+
|
| 176 |
+
m.datamodule.aug = _DictNormalize(
|
| 177 |
+
_old.means.view(-1).detach().clone(),
|
| 178 |
+
_old.stds.view(-1).detach().clone(),
|
| 179 |
+
)
|
| 180 |
log.info("prithvi: patched v2 datamodule transforms "
|
| 181 |
+
"for IBM inference.py compat (dict-aware Normalize)")
|
| 182 |
else:
|
| 183 |
log.info("prithvi: v2 unavailable, falling back to base")
|
| 184 |
base_ckpt = hf_hub_download(
|