| """Run Prithvi-EO-2.0-300M-TL-Sen1Floods11 once on a low-cloud HLS scene |
| over NYC. Save the resulting water mask as a vectorized GeoJSON for use |
| as a Riprap flood-layer specialist. |
| |
| This script defers to IBM's official inference.py (downloaded from the |
| model repo) rather than reimplementing the inference loop — that file |
| knows about the temporal/location-coord embeddings, the per-window |
| albumentations stack, and the upernet decoder output shape, all of |
| which are easy to get wrong. |
| |
| python scripts/run_prithvi_flood.py |
| """ |
| from __future__ import annotations |
|
|
| import importlib.util |
| import json |
| import sys |
| import warnings |
| from pathlib import Path |
|
|
| warnings.filterwarnings("ignore") |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| OUT_DIR = ROOT / "data" |
| OUT_DIR.mkdir(exist_ok=True, parents=True) |
|
|
| |
| |
| |
| SCENES = [ |
| ("HLS.S30.T18TWL.2024247T153941.v2.0", "2024-09-04"), |
| ("HLS.S30.T18TXK.2024252T153819.v2.0", "2024-09-08"), |
| ] |
| SCENE_ID, SCENE_DATE = SCENES[0] |
| MODEL_REPO = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" |
| PRITHVI_BAND_NAMES = ["B02", "B03", "B04", "B8A", "B11", "B12"] |
|
|
|
|
| def _stage_stack(out_path: Path, scene_id: str = SCENE_ID) -> bool: |
| if out_path.exists(): |
| return True |
| import numpy as np |
| import planetary_computer |
| import pystac_client |
| import rasterio |
| print(f"fetching scene {scene_id}...", file=sys.stderr) |
| catalog = pystac_client.Client.open( |
| "https://planetarycomputer.microsoft.com/api/stac/v1", |
| modifier=planetary_computer.sign_inplace, |
| ) |
| item = catalog.get_collection("hls2-s30").get_item(scene_id) |
| if item is None: |
| print(" scene not retrievable", file=sys.stderr) |
| return False |
| arrays = []; profile = None |
| for band in PRITHVI_BAND_NAMES: |
| with rasterio.open(item.assets[band].href) as ds: |
| arrays.append(ds.read(1)) |
| if profile is None: |
| profile = ds.profile.copy() |
| stack = np.stack(arrays, axis=0).astype("float32") |
| |
| |
| |
| |
| |
| stack[stack <= -9000] = 0.0 |
| stack = stack / 10000.0 |
| stack = np.clip(stack, 0.0, 1.0).astype("float32") |
| profile.update(count=6, dtype="float32", |
| compress="DEFLATE", tiled=True, |
| blockxsize=256, blockysize=256, nodata=0.0) |
| with rasterio.open(out_path, "w", **profile) as ds: |
| for i in range(6): |
| ds.write(stack[i], i + 1) |
| print(f" wrote {out_path} ({out_path.stat().st_size // (1024*1024)} MB) " |
| f"(reflectance units, nodata→0)", file=sys.stderr) |
| return True |
|
|
|
|
| def _process_one(scene_id: str, scene_date: str) -> list[dict]: |
| """Stage one MGRS tile, run Prithvi, vectorise to features. Returns |
| a list of GeoJSON Features in EPSG:4326 (so they can be merged across |
| tiles in different UTM zones).""" |
| stack_path = OUT_DIR / f"hls_stack_{scene_date}.tif" |
| if not _stage_stack(stack_path, scene_id=scene_id): |
| return [] |
|
|
| from huggingface_hub import hf_hub_download |
| inf_py = hf_hub_download(MODEL_REPO, "inference.py") |
| cfg = hf_hub_download(MODEL_REPO, "config.yaml") |
| ckpt = hf_hub_download(MODEL_REPO, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt") |
|
|
| spec = importlib.util.spec_from_file_location("prithvi_inf", inf_py) |
| pm = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(pm) |
|
|
| out_dir = OUT_DIR / "prithvi_runs" |
| out_dir.mkdir(exist_ok=True) |
|
|
| pred_path = out_dir / f"pred_{stack_path.stem}.tiff" |
| if not pred_path.exists(): |
| print(f"running Prithvi on {scene_id}...", file=sys.stderr) |
| pm.main(data_file=str(stack_path), config=cfg, checkpoint=ckpt, |
| output_dir=str(out_dir), rgb_outputs=False, input_indices=None) |
| else: |
| print(f" reusing existing pred: {pred_path}", file=sys.stderr) |
|
|
| if not pred_path.exists(): |
| cands = list(out_dir.glob(f"pred_{stack_path.stem}*")) |
| pred_path = cands[0] if cands else None |
| if pred_path is None or not pred_path.exists(): |
| print(f" no prediction tiff for {scene_id}", file=sys.stderr) |
| return [] |
|
|
| import geopandas as gpd |
| import rasterio |
| from rasterio.features import shapes |
| from shapely.geometry import mapping, shape |
|
|
| with rasterio.open(pred_path) as ds: |
| pred = ds.read(1); transform = ds.transform; src_crs = ds.crs |
|
|
| water_mask = pred == 255 |
| n_water = int(water_mask.sum()) |
| print(f" {scene_id}: {n_water} water px " |
| f"({100*n_water/pred.size:.2f}%)", file=sys.stderr) |
|
|
| feats = [] |
| for geom, val in shapes(water_mask.astype("uint8"), |
| mask=water_mask, transform=transform): |
| if val == 1: |
| poly = shape(geom) |
| if poly.area > 0: |
| feats.append({"type": "Feature", |
| "geometry": mapping(poly), |
| "properties": {"class": "water", |
| "scene_id": scene_id, |
| "scene_date": scene_date}}) |
|
|
| if not feats: |
| return [] |
|
|
| |
| g = gpd.GeoDataFrame.from_features(feats, crs=src_crs) |
| g = g.to_crs("EPSG:4326") |
| return json.loads(g.to_json())["features"] |
|
|
|
|
| def main() -> int: |
| out_geojson = OUT_DIR / "prithvi_flood_nyc.geojson" |
| if out_geojson.exists(): |
| print(f"already exists: {out_geojson}", file=sys.stderr) |
| return 0 |
|
|
| all_features = [] |
| scene_ids = []; scene_dates = [] |
| for scene_id, scene_date in SCENES: |
| feats = _process_one(scene_id, scene_date) |
| all_features.extend(feats) |
| if feats: |
| scene_ids.append(scene_id); scene_dates.append(scene_date) |
|
|
| out = {"type": "FeatureCollection", "features": all_features, |
| "scene_ids": scene_ids, "scene_dates": scene_dates, |
| "model": MODEL_REPO, "crs": "EPSG:4326"} |
| out_geojson.write_text(json.dumps(out)) |
| print(f"\nwrote {len(all_features)} water polygons across " |
| f"{len(scene_ids)} scenes -> {out_geojson} " |
| f"({out_geojson.stat().st_size // 1024} KB)", file=sys.stderr) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|