File size: 6,867 Bytes
dbf7a0e b9a10ad dbf7a0e b9a10ad dbf7a0e b9a10ad dbf7a0e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """Run Prithvi-EO-2.0-300M-TL-Sen1Floods11 once on a low-cloud HLS scene
over NYC. Save the resulting water mask as a vectorized GeoJSON for use
as a Riprap flood-layer specialist.
This script defers to IBM's official inference.py (downloaded from the
model repo) rather than reimplementing the inference loop — that file
knows about the temporal/location-coord embeddings, the per-window
albumentations stack, and the upernet decoder output shape, all of
which are easy to get wrong.
python scripts/run_prithvi_flood.py
"""
from __future__ import annotations
import importlib.util
import json
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")
ROOT = Path(__file__).resolve().parent.parent
OUT_DIR = ROOT / "data"
OUT_DIR.mkdir(exist_ok=True, parents=True)
# NYC needs two MGRS tiles to cover everything:
# T18TWL covers Manhattan, Bronx, western Brooklyn, Newark Bay
# T18TXK covers eastern Brooklyn, Queens, Far Rockaway, Jamaica Bay, Long Island Sound
SCENES = [
("HLS.S30.T18TWL.2024247T153941.v2.0", "2024-09-04"), # 1% cloud, central NYC
("HLS.S30.T18TXK.2024252T153819.v2.0", "2024-09-08"), # 0% cloud, eastern NYC
]
SCENE_ID, SCENE_DATE = SCENES[0] # back-compat for legacy users
MODEL_REPO = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
PRITHVI_BAND_NAMES = ["B02", "B03", "B04", "B8A", "B11", "B12"]
def _stage_stack(out_path: Path, scene_id: str = SCENE_ID) -> bool:
if out_path.exists():
return True
import numpy as np
import planetary_computer
import pystac_client
import rasterio
print(f"fetching scene {scene_id}...", file=sys.stderr)
catalog = pystac_client.Client.open(
"https://planetarycomputer.microsoft.com/api/stac/v1",
modifier=planetary_computer.sign_inplace,
)
item = catalog.get_collection("hls2-s30").get_item(scene_id)
if item is None:
print(" scene not retrievable", file=sys.stderr)
return False
arrays = []; profile = None
for band in PRITHVI_BAND_NAMES:
with rasterio.open(item.assets[band].href) as ds:
arrays.append(ds.read(1))
if profile is None:
profile = ds.profile.copy()
stack = np.stack(arrays, axis=0).astype("float32")
# Replace nodata -9999 with the inference.py NO_DATA_FLOAT sentinel (0.0001).
# inference.py only treats nodata correctly when explicit mean/std are
# configured — for this Sen1Floods11 fine-tune mean/std are None, so we
# do the substitution upstream and write a clean float32 raster in 0..1
# reflectance units (constant_scale=0.0001 in config => DN/10000).
stack[stack <= -9000] = 0.0
stack = stack / 10000.0
stack = np.clip(stack, 0.0, 1.0).astype("float32")
profile.update(count=6, dtype="float32",
compress="DEFLATE", tiled=True,
blockxsize=256, blockysize=256, nodata=0.0)
with rasterio.open(out_path, "w", **profile) as ds:
for i in range(6):
ds.write(stack[i], i + 1)
print(f" wrote {out_path} ({out_path.stat().st_size // (1024*1024)} MB) "
f"(reflectance units, nodata→0)", file=sys.stderr)
return True
def _process_one(scene_id: str, scene_date: str) -> list[dict]:
"""Stage one MGRS tile, run Prithvi, vectorise to features. Returns
a list of GeoJSON Features in EPSG:4326 (so they can be merged across
tiles in different UTM zones)."""
stack_path = OUT_DIR / f"hls_stack_{scene_date}.tif"
if not _stage_stack(stack_path, scene_id=scene_id):
return []
from huggingface_hub import hf_hub_download
inf_py = hf_hub_download(MODEL_REPO, "inference.py")
cfg = hf_hub_download(MODEL_REPO, "config.yaml")
ckpt = hf_hub_download(MODEL_REPO, "Prithvi-EO-V2-300M-TL-Sen1Floods11.pt")
spec = importlib.util.spec_from_file_location("prithvi_inf", inf_py)
pm = importlib.util.module_from_spec(spec)
spec.loader.exec_module(pm)
out_dir = OUT_DIR / "prithvi_runs"
out_dir.mkdir(exist_ok=True)
pred_path = out_dir / f"pred_{stack_path.stem}.tiff"
if not pred_path.exists():
print(f"running Prithvi on {scene_id}...", file=sys.stderr)
pm.main(data_file=str(stack_path), config=cfg, checkpoint=ckpt,
output_dir=str(out_dir), rgb_outputs=False, input_indices=None)
else:
print(f" reusing existing pred: {pred_path}", file=sys.stderr)
if not pred_path.exists():
cands = list(out_dir.glob(f"pred_{stack_path.stem}*"))
pred_path = cands[0] if cands else None
if pred_path is None or not pred_path.exists():
print(f" no prediction tiff for {scene_id}", file=sys.stderr)
return []
import geopandas as gpd
import rasterio
from rasterio.features import shapes
from shapely.geometry import mapping, shape
with rasterio.open(pred_path) as ds:
pred = ds.read(1); transform = ds.transform; src_crs = ds.crs
water_mask = pred == 255
n_water = int(water_mask.sum())
print(f" {scene_id}: {n_water} water px "
f"({100*n_water/pred.size:.2f}%)", file=sys.stderr)
feats = []
for geom, val in shapes(water_mask.astype("uint8"),
mask=water_mask, transform=transform):
if val == 1:
poly = shape(geom)
if poly.area > 0:
feats.append({"type": "Feature",
"geometry": mapping(poly),
"properties": {"class": "water",
"scene_id": scene_id,
"scene_date": scene_date}})
if not feats:
return []
# Reproject to EPSG:4326 for cross-tile merging
g = gpd.GeoDataFrame.from_features(feats, crs=src_crs)
g = g.to_crs("EPSG:4326")
return json.loads(g.to_json())["features"]
def main() -> int:
out_geojson = OUT_DIR / "prithvi_flood_nyc.geojson"
if out_geojson.exists():
print(f"already exists: {out_geojson}", file=sys.stderr)
return 0
all_features = []
scene_ids = []; scene_dates = []
for scene_id, scene_date in SCENES:
feats = _process_one(scene_id, scene_date)
all_features.extend(feats)
if feats:
scene_ids.append(scene_id); scene_dates.append(scene_date)
out = {"type": "FeatureCollection", "features": all_features,
"scene_ids": scene_ids, "scene_dates": scene_dates,
"model": MODEL_REPO, "crs": "EPSG:4326"}
out_geojson.write_text(json.dumps(out))
print(f"\nwrote {len(all_features)} water polygons across "
f"{len(scene_ids)} scenes -> {out_geojson} "
f"({out_geojson.stat().st_size // 1024} KB)", file=sys.stderr)
return 0
if __name__ == "__main__":
sys.exit(main())
|