File size: 6,263 Bytes
79b2fcc
 
 
 
 
67b16c6
 
 
79b2fcc
67b16c6
 
 
6d7f53f
79b2fcc
67b16c6
 
 
 
 
 
 
 
 
79b2fcc
 
67b16c6
 
 
79b2fcc
 
 
 
 
 
 
 
67b16c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b2fcc
 
 
67b16c6
 
79b2fcc
 
 
 
67b16c6
 
 
 
 
79b2fcc
67b16c6
 
79b2fcc
 
 
67b16c6
 
 
 
 
 
 
8460d28
 
 
67b16c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b2fcc
 
 
 
 
67b16c6
79b2fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5af3ab5
79b2fcc
 
 
67b16c6
 
 
 
79b2fcc
 
 
 
67b16c6
 
79b2fcc
 
 
 
67b16c6
 
79b2fcc
 
 
 
 
 
 
6d7f53f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Authentication routes for HF OAuth.

Handles the OAuth 2.0 authorization code flow with HF as provider.
After successful auth, sets an HttpOnly cookie with the access token.
"""

import os
import secrets
import time
from urllib.parse import urlencode

import httpx
from dependencies import AUTH_ENABLED, check_org_membership, get_current_user
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse

router = APIRouter(prefix="/auth", tags=["auth"])

# OAuth configuration from environment
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")

# In-memory OAuth state store with expiry (5 min TTL)
_OAUTH_STATE_TTL = 300
oauth_states: dict[str, dict] = {}


def _cleanup_expired_states() -> None:
    """Remove expired OAuth states to prevent memory growth."""
    now = time.time()
    expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)]
    for k in expired:
        del oauth_states[k]


def get_redirect_uri(request: Request) -> str:
    """Get the OAuth callback redirect URI."""
    # In HF Spaces, use the SPACE_HOST if available
    space_host = os.environ.get("SPACE_HOST")
    if space_host:
        return f"https://{space_host}/auth/callback"
    # Otherwise construct from request
    return str(request.url_for("oauth_callback"))


@router.get("/login")
async def oauth_login(request: Request) -> RedirectResponse:
    """Initiate OAuth login flow."""
    if not OAUTH_CLIENT_ID:
        raise HTTPException(
            status_code=500,
            detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.",
        )

    # Clean up expired states to prevent memory growth
    _cleanup_expired_states()

    # Generate state for CSRF protection
    state = secrets.token_urlsafe(32)
    oauth_states[state] = {
        "redirect_uri": get_redirect_uri(request),
        "expires_at": time.time() + _OAUTH_STATE_TTL,
    }

    # Build authorization URL
    params = {
        "client_id": OAUTH_CLIENT_ID,
        "redirect_uri": get_redirect_uri(request),
        "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions",
        "response_type": "code",
        "state": state,
        "orgIds": os.environ.get(
            "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1"
        ),  # ml-agent-explorers
    }
    auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}"

    return RedirectResponse(url=auth_url)


@router.get("/callback")
async def oauth_callback(
    request: Request, code: str = "", state: str = ""
) -> RedirectResponse:
    """Handle OAuth callback."""
    # Verify state
    if state not in oauth_states:
        raise HTTPException(status_code=400, detail="Invalid state parameter")

    stored_state = oauth_states.pop(state)
    redirect_uri = stored_state["redirect_uri"]

    if not code:
        raise HTTPException(status_code=400, detail="No authorization code provided")

    # Exchange code for token
    token_url = f"{OPENID_PROVIDER_URL}/oauth/token"
    async with httpx.AsyncClient() as client:
        try:
            response = await client.post(
                token_url,
                data={
                    "grant_type": "authorization_code",
                    "code": code,
                    "redirect_uri": redirect_uri,
                    "client_id": OAUTH_CLIENT_ID,
                    "client_secret": OAUTH_CLIENT_SECRET,
                },
            )
            response.raise_for_status()
            token_data = response.json()
        except httpx.HTTPError as e:
            raise HTTPException(status_code=500, detail=f"Token exchange failed: {e}")

    # Get user info
    access_token = token_data.get("access_token")
    if not access_token:
        raise HTTPException(
            status_code=500,
            detail="Token exchange succeeded but no access_token was returned.",
        )

    # Fetch user info (optional — failure is not fatal)
    async with httpx.AsyncClient() as client:
        try:
            userinfo_response = await client.get(
                f"{OPENID_PROVIDER_URL}/oauth/userinfo",
                headers={"Authorization": f"Bearer {access_token}"},
            )
            userinfo_response.raise_for_status()
        except httpx.HTTPError:
            pass  # user_info not required for auth flow

    # Set access token as HttpOnly cookie (not in URL — avoids leaks via
    # Referrer headers, browser history, and server logs)
    is_production = bool(os.environ.get("SPACE_HOST"))
    response = RedirectResponse(url="/", status_code=302)
    response.set_cookie(
        key="hf_access_token",
        value=access_token,
        httponly=True,
        secure=is_production,  # Secure flag only in production (HTTPS)
        samesite="lax",
        max_age=3600 * 24 * 7,  # 7 days
        path="/",
    )
    return response


@router.get("/logout")
async def logout() -> RedirectResponse:
    """Log out the user by clearing the auth cookie."""
    response = RedirectResponse(url="/")
    response.delete_cookie(key="hf_access_token", path="/")
    return response


@router.get("/status")
async def auth_status() -> dict:
    """Check if OAuth is enabled on this instance."""
    return {"auth_enabled": AUTH_ENABLED}


@router.get("/me")
async def get_me(user: dict = Depends(get_current_user)) -> dict:
    """Get current user info. Returns the authenticated user or dev user.

    Uses the shared auth dependency which handles cookie + Bearer token.
    """
    return user


ORG_NAME = "ml-agent-explorers"


@router.get("/org-membership")
async def org_membership(
    request: Request, user: dict = Depends(get_current_user)
) -> dict:
    """Check if the authenticated user belongs to the ml-agent-explorers org."""
    if not AUTH_ENABLED:
        return {"is_member": True}
    token = request.cookies.get("hf_access_token") or ""
    if not token:
        return {"is_member": False}
    is_member = await check_org_membership(token, ORG_NAME)
    return {"is_member": is_member}