""" ClauseGuard — JWT Authentication for FastAPI Validates Supabase JWTs using JWKS (RS256). """ import os from functools import lru_cache from typing import Optional import httpx from fastapi import Depends, HTTPException, Header from jose import jwt, JWTError, jwk from jose.utils import base64url_decode SUPABASE_URL = os.environ.get("SUPABASE_URL", "") SUPABASE_JWT_SECRET = os.environ.get("SUPABASE_JWT_SECRET", "") _jwks_cache: Optional[dict] = None async def _get_jwks() -> dict: """Fetch and cache JWKS from Supabase.""" global _jwks_cache if _jwks_cache: return _jwks_cache if not SUPABASE_URL: return {} async with httpx.AsyncClient() as client: resp = await client.get(f"{SUPABASE_URL}/auth/v1/jwks") resp.raise_for_status() _jwks_cache = resp.json() return _jwks_cache def _verify_with_secret(token: str) -> dict: """Verify JWT using Supabase JWT secret (HS256 fallback).""" if not SUPABASE_JWT_SECRET: raise HTTPException(status_code=500, detail="JWT secret not configured") try: payload = jwt.decode( token, SUPABASE_JWT_SECRET, algorithms=["HS256"], audience="authenticated", ) return payload except JWTError as e: raise HTTPException(status_code=401, detail=f"Invalid token: {e}") async def _verify_with_jwks(token: str) -> dict: """Verify JWT using JWKS endpoint (RS256).""" jwks = await _get_jwks() if not jwks or "keys" not in jwks: return _verify_with_secret(token) try: unverified_header = jwt.get_unverified_header(token) kid = unverified_header.get("kid") key = None for k in jwks["keys"]: if k.get("kid") == kid: key = k break if not key: raise HTTPException(status_code=401, detail="Token key ID not found in JWKS") payload = jwt.decode( token, key, algorithms=["RS256"], audience="authenticated", ) return payload except JWTError as e: raise HTTPException(status_code=401, detail=f"Invalid token: {e}") async def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[dict]: """ Extract and validate user from Authorization header. Returns None for unauthenticated requests (free tier). Raises 401 for invalid tokens. """ if not authorization: return None token = authorization.replace("Bearer ", "") if not token: return None try: payload = await _verify_with_jwks(token) return { "id": payload.get("sub"), "email": payload.get("email"), "role": payload.get("role", "authenticated"), } except HTTPException: raise except Exception: raise HTTPException(status_code=401, detail="Authentication failed") async def require_auth(user: Optional[dict] = Depends(get_current_user)) -> dict: """Dependency that requires a valid authenticated user.""" if not user: raise HTTPException(status_code=401, detail="Authentication required") return user