# main.py - FastAPI application for Flood Vulnerability Assessment from fastapi import FastAPI, File, UploadFile, HTTPException, Request from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from typing import Dict import pandas as pd import io import asyncio import numpy as np from concurrent.futures import ThreadPoolExecutor from api.models import SingleAssessment from api.batch import process_single_row, process_single_row_multihazard from spatial_queries import get_terrain_metrics, distance_to_water from vulnerability import calculate_vulnerability_index, calculate_multi_hazard_vulnerability from gee_auth import initialize_gee from height_predictor.inference import get_predictor # SHAP Explainer Initialization try: from explainability import VulnerabilityExplainer explainer = VulnerabilityExplainer() # Automatically loads rf_explainer.pkl if present print("✅ SHAP model initialized successfully.") except Exception as e: print(f"⚠️ SHAP explainer not available: {e}") explainer = None # Initialize GEE once at startup try: initialize_gee() print("✅ GEE initialized once at startup.") except Exception as e: print(f"⚠️ GEE initialization failed at startup: {e}") # APP INITIALIZATION app = FastAPI(title="Flood Vulnerability Assessment API", version="1.0") # Mount static files directory app.mount("/static", StaticFiles(directory="static"), name="static") # Frontend templates setup templates = Jinja2Templates(directory="templates") # Thread pool for batch processing executor = ThreadPoolExecutor(max_workers=10) # GBA getter is initialized in api/batch.py # FRONTEND ROUTE @app.get("/", response_class=HTMLResponse) async def home(request: Request): """Serve the main web interface""" return templates.TemplateResponse("index.html", {"request": request}) # API ROUTES @app.get("/api") async def root() -> Dict: """API info endpoint""" return { "service": "Flood Vulnerability Assessment API", "version": "1.0", "endpoints": { "POST /assess": "Assess single location", "POST /assess_batch": "Assess batch from CSV file", "GET /health": "Health check" } } @app.post("/assess") async def assess_single(data: SingleAssessment) -> Dict: """Assess flood vulnerability for a single location (non-blocking).""" loop = asyncio.get_event_loop() try: # Run slow terrain + water queries in a background thread terrain, water_dist = await loop.run_in_executor( None, lambda: ( get_terrain_metrics(data.latitude, data.longitude), distance_to_water(data.latitude, data.longitude) ) ) # Calculate vulnerability after terrain + water distance retrieved result = calculate_vulnerability_index( lat=data.latitude, lon=data.longitude, height=data.height, basement=data.basement, terrain_metrics=terrain, water_distance=water_dist ) return { "status": "success", "input": data.dict(), "assessment": result } except Exception as e: raise HTTPException(status_code=500, detail=f"Assessment failed: {e}") @app.post("/predict_height") async def predict_height(data: SingleAssessment) -> Dict: try: predictor = get_predictor() result = predictor.predict_from_coordinates(data.latitude, data.longitude) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['error']) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Batch processing functions moved to api/batch.py @app.post("/assess_batch") async def assess_batch(file: UploadFile = File(...), use_predicted_height:bool=False, use_gba_height:bool=False) -> StreamingResponse: """Assess flood vulnerability for multiple locations from a CSV file.""" try: contents = await file.read() df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) if 'latitude' not in df.columns or 'longitude' not in df.columns: raise HTTPException( status_code=400, detail="CSV must contain 'latitude' and 'longitude' columns" ) df = df[(np.abs(df['latitude']) <= 90) & (np.abs(df['longitude']) <= 180)] if len(df) == 0: raise HTTPException(status_code=400, detail="No valid coordinates in CSV (lat -90..90, lon -180..180)") # Set defaults for optional columns if 'height' not in df.columns: df['height'] = 0.0 if 'basement' not in df.columns: df['basement'] = 0.0 loop = asyncio.get_event_loop() results = await loop.run_in_executor( executor, lambda: [process_single_row(row, use_predicted_height, use_gba_height) for _, row in df.iterrows()] ) results_df = pd.DataFrame(results) output = io.StringIO() results_df.to_csv(output, index=False) output.seek(0) return StreamingResponse( io.BytesIO(output.getvalue().encode('utf-8')), media_type="text/csv", headers={ "Content-Disposition": ( "attachment; filename=vulnerability_results.csv; " "filename*=UTF-8''vulnerability_results.csv" ) } ) except Exception as e: raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}") @app.post("/assess_batch_multihazard") async def assess_batch_multihazard(file: UploadFile = File(...), use_predicted_height: bool = False, use_gba_height: bool = False) -> StreamingResponse: try: contents = await file.read() df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) if 'latitude' not in df.columns or 'longitude' not in df.columns: raise HTTPException( status_code=400, detail="CSV must contain 'latitude' and 'longitude' columns" ) # Validate coordinates df = df[(np.abs(df['latitude']) <= 90) & (np.abs(df['longitude']) <= 180)] if len(df) == 0: raise HTTPException(status_code=400, detail="No valid coordinates in CSV (lat -90..90, lon -180..180)") # Set defaults for optional columns if 'height' not in df.columns: df['height'] = 0.0 if 'basement' not in df.columns: df['basement'] = 0.0 loop = asyncio.get_event_loop() results = await loop.run_in_executor( executor, lambda: [process_single_row_multihazard(row, use_predicted_height, use_gba_height) for _, row in df.iterrows()] ) results_df = pd.DataFrame(results) output = io.StringIO() results_df.to_csv(output, index=False) output.seek(0) return StreamingResponse( io.BytesIO(output.getvalue().encode('utf-8')), media_type="text/csv", headers={ "Content-Disposition": ( "attachment; filename=multihazard_results.csv; " "filename*=UTF-8''multihazard_results.csv" ) } ) except Exception as e: raise HTTPException(status_code=500, detail=f"Batch multihazard failed: {str(e)}") @app.post("/explain") async def explain_assessment(data: SingleAssessment) -> Dict: """Assess vulnerability with SHAP explanation""" loop = asyncio.get_event_loop() try: # Run slow terrain + water queries in a background thread terrain, water_dist = await loop.run_in_executor( None, lambda: ( get_terrain_metrics(data.latitude, data.longitude), distance_to_water(data.latitude, data.longitude) ) ) result = calculate_vulnerability_index( lat=data.latitude, lon=data.longitude, height=data.height, basement=data.basement, terrain_metrics=terrain, water_distance=water_dist ) # Generate explanation if explainer available explanation = None if explainer: try: explanation = explainer.explain(result['components']) except Exception as e: print(f"SHAP explanation failed: {e}") return { "status": "success", "input": data.dict(), "assessment": result, "explanation": explanation } except Exception as e: raise HTTPException(status_code=500, detail=f"Assessment failed: {e}") # Multi-hazard batch processing moved to api/batch.py @app.post("/assess_multihazard") async def assess_multihazard(data: SingleAssessment) -> Dict: """Multi-hazard assessment (fluvial + coastal + pluvial)""" loop = asyncio.get_event_loop() try: # Run slow terrain + water queries in a background thread terrain, water_dist = await loop.run_in_executor( None, lambda: ( get_terrain_metrics(data.latitude, data.longitude), distance_to_water(data.latitude, data.longitude) ) ) result = calculate_multi_hazard_vulnerability( lat=data.latitude, lon=data.longitude, height=data.height, basement=data.basement, terrain_metrics=terrain, water_distance=water_dist ) return { "status": "success", "input": data.dict(), "assessment": result } except Exception as e: raise HTTPException(status_code=500, detail=f"Assessment failed: {e}") @app.post("/get_height_gba") async def get_height_gba(data: SingleAssessment): try: from api.batch import gba_getter result = gba_getter.get_height_m(data.latitude, data.longitude, buffer_m=5.0) if result.get("status") != "success": raise HTTPException(status_code=404, detail="GBA height not found for this location. Please try predicting the height.") return result except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check() -> Dict: """Health check endpoint.""" return {"status": "healthy", "gee_initialized": True}