| """Per-query EO chip cache — Sentinel-2 L2A, Sentinel-1 RTC, DEM. |
| |
| Fetches a co-registered (S2L2A, S1RTC, DEM) chip centered on (lat, lon) |
| and returns a dict of torch tensors ready for TerraMind-NYC inference. |
| The TerraMind base was trained with `temporal_n_timestamps=4`, so this |
| helper expands a single S2/S1 acquisition to T=4 by repetition along |
| the temporal axis. Single-timestep nowcasting trades some training- |
| distribution match for a much simpler runtime — the published LoRA |
| adapters still produce sensible argmax masks at T=1 / tiled. |
| |
| Failure semantics mirror prithvi_live: every dependency or network |
| failure is converted to a clean `{ok: False, skipped: <reason>}` |
| result, never a raised exception. Callers (FSM specialists) that |
| chain off the chip can short-circuit on `ok=False` and skip the |
| specialist instead of surfacing a noisy error. |
| """ |
| from __future__ import annotations |
|
|
| import concurrent.futures |
| import logging |
| import os |
| import threading |
| import time |
| from typing import Any |
|
|
| log = logging.getLogger("riprap.eo_chip_cache") |
|
|
| ENABLE = os.environ.get("RIPRAP_EO_CHIP_ENABLE", "1").lower() in ("1", "true", "yes") |
| SEARCH_DAYS = int(os.environ.get("RIPRAP_EO_CHIP_SEARCH_DAYS", "120")) |
| MAX_CLOUD_PCT = float(os.environ.get("RIPRAP_EO_CHIP_MAX_CLOUD", "30")) |
| CHIP_PX = int(os.environ.get("RIPRAP_EO_CHIP_PX", "224")) |
| PIXEL_M = 10 |
| N_TIMESTEPS = 4 |
|
|
| |
| S2_BANDS = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", |
| "B08", "B8A", "B09", "B11", "B12"] |
|
|
| |
| S1_BANDS = ["vv", "vh"] |
|
|
|
|
| def _has_required_deps() -> tuple[bool, str | None]: |
| missing: list[str] = [] |
| for name in ("planetary_computer", "pystac_client", |
| "rioxarray", "xarray", "torch", "numpy"): |
| try: |
| __import__(name) |
| except ImportError: |
| missing.append(name) |
| if missing: |
| return False, ", ".join(missing) |
| return True, None |
|
|
|
|
| _DEPS_OK, _DEPS_MISSING = _has_required_deps() |
| _FETCH_LOCK = threading.Lock() |
|
|
|
|
| def _search_s2(lat: float, lon: float): |
| """Return (item, cloud_cover) for the most recent low-cloud S2L2A |
| acquisition near (lat, lon), or (None, None) if no scene exists.""" |
| import datetime as dt |
|
|
| import planetary_computer as pc |
| from pystac_client import Client |
| end = dt.datetime.utcnow().date() |
| start = end - dt.timedelta(days=SEARCH_DAYS) |
| client = Client.open( |
| "https://planetarycomputer.microsoft.com/api/stac/v1", |
| modifier=pc.sign_inplace, |
| ) |
| delta = 0.02 |
| search = client.search( |
| collections=["sentinel-2-l2a"], |
| bbox=[lon - delta, lat - delta, lon + delta, lat + delta], |
| datetime=f"{start}/{end}", |
| query={"eo:cloud_cover": {"lt": MAX_CLOUD_PCT}}, |
| max_items=20, |
| ) |
| items = sorted( |
| search.items(), |
| key=lambda it: (it.properties.get("eo:cloud_cover", 100), |
| -(it.datetime.timestamp() if it.datetime else 0)), |
| ) |
| if not items: |
| return None, None |
| item = items[0] |
| cc = float(item.properties.get("eo:cloud_cover", -1)) |
| return item, cc |
|
|
|
|
| def _search_s1(item_dt, lat: float, lon: float): |
| """Return the closest Sentinel-1 RTC acquisition to the given S2 |
| datetime, or None if Planetary Computer has nothing nearby.""" |
| import datetime as dt |
|
|
| import planetary_computer as pc |
| from pystac_client import Client |
| win = dt.timedelta(days=10) |
| start = item_dt - win |
| end = item_dt + win |
| client = Client.open( |
| "https://planetarycomputer.microsoft.com/api/stac/v1", |
| modifier=pc.sign_inplace, |
| ) |
| delta = 0.02 |
| search = client.search( |
| collections=["sentinel-1-rtc"], |
| bbox=[lon - delta, lat - delta, lon + delta, lat + delta], |
| datetime=f"{start.isoformat()}/{end.isoformat()}", |
| max_items=10, |
| ) |
| items = list(search.items()) |
| if not items: |
| return None |
| items.sort(key=lambda it: |
| abs((it.datetime - item_dt).total_seconds()) |
| if it.datetime else 1e18) |
| return items[0] |
|
|
|
|
| def _read_band(href, bbox_xy_meters, epsg): |
| """Read a single COG band, clipped to the bbox, and resample to |
| CHIP_PX × CHIP_PX. Returns a numpy array (CHIP_PX, CHIP_PX) float32. |
| """ |
| import numpy as np |
| import rioxarray |
| da = rioxarray.open_rasterio(href, masked=False).squeeze(drop=True) |
| da = da.rio.clip_box(minx=bbox_xy_meters[0], miny=bbox_xy_meters[1], |
| maxx=bbox_xy_meters[2], maxy=bbox_xy_meters[3]) |
| if da.shape[-2] != CHIP_PX or da.shape[-1] != CHIP_PX: |
| |
| |
| |
| |
| da = da.rio.reproject( |
| f"EPSG:{epsg}", shape=(CHIP_PX, CHIP_PX), resampling=0 |
| ) |
| arr = da.values.astype("float32") |
| return np.nan_to_num(arr) |
|
|
|
|
| def _fetch_modalities(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]: |
| """Fetch S2L2A + S1RTC + DEM as numpy arrays, resampled to a common |
| CHIP_PX × CHIP_PX grid centered on (lat, lon). |
| """ |
| import numpy as np |
| from pyproj import Transformer |
|
|
| t0 = time.time() |
| item, cc = _search_s2(lat, lon) |
| if item is None: |
| return {"ok": False, |
| "skipped": f"no <{MAX_CLOUD_PCT}% cloud S2 in last " |
| f"{SEARCH_DAYS}d"} |
| if "proj:epsg" in item.properties: |
| epsg = int(item.properties["proj:epsg"]) |
| else: |
| code = item.properties.get("proj:code", "") |
| if not code.startswith("EPSG:"): |
| return {"ok": False, |
| "skipped": "STAC item missing proj:epsg / proj:code"} |
| epsg = int(code.split(":", 1)[1]) |
|
|
| fwd = Transformer.from_crs("EPSG:4326", f"EPSG:{epsg}", always_xy=True) |
| cx, cy = fwd.transform(lon, lat) |
| half_m = CHIP_PX / 2 * PIXEL_M |
| bbox = (cx - half_m, cy - half_m, cx + half_m, cy + half_m) |
|
|
| if time.time() - t0 > timeout_s: |
| return {"ok": False, "skipped": "STAC search exceeded budget"} |
|
|
| |
| s2_arrs = [] |
| try: |
| for b in S2_BANDS: |
| href = item.assets[b].href |
| s2_arrs.append(_read_band(href, bbox, epsg)) |
| except Exception as e: |
| log.warning("eo_chip: S2 band fetch failed (%s); aborting", e) |
| return {"ok": False, "err": f"S2 fetch failed: {type(e).__name__}: {e}"} |
| s2 = np.stack(s2_arrs) |
| if s2.mean() > 1.0: |
| s2 = s2 / 10000.0 |
|
|
| |
| s1: np.ndarray | None = None |
| s1_meta: dict[str, Any] = {} |
| if time.time() - t0 < timeout_s: |
| try: |
| s1_item = _search_s1(item.datetime, lat, lon) |
| if s1_item is not None: |
| s1_arrs = [] |
| for b in S1_BANDS: |
| href = s1_item.assets[b].href |
| s1_arrs.append(_read_band(href, bbox, epsg)) |
| s1 = np.stack(s1_arrs) |
| s1_meta = { |
| "scene_id": s1_item.id, |
| "datetime": (s1_item.datetime.isoformat() |
| if s1_item.datetime else None), |
| } |
| except Exception as e: |
| log.warning("eo_chip: S1 fetch best-effort failed: %s", e) |
|
|
| |
| dem: np.ndarray | None = None |
| if time.time() - t0 < timeout_s: |
| try: |
| import planetary_computer as pc |
| from pystac_client import Client |
| client = Client.open( |
| "https://planetarycomputer.microsoft.com/api/stac/v1", |
| modifier=pc.sign_inplace, |
| ) |
| dem_search = client.search( |
| collections=["cop-dem-glo-30"], |
| bbox=[lon - 0.02, lat - 0.02, lon + 0.02, lat + 0.02], |
| max_items=1, |
| ) |
| dem_items = list(dem_search.items()) |
| if dem_items: |
| href = dem_items[0].assets["data"].href |
| dem = _read_band(href, bbox, epsg) |
| dem = dem[None, :, :] |
| except Exception as e: |
| log.warning("eo_chip: DEM fetch best-effort failed: %s", e) |
|
|
| return { |
| "ok": True, |
| "lat": lat, "lon": lon, |
| "epsg": epsg, "chip_px": CHIP_PX, "pixel_m": PIXEL_M, |
| "s2": s2, "s1": s1, "dem": dem, |
| "s2_meta": { |
| "scene_id": item.id, |
| "datetime": (item.datetime.isoformat() if item.datetime else None), |
| "cloud_cover": cc, |
| }, |
| "s1_meta": s1_meta, |
| "elapsed_s": round(time.time() - t0, 2), |
| } |
|
|
|
|
| def _to_terramind_tensors(modalities: dict[str, Any]) -> dict[str, Any]: |
| """Shape numpy modality arrays into the (B, C, T, H, W) tensors |
| TerraMind expects with `temporal_n_timestamps=4`. Single-timestep |
| fetches get tiled to T=4 — same observation in every slot. |
| """ |
| import torch |
| s2 = modalities["s2"] |
| s2_t = torch.from_numpy(s2).float().unsqueeze(1) |
| s2_t = s2_t.repeat(1, N_TIMESTEPS, 1, 1).unsqueeze(0) |
| chips = {"S2L2A": s2_t} |
| if modalities.get("s1") is not None: |
| s1 = modalities["s1"] |
| s1_t = torch.from_numpy(s1).float().unsqueeze(1) |
| s1_t = s1_t.repeat(1, N_TIMESTEPS, 1, 1).unsqueeze(0) |
| chips["S1RTC"] = s1_t |
| if modalities.get("dem") is not None: |
| dem = modalities["dem"] |
| dem_t = torch.from_numpy(dem).float().unsqueeze(1) |
| dem_t = dem_t.repeat(1, N_TIMESTEPS, 1, 1).unsqueeze(0) |
| chips["DEM"] = dem_t |
| return chips |
|
|
|
|
| def _fetch_and_build(lat: float, lon: float, timeout_s: float) -> dict[str, Any]: |
| """Inner fetch + tensor build, run inside a bounded thread.""" |
| with _FETCH_LOCK: |
| try: |
| modalities = _fetch_modalities(lat, lon, timeout_s=timeout_s) |
| except Exception as e: |
| log.exception("eo_chip: fetch failed") |
| return {"ok": False, "err": f"{type(e).__name__}: {e}"} |
| if not modalities.get("ok"): |
| return modalities |
| try: |
| modalities["tensors"] = _to_terramind_tensors(modalities) |
| except Exception as e: |
| log.exception("eo_chip: tensor build failed") |
| return {"ok": False, |
| "err": f"tensor build failed: {type(e).__name__}: {e}"} |
| |
| |
| |
| |
| |
| try: |
| from pyproj import Transformer |
| half_m = (CHIP_PX * PIXEL_M) / 2.0 |
| t_to_utm = Transformer.from_crs( |
| "EPSG:4326", f"EPSG:{modalities['epsg']}", always_xy=True) |
| t_to_4326 = Transformer.from_crs( |
| f"EPSG:{modalities['epsg']}", "EPSG:4326", always_xy=True) |
| cx, cy = t_to_utm.transform(lon, lat) |
| corners_utm = [ |
| (cx - half_m, cy - half_m), |
| (cx - half_m, cy + half_m), |
| (cx + half_m, cy - half_m), |
| (cx + half_m, cy + half_m), |
| ] |
| corners_ll = [t_to_4326.transform(x, y) for x, y in corners_utm] |
| lons = [c[0] for c in corners_ll] |
| lats = [c[1] for c in corners_ll] |
| modalities["bounds_4326"] = ( |
| min(lons), min(lats), max(lons), max(lats)) |
| except Exception: |
| log.exception("eo_chip: bounds_4326 reprojection failed") |
| return modalities |
|
|
|
|
| def fetch(lat: float, lon: float, timeout_s: float = 60.0) -> dict[str, Any]: |
| """Run the chip pipeline. Always returns a dict with at minimum |
| `{ok, skipped|err, ...}`; on success the dict carries the |
| co-registered numpy arrays plus `tensors` (the TerraMind-shaped |
| torch dict). |
| |
| Runs in a daemon thread so that STAC searches and COG band downloads |
| (which use requests/rioxarray without per-call timeouts) are bounded |
| by a hard wall-clock deadline even when the network hangs. |
| """ |
| if not ENABLE: |
| return {"ok": False, "skipped": "RIPRAP_EO_CHIP_ENABLE=0"} |
| if not _DEPS_OK: |
| return {"ok": False, |
| "skipped": f"deps unavailable on this deployment: " |
| f"{_DEPS_MISSING}"} |
| |
| |
| hard_timeout = timeout_s + 15.0 |
| with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: |
| future = pool.submit(_fetch_and_build, lat, lon, timeout_s) |
| try: |
| return future.result(timeout=hard_timeout) |
| except concurrent.futures.TimeoutError: |
| log.warning("eo_chip: hard timeout after %.0fs (STAC/COG hung)", hard_timeout) |
| return {"ok": False, "skipped": f"eo_chip timed out after {hard_timeout:.0f}s"} |
|
|