| """ |
| FastAPI Backend AI Service using Gemma-3n-E4B-it-GGUF |
| Provides OpenAI-compatible chat completion endpoints powered by llama-cpp-python |
| """ |
|
|
| import os |
| import warnings |
| import logging |
| import time |
| from contextlib import asynccontextmanager |
| from typing import List, Dict, Any, Optional, Union |
|
|
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field, field_validator |
|
|
| |
| try: |
| from llama_cpp import Llama |
| llama_cpp_available = True |
| except ImportError: |
| llama_cpp_available = False |
|
|
| import uvicorn |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| class ChatMessage(BaseModel): |
| role: str = Field(..., description="The role of the message author") |
| content: str = Field(..., description="The content of the message") |
| |
| @field_validator('role') |
| @classmethod |
| def validate_role(cls, v: str) -> str: |
| if v not in ["system", "user", "assistant"]: |
| raise ValueError("Role must be one of: system, user, assistant") |
| return v |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str = Field(default="gemma-3n-e4b-it", description="The model to use for completion") |
| messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=2048, description="Maximum tokens to generate") |
| temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0, description="Sampling temperature") |
| top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") |
| top_k: Optional[int] = Field(default=64, ge=1, le=100, description="Top-k sampling") |
| stream: Optional[bool] = Field(default=False, description="Whether to stream responses") |
|
|
| class ChatCompletionChoice(BaseModel): |
| index: int |
| message: ChatMessage |
| finish_reason: str |
|
|
| class ChatCompletionResponse(BaseModel): |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[ChatCompletionChoice] |
|
|
| class HealthResponse(BaseModel): |
| status: str |
| model: str |
| version: str |
|
|
| class ModelInfo(BaseModel): |
| id: str |
| object: str = "model" |
| created: int |
| owned_by: str = "huggingface" |
|
|
| class ModelsResponse(BaseModel): |
| object: str = "list" |
| data: List[ModelInfo] |
|
|
| |
| current_model = os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF") |
| llm = None |
|
|
| def create_gemma_chat_template(): |
| """ |
| Create a custom chat template for Gemma 3n |
| Based on the format: <bos><start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n{assistant_response}<end_of_turn> |
| """ |
| return """<bos>{% for message in messages %}{% if message['role'] == 'user' %}<start_of_turn>user |
| {{ message['content'] }}<end_of_turn> |
| {% elif message['role'] == 'assistant' %}<start_of_turn>model |
| {{ message['content'] }}<end_of_turn> |
| {% elif message['role'] == 'system' %}<start_of_turn>system |
| {{ message['content'] }}<end_of_turn> |
| {% endif %}{% endfor %}<start_of_turn>model |
| """ |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Application lifespan manager for startup and shutdown events""" |
| global llm |
| logger.info("🚀 Starting Gemma 3n Backend Service...") |
| |
| if not llama_cpp_available: |
| logger.error("❌ llama-cpp-python is not available. Please install with: pip install llama-cpp-python") |
| raise RuntimeError("llama-cpp-python not available") |
| |
| try: |
| logger.info(f"📥 Loading Gemma 3n model from {current_model}...") |
| |
| |
| |
| llm = Llama.from_pretrained( |
| repo_id=current_model, |
| filename="*q4_k_m.gguf", |
| verbose=True, |
| |
| n_ctx=4096, |
| n_threads=4, |
| n_gpu_layers=-1, |
| |
| chat_format="gemma", |
| ) |
| |
| logger.info("✅ Successfully loaded Gemma 3n model") |
| |
| except Exception as e: |
| logger.error(f"❌ Failed to initialize Gemma 3n model: {e}") |
| |
| try: |
| logger.info("🔄 Trying alternative model loading approach...") |
| |
| logger.warning("⚠️ Please download the GGUF model file locally and update the path") |
| logger.warning("⚠️ You can download from: https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF") |
| |
| |
| raise RuntimeError( |
| "Model loading failed. Please download the GGUF model locally:\n" |
| "1. Visit: https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF\n" |
| "2. Download a GGUF file (recommended: gemma-3n-e4b-it-q4_k_m.gguf)\n" |
| "3. Update the model path in the code" |
| ) |
| |
| except Exception as fallback_error: |
| logger.error(f"❌ Fallback loading also failed: {fallback_error}") |
| raise RuntimeError(f"Service initialization failed: {fallback_error}") |
| |
| yield |
| logger.info("🔄 Shutting down Gemma 3n Backend Service...") |
| if llm: |
| |
| llm = None |
|
|
| |
| app = FastAPI( |
| title="Gemma 3n Backend Service", |
| description="OpenAI-compatible chat completion API powered by Gemma-3n-E4B-it-GGUF", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| def ensure_model_ready(): |
| """Check if model is loaded and ready""" |
| if llm is None: |
| raise HTTPException(status_code=503, detail="Service not ready - Gemma 3n model not initialized") |
|
|
| def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: |
| """Convert OpenAI messages format to Gemma 3n chat format""" |
| |
| prompt_parts = ["<bos>"] |
| |
| for message in messages: |
| role = message.role |
| content = message.content |
| |
| if role == "system": |
| prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") |
| elif role == "user": |
| prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") |
| elif role == "assistant": |
| prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") |
| |
| |
| prompt_parts.append("<start_of_turn>model\n") |
| |
| return "\n".join(prompt_parts) |
|
|
| async def generate_response_gemma( |
| messages: List[ChatMessage], |
| max_tokens: int = 512, |
| temperature: float = 1.0, |
| top_p: float = 0.95, |
| top_k: int = 64 |
| ) -> str: |
| """Generate response using Gemma 3n model""" |
| ensure_model_ready() |
| |
| try: |
| |
| if hasattr(llm, 'create_chat_completion'): |
| |
| messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages] |
| |
| response = llm.create_chat_completion( |
| messages=messages_dict, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| stop=["<end_of_turn>", "<eos>"] |
| ) |
| |
| return response['choices'][0]['message']['content'].strip() |
| |
| else: |
| |
| prompt = convert_messages_to_prompt(messages) |
| |
| response = llm( |
| prompt, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| stop=["<end_of_turn>", "<eos>"], |
| echo=False |
| ) |
| |
| return response['choices'][0]['text'].strip() |
| |
| except Exception as e: |
| logger.error(f"Generation failed: {e}") |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." |
|
|
| @app.get("/", response_class=JSONResponse) |
| async def root() -> Dict[str, Any]: |
| """Root endpoint with service information""" |
| return { |
| "message": "Gemma 3n Backend Service is running!", |
| "model": current_model, |
| "version": "1.0.0", |
| "endpoints": { |
| "health": "/health", |
| "models": "/v1/models", |
| "chat_completions": "/v1/chat/completions" |
| } |
| } |
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| """Health check endpoint""" |
| return HealthResponse( |
| status="healthy" if (llm is not None) else "unhealthy", |
| model=current_model, |
| version="1.0.0" |
| ) |
|
|
| @app.get("/v1/models", response_model=ModelsResponse) |
| async def list_models(): |
| """List available models (OpenAI-compatible)""" |
| models = [ |
| ModelInfo( |
| id="gemma-3n-e4b-it", |
| created=int(time.time()), |
| owned_by="google" |
| ) |
| ] |
| |
| return ModelsResponse(data=models) |
|
|
| @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
| async def create_chat_completion( |
| request: ChatCompletionRequest |
| ) -> ChatCompletionResponse: |
| """Create a chat completion (OpenAI-compatible) using Gemma 3n""" |
| try: |
| if not request.messages: |
| raise HTTPException(status_code=400, detail="Messages cannot be empty") |
| |
| logger.info(f"Generating Gemma 3n response for {len(request.messages)} messages") |
| |
| response_text = await generate_response_gemma( |
| request.messages, |
| request.max_tokens or 512, |
| request.temperature or 1.0, |
| request.top_p or 0.95, |
| request.top_k or 64 |
| ) |
| |
| response_text = response_text.strip() if response_text else "No response generated." |
| |
| return ChatCompletionResponse( |
| id=f"chatcmpl-{int(time.time())}", |
| created=int(time.time()), |
| model=request.model, |
| choices=[ChatCompletionChoice( |
| index=0, |
| message=ChatMessage(role="assistant", content=response_text), |
| finish_reason="stop" |
| )] |
| ) |
| |
| except Exception as e: |
| logger.error(f"Error in chat completion: {e}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("gemma_backend_service:app", host="0.0.0.0", port=8000, reload=True) |
|
|