| from fastapi import Request, Response |
| from fastapi.responses import JSONResponse |
| from starlette.middleware.base import BaseHTTPMiddleware |
| from collections import defaultdict |
| import time |
| import asyncio |
|
|
| from Backend.core.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class SecurityHeadersMiddleware(BaseHTTPMiddleware): |
| async def dispatch(self, request: Request, call_next): |
| response = await call_next(request) |
| |
| response.headers["X-Content-Type-Options"] = "nosniff" |
| response.headers["X-Frame-Options"] = "DENY" |
| response.headers["X-XSS-Protection"] = "1; mode=block" |
| response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" |
| response.headers["Permissions-Policy"] = "geolocation=(self), camera=(self)" |
| |
| if request.url.scheme == "https": |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" |
| |
| return response |
|
|
|
|
| class RateLimitMiddleware(BaseHTTPMiddleware): |
| def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 10): |
| super().__init__(app) |
| self.requests_per_minute = requests_per_minute |
| self.burst_limit = burst_limit |
| self.requests = defaultdict(list) |
| self.lock = asyncio.Lock() |
| |
| async def dispatch(self, request: Request, call_next): |
| client_ip = request.client.host if request.client else "unknown" |
| current_time = time.time() |
| |
| async with self.lock: |
| self.requests[client_ip] = [ |
| t for t in self.requests[client_ip] |
| if current_time - t < 60 |
| ] |
| |
| if len(self.requests[client_ip]) >= self.requests_per_minute: |
| logger.warning(f"Rate limit exceeded for {client_ip}") |
| return JSONResponse( |
| status_code=429, |
| content={"detail": "Too many requests. Please slow down."}, |
| headers={"Retry-After": "60"} |
| ) |
| |
| recent_requests = [t for t in self.requests[client_ip] if current_time - t < 1] |
| if len(recent_requests) >= self.burst_limit: |
| logger.warning(f"Burst limit exceeded for {client_ip}") |
| return JSONResponse( |
| status_code=429, |
| content={"detail": "Too many requests. Please slow down."}, |
| headers={"Retry-After": "1"} |
| ) |
| |
| self.requests[client_ip].append(current_time) |
| |
| return await call_next(request) |
|
|
|
|
| class RequestValidationMiddleware(BaseHTTPMiddleware): |
| MAX_CONTENT_LENGTH = 50 * 1024 * 1024 |
| |
| async def dispatch(self, request: Request, call_next): |
| content_length = request.headers.get("content-length") |
| if content_length and int(content_length) > self.MAX_CONTENT_LENGTH: |
| return JSONResponse( |
| status_code=413, |
| content={"detail": "Request entity too large"} |
| ) |
| |
| return await call_next(request) |
|
|