| import os |
| import json |
| import requests |
| from fastapi import FastAPI, Header, HTTPException |
| from fastapi.responses import StreamingResponse, PlainTextResponse, Response |
| from pydantic import BaseModel |
| from typing import List, Optional, Union |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| app = FastAPI() |
|
|
| GEMMA_API_KEY = os.getenv("GEMMA_API_KEY") |
| APP_API_KEY = os.getenv("APP_API_KEY") |
| GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" |
|
|
|
|
| class Message(BaseModel): |
| role: str |
| content: Union[str, List[dict]] |
|
|
|
|
| class ChatRequest(BaseModel): |
| model: str |
| messages: List[Message] |
| stream: Optional[bool] = False |
| plain: Optional[bool] = False |
|
|
|
|
| class UTF8JSONResponse(Response): |
| media_type = "application/json; charset=utf-8" |
|
|
| def render(self, content) -> bytes: |
| return json.dumps( |
| content, |
| ensure_ascii=False, |
| separators=(",", ":"), |
| ).encode("utf-8") |
|
|
|
|
| def extract_text(messages): |
| text = "" |
|
|
| for msg in messages: |
| content = msg.content |
|
|
| if isinstance(content, list): |
| for item in content: |
| if item.get("type") == "text": |
| text += item.get("text", "") + "\n" |
| else: |
| continue |
| elif isinstance(content, str): |
| text += content + "\n" |
|
|
| return text.strip() |
|
|
|
|
| def fix_mojibake(text: str) -> str: |
| """ |
| Repairs common UTF-8-as-latin1 mojibake such as: |
| Itâs -> It’s |
| """ |
| if not isinstance(text, str): |
| return text |
|
|
| suspicious = ("’", "“", "â€", "‘", "–", "—", "Ã") |
| if any(s in text for s in suspicious): |
| try: |
| return text.encode("latin1").decode("utf-8") |
| except UnicodeError: |
| return text |
|
|
| return text |
|
|
|
|
| def gemini_generate_url(model_name: str) -> str: |
| return f"{GEMINI_BASE_URL}/models/{model_name}:generateContent?key={GEMMA_API_KEY}" |
|
|
|
|
| def gemini_stream_url(model_name: str) -> str: |
| return f"{GEMINI_BASE_URL}/models/{model_name}:streamGenerateContent?alt=sse&key={GEMMA_API_KEY}" |
|
|
|
|
| def extract_gemini_text(payload: dict) -> str: |
| candidates = payload.get("candidates") or [] |
| if not candidates: |
| return "" |
|
|
| content = candidates[0].get("content") or {} |
| parts = content.get("parts") or [] |
| if not parts: |
| return "" |
|
|
| return parts[0].get("text", "") or "" |
|
|
|
|
| @app.post("/v1/chat/completions") |
| def chat_completions( |
| request: ChatRequest, |
| authorization: Optional[str] = Header(None), |
| ): |
| if not authorization: |
| raise HTTPException(status_code=401, detail="Missing Authorization header") |
|
|
| token = authorization.replace("Bearer ", "").strip() |
| if token != APP_API_KEY: |
| raise HTTPException(status_code=403, detail="Invalid API key") |
|
|
| if not GEMMA_API_KEY: |
| raise HTTPException(status_code=500, detail="GEMMA_API_KEY is not set") |
|
|
| model_name = request.model or "gemma-3-27b-it" |
| prompt = extract_text(request.messages) |
|
|
| payload = { |
| "contents": [ |
| { |
| "parts": [{"text": prompt}] |
| } |
| ] |
| } |
|
|
| |
| if request.stream: |
| def generate(): |
| try: |
| url = gemini_stream_url(model_name) |
|
|
| with requests.post( |
| url, |
| json=payload, |
| stream=True, |
| timeout=120, |
| headers={"Content-Type": "application/json"}, |
| ) as res: |
| res.raise_for_status() |
| res.encoding = "utf-8" |
|
|
| sent_role = False |
|
|
| for raw_line in res.iter_lines(decode_unicode=True): |
| if not raw_line: |
| continue |
|
|
| line = raw_line.strip() |
|
|
| if line.startswith("data:"): |
| line = line[5:].strip() |
|
|
| if not line or line == "[DONE]": |
| continue |
|
|
| try: |
| chunk_json = json.loads(line) |
| except json.JSONDecodeError: |
| continue |
|
|
| text = extract_gemini_text(chunk_json) |
| if not text: |
| continue |
|
|
| text = fix_mojibake(text) |
|
|
| delta = {"content": text} |
| if not sent_role: |
| delta["role"] = "assistant" |
| sent_role = True |
|
|
| openai_chunk = { |
| "id": "chatcmpl-gemma", |
| "object": "chat.completion.chunk", |
| "choices": [ |
| { |
| "index": 0, |
| "delta": delta, |
| "finish_reason": None, |
| } |
| ], |
| } |
|
|
| yield f"data: {json.dumps(openai_chunk, ensure_ascii=False)}\n\n" |
|
|
| yield "data: [DONE]\n\n" |
|
|
| except Exception as e: |
| error_chunk = {"error": str(e)} |
| yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n" |
| yield "data: [DONE]\n\n" |
|
|
| return StreamingResponse( |
| generate(), |
| media_type="text/event-stream; charset=utf-8", |
| headers={"Cache-Control": "no-cache"}, |
| ) |
|
|
| |
| try: |
| url = gemini_generate_url(model_name) |
| res = requests.post( |
| url, |
| json=payload, |
| timeout=120, |
| headers={"Content-Type": "application/json"}, |
| ) |
| res.raise_for_status() |
| res.encoding = "utf-8" |
|
|
| data = res.json() |
| output = extract_gemini_text(data) |
| output = fix_mojibake(output) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| if request.plain: |
| return PlainTextResponse( |
| output, |
| media_type="text/plain; charset=utf-8", |
| ) |
|
|
| |
| return UTF8JSONResponse({ |
| "id": "chatcmpl-gemma", |
| "object": "chat.completion", |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": output, |
| }, |
| "finish_reason": "stop", |
| } |
| ], |
| }) |