| import asyncio
|
| import json
|
| from datetime import datetime, timezone
|
| import os
|
| import base64
|
| import tempfile
|
|
|
| from fastapi import FastAPI, HTTPException, Request
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from fastapi.responses import JSONResponse
|
| from fastapi.responses import StreamingResponse
|
| from pydantic import BaseModel
|
| from typing import List, Optional, Dict, Any, Union
|
| import time
|
| import uuid
|
| import logging
|
|
|
| from gemini_webapi import GeminiClient, set_log_level
|
| from gemini_webapi.constants import Model
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
| set_log_level("INFO")
|
|
|
| app = FastAPI(title="Gemini API FastAPI Server")
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
| gemini_client = None
|
|
|
|
|
| SECURE_1PSID = os.environ.get("SECURE_1PSID", "")
|
| SECURE_1PSIDTS = os.environ.get("SECURE_1PSIDTS", "")
|
|
|
|
|
| if not SECURE_1PSID or not SECURE_1PSIDTS:
|
| logger.warning("⚠️ Gemini API credentials are not set or empty! Please check your environment variables.")
|
| logger.warning("Make sure SECURE_1PSID and SECURE_1PSIDTS are correctly set in your .env file or environment.")
|
| logger.warning("If using Docker, ensure the .env file is correctly mounted and formatted.")
|
| logger.warning("Example format in .env file (no quotes):")
|
| logger.warning("SECURE_1PSID=your_secure_1psid_value_here")
|
| logger.warning("SECURE_1PSIDTS=your_secure_1psidts_value_here")
|
| else:
|
|
|
| logger.info(f"Credentials found. SECURE_1PSID starts with: {SECURE_1PSID[:5]}...")
|
| logger.info(f"Credentials found. SECURE_1PSIDTS starts with: {SECURE_1PSIDTS[:5]}...")
|
|
|
|
|
|
|
| class ContentItem(BaseModel):
|
| type: str
|
| text: Optional[str] = None
|
| image_url: Optional[Dict[str, str]] = None
|
|
|
|
|
| class Message(BaseModel):
|
| role: str
|
| content: Union[str, List[ContentItem]]
|
| name: Optional[str] = None
|
|
|
|
|
| class ChatCompletionRequest(BaseModel):
|
| model: str
|
| messages: List[Message]
|
| temperature: Optional[float] = 0.7
|
| top_p: Optional[float] = 1.0
|
| n: Optional[int] = 1
|
| stream: Optional[bool] = False
|
| max_tokens: Optional[int] = None
|
| presence_penalty: Optional[float] = 0
|
| frequency_penalty: Optional[float] = 0
|
| user: Optional[str] = None
|
|
|
|
|
| class Choice(BaseModel):
|
| index: int
|
| message: Message
|
| finish_reason: str
|
|
|
|
|
| class Usage(BaseModel):
|
| prompt_tokens: int
|
| completion_tokens: int
|
| total_tokens: int
|
|
|
|
|
| class ChatCompletionResponse(BaseModel):
|
| id: str
|
| object: str = "chat.completion"
|
| created: int
|
| model: str
|
| choices: List[Choice]
|
| usage: Usage
|
|
|
|
|
| class ModelData(BaseModel):
|
| id: str
|
| object: str = "model"
|
| created: int
|
| owned_by: str = "google"
|
|
|
|
|
| class ModelList(BaseModel):
|
| object: str = "list"
|
| data: List[ModelData]
|
|
|
|
|
|
|
| @app.middleware("http")
|
| async def error_handling(request: Request, call_next):
|
| try:
|
| return await call_next(request)
|
| except Exception as e:
|
| logger.error(f"Request failed: {str(e)}")
|
| return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "internal_server_error"}})
|
|
|
|
|
|
|
| @app.get("/v1/models")
|
| async def list_models():
|
| """返回 gemini_webapi 中声明的模型列表"""
|
| now = int(datetime.now(tz=timezone.utc).timestamp())
|
| data = [
|
| {
|
| "id": m.model_name,
|
| "object": "model",
|
| "created": now,
|
| "owned_by": "google-gemini-web",
|
| }
|
| for m in Model
|
| ]
|
| print(data)
|
| return {"object": "list", "data": data}
|
|
|
|
|
|
|
| def map_model_name(openai_model_name: str) -> Model:
|
| """根据模型名称字符串查找匹配的 Model 枚举值"""
|
|
|
| all_models = [m.model_name if hasattr(m, "model_name") else str(m) for m in Model]
|
| logger.info(f"Available models: {all_models}")
|
|
|
|
|
| for m in Model:
|
| model_name = m.model_name if hasattr(m, "model_name") else str(m)
|
| if openai_model_name.lower() in model_name.lower():
|
| return m
|
|
|
|
|
| model_keywords = {
|
| "gemini-pro": ["pro", "2.0"],
|
| "gemini-pro-vision": ["vision", "pro"],
|
| "gemini-flash": ["flash", "2.0"],
|
| "gemini-1.5-pro": ["1.5", "pro"],
|
| "gemini-1.5-flash": ["1.5", "flash"],
|
| }
|
|
|
|
|
| keywords = model_keywords.get(openai_model_name, ["pro"])
|
|
|
| for m in Model:
|
| model_name = m.model_name if hasattr(m, "model_name") else str(m)
|
| if all(kw.lower() in model_name.lower() for kw in keywords):
|
| return m
|
|
|
|
|
| return next(iter(Model))
|
|
|
|
|
|
|
| def prepare_conversation(messages: List[Message]) -> tuple:
|
| conversation = ""
|
| temp_files = []
|
|
|
| for msg in messages:
|
| if isinstance(msg.content, str):
|
|
|
| if msg.role == "system":
|
| conversation += f"System: {msg.content}\n\n"
|
| elif msg.role == "user":
|
| conversation += f"Human: {msg.content}\n\n"
|
| elif msg.role == "assistant":
|
| conversation += f"Assistant: {msg.content}\n\n"
|
| else:
|
|
|
| if msg.role == "user":
|
| conversation += "Human: "
|
| elif msg.role == "system":
|
| conversation += "System: "
|
| elif msg.role == "assistant":
|
| conversation += "Assistant: "
|
|
|
| for item in msg.content:
|
| if item.type == "text":
|
| conversation += item.text or ""
|
| elif item.type == "image_url" and item.image_url:
|
|
|
| image_url = item.image_url.get("url", "")
|
| if image_url.startswith("data:image/"):
|
|
|
| try:
|
|
|
| base64_data = image_url.split(",")[1]
|
| image_data = base64.b64decode(base64_data)
|
|
|
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
|
| tmp.write(image_data)
|
| temp_files.append(tmp.name)
|
| except Exception as e:
|
| logger.error(f"Error processing base64 image: {str(e)}")
|
|
|
| conversation += "\n\n"
|
|
|
|
|
| conversation += "Assistant: "
|
|
|
| return conversation, temp_files
|
|
|
|
|
|
|
| async def get_gemini_client():
|
| global gemini_client
|
| if gemini_client is None:
|
| try:
|
| gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
|
| await gemini_client.init(timeout=300)
|
| except Exception as e:
|
| logger.error(f"Failed to initialize Gemini client: {str(e)}")
|
| raise HTTPException(status_code=500, detail=f"Failed to initialize Gemini client: {str(e)}")
|
| return gemini_client
|
|
|
|
|
| @app.post("/v1/chat/completions")
|
| async def create_chat_completion(request: ChatCompletionRequest):
|
| try:
|
|
|
| global gemini_client
|
| if gemini_client is None:
|
| gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
|
| await gemini_client.init(timeout=300)
|
| logger.info("Gemini client initialized successfully")
|
|
|
|
|
| conversation, temp_files = prepare_conversation(request.messages)
|
| logger.info(f"Prepared conversation: {conversation}")
|
| logger.info(f"Temp files: {temp_files}")
|
|
|
|
|
| model = map_model_name(request.model)
|
| logger.info(f"Using model: {model}")
|
|
|
|
|
| logger.info("Sending request to Gemini...")
|
| if temp_files:
|
|
|
| response = await gemini_client.generate_content(conversation, files=temp_files, model=model)
|
| else:
|
|
|
| response = await gemini_client.generate_content(conversation, model=model)
|
|
|
|
|
| for temp_file in temp_files:
|
| try:
|
| os.unlink(temp_file)
|
| except Exception as e:
|
| logger.warning(f"Failed to delete temp file {temp_file}: {str(e)}")
|
|
|
|
|
| reply_text = ""
|
| if hasattr(response, "text"):
|
| reply_text = response.text
|
| else:
|
| reply_text = str(response)
|
|
|
| logger.info(f"Response: {reply_text}")
|
|
|
| if not reply_text or reply_text.strip() == "":
|
| logger.warning("Empty response received from Gemini")
|
| reply_text = "服务器返回了空响应。请检查 Gemini API 凭据是否有效。"
|
|
|
|
|
| completion_id = f"chatcmpl-{uuid.uuid4()}"
|
| created_time = int(time.time())
|
|
|
|
|
| if request.stream:
|
|
|
| async def generate_stream():
|
|
|
|
|
| data = {
|
| "id": completion_id,
|
| "object": "chat.completion.chunk",
|
| "created": created_time,
|
| "model": request.model,
|
| "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
| }
|
| yield f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
| for char in reply_text:
|
| data = {
|
| "id": completion_id,
|
| "object": "chat.completion.chunk",
|
| "created": created_time,
|
| "model": request.model,
|
| "choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None}],
|
| }
|
| yield f"data: {json.dumps(data)}\n\n"
|
|
|
| await asyncio.sleep(0.01)
|
|
|
|
|
| data = {
|
| "id": completion_id,
|
| "object": "chat.completion.chunk",
|
| "created": created_time,
|
| "model": request.model,
|
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
| }
|
| yield f"data: {json.dumps(data)}\n\n"
|
| yield "data: [DONE]\n\n"
|
|
|
| return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
| else:
|
|
|
| result = {
|
| "id": completion_id,
|
| "object": "chat.completion",
|
| "created": created_time,
|
| "model": request.model,
|
| "choices": [{"index": 0, "message": {"role": "assistant", "content": reply_text}, "finish_reason": "stop"}],
|
| "usage": {
|
| "prompt_tokens": len(conversation.split()),
|
| "completion_tokens": len(reply_text.split()),
|
| "total_tokens": len(conversation.split()) + len(reply_text.split()),
|
| },
|
| }
|
|
|
| logger.info(f"Returning response: {result}")
|
| return result
|
|
|
| except Exception as e:
|
| logger.error(f"Error generating completion: {str(e)}", exc_info=True)
|
| raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}")
|
|
|
|
|
| @app.get("/")
|
| async def root():
|
| return {"status": "online", "message": "Gemini API FastAPI Server is running"}
|
|
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
|
|
| uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info")
|
|
|