Spaces:
Running
Running
| """ | |
| ClauseGuard — JWT Authentication for FastAPI | |
| Validates Supabase JWTs using JWKS (RS256). | |
| """ | |
| import os | |
| from functools import lru_cache | |
| from typing import Optional | |
| import httpx | |
| from fastapi import Depends, HTTPException, Header | |
| from jose import jwt, JWTError, jwk | |
| from jose.utils import base64url_decode | |
| SUPABASE_URL = os.environ.get("SUPABASE_URL", "") | |
| SUPABASE_JWT_SECRET = os.environ.get("SUPABASE_JWT_SECRET", "") | |
| _jwks_cache: Optional[dict] = None | |
| async def _get_jwks() -> dict: | |
| """Fetch and cache JWKS from Supabase.""" | |
| global _jwks_cache | |
| if _jwks_cache: | |
| return _jwks_cache | |
| if not SUPABASE_URL: | |
| return {} | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get(f"{SUPABASE_URL}/auth/v1/jwks") | |
| resp.raise_for_status() | |
| _jwks_cache = resp.json() | |
| return _jwks_cache | |
| def _verify_with_secret(token: str) -> dict: | |
| """Verify JWT using Supabase JWT secret (HS256 fallback).""" | |
| if not SUPABASE_JWT_SECRET: | |
| raise HTTPException(status_code=500, detail="JWT secret not configured") | |
| try: | |
| payload = jwt.decode( | |
| token, | |
| SUPABASE_JWT_SECRET, | |
| algorithms=["HS256"], | |
| audience="authenticated", | |
| ) | |
| return payload | |
| except JWTError as e: | |
| raise HTTPException(status_code=401, detail=f"Invalid token: {e}") | |
| async def _verify_with_jwks(token: str) -> dict: | |
| """Verify JWT using JWKS endpoint (RS256).""" | |
| jwks = await _get_jwks() | |
| if not jwks or "keys" not in jwks: | |
| return _verify_with_secret(token) | |
| try: | |
| unverified_header = jwt.get_unverified_header(token) | |
| kid = unverified_header.get("kid") | |
| key = None | |
| for k in jwks["keys"]: | |
| if k.get("kid") == kid: | |
| key = k | |
| break | |
| if not key: | |
| raise HTTPException(status_code=401, detail="Token key ID not found in JWKS") | |
| payload = jwt.decode( | |
| token, | |
| key, | |
| algorithms=["RS256"], | |
| audience="authenticated", | |
| ) | |
| return payload | |
| except JWTError as e: | |
| raise HTTPException(status_code=401, detail=f"Invalid token: {e}") | |
| async def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[dict]: | |
| """ | |
| Extract and validate user from Authorization header. | |
| Returns None for unauthenticated requests (free tier). | |
| Raises 401 for invalid tokens. | |
| """ | |
| if not authorization: | |
| return None | |
| token = authorization.replace("Bearer ", "") | |
| if not token: | |
| return None | |
| try: | |
| payload = await _verify_with_jwks(token) | |
| return { | |
| "id": payload.get("sub"), | |
| "email": payload.get("email"), | |
| "role": payload.get("role", "authenticated"), | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception: | |
| raise HTTPException(status_code=401, detail="Authentication failed") | |
| async def require_auth(user: Optional[dict] = Depends(get_current_user)) -> dict: | |
| """Dependency that requires a valid authenticated user.""" | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Authentication required") | |
| return user | |