seriffic Claude Opus 4.7 (1M context) commited on
Commit
bd05c73
·
1 Parent(s): 2e8df60

fix(prithvi_live): build v2 model from v2 yaml, not base config

Browse files

C5 originally tried to load v2 ckpt weights into a model built from
the IBM-NASA base config.yaml. They're architecturally different —
v2 ships UNetDecoder + 2-class head; the base ships UperNet (PSP /
FPN). Loading produced a giant size-mismatch RuntimeError on
head.head.2 and dozens of missing/unexpected keys in decoder.fpn1 /
psp_modules / lateral_convs.

Fix: when the active REPO is not BASE_REPO, download the v2 yaml +
v2 ckpt directly from the published HF artefact and let
LightningInferenceModel.from_config build the architecture from the
v2 yaml itself. The yaml's data: section points at training-droplet
paths that don't exist locally, but the
GenericNonGeoSegmentationDataModule constructor only records paths;
splits aren't read until setup(), which we never call.

Falls back to the proven base path on any v2 failure (yaml not in
repo, datamodule constructor strict, etc.) so the specialist degrades
to v1 behaviour rather than no-opping.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. app/flood_layers/prithvi_live.py +70 -69
app/flood_layers/prithvi_live.py CHANGED
@@ -98,21 +98,30 @@ def warm():
98
 
99
 
100
  def _ensure_model():
101
- """Load Prithvi-EO 2.0 once into RAM. Two artifact shapes are
102
- supported, in priority order:
103
-
104
- 1) **NYC Pluvial v2** (`msradam/Prithvi-EO-2.0-NYC-Pluvial`)
105
- Lightning checkpoint (`*.ckpt`) restored via
106
- `SemanticSegmentationTask.load_from_checkpoint`. Full task
107
- (config + weights) lives inside the ckpt.
108
- 2) **Sen1Floods11 base** (`ibm-nasa-geospatial/...`) — raw `.pt`
109
- weights + a separate `config.yaml`, loaded via
110
- `LightningInferenceModel.from_config(config, ckpt)`. This is
111
- the path the original prithvi_live.py used.
112
-
113
- The shared inference helper (`run_model`) only ships in the IBM-NASA
114
- base repo; for the v2 path we monkey-import it from the base repo
115
- so a single code path drives prediction either way."""
 
 
 
 
 
 
 
 
 
116
  global _MODEL, _RUN_MODEL
117
  if _MODEL is not None:
118
  return _MODEL, _RUN_MODEL
@@ -121,63 +130,57 @@ def _ensure_model():
121
  return _MODEL, _RUN_MODEL
122
  import importlib.util
123
 
124
- from huggingface_hub import hf_hub_download, snapshot_download
 
125
  log.info("prithvi_live: loading model from %s", REPO)
126
 
127
- # ---- Try the v2 / Lightning-ckpt path first -----------------
 
 
128
  m = None
129
- try:
130
- from terratorch.tasks import SemanticSegmentationTask
131
- local_dir = snapshot_download(REPO)
132
- ckpt = None
133
- # Lightning saves under various conventional names; probe
134
- # the most likely candidates rather than trusting one path.
135
- for name in ("best_val_loss.ckpt", "model.ckpt",
136
- "last.ckpt"):
137
- candidate = os.path.join(local_dir, name)
138
- if os.path.exists(candidate):
139
- ckpt = candidate
140
- break
141
- if ckpt is None:
142
- # Walk the snapshot for any *.ckpt file.
143
- for root, _, files in os.walk(local_dir):
144
- for f in files:
145
- if f.endswith(".ckpt"):
146
- ckpt = os.path.join(root, f)
147
- break
148
- if ckpt:
 
149
  break
150
- if ckpt is not None:
151
- log.info("prithvi_live: loading Lightning ckpt %s", ckpt)
152
- map_loc = "cuda" if (DEVICE == "cuda") else "cpu"
153
- task = SemanticSegmentationTask.load_from_checkpoint(
154
- ckpt, map_location=map_loc, strict=False,
155
- )
156
- task.eval()
157
-
158
- # Mimic LightningInferenceModel's surface so the rest
159
- # of the file (which expects `.model` and `.datamodule`)
160
- # keeps working. datamodule isn't strictly needed by
161
- # run_model in current terratorch but we set it to None
162
- # explicitly so a missing-attr access surfaces clearly.
163
- class _LightningTaskWrapper:
164
- def __init__(self, task):
165
- self.model = task
166
- self.datamodule = None
167
-
168
- m = _LightningTaskWrapper(task)
169
- except Exception as e:
170
- log.warning("prithvi_live: Lightning-ckpt load failed (%s); "
171
- "falling back to raw-weights path", e)
172
-
173
- # ---- Fallback: raw .pt + config.yaml (Sen1Floods11 base) ----
174
  if m is None:
