import os import time import hashlib from fastapi import WebSocket, WebSocketDisconnect from collections import defaultdict, deque import json from fastapi import FastAPI, Request, HTTPException, status, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ( Response, JSONResponse, StreamingResponse, RedirectResponse, ) import httpx from bs4 import BeautifulSoup from typing import List, Dict, Any import asyncio import re import random from urllib.parse import quote import base64 from helper.subscriptions import ( fetch_subscription, normalize_plan_key, TIER_CONFIG, PLAN_ORDER, ) from typing import Optional from helper.keywords import * from helper.assets import asset_router from helper.ratelimit import ( enforce_rate_limit, resolve_rate_limit_identity, check_audio_rate_limit, check_video_rate_limit, check_image_rate_limit, MAX_CHAT_PROMPT_BYTES, MAX_CHAT_PROMPT_CHARS, MAX_GROQ_PROMPT_BYTES, MAX_GROQ_PROMPT_CHARS, MAX_MEDIA_PROMPT_BYTES, MAX_MEDIA_PROMPT_CHARS, extract_user_text, calculate_messages_size, normalize_prompt_value, enforce_prompt_size, resolve_bound_subject, get_usage_snapshot_for_subject, ) from status import router as status_router from gen import router as gen_router app = FastAPI() WEBSOCKET_KEY = os.getenv("WEBSOCKET_KEY") # authentication attempt tracking AUTH_ATTEMPTS = defaultdict(lambda: deque()) AUTH_WINDOW_SECONDS = 60 AUTH_MAX_ATTEMPTS = 10 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST", "HEAD"], allow_headers=["*"], ) app.include_router(asset_router) app.include_router(status_router) app.include_router(gen_router) def check_ws_auth_rate_limit(ip: str): now = time.time() q = AUTH_ATTEMPTS[ip] # purge old attempts while q and now - q[0] > AUTH_WINDOW_SECONDS: q.popleft() if len(q) >= AUTH_MAX_ATTEMPTS: return False q.append(now) return True @app.get("/") async def reroute_to_home(): return RedirectResponse( url="https://inference.js.org", status_code=status.HTTP_308_PERMANENT_REDIRECT ) OLLAMA_LIBRARY_URL = "https://ollama.com/library" @app.head("/models") @app.get("/models") async def get_models() -> List[Dict]: async with httpx.AsyncClient() as client: response = await client.get(OLLAMA_LIBRARY_URL) html = response.text soup = BeautifulSoup(html, "html.parser") items = soup.select("li[x-test-model]") models = [] for item in items: name = item.select_one("[x-test-model-title] span") description = item.select_one("p.max-w-lg") sizes = [el.get_text(strip=True) for el in item.select("[x-test-size]")] pulls = item.select_one("[x-test-pull-count]") tags = [ t.get_text(strip=True) for t in item.select('span[class*="text-blue-600"]') ] updated = item.select_one("[x-test-updated]") link = item.select_one("a") models.append( { "name": name.get_text(strip=True) if name else "", "description": ( description.get_text(strip=True) if description else "No description" ), "sizes": sizes, "pulls": pulls.get_text(strip=True) if pulls else "Unknown", "tags": tags, "updated": updated.get_text(strip=True) if updated else "Unknown", "link": link.get("href") if link else None, } ) return models @app.get("/subscription") async def get_subscription(authorization: Optional[str] = Header(None)): if not authorization or not authorization.startswith("Bearer "): raise HTTPException(401, "Missing or invalid Authorization header") jwt = authorization.split(" ", 1)[1] result = await fetch_subscription(jwt) if "error" in result: raise HTTPException(401, result["error"]) plan_key = normalize_plan_key(result.get("plan_key")) result["plan_key"] = plan_key result["plan_name"] = (TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"])["name"] return result @app.get("/usage") async def get_usage( request: Request, authorization: Optional[str] = Header(None), x_client_id: Optional[str] = Header(None), ): plan_key, subject = await resolve_rate_limit_identity( request, authorization, x_client_id ) plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"] usage = get_usage_snapshot_for_subject(plan_key, subject) return JSONResponse( status_code=200, content={ "plan_key": plan_key, "plan_name": plan.get("name", "Free Tier"), "usage": usage, "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), }, ) @app.get("/tier-config") async def tier_config(): plans = [] for idx, key in enumerate(PLAN_ORDER): plan = TIER_CONFIG.get(key) if not plan: continue plans.append( { "key": key, "name": plan["name"], "url": plan["url"], "price": plan["price"], "limits": plan["limits"], "order": idx, } ) return JSONResponse( status_code=200, content={ "defaultPlanKey": "free", "plans": plans, }, ) @app.get("/tiers") async def tiers(): paid_plans = [] for key in PLAN_ORDER: if key == "free": continue plan = TIER_CONFIG.get(key) if not plan: continue paid_plans.append( { "key": key, "name": plan["name"], "url": plan["url"], "price": plan["price"], "limits": plan["limits"], } ) return JSONResponse( status_code=200, content=paid_plans, ) @app.websocket("/ws/chat") async def websocket_chat(ws: WebSocket): ip = ws.client.host # Accept connection await ws.accept() # Basic connection rate limiting if not check_ws_auth_rate_limit(ip): await ws.close(code=4408) return try: # Expect an auth message first auth_msg = await ws.receive_text() try: auth_data = json.loads(auth_msg) except Exception: await ws.close(code=4400) return provided_key = auth_data.get("key") if not WEBSOCKET_KEY or provided_key != WEBSOCKET_KEY: await ws.close(code=4403) return await ws.send_json({ "type": "auth", "status": "ok" }) internal_url = "http://127.0.0.1:7860/gen/chat/completions" async with httpx.AsyncClient( timeout=None, follow_redirects=False ) as client: request_counter = 0 async def handle_incoming_requests(): nonlocal request_counter while True: try: msg = await ws.receive_text() except WebSocketDisconnect: break try: data = json.loads(msg) except Exception: await ws.send_json({"error": "Invalid JSON"}) continue body = data.get("body") headers = data.get("headers") or {} if not body: await ws.send_json({"error": "Missing body"}) continue request_counter += 1 request_id = request_counter # Process this request concurrently asyncio.create_task(process_request(client, request_id, body, headers)) async def process_request(client, request_id, body, headers): try: async with client.stream( "POST", internal_url, json=body, headers=headers ) as response: # Handle upstream errors if response.status_code >= 400: error_text = "" try: error_text = (await response.aread()).decode( "utf-8", errors="replace" )[:500] except Exception: pass error_payload = json.dumps({ "error": "Upstream request failed", "status": response.status_code, "detail": error_text }) try: await ws.send_text(f"{request_id}:{error_payload}") except RuntimeError: return return # Stream tokens/lines to the websocket async for line in response.aiter_lines(): if not line: continue try: await ws.send_text(f"{request_id}:{line}") except RuntimeError: return except Exception as stream_error: try: error_payload = json.dumps({ "error": "Streaming error", "detail": str(stream_error) }) await ws.send_text(f"{request_id}:{error_payload}") except Exception: pass # Start handling incoming requests await handle_incoming_requests() except WebSocketDisconnect: return except Exception as e: try: await ws.send_json({ "error": "Server error", "detail": str(e) }) except Exception: pass try: await ws.close() except Exception: pass @app.get("/portal") @app.post("/portal") async def redirect_to_protal(request: Request): email = None if request.method == "POST": try: body = await request.json() email = body.get("email") except: email = None base_url = "https://billing.stripe.com/p/login/5kQdR9aIM3ts4steyabbG00" if not email: return RedirectResponse(url=base_url, status_code=status.HTTP_302_FOUND) if request.method != "POST": return RedirectResponse( url=f"{base_url}?prefilled_email={email}", status_code=status.HTTP_302_FOUND ) else: return JSONResponse({"redirect_url": (base_url + "?prefilled_email=" + email)})