| |
| from fastapi import FastAPI, HTTPException |
| import uvicorn |
| import sqlite3 |
| import logging |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_ollama import ChatOllama |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger("Private_MCP_Server") |
|
|
| |
| DB_FILE = "portfolio.db" |
|
|
| |
| |
| |
| llm = ChatOllama(model="llama3", temperature=0) |
|
|
| |
| |
| text_to_sql_prompt = ChatPromptTemplate.from_messages([ |
| ("system", |
| """You are a Text-to-SQL assistant. Convert the question to a read-only SQLite query for the 'holdings' table. |
| Schema: symbol (TEXT), shares (INTEGER), average_cost (REAL). |
| RULES: |
| 1. SELECT only. No INSERT/UPDATE/DELETE. |
| 2. Output ONLY the SQL query. No markdown. |
| """), |
| ("human", "Question: {question}") |
| ]) |
|
|
| |
| sql_generation_chain = text_to_sql_prompt | llm | StrOutputParser() |
|
|
| |
| app = FastAPI(title="Aegis Private MCP Server") |
|
|
| @app.on_event("startup") |
| async def startup_db(): |
| """Initialize the database with dummy data if it doesn't exist.""" |
| try: |
| with sqlite3.connect(DB_FILE) as conn: |
| cursor = conn.cursor() |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS holdings ( |
| symbol TEXT PRIMARY KEY, |
| shares INTEGER, |
| average_cost REAL |
| ) |
| """) |
| |
| |
| cursor.execute("SELECT count(*) FROM holdings") |
| if cursor.fetchone()[0] == 0: |
| logger.info("Populating database with diverse dummy data...") |
| |
| dummy_data = [ |
| |
| ('AAPL', 5000, 180.20), ('MSFT', 3000, 350.50), ('GOOGL', 1500, 140.10), ('NVDA', 800, 450.00), ('AMD', 2000, 110.30), |
| ('INTC', 4000, 35.40), ('CRM', 1200, 220.10), ('ADBE', 600, 550.20), ('ORCL', 2500, 115.50), ('CSCO', 3500, 52.10), |
| |
| ('JPM', 2000, 150.40), ('BAC', 5000, 32.10), ('GS', 500, 340.50), ('V', 1000, 240.20), ('MA', 800, 380.10), |
| |
| ('WMT', 1500, 160.30), ('TGT', 1000, 130.50), ('COST', 400, 550.10), ('KO', 3000, 58.20), ('PEP', 2500, 170.40), |
| ('PG', 2000, 150.10), ('NKE', 1200, 105.30), ('SBUX', 1800, 95.40), |
| |
| ('JNJ', 2500, 160.20), ('PFE', 4000, 35.10), ('UNH', 600, 480.50), ('LLY', 400, 580.10), ('MRK', 2000, 110.20), |
| |
| ('XOM', 3000, 105.40), ('CVX', 2000, 150.20), ('GE', 1500, 110.50), ('CAT', 800, 280.10), ('BA', 500, 210.30), |
| |
| ('TSLA', 1000, 220.90), ('F', 5000, 12.10), ('GM', 4000, 35.40) |
| ] |
| cursor.executemany("INSERT INTO holdings (symbol, shares, average_cost) VALUES (?, ?, ?)", dummy_data) |
| conn.commit() |
| logger.info("Database populated successfully.") |
| else: |
| logger.info("Database already contains data.") |
| except Exception as e: |
| logger.error(f"Failed to initialize database: {e}") |
|
|
|
|
| def execute_safe_query(query: str, params=None): |
| """ |
| Executes a SQL query after a basic safety check. |
| This is a critical security function. |
| """ |
| |
| if not query.strip().upper().startswith("SELECT"): |
| logger.error(f"SECURITY VIOLATION: Attempted to execute non-SELECT query: {query}") |
| raise HTTPException(status_code=403, detail="Forbidden: Only SELECT queries are allowed.") |
| |
| try: |
| with sqlite3.connect(DB_FILE) as conn: |
| conn.row_factory = sqlite3.Row |
| cursor = conn.cursor() |
| if params: |
| cursor.execute(query, params) |
| else: |
| cursor.execute(query) |
| |
| results = [dict(row) for row in cursor.fetchall()] |
| |
| for row in results: |
| for key, value in row.items(): |
| if value is None: |
| row[key] = 0 |
| return results |
| except sqlite3.Error as e: |
| logger.error(f"Database error executing query '{query}': {e}") |
| raise HTTPException(status_code=500, detail=f"Database query failed: {e}") |
|
|
| @app.post("/portfolio_data") |
| async def get_portfolio_data(payload: dict): |
| """ |
| Takes a natural language question, converts it to SQL using Llama 3, |
| and executes it against the internal portfolio database. |
| """ |
| question = payload.get("question") |
| if not question: |
| raise HTTPException(status_code=400, detail="'question' is a required field.") |
|
|
| logger.info(f"Received portfolio data question: '{question}'") |
|
|
| try: |
| |
| try: |
| generated_sql = await sql_generation_chain.ainvoke({"question": question}) |
| logger.info(f"Llama 3 generated SQL: {generated_sql}") |
| except Exception as llm_error: |
| logger.warning(f"LLM generation failed (likely Ollama offline): {llm_error}. Using fallback logic.") |
| |
| import re |
| q_upper = question.upper() |
| |
| matches = re.findall(r'\b[A-Z]{1,5}\b', q_upper) |
| |
| found_symbol = None |
| ignored_words = ["WHAT", "IS", "THE", "TO", "OF", "FOR", "IN", "AND", "OR", "SHOW", "ME", "DATA", "STOCK", "PRICE", "DO", "WE", "OWN", "HAVE", "ANY", "EXPOSURE", "CURRENT"] |
| |
| for match in matches: |
| if match not in ignored_words: |
| found_symbol = match |
| break |
| |
| if found_symbol: |
| generated_sql = f"SELECT * FROM holdings WHERE symbol='{found_symbol}'" |
| else: |
| generated_sql = "SELECT * FROM holdings" |
| logger.info(f"Fallback SQL generated: {generated_sql}") |
|
|
| |
| results = execute_safe_query(generated_sql) |
| logger.info(f"Successfully executed query and found {len(results)} records.") |
|
|
| return {"status": "success", "question": question, "generated_sql": generated_sql, "data": results} |
|
|
| except HTTPException as http_exc: |
| |
| raise http_exc |
| except Exception as e: |
| logger.critical(f"An unexpected error occurred in the portfolio data endpoint: {e}") |
| |
| return {"status": "error", "message": str(e), "data": []} |
|
|
| @app.get("/") |
| def read_root(): |
| return {"message": "Aegis Private MCP Server is operational."} |
|
|
| |
| if __name__ == "__main__": |
| |
| uvicorn.run(app, host="127.0.0.1", port=8003) |