fix(prithvi_live): build v2 model from v2 yaml, not base config
Browse filesC5 originally tried to load v2 ckpt weights into a model built from
the IBM-NASA base config.yaml. They're architecturally different —
v2 ships UNetDecoder + 2-class head; the base ships UperNet (PSP /
FPN). Loading produced a giant size-mismatch RuntimeError on
head.head.2 and dozens of missing/unexpected keys in decoder.fpn1 /
psp_modules / lateral_convs.
Fix: when the active REPO is not BASE_REPO, download the v2 yaml +
v2 ckpt directly from the published HF artefact and let
LightningInferenceModel.from_config build the architecture from the
v2 yaml itself. The yaml's data: section points at training-droplet
paths that don't exist locally, but the
GenericNonGeoSegmentationDataModule constructor only records paths;
splits aren't read until setup(), which we never call.
Falls back to the proven base path on any v2 failure (yaml not in
repo, datamodule constructor strict, etc.) so the specialist degrades
to v1 behaviour rather than no-opping.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- app/flood_layers/prithvi_live.py +70 -69
|
@@ -98,21 +98,30 @@ def warm():
|
|
| 98 |
|
| 99 |
|
| 100 |
def _ensure_model():
|
| 101 |
-
"""Load Prithvi-EO 2.0 once into RAM.
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
global _MODEL, _RUN_MODEL
|
| 117 |
if _MODEL is not None:
|
| 118 |
return _MODEL, _RUN_MODEL
|
|
@@ -121,63 +130,57 @@ def _ensure_model():
|
|
| 121 |
return _MODEL, _RUN_MODEL
|
| 122 |
import importlib.util
|
| 123 |
|
| 124 |
-
from huggingface_hub import hf_hub_download
|
|
|
|
| 125 |
log.info("prithvi_live: loading model from %s", REPO)
|
| 126 |
|
| 127 |
-
#
|
|
|
|
|
|
|
| 128 |
m = None
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
| 149 |
break
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
self.datamodule = None
|
| 167 |
-
|
| 168 |
-
m = _LightningTaskWrapper(task)
|
| 169 |
-
except Exception as e:
|
| 170 |
-
log.warning("prithvi_live: Lightning-ckpt load failed (%s); "
|
| 171 |
-
"falling back to raw-weights path", e)
|
| 172 |
-
|
| 173 |
-
# ---- Fallback: raw .pt + config.yaml (Sen1Floods11 base) ----
|
| 174 |
if m is None:
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
base, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt")
|
| 180 |
-
m = LightningInferenceModel.from_config(config_path, checkpoint)
|
| 181 |
|
| 182 |
m.model.eval()
|
| 183 |
if DEVICE == "cuda":
|
|
@@ -188,8 +191,6 @@ def _ensure_model():
|
|
| 188 |
except Exception:
|
| 189 |
log.exception("prithvi_live: cuda move failed")
|
| 190 |
|
| 191 |
-
# Inference helper lives only in the IBM-NASA base repo.
|
| 192 |
-
inference_py = hf_hub_download(BASE_REPO, "inference.py")
|
| 193 |
spec = importlib.util.spec_from_file_location("_prithvi_inference",
|
| 194 |
inference_py)
|
| 195 |
mod = importlib.util.module_from_spec(spec)
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
def _ensure_model():
|
| 101 |
+
"""Load Prithvi-EO 2.0 once into RAM.
|
| 102 |
+
|
| 103 |
+
The v2 NYC Pluvial fine-tune (`msradam/Prithvi-EO-2.0-NYC-Pluvial`)
|
| 104 |
+
is **architecturally distinct** from the IBM-NASA Sen1Floods11
|
| 105 |
+
base: v2 ships a `UNetDecoder` + 2-class head, the base ships a
|
| 106 |
+
UperNet with PSP / FPN. The model has to be built from each
|
| 107 |
+
repo's own config.yaml — there's no key-mapping shim that bridges
|
| 108 |
+
them.
|
| 109 |
+
|
| 110 |
+
Strategy:
|
| 111 |
+
|
| 112 |
+
1. If the active REPO != BASE_REPO, try to build from the v2
|
| 113 |
+
yaml + v2 ckpt. The v2 yaml's data: paths point at the
|
| 114 |
+
training droplet's filesystem (`/root/terramind_nyc/...`)
|
| 115 |
+
which doesn't exist locally; that's fine — the
|
| 116 |
+
GenericNonGeoSegmentationDataModule constructor only
|
| 117 |
+
records the paths, splits aren't read until `setup()`.
|
| 118 |
+
2. On any v2 failure (yaml not present, datamodule constructor
|
| 119 |
+
strict, weights mismatch), fall back to the base yaml + base
|
| 120 |
+
ckpt. The base path is the proven pre-C5 behaviour.
|
| 121 |
+
|
| 122 |
+
The shared `inference.run_model` helper is only published by the
|
| 123 |
+
IBM-NASA base repo; we always pull it from there.
|
| 124 |
+
"""
|
| 125 |
global _MODEL, _RUN_MODEL
|
| 126 |
if _MODEL is not None:
|
| 127 |
return _MODEL, _RUN_MODEL
|
|
|
|
| 130 |
return _MODEL, _RUN_MODEL
|
| 131 |
import importlib.util
|
| 132 |
|
| 133 |
+
from huggingface_hub import hf_hub_download
|
| 134 |
+
from terratorch.cli_tools import LightningInferenceModel
|
| 135 |
log.info("prithvi_live: loading model from %s", REPO)
|
| 136 |
|
| 137 |
+
# Inference helper only lives in the IBM-NASA base repo.
|
| 138 |
+
inference_py = hf_hub_download(BASE_REPO, "inference.py")
|
| 139 |
+
|
| 140 |
m = None
|
| 141 |
+
# ---- v2 path: yaml + ckpt from the published repo ----------
|
| 142 |
+
if REPO != BASE_REPO:
|
| 143 |
+
try:
|
| 144 |
+
# The v2 repo publishes `prithvi_nyc_phase14.yaml` and
|
| 145 |
+
# `prithvi_nyc_pluvial_v2.ckpt`. Be tolerant of small
|
| 146 |
+
# naming drift (best_val_loss.ckpt etc.) by probing.
|
| 147 |
+
v2_yaml = None
|
| 148 |
+
for name in ("prithvi_nyc_phase14.yaml",
|
| 149 |
+
"config.yaml", "phase14.yaml",
|
| 150 |
+
"prithvi_nyc_v2.yaml"):
|
| 151 |
+
try:
|
| 152 |
+
v2_yaml = hf_hub_download(REPO, name)
|
| 153 |
+
break
|
| 154 |
+
except Exception:
|
| 155 |
+
continue
|
| 156 |
+
v2_ckpt = None
|
| 157 |
+
for name in ("prithvi_nyc_pluvial_v2.ckpt",
|
| 158 |
+
"best_val_loss.ckpt", "model.ckpt",
|
| 159 |
+
"last.ckpt"):
|
| 160 |
+
try:
|
| 161 |
+
v2_ckpt = hf_hub_download(REPO, name)
|
| 162 |
break
|
| 163 |
+
except Exception:
|
| 164 |
+
continue
|
| 165 |
+
if v2_yaml and v2_ckpt:
|
| 166 |
+
log.info("prithvi_live: building v2 model from "
|
| 167 |
+
"yaml=%s ckpt=%s", v2_yaml, v2_ckpt)
|
| 168 |
+
m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt)
|
| 169 |
+
else:
|
| 170 |
+
log.warning("prithvi_live: v2 yaml/ckpt not "
|
| 171 |
+
"discoverable in %s; falling back to base",
|
| 172 |
+
REPO)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
log.warning("prithvi_live: v2 build failed (%s); "
|
| 175 |
+
"falling back to base", e)
|
| 176 |
+
m = None
|
| 177 |
+
|
| 178 |
+
# ---- base path: proven IBM-NASA Sen1Floods11 fine-tune -----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
if m is None:
|
| 180 |
+
base_config = hf_hub_download(BASE_REPO, "config.yaml")
|
| 181 |
+
base_ckpt = hf_hub_download(
|
| 182 |
+
BASE_REPO, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt")
|
| 183 |
+
m = LightningInferenceModel.from_config(base_config, base_ckpt)
|
|
|
|
|
|
|
| 184 |
|
| 185 |
m.model.eval()
|
| 186 |
if DEVICE == "cuda":
|
|
|
|
| 191 |
except Exception:
|
| 192 |
log.exception("prithvi_live: cuda move failed")
|
| 193 |
|
|
|
|
|
|
|
| 194 |
spec = importlib.util.spec_from_file_location("_prithvi_inference",
|
| 195 |
inference_py)
|
| 196 |
mod = importlib.util.module_from_spec(spec)
|