Spaces:
Running
Running
| """ | |
| RAG Agent FastAPI Server using OpenAI Agents SDK | |
| Provides POST /chat endpoint for grounded Q&A using OpenAI Agents SDK | |
| and retrieval from Qdrant via Spec-2's retrieve.py module. | |
| """ | |
| import os | |
| import sys | |
| import uuid | |
| import asyncio | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, validator | |
| from dotenv import load_dotenv | |
| # Load environment first | |
| load_dotenv() | |
| from agents import OpenAIChatCompletionsModel | |
| from openai import AsyncOpenAI | |
| # Get OpenRouter API key from environment | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| if not OPENROUTER_API_KEY: | |
| raise ValueError( | |
| "OPENROUTER_API_KEY environment variable must be set. " | |
| "Get a free key from https://openrouter.ai/" | |
| ) | |
| # Configure AsyncOpenAI client for OpenRouter | |
| client = AsyncOpenAI( | |
| api_key=OPENROUTER_API_KEY, | |
| base_url="https://openrouter.ai/api/v1", | |
| ) | |
| # Use OpenRouter's free model: tencent/hy3-preview:free | |
| third_party_model = OpenAIChatCompletionsModel( | |
| openai_client=client, model="tencent/hy3-preview:free" | |
| ) | |
| # Make backend package importable | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| # Import backend modules | |
| from config import get_config | |
| from retrieve import search as retrieve_search | |
| from logging_config import setup_logging | |
| # Import OpenAI Agents SDK (must be installed separately) | |
| try: | |
| from agents import Agent, Runner, function_tool, ModelSettings, ToolCallOutputItem | |
| except ImportError: | |
| raise ImportError( | |
| "openai-agents package required. Install: pip install openai-agents" | |
| ) | |
| # Setup logging | |
| logger = setup_logging("agent") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="RAG Book Chatbot API", | |
| version="1.0.0", | |
| description="Chatbot for humanoid robotics book using OpenAI Agents SDK", | |
| ) | |
| # ============ CORS Configuration ============ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:3000", | |
| "http://127.0.0.1:3000", | |
| "https://hackathon-1-humanoid-ai-robotics.vercel.app", | |
| "https://*.vercel.app", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
| allow_headers=["Content-Type", "Authorization"], | |
| ) | |
| # ============ Pydantic Models ============ | |
| class ChatRequest(BaseModel): | |
| question: str = Field(..., min_length=1, max_length=1000) | |
| def validate_question(cls, v): | |
| if not v or not v.strip(): | |
| raise ValueError("Question cannot be empty") | |
| return v.strip() | |
| class Source(BaseModel): | |
| url: str | |
| chunk_index: int | |
| text_snippet: str | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| sources: List[Source] | |
| tokens_used: int | |
| agent_trace: Optional[str] = None | |
| class HealthStatus(BaseModel): | |
| status: str | |
| qdrant: str | |
| openai: str | |
| timestamp: str | |
| # ============ Retrieval Tool ============ | |
| def retrieve_chunks(query: str, top_k: int = 5) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve relevant book chunks from Qdrant. | |
| Args: | |
| query: User's question | |
| top_k: Number of chunks to retrieve (default: 5, max: 10) | |
| Returns: | |
| List of chunks with url, chunk_index, text, score, and source_number | |
| """ | |
| logger.info( | |
| f"[Tool] retrieve_chunks called: query='{query[:100]}...', top_k={top_k}" | |
| ) | |
| try: | |
| import cohere | |
| from qdrant_client import QdrantClient | |
| cfg = get_config() | |
| cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"]) | |
| qdrant_client = QdrantClient( | |
| url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"] | |
| ) | |
| collection_name = cfg["qdrant_collection"] | |
| results = retrieve_search( | |
| query_text=query, | |
| cohere_client=cohere_client, | |
| qdrant_client=qdrant_client, | |
| collection_name=collection_name, | |
| top_k=top_k, | |
| ) | |
| chunks = [] | |
| for i, result in enumerate(results): | |
| payload = result.get("payload", {}) | |
| chunks.append( | |
| { | |
| "url": payload.get("url", ""), | |
| "chunk_index": payload.get("chunk_index", i), | |
| "text": payload.get("text", ""), | |
| "score": result.get("score", 0.0), | |
| "source_number": i + 1, | |
| } | |
| ) | |
| logger.info(f"[Tool] Retrieved {len(chunks)} chunks") | |
| return chunks | |
| except Exception as e: | |
| logger.error(f"[Tool] Retrieval failed: {e}", exc_info=True) | |
| raise | |
| # ============ Agent Definition ============ | |
| def get_agent_instructions() -> str: | |
| return """You are a helpful assistant answering questions about a humanoid robotics book. | |
| IMPORTANT GROUNDING RULES: | |
| 1. Answer ONLY using the retrieved book content provided by the retrieve_chunks tool. | |
| 2. Do NOT use external knowledge or make up information. | |
| 3. If the retrieved content does not contain relevant information, say "I couldn't find relevant information in the book." | |
| 4. Always cite your sources using the format [Source 1], [Source 2], etc. Each source number corresponds to the chunk number from the tool. | |
| 5. Be concise and accurate. | |
| Your responses should be helpful, clear, and grounded exclusively in the provided context.""" | |
| def create_agent(): | |
| return Agent( | |
| name="RAG Book Assistant", | |
| instructions=get_agent_instructions(), | |
| tools=[retrieve_chunks], | |
| model=third_party_model, | |
| model_settings=ModelSettings(temperature=0.7, max_tokens=500), | |
| ) | |
| _agent_instance = None | |
| def get_agent(): | |
| """Lazy singleton agent instance.""" | |
| global _agent_instance | |
| if _agent_instance is None: | |
| _agent_instance = create_agent() | |
| return _agent_instance | |
| # ============ Health Checks ============ | |
| def check_qdrant_health() -> str: | |
| try: | |
| from qdrant_client import QdrantClient | |
| cfg = get_config() | |
| client = QdrantClient(url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"]) | |
| client.get_collection(cfg["qdrant_collection"]) | |
| return "connected" | |
| except Exception as e: | |
| logger.warning(f"Qdrant health check failed: {e}") | |
| return "disconnected" | |
| def check_openai_health() -> str: | |
| try: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return "disconnected" | |
| import openai | |
| client = openai.OpenAI(api_key=api_key) | |
| # Simple models.list call to verify API connectivity | |
| client.models.list() | |
| return "connected" | |
| except Exception as e: | |
| logger.warning(f"OpenAI health check failed: {e}") | |
| return "disconnected" | |
| # ============ FastAPI Endpoints ============ | |
| async def chat_endpoint(request: ChatRequest): | |
| request_id = str(uuid.uuid4())[:8] | |
| question = request.question.strip() | |
| logger.info(f"[{request_id}] Received chat: {question[:100]}...") | |
| try: | |
| logger.info(f"[{request_id}] Initializing agent...") | |
| agent = get_agent() | |
| logger.info(f"[{request_id}] Agent initialized successfully") | |
| # Use async Runner.run (native async, no blocking) | |
| logger.info(f"[{request_id}] Starting agent run...") | |
| result = await asyncio.wait_for( | |
| Runner.run(agent, question), | |
| timeout=60.0, # Increased to 60s to handle large questions | |
| ) | |
| logger.info(f"[{request_id}] Agent run completed") | |
| # Extract sources from tool call outputs | |
| sources = [] | |
| if result.new_items: | |
| for item in result.new_items: | |
| if isinstance(item, ToolCallOutputItem): | |
| output = item.output | |
| if isinstance(output, list): | |
| for chunk in output: | |
| sources.append( | |
| Source( | |
| url=chunk.get("url", ""), | |
| chunk_index=chunk.get("chunk_index", 0), | |
| text_snippet=chunk.get("text", "")[:200], | |
| ) | |
| ) | |
| # Get token usage | |
| tokens_used = 0 | |
| if result.context_wrapper and hasattr(result.context_wrapper, "usage"): | |
| tokens_used = result.context_wrapper.usage.total_tokens | |
| response = ChatResponse( | |
| answer=result.final_output, | |
| sources=sources, | |
| tokens_used=tokens_used, | |
| agent_trace=f"{request_id}: completed", | |
| ) | |
| logger.info( | |
| f"[{request_id}] Completed: tokens={tokens_used}, sources={len(sources)}" | |
| ) | |
| return response | |
| except asyncio.TimeoutError: | |
| logger.error(f"[{request_id}] Timeout after 60s") | |
| return JSONResponse( | |
| status_code=504, | |
| content={ | |
| "error": "timeout", | |
| "message": "The chatbot is taking too long to respond. Please try a shorter or simpler question.", | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"[{request_id}] Error during chat: {type(e).__name__}: {e}", exc_info=True | |
| ) | |
| # Check for specific error types | |
| error_str = str(e).lower() | |
| if ( | |
| "api" in error_str | |
| or "authentication" in error_str | |
| or "401" in error_str | |
| or "403" in error_str | |
| ): | |
| status_code = 503 | |
| error_code = "api_auth_failed" | |
| message = "Authentication failed with API provider. Check your OPENROUTER_API_KEY." | |
| elif "openrouter" in error_str or "openai" in error_str: | |
| status_code = 503 | |
| error_code = "openai_failed" | |
| message = f"API service error: {str(e)[:100]}" | |
| else: | |
| status_code = 500 | |
| error_code = "internal_error" | |
| message = f"Internal error: {str(e)[:100]}" | |
| logger.error(f"[{request_id}] Returning {status_code}: {error_code}") | |
| return JSONResponse( | |
| status_code=status_code, | |
| content={"error": error_code, "message": message, "request_id": request_id}, | |
| ) | |
| async def health_check(): | |
| request_id = str(uuid.uuid4())[:8] | |
| qdrant = check_qdrant_health() | |
| openai = check_openai_health() # sync call | |
| status = ( | |
| "healthy" if qdrant == "connected" and openai == "connected" else "degraded" | |
| ) | |
| return HealthStatus( | |
| status=status, | |
| qdrant=qdrant, | |
| openai=openai, | |
| timestamp=datetime.utcnow().isoformat() + "Z", | |
| ) | |
| async def test_agent_endpoint(): | |
| """Test endpoint to verify agent can be initialized.""" | |
| request_id = str(uuid.uuid4())[:8] | |
| try: | |
| logger.info(f"[{request_id}] Testing agent initialization...") | |
| agent = get_agent() | |
| logger.info(f"[{request_id}] Agent initialized successfully") | |
| return { | |
| "status": "ok", | |
| "message": "Agent initialized successfully", | |
| "agent_name": agent.name if hasattr(agent, "name") else "unknown", | |
| } | |
| except Exception as e: | |
| logger.error(f"[{request_id}] Agent init test failed: {e}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "status": "error", | |
| "message": f"Agent initialization failed: {str(e)}", | |
| "request_id": request_id, | |
| }, | |
| ) | |
| async def startup_event(): | |
| logger.info("=" * 60) | |
| logger.info("RAG Agent FastAPI Server Starting") | |
| logger.info("=" * 60) | |
| if not os.getenv("OPENROUTER_API_KEY"): | |
| logger.error("OPENROUTER_API_KEY not set - chat will fail!") | |
| else: | |
| logger.info("OPENROUTER_API_KEY is configured") | |
| # Test retrieval | |
| try: | |
| import cohere | |
| from qdrant_client import QdrantClient | |
| cfg = get_config() | |
| cohere_client = cohere.ClientV2(api_key=cfg["cohere_api_key"]) | |
| qdrant_client = QdrantClient( | |
| url=cfg["qdrant_url"], api_key=cfg["qdrant_api_key"] | |
| ) | |
| test_result = retrieve_search( | |
| query_text="test", | |
| cohere_client=cohere_client, | |
| qdrant_client=qdrant_client, | |
| collection_name=cfg["qdrant_collection"], | |
| top_k=1, | |
| ) | |
| logger.info(f"Retrieval test OK: {len(test_result)} results") | |
| except Exception as e: | |
| logger.error(f"Retrieval test failed: {e}") | |
| logger.info("Server startup complete") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |