ClauseGuard / api /auth.py
gaurv007's picture
v3: All missing features — auth callback, settings, PDF export, email (Resend), JWT auth, /api/history, SaulLM integration, extension icons, Supabase 0.10 breaking changes, Stripe v22
e3f2df1 verified
raw
history blame
3.24 kB
"""
ClauseGuard — JWT Authentication for FastAPI
Validates Supabase JWTs using JWKS (RS256).
"""
import os
from functools import lru_cache
from typing import Optional
import httpx
from fastapi import Depends, HTTPException, Header
from jose import jwt, JWTError, jwk
from jose.utils import base64url_decode
SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
SUPABASE_JWT_SECRET = os.environ.get("SUPABASE_JWT_SECRET", "")
_jwks_cache: Optional[dict] = None
async def _get_jwks() -> dict:
"""Fetch and cache JWKS from Supabase."""
global _jwks_cache
if _jwks_cache:
return _jwks_cache
if not SUPABASE_URL:
return {}
async with httpx.AsyncClient() as client:
resp = await client.get(f"{SUPABASE_URL}/auth/v1/jwks")
resp.raise_for_status()
_jwks_cache = resp.json()
return _jwks_cache
def _verify_with_secret(token: str) -> dict:
"""Verify JWT using Supabase JWT secret (HS256 fallback)."""
if not SUPABASE_JWT_SECRET:
raise HTTPException(status_code=500, detail="JWT secret not configured")
try:
payload = jwt.decode(
token,
SUPABASE_JWT_SECRET,
algorithms=["HS256"],
audience="authenticated",
)
return payload
except JWTError as e:
raise HTTPException(status_code=401, detail=f"Invalid token: {e}")
async def _verify_with_jwks(token: str) -> dict:
"""Verify JWT using JWKS endpoint (RS256)."""
jwks = await _get_jwks()
if not jwks or "keys" not in jwks:
return _verify_with_secret(token)
try:
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid")
key = None
for k in jwks["keys"]:
if k.get("kid") == kid:
key = k
break
if not key:
raise HTTPException(status_code=401, detail="Token key ID not found in JWKS")
payload = jwt.decode(
token,
key,
algorithms=["RS256"],
audience="authenticated",
)
return payload
except JWTError as e:
raise HTTPException(status_code=401, detail=f"Invalid token: {e}")
async def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[dict]:
"""
Extract and validate user from Authorization header.
Returns None for unauthenticated requests (free tier).
Raises 401 for invalid tokens.
"""
if not authorization:
return None
token = authorization.replace("Bearer ", "")
if not token:
return None
try:
payload = await _verify_with_jwks(token)
return {
"id": payload.get("sub"),
"email": payload.get("email"),
"role": payload.get("role", "authenticated"),
}
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=401, detail="Authentication failed")
async def require_auth(user: Optional[dict] = Depends(get_current_user)) -> dict:
"""Dependency that requires a valid authenticated user."""
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user