Spaces:
Paused
Paused
| """ | |
| FastAPI entry point for MicroClimate-X. | |
| Endpoints | |
| --------- | |
| GET / β name / version / banner | |
| GET /api/predict β main prediction endpoint (?lat=&lon=&activity=) | |
| GET /api/health β JSON health + cache stats + DB latency | |
| GET /api/version β version metadata for clients | |
| Lifespan | |
| -------- | |
| * On startup: WAL-mode SQLite init, prune expired cache rows, load ML model. | |
| * On shutdown: dispose of the shared httpx.AsyncClient. | |
| Resilience | |
| ---------- | |
| * `RequestIDMiddleware` stamps every request with `X-Request-ID` for log | |
| correlation (taken from incoming header if present, otherwise generated). | |
| * All exceptions surface as a `errors.ErrorResponse` JSON document β no | |
| bare 500 HTML responses leak. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import datetime as _dt | |
| import logging | |
| import math | |
| import time | |
| import uuid | |
| from contextlib import asynccontextmanager | |
| from typing import Any | |
| import httpx | |
| from fastapi import FastAPI, HTTPException, Query, Request | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from tenacity import RetryError, retry, stop_after_attempt, wait_exponential | |
| from . import cache, config, rule_engine, terrain | |
| from .errors import ( | |
| ERR_INTERNAL, | |
| ERR_INVALID_INPUT, | |
| ERR_UPSTREAM_FAILURE, | |
| ErrorResponse, | |
| ) | |
| from .ml_engine import MLEngine | |
| from .schemas import ActivityType, PredictionResponse | |
| __version__ = "1.0.0" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Logging β structured records: ts | level | request_id | message | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _RequestIDFilter(logging.Filter): | |
| def filter(self, record: logging.LogRecord) -> bool: | |
| if not hasattr(record, "request_id"): | |
| record.request_id = "-" | |
| return True | |
| _handler = logging.StreamHandler() | |
| _handler.setFormatter(logging.Formatter( | |
| "%(asctime)s | %(levelname)-7s | %(request_id)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%dT%H:%M:%S", | |
| )) | |
| _handler.addFilter(_RequestIDFilter()) | |
| logging.basicConfig(level=logging.INFO, handlers=[_handler], force=True) | |
| log = logging.getLogger("microclimate-x") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Lifespan: model + DB + HTTP client | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| log.info("Starting MicroClimate-X backend (v%s)β¦", __version__) | |
| await cache.init_db() | |
| pruned = await cache.prune_expired() | |
| if pruned: | |
| log.info("Cache GC removed %d expired rows on startup.", pruned) | |
| engine = MLEngine() | |
| engine.load() | |
| if engine.is_loaded: | |
| log.info("ML model loaded from %s", engine.loaded_from) | |
| else: | |
| log.warning( | |
| "No trained model found β falling back to heuristic predictor. " | |
| "Run scripts/3_train_model.py to enable Random Forest." | |
| ) | |
| app.state.ml = engine | |
| app.state.http = httpx.AsyncClient(timeout=15.0, http2=False) | |
| app.state.start_ts = time.time() | |
| try: | |
| yield | |
| finally: | |
| await app.state.http.aclose() | |
| log.info("Shutdown complete.") | |
| app = FastAPI( | |
| title="MicroClimate-X API", | |
| version=__version__, | |
| description=( | |
| "Hybrid microclimate risk assessment for complex terrain. " | |
| "Combines a Random Forest macro-rain predictor with a topographic " | |
| "rule-based expert system (Veto cascade + R1-R4 decision table " | |
| "+ activity-aware composite). " | |
| "Implements proposal Β§3.7 β sub-process P4.1 through P4.6." | |
| ), | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["GET"], | |
| allow_headers=["*"], | |
| expose_headers=["X-Request-ID", "X-Response-Time-ms"], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Request-ID + timing middleware | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RequestIDMiddleware(BaseHTTPMiddleware): | |
| """Tag every request with `X-Request-ID` and measure latency. | |
| The ID propagates from incoming headers (so a load-balancer / front-end | |
| can supply one) and falls back to a new UUID4 prefix. | |
| """ | |
| async def dispatch(self, request: Request, call_next): | |
| req_id = request.headers.get("x-request-id") or uuid.uuid4().hex[:12] | |
| # Stash on request state so handlers can read it. | |
| request.state.request_id = req_id | |
| start = time.perf_counter() | |
| try: | |
| response = await call_next(request) | |
| except Exception: # pragma: no cover | |
| elapsed_ms = int((time.perf_counter() - start) * 1000) | |
| log.exception( | |
| "unhandled exception", | |
| extra={"request_id": req_id, "path": request.url.path, | |
| "elapsed_ms": elapsed_ms}, | |
| ) | |
| return _json_error( | |
| req_id, 500, ERR_INTERNAL, | |
| "Internal server error β please retry.", | |
| ) | |
| elapsed_ms = int((time.perf_counter() - start) * 1000) | |
| response.headers["X-Request-ID"] = req_id | |
| response.headers["X-Response-Time-ms"] = str(elapsed_ms) | |
| # Only log non-static-asset, non-OPTIONS for noise control. | |
| if request.url.path.startswith("/api/") or request.url.path in {"/"}: | |
| log.info( | |
| "%s %s -> %d (%d ms)", | |
| request.method, request.url.path, response.status_code, elapsed_ms, | |
| extra={"request_id": req_id}, | |
| ) | |
| return response | |
| app.add_middleware(RequestIDMiddleware) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Exception handlers β every error follows the ErrorResponse schema | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _json_error(req_id: str | None, status: int, code: str, detail: str, | |
| ctx: dict[str, Any] | None = None) -> JSONResponse: | |
| payload = ErrorResponse(error=code, detail=detail, request_id=req_id, context=ctx) | |
| return JSONResponse(status_code=status, content=payload.model_dump(exclude_none=True)) | |
| async def _on_validation_error(request: Request, exc: RequestValidationError): | |
| req_id = getattr(request.state, "request_id", None) | |
| return _json_error( | |
| req_id, 422, ERR_INVALID_INPUT, | |
| "One or more query parameters failed validation.", | |
| ctx={"errors": exc.errors()[:5]}, | |
| ) | |
| async def _on_http_exception(request: Request, exc: HTTPException): | |
| req_id = getattr(request.state, "request_id", None) | |
| code = ( | |
| ERR_UPSTREAM_FAILURE if exc.status_code in {502, 503, 504} | |
| else ERR_INVALID_INPUT if exc.status_code in {400, 422} | |
| else ERR_INTERNAL | |
| ) | |
| return _json_error(req_id, exc.status_code, code, str(exc.detail)) | |
| async def _on_unhandled(request: Request, exc: Exception): # pragma: no cover | |
| req_id = getattr(request.state, "request_id", None) | |
| log.exception("unhandled top-level exception", | |
| extra={"request_id": req_id or "-"}) | |
| return _json_error( | |
| req_id, 500, ERR_INTERNAL, | |
| "Internal server error β please retry. If the problem persists, file an issue.", | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Frontend static files (optional β only if /frontend exists alongside backend) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| FRONTEND_DIR = config.ROOT / "frontend" | |
| if FRONTEND_DIR.exists(): | |
| app.mount("/app", StaticFiles(directory=FRONTEND_DIR, html=True), name="frontend") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Health & version & root | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root() -> dict[str, Any]: | |
| return { | |
| "name": "MicroClimate-X", | |
| "version": __version__, | |
| "ml_loaded": app.state.ml.is_loaded, | |
| "frontend_url": "/app/", | |
| "docs_url": "/docs", | |
| "openapi_url": "/openapi.json", | |
| } | |
| async def version() -> dict[str, Any]: | |
| return { | |
| "version": __version__, | |
| "git_revision": config.GIT_REVISION, | |
| "ml_loaded": app.state.ml.is_loaded, | |
| "ml_loaded_from": app.state.ml.loaded_from, | |
| "ml_features": [*app.state.ml.feature_columns[:5], "β¦"] | |
| if len(app.state.ml.feature_columns) > 5 | |
| else app.state.ml.feature_columns, | |
| } | |
| async def health() -> dict[str, Any]: | |
| stats = await cache.cache_stats() | |
| return { | |
| "status": "ok", | |
| "uptime_sec": int(time.time() - app.state.start_ts), | |
| "ml_loaded": app.state.ml.is_loaded, | |
| "cache": stats, | |
| "db_path": str(config.DB_PATH), | |
| "version": __version__, | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # External fetching helpers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _fetch_current_weather(client: httpx.AsyncClient, lat: float, lon: float) -> dict[str, Any]: | |
| resp = await client.get( | |
| config.OPEN_METEO_FORECAST_URL, | |
| params={ | |
| "latitude": lat, | |
| "longitude": lon, | |
| "current": ",".join([ | |
| "temperature_2m", "relative_humidity_2m", "precipitation", | |
| "wind_speed_10m", "wind_direction_10m", "surface_pressure", | |
| "dew_point_2m", "cloud_cover", "cape", "visibility", | |
| ]), | |
| "windspeed_unit": "kmh", | |
| "timezone": "auto", | |
| }, | |
| timeout=15.0, | |
| ) | |
| resp.raise_for_status() | |
| raw = resp.json().get("current", {}) | |
| return { | |
| "temperature_c": raw.get("temperature_2m"), | |
| "humidity_pct": raw.get("relative_humidity_2m"), | |
| "precipitation_mm": raw.get("precipitation", 0.0), | |
| "wind_speed_kmh": raw.get("wind_speed_10m", 0.0), | |
| "wind_direction_deg": raw.get("wind_direction_10m", 0.0), | |
| "pressure_hpa": raw.get("surface_pressure"), | |
| "dew_point_c": raw.get("dew_point_2m"), | |
| "cloud_cover_pct": raw.get("cloud_cover", 0.0), | |
| "cape_jkg": raw.get("cape", 0.0), | |
| "visibility_m": raw.get("visibility", 10000.0), | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main endpoint | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def predict( | |
| request: Request, | |
| lat: float = Query(..., ge=-90.0, le=90.0, description="Latitude (WGS84)"), | |
| lon: float = Query(..., ge=-180.0, le=180.0, description="Longitude (WGS84)"), | |
| activity: ActivityType = Query( | |
| "general", | |
| description="User activity context β affects composite score weighting (D5 Β§3.7 / P4.4).", | |
| ), | |
| ) -> PredictionResponse: | |
| req_id = getattr(request.state, "request_id", "-") | |
| # ββ Cache lookup first (per-coordinate + per-activity) ββ | |
| hit = await cache.get(lat, lon, activity=activity) | |
| if hit is not None: | |
| payload, ttl_remaining = hit | |
| payload["cached"] = True | |
| payload["cache_ttl"] = ttl_remaining | |
| log.info("cache hit (ttl_remaining=%ds)", ttl_remaining, extra={"request_id": req_id}) | |
| return PredictionResponse(**payload) | |
| client: httpx.AsyncClient = app.state.http | |
| # ββ Fetch DEM (terrain) and weather in parallel ββ | |
| try: | |
| dem9, weather = await asyncio.gather( | |
| terrain.fetch_dem_3x3(lat, lon, client), | |
| _fetch_current_weather(client, lat, lon), | |
| ) | |
| except (httpx.HTTPError, RetryError, ValueError) as exc: | |
| log.warning( | |
| "upstream API failure: %s", | |
| type(exc).__name__, | |
| extra={"request_id": req_id}, | |
| ) | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"Upstream weather/DEM service unavailable ({type(exc).__name__}). " | |
| f"Please retry shortly.", | |
| ) from exc | |
| tinfo = terrain.classify_terrain(dem9) | |
| orographic_dot = ( | |
| terrain.orographic_lift_dot( | |
| weather.get("wind_direction_deg", 0.0), | |
| tinfo.aspect_deg, | |
| tinfo.slope_deg, | |
| ) | |
| if tinfo.terrain == "Slope" else 0.0 | |
| ) | |
| # ββ Build ML feature dict ββ | |
| feats = _build_ml_features(weather, tinfo.elevation_m) | |
| try: | |
| ml_prob = app.state.ml.predict_rain_probability(feats) | |
| except Exception as exc: # pragma: no cover | |
| log.exception("ML inference failed", extra={"request_id": req_id}) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Model inference failed: {exc!r}", | |
| ) from exc | |
| # ββ Apply Rule Engine ββ | |
| rule_result = rule_engine.evaluate( | |
| lat=lat, | |
| lon=lon, | |
| elevation_m=tinfo.elevation_m, | |
| terrain=tinfo.terrain, | |
| weather=weather, | |
| ml_rain_prob=ml_prob, | |
| slope_deg=tinfo.slope_deg, | |
| aspect_deg=tinfo.aspect_deg, | |
| orographic_dot=orographic_dot, | |
| activity=activity, | |
| ) | |
| # ββ Assemble response ββ | |
| ttl = cache.adaptive_ttl(rule_result.risk_score, rule_result.has_veto) | |
| response = PredictionResponse( | |
| latitude=lat, | |
| longitude=lon, | |
| elevation_m=tinfo.elevation_m, | |
| terrain=tinfo.terrain, | |
| ml_rain_probability=ml_prob, | |
| hazard_subscores=rule_result.hazard_subscores, | |
| decision_table_matches=rule_result.decision_table_matches, | |
| activity=rule_result.activity, | |
| risk_score=rule_result.risk_score, | |
| risk_level=rule_result.risk_level, | |
| veto_triggers=rule_result.veto_triggers, | |
| inference_log=rule_result.inference_log, | |
| advice_en=rule_result.advice_en, | |
| advice_zh=rule_result.advice_zh, | |
| cached=False, | |
| cache_ttl=ttl, | |
| ) | |
| # ββ Cache + audit-log (fire-and-forget β never blocks the response) ββ | |
| payload_dump = response.model_dump(mode="json") | |
| _bg_tasks: set[asyncio.Task[Any]] = getattr(request.app.state, "bg_tasks", None) or set() | |
| request.app.state.bg_tasks = _bg_tasks | |
| for coro in ( | |
| cache.set(lat, lon, payload_dump, ttl, activity=activity), | |
| cache.log_inference( | |
| lat, lon, rule_result.risk_score, rule_result.has_veto, | |
| rule_result.advice_en, | |
| ), | |
| ): | |
| task = asyncio.create_task(coro) | |
| _bg_tasks.add(task) | |
| task.add_done_callback(_bg_tasks.discard) | |
| return response | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helpers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_ml_features(weather: dict[str, Any], elevation_m: float) -> dict[str, float]: | |
| """Mirror of `scripts/2_preprocess.py` β keep features in sync with training.""" | |
| now = _dt.datetime.now() | |
| feats = dict(weather) | |
| feats["elevation_m"] = elevation_m | |
| wind_kmh = weather.get("wind_speed_kmh", 0.0) or 0.0 | |
| wind_dir = weather.get("wind_direction_deg", 0.0) or 0.0 | |
| feats["wind_u"] = wind_kmh * math.sin(math.radians(wind_dir)) | |
| feats["wind_v"] = wind_kmh * math.cos(math.radians(wind_dir)) | |
| feats["hour_sin"] = math.sin(2 * math.pi * now.hour / 24.0) | |
| feats["hour_cos"] = math.cos(2 * math.pi * now.hour / 24.0) | |
| feats["month_sin"] = math.sin(2 * math.pi * now.month / 12.0) | |
| feats["month_cos"] = math.cos(2 * math.pi * now.month / 12.0) | |
| temp = weather.get("temperature_c") or 25.0 | |
| dew = weather.get("dew_point_c") or temp | |
| feats["dew_point_depression"] = temp - dew | |
| feats["pressure_change_3h"] = 0.0 # set by historical training; 0 at inference | |
| feats["precipitation_lag_1h"] = weather.get("precipitation_mm", 0.0) or 0.0 | |
| return feats | |