|
|
| from dotenv import load_dotenv |
| load_dotenv() |
| import os |
| import httpx |
|
|
| |
|
|
| """ |
| FastAPI Backend AI Service using Gemma-3n-E4B-it |
| Provides OpenAI-compatible chat completion endpoints powered by google/gemma-3n-E4B-it |
| """ |
| import warnings |
|
|
| |
| warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") |
| warnings.filterwarnings("ignore", message=".*slow image processor.*") |
| warnings.filterwarnings("ignore", message=".*rope_scaling.*") |
|
|
| |
| os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface") |
| |
| os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" |
| hf_token = os.environ.get("HF_TOKEN") |
| import asyncio |
| import logging |
| import time |
| from contextlib import asynccontextmanager |
| from typing import List, Dict, Any, Optional, Union |
|
|
| from fastapi import FastAPI, HTTPException, Depends, Request |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field, field_validator |
|
|
| import uvicorn |
| import requests |
| from PIL import Image |
|
|
|
|
|
|
| |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| from transformers import BitsAndBytesConfig |
| |
| from transformers import Gemma3nForConditionalGeneration, AutoProcessor |
| import torch |
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
|
|
| |
| class TextContent(BaseModel): |
| type: str = Field(default="text", description="Content type") |
| text: str = Field(..., description="Text content") |
| |
| @field_validator('type') |
| @classmethod |
| def validate_type(cls, v: str) -> str: |
| if v != "text": |
| raise ValueError("Type must be 'text'") |
| return v |
|
|
| class ImageContent(BaseModel): |
| type: str = Field(default="image", description="Content type") |
| url: str = Field(..., description="Image URL") |
| |
| @field_validator('type') |
| @classmethod |
| def validate_type(cls, v: str) -> str: |
| if v != "image": |
| raise ValueError("Type must be 'image'") |
| return v |
|
|
| |
| class ChatMessage(BaseModel): |
| role: str = Field(..., description="The role of the message author") |
| content: Union[str, List[Union[TextContent, ImageContent]]] = Field(..., description="The content of the message - either string or list of content items") |
| |
| @field_validator('role') |
| @classmethod |
| def validate_role(cls, v: str) -> str: |
| if v not in ["system", "user", "assistant"]: |
| raise ValueError("Role must be one of: system, user, assistant") |
| return v |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str = Field(default_factory=lambda: "google/gemma-3n-E4B-it", description="The model to use for completion") |
| messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=2048, description="Maximum tokens to generate") |
| temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") |
| stream: Optional[bool] = Field(default=False, description="Whether to stream responses") |
| top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") |
|
|
| class ChatCompletionChoice(BaseModel): |
| index: int |
| message: ChatMessage |
| finish_reason: str |
|
|
| class ChatCompletionResponse(BaseModel): |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[ChatCompletionChoice] |
|
|
| class ChatCompletionChunk(BaseModel): |
| id: str |
| object: str = "chat.completion.chunk" |
| created: int |
| model: str |
| choices: List[Dict[str, Any]] |
|
|
| class HealthResponse(BaseModel): |
| status: str |
| model: str |
| version: str |
|
|
| class ModelInfo(BaseModel): |
| id: str |
| object: str = "model" |
| created: int |
| owned_by: str = "huggingface" |
|
|
| class ModelsResponse(BaseModel): |
| object: str = "list" |
| data: List[ModelInfo] |
|
|
| class CompletionRequest(BaseModel): |
| prompt: str = Field(..., description="The prompt to complete") |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=2048) |
| temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) |
|
|
|
|
|
|
| |
| |
| ai_model_env = os.environ.get("AI_MODEL", "google/gemma-3n-E4B-it") |
| |
| if "GGUF" in ai_model_env: |
| current_model = "google/gemma-3n-E4B-it" |
| print(f"🔄 Overriding GGUF model {ai_model_env} with transformers-compatible model: {current_model}") |
| else: |
| current_model = ai_model_env |
| vision_model = os.environ.get("VISION_MODEL", "Salesforce/blip-image-captioning-base") |
|
|
| |
| processor = None |
| model = None |
| image_text_pipeline = None |
|
|
|
|
|
|
| |
| async def download_image(url: str) -> Image.Image: |
| """Download and process image from URL""" |
| try: |
| response = requests.get(url, timeout=10) |
| response.raise_for_status() |
| image = Image.open(requests.compat.BytesIO(response.content)) |
| return image |
| except Exception as e: |
| logger.error(f"Failed to download image from {url}: {e}") |
| raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}") |
|
|
| def extract_text_and_images(content: Union[str, List[Any]]) -> tuple[str, List[str]]: |
| """Extract text and image URLs from message content""" |
| if isinstance(content, str): |
| return content, [] |
| |
| text_parts: List[str] = [] |
| image_urls: List[str] = [] |
| |
| for item in content: |
| if hasattr(item, 'type'): |
| if item.type == "text" and hasattr(item, 'text'): |
| text_parts.append(str(item.text)) |
| elif item.type == "image" and hasattr(item, 'url'): |
| image_urls.append(str(item.url)) |
| |
| return " ".join(text_parts), image_urls |
|
|
| def has_images(messages: List[ChatMessage]) -> bool: |
| """Check if any messages contain images""" |
| for message in messages: |
| if isinstance(message.content, list): |
| for item in message.content: |
| if hasattr(item, 'type') and item.type == "image": |
| return True |
| return False |
|
|
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Application lifespan manager for startup and shutdown events""" |
| global processor, model, image_text_pipeline, current_model |
| logger.info("🚀 Starting AI Backend Service (Hugging Face Spaces mode)...") |
| logger.info(f"🔧 Using model: {current_model}") |
| try: |
| logger.info(f"📥 Loading model with transformers: {current_model}") |
| |
| |
| if "gemma-3n" in current_model.lower(): |
| logger.info("🔍 Detected Gemma 3n model - using specialized classes") |
| processor = AutoProcessor.from_pretrained(current_model) |
| model = Gemma3nForConditionalGeneration.from_pretrained( |
| current_model, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| ).eval() |
| else: |
| |
| logger.info("🔍 Using standard transformers classes") |
| processor = AutoTokenizer.from_pretrained(current_model) |
| model = AutoModelForCausalLM.from_pretrained( |
| current_model, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| ) |
| |
| logger.info(f"✅ Successfully loaded model and processor: {current_model}") |
| |
| |
| if "gemma-3n" not in current_model.lower(): |
| |
| try: |
| logger.info(f"🖼️ Initializing image captioning pipeline with model: {vision_model}") |
| image_text_pipeline = pipeline("image-to-text", model=vision_model) |
| logger.info("✅ Image captioning pipeline loaded successfully") |
| except Exception as e: |
| logger.warning(f"⚠️ Could not load image captioning pipeline: {e}") |
| image_text_pipeline = None |
| else: |
| logger.info("✅ Gemma 3n has built-in multimodal support") |
| image_text_pipeline = None |
| |
| except Exception as e: |
| logger.error(f"❌ Failed to initialize model: {e}") |
| raise RuntimeError(f"Service initialization failed: {e}") |
| yield |
| logger.info("🔄 Shutting down AI Backend Service...") |
| processor = None |
| model = None |
| image_text_pipeline = None |
|
|
| |
| app = FastAPI( |
| title="AI Backend Service - Gemma 3n", |
| description="OpenAI-compatible chat completion API powered by google/gemma-3n-E4B-it", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| def ensure_model_ready(): |
| """Check if transformers model is loaded and ready""" |
| if processor is None or model is None: |
| raise HTTPException(status_code=503, detail="Service not ready - no model initialized (transformers)") |
|
|
| def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: |
| """Convert OpenAI messages format to a single prompt string""" |
| prompt_parts: List[str] = [] |
| |
| for message in messages: |
| role = message.role |
| |
| |
| if isinstance(message.content, str): |
| content = message.content |
| else: |
| content, _ = extract_text_and_images(message.content) |
| |
| if role == "system": |
| prompt_parts.append(f"System: {content}") |
| elif role == "user": |
| prompt_parts.append(f"Human: {content}") |
| elif role == "assistant": |
| prompt_parts.append(f"Assistant: {content}") |
| |
| |
| prompt_parts.append("Assistant:") |
| |
| return "\n".join(prompt_parts) |
|
|
| async def generate_multimodal_response( |
| messages: List[ChatMessage], |
| request: ChatCompletionRequest |
| ) -> str: |
| """Generate response using image-text-to-text pipeline for multimodal content""" |
| if not image_text_pipeline: |
| raise HTTPException(status_code=503, detail="Image processing not available - pipeline not initialized") |
| |
| try: |
| |
| last_user_message = None |
| for message in reversed(messages): |
| if message.role == "user" and isinstance(message.content, list): |
| last_user_message = message |
| break |
| |
| if not last_user_message: |
| raise HTTPException(status_code=400, detail="No user message with images found") |
| |
| |
| text_content, image_urls = extract_text_and_images(last_user_message.content) |
| |
| if not image_urls: |
| raise HTTPException(status_code=400, detail="No images found in the message") |
| |
| |
| image_url = image_urls[0] |
| |
| |
| logger.info(f"🖼️ Processing image: {image_url}") |
| try: |
| |
| result = await asyncio.to_thread(lambda: image_text_pipeline(image_url)) |
| |
| |
| if result and hasattr(result, '__len__') and len(result) > 0: |
| first_result = result[0] |
| if hasattr(first_result, 'get'): |
| generated_text = first_result.get('generated_text', f'I can see an image at {image_url}.') |
| else: |
| generated_text = str(first_result) |
| |
| |
| if text_content: |
| response = f"Looking at this image, I can see: {generated_text}. " |
| if "what" in text_content.lower() or "?" in text_content: |
| response += f"Regarding your question '{text_content}': Based on what I can see, this appears to be {generated_text.lower()}." |
| else: |
| response += f"You mentioned: {text_content}" |
| return response |
| else: |
| return f"I can see: {generated_text}" |
| else: |
| return f"I can see there's an image at {image_url}, but cannot process it right now." |
| |
| except Exception as pipeline_error: |
| logger.warning(f"Pipeline error: {pipeline_error}") |
| return f"I can see there's an image at {image_url}. The image appears to contain visual content that I'm having trouble processing right now." |
| |
| except Exception as e: |
| logger.error(f"Error in multimodal generation: {e}") |
| return f"I'm having trouble processing the image. Error: {str(e)}" |
|
|
|
|
| def generate_response_local(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str: |
| """Generate response using local transformers model with chat template.""" |
| ensure_model_ready() |
| try: |
| logger.info(" Generating response using transformers model") |
| return generate_response_transformers(messages, max_tokens, temperature, top_p) |
| except Exception as e: |
| logger.error(f"Local generation failed: {e}") |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." |
|
|
| |
|
|
| def convert_messages_to_gemma_prompt(messages: List[ChatMessage]) -> str: |
| """Convert OpenAI messages format to Gemma 3n chat format.""" |
| |
| prompt_parts = ["<bos>"] |
| |
| for message in messages: |
| role = message.role |
| content = message.content |
| |
| if role == "system": |
| prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") |
| elif role == "user": |
| prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") |
| elif role == "assistant": |
| prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") |
| |
| |
| prompt_parts.append("<start_of_turn>model\n") |
| |
| return "\n".join(prompt_parts) |
|
|
| def generate_response_transformers(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str: |
| """Generate response using transformers model with chat template.""" |
| try: |
| |
| if "gemma-3n" in current_model.lower(): |
| |
| |
| chat_messages = [] |
| for m in messages: |
| |
| if isinstance(m.content, str): |
| content = [{"type": "text", "text": m.content}] |
| else: |
| |
| text_content, _ = extract_text_and_images(m.content) |
| content = [{"type": "text", "text": text_content}] |
| |
| chat_messages.append({"role": m.role, "content": content}) |
| |
| |
| inputs = processor.apply_chat_template( |
| chat_messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
| |
| |
| input_len = inputs["input_ids"].shape[-1] |
| with torch.inference_mode(): |
| generation = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=temperature > 0, |
| ) |
| generation = generation[0][input_len:] |
| |
| |
| generated_text = processor.decode(generation, skip_special_tokens=True) |
| return generated_text.strip() |
| |
| else: |
| |
| |
| chat_messages = [] |
| for m in messages: |
| content_str = m.content if isinstance(m.content, str) else extract_text_and_images(m.content)[0] |
| chat_messages.append({"role": m.role, "content": content_str}) |
| |
| |
| inputs = processor.apply_chat_template( |
| chat_messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
| |
| outputs = model.generate( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs.get("attention_mask"), |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=temperature > 0, |
| ) |
| |
| generated_text = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) |
| return generated_text.strip() |
| |
| except Exception as e: |
| logger.error(f"Transformers generation failed: {e}") |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." |
|
|
|
|
| @app.get("/", response_class=JSONResponse) |
| async def root() -> Dict[str, Any]: |
| """Root endpoint with service information""" |
| return { |
| "message": "AI Backend Service is running with Mistral Nemo!", |
| "model": current_model, |
| "version": "1.0.0", |
| "endpoints": { |
| "health": "/health", |
| "models": "/v1/models", |
| "chat_completions": "/v1/chat/completions" |
| } |
| } |
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| """Health check endpoint""" |
| global current_model, tokenizer, model |
| return HealthResponse( |
| status="healthy" if (tokenizer is not None and model is not None) else "unhealthy", |
| model=current_model, |
| version="1.0.0" |
| ) |
|
|
| @app.get("/v1/models", response_model=ModelsResponse) |
| async def list_models(): |
| """List available models (OpenAI-compatible)""" |
| |
| models = [ |
| ModelInfo( |
| id=current_model, |
| created=int(time.time()), |
| owned_by="huggingface" |
| ) |
| ] |
| |
| |
| if image_text_pipeline: |
| models.append( |
| ModelInfo( |
| id=vision_model, |
| created=int(time.time()), |
| owned_by="huggingface" |
| ) |
| ) |
| |
| return ModelsResponse(data=models) |
|
|
| |
| |
|
|
|
|
|
|
|
|
| |
| @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
| async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: |
| """Create a chat completion (OpenAI-compatible) with multimodal support. Hugging Face Spaces: Only transformers backend supported.""" |
| try: |
| if not request.messages: |
| raise HTTPException(status_code=400, detail="Messages cannot be empty") |
| is_multimodal = has_images(request.messages) |
| if is_multimodal: |
| if not image_text_pipeline: |
| raise HTTPException(status_code=503, detail="Image processing not available") |
| response_text = await generate_multimodal_response(request.messages, request) |
| else: |
| logger.info(f"Generating local response for messages: {request.messages}") |
| response_text = await asyncio.to_thread( |
| generate_response_local, |
| request.messages, |
| request.max_tokens or 512, |
| request.temperature or 0.7, |
| request.top_p or 0.95 |
| ) |
| response_text = response_text.strip() if response_text else "No response generated." |
| return ChatCompletionResponse( |
| id=f"chatcmpl-{int(time.time())}", |
| created=int(time.time()), |
| model=request.model, |
| choices=[ChatCompletionChoice( |
| index=0, |
| message=ChatMessage(role="assistant", content=response_text), |
| finish_reason="stop" |
| )] |
| ) |
| except Exception as e: |
| logger.error(f"Error in chat completion: {e}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
|
|
|
| @app.post("/v1/completions") |
| async def create_completion( |
| request: CompletionRequest |
| ) -> Dict[str, Any]: |
| """Create a text completion (OpenAI-compatible)""" |
| try: |
| if not request.prompt: |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") |
| ensure_model_ready() |
| |
| messages = [ChatMessage(role="user", content=request.prompt)] |
| response_text = await asyncio.to_thread( |
| generate_response_local, |
| messages, |
| request.max_tokens or 512, |
| request.temperature or 0.7, |
| 0.95 |
| ) |
| return { |
| "id": f"cmpl-{int(time.time())}", |
| "object": "text_completion", |
| "created": int(time.time()), |
| "model": current_model, |
| "choices": [{ |
| "text": response_text, |
| "index": 0, |
| "finish_reason": "stop" |
| }] |
| } |
| except Exception as e: |
| logger.error(f"Error in completion: {e}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
| @app.post("/api/response") |
| async def api_response(request: Request) -> JSONResponse: |
| """Endpoint to receive and send responses via API.""" |
| try: |
| data = await request.json() |
| message = data.get("message", "No message provided") |
| return JSONResponse(content={ |
| "status": "success", |
| "received_message": message, |
| "response_message": f"You sent: {message}" |
| }) |
| except Exception as e: |
| logger.error(f"Error processing API response: {e}") |
| raise HTTPException(status_code=500, detail="Internal server error") |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("backend_service:app", host="0.0.0.0", port=8000, reload=True) |
|
|