| |
| |
| |
| |
|
|
| import uuid |
| import time |
| import json |
| import asyncio |
| import logging |
| import os |
| from typing import Optional, List, Union, Dict, Any, Literal |
| from fastapi import FastAPI, HTTPException, Request, status |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from pydantic import BaseModel, Field, ValidationError |
| from gradio_client import Client |
| from contextlib import asynccontextmanager |
| import httpx |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(name)s | %(threadName)s | %(message)s", |
| handlers=[logging.StreamHandler()] |
| ) |
| logger = logging.getLogger("api_gateway") |
|
|
| class SessionData: |
| def __init__(self): |
| self.system: str = "" |
| self.history: List[Dict[str, Any]] = [] |
| self.last_access: float = time.time() |
| self.active_tasks: Dict[str, asyncio.Task] = {} |
| self.last_request_time: float = 0.0 |
|
|
| class SessionManager: |
| def __init__(self): |
| self.sessions: Dict[str, Dict[str, SessionData]] = {} |
| self.lock = asyncio.Lock() |
|
|
| async def cleanup(self): |
| while True: |
| await asyncio.sleep(60) |
| async with self.lock: |
| now = time.time() |
| expired = [] |
| for user, sessions in list(self.sessions.items()): |
| for sid, data in list(sessions.items()): |
| if now - data.last_access > 300: |
| expired.append((user, sid)) |
| for task_id, task in data.active_tasks.items(): |
| if not task.done(): |
| task.cancel() |
| for user, sid in expired: |
| if user in self.sessions and sid in self.sessions[user]: |
| del self.sessions[user][sid] |
| if not self.sessions[user]: |
| del self.sessions[user] |
|
|
| async def get_session(self, user: Optional[str], session_id: Optional[str]) -> (str, str, SessionData): |
| async with self.lock: |
| if not user: |
| user = str(uuid.uuid4()) |
| if user not in self.sessions: |
| self.sessions[user] = {} |
| if not session_id or session_id not in self.sessions[user]: |
| session_id = str(uuid.uuid4()) |
| self.sessions[user][session_id] = SessionData() |
| session = self.sessions[user][session_id] |
| session.last_access = time.time() |
| return user, session_id, session |
|
|
| session_manager = SessionManager() |
|
|
| async def create_gradio_client_with_retry(url: str, max_retries: int = 3, backoff_factor: float = 0.5) -> Client: |
| for attempt in range(max_retries): |
| try: |
| client = Client(url) |
| _ = client.config |
| return client |
| except (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.HTTPError, Exception) as e: |
| logger.warning(f"Attempt {attempt+1}/{max_retries} failed to create Gradio client: {e}") |
| await asyncio.sleep(backoff_factor * (2 ** attempt)) |
| raise RuntimeError("Failed to create Gradio client after retries") |
|
|
| async def refresh_client(app: FastAPI): |
| while True: |
| await asyncio.sleep(15 * 60) |
| async with app.state.client_lock: |
| try: |
| if app.state.client is not None: |
| old_client = app.state.client |
| app.state.client = None |
| del old_client |
| app.state.client = await create_gradio_client_with_retry("https://hadadrjt-ai.hf.space/") |
| logger.info("Refreshed Gradio client connection") |
| except Exception as e: |
| logger.error(f"Error refreshing Gradio client: {e}", exc_info=True) |
| app.state.client = None |
|
|
| async def clear_terminal_periodically(): |
| while True: |
| await asyncio.sleep(300) |
| if os.name == "nt": |
| os.system("cls") |
| else: |
| print("\033c", end="", flush=True) |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| app.state.session_manager = session_manager |
| app.state.client = None |
| app.state.client_lock = asyncio.Lock() |
| try: |
| app.state.client = await create_gradio_client_with_retry("https://hadadrjt-ai.hf.space/") |
| except Exception as e: |
| logger.error(f"Initial Gradio client creation failed: {e}", exc_info=True) |
| app.state.client = None |
| app.state.refresh_task = asyncio.create_task(refresh_client(app)) |
| app.state.cleanup_task = asyncio.create_task(session_manager.cleanup()) |
| app.state.clear_log_task = asyncio.create_task(clear_terminal_periodically()) |
| try: |
| yield |
| finally: |
| app.state.refresh_task.cancel() |
| app.state.cleanup_task.cancel() |
| app.state.clear_log_task.cancel() |
| await asyncio.sleep(0.1) |
|
|
| app = FastAPI( |
| title="J.A.R.V.I.S. OpenAI-Compatible API", |
| version="2.1.3-0625", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["GET", "POST", "OPTIONS", "HEAD"], |
| allow_headers=["*"], |
| ) |
|
|
| class Function(BaseModel): |
| name: str |
| description: Optional[str] |
| parameters: Dict[str, Any] |
|
|
| class Tool(BaseModel): |
| type: Literal["function"] = "function" |
| function: Function |
|
|
| class ToolCall(BaseModel): |
| id: str |
| type: Literal["function"] |
| function: Dict[str, str] |
|
|
| class Message(BaseModel): |
| role: str = Field(..., pattern="^(system|user|assistant|tool|function)$") |
| content: Optional[Union[str, List[Dict[str, Any]]]] |
| name: Optional[str] = None |
| tool_calls: Optional[List[ToolCall]] = None |
| tool_call_id: Optional[str] = None |
|
|
| class CommonParams(BaseModel): |
| model: str |
| stream: bool = False |
| user: Optional[str] = None |
| session_id: Optional[str] = None |
| top_p: Optional[float] = None |
| top_k: Optional[int] = None |
| temperature: Optional[float] = None |
| max_tokens: Optional[int] = None |
| max_new_tokens: Optional[int] = None |
| presence_penalty: Optional[float] = None |
| frequency_penalty: Optional[float] = None |
| repetition_penalty: Optional[float] = None |
| logit_bias: Optional[Dict[str, float]] = None |
| repeat_penalty: Optional[float] = None |
| seed: Optional[int] = None |
| tools: Optional[List[Tool]] = None |
| tool_choice: Optional[Union[str, Dict[str, str]]] = None |
| functions: Optional[List[Function]] = None |
| function_call: Optional[Union[str, Dict[str, str]]] = None |
|
|
| class ChatCompletionRequest(CommonParams): |
| messages: List[Message] |
|
|
| class CompletionRequest(CommonParams): |
| prompt: Union[str, List[str]] |
|
|
| class EmbeddingRequest(BaseModel): |
| model: str |
| input: Union[str, List[str]] |
| user: Optional[str] = None |
|
|
| class RouterRequest(CommonParams): |
| endpoint: Optional[str] = "chat/completions" |
| messages: Optional[List[Message]] = None |
| prompt: Optional[Union[str, List[str]]] = None |
| input: Optional[Union[str, List[str]]] = None |
|
|
| def sanitize_messages(messages: List[Message]) -> List[Message]: |
| cleaned = [] |
| for m in messages: |
| if isinstance(m.content, list): |
| texts = [c.get("text", "") for c in m.content if isinstance(c, dict) and c.get("type") == "text"] |
| if texts: |
| cleaned.append(Message(role=m.role, content=" ".join(texts))) |
| elif isinstance(m.content, str): |
| cleaned.append(m) |
| return cleaned |
|
|
| def map_messages(system: str, history: List[Dict[str, Any]], new_msgs: List[Message]) -> str: |
| msgs = [] |
| if system: |
| msgs.append({"role": "system", "content": system}) |
| msgs.extend(history) |
| msgs.extend([m.model_dump() for m in new_msgs if m.role != "system"]) |
| text = "" |
| for m in msgs: |
| text += f"{m.get('role','')}:{m.get('content','')}\n" |
| return text.strip() |
|
|
| async def get_client(app: FastAPI) -> Client: |
| async with app.state.client_lock: |
| if app.state.client is None: |
| try: |
| app.state.client = await create_gradio_client_with_retry("https://hadadrjt-ai.hf.space/") |
| logger.info("Created Gradio client connection on demand") |
| except Exception as e: |
| logger.error(f"Failed to create Gradio client: {e}", exc_info=True) |
| raise HTTPException(status_code=502, detail="Failed to connect to upstream Gradio app") |
| return app.state.client |
|
|
| async def call_gradio(client: Client, params: dict): |
| for attempt in range(3): |
| try: |
| return await asyncio.to_thread(lambda: client.submit(**params)) |
| except Exception as e: |
| await asyncio.sleep(0.2 * (attempt + 1)) |
| raise HTTPException(status_code=502, detail="Upstream Gradio app error") |
|
|
| async def stream_response(job, session_id: str, session_history: List[Dict[str, Any]], new_messages: List[Message], response_type: str): |
| partial = "" |
| try: |
| chunks = await asyncio.to_thread(lambda: list(job)) |
| except Exception: |
| chunks = [] |
| for chunk in chunks: |
| try: |
| if isinstance(chunk, list): |
| response = next((item.get('content') for item in chunk if isinstance(item, dict) and 'content' in item), str(chunk)) |
| else: |
| response = str(chunk) |
| token = response[len(partial):] if response.startswith(partial) else response |
| partial = response |
| if response_type == "chat": |
| data = { |
| "id": str(uuid.uuid4()), |
| "object": "chat.completion.chunk", |
| "choices": [{"delta": {"content": token}, "index": 0, "finish_reason": None}], |
| "session_id": session_id |
| } |
| else: |
| data = { |
| "id": str(uuid.uuid4()), |
| "object": "text_completion.chunk", |
| "choices": [{"text": token, "index": 0, "finish_reason": None}], |
| "session_id": session_id |
| } |
| yield f"data: {json.dumps(data)}\n\n" |
| except Exception: |
| continue |
| session_history.extend([m.model_dump() for m in new_messages if m.role != "system"]) |
| session_history.append({"role": "assistant", "content": partial}) |
| done_data = { |
| "id": str(uuid.uuid4()), |
| "object": f"{response_type}.completion.chunk", |
| "choices": [{"delta" if response_type=="chat" else "text": {} if response_type=="chat" else "", "index": 0, "finish_reason": "stop"}], |
| "session_id": session_id |
| } |
| yield f"data: {json.dumps(done_data)}\n\n" |
|
|
| RATE_LIMIT_SECONDS = 1.0 |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(req: ChatCompletionRequest): |
| user, session_id, session = await session_manager.get_session(req.user, req.session_id) |
| now = time.time() |
| if now - session.last_request_time < RATE_LIMIT_SECONDS: |
| raise HTTPException(status_code=429, detail="Too many requests, please slow down") |
| session.last_request_time = now |
| req.messages = sanitize_messages(req.messages) |
| for m in req.messages: |
| if m.role == "system": |
| session.system = m.content |
| break |
| text = map_messages(session.system, session.history, req.messages) |
| params = { |
| "message": text, |
| "model_label": req.model, |
| "api_name": "/api", |
| "top_p": req.top_p, |
| "top_k": req.top_k, |
| "temperature": req.temperature, |
| "max_tokens": req.max_tokens, |
| "max_new_tokens": req.max_new_tokens, |
| "presence_penalty": req.presence_penalty, |
| "frequency_penalty": req.frequency_penalty, |
| "repetition_penalty": req.repetition_penalty, |
| "repeat_penalty": req.repeat_penalty, |
| "logit_bias": req.logit_bias, |
| "seed": req.seed, |
| "functions": req.functions or req.tools, |
| "function_call": req.function_call or req.tool_choice, |
| } |
| params = {k: v for k, v in params.items() if v is not None} |
| client = await get_client(app) |
| if req.stream: |
| job = await call_gradio(client, params) |
| generator = stream_response(job, session_id, session.history, req.messages, "chat") |
| return StreamingResponse(generator, media_type="text/event-stream") |
| else: |
| loop = asyncio.get_running_loop() |
| try: |
| result = await loop.run_in_executor(None, lambda: client.predict(**params)) |
| except Exception: |
| raise HTTPException(status_code=502, detail="Upstream Gradio app error") |
| session.history.extend([m.model_dump() for m in req.messages if m.role != "system"]) |
| session.history.append({"role": "assistant", "content": result}) |
| return { |
| "id": str(uuid.uuid4()), |
| "object": "chat.completion", |
| "choices": [{"message": {"role": "assistant", "content": result}}], |
| "session_id": session_id |
| } |
|
|
| @app.post("/v1/completions") |
| async def completions(req: CompletionRequest): |
| user, session_id, session = await session_manager.get_session(req.user, req.session_id) |
| now = time.time() |
| if now - session.last_request_time < RATE_LIMIT_SECONDS: |
| raise HTTPException(status_code=429, detail="Too many requests, please slow down") |
| session.last_request_time = now |
| prompt = req.prompt if isinstance(req.prompt, str) else "\n".join(req.prompt) |
| params = { |
| "message": prompt, |
| "model_label": req.model, |
| "api_name": "/api", |
| "top_p": req.top_p, |
| "top_k": req.top_k, |
| "temperature": req.temperature, |
| "max_tokens": req.max_tokens, |
| "max_new_tokens": req.max_new_tokens, |
| "presence_penalty": req.presence_penalty, |
| "frequency_penalty": req.frequency_penalty, |
| "repetition_penalty": req.repetition_penalty, |
| "repeat_penalty": req.repeat_penalty, |
| "logit_bias": req.logit_bias, |
| "seed": req.seed, |
| } |
| params = {k: v for k, v in params.items() if v is not None} |
| client = await get_client(app) |
| if req.stream: |
| job = await call_gradio(client, params) |
| generator = stream_response(job, session_id, [], [], "text") |
| return StreamingResponse(generator, media_type="text/event-stream") |
| else: |
| loop = asyncio.get_running_loop() |
| try: |
| result = await loop.run_in_executor(None, lambda: client.predict(**params)) |
| except Exception: |
| raise HTTPException(status_code=502, detail="Upstream Gradio app error") |
| return {"id": str(uuid.uuid4()), "object": "text_completion", "choices": [{"text": result}]} |
|
|
| @app.post("/v1/embeddings") |
| async def embeddings(req: EmbeddingRequest): |
| inputs = req.input if isinstance(req.input, list) else [req.input] |
| embeddings = [[0.0] * 768 for _ in inputs] |
| return {"object": "list", "data": [{"embedding": emb, "index": i} for i, emb in enumerate(embeddings)]} |
|
|
| @app.get("/v1/models") |
| async def get_models(): |
| return {"object": "list", "data": [{"id": "Q8_K_XL", "object": "model", "owned_by": "J.A.R.V.I.S."}]} |
|
|
| @app.get("/v1/history") |
| async def get_history(user: Optional[str] = None, session_id: Optional[str] = None): |
| user = user or "anonymous" |
| sessions = session_manager.sessions |
| if user in sessions and session_id and session_id in sessions[user]: |
| return {"user": user, "session_id": session_id, "history": sessions[user][session_id].history} |
| return {"user": user, "session_id": session_id, "history": []} |
|
|
| @app.post("/v1/responses/cancel") |
| async def cancel_response(user: Optional[str], session_id: Optional[str], task_id: Optional[str]): |
| user = user or "anonymous" |
| if not task_id: |
| raise HTTPException(status_code=400, detail="Missing task_id for cancellation") |
| async with session_manager.lock: |
| if user in session_manager.sessions and session_id in session_manager.sessions[user]: |
| session = session_manager.sessions[user][session_id] |
| task = session.active_tasks.get(task_id) |
| if task and not task.done(): |
| task.cancel() |
| return {"message": f"Cancelled task {task_id}"} |
| raise HTTPException(status_code=404, detail="Task not found or already completed") |
|
|
| @app.api_route("/v1", methods=["POST", "GET", "OPTIONS", "HEAD"]) |
| async def router(request: Request): |
| if request.method == "POST": |
| try: |
| body_json = await request.json() |
| except Exception: |
| raise HTTPException(status_code=400, detail="Invalid JSON body") |
| try: |
| body = RouterRequest(**body_json) |
| except ValidationError as e: |
| raise HTTPException(status_code=422, detail=e.errors()) |
| endpoint = body.endpoint or "chat/completions" |
| if endpoint == "chat/completions": |
| if not body.model or not body.messages: |
| raise HTTPException(status_code=422, detail="Missing 'model' or 'messages'") |
| req_obj = ChatCompletionRequest(**body.dict()) |
| return await chat_completions(req_obj) |
| elif endpoint == "completions": |
| if not body.model or not body.prompt: |
| raise HTTPException(status_code=422, detail="Missing 'model' or 'prompt'") |
| req_obj = CompletionRequest(**body.dict()) |
| return await completions(req_obj) |
| elif endpoint == "embeddings": |
| if not body.model or body.input is None: |
| raise HTTPException(status_code=422, detail="Missing 'model' or 'input'") |
| req_obj = EmbeddingRequest(**body.dict()) |
| return await embeddings(req_obj) |
| elif endpoint == "models": |
| return await get_models() |
| elif endpoint == "history": |
| return await get_history(body.user, body.session_id) |
| elif endpoint == "responses/cancel": |
| task_id = None |
| if isinstance(body.tool_choice, str): |
| task_id = body.tool_choice |
| return await cancel_response(body.user, body.session_id, task_id) |
| else: |
| raise HTTPException(status_code=404, detail="Endpoint not found") |
| else: |
| return JSONResponse({"message": "Send POST request with JSON body"}, status_code=status.HTTP_405_METHOD_NOT_ALLOWED) |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "endpoints": [ |
| "/v1/chat/completions", |
| "/v1/completions", |
| "/v1/embeddings", |
| "/v1/models", |
| "/v1/history", |
| "/v1/responses/cancel", |
| "/v1" |
| ] |
| } |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|