| from typing import Optional |
| from fastapi import Request, HTTPException, status |
| from fastapi.security.http import HTTPBearer |
| from jose import JWTError, jwt |
| from starlette.middleware.base import BaseHTTPMiddleware |
| from starlette.responses import Response |
| from starlette.requests import Request as StarletteRequest |
| from datetime import datetime |
| from src.core.config import settings |
| from src.auth.security import verify_token |
|
|
|
|
| class JWTMiddleware(BaseHTTPMiddleware): |
| """ |
| Middleware to verify JWT tokens for protected routes |
| """ |
|
|
| def __init__(self, app): |
| super().__init__(app) |
| self.http_bearer = HTTPBearer(auto_error=False) |
|
|
| async def dispatch(self, request: Request, call_next): |
| |
| public_routes = [ |
| "/", |
| "/docs", |
| "/redoc", |
| "/openapi.json", |
| "/health", |
| "/api/v1/login", |
| "/api/v1/register", |
| "/login", |
| "/register", |
| |
| ] |
|
|
| |
| is_public_route = any( |
| request.url.path.startswith(route) for route in public_routes |
| ) |
|
|
| |
| if request.method == "OPTIONS" or is_public_route: |
| response = await call_next(request) |
| return response |
|
|
| |
| auth_header = request.headers.get("Authorization") |
|
|
| if auth_header is None: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Authorization header is missing", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
|
|
| |
| try: |
| scheme, token = auth_header.split(" ") |
| if scheme.lower() != "bearer": |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Authorization scheme must be Bearer", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| except ValueError: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid authorization header format", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
|
|
| |
| payload = verify_token(token) |
| if payload is None: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid or expired token", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
|
|
| |
| exp_time = payload.get("exp") |
| if exp_time: |
| current_time = datetime.utcnow().timestamp() |
| if current_time >= exp_time: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Token has expired", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
|
|
| |
| request.state.user_id = payload.get("user_id") |
| request.state.user_role = payload.get("role", "user") |
| request.state.token_payload = ( |
| payload |
| ) |
|
|
| response = await call_next(request) |
| return response |
|
|
|
|
| |
| def get_jwt_middleware(): |
| return JWTMiddleware |
|
|
|
|
| |
| jwt_middleware = get_jwt_middleware() |
|
|