File size: 3,241 Bytes
e3f2df1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
ClauseGuard — JWT Authentication for FastAPI
Validates Supabase JWTs using JWKS (RS256).
"""

import os
from functools import lru_cache
from typing import Optional

import httpx
from fastapi import Depends, HTTPException, Header
from jose import jwt, JWTError, jwk
from jose.utils import base64url_decode

SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
SUPABASE_JWT_SECRET = os.environ.get("SUPABASE_JWT_SECRET", "")

_jwks_cache: Optional[dict] = None


async def _get_jwks() -> dict:
    """Fetch and cache JWKS from Supabase."""
    global _jwks_cache
    if _jwks_cache:
        return _jwks_cache

    if not SUPABASE_URL:
        return {}

    async with httpx.AsyncClient() as client:
        resp = await client.get(f"{SUPABASE_URL}/auth/v1/jwks")
        resp.raise_for_status()
        _jwks_cache = resp.json()
        return _jwks_cache


def _verify_with_secret(token: str) -> dict:
    """Verify JWT using Supabase JWT secret (HS256 fallback)."""
    if not SUPABASE_JWT_SECRET:
        raise HTTPException(status_code=500, detail="JWT secret not configured")

    try:
        payload = jwt.decode(
            token,
            SUPABASE_JWT_SECRET,
            algorithms=["HS256"],
            audience="authenticated",
        )
        return payload
    except JWTError as e:
        raise HTTPException(status_code=401, detail=f"Invalid token: {e}")


async def _verify_with_jwks(token: str) -> dict:
    """Verify JWT using JWKS endpoint (RS256)."""
    jwks = await _get_jwks()
    if not jwks or "keys" not in jwks:
        return _verify_with_secret(token)

    try:
        unverified_header = jwt.get_unverified_header(token)
        kid = unverified_header.get("kid")

        key = None
        for k in jwks["keys"]:
            if k.get("kid") == kid:
                key = k
                break

        if not key:
            raise HTTPException(status_code=401, detail="Token key ID not found in JWKS")

        payload = jwt.decode(
            token,
            key,
            algorithms=["RS256"],
            audience="authenticated",
        )
        return payload
    except JWTError as e:
        raise HTTPException(status_code=401, detail=f"Invalid token: {e}")


async def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[dict]:
    """
    Extract and validate user from Authorization header.
    Returns None for unauthenticated requests (free tier).
    Raises 401 for invalid tokens.
    """
    if not authorization:
        return None

    token = authorization.replace("Bearer ", "")
    if not token:
        return None

    try:
        payload = await _verify_with_jwks(token)
        return {
            "id": payload.get("sub"),
            "email": payload.get("email"),
            "role": payload.get("role", "authenticated"),
        }
    except HTTPException:
        raise
    except Exception:
        raise HTTPException(status_code=401, detail="Authentication failed")


async def require_auth(user: Optional[dict] = Depends(get_current_user)) -> dict:
    """Dependency that requires a valid authenticated user."""
    if not user:
        raise HTTPException(status_code=401, detail="Authentication required")
    return user