Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import os | |
| import logging | |
| import sys | |
| from dotenv import load_dotenv | |
| from .config import DATASET_CONFIGS, load_prompt_template | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletionMessageParam | |
| import json | |
| # Load environment variables | |
| load_dotenv() | |
| # Lazy imports to avoid blocking startup | |
| # from .pipeline import RAGPipeline # Will import when needed | |
| # import umap # Will import when needed for visualization | |
| # import plotly.express as px # Will import when needed for visualization | |
| # import plotly.graph_objects as go # Will import when needed for visualization | |
| # from plotly.subplots import make_subplots # Will import when needed for visualization | |
| # import numpy as np # Will import when needed for visualization | |
| # from sklearn.preprocessing import normalize # Will import when needed for visualization | |
| # import pandas as pd # Will import when needed for visualization | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="RAG Pipeline API", description="Multi-dataset RAG API", version="1.0.0") | |
| # Initialize OpenRouter client | |
| openrouter_api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not openrouter_api_key: | |
| raise ValueError("OPENROUTER_API_KEY environment variable is not set") | |
| openrouter_client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=openrouter_api_key | |
| ) | |
| # Model configuration | |
| MODEL_NAME = "z-ai/glm-4.5-air:free" | |
| # Initialize pipelines for all datasets | |
| pipelines = {} | |
| google_api_key = os.getenv("GOOGLE_API_KEY") | |
| logger.info(f"Starting RAG Pipeline API") | |
| logger.info(f"Port from env: {os.getenv('PORT', 'Not set - will use 8000')}") | |
| logger.info(f"Google API Key present: {'Yes' if google_api_key else 'No'}") | |
| logger.info(f"Available datasets: {list(DATASET_CONFIGS.keys())}") | |
| # Define tools for the GLM model | |
| def rag_qa(question: str, dataset: str = "developer-portfolio") -> str: | |
| """ | |
| Get answers from the RAG pipeline for specific questions about the dataset. | |
| Args: | |
| question: The question to answer using the RAG pipeline | |
| dataset: The dataset to search in (default: developer-portfolio) | |
| Returns: | |
| Answer from the RAG pipeline | |
| """ | |
| try: | |
| # Check if pipelines are loaded | |
| if not pipelines: | |
| return "RAG Pipeline is running but datasets are still loading in the background. Please try again in a moment." | |
| # Select the appropriate pipeline based on dataset | |
| if dataset not in pipelines: | |
| return f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}" | |
| selected_pipeline = pipelines[dataset] | |
| answer = selected_pipeline.answer_question(question) | |
| return answer | |
| except Exception as e: | |
| return f"Error accessing RAG pipeline: {str(e)}" | |
| # Tool definitions for GLM | |
| TOOLS = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "rag_qa", | |
| "description": "Get answers from the RAG pipeline for specific questions about datasets", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "question": { | |
| "type": "string", | |
| "description": "The question to answer using the RAG pipeline" | |
| }, | |
| "dataset": { | |
| "type": "string", | |
| "description": "The dataset to search in (default: developer-portfolio)", | |
| "default": "developer-portfolio" | |
| } | |
| }, | |
| "required": ["question"] | |
| } | |
| } | |
| } | |
| ] | |
| # Don't load datasets during startup - do it asynchronously after server starts | |
| logger.info("RAG Pipeline API is ready to serve requests - datasets will load in background") | |
| # Visualization function disabled to speed up startup | |
| # def create_3d_visualization(pipeline): | |
| # ... (commented out for faster startup) | |
| class Question(BaseModel): | |
| text: str | |
| dataset: str = "developer-portfolio" # Default dataset | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: list[ChatMessage] | |
| dataset: str = "developer-portfolio" # Default dataset | |
| async def chat_with_ai(request: ChatRequest): | |
| """ | |
| Chat with the AI assistant. The AI will use the RAG pipeline when needed to answer questions about the datasets. | |
| """ | |
| try: | |
| # Convert messages to OpenAI format with proper typing | |
| messages: list[ChatCompletionMessageParam] = [ | |
| {"role": msg.role, "content": msg.content} # type: ignore | |
| for msg in request.messages | |
| ] | |
| # Add system message to guide the AI | |
| if request.dataset == "developer-portfolio": | |
| system_message: ChatCompletionMessageParam = { | |
| "role": "system", | |
| "content": load_prompt_template("system-instruction.txt") | |
| } | |
| else: | |
| system_message: ChatCompletionMessageParam = { | |
| "role": "system", | |
| "content": load_prompt_template("generic-system-instruction.txt") | |
| } | |
| messages.insert(0, system_message) | |
| # Make the API call with tools | |
| response = openrouter_client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| tools=TOOLS, # type: ignore | |
| tool_choice="auto" | |
| ) | |
| message = response.choices[0].message | |
| finish_reason = response.choices[0].finish_reason | |
| # Handle tool calls | |
| if finish_reason == "tool_calls" and hasattr(message, 'tool_calls') and message.tool_calls: | |
| tool_results = [] | |
| # Execute tool calls | |
| for tool_call in message.tool_calls: | |
| if tool_call.function.name == "rag_qa": | |
| # Parse arguments | |
| args = json.loads(tool_call.function.arguments) | |
| question = args.get("question") | |
| dataset = args.get("dataset", request.dataset) | |
| # Call the rag_qa function | |
| result = rag_qa(question, dataset) | |
| tool_results.append({ | |
| "tool_call_id": tool_call.id, | |
| "result": result | |
| }) | |
| # Add tool results to conversation and get final response | |
| assistant_message: ChatCompletionMessageParam = { | |
| "role": "assistant", | |
| "content": message.content or "", | |
| "tool_calls": [ | |
| { | |
| "id": tc.id, | |
| "type": tc.type, | |
| "function": { | |
| "name": tc.function.name, | |
| "arguments": tc.function.arguments | |
| } | |
| } | |
| for tc in message.tool_calls | |
| ] | |
| } | |
| messages.append(assistant_message) | |
| for tool_result in tool_results: | |
| tool_message: ChatCompletionMessageParam = { | |
| "role": "tool", | |
| "tool_call_id": tool_result["tool_call_id"], | |
| "content": tool_result["result"] | |
| } | |
| messages.append(tool_message) | |
| # Get final response | |
| final_response = openrouter_client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages | |
| ) | |
| return { | |
| "response": final_response.choices[0].message.content, | |
| "tool_calls": [ | |
| { | |
| "name": tc.function.name, | |
| "arguments": tc.function.arguments | |
| } | |
| for tc in message.tool_calls | |
| ] | |
| } | |
| else: | |
| # Direct response without tool calls | |
| return { | |
| "response": message.content, | |
| "tool_calls": None | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # /answer endpoint removed - use /chat for all interactions | |
| async def list_datasets(): | |
| """List all available datasets""" | |
| return {"datasets": list(pipelines.keys())} | |
| async def list_questions(dataset: str = "developer-portfolio"): | |
| """List all questions for a given dataset""" | |
| if dataset not in pipelines: | |
| raise HTTPException(status_code=400, detail=f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}") | |
| selected_pipeline = pipelines[dataset] | |
| questions = [doc.meta['question'] for doc in selected_pipeline.documents if 'question' in doc.meta] | |
| return {"dataset": dataset, "questions": questions} | |
| async def load_datasets_background(): | |
| """Load datasets in background after server starts""" | |
| global pipelines | |
| # Import RAGPipeline only when needed | |
| from .pipeline import RAGPipeline | |
| # Only load developer-portfolio to save memory | |
| dataset_name = "developer-portfolio" | |
| try: | |
| logger.info(f"Loading dataset: {dataset_name}") | |
| pipeline = RAGPipeline.from_preset(preset_name=dataset_name) | |
| pipelines[dataset_name] = pipeline | |
| logger.info(f"Successfully loaded {dataset_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to load {dataset_name}: {e}") | |
| logger.info(f"Background loading complete - {len(pipelines)} datasets loaded") | |
| async def startup_event(): | |
| logger.info("FastAPI application startup complete") | |
| logger.info(f"Server should be running on port: {os.getenv('PORT', '8000')}") | |
| # Start loading datasets in background (non-blocking) | |
| import asyncio | |
| asyncio.create_task(load_datasets_background()) | |
| async def shutdown_event(): | |
| logger.info("FastAPI application shutting down") | |
| async def root(): | |
| """Root endpoint""" | |
| return {"status": "ok", "message": "RAG Pipeline API", "version": "1.0.0", "datasets": list(pipelines.keys())} | |
| async def health_check(): | |
| """Health check endpoint""" | |
| logger.info("Health check called") | |
| loading_status = "complete" if "developer-portfolio" in pipelines else "loading" | |
| return { | |
| "status": "healthy", | |
| "datasets_loaded": len(pipelines), | |
| "total_datasets": 1, # Only loading developer-portfolio | |
| "loading_status": loading_status, | |
| "port": os.getenv('PORT', '8000') | |
| } | |