Spaces:
Running
Running
| import os | |
| import time | |
| import hashlib | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from collections import defaultdict, deque | |
| import json | |
| from fastapi import FastAPI, Request, HTTPException, status, Header | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import ( | |
| Response, | |
| JSONResponse, | |
| StreamingResponse, | |
| RedirectResponse, | |
| ) | |
| import httpx | |
| from bs4 import BeautifulSoup | |
| from typing import List, Dict, Any | |
| import asyncio | |
| import re | |
| import random | |
| from urllib.parse import quote | |
| import base64 | |
| from helper.subscriptions import ( | |
| fetch_subscription, | |
| normalize_plan_key, | |
| TIER_CONFIG, | |
| PLAN_ORDER, | |
| ) | |
| from typing import Optional | |
| from helper.keywords import * | |
| from helper.assets import asset_router | |
| from helper.ratelimit import ( | |
| enforce_rate_limit, | |
| resolve_rate_limit_identity, | |
| check_audio_rate_limit, | |
| check_video_rate_limit, | |
| check_image_rate_limit, | |
| MAX_CHAT_PROMPT_BYTES, | |
| MAX_CHAT_PROMPT_CHARS, | |
| MAX_GROQ_PROMPT_BYTES, | |
| MAX_GROQ_PROMPT_CHARS, | |
| MAX_MEDIA_PROMPT_BYTES, | |
| MAX_MEDIA_PROMPT_CHARS, | |
| extract_user_text, | |
| calculate_messages_size, | |
| normalize_prompt_value, | |
| enforce_prompt_size, | |
| resolve_bound_subject, | |
| get_usage_snapshot_for_subject, | |
| ) | |
| from status import router as status_router | |
| from gen import router as gen_router | |
| app = FastAPI() | |
| WEBSOCKET_KEY = os.getenv("WEBSOCKET_KEY") | |
| # authentication attempt tracking | |
| AUTH_ATTEMPTS = defaultdict(lambda: deque()) | |
| AUTH_WINDOW_SECONDS = 60 | |
| AUTH_MAX_ATTEMPTS = 10 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["GET", "POST", "HEAD"], | |
| allow_headers=["*"], | |
| ) | |
| app.include_router(asset_router) | |
| app.include_router(status_router) | |
| app.include_router(gen_router) | |
| def check_ws_auth_rate_limit(ip: str): | |
| now = time.time() | |
| q = AUTH_ATTEMPTS[ip] | |
| # purge old attempts | |
| while q and now - q[0] > AUTH_WINDOW_SECONDS: | |
| q.popleft() | |
| if len(q) >= AUTH_MAX_ATTEMPTS: | |
| return False | |
| q.append(now) | |
| return True | |
| async def reroute_to_home(): | |
| return RedirectResponse( | |
| url="https://inference.js.org", status_code=status.HTTP_308_PERMANENT_REDIRECT | |
| ) | |
| OLLAMA_LIBRARY_URL = "https://ollama.com/library" | |
| async def get_models() -> List[Dict]: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(OLLAMA_LIBRARY_URL) | |
| html = response.text | |
| soup = BeautifulSoup(html, "html.parser") | |
| items = soup.select("li[x-test-model]") | |
| models = [] | |
| for item in items: | |
| name = item.select_one("[x-test-model-title] span") | |
| description = item.select_one("p.max-w-lg") | |
| sizes = [el.get_text(strip=True) for el in item.select("[x-test-size]")] | |
| pulls = item.select_one("[x-test-pull-count]") | |
| tags = [ | |
| t.get_text(strip=True) for t in item.select('span[class*="text-blue-600"]') | |
| ] | |
| updated = item.select_one("[x-test-updated]") | |
| link = item.select_one("a") | |
| models.append( | |
| { | |
| "name": name.get_text(strip=True) if name else "", | |
| "description": ( | |
| description.get_text(strip=True) | |
| if description | |
| else "No description" | |
| ), | |
| "sizes": sizes, | |
| "pulls": pulls.get_text(strip=True) if pulls else "Unknown", | |
| "tags": tags, | |
| "updated": updated.get_text(strip=True) if updated else "Unknown", | |
| "link": link.get("href") if link else None, | |
| } | |
| ) | |
| return models | |
| async def get_subscription(authorization: Optional[str] = Header(None)): | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(401, "Missing or invalid Authorization header") | |
| jwt = authorization.split(" ", 1)[1] | |
| result = await fetch_subscription(jwt) | |
| if "error" in result: | |
| raise HTTPException(401, result["error"]) | |
| plan_key = normalize_plan_key(result.get("plan_key")) | |
| result["plan_key"] = plan_key | |
| result["plan_name"] = (TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"])["name"] | |
| return result | |
| async def get_usage( | |
| request: Request, | |
| authorization: Optional[str] = Header(None), | |
| x_client_id: Optional[str] = Header(None), | |
| ): | |
| plan_key, subject = await resolve_rate_limit_identity( | |
| request, authorization, x_client_id | |
| ) | |
| plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"] | |
| usage = get_usage_snapshot_for_subject(plan_key, subject) | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "plan_key": plan_key, | |
| "plan_name": plan.get("name", "Free Tier"), | |
| "usage": usage, | |
| "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| }, | |
| ) | |
| async def tier_config(): | |
| plans = [] | |
| for idx, key in enumerate(PLAN_ORDER): | |
| plan = TIER_CONFIG.get(key) | |
| if not plan: | |
| continue | |
| plans.append( | |
| { | |
| "key": key, | |
| "name": plan["name"], | |
| "url": plan["url"], | |
| "price": plan["price"], | |
| "limits": plan["limits"], | |
| "order": idx, | |
| } | |
| ) | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "defaultPlanKey": "free", | |
| "plans": plans, | |
| }, | |
| ) | |
| async def tiers(): | |
| paid_plans = [] | |
| for key in PLAN_ORDER: | |
| if key == "free": | |
| continue | |
| plan = TIER_CONFIG.get(key) | |
| if not plan: | |
| continue | |
| paid_plans.append( | |
| { | |
| "key": key, | |
| "name": plan["name"], | |
| "url": plan["url"], | |
| "price": plan["price"], | |
| "limits": plan["limits"], | |
| } | |
| ) | |
| return JSONResponse( | |
| status_code=200, | |
| content=paid_plans, | |
| ) | |
| async def websocket_chat(ws: WebSocket): | |
| ip = ws.client.host | |
| # Accept connection | |
| await ws.accept() | |
| # Basic connection rate limiting | |
| if not check_ws_auth_rate_limit(ip): | |
| await ws.close(code=4408) | |
| return | |
| try: | |
| # Expect an auth message first | |
| auth_msg = await ws.receive_text() | |
| try: | |
| auth_data = json.loads(auth_msg) | |
| except Exception: | |
| await ws.close(code=4400) | |
| return | |
| provided_key = auth_data.get("key") | |
| if not WEBSOCKET_KEY or provided_key != WEBSOCKET_KEY: | |
| await ws.close(code=4403) | |
| return | |
| await ws.send_json({ | |
| "type": "auth", | |
| "status": "ok" | |
| }) | |
| internal_url = "http://127.0.0.1:7860/gen/chat/completions" | |
| async with httpx.AsyncClient( | |
| timeout=None, | |
| follow_redirects=False | |
| ) as client: | |
| request_counter = 0 | |
| async def handle_incoming_requests(): | |
| nonlocal request_counter | |
| while True: | |
| try: | |
| msg = await ws.receive_text() | |
| except WebSocketDisconnect: | |
| break | |
| try: | |
| data = json.loads(msg) | |
| except Exception: | |
| await ws.send_json({"error": "Invalid JSON"}) | |
| continue | |
| body = data.get("body") | |
| headers = data.get("headers") or {} | |
| if not body: | |
| await ws.send_json({"error": "Missing body"}) | |
| continue | |
| request_counter += 1 | |
| request_id = request_counter | |
| # Process this request concurrently | |
| asyncio.create_task(process_request(client, request_id, body, headers)) | |
| async def process_request(client, request_id, body, headers): | |
| try: | |
| async with client.stream( | |
| "POST", | |
| internal_url, | |
| json=body, | |
| headers=headers | |
| ) as response: | |
| # Handle upstream errors | |
| if response.status_code >= 400: | |
| error_text = "" | |
| try: | |
| error_text = (await response.aread()).decode( | |
| "utf-8", errors="replace" | |
| )[:500] | |
| except Exception: | |
| pass | |
| error_payload = json.dumps({ | |
| "error": "Upstream request failed", | |
| "status": response.status_code, | |
| "detail": error_text | |
| }) | |
| try: | |
| await ws.send_text(f"{request_id}:{error_payload}") | |
| except RuntimeError: | |
| return | |
| return | |
| # Stream tokens/lines to the websocket | |
| async for line in response.aiter_lines(): | |
| if not line: | |
| continue | |
| try: | |
| await ws.send_text(f"{request_id}:{line}") | |
| except RuntimeError: | |
| return | |
| except Exception as stream_error: | |
| try: | |
| error_payload = json.dumps({ | |
| "error": "Streaming error", | |
| "detail": str(stream_error) | |
| }) | |
| await ws.send_text(f"{request_id}:{error_payload}") | |
| except Exception: | |
| pass | |
| # Start handling incoming requests | |
| await handle_incoming_requests() | |
| except WebSocketDisconnect: | |
| return | |
| except Exception as e: | |
| try: | |
| await ws.send_json({ | |
| "error": "Server error", | |
| "detail": str(e) | |
| }) | |
| except Exception: | |
| pass | |
| try: | |
| await ws.close() | |
| except Exception: | |
| pass | |
| async def redirect_to_protal(request: Request): | |
| email = None | |
| if request.method == "POST": | |
| try: | |
| body = await request.json() | |
| email = body.get("email") | |
| except: | |
| email = None | |
| base_url = "https://billing.stripe.com/p/login/5kQdR9aIM3ts4steyabbG00" | |
| if not email: | |
| return RedirectResponse(url=base_url, status_code=status.HTTP_302_FOUND) | |
| if request.method != "POST": | |
| return RedirectResponse( | |
| url=f"{base_url}?prefilled_email={email}", status_code=status.HTTP_302_FOUND | |
| ) | |
| else: | |
| return JSONResponse({"redirect_url": (base_url + "?prefilled_email=" + email)}) | |