| import os |
| import sys |
| import logging |
| import json |
| from contextlib import asynccontextmanager |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
| import uvicorn |
| from fastapi.responses import StreamingResponse |
|
|
| from langchain_core.messages import ToolMessage, AIMessage |
| from langchain_openai import ChatOpenAI |
| from langgraph.prebuilt import create_react_agent |
|
|
| from tools import MCPClient, SchemaSearchTool, JoinPathFinderTool, QueryExecutorTool |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp") |
| API_KEY = os.getenv("MCP_API_KEY", "dev-key-123") |
| LLM_API_KEY = os.getenv("LLM_API_KEY") |
|
|
| |
| SYSTEM_PROMPT = """You are a helpful assistant for querying life sciences databases. |
| |
| You have access to these tools: |
| - schema_search: Find relevant database tables and columns based on keywords |
| - find_join_path: Discover how to join tables together using the knowledge graph |
| - execute_query: Run SQL queries against the databases |
| |
| Always use schema_search first to understand the available data, then construct appropriate SQL queries. |
| When querying, be specific about what tables and columns you're using.""" |
|
|
| |
| class GraphRAGAgent: |
| """The core agent for handling GraphRAG queries using LangGraph.""" |
|
|
| def __init__(self): |
| if not LLM_API_KEY: |
| raise ValueError("LLM_API_KEY environment variable not set.") |
|
|
| llm = ChatOpenAI(api_key=LLM_API_KEY, model="gpt-4o-mini", temperature=0, max_retries=1) |
| |
| mcp_client = MCPClient(mcp_url=MCP_URL, api_key=API_KEY) |
| tools = [ |
| SchemaSearchTool(mcp_client=mcp_client), |
| JoinPathFinderTool(mcp_client=mcp_client), |
| QueryExecutorTool(mcp_client=mcp_client), |
| ] |
| |
| |
| self.graph = create_react_agent(llm, tools, state_modifier=SYSTEM_PROMPT) |
|
|
| async def stream_query(self, question: str): |
| """Processes a question and streams the intermediate steps.""" |
| try: |
| async for event in self.graph.astream( |
| {"messages": [("user", question)]}, |
| stream_mode="values" |
| ): |
| |
| messages = event.get("messages", []) |
| if not messages: |
| continue |
| |
| last_message = messages[-1] |
| |
| if isinstance(last_message, AIMessage) and last_message.tool_calls: |
| |
| tool_call = last_message.tool_calls[0] |
| yield json.dumps({ |
| "type": "thought", |
| "content": f"🤖 Calling tool `{tool_call['name']}` with args: {tool_call['args']}" |
| }) + "\n\n" |
| elif isinstance(last_message, ToolMessage): |
| |
| yield json.dumps({ |
| "type": "observation", |
| "content": f"🛠️ Tool `{last_message.name}` returned:\n\n```\n{last_message.content}\n```" |
| }) + "\n\n" |
| elif isinstance(last_message, AIMessage) and last_message.content: |
| |
| yield json.dumps({ |
| "type": "final_answer", |
| "content": last_message.content |
| }) + "\n\n" |
| except Exception as e: |
| logger.error(f"Error in agent workflow: {e}", exc_info=True) |
| yield json.dumps({ |
| "type": "final_answer", |
| "content": f"I encountered an error while processing your request. Please try rephrasing your question or asking something simpler." |
| }) + "\n\n" |
|
|
| |
| agent = None |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Handles agent initialization on startup.""" |
| global agent |
| logger.info("Agent server startup...") |
| try: |
| agent = GraphRAGAgent() |
| logger.info("GraphRAGAgent initialized successfully.") |
| except ValueError as e: |
| logger.error(f"Agent initialization failed: {e}") |
| yield |
| logger.info("Agent server shutdown.") |
|
|
| app = FastAPI(title="GraphRAG Agent Server", lifespan=lifespan) |
|
|
| class QueryRequest(BaseModel): |
| question: str |
|
|
| @app.post("/query") |
| async def execute_query(request: QueryRequest) -> StreamingResponse: |
| """Endpoint to receive questions and stream the agent's response.""" |
| if not agent: |
| async def error_stream(): |
| yield json.dumps({"error": "Agent is not initialized. Check server logs."}) |
| return StreamingResponse(error_stream()) |
| |
| return StreamingResponse(agent.stream_query(request.question), media_type="application/x-ndjson") |
|
|
| @app.get("/health") |
| def health_check(): |
| """Health check endpoint.""" |
| return {"status": "ok", "agent_initialized": agent is not None} |
|
|
| |
| def main(): |
| """Main entry point to run the FastAPI server.""" |
| logger.info("Starting agent server...") |
| uvicorn.run(app, host="0.0.0.0", port=8001) |
|
|
| if __name__ == "__main__": |
| main() |