Spaces:
Running
Running
| """ | |
| FastAPI wrapper for the RAG Book Assistant agent. | |
| This module provides a standalone FastAPI application that exposes the | |
| /chat endpoint using the agent defined in agent.py. It is separate from | |
| agent.py to allow independent deployment and testing. | |
| """ | |
| import os | |
| import sys | |
| import uuid | |
| import asyncio | |
| import logging | |
| import traceback | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Optional | |
| from collections import deque | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, validator | |
| from dotenv import load_dotenv | |
| # Make backend package importable | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| # Load environment | |
| load_dotenv() | |
| # Import agent components | |
| try: | |
| from agent import get_agent, Source as AgentSource | |
| from agents import Runner, ToolCallOutputItem | |
| except ImportError as e: | |
| raise ImportError(f"Failed to import agent module: {e}") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="RAG Chatbot API", | |
| version="1.0.0", | |
| description="FastAPI wrapper for RAG Book Assistant", | |
| ) | |
| # ============ CORS Configuration ============ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:3000", | |
| "http://127.0.0.1:3000", | |
| "https://hackathon-1-humanoid-ai-robotics.vercel.app", | |
| "https://*.vercel.app", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
| allow_headers=["Content-Type", "Authorization"], | |
| ) | |
| # ============ Logging Configuration ============ | |
| class InMemoryLogHandler(logging.Handler): | |
| """Handler that stores log records in memory with a max size limit.""" | |
| def __init__(self, max_logs=500): | |
| super().__init__() | |
| self.logs = deque(maxlen=max_logs) | |
| def emit(self, record): | |
| try: | |
| message = self.format(record) | |
| if record.exc_info: | |
| message = f"{message}\n{traceback.format_exception(*record.exc_info)}" | |
| log_entry = { | |
| "timestamp": datetime.fromtimestamp(record.created).isoformat(), | |
| "level": record.levelname, | |
| "logger": record.name, | |
| "message": message, | |
| } | |
| self.logs.append(log_entry) | |
| except Exception: | |
| self.handleError(record) | |
| # Set up in-memory logging | |
| log_handler = InMemoryLogHandler(max_logs=500) | |
| log_formatter = logging.Formatter( | |
| "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| log_handler.setFormatter(log_formatter) | |
| # Add handler to root logger and FastAPI logger | |
| root_logger = logging.getLogger() | |
| root_logger.addHandler(log_handler) | |
| if root_logger.level == logging.NOTSET: | |
| root_logger.setLevel(logging.INFO) | |
| # Also capture uvicorn logs | |
| uvicorn_logger = logging.getLogger("uvicorn") | |
| uvicorn_logger.addHandler(log_handler) | |
| uvicorn_logger.setLevel(logging.INFO) | |
| # ============ Pydantic Models ============ | |
| class ChatRequest(BaseModel): | |
| question: str = Field(..., min_length=1, max_length=1000) | |
| def validate_question(cls, v): | |
| if not v or not v.strip(): | |
| raise ValueError("Question cannot be empty") | |
| return v.strip() | |
| class Source(BaseModel): | |
| url: str | |
| chunk_index: int | |
| text_snippet: str | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| sources: List[Source] | |
| tokens_used: int | |
| agent_trace: Optional[str] = None | |
| class HealthStatus(BaseModel): | |
| status: str | |
| qdrant: str | |
| openai: str | |
| timestamp: str | |
| class LogEntry(BaseModel): | |
| timestamp: str | |
| level: str | |
| logger: str | |
| message: str | |
| class LogsResponse(BaseModel): | |
| logs: List[LogEntry] | |
| total_entries: int | |
| # ============ Health Check ============ | |
| def check_qdrant_health() -> str: | |
| try: | |
| from config import get_config | |
| from qdrant_client import QdrantClient | |
| cfg = get_config() | |
| client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"]) | |
| client.get_collection(cfg["qdrant_collection"]) | |
| return "connected" | |
| except Exception as e: | |
| return "disconnected" | |
| def check_openai_health() -> str: | |
| try: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return "disconnected" | |
| import openai | |
| client = openai.OpenAI(api_key=api_key) | |
| client.models.list() | |
| return "connected" | |
| except Exception: | |
| return "disconnected" | |
| async def health_check(): | |
| qdrant = check_qdrant_health() | |
| openai = check_openai_health() | |
| status = ( | |
| "healthy" if qdrant == "connected" and openai == "connected" else "degraded" | |
| ) | |
| return HealthStatus( | |
| status=status, | |
| qdrant=qdrant, | |
| openai=openai, | |
| timestamp=datetime.utcnow().isoformat() + "Z", | |
| ) | |
| # ============ Logs Endpoint ============ | |
| async def get_logs(limit: Optional[int] = None): | |
| """ | |
| Retrieve application logs. | |
| Args: | |
| limit: Optional maximum number of logs to return (default: all, max 500) | |
| Returns: | |
| LogsResponse with list of log entries | |
| """ | |
| all_logs = list(log_handler.logs) | |
| # Apply limit if specified | |
| if limit is not None and limit > 0: | |
| all_logs = all_logs[-limit:] if limit < len(all_logs) else all_logs | |
| return LogsResponse( | |
| logs=[LogEntry(**log) for log in all_logs], | |
| total_entries=len(all_logs), | |
| ) | |
| # ============ Root Endpoint ============ | |
| async def root(logs: Optional[str] = None): | |
| """ | |
| Root endpoint that handles various query parameters. | |
| Args: | |
| logs: If set to 'container', returns application logs | |
| Returns: | |
| Logs if logs=container, otherwise returns API info | |
| """ | |
| if logs == "container": | |
| all_logs = list(log_handler.logs) | |
| return { | |
| "logs": [LogEntry(**log) for log in all_logs], | |
| "total_entries": len(all_logs), | |
| } | |
| return { | |
| "name": "RAG Chatbot API", | |
| "version": "1.0.0", | |
| "endpoints": [ | |
| {"method": "GET", "path": "/health", "description": "Health check"}, | |
| {"method": "GET", "path": "/logs", "description": "Get application logs"}, | |
| { | |
| "method": "GET", | |
| "path": "/?logs=container", | |
| "description": "Get container logs", | |
| }, | |
| {"method": "POST", "path": "/chat", "description": "Chat endpoint"}, | |
| ], | |
| } | |
| # ============ Chat Endpoint ============ | |
| async def chat_endpoint(request: ChatRequest): | |
| request_id = str(uuid.uuid4())[:8] | |
| question = request.question.strip() | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"[{request_id}] Chat request: {question[:100]}...") | |
| try: | |
| logger.info(f"[{request_id}] Initializing agent...") | |
| agent = get_agent() | |
| logger.info(f"[{request_id}] Agent initialized successfully") | |
| # Run agent with timeout (60s to accommodate full workflow and large questions) | |
| logger.info(f"[{request_id}] Running agent with 60s timeout...") | |
| result = await asyncio.wait_for(Runner.run(agent, question), timeout=60.0) | |
| logger.info(f"[{request_id}] Agent completed successfully") | |
| # Extract sources from tool call outputs | |
| sources = [] | |
| if result.new_items: | |
| for item in result.new_items: | |
| if isinstance(item, ToolCallOutputItem): | |
| output = item.output | |
| if isinstance(output, list): | |
| for chunk in output: | |
| sources.append( | |
| Source( | |
| url=chunk.get("url", ""), | |
| chunk_index=chunk.get("chunk_index", 0), | |
| text_snippet=chunk.get("text", "")[:200], | |
| ) | |
| ) | |
| # Get token usage | |
| tokens_used = 0 | |
| if result.context_wrapper and hasattr(result.context_wrapper, "usage"): | |
| tokens_used = result.context_wrapper.usage.total_tokens | |
| logger.info(f"[{request_id}] Returning response with {len(sources)} sources") | |
| return ChatResponse( | |
| answer=result.final_output, | |
| sources=sources, | |
| tokens_used=tokens_used, | |
| agent_trace=f"{request_id}: completed", | |
| ) | |
| except asyncio.TimeoutError: | |
| logger.warning(f"[{request_id}] Request timeout after 60s") | |
| return JSONResponse( | |
| status_code=504, | |
| content={ | |
| "error": "timeout", | |
| "message": "The chatbot is taking too long to respond. Please try a shorter question.", | |
| }, | |
| ) | |
| except Exception as e: | |
| # Log the full exception for debugging | |
| error_msg = f"Chat endpoint error [{request_id}]: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| if "openai" in str(e).lower() or "rate limit" in str(e).lower(): | |
| return JSONResponse( | |
| status_code=503, | |
| content={ | |
| "error": "openai_failed", | |
| "message": "The AI service is currently unavailable. Please try again in a few minutes.", | |
| }, | |
| ) | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "internal_error", | |
| "message": "An unexpected error occurred. Please refresh the page and try again.", | |
| "request_id": request_id, | |
| }, | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |