Spaces:
Sleeping
Sleeping
File size: 3,241 Bytes
e3f2df1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | """
ClauseGuard — JWT Authentication for FastAPI
Validates Supabase JWTs using JWKS (RS256).
"""
import os
from functools import lru_cache
from typing import Optional
import httpx
from fastapi import Depends, HTTPException, Header
from jose import jwt, JWTError, jwk
from jose.utils import base64url_decode
SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
SUPABASE_JWT_SECRET = os.environ.get("SUPABASE_JWT_SECRET", "")
_jwks_cache: Optional[dict] = None
async def _get_jwks() -> dict:
"""Fetch and cache JWKS from Supabase."""
global _jwks_cache
if _jwks_cache:
return _jwks_cache
if not SUPABASE_URL:
return {}
async with httpx.AsyncClient() as client:
resp = await client.get(f"{SUPABASE_URL}/auth/v1/jwks")
resp.raise_for_status()
_jwks_cache = resp.json()
return _jwks_cache
def _verify_with_secret(token: str) -> dict:
"""Verify JWT using Supabase JWT secret (HS256 fallback)."""
if not SUPABASE_JWT_SECRET:
raise HTTPException(status_code=500, detail="JWT secret not configured")
try:
payload = jwt.decode(
token,
SUPABASE_JWT_SECRET,
algorithms=["HS256"],
audience="authenticated",
)
return payload
except JWTError as e:
raise HTTPException(status_code=401, detail=f"Invalid token: {e}")
async def _verify_with_jwks(token: str) -> dict:
"""Verify JWT using JWKS endpoint (RS256)."""
jwks = await _get_jwks()
if not jwks or "keys" not in jwks:
return _verify_with_secret(token)
try:
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
key = None
for k in jwks["keys"]:
if k.get("kid") == kid:
key = k
break
if not key:
raise HTTPException(status_code=401, detail="Token key ID not found in JWKS")
payload = jwt.decode(
token,
key,
algorithms=["RS256"],
audience="authenticated",
)
return payload
except JWTError as e:
raise HTTPException(status_code=401, detail=f"Invalid token: {e}")
async def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[dict]:
"""
Extract and validate user from Authorization header.
Returns None for unauthenticated requests (free tier).
Raises 401 for invalid tokens.
"""
if not authorization:
return None
token = authorization.replace("Bearer ", "")
if not token:
return None
try:
payload = await _verify_with_jwks(token)
return {
"id": payload.get("sub"),
"email": payload.get("email"),
"role": payload.get("role", "authenticated"),
}
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=401, detail="Authentication failed")
async def require_auth(user: Optional[dict] = Depends(get_current_user)) -> dict:
"""Dependency that requires a valid authenticated user."""
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user
|