| import os |
| import requests |
| import json |
| from typing import Dict, Any, List, Optional |
| from langchain.tools import BaseTool |
| from pydantic import Field |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class MCPClient: |
| """Client for making authenticated REST API calls to the MCP server.""" |
| |
| def __init__(self, mcp_url: str, api_key: str): |
| self.mcp_url = mcp_url |
| self.headers = { |
| "x-api-key": api_key, |
| "Content-Type": "application/json" |
| } |
| |
| def post(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: |
| """Send a POST request to a given MCP endpoint.""" |
| try: |
| url = f"{self.mcp_url}/{endpoint}" |
| response = requests.post(url, headers=self.headers, data=json.dumps(data)) |
| response.raise_for_status() |
| return response.json() |
| except requests.exceptions.HTTPError as http_err: |
| logger.error(f"HTTP error occurred: {http_err} - {response.text}") |
| return {"status": "error", "message": f"HTTP error: {response.status_code} {response.reason}"} |
| except requests.exceptions.RequestException as req_err: |
| logger.error(f"Request error occurred: {req_err}") |
| return {"status": "error", "message": f"Request failed: {req_err}"} |
| except json.JSONDecodeError: |
| logger.error("Failed to decode JSON response.") |
| return {"status": "error", "message": "Invalid JSON response from server."} |
|
|
|
|
| class SchemaSearchTool(BaseTool): |
| """LangChain tool for searching database schemas.""" |
| |
| name: str = "schema_search" |
| description: str = """ |
| Search for relevant database schemas based on a natural language query. |
| Use this when you need to find which tables/columns are relevant to a user's question. |
| Input should be a descriptive query like 'patient information' or 'drug trials'. |
| """ |
| mcp_client: MCPClient |
| |
| def _run(self, query: str) -> str: |
| """Execute schema search.""" |
| response = self.mcp_client.post("discovery/get_relevant_schemas", {"query": query}) |
| |
| if response.get("status") == "success": |
| schemas = response.get("schemas", []) |
| if schemas: |
| schema_text = "Found relevant schemas:\\n" |
| for schema in schemas: |
| schema_text += f"- {schema.get('database', 'Unknown')}.{schema.get('table', 'Unknown')}.{schema.get('name', 'Unknown')} ({schema.get('type', ['Unknown'])[0]})\\n" |
| return schema_text |
| else: |
| return "No relevant schemas found." |
| else: |
| return f"Error searching schemas: {response.get('message', 'Unknown error')}" |
| |
| async def _arun(self, query: str) -> str: |
| """Async version - just calls sync version.""" |
| return self._run(query) |
|
|
|
|
| class JoinPathFinderTool(BaseTool): |
| """LangChain tool for finding join paths between tables.""" |
| |
| name: str = "find_join_path" |
| description: str = """ |
| Find how to join two tables together using foreign key relationships. |
| Use this when you need to query across multiple tables. |
| Input should be two table names separated by a comma, like 'patients,studies'. |
| """ |
| mcp_client: MCPClient |
|
|
| def _run(self, table_names: str) -> str: |
| """Find join path.""" |
| try: |
| tables = [t.strip() for t in table_names.split(',')] |
| if len(tables) != 2: |
| return "Please provide exactly two table names separated by a comma." |
| |
| response = self.mcp_client.post( |
| "graph/find_join_path", |
| {"table1": tables[0], "table2": tables[1]} |
| ) |
| |
| if response.get("status") == "success": |
| path = response.get("path", "No path found") |
| return f"Join path: {path}" |
| else: |
| return f"Error finding join path: {response.get('message', 'Unknown error')}" |
| except Exception as e: |
| return f"Failed to find join path: {str(e)}" |
|
|
| async def _arun(self, table_names: str) -> str: |
| """Async version - just calls sync version.""" |
| return self._run(table_names) |
|
|
|
|
| class QueryExecutorTool(BaseTool): |
| """LangChain tool for executing SQL queries.""" |
| |
| name: str = "execute_query" |
| description: str = """ |
| Execute a SQL query against the databases and return results. |
| Use this after you have a valid SQL query. |
| Input should be a valid SQL query string. |
| """ |
| mcp_client: MCPClient |
|
|
| def _run(self, sql: str) -> str: |
| """Execute query.""" |
| try: |
| response = self.mcp_client.post( |
| "intelligence/execute_query", |
| {"sql": sql} |
| ) |
| |
| if response.get("status") == "success": |
| results = response.get("results", []) |
| |
| if results: |
| |
| result_text = f"Query returned {len(results)} rows:\\n" |
| headers = list(results[0].keys()) |
| result_text += " | ".join(headers) + "\n" |
| result_text += "-" * (len(" | ".join(headers))) + "\n" |
| |
| for row in results[:10]: |
| values = [str(row.get(h, "")) for h in headers] |
| result_text += " | ".join(values) + "\n" |
| |
| if len(results) > 10: |
| result_text += f"... and {len(results) - 10} more rows\n" |
| |
| return result_text |
| else: |
| return "Query executed successfully but returned no results." |
| else: |
| return f"Error executing query: {response.get('message', 'Unknown error')}" |
| except Exception as e: |
| return f"Failed to execute query: {str(e)}" |
|
|
| async def _arun(self, sql: str) -> str: |
| """Async version - just calls sync version.""" |
| return self._run(sql) |
|
|