| import uvicorn |
| from fastapi import FastAPI, Depends |
| from starlette.responses import RedirectResponse |
| from starlette.middleware.sessions import SessionMiddleware |
| from authlib.integrations.starlette_client import OAuth, OAuthError |
| from fastapi import Request |
| import os |
| from starlette.config import Config |
| import gradio as gr |
|
|
| app = FastAPI() |
|
|
| |
| GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID") |
| GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") |
| SECRET_KEY = os.environ.get("SECRET_KEY") |
|
|
| |
| config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} |
| starlette_config = Config(environ=config_data) |
| oauth = OAuth(starlette_config) |
| oauth.register( |
| name='google', |
| server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', |
| client_kwargs={'scope': 'openid email profile'}, |
| ) |
|
|
| app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
|
| |
| def get_user(request: Request): |
| user = request.session.get('user') |
| print ("User", user) |
| if user and user['email'].endswith("@zalando.de"): |
| return user['name'] |
| return None |
|
|
| @app.get('/') |
| def public(request: Request, user = Depends(get_user)): |
| root_url = gr.route_utils.get_root_url(request, "/", None) |
| if user: |
| return RedirectResponse(url=f'{root_url}/gradio/') |
| else: |
| return RedirectResponse(url=f'{root_url}/main/') |
|
|
| @app.route('/logout') |
| async def logout(request: Request): |
| request.session.pop('user', None) |
| return RedirectResponse(url='/') |
|
|
| @app.route('/login') |
| async def login(request: Request): |
| root_url = gr.route_utils.get_root_url(request, "/login", None) |
| redirect_uri = f"{root_url}/auth" |
| print("Redirecting to", redirect_uri) |
| return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
| @app.route('/auth') |
| async def auth(request: Request): |
| try: |
| access_token = await oauth.google.authorize_access_token(request) |
| except OAuthError: |
| print("Error getting access token", str(OAuthError)) |
| return RedirectResponse(url='/') |
| request.session['user'] = dict(access_token)["userinfo"] |
| print("Redirecting to /gradio") |
| return RedirectResponse(url='/gradio') |
|
|
| with gr.Blocks() as login_demo: |
| btn = gr.Button("Login") |
| _js_redirect = """ |
| () => { |
| url = '/login' + window.location.search; |
| window.open(url, '_blank'); |
| } |
| """ |
| btn.click(None, js=_js_redirect) |
|
|
| app = gr.mount_gradio_app(app, login_demo, path="/main") |
|
|
| def greet(request: gr.Request): |
| return f"Welcome to Gradio, {request.username}" |
|
|
| with gr.Blocks() as main_demo: |
| m = gr.Markdown("Welcome to Gradio!") |
| gr.Button("Logout", link="/logout") |
| main_demo.load(greet, None, m) |
|
|
| app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user) |
|
|
|
|
| if __name__ == '__main__': |
| uvicorn.run(app) |
|
|