| """Prithvi-EO 2.0 (NYC Pluvial v2 fine-tune) live water segmentation. |
| |
| A per-query specialist: pulls the most recent low-cloud Sentinel-2 L2A |
| scene over the address from Microsoft Planetary Computer, runs the |
| NYC-specialized fine-tune, and reports % water within 500 m. |
| |
| Distinct from `app/flood_layers/prithvi_water.py`, which serves the |
| offline-precomputed 2021 Ida polygons. This one is *fresh observation* |
| each query — same doc_id (`prithvi_live`), but the underlying model |
| has been swapped from the Sen1Floods11 base to |
| `msradam/Prithvi-EO-2.0-NYC-Pluvial` (Apache-2.0, fine-tuned on AMD |
| Instinct MI300X via AMD Developer Cloud — test flood IoU 0.5979, |
| 6× over the base). The base model is still loadable by setting |
| RIPRAP_PRITHVI_LIVE_REPO to the IBM repo as a fallback. |
| |
| Network calls (STAC search + COG band reads) and a 300M-param model |
| forward pass make this the slowest specialist after the LLM. Gated by |
| RIPRAP_PRITHVI_LIVE_ENABLE so deployments without the deps installed |
| silently skip it. Cloud-cover refuses out at 30%+ to honor the |
| Sen1Floods11 training distribution. |
| |
| License: Apache-2.0. See experiments/shared/licenses.md. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import concurrent.futures |
| import logging |
| import os |
| import threading |
| import time |
| from typing import Any |
|
|
| log = logging.getLogger("riprap.prithvi_live") |
|
|
| ENABLE = os.environ.get("RIPRAP_PRITHVI_LIVE_ENABLE", "1").lower() in ("1", "true", "yes") |
| SEARCH_DAYS = int(os.environ.get("RIPRAP_PRITHVI_LIVE_SEARCH_DAYS", "120")) |
| MAX_CLOUD_PCT = float(os.environ.get("RIPRAP_PRITHVI_LIVE_MAX_CLOUD", "30")) |
| DEVICE = os.environ.get("RIPRAP_PRITHVI_LIVE_DEVICE", "cpu") |
|
|
| |
| |
| |
| REPO = os.environ.get( |
| "RIPRAP_PRITHVI_LIVE_REPO", |
| "msradam/Prithvi-EO-2.0-NYC-Pluvial", |
| ) |
| BASE_REPO = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" |
|
|
| |
| BANDS = ["B02", "B03", "B04", "B8A", "B11", "B12"] |
| IMG_SIZE = 512 |
| CHIP_PX = 1024 |
| CHIP_M = CHIP_PX * 10 |
| HALF_M = CHIP_M / 2 |
| CENTER_RADIUS_M = 500 |
| PIXEL_M = 10 |
|
|
| _MODEL = None |
| _RUN_MODEL = None |
| _INIT_LOCK = threading.Lock() |
| |
|
|
|
|
| def _has_required_deps() -> tuple[bool, str | None]: |
| """Probe deps in two tiers. |
| |
| Tier 1 — chip fetching (planetary_computer / pystac_client / rioxarray |
| / xarray / einops) is always required: prithvi_live always pulls a |
| Sentinel-2 chip from Microsoft Planetary Computer regardless of where |
| inference runs. |
| |
| Tier 2 — local inference (terratorch) is only required when remote |
| inference is unavailable. On the HF Space we have remote inference |
| on the AMD MI300X via app/inference.py, so terratorch is not needed |
| even though chip-fetch is. |
| |
| Returns (False, missing) if any required dep is missing. Splitting |
| the gate this way lets the HF Space deployment fetch chips and run |
| remote inference even though it doesn't fit terratorch's transitive |
| dep cone (~250 MB) in the HF build sandbox.""" |
| chip_deps = ("planetary_computer", "pystac_client", |
| "rioxarray", "xarray", "einops") |
| missing = [n for n in chip_deps |
| if not _has_module(n)] |
| if missing: |
| return False, ", ".join(missing) |
| |
| try: |
| from app import inference as _inf |
| if _inf.remote_enabled(): |
| return True, None |
| except Exception: |
| pass |
| if not _has_module("terratorch"): |
| return False, "terratorch (local inference)" |
| return True, None |
|
|
|
|
| def _has_module(name: str) -> bool: |
| """True if `name` imports cleanly. ImportError → not installed. |
| Other exceptions (e.g. torchvision::nms RuntimeError on the HF |
| Space) → treat as unavailable too; we don't want a clean-skip |
| intent to crash the FSM at deps-probe time.""" |
| try: |
| __import__(name) |
| return True |
| except ImportError: |
| return False |
| except Exception as e: |
| log.warning("prithvi_live: %s import raised %s; treating as " |
| "unavailable", name, type(e).__name__) |
| return False |
|
|
|
|
| _DEPS_OK, _DEPS_MISSING = _has_required_deps() |
|
|
|
|
| def warm(): |
| """Optional pre-load. The FSM action is lazy too — calling warm() |
| here just amortizes the first-query cost at app boot.""" |
| if not ENABLE: |
| return |
| try: |
| _ensure_model() |
| except Exception: |
| log.exception("prithvi_live: warm() failed; specialist will no-op") |
|
|
|
|
| def _ensure_model(): |
| """Load Prithvi-EO 2.0 once into RAM. |
| |
| The v2 NYC Pluvial fine-tune (`msradam/Prithvi-EO-2.0-NYC-Pluvial`) |
| is **architecturally distinct** from the IBM-NASA Sen1Floods11 |
| base: v2 ships a `UNetDecoder` + 2-class head, the base ships a |
| UperNet with PSP / FPN. The model has to be built from each |
| repo's own config.yaml — there's no key-mapping shim that bridges |
| them. |
| |
| Strategy: |
| |
| 1. If the active REPO != BASE_REPO, try to build from the v2 |
| yaml + v2 ckpt. The v2 yaml's data: paths point at the |
| training droplet's filesystem (`/root/terramind_nyc/...`) |
| which doesn't exist locally; that's fine — the |
| GenericNonGeoSegmentationDataModule constructor only |
| records the paths, splits aren't read until `setup()`. |
| 2. On any v2 failure (yaml not present, datamodule constructor |
| strict, weights mismatch), fall back to the base yaml + base |
| ckpt. The base path is the proven pre-C5 behaviour. |
| |
| The shared `inference.run_model` helper is only published by the |
| IBM-NASA base repo; we always pull it from there. |
| """ |
| global _MODEL, _RUN_MODEL |
| if _MODEL is not None: |
| return _MODEL, _RUN_MODEL |
| with _INIT_LOCK: |
| if _MODEL is not None: |
| return _MODEL, _RUN_MODEL |
| import importlib.util |
|
|
| from huggingface_hub import hf_hub_download |
| from terratorch.cli_tools import LightningInferenceModel |
| log.info("prithvi_live: loading model from %s", REPO) |
|
|
| |
| inference_py = hf_hub_download(BASE_REPO, "inference.py") |
|
|
| m = None |
| |
| if REPO != BASE_REPO: |
| try: |
| |
| |
| |
| v2_yaml = None |
| for name in ("prithvi_nyc_phase14.yaml", |
| "config.yaml", "phase14.yaml", |
| "prithvi_nyc_v2.yaml"): |
| try: |
| v2_yaml = hf_hub_download(REPO, name) |
| break |
| except Exception: |
| continue |
| v2_ckpt = None |
| for name in ("prithvi_nyc_pluvial_v2.ckpt", |
| "best_val_loss.ckpt", "model.ckpt", |
| "last.ckpt"): |
| try: |
| v2_ckpt = hf_hub_download(REPO, name) |
| break |
| except Exception: |
| continue |
| if v2_yaml and v2_ckpt: |
| log.info("prithvi_live: building v2 model from " |
| "yaml=%s ckpt=%s", v2_yaml, v2_ckpt) |
| m = LightningInferenceModel.from_config(v2_yaml, v2_ckpt) |
| |
| |
| |
| |
| |
| |
| if getattr(getattr(m, 'datamodule', None), |
| 'test_transform', None) is None: |
| import albumentations as A |
| import torch as _torch |
| from albumentations.pytorch import ToTensorV2 |
| m.datamodule.test_transform = A.Compose([ToTensorV2()]) |
| _old = m.datamodule.aug |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| class _DictNormalize: |
| def __init__(self, mean, std): |
| self.mean = _torch.as_tensor(mean).view(-1, 1, 1).float() |
| self.std = _torch.as_tensor(std).view(-1, 1, 1).float() |
|
|
| def __call__(self, sample): |
| if isinstance(sample, dict): |
| img = sample["image"] |
| mean = self.mean.to(img.device) |
| std = self.std.to(img.device) |
| return {**sample, "image": (img - mean) / std} |
| mean = self.mean.to(sample.device) |
| std = self.std.to(sample.device) |
| return (sample - mean) / std |
|
|
| |
| |
| |
| |
| |
| |
| |
| m.datamodule.aug = _DictNormalize(_old.means, _old.stds) |
| log.info("prithvi_live: patched v2 datamodule transforms " |
| "for IBM inference.py compat (dict-aware Normalize)") |
| else: |
| log.warning("prithvi_live: v2 yaml/ckpt not " |
| "discoverable in %s; falling back to base", |
| REPO) |
| except Exception as e: |
| log.warning("prithvi_live: v2 build failed (%s); " |
| "falling back to base", e) |
| m = None |
|
|
| |
| if m is None: |
| base_config = hf_hub_download(BASE_REPO, "config.yaml") |
| base_ckpt = hf_hub_download( |
| BASE_REPO, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt") |
| m = LightningInferenceModel.from_config(base_config, base_ckpt) |
|
|
| m.model.eval() |
| if DEVICE == "cuda": |
| try: |
| import torch |
| if torch.cuda.is_available(): |
| m.model.cuda() |
| except Exception: |
| log.exception("prithvi_live: cuda move failed") |
|
|
| spec = importlib.util.spec_from_file_location("_prithvi_inference", |
| inference_py) |
| mod = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mod) |
| _MODEL = m |
| _RUN_MODEL = mod.run_model |
| return _MODEL, _RUN_MODEL |
|
|
|
|
| def _search_recent_scene(lat: float, lon: float): |
| """Most recent low-cloud S2 L2A item near (lat, lon) in the last |
| SEARCH_DAYS days, or None.""" |
| import datetime as dt |
|
|
| import planetary_computer as pc |
| from pystac_client import Client |
| end = dt.datetime.utcnow().date() |
| start = end - dt.timedelta(days=SEARCH_DAYS) |
| client = Client.open( |
| "https://planetarycomputer.microsoft.com/api/stac/v1", |
| modifier=pc.sign_inplace, |
| ) |
| delta = 0.02 |
| search = client.search( |
| collections=["sentinel-2-l2a"], |
| bbox=[lon - delta, lat - delta, lon + delta, lat + delta], |
| datetime=f"{start}/{end}", |
| query={"eo:cloud_cover": {"lt": MAX_CLOUD_PCT}}, |
| max_items=20, |
| ) |
| items = sorted( |
| search.items(), |
| key=lambda it: (it.properties.get("eo:cloud_cover", 100), |
| -(it.datetime.timestamp() if it.datetime else 0)), |
| ) |
| return items[0] if items else None |
|
|
|
|
| def _build_chip(item, lat: float, lon: float): |
| """Returns (img, ref_da, epsg) — img is the (6, H, W) center-cropped |
| float32 array; ref_da is the rioxarray DataArray of the reference |
| band BEFORE the center crop (kept so we can compute the affine |
| transform for polygonization in EPSG:4326).""" |
| import numpy as np |
| import rioxarray |
| import xarray as xr |
| from pyproj import Transformer |
| if "proj:epsg" in item.properties: |
| epsg = int(item.properties["proj:epsg"]) |
| else: |
| code = item.properties.get("proj:code", "") |
| if code.startswith("EPSG:"): |
| epsg = int(code.split(":", 1)[1]) |
| else: |
| raise RuntimeError("STAC item missing proj:epsg / proj:code") |
| fwd = Transformer.from_crs("EPSG:4326", f"EPSG:{epsg}", always_xy=True) |
| cx, cy = fwd.transform(lon, lat) |
| xmin, xmax = cx - HALF_M, cx + HALF_M |
| ymin, ymax = cy - HALF_M, cy + HALF_M |
| ref = rioxarray.open_rasterio(item.assets[BANDS[0]].href, masked=False).squeeze(drop=True) |
| ref = ref.rio.clip_box(minx=xmin, miny=ymin, maxx=xmax, maxy=ymax) |
| ref = ref.isel(y=slice(0, CHIP_PX), x=slice(0, CHIP_PX)) |
| arrs = [ref.astype("float32")] |
| for b in BANDS[1:]: |
| da = rioxarray.open_rasterio(item.assets[b].href, masked=False).squeeze(drop=True) |
| da = da.rio.clip_box(minx=xmin, miny=ymin, maxx=xmax, maxy=ymax) |
| if da.shape != ref.shape: |
| da = da.rio.reproject_match(ref) |
| arrs.append(da.astype("float32")) |
| stacked = xr.concat(arrs, dim="band", join="override").assign_coords(band=BANDS) |
| img = stacked.values |
| |
| _, h, w = img.shape |
| sy, sx = (h - IMG_SIZE) // 2, (w - IMG_SIZE) // 2 |
| img = img[:, sy:sy + IMG_SIZE, sx:sx + IMG_SIZE] |
| if img.mean() > 1: |
| img = img / 10000.0 |
| return np.nan_to_num(img.astype("float32")), ref, epsg |
|
|
|
|
| def _polygonize_mask(pred, ref_da, epsg: int) -> dict | None: |
| """Vectorize the binary water mask into an EPSG:4326 GeoJSON |
| FeatureCollection so the frontend can paint it on the MapLibre |
| map. Returns None on failure (best-effort — never raises into the |
| caller path).""" |
| try: |
| import json |
|
|
| import geopandas as gpd |
| from rasterio.features import shapes |
| from rasterio.transform import from_origin |
| from shapely.geometry import shape |
| |
| |
| xs = ref_da.x.values |
| ys = ref_da.y.values |
| if len(xs) < IMG_SIZE or len(ys) < IMG_SIZE: |
| return None |
| |
| |
| sy = (len(ys) - IMG_SIZE) // 2 |
| sx = (len(xs) - IMG_SIZE) // 2 |
| |
| top_y = float(ys[sy]) + (PIXEL_M / 2.0) |
| left_x = float(xs[sx]) - (PIXEL_M / 2.0) |
| transform = from_origin(left_x, top_y, PIXEL_M, PIXEL_M) |
| |
| mask = (pred == 1).astype("uint8") |
| polys = [] |
| for geom, value in shapes(mask, mask=mask.astype(bool), |
| transform=transform): |
| if value != 1: |
| continue |
| polys.append(shape(geom)) |
| if not polys: |
| return {"type": "FeatureCollection", "features": []} |
| gdf = gpd.GeoDataFrame({"geometry": polys}, |
| crs=f"EPSG:{epsg}").to_crs("EPSG:4326") |
| |
| |
| |
| gdf["geometry"] = gdf.geometry.simplify(0.00005, preserve_topology=True) |
| return json.loads(gdf.to_json()) |
| except Exception: |
| log.exception("prithvi_live: polygonize failed") |
| return None |
|
|
|
|
| def _fetch_inner(lat: float, lon: float, timeout_s: float) -> dict[str, Any]: |
| """Core fetch logic — run inside a bounded thread via fetch().""" |
| t0 = time.time() |
| try: |
| item = _search_recent_scene(lat, lon) |
| if item is None: |
| return {"ok": False, "skipped": f"no <{MAX_CLOUD_PCT}% cloud " |
| f"S2 in last {SEARCH_DAYS}d"} |
| cc = float(item.properties.get("eo:cloud_cover", -1)) |
| if time.time() - t0 > timeout_s: |
| return {"ok": False, "skipped": "stac search exceeded budget"} |
| img, ref_da, epsg = _build_chip(item, lat, lon) |
| if time.time() - t0 > timeout_s: |
| return {"ok": False, "skipped": "chip build exceeded budget"} |
|
|
| |
| |
| |
| |
| |
| |
| |
| remote_attempted = False |
| try: |
| from app import inference as _inf |
| if _inf.remote_enabled(): |
| remote_attempted = True |
| remote = _inf.prithvi_pluvial( |
| img, scene_id=item.id, |
| scene_datetime=str(item.datetime), |
| cloud_cover=cc, |
| timeout=timeout_s, |
| ) |
| if remote.get("ok"): |
| |
| |
| |
| |
| |
| polys = None |
| pred_b64 = remote.get("pred_b64") |
| pred_shape = remote.get("pred_shape") |
| if pred_b64 and pred_shape: |
| try: |
| xs = ref_da.x.values |
| ys = ref_da.y.values |
| from pyproj import Transformer |
| t_inv = Transformer.from_crs( |
| f"EPSG:{epsg}", "EPSG:4326", |
| always_xy=True) |
| minx, maxx = float(xs.min()), float(xs.max()) |
| miny, maxy = float(ys.min()), float(ys.max()) |
| minlon, minlat = t_inv.transform(minx, miny) |
| maxlon, maxlat = t_inv.transform(maxx, maxy) |
| from app.context._polygonize import ( |
| polygonize_binary_mask, |
| ) |
| polys = polygonize_binary_mask( |
| pred_b64, pred_shape, |
| (minlon, minlat, maxlon, maxlat), |
| label="water", fill_color="#1F77B4", |
| simplify_tolerance=2e-5, |
| ) |
| except Exception: |
| log.exception("prithvi_live: remote polygonize failed") |
| polys = None |
| return { |
| "ok": True, |
| "item_id": item.id, |
| "item_datetime": str(item.datetime), |
| "cloud_cover": cc, |
| "pct_water_full": remote.get("pct_water_full"), |
| "pct_water_within_500m": remote.get("pct_water_within_500m"), |
| "polygons_geojson": polys, |
| "compute": f"remote · {remote.get('device', 'gpu')}", |
| "elapsed_s": round(time.time() - t0, 2), |
| } |
| err = (remote.get("err") |
| or remote.get("error") |
| or remote.get("skipped") |
| or "unknown") |
| return {"ok": False, |
| "skipped": f"remote prithvi-pluvial non-ok: {err}", |
| "elapsed_s": round(time.time() - t0, 2)} |
| except _inf.RemoteUnreachable as e: |
| log.info("prithvi_live: remote unreachable (%s)", e) |
| if remote_attempted: |
| |
| |
| |
| |
| return {"ok": False, |
| "skipped": f"remote prithvi-pluvial unreachable: {e}", |
| "elapsed_s": round(time.time() - t0, 2)} |
| except Exception as e: |
| log.exception("prithvi_live: remote call failed") |
| if remote_attempted: |
| return {"ok": False, |
| "skipped": f"remote prithvi-pluvial error: " |
| f"{type(e).__name__}: {e}", |
| "elapsed_s": round(time.time() - t0, 2)} |
|
|
| |
| |
| |
| model, run_model = _ensure_model() |
| x = img[None, :, None, :, :] |
| pred_t = run_model(x, None, None, model.model, model.datamodule, IMG_SIZE) |
| import numpy as np |
| pred = pred_t[0].cpu().numpy().astype("uint8") |
| pct_full = float(100.0 * pred.mean()) |
| yy, xx = np.indices(pred.shape) |
| cy, cx = pred.shape[0] // 2, pred.shape[1] // 2 |
| radius_px = CENTER_RADIUS_M / PIXEL_M |
| circle = (yy - cy) ** 2 + (xx - cx) ** 2 <= radius_px ** 2 |
| pct_500 = float(100.0 * pred[circle].mean()) if circle.sum() else 0.0 |
| polygons_geojson = _polygonize_mask(pred, ref_da, epsg) |
| return { |
| "ok": True, |
| "item_id": item.id, |
| "item_datetime": str(item.datetime), |
| "cloud_cover": cc, |
| "pct_water_full": pct_full, |
| "pct_water_within_500m": pct_500, |
| "polygons_geojson": polygons_geojson, |
| "compute": "local", |
| "elapsed_s": round(time.time() - t0, 2), |
| } |
| except Exception as e: |
| log.exception("prithvi_live: fetch failed") |
| return {"ok": False, "err": f"{type(e).__name__}: {e}", |
| "elapsed_s": round(time.time() - t0, 2)} |
|
|
|
|
| def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]: |
| """Run the specialist. Wraps _fetch_inner in a bounded thread so that |
| STAC searches and COG band reads (which lack per-request HTTP timeouts) |
| cannot hang the FSM indefinitely. |
| |
| Returns a dict with at minimum: |
| { "ok": bool, "skipped": str | None, "item_id": str | None, |
| "cloud_cover": float | None, "pct_water_within_500m": float | None } |
| Designed to never raise; failures show up as ok=False with an `err`. |
| """ |
| if not ENABLE: |
| return {"ok": False, "skipped": "RIPRAP_PRITHVI_LIVE_ENABLE=0"} |
| if not _DEPS_OK: |
| return {"ok": False, |
| "skipped": f"deps unavailable on this deployment: " |
| f"{_DEPS_MISSING}"} |
| hard_timeout = timeout_s + 15.0 |
| from app import emissions as _emissions |
| _parent_tracker = _emissions.current() |
| with concurrent.futures.ThreadPoolExecutor( |
| max_workers=1, |
| initializer=lambda t=_parent_tracker: _emissions.install(t), |
| ) as pool: |
| future = pool.submit(_fetch_inner, lat, lon, timeout_s) |
| try: |
| return future.result(timeout=hard_timeout) |
| except concurrent.futures.TimeoutError: |
| log.warning("prithvi_live: hard timeout after %.0fs (STAC/COG hung)", |
| hard_timeout) |
| return {"ok": False, |
| "skipped": f"prithvi_live timed out after {hard_timeout:.0f}s"} |
|
|