175
- from terratorch.cli_tools import LightningInferenceModel
176
- base = REPO if REPO == BASE_REPO else BASE_REPO
177
- config_path = hf_hub_download(base, "config.yaml")
178
- checkpoint = hf_hub_download(
179
- base, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt")
180
- m = LightningInferenceModel.from_config(config_path, checkpoint)
181
 
182
  m.model.eval()
183
  if DEVICE == "cuda":
@@ -188,8 +191,6 @@ def _ensure_model():
188
  except Exception:
189
  log.exception("prithvi_live: cuda move failed")
190
 
191
- # Inference helper lives only in the IBM-NASA base repo.
192
- inference_py = hf_hub_download(BASE_REPO, "inference.py")
193
  spec = importlib.util.spec_from_file_location("_prithvi_inference",
194
  inference_py)
195
  mod = importlib.util.module_from_spec(spec)
 
98
 
99
 
100
  def _ensure_model():
101
+ """Load Prithvi-EO 2.0 once into RAM.
102
+
103
+ The v2 NYC Pluvial fine-tune (`msradam/Prithvi-EO-2.0-NYC-Pluvial`)
104
+ is **architecturally distinct** from the IBM-NASA Sen1Floods11
105
+ base: v2 ships a `UNetDecoder` + 2-class head, the base ships a
106
+ UperNet with PSP / FPN. The model has to be built from each
107
+ repo's own config.yaml there's no key-mapping shim that bridges
108
+ them.
109
+
110
+ Strategy:
111
+
112
+ 1. If the active REPO != BASE_REPO, try to build from the v2
113
+ yaml + v2 ckpt. The v2 yaml's data: paths point at the
114
+ training droplet's filesystem (`/root/terramind_nyc/...`)
115
+ which doesn't exist locally; that's fine the
116
+ GenericNonGeoSegmentationDataModule constructor only
117
+ records the paths, splits aren't read until `setup()`.
118
+ 2. On any v2 failure (yaml not present, datamodule constructor
119
+ strict, weights mismatch), fall back to the base yaml + base
120
+ ckpt. The base path is the proven pre-C5 behaviour.
121
+
122
+ The shared `inference.run_model` helper is only published by the
123
+ IBM-NASA base repo; we always pull it from there.
124
+ """
125
  global _MODEL, _RUN_MODEL
126
  if _MODEL is not None:
127
  return _MODEL, _RUN_MODEL
 
130
  return _MODEL, _RUN_MODEL
131
  import importlib.util
132
 
133
+ from huggingface_hub import hf_hub_download
134
+ from terratorch.cli_tools import LightningInferenceModel
135
  log.info("prithvi_live: loading model from %s", REPO)
136
 
137
+ # Inference helper only lives in the IBM-NASA base repo.
138
+ inference_py = hf_hub_download(BASE_REPO, "inference.py")
139
+
140
  m = None
141
+ # ---- v2 path: yaml + ckpt from the published repo ----------
142
+ if REPO != BASE_REPO:
143
+ try:
144
+ # The v2 repo publishes `prithvi_nyc_phase14.yaml` and
145
+ # `prithvi_nyc_pluvial_v2.ckpt`. Be tolerant of small
146
+ # naming drift (best_val_loss.ckpt etc.) by probing.
147
+ v2_yaml = None
148
+ for name in ("prithvi_nyc_phase14.yaml",
149
+ "config.yaml", "phase14.yaml",
150
+ "prithvi_nyc_v2.yaml"):
151
+ try:
152
+ v2_yaml = hf_hub_download(REPO, name)
153
+ break
154
+ except Exception:
155
+ continue
156
+ v2_ckpt = None
157
+ for name in ("prithvi_nyc_pluvial_v2.ckpt",
158
+ "best_val_loss.ckpt", "model.ckpt",
159
+ "last.ckpt"):
160
+ try:
161
+ v2_ckpt = hf_hub_download(REPO, name)
162
  break
163
+ except Exception:
164
+ continue
165
+ if v2_yaml and v2_ckpt:
166
+ log.info("prithvi_live: building v2 model from "
167
+ "yaml=%s ckpt=%s", v2_yaml, v2_ckpt)
168
+ m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt)
169
+ else:
170
+ log.warning("prithvi_live: v2 yaml/ckpt not "
171
+ "discoverable in %s; falling back to base",
172
+ REPO)
173
+ except Exception as e:
174
+ log.warning("prithvi_live: v2 build failed (%s); "
175
+ "falling back to base", e)
176
+ m = None
177
+
178
+ # ---- base path: proven IBM-NASA Sen1Floods11 fine-tune -----
 
 
 
 
 
 
 
 
179
  if m is None:
180
+ base_config = hf_hub_download(BASE_REPO, "config.yaml")
181
+ base_ckpt = hf_hub_download(
182
+ BASE_REPO, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt")
183
+ m = LightningInferenceModel.from_config(base_config, base_ckpt)
 
 
184
 
185
  m.model.eval()
186
  if DEVICE == "cuda":
 
191
  except Exception:
192
  log.exception("prithvi_live: cuda move failed")
193
 
 
 
194
  spec = importlib.util.spec_from_file_location("_prithvi_inference",
195
  inference_py)
196
  mod = importlib.util.module_from_spec(spec)