| from fastapi import FastAPI, Request, Response, UploadFile, File |
| from fastapi.responses import StreamingResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| import httpx |
| import json |
| import asyncio |
| import time |
| import base64 |
| from typing import Optional, Dict, Any, List |
| from io import BytesIO |
|
|
| |
| QWEN_API_URL = "https://chat.qwenlm.ai/api/chat/completions" |
| QWEN_MODELS_URL = "https://chat.qwenlm.ai/api/models" |
| QWEN_FILES_URL = "https://chat.qwenlm.ai/api/v1/files/" |
| MAX_RETRIES = 3 |
| RETRY_DELAY = 1 |
|
|
| |
| cached_models = None |
| cached_models_timestamp = 0 |
| CACHE_TTL = 60 * 60 |
|
|
| app = FastAPI() |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
| client = httpx.AsyncClient() |
|
|
| @app.get("/") |
| async def root(): |
| return FileResponse("index.html") |
|
|
| async def sleep(seconds: float): |
| await asyncio.sleep(seconds) |
|
|
| |
| async def base64_to_file(base64_str: str) -> BytesIO: |
| try: |
| |
| if ',' in base64_str: |
| base64_str = base64_str.split(',', 1)[1] |
| |
| |
| image_data = base64.b64decode(base64_str) |
| return BytesIO(image_data) |
| except Exception as e: |
| raise Exception(f"Failed to convert base64 to file: {str(e)}") |
|
|
| |
| async def upload_image_to_qwen(auth_header: str, image_data: BytesIO) -> str: |
| try: |
| files = {'file': ('image.jpg', image_data, 'image/jpeg')} |
| headers = { |
| "Authorization": auth_header, |
| "accept": "application/json" |
| } |
| |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| QWEN_FILES_URL, |
| headers=headers, |
| files=files |
| ) |
| |
| if response.is_success: |
| data = response.json() |
| if not data.get('id'): |
| raise Exception("File upload failed: No valid file ID returned") |
| return data['id'] |
| else: |
| raise Exception(f"File upload failed with status {response.status_code}") |
| |
| except Exception as e: |
| raise Exception(f"Failed to upload image: {str(e)}") |
|
|
| |
| async def process_messages(messages: List[Dict], auth_header: str) -> List[Dict]: |
| processed_messages = [] |
| |
| for message in messages: |
| if isinstance(message.get('content'), list): |
| new_content = [] |
| for content in message['content']: |
| if (content.get('type') == 'image_url' and |
| content.get('image_url', {}).get('url', '').startswith('data:')): |
| |
| image_data = await base64_to_file(content['image_url']['url']) |
| image_id = await upload_image_to_qwen(auth_header, image_data) |
| new_content.append({ |
| 'type': 'image', |
| 'image': image_id |
| }) |
| else: |
| new_content.append(content) |
| message['content'] = new_content |
| processed_messages.append(message) |
| |
| return processed_messages |
|
|
| async def fetch_with_retry(url: str, options: Dict, retries: int = MAX_RETRIES): |
| last_error = None |
| |
| for i in range(retries): |
| try: |
| response = await client.request( |
| method=options.get("method", "GET"), |
| url=url, |
| headers=options.get("headers", {}), |
| json=options.get("json"), |
| ) |
| |
| if response.is_success: |
| return response |
| |
| content_type = response.headers.get("content-type", "") |
| if response.status_code >= 500 or "text/html" in content_type: |
| last_error = { |
| "status": response.status_code, |
| "content_type": content_type, |
| "response_text": response.text[:1000], |
| "headers": dict(response.headers) |
| } |
| |
| if i < retries - 1: |
| await sleep(RETRY_DELAY * (i + 1)) |
| continue |
| else: |
| last_error = { |
| "status": response.status_code, |
| "headers": dict(response.headers) |
| } |
| break |
| |
| except Exception as error: |
| last_error = error |
| if i < retries - 1: |
| await sleep(RETRY_DELAY * (i + 1)) |
| continue |
| |
| raise Exception(json.dumps({ |
| "error": True, |
| "message": "All retry attempts failed", |
| "last_error": str(last_error), |
| "retries": retries |
| })) |
|
|
| async def process_line(line: str, previous_content: str) -> tuple[str, Optional[dict]]: |
| try: |
| data = json.loads(line[6:]) |
| if (data.get("choices") and data["choices"][0].get("delta")): |
| delta = data["choices"][0]["delta"] |
| current_content = delta.get("content", "") |
| |
| |
| if previous_content and current_content: |
| if current_content.startswith(previous_content): |
| new_content = current_content[len(previous_content):] |
| else: |
| new_content = current_content |
| else: |
| new_content = current_content |
|
|
| |
| new_data = { |
| "choices": [{ |
| "delta": { |
| "role": delta.get("role", "assistant"), |
| "content": new_content |
| } |
| }] |
| } |
| |
| return current_content, new_data |
| return previous_content, data |
| except: |
| return previous_content, None |
|
|
| async def stream_generator(response: httpx.Response): |
| buffer = "" |
| previous_content = "" |
| |
| async for chunk in response.aiter_bytes(): |
| chunk_text = chunk.decode() |
| buffer += chunk_text |
| |
| lines = buffer.split("\n") |
| buffer = lines.pop() if lines else "" |
| |
| for line in lines: |
| line = line.strip() |
| if line.startswith("data: "): |
| previous_content, data = await process_line(line, previous_content) |
| if data: |
| yield f"data: {json.dumps(data)}\n\n" |
| |
| if buffer: |
| previous_content, data = await process_line(buffer, previous_content) |
| if data: |
| yield f"data: {json.dumps(data)}\n\n" |
| |
| yield "data: [DONE]\n\n" |
|
|
| @app.get("/healthz") |
| async def health_check(): |
| return {"status": "ok"} |
|
|
| @app.get("/api/models") |
| async def get_models(request: Request): |
| global cached_models, cached_models_timestamp |
| |
| auth_header = request.headers.get("Authorization") |
| if not auth_header or not auth_header.startswith("Bearer "): |
| return Response(status_code=401, content="Unauthorized") |
| |
| now = time.time() |
| if cached_models and now - cached_models_timestamp < CACHE_TTL: |
| return Response( |
| content=cached_models, |
| media_type="application/json" |
| ) |
| |
| try: |
| response = await fetch_with_retry( |
| QWEN_MODELS_URL, |
| {"headers": {"Authorization": auth_header}} |
| ) |
| |
| cached_models = response.text |
| cached_models_timestamp = now |
| |
| return Response( |
| content=cached_models, |
| media_type="application/json" |
| ) |
| except Exception as error: |
| return Response( |
| content=json.dumps({"error": True, "message": str(error)}), |
| status_code=500 |
| ) |
|
|
| @app.post("/api/chat/completions") |
| async def chat_completions(request: Request): |
| auth_header = request.headers.get("Authorization") |
| if not auth_header or not auth_header.startswith("Bearer "): |
| return Response(status_code=401, content="Unauthorized") |
| |
| request_data = await request.json() |
| messages = request_data.get("messages") |
| stream = request_data.get("stream", False) |
| model = request_data.get("model") |
| max_tokens = request_data.get("max_tokens") |
| |
| if not model: |
| return Response( |
| content=json.dumps({"error": True, "message": "Model parameter is required"}), |
| status_code=400 |
| ) |
| |
| try: |
| |
| processed_messages = await process_messages(messages, auth_header) |
| |
| qwen_request = { |
| "model": model, |
| "messages": processed_messages, |
| "stream": stream |
| } |
| |
| if max_tokens is not None: |
| qwen_request["max_tokens"] = max_tokens |
| |
| response = await client.post( |
| QWEN_API_URL, |
| headers={ |
| "Content-Type": "application/json", |
| "Authorization": auth_header |
| }, |
| json=qwen_request |
| ) |
| |
| if stream: |
| return StreamingResponse( |
| stream_generator(response), |
| media_type="text/event-stream" |
| ) |
| |
| return Response( |
| content=response.text, |
| media_type="application/json" |
| ) |
| |
| except Exception as error: |
| return Response( |
| content=json.dumps({"error": True, "message": str(error)}), |
| status_code=500 |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|