adema5051's picture
Update main.py
7dc690c verified
# main.py - FastAPI application for Flood Vulnerability Assessment
from fastapi import FastAPI, File, UploadFile, HTTPException, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from typing import Dict
import pandas as pd
import io
import asyncio
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from api.models import SingleAssessment
from api.batch import process_single_row, process_single_row_multihazard
from spatial_queries import get_terrain_metrics, distance_to_water
from vulnerability import calculate_vulnerability_index, calculate_multi_hazard_vulnerability
from gee_auth import initialize_gee
from height_predictor.inference import get_predictor
# SHAP Explainer Initialization
try:
from explainability import VulnerabilityExplainer
explainer = VulnerabilityExplainer() # Automatically loads rf_explainer.pkl if present
print("✅ SHAP model initialized successfully.")
except Exception as e:
print(f"⚠️ SHAP explainer not available: {e}")
explainer = None
# Initialize GEE once at startup
try:
initialize_gee()
print("✅ GEE initialized once at startup.")
except Exception as e:
print(f"⚠️ GEE initialization failed at startup: {e}")
# APP INITIALIZATION
app = FastAPI(title="Flood Vulnerability Assessment API", version="1.0")
# Mount static files directory
app.mount("/static", StaticFiles(directory="static"), name="static")
# Frontend templates setup
templates = Jinja2Templates(directory="templates")
# Thread pool for batch processing
executor = ThreadPoolExecutor(max_workers=10)
# GBA getter is initialized in api/batch.py
# FRONTEND ROUTE
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Serve the main web interface"""
return templates.TemplateResponse("index.html", {"request": request})
# API ROUTES
@app.get("/api")
async def root() -> Dict:
"""API info endpoint"""
return {
"service": "Flood Vulnerability Assessment API",
"version": "1.0",
"endpoints": {
"POST /assess": "Assess single location",
"POST /assess_batch": "Assess batch from CSV file",
"GET /health": "Health check"
}
}
@app.post("/assess")
async def assess_single(data: SingleAssessment) -> Dict:
"""Assess flood vulnerability for a single location (non-blocking)."""
loop = asyncio.get_event_loop()
try:
# Run slow terrain + water queries in a background thread
terrain, water_dist = await loop.run_in_executor(
None,
lambda: (
get_terrain_metrics(data.latitude, data.longitude),
distance_to_water(data.latitude, data.longitude)
)
)
# Calculate vulnerability after terrain + water distance retrieved
result = calculate_vulnerability_index(
lat=data.latitude,
lon=data.longitude,
height=data.height,
basement=data.basement,
terrain_metrics=terrain,
water_distance=water_dist
)
return {
"status": "success",
"input": data.dict(),
"assessment": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Assessment failed: {e}")
@app.post("/predict_height")
async def predict_height(data: SingleAssessment) -> Dict:
try:
predictor = get_predictor()
result = predictor.predict_from_coordinates(data.latitude, data.longitude)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['error'])
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Batch processing functions moved to api/batch.py
@app.post("/assess_batch")
async def assess_batch(file: UploadFile = File(...), use_predicted_height:bool=False, use_gba_height:bool=False) -> StreamingResponse:
"""Assess flood vulnerability for multiple locations from a CSV file."""
try:
contents = await file.read()
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
if 'latitude' not in df.columns or 'longitude' not in df.columns:
raise HTTPException(
status_code=400,
detail="CSV must contain 'latitude' and 'longitude' columns"
)
df = df[(np.abs(df['latitude']) <= 90) & (np.abs(df['longitude']) <= 180)]
if len(df) == 0:
raise HTTPException(status_code=400, detail="No valid coordinates in CSV (lat -90..90, lon -180..180)")
# Set defaults for optional columns
if 'height' not in df.columns:
df['height'] = 0.0
if 'basement' not in df.columns:
df['basement'] = 0.0
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
executor,
lambda: [process_single_row(row, use_predicted_height, use_gba_height) for _, row in df.iterrows()]
)
results_df = pd.DataFrame(results)
output = io.StringIO()
results_df.to_csv(output, index=False)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode('utf-8')),
media_type="text/csv",
headers={
"Content-Disposition": (
"attachment; filename=vulnerability_results.csv; "
"filename*=UTF-8''vulnerability_results.csv"
)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}")
@app.post("/assess_batch_multihazard")
async def assess_batch_multihazard(file: UploadFile = File(...), use_predicted_height: bool = False, use_gba_height: bool = False) -> StreamingResponse:
try:
contents = await file.read()
df = pd.read_csv(io.StringIO(contents.decode('utf-8')))
if 'latitude' not in df.columns or 'longitude' not in df.columns:
raise HTTPException(
status_code=400,
detail="CSV must contain 'latitude' and 'longitude' columns"
)
# Validate coordinates
df = df[(np.abs(df['latitude']) <= 90) & (np.abs(df['longitude']) <= 180)]
if len(df) == 0:
raise HTTPException(status_code=400, detail="No valid coordinates in CSV (lat -90..90, lon -180..180)")
# Set defaults for optional columns
if 'height' not in df.columns:
df['height'] = 0.0
if 'basement' not in df.columns:
df['basement'] = 0.0
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
executor,
lambda: [process_single_row_multihazard(row, use_predicted_height, use_gba_height) for _, row in df.iterrows()]
)
results_df = pd.DataFrame(results)
output = io.StringIO()
results_df.to_csv(output, index=False)
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode('utf-8')),
media_type="text/csv",
headers={
"Content-Disposition": (
"attachment; filename=multihazard_results.csv; "
"filename*=UTF-8''multihazard_results.csv"
)
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch multihazard failed: {str(e)}")
@app.post("/explain")
async def explain_assessment(data: SingleAssessment) -> Dict:
"""Assess vulnerability with SHAP explanation"""
loop = asyncio.get_event_loop()
try:
# Run slow terrain + water queries in a background thread
terrain, water_dist = await loop.run_in_executor(
None,
lambda: (
get_terrain_metrics(data.latitude, data.longitude),
distance_to_water(data.latitude, data.longitude)
)
)
result = calculate_vulnerability_index(
lat=data.latitude,
lon=data.longitude,
height=data.height,
basement=data.basement,
terrain_metrics=terrain,
water_distance=water_dist
)
# Generate explanation if explainer available
explanation = None
if explainer:
try:
explanation = explainer.explain(result['components'])
except Exception as e:
print(f"SHAP explanation failed: {e}")
return {
"status": "success",
"input": data.dict(),
"assessment": result,
"explanation": explanation
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Assessment failed: {e}")
# Multi-hazard batch processing moved to api/batch.py
@app.post("/assess_multihazard")
async def assess_multihazard(data: SingleAssessment) -> Dict:
"""Multi-hazard assessment (fluvial + coastal + pluvial)"""
loop = asyncio.get_event_loop()
try:
# Run slow terrain + water queries in a background thread
terrain, water_dist = await loop.run_in_executor(
None,
lambda: (
get_terrain_metrics(data.latitude, data.longitude),
distance_to_water(data.latitude, data.longitude)
)
)
result = calculate_multi_hazard_vulnerability(
lat=data.latitude,
lon=data.longitude,
height=data.height,
basement=data.basement,
terrain_metrics=terrain,
water_distance=water_dist
)
return {
"status": "success",
"input": data.dict(),
"assessment": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Assessment failed: {e}")
@app.post("/get_height_gba")
async def get_height_gba(data: SingleAssessment):
try:
from api.batch import gba_getter
result = gba_getter.get_height_m(data.latitude, data.longitude, buffer_m=5.0)
if result.get("status") != "success":
raise HTTPException(status_code=404, detail="GBA height not found for this location. Please try predicting the height.")
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check() -> Dict:
"""Health check endpoint."""
return {"status": "healthy", "gee_initialized": True}