lightning / app.py
sharktide's picture
Update app.py
2c6dd2f verified
import os
import time
import hashlib
from fastapi import WebSocket, WebSocketDisconnect
from collections import defaultdict, deque
import json
from fastapi import FastAPI, Request, HTTPException, status, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import (
Response,
JSONResponse,
StreamingResponse,
RedirectResponse,
)
import httpx
from bs4 import BeautifulSoup
from typing import List, Dict, Any
import asyncio
import re
import random
from urllib.parse import quote
import base64
from helper.subscriptions import (
fetch_subscription,
normalize_plan_key,
TIER_CONFIG,
PLAN_ORDER,
)
from typing import Optional
from helper.keywords import *
from helper.assets import asset_router
from helper.ratelimit import (
enforce_rate_limit,
resolve_rate_limit_identity,
check_audio_rate_limit,
check_video_rate_limit,
check_image_rate_limit,
MAX_CHAT_PROMPT_BYTES,
MAX_CHAT_PROMPT_CHARS,
MAX_GROQ_PROMPT_BYTES,
MAX_GROQ_PROMPT_CHARS,
MAX_MEDIA_PROMPT_BYTES,
MAX_MEDIA_PROMPT_CHARS,
extract_user_text,
calculate_messages_size,
normalize_prompt_value,
enforce_prompt_size,
resolve_bound_subject,
get_usage_snapshot_for_subject,
)
from status import router as status_router
from gen import router as gen_router
app = FastAPI()
WEBSOCKET_KEY = os.getenv("WEBSOCKET_KEY")
# authentication attempt tracking
AUTH_ATTEMPTS = defaultdict(lambda: deque())
AUTH_WINDOW_SECONDS = 60
AUTH_MAX_ATTEMPTS = 10
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "HEAD"],
allow_headers=["*"],
)
app.include_router(asset_router)
app.include_router(status_router)
app.include_router(gen_router)
def check_ws_auth_rate_limit(ip: str):
now = time.time()
q = AUTH_ATTEMPTS[ip]
# purge old attempts
while q and now - q[0] > AUTH_WINDOW_SECONDS:
q.popleft()
if len(q) >= AUTH_MAX_ATTEMPTS:
return False
q.append(now)
return True
@app.get("/")
async def reroute_to_home():
return RedirectResponse(
url="https://inference.js.org", status_code=status.HTTP_308_PERMANENT_REDIRECT
)
OLLAMA_LIBRARY_URL = "https://ollama.com/library"
@app.head("/models")
@app.get("/models")
async def get_models() -> List[Dict]:
async with httpx.AsyncClient() as client:
response = await client.get(OLLAMA_LIBRARY_URL)
html = response.text
soup = BeautifulSoup(html, "html.parser")
items = soup.select("li[x-test-model]")
models = []
for item in items:
name = item.select_one("[x-test-model-title] span")
description = item.select_one("p.max-w-lg")
sizes = [el.get_text(strip=True) for el in item.select("[x-test-size]")]
pulls = item.select_one("[x-test-pull-count]")
tags = [
t.get_text(strip=True) for t in item.select('span[class*="text-blue-600"]')
]
updated = item.select_one("[x-test-updated]")
link = item.select_one("a")
models.append(
{
"name": name.get_text(strip=True) if name else "",
"description": (
description.get_text(strip=True)
if description
else "No description"
),
"sizes": sizes,
"pulls": pulls.get_text(strip=True) if pulls else "Unknown",
"tags": tags,
"updated": updated.get_text(strip=True) if updated else "Unknown",
"link": link.get("href") if link else None,
}
)
return models
@app.get("/subscription")
async def get_subscription(authorization: Optional[str] = Header(None)):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(401, "Missing or invalid Authorization header")
jwt = authorization.split(" ", 1)[1]
result = await fetch_subscription(jwt)
if "error" in result:
raise HTTPException(401, result["error"])
plan_key = normalize_plan_key(result.get("plan_key"))
result["plan_key"] = plan_key
result["plan_name"] = (TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"])["name"]
return result
@app.get("/usage")
async def get_usage(
request: Request,
authorization: Optional[str] = Header(None),
x_client_id: Optional[str] = Header(None),
):
plan_key, subject = await resolve_rate_limit_identity(
request, authorization, x_client_id
)
plan = TIER_CONFIG.get(plan_key) or TIER_CONFIG["free"]
usage = get_usage_snapshot_for_subject(plan_key, subject)
return JSONResponse(
status_code=200,
content={
"plan_key": plan_key,
"plan_name": plan.get("name", "Free Tier"),
"usage": usage,
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
},
)
@app.get("/tier-config")
async def tier_config():
plans = []
for idx, key in enumerate(PLAN_ORDER):
plan = TIER_CONFIG.get(key)
if not plan:
continue
plans.append(
{
"key": key,
"name": plan["name"],
"url": plan["url"],
"price": plan["price"],
"limits": plan["limits"],
"order": idx,
}
)
return JSONResponse(
status_code=200,
content={
"defaultPlanKey": "free",
"plans": plans,
},
)
@app.get("/tiers")
async def tiers():
paid_plans = []
for key in PLAN_ORDER:
if key == "free":
continue
plan = TIER_CONFIG.get(key)
if not plan:
continue
paid_plans.append(
{
"key": key,
"name": plan["name"],
"url": plan["url"],
"price": plan["price"],
"limits": plan["limits"],
}
)
return JSONResponse(
status_code=200,
content=paid_plans,
)
@app.websocket("/ws/chat")
async def websocket_chat(ws: WebSocket):
ip = ws.client.host
# Accept connection
await ws.accept()
# Basic connection rate limiting
if not check_ws_auth_rate_limit(ip):
await ws.close(code=4408)
return
try:
# Expect an auth message first
auth_msg = await ws.receive_text()
try:
auth_data = json.loads(auth_msg)
except Exception:
await ws.close(code=4400)
return
provided_key = auth_data.get("key")
if not WEBSOCKET_KEY or provided_key != WEBSOCKET_KEY:
await ws.close(code=4403)
return
await ws.send_json({
"type": "auth",
"status": "ok"
})
internal_url = "http://127.0.0.1:7860/gen/chat/completions"
async with httpx.AsyncClient(
timeout=None,
follow_redirects=False
) as client:
request_counter = 0
async def handle_incoming_requests():
nonlocal request_counter
while True:
try:
msg = await ws.receive_text()
except WebSocketDisconnect:
break
try:
data = json.loads(msg)
except Exception:
await ws.send_json({"error": "Invalid JSON"})
continue
body = data.get("body")
headers = data.get("headers") or {}
if not body:
await ws.send_json({"error": "Missing body"})
continue
request_counter += 1
request_id = request_counter
# Process this request concurrently
asyncio.create_task(process_request(client, request_id, body, headers))
async def process_request(client, request_id, body, headers):
try:
async with client.stream(
"POST",
internal_url,
json=body,
headers=headers
) as response:
# Handle upstream errors
if response.status_code >= 400:
error_text = ""
try:
error_text = (await response.aread()).decode(
"utf-8", errors="replace"
)[:500]
except Exception:
pass
error_payload = json.dumps({
"error": "Upstream request failed",
"status": response.status_code,
"detail": error_text
})
try:
await ws.send_text(f"{request_id}:{error_payload}")
except RuntimeError:
return
return
# Stream tokens/lines to the websocket
async for line in response.aiter_lines():
if not line:
continue
try:
await ws.send_text(f"{request_id}:{line}")
except RuntimeError:
return
except Exception as stream_error:
try:
error_payload = json.dumps({
"error": "Streaming error",
"detail": str(stream_error)
})
await ws.send_text(f"{request_id}:{error_payload}")
except Exception:
pass
# Start handling incoming requests
await handle_incoming_requests()
except WebSocketDisconnect:
return
except Exception as e:
try:
await ws.send_json({
"error": "Server error",
"detail": str(e)
})
except Exception:
pass
try:
await ws.close()
except Exception:
pass
@app.get("/portal")
@app.post("/portal")
async def redirect_to_protal(request: Request):
email = None
if request.method == "POST":
try:
body = await request.json()
email = body.get("email")
except:
email = None
base_url = "https://billing.stripe.com/p/login/5kQdR9aIM3ts4steyabbG00"
if not email:
return RedirectResponse(url=base_url, status_code=status.HTTP_302_FOUND)
if request.method != "POST":
return RedirectResponse(
url=f"{base_url}?prefilled_email={email}", status_code=status.HTTP_302_FOUND
)
else:
return JSONResponse({"redirect_url": (base_url + "?prefilled_email=" + email)})