| import os |
|
|
| import streamlit as st |
|
|
| import pandas as pd |
|
|
| import pickle |
|
|
| import base64 |
|
|
| from io import BytesIO, StringIO |
|
|
| import sys |
|
|
| import operator |
|
|
| from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple |
|
|
| import tempfile |
|
|
| import shutil |
|
|
| import plotly.io as pio |
|
|
| import io |
|
|
| import re |
|
|
| import json |
|
|
| import openai |
|
|
| |
|
|
| import base64 |
|
|
| from datetime import datetime |
|
|
| from reportlab.lib.pagesizes import letter |
|
|
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image |
|
|
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle |
|
|
| from reportlab.lib.units import inch |
|
|
| from PIL import Image as PILImage |
|
|
| |
|
|
| |
|
|
| from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage |
|
|
| from langchain_core.prompts import ChatPromptTemplate |
|
|
| from langchain_openai import ChatOpenAI |
|
|
| from langchain_experimental.utilities import PythonREPL |
|
|
| from langgraph.prebuilt import ToolInvocation, ToolExecutor |
|
|
| from langchain_core.tools import tool |
|
|
| from langgraph.prebuilt import InjectedState |
|
|
| from langgraph.graph import StateGraph, END |
|
|
| from reportlab.platypus import PageBreak |
|
|
| from PIL import Image as PILImage |
|
|
| |
|
|
| |
|
|
| if 'ai_provider' not in st.session_state: |
|
|
| st.session_state.ai_provider = "openai" |
|
|
| |
|
|
| if 'api_key' not in st.session_state: |
|
|
| st.session_state.api_key = "" |
|
|
| |
|
|
| if 'selected_model' not in st.session_state: |
|
|
| st.session_state.selected_model = "gpt-4" |
|
|
| |
|
|
| |
|
|
| OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"] |
|
|
| GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"] |
|
|
| |
|
|
| |
|
|
| if 'temp_dir' not in st.session_state: |
|
|
| st.session_state.temp_dir = tempfile.mkdtemp() |
|
|
| st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle") |
|
|
| os.makedirs(st.session_state.images_dir, exist_ok=True) |
|
|
| print(f"Created temporary directory: {st.session_state.temp_dir}") |
|
|
| print(f"Created images directory: {st.session_state.images_dir}") |
|
|
| |
|
|
| |
|
|
| SYSTEM_PROMPT = """## Role |
| |
| You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data. |
| |
| |
| |
| ## Capabilities |
| |
| 1. **Execute python code** using the `complete_python_task` tool. |
| |
| |
| |
| ## Goals |
| |
| 1. Understand the user's objectives clearly. |
| |
| 2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems. |
| |
| 3. Investigate if the goal is achievable by running Python code via the `python_code` field. |
| |
| 4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances. |
| |
| |
| |
| ## Code Guidelines |
| |
| - **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data. |
| |
| - **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed. |
| |
| - **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise. |
| |
| - **ONLY USE THE FOLLOWING LIBRARIES**: |
| |
| - `pandas` |
| |
| - `sklearn` (including all major ML models) |
| |
| - `plotly` |
| |
| - `numpy` |
| |
| |
| |
| All these libraries are already imported for you. |
| |
| |
| |
| ## Machine Learning Guidelines |
| |
| - For regression tasks: |
| |
| - Linear Regression: `LinearRegression` |
| |
| - Logistic Regression: `LogisticRegression` |
| |
| - Ridge Regression: `Ridge` |
| |
| - Lasso Regression: `Lasso` |
| |
| - Random Forest Regression: `RandomForestRegressor` |
| |
| |
| |
| - For classification tasks: |
| |
| - Logistic Regression: `LogisticRegression` |
| |
| - Decision Trees: `DecisionTreeClassifier` |
| |
| - Random Forests: `RandomForestClassifier` |
| |
| - Support Vector Machines: `SVC` |
| |
| - K-Nearest Neighbors: `KNeighborsClassifier` |
| |
| - Naive Bayes: `GaussianNB` |
| |
| |
| |
| - For clustering: |
| |
| - K-Means: `KMeans` |
| |
| - DBSCAN: `DBSCAN` |
| |
| |
| |
| - For dimensionality reduction: |
| |
| - PCA: `PCA` |
| |
| |
| |
| - Always preprocess data appropriately: |
| |
| - Scale numerical features with `StandardScaler` or `MinMaxScaler` |
| |
| - Encode categorical variables with `OneHotEncoder` when needed |
| |
| - Handle missing values with `SimpleImputer` |
| |
| |
| |
| - Always split data into training and testing sets using `train_test_split` |
| |
| - Evaluate models using appropriate metrics: |
| |
| - For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score` |
| |
| - For classification: `accuracy_score`, `confusion_matrix`, `classification_report` |
| |
| - For clustering: `silhouette_score` |
| |
| |
| |
| - Consider using `cross_val_score` for more robust evaluation |
| |
| - Visualize ML results with plotly when possible |
| |
| |
| |
| ## Plotting Guidelines |
| |
| - Always use the `plotly` library for plotting. |
| |
| - Store all plotly figures inside a `plotly_figures` list, they will be saved automatically. |
| |
| - Do not try and show the plots inline with `fig.show()`. |
| |
| """ |
|
|
| |
|
|
| |
|
|
| class AgentState(TypedDict): |
|
|
| messages: Annotated[Sequence[BaseMessage], operator.add] |
|
|
| input_data: Annotated[List[Dict], operator.add] |
|
|
| intermediate_outputs: Annotated[List[dict], operator.add] |
|
|
| current_variables: dict |
|
|
| output_image_paths: Annotated[List[str], operator.add] |
|
|
| |
|
|
| |
|
|
| if 'in_memory_datasets' not in st.session_state: |
|
|
| st.session_state.in_memory_datasets = {} |
|
|
| |
|
|
| if 'persistent_vars' not in st.session_state: |
|
|
| st.session_state.persistent_vars = {} |
|
|
| |
|
|
| if 'dataset_metadata_list' not in st.session_state: |
|
|
| st.session_state.dataset_metadata_list = [] |
|
|
| |
|
|
| if 'chat_history' not in st.session_state: |
|
|
| st.session_state.chat_history = [] |
|
|
| |
|
|
| if 'dashboard_plots' not in st.session_state: |
|
|
| st.session_state.dashboard_plots = [None, None, None, None] |
|
|
| |
|
|
| if 'columns' not in st.session_state: |
|
|
| st.session_state.columns = ["No columns available"] |
|
|
| |
|
|
| if 'custom_plots_to_save' not in st.session_state: |
|
|
| st.session_state.custom_plots_to_save = {} |
|
|
| |
|
|
| |
|
|
| repl = PythonREPL() |
|
|
| plotly_saving_code = """import pickle |
| |
| |
| |
| import uuid |
| |
| import os |
| |
| for figure in plotly_figures: |
| |
| pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle" |
| |
| with open(pickle_filename, 'wb') as f: |
| |
| pickle.dump(figure, f) |
| |
| """ |
|
|
| |
|
|
| @tool |
|
|
| def complete_python_task( |
| |
| graph_state: Annotated[dict, InjectedState], |
| |
| thought: str, |
| |
| python_code: str |
| |
| ) -> Tuple[str, dict]: |
|
|
| """Execute Python code for data analysis and visualization.""" |
|
|
| |
|
|
| current_variables = graph_state.get("current_variables", {}) |
|
|
| |
|
|
| |
|
|
| for input_dataset in graph_state.get("input_data", []): |
|
|
| var_name = input_dataset.get("variable_name") |
|
|
| if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets: |
|
|
| print(f"Loading {var_name} from in-memory storage") |
|
|
| current_variables[var_name] = st.session_state.in_memory_datasets[var_name] |
|
|
| current_image_pickle_files = os.listdir(st.session_state.images_dir) |
|
|
| |
|
|
| try: |
|
|
| |
|
|
| old_stdout = sys.stdout |
|
|
| sys.stdout = StringIO() |
|
|
| |
|
|
| |
|
|
| exec_globals = globals().copy() |
|
|
| exec_globals.update(st.session_state.persistent_vars) |
|
|
| exec_globals.update(current_variables) |
|
|
| |
|
|
| |
|
|
| import sklearn |
|
|
| import numpy as np |
|
|
| |
|
|
| |
|
|
| from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso |
|
|
| from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier |
|
|
| from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor |
|
|
| from sklearn.svm import SVC, SVR |
|
|
| from sklearn.naive_bayes import GaussianNB |
|
|
| from sklearn.decomposition import PCA |
|
|
| from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor |
|
|
| from sklearn.cluster import KMeans, DBSCAN |
|
|
| from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder |
|
|
| from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV |
|
|
| from sklearn.metrics import ( |
|
|
| accuracy_score, confusion_matrix, classification_report, |
|
|
| mean_squared_error, r2_score, mean_absolute_error, silhouette_score |
|
|
| ) |
|
|
| from sklearn.pipeline import Pipeline |
|
|
| from sklearn.impute import SimpleImputer |
|
|
| |
|
|
| |
|
|
| exec_globals.update({ |
|
|
| "plotly_figures": [], |
|
|
| "images_dir": st.session_state.images_dir, |
|
|
| "np": np, |
|
|
| |
|
|
| "LinearRegression": LinearRegression, |
|
|
| "LogisticRegression": LogisticRegression, |
|
|
| "Ridge": Ridge, |
|
|
| "Lasso": Lasso, |
|
|
| |
|
|
| "DecisionTreeClassifier": DecisionTreeClassifier, |
|
|
| "DecisionTreeRegressor": DecisionTreeRegressor, |
|
|
| "RandomForestClassifier": RandomForestClassifier, |
|
|
| "RandomForestRegressor": RandomForestRegressor, |
|
|
| "GradientBoostingClassifier": GradientBoostingClassifier, |
|
|
| |
|
|
| "SVC": SVC, |
|
|
| "SVR": SVR, |
|
|
| |
|
|
| "GaussianNB": GaussianNB, |
|
|
| "PCA": PCA, |
|
|
| "KNeighborsClassifier": KNeighborsClassifier, |
|
|
| "KNeighborsRegressor": KNeighborsRegressor, |
|
|
| "KMeans": KMeans, |
|
|
| "DBSCAN": DBSCAN, |
|
|
| |
|
|
| "StandardScaler": StandardScaler, |
|
|
| "MinMaxScaler": MinMaxScaler, |
|
|
| "OneHotEncoder": OneHotEncoder, |
|
|
| "SimpleImputer": SimpleImputer, |
|
|
| |
|
|
| "train_test_split": train_test_split, |
|
|
| "cross_val_score": cross_val_score, |
|
|
| "GridSearchCV": GridSearchCV, |
|
|
| "accuracy_score": accuracy_score, |
|
|
| "confusion_matrix": confusion_matrix, |
|
|
| "classification_report": classification_report, |
|
|
| "mean_squared_error": mean_squared_error, |
|
|
| "r2_score": r2_score, |
|
|
| "mean_absolute_error": mean_absolute_error, |
|
|
| "silhouette_score": silhouette_score, |
|
|
| |
|
|
| "Pipeline": Pipeline |
|
|
| }) |
|
|
| |
|
|
| exec(python_code, exec_globals) |
|
|
| |
|
|
| st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()}) |
|
|
| |
|
|
| |
|
|
| output = sys.stdout.getvalue() |
|
|
| |
|
|
| |
|
|
| sys.stdout = old_stdout |
|
|
| |
|
|
| updated_state = { |
|
|
| "intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}], |
|
|
| "current_variables": st.session_state.persistent_vars |
|
|
| } |
|
|
| |
|
|
| if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']: |
|
|
| exec(plotly_saving_code, exec_globals) |
|
|
| |
|
|
| |
|
|
| new_image_folder_contents = os.listdir(st.session_state.images_dir) |
|
|
| new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files] |
|
|
| |
|
|
| if new_image_files: |
|
|
| updated_state["output_image_paths"] = new_image_files |
|
|
| st.session_state.persistent_vars["plotly_figures"] = [] |
|
|
| return output, updated_state |
|
|
| |
|
|
| except Exception as e: |
|
|
| sys.stdout = old_stdout |
|
|
| print(f"Error in complete_python_task: {str(e)}") |
|
|
| return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]} |
|
|
| |
|
|
| |
|
|
| def initialize_llm(): |
|
|
| api_key = st.session_state.api_key |
|
|
| model = st.session_state.selected_model |
|
|
| |
|
|
| if not api_key: |
|
|
| return None |
|
|
| |
|
|
| try: |
|
|
| if st.session_state.ai_provider == "openai": |
|
|
| os.environ["OPENAI_API_KEY"] = api_key |
|
|
| return ChatOpenAI(model=model, temperature=0) |
|
|
| elif st.session_state.ai_provider == "groq": |
|
|
| os.environ["GROQ_API_KEY"] = api_key |
|
|
| |
|
|
| from langchain_groq import ChatGroq |
|
|
| return ChatGroq(model=model, temperature=0) |
|
|
| except Exception as e: |
|
|
| print(f"Error initializing LLM: {str(e)}") |
|
|
| return None |
|
|
| |
|
|
| |
|
|
| tools = [complete_python_task] |
|
|
| tool_executor = ToolExecutor(tools) |
|
|
| |
|
|
| |
|
|
| chat_template = ChatPromptTemplate.from_messages([ |
|
|
| ("system", SYSTEM_PROMPT), |
|
|
| ("placeholder", "{messages}"), |
|
|
| ]) |
|
|
| |
|
|
| def create_data_summary(state: AgentState) -> str: |
|
|
| summary = "" |
|
|
| variables = [] |
|
|
| |
|
|
| |
|
|
| for d in state.get("input_data", []): |
|
|
| var_name = d.get("variable_name") |
|
|
| if var_name: |
|
|
| |
|
|
| variables.append(var_name) |
|
|
| summary += f"\n\nVariable: {var_name}\n" |
|
|
| summary += f"Description: {d.get('data_description', 'No description')}\n" |
|
|
| |
|
|
| |
|
|
| if var_name in st.session_state.in_memory_datasets: |
|
|
| df = st.session_state.in_memory_datasets[var_name] |
|
|
| summary += "\nSample Data (first 5 rows):\n" |
|
|
| summary += df.head(5).to_string() |
|
|
| |
|
|
| if "current_variables" in state: |
|
|
| remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")] |
|
|
| |
|
|
| for v in remaining_variables: |
|
|
| |
|
|
| var_value = state["current_variables"].get(v) |
|
|
| |
|
|
| if isinstance(var_value, pd.DataFrame): |
|
|
| summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})" |
|
|
| else: |
|
|
| summary += f"\n\nVariable: {v}" |
|
|
| return summary |
|
|
| |
|
|
| def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]: |
|
|
| """Determine if we should route to tools or end the chain""" |
|
|
| if messages := state.get("messages", []): |
|
|
| ai_message = messages[-1] |
|
|
| else: |
|
|
| raise ValueError(f"No messages found in input state to tool_edge: {state}") |
|
|
| |
|
|
| if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: |
|
|
| return "tools" |
|
|
| |
|
|
| return "__end__" |
|
|
| |
|
|
| def call_model(state: AgentState): |
|
|
| """Call the LLM to get a response""" |
|
|
| current_data_template = """The following data is available:\n{data_summary}""" |
|
|
| current_data_message = HumanMessage( |
|
|
| content=current_data_template.format(data_summary=create_data_summary(state)) |
|
|
| ) |
|
|
| messages = [current_data_message] + state["messages"] |
|
|
| |
|
|
| |
|
|
| llm = initialize_llm() |
|
|
| if llm is None: |
|
|
| return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]} |
|
|
| |
|
|
| |
|
|
| model = llm.bind_tools(tools) |
|
|
| model = chat_template | model |
|
|
| |
|
|
| llm_outputs = model.invoke({"messages": messages}) |
|
|
| return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]} |
|
|
| |
|
|
| def call_tools(state: AgentState): |
|
|
| """Execute tools called by the LLM""" |
|
|
| last_message = state["messages"][-1] |
|
|
| tool_invocations = [] |
|
|
| |
|
|
| if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'): |
|
|
| tool_invocations = [ |
|
|
| ToolInvocation( |
|
|
| tool=tool_call["name"], |
|
|
| tool_input={**tool_call["args"], "graph_state": state} |
|
|
| ) for tool_call in last_message.tool_calls |
|
|
| ] |
|
|
| responses = tool_executor.batch(tool_invocations, return_exceptions=True) |
|
|
| |
|
|
| tool_messages = [] |
|
|
| state_updates = {} |
|
|
| |
|
|
| for tc, response in zip(last_message.tool_calls, responses): |
|
|
| if isinstance(response, Exception): |
|
|
| print(f"Exception in tool execution: {str(response)}") |
|
|
| tool_messages.append(ToolMessage( |
|
|
| content=f"Error: {str(response)}", |
|
|
| name=tc["name"], |
|
|
| tool_call_id=tc["id"] |
|
|
| )) |
|
|
| continue |
|
|
| |
|
|
| message, updates = response |
|
|
| tool_messages.append(ToolMessage( |
|
|
| content=str(message), |
|
|
| name=tc["name"], |
|
|
| tool_call_id=tc["id"] |
|
|
| )) |
|
|
| |
|
|
| |
|
|
| for key, value in updates.items(): |
|
|
| if key in state_updates: |
|
|
| if isinstance(value, list) and isinstance(state_updates[key], list): |
|
|
| state_updates[key].extend(value) |
|
|
| elif isinstance(value, dict) and isinstance(state_updates[key], dict): |
|
|
| state_updates[key].update(value) |
|
|
| else: |
|
|
| state_updates[key] = value |
|
|
| else: |
|
|
| state_updates[key] = value |
|
|
| |
|
|
| if 'messages' not in state_updates: |
|
|
| state_updates["messages"] = [] |
|
|
| |
|
|
| state_updates["messages"] = tool_messages |
|
|
| return state_updates |
|
|
| |
|
|
| |
|
|
| workflow = StateGraph(AgentState) |
|
|
| workflow.add_node("agent", call_model) |
|
|
| workflow.add_node("tools", call_tools) |
|
|
| workflow.add_conditional_edges( |
|
|
| "agent", |
|
|
| route_to_tools, |
|
|
| { |
|
|
| "tools": "tools", |
|
|
| "__end__": END |
|
|
| } |
|
|
| ) |
|
|
| workflow.add_edge("tools", "agent") |
|
|
| workflow.set_entry_point("agent") |
|
|
| |
|
|
| chain = workflow.compile() |
|
|
| |
|
|
| def process_file_upload(files): |
|
|
| """Process uploaded files and return dataframe previews and column names""" |
|
|
| st.session_state.in_memory_datasets = {} |
|
|
| st.session_state.dataset_metadata_list = [] |
|
|
| st.session_state.persistent_vars.clear() |
|
|
| |
|
|
| if not files: |
|
|
| return "No files uploaded.", [], ["No columns available"] |
|
|
| |
|
|
| results = [] |
|
|
| all_columns = [] |
|
|
| |
|
|
| for file in files: |
|
|
| try: |
|
|
| |
|
|
| if file.name.endswith('.csv'): |
|
|
| df = pd.read_csv(file) |
|
|
| elif file.name.endswith(('.xls', '.xlsx')): |
|
|
| df = pd.read_excel(file) |
|
|
| else: |
|
|
| results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.") |
|
|
| continue |
|
|
| |
|
|
| var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower() |
|
|
| st.session_state.in_memory_datasets[var_name] = df |
|
|
| |
|
|
| |
|
|
| all_columns.extend(df.columns.tolist()) |
|
|
| |
|
|
| |
|
|
| dataset_metadata = { |
|
|
| "variable_name": var_name, |
|
|
| "data_path": "in_memory", |
|
|
| "data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}", |
|
|
| "original_filename": file.name |
|
|
| } |
|
|
| |
|
|
| st.session_state.dataset_metadata_list.append(dataset_metadata) |
|
|
| |
|
|
| |
|
|
| preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n" |
|
|
| preview += df.head(10).to_markdown() |
|
|
| results.append(preview) |
|
|
| print(f"Successfully processed {file.name}") |
|
|
| |
|
|
| except Exception as e: |
|
|
| print(f"Error processing {file.name}: {str(e)}") |
|
|
| results.append(f"Error processing {file.name}: {str(e)}") |
|
|
| |
|
|
| |
|
|
| unique_columns = [] |
|
|
| seen = set() |
|
|
| |
|
|
| for col in all_columns: |
|
|
| if col not in seen: |
|
|
| seen.add(col) |
|
|
| unique_columns.append(col) |
|
|
| |
|
|
| if not unique_columns: |
|
|
| unique_columns = ["No columns available"] |
|
|
| |
|
|
| print(f"Found {len(unique_columns)} unique columns across datasets") |
|
|
| return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns |
|
|
| |
|
|
| def get_columns(): |
|
|
| """Directly gets columns from in-memory datasets""" |
|
|
| all_columns = [] |
|
|
| |
|
|
| for var_name, df in st.session_state.in_memory_datasets.items(): |
|
|
| if isinstance(df, pd.DataFrame): |
|
|
| all_columns.extend(df.columns.tolist()) |
|
|
| |
|
|
| |
|
|
| unique_columns = [] |
|
|
| seen = set() |
|
|
| |
|
|
| for col in all_columns: |
|
|
| if col not in seen: |
|
|
| seen.add(col) |
|
|
| unique_columns.append(col) |
|
|
| |
|
|
| if not unique_columns: |
|
|
| unique_columns = ["No columns available"] |
|
|
| |
|
|
| print(f"Populating dropdowns with {len(unique_columns)} columns") |
|
|
| return unique_columns |
|
|
| |
|
|
| |
|
|
| import openai |
|
|
| import pandas as pd |
|
|
| import json |
|
|
| import re |
|
|
| |
|
|
| def standard_clean(df): |
|
|
| df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns] |
|
|
| df.drop_duplicates(inplace=True) |
|
|
| df.dropna(axis=1, how='all', inplace=True) |
|
|
| df.dropna(axis=0, how='all', inplace=True) |
|
|
| for col in df.select_dtypes(include='object').columns: |
|
|
| df[col] = df[col].astype(str).str.strip() |
|
|
| return df |
|
|
| |
|
|
| def query_openai(prompt): |
|
|
| try: |
|
|
| |
|
|
| api_key = st.session_state.api_key |
|
|
| model = st.session_state.selected_model |
|
|
| |
|
|
| if st.session_state.ai_provider == "openai": |
|
|
| client = openai.OpenAI(api_key=api_key) |
|
|
| response = client.chat.completions.create( |
|
|
| model=model, |
|
|
| messages=[{"role": "user", "content": prompt}], |
|
|
| temperature=0.7 |
|
|
| ) |
|
|
| return response.choices[0].message.content |
|
|
| elif st.session_state.ai_provider == "groq": |
|
|
| from groq import Groq |
|
|
| client = Groq(api_key=api_key) |
|
|
| response = client.chat.completions.create( |
|
|
| model=model, |
|
|
| messages=[{"role": "user", "content": prompt}], |
|
|
| temperature=0.7 |
|
|
| ) |
|
|
| return response.choices[0].message.content |
|
|
| except Exception as e: |
|
|
| print(f"API Error: {e}") |
|
|
| return "{}" |
|
|
| |
|
|
| def llm_suggest_cleaning(df): |
|
|
| sample = df.head(10).to_csv(index=False) |
|
|
| prompt = f""" |
| |
| You are a professional data wrangler. Below is a sample of a messy dataset. |
| |
| |
| |
| Return a Python dictionary with the following keys: |
| |
| |
| |
| 1. rename_columns – fix unclear or inconsistent column names |
| |
| 2. convert_types – correct datatypes: int, float, str, or date |
| |
| 3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0 |
| |
| 4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes) |
| |
| |
| |
| Do not drop any rows or columns. Your output must be a valid Python dict. |
| |
| |
| |
| Example: |
| |
| {{ |
| |
| "rename_columns": {{"dob": "date_of_birth"}}, |
| |
| "convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}}, |
| |
| "fill_missing": {{"gender": "mode", "salary": -1}}, |
| |
| "value_map": {{ |
| |
| "gender": {{"M": "Male", "F": "Female"}}, |
| |
| "subscribed": {{"Y": "Yes", "N": "No"}} |
| |
| }} |
| |
| }} |
| |
| Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning. |
| |
| Sample data: |
| |
| {sample} |
| |
| """ |
|
|
| raw_response = query_openai(prompt) |
|
|
| try: |
|
|
| suggestions = eval(raw_response) |
|
|
| return suggestions |
|
|
| except: |
|
|
| print("Could not parse suggestions.") |
|
|
| return { |
|
|
| "rename_columns": {}, |
|
|
| "convert_types": {}, |
|
|
| "fill_missing": {}, |
|
|
| "value_map": {} |
|
|
| } |
|
|
| |
|
|
| def apply_suggestions(df, suggestions): |
|
|
| df.rename(columns=suggestions.get("rename_columns", {}), inplace=True) |
|
|
| |
|
|
| for col, dtype in suggestions.get("convert_types", {}).items(): |
|
|
| if col not in df.columns: |
|
|
| continue |
|
|
| try: |
|
|
| if dtype == "int": |
|
|
| df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64") |
|
|
| elif dtype == "float": |
|
|
| df[col] = pd.to_numeric(df[col], errors='coerce') |
|
|
| elif dtype == "str": |
|
|
| df[col] = df[col].astype(str) |
|
|
| elif dtype == "date": |
|
|
| df[col] = pd.to_datetime(df[col], errors='coerce') |
|
|
| except: |
|
|
| print(f"Failed to convert {col} to {dtype}") |
|
|
| |
|
|
| for col, method in suggestions.get("fill_missing", {}).items(): |
|
|
| if col not in df.columns: |
|
|
| continue |
|
|
| try: |
|
|
| if method == "mean": |
|
|
| df[col].fillna(df[col].mean(), inplace=True) |
|
|
| elif method == "median": |
|
|
| df[col].fillna(df[col].median(), inplace=True) |
|
|
| elif method == "mode": |
|
|
| df[col].fillna(df[col].mode().iloc[0], inplace=True) |
|
|
| elif isinstance(method, str): |
|
|
| df[col].fillna(method, inplace=True) |
|
|
| except: |
|
|
| print(f"Could not fill missing values for {col}") |
|
|
| |
|
|
| for col, mapping in suggestions.get("value_map", {}).items(): |
|
|
| if col in df.columns: |
|
|
| try: |
|
|
| df[col] = df[col].replace(mapping) |
|
|
| except: |
|
|
| print(f"Could not map values in {col}") |
|
|
| |
|
|
| return df |
|
|
| |
|
|
| def capture_dashboard_screenshot(): |
|
|
| """Capture the entire dashboard as a single image""" |
|
|
| try: |
|
|
| |
|
|
| import plotly.graph_objects as go |
|
|
| from plotly.subplots import make_subplots |
|
|
| |
|
|
| |
|
|
| fig = make_subplots(rows=2, cols=2, |
|
|
| subplot_titles=["Visualization 1", "Visualization 2", |
|
|
| "Visualization 3", "Visualization 4"]) |
|
|
| |
|
|
| |
|
|
| for i, plot in enumerate(st.session_state.dashboard_plots): |
|
|
| if plot is not None: |
|
|
| row = (i // 2) + 1 |
|
|
| col = (i % 2) + 1 |
|
|
| |
|
|
| |
|
|
| for trace in plot.data: |
|
|
| fig.add_trace(trace, row=row, col=col) |
|
|
| |
|
|
| |
|
|
| for axis_type in ['xaxis', 'yaxis']: |
|
|
| axis_name = f"{axis_type}{i+1 if i > 0 else ''}" |
|
|
| subplot_name = f"{axis_type}{row}{col}" |
|
|
| |
|
|
| |
|
|
| if hasattr(plot.layout, axis_name): |
|
|
| axis_props = getattr(plot.layout, axis_name) |
|
|
| fig.update_layout({subplot_name: axis_props}) |
|
|
| |
|
|
| |
|
|
| fig.update_layout( |
|
|
| height=800, |
|
|
| width=1000, |
|
|
| title_text="Dashboard Overview", |
|
|
| showlegend=False, |
|
|
| ) |
|
|
| |
|
|
| |
|
|
| dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png" |
|
|
| fig.write_image(dashboard_path, scale=2) |
|
|
| return dashboard_path |
|
|
| |
|
|
| except Exception as e: |
|
|
| import traceback |
|
|
| print(f"Error capturing dashboard: {str(e)}") |
|
|
| print(traceback.format_exc()) |
|
|
| return None |
|
|
| |
|
|
| def generate_enhanced_pdf_report(): |
|
|
| """Generate an enhanced PDF report with proper handling of base64 image data""" |
|
|
| try: |
|
|
| |
|
|
| buffer = io.BytesIO() |
|
|
| |
|
|
| |
|
|
| doc = SimpleDocTemplate(buffer, pagesize=letter, |
|
|
| leftMargin=36, rightMargin=36, |
|
|
| topMargin=36, bottomMargin=36) |
|
|
| |
|
|
| |
|
|
| styles = getSampleStyleSheet() |
|
|
| |
|
|
| |
|
|
| styles.add(ParagraphStyle( |
|
|
| name='ReportTitle', |
|
|
| parent=styles['Heading1'], |
|
|
| fontSize=24, |
|
|
| alignment=1, |
|
|
| spaceAfter=20, |
|
|
| textColor='#2C3E50' |
|
|
| )) |
|
|
| |
|
|
| styles.add(ParagraphStyle( |
|
|
| name='SectionHeader', |
|
|
| parent=styles['Heading2'], |
|
|
| fontSize=16, |
|
|
| spaceBefore=15, |
|
|
| spaceAfter=10, |
|
|
| textColor='#2C3E50', |
|
|
| borderWidth=1, |
|
|
| borderColor='#95A5A6', |
|
|
| borderPadding=5, |
|
|
| borderRadius=5 |
|
|
| )) |
|
|
| |
|
|
| styles.add(ParagraphStyle( |
|
|
| name='SubHeader', |
|
|
| parent=styles['Heading3'], |
|
|
| fontSize=14, |
|
|
| spaceBefore=10, |
|
|
| spaceAfter=8, |
|
|
| textColor='#34495E', |
|
|
| fontWeight='bold' |
|
|
| )) |
|
|
| styles.add(ParagraphStyle( |
|
|
| name='UserMessage', |
|
|
| parent=styles['Normal'], |
|
|
| fontSize=11, |
|
|
| leftIndent=10, |
|
|
| spaceBefore=8, |
|
|
| spaceAfter=4 |
|
|
| )) |
|
|
| |
|
|
| styles.add(ParagraphStyle( |
|
|
| name='AssistantMessage', |
|
|
| parent=styles['Normal'], |
|
|
| fontSize=11, |
|
|
| leftIndent=10, |
|
|
| spaceBefore=4, |
|
|
| spaceAfter=12, |
|
|
| textColor='#2980B9' |
|
|
| )) |
|
|
| |
|
|
| styles.add(ParagraphStyle( |
|
|
| name='Timestamp', |
|
|
| parent=styles['Italic'], |
|
|
| fontSize=10, |
|
|
| textColor='#7F8C8D', |
|
|
| alignment=2 |
|
|
| )) |
|
|
| |
|
|
| |
|
|
| elements = [] |
|
|
| |
|
|
| |
|
|
| elements.append(Paragraph('Data Analysis Report', styles['ReportTitle'])) |
|
|
| |
|
|
| |
|
|
| elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}', |
|
|
| styles['Timestamp'])) |
|
|
| elements.append(Spacer(1, 0.5*inch)) |
|
|
| |
|
|
| |
|
|
| elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader'])) |
|
|
| |
|
|
| if st.session_state.chat_history: |
|
|
| for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history): |
|
|
| |
|
|
| elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader'])) |
|
|
| user_msg_formatted = user_msg.replace('\n', '<br/>') |
|
|
| elements.append(Paragraph(user_msg_formatted, styles['UserMessage'])) |
|
|
| |
|
|
| |
|
|
| |
|
|
| base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)' |
|
|
| |
|
|
| |
|
|
| if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg): |
|
|
| |
|
|
| if '### Visualizations' in assistant_msg: |
|
|
| parts = assistant_msg.split('### Visualizations', 1) |
|
|
| text_part = parts[0] |
|
|
| viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else "" |
|
|
| else: |
|
|
| |
|
|
| match = re.search(base64_pattern, assistant_msg) |
|
|
| text_part = assistant_msg[:match.start()] |
|
|
| viz_part = assistant_msg[match.start():] |
|
|
| |
|
|
| |
|
|
| elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader'])) |
|
|
| text_part = text_part.replace('\n', '<br/>') |
|
|
| elements.append(Paragraph(text_part, styles['AssistantMessage'])) |
|
|
| |
|
|
| |
|
|
| matches = re.findall(base64_pattern, viz_part) |
|
|
| for j, base64_data in enumerate(matches): |
|
|
| try: |
|
|
| |
|
|
| image_data = base64.b64decode(base64_data) |
|
|
| |
|
|
| |
|
|
| temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png" |
|
|
| |
|
|
| with open(temp_img_path, 'wb') as f: |
|
|
| f.write(image_data) |
|
|
| |
|
|
| |
|
|
| elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader'])) |
|
|
| elements.append(Spacer(1, 0.1*inch)) |
|
|
| img = Image(temp_img_path, width=6*inch, height=4*inch) |
|
|
| elements.append(img) |
|
|
| elements.append(Spacer(1, 0.2*inch)) |
|
|
| except Exception as e: |
|
|
| print(f"Error processing base64 image: {str(e)}") |
|
|
| elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]", |
|
|
| styles['Normal'])) |
|
|
| else: |
|
|
| |
|
|
| elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader'])) |
|
|
| assistant_msg_formatted = assistant_msg.replace('\n', '<br/>') |
|
|
| if len(assistant_msg_formatted) > 1500: |
|
|
| assistant_msg_formatted = assistant_msg_formatted[:1500] + '...' |
|
|
| elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage'])) |
|
|
| |
|
|
| elements.append(Spacer(1, 0.2*inch)) |
|
|
| else: |
|
|
| elements.append(Paragraph('No conversation history available.', styles['Normal'])) |
|
|
| |
|
|
| |
|
|
| elements.append(PageBreak()) |
|
|
| |
|
|
| |
|
|
| elements.append(Paragraph('Dashboard Overview', styles['SectionHeader'])) |
|
|
| elements.append(Spacer(1, 0.2*inch)) |
|
|
| |
|
|
| |
|
|
| dashboard_img_path = capture_dashboard_screenshot() |
|
|
| |
|
|
| if dashboard_img_path: |
|
|
| |
|
|
| available_width = doc.width |
|
|
| |
|
|
| |
|
|
| pil_img = PILImage.open(dashboard_img_path) |
|
|
| img_width, img_height = pil_img.size |
|
|
| |
|
|
| |
|
|
| scale_factor = available_width / img_width |
|
|
| |
|
|
| |
|
|
| new_height = img_height * scale_factor |
|
|
| |
|
|
| |
|
|
| img = Image(dashboard_img_path, width=available_width, height=new_height) |
|
|
| elements.append(img) |
|
|
| else: |
|
|
| |
|
|
| plot_count = 0 |
|
|
| for i, plot in enumerate(st.session_state.dashboard_plots): |
|
|
| if plot is not None: |
|
|
| plot_count += 1 |
|
|
| |
|
|
| |
|
|
| img_bytes = io.BytesIO() |
|
|
| plot.write_image(img_bytes, format='png', width=500, height=300) |
|
|
| img_bytes.seek(0) |
|
|
| |
|
|
| |
|
|
| temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png" |
|
|
| |
|
|
| with open(temp_img_path, 'wb') as f: |
|
|
| f.write(img_bytes.getvalue()) |
|
|
| |
|
|
| |
|
|
| elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader'])) |
|
|
| elements.append(Spacer(1, 0.1*inch)) |
|
|
| |
|
|
| |
|
|
| img = Image(temp_img_path, width=6.5*inch, height=4*inch) |
|
|
| elements.append(img) |
|
|
| elements.append(Spacer(1, 0.3*inch)) |
|
|
| |
|
|
| if plot_count == 0: |
|
|
| elements.append(Paragraph('No visualizations have been added to the dashboard.', |
|
|
| styles['Normal'])) |
|
|
| |
|
|
| |
|
|
| doc.build(elements) |
|
|
| |
|
|
| |
|
|
| pdf_value = buffer.getvalue() |
|
|
| buffer.close() |
|
|
| |
|
|
| return pdf_value |
|
|
| |
|
|
| except Exception as e: |
|
|
| import traceback |
|
|
| print(f"Error generating enhanced PDF report: {str(e)}") |
|
|
| print(traceback.format_exc()) |
|
|
| return None |
|
|
| |
|
|
| def chat_with_workflow(message, history, dataset_info): |
|
|
| """Send user query to the workflow and get response""" |
|
|
| |
|
|
| if not dataset_info: |
|
|
| return "Please upload at least one dataset before asking questions." |
|
|
| |
|
|
| |
|
|
| if not st.session_state.api_key: |
|
|
| return "Please set up your API key and model in the Settings tab before chatting." |
|
|
| |
|
|
| print(f"Chat with workflow called with {len(dataset_info)} datasets") |
|
|
| |
|
|
| try: |
|
|
| |
|
|
| max_history = 3 |
|
|
| previous_messages = [] |
|
|
| |
|
|
| if history: |
|
|
| start_idx = max(0, len(history) - max_history) |
|
|
| recent_history = history[start_idx:] |
|
|
| |
|
|
| for exchange in recent_history: |
|
|
| if exchange[0]: |
|
|
| previous_messages.append(HumanMessage(content=exchange[0])) |
|
|
| if exchange[1]: |
|
|
| previous_messages.append(AIMessage(content=exchange[1])) |
|
|
| |
|
|
| |
|
|
| state = AgentState( |
|
|
| messages=previous_messages + [HumanMessage(content=message)], |
|
|
| input_data=dataset_info, |
|
|
| intermediate_outputs=[], |
|
|
| current_variables=st.session_state.persistent_vars, |
|
|
| output_image_paths=[] |
|
|
| ) |
|
|
| |
|
|
| |
|
|
| print("Executing workflow...") |
|
|
| result = chain.invoke(state) |
|
|
| print("Workflow execution completed") |
|
|
| |
|
|
| |
|
|
| messages = result["messages"] |
|
|
| |
|
|
| |
|
|
| response = "" |
|
|
| if messages: |
|
|
| latest_message = messages[-1] |
|
|
| if hasattr(latest_message, "content"): |
|
|
| content = latest_message.content |
|
|
| |
|
|
| |
|
|
| |
|
|
| if message in content: |
|
|
| content = content.split(message)[-1].strip() |
|
|
| |
|
|
| |
|
|
| content_lines = content.split('\n') |
|
|
| filtered_lines = [line for line in content_lines |
|
|
| if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))] |
|
|
| content = '\n'.join(filtered_lines) |
|
|
| |
|
|
| response = content.strip() + "\n\n" |
|
|
| |
|
|
| |
|
|
| if "output_image_paths" in result and result["output_image_paths"]: |
|
|
| response += "### Visualizations\n\n" |
|
|
| for img_path in result["output_image_paths"]: |
|
|
| try: |
|
|
| full_path = os.path.join(st.session_state.images_dir, img_path) |
|
|
| with open(full_path, 'rb') as f: |
|
|
| fig = pickle.load(f) |
|
|
| |
|
|
| |
|
|
| img_bytes = BytesIO() |
|
|
| fig.update_layout(width=800, height=500) |
|
|
| pio.write_image(fig, img_bytes, format='png') |
|
|
| img_bytes.seek(0) |
|
|
| |
|
|
| |
|
|
| b64_img = base64.b64encode(img_bytes.read()).decode() |
|
|
| response += f"\n\n" |
|
|
| except Exception as e: |
|
|
| response += f"Error loading visualization: {str(e)}\n\n" |
|
|
| |
|
|
| return response |
|
|
| |
|
|
| except Exception as e: |
|
|
| import traceback |
|
|
| print(f"Error in chat_with_workflow: {str(e)}") |
|
|
| print(traceback.format_exc()) |
|
|
| return f"Error executing workflow: {str(e)}" |
|
|
| |
|
|
| def auto_generate_dashboard(dataset_info): |
|
|
| """Generate an automatic dashboard with four plots""" |
|
|
| |
|
|
| if not dataset_info: |
|
|
| return "Please upload a dataset first.", [None, None, None, None] |
|
|
| |
|
|
| prompt = """ |
| |
| You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends). |
| |
| |
| |
| Use plotly and store the plots in a list named plotly_figures. |
| |
| |
| |
| Include multivariate plots using color/size/facets when helpful. |
| |
| """ |
|
|
| |
|
|
| state = AgentState( |
|
|
| messages=[HumanMessage(content=prompt)], |
|
|
| input_data=dataset_info, |
|
|
| intermediate_outputs=[], |
|
|
| current_variables=st.session_state.persistent_vars, |
|
|
| output_image_paths=[] |
|
|
| ) |
|
|
| |
|
|
| result = chain.invoke(state) |
|
|
| figures = [] |
|
|
| |
|
|
| if "output_image_paths" in result: |
|
|
| for img_path in result["output_image_paths"][:4]: |
|
|
| try: |
|
|
| full_path = os.path.join(st.session_state.images_dir, img_path) |
|
|
| with open(full_path, 'rb') as f: |
|
|
| fig = pickle.load(f) |
|
|
| figures.append(fig) |
|
|
| except Exception as e: |
|
|
| print(f"Error loading figure: {e}") |
|
|
| |
|
|
| while len(figures) < 4: |
|
|
| figures.append(None) |
|
|
| |
|
|
| st.session_state.dashboard_plots = figures |
|
|
| return "Dashboard generated!", figures |
|
|
| |
|
|
| def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col): |
|
|
| """Generate custom plots based on user-selected columns""" |
|
|
| |
|
|
| if not dataset_info or not x_col or not y_col: |
|
|
| return [None, None, None] |
|
|
| |
|
|
| prompt = f""" |
| |
| You are a data visualization expert. |
| |
| |
| |
| Create 3 insightful visualizations using Plotly based on: |
| |
| |
| |
| - X-axis: {x_col} |
| |
| - Y-axis: {y_col} |
| |
| - Facet (optional): {facet_col if facet_col != 'None' else 'None'} |
| |
| |
| |
| Try to find interesting relationships, trends, or clusters using appropriate chart types. |
| |
| |
| |
| Use `plotly_figures` list and avoid using fig.show(). |
| |
| """ |
|
|
| |
|
|
| state = AgentState( |
|
|
| messages=[HumanMessage(content=prompt)], |
|
|
| input_data=dataset_info, |
|
|
| intermediate_outputs=[], |
|
|
| current_variables=st.session_state.persistent_vars, |
|
|
| output_image_paths=[] |
|
|
| ) |
|
|
| |
|
|
| result = chain.invoke(state) |
|
|
| figures = [] |
|
|
| |
|
|
| if "output_image_paths" in result: |
|
|
| for img_path in result["output_image_paths"][:3]: |
|
|
| try: |
|
|
| full_path = os.path.join(st.session_state.images_dir, img_path) |
|
|
| with open(full_path, 'rb') as f: |
|
|
| fig = pickle.load(f) |
|
|
| figures.append(fig) |
|
|
| except Exception as e: |
|
|
| print(f"Error loading figure: {e}") |
|
|
| |
|
|
| while len(figures) < 3: |
|
|
| figures.append(None) |
|
|
| return figures |
|
|
| |
|
|
| def remove_plot(index): |
|
|
| """Remove a plot from the dashboard""" |
|
|
| if 0 <= index < len(st.session_state.dashboard_plots): |
|
|
| st.session_state.dashboard_plots[index] = None |
|
|
| |
|
|
| def respond(message): |
|
|
| """Handle chat message response""" |
|
|
| if not st.session_state.dataset_metadata_list: |
|
|
| bot_message = "Please upload at least one dataset before asking questions." |
|
|
| else: |
|
|
| bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list) |
|
|
| |
|
|
| st.session_state.chat_history.append((message, bot_message)) |
|
|
| st.rerun() |
|
|
| |
|
|
| def save_plot_to_dashboard(plot_index): |
|
|
| """Callback for the Add Plot button""" |
|
|
| for i in range(len(st.session_state.dashboard_plots)): |
|
|
| if st.session_state.dashboard_plots[i] is None: |
|
|
| |
|
|
| st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index] |
|
|
| return |
|
|
| |
|
|
| |
|
|
| st.set_page_config(page_title="QueryMind 🧠", layout="wide") |
|
|
| st.title("QueryMind 🧠 - Data Assistant") |
|
|
| st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.") |
|
|
| |
|
|
| |
|
|
| tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report", "Settings"]) |
|
|
| |
|
|
| with tab1: |
|
|
| st.header("Upload Datasets") |
|
|
| uploaded_files = st.file_uploader("Upload CSV or Excel Files", |
|
|
| accept_multiple_files=True, |
|
|
| type=['csv', 'xlsx', 'xls']) |
|
|
| |
|
|
| if uploaded_files and st.button("Process Uploaded Files"): |
|
|
| with st.spinner("Processing files..."): |
|
|
| preview, metadata_list, columns = process_file_upload(uploaded_files) |
|
|
| st.session_state.columns = columns |
|
|
| |
|
|
| |
|
|
| st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)") |
|
|
| |
|
|
| |
|
|
| st.subheader("Dataset Previews") |
|
|
| |
|
|
| for dataset_name, df in st.session_state.in_memory_datasets.items(): |
|
|
| with st.expander(f"Preview: {dataset_name}"): |
|
|
| |
|
|
| st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}") |
|
|
| |
|
|
| |
|
|
| col_info = pd.DataFrame({ |
|
|
| 'Column Name': df.columns, |
|
|
| 'Data Type': df.dtypes.astype(str), |
|
|
| 'Non-Null Count': df.count().values, |
|
|
| 'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns] |
|
|
| }) |
|
|
| |
|
|
| |
|
|
| st.write("**Column Information:**") |
|
|
| st.dataframe(col_info, use_container_width=True) |
|
|
| |
|
|
| |
|
|
| st.write("**Data Preview (First 10 rows):**") |
|
|
| st.dataframe(df.head(10), use_container_width=True) |
|
|
| |
|
|
| |
|
|
| st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.") |
|
|
| |
|
|
| with tab2: |
|
|
| st.header("Data Cleaning") |
|
|
| |
|
|
| if 'cleaning_done' not in st.session_state: |
|
|
| st.session_state.cleaning_done = False |
|
|
| |
|
|
| if 'cleaned_datasets' not in st.session_state: |
|
|
| st.session_state.cleaned_datasets = {} |
|
|
| |
|
|
| if 'cleaning_summaries' not in st.session_state: |
|
|
| st.session_state.cleaning_summaries = {} |
|
|
| |
|
|
| if st.session_state.get("in_memory_datasets"): |
|
|
| if not st.session_state.cleaning_done: |
|
|
| if st.button("Run Data Cleaning"): |
|
|
| with st.spinner("Running LLM-assisted cleaning..."): |
|
|
| for name, df in st.session_state.in_memory_datasets.items(): |
|
|
| raw_df = df.copy() |
|
|
| df_std = standard_clean(raw_df.copy()) |
|
|
| suggestions = llm_suggest_cleaning(df_std.copy()) |
|
|
| df_clean = apply_suggestions(df_std.copy(), suggestions) |
|
|
| st.session_state.cleaned_datasets[name] = df_clean |
|
|
| st.session_state.cleaning_summaries[name] = suggestions |
|
|
| st.session_state.cleaning_done = True |
|
|
| st.rerun() |
|
|
| else: |
|
|
| st.info("Click Run Data Cleaning to clean your datasets using the LLM.") |
|
|
| else: |
|
|
| for name, df_clean in st.session_state.cleaned_datasets.items(): |
|
|
| raw_df = st.session_state.in_memory_datasets[name] |
|
|
| |
|
|
| st.subheader(f"Dataset: {name}") |
|
|
| col1, col2 = st.columns(2) |
|
|
| |
|
|
| with col1: |
|
|
| st.markdown("Original Data (First 5 Rows)") |
|
|
| st.dataframe(raw_df.head()) |
|
|
| |
|
|
| with col2: |
|
|
| st.markdown("Cleaned Data (First 5 Rows)") |
|
|
| st.dataframe(df_clean.head()) |
|
|
| |
|
|
| st.markdown("Summary of Cleaning Actions") |
|
|
| suggestions = st.session_state.cleaning_summaries[name] |
|
|
| summary_text = "" |
|
|
| |
|
|
| if suggestions: |
|
|
| for key, value in suggestions.items(): |
|
|
| summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n" |
|
|
| st.markdown(summary_text) |
|
|
| |
|
|
| st.markdown("Refine the Cleaning (Natural Language Instructions)") |
|
|
| user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'", |
|
|
| key=f"user_input_{name}") |
|
|
| |
|
|
| if f'corrections_{name}' not in st.session_state: |
|
|
| st.session_state[f'corrections_{name}'] = [] |
|
|
| |
|
|
| if st.button("Apply Correction", key=f'apply_correction_{name}'): |
|
|
| if user_input.strip(): |
|
|
| correction_prompt = f""" |
| |
| You are a data cleaning expert. Below is a previously cleaned dataset with these actions: |
| |
| |
| |
| {summary_text} |
| |
| |
| |
| The user now wants the following additional instruction: |
| |
| \"{user_input.strip()}\" |
| |
| |
| |
| Write only the Python code that modifies the pandas DataFrame `df` accordingly. |
| |
| Do not include explanations or markdown. |
| |
| """ |
|
|
| correction_code = query_openai(correction_prompt) |
|
|
| |
|
|
| try: |
|
|
| df = st.session_state.cleaned_datasets[name].copy() |
|
|
| local_vars = {"df": df} |
|
|
| exec(correction_code, {}, local_vars) |
|
|
| df_updated = local_vars["df"] |
|
|
| |
|
|
| st.session_state.cleaned_datasets[name] = df_updated |
|
|
| st.session_state[f'corrections_{name}'].append((user_input, correction_code)) |
|
|
| st.success("Correction applied.") |
|
|
| st.rerun() |
|
|
| |
|
|
| except Exception as e: |
|
|
| st.error(f"Failed to apply correction: {str(e)}") |
|
|
| |
|
|
| if st.session_state[f'corrections_{name}']: |
|
|
| st.markdown("Applied Corrections") |
|
|
| for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']): |
|
|
| st.markdown(f"**Instruction:** {msg}") |
|
|
| st.code(code, language='python') |
|
|
| |
|
|
| col1, col2 = st.columns([1, 2]) |
|
|
| with col1: |
|
|
| if st.button("Reset Cleaning and Re-run"): |
|
|
| st.session_state.cleaning_done = False |
|
|
| st.rerun() |
|
|
| |
|
|
| with col2: |
|
|
| if st.button("Finalize and Proceed to Visualizations"): |
|
|
| st.session_state.cleaning_finalized = True |
|
|
| st.rerun() |
|
|
| else: |
|
|
| st.info("Please upload and process datasets first.") |
|
|
| |
|
|
| with tab3: |
|
|
| st.header("Chat with AI Assistant") |
|
|
| |
|
|
| |
|
|
| if not st.session_state.api_key: |
|
|
| st.warning("⚠️ Please set up your API key and model in the Settings tab before using the chat.") |
|
|
| |
|
|
| st.markdown(""" |
| |
| ## Example Questions |
| |
| - "What analysis can you perform on this dataset?" |
| |
| - "Show me basic statistics for all columns" |
| |
| - "Create a correlation heatmap" |
| |
| - "Plot the distribution of a specific column" |
| |
| - "What is the relationship between two columns?" |
| |
| """) |
|
|
| |
|
|
| |
|
|
| for exchange in st.session_state.chat_history: |
|
|
| with st.chat_message("user"): |
|
|
| st.write(exchange[0]) |
|
|
| with st.chat_message("assistant"): |
|
|
| st.write(exchange[1]) |
|
|
| |
|
|
| |
|
|
| if prompt := st.chat_input("Your question"): |
|
|
| with st.spinner("Thinking..."): |
|
|
| respond(prompt) |
|
|
| |
|
|
| with tab4: |
|
|
| st.header("Auto Dashboard Generator") |
|
|
| |
|
|
| |
|
|
| dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title") |
|
|
|
|
| col1, col2 = st.columns(2) |
|
|
| |
|
|
| with col1: |
|
|
| if st.button("Generate Suggested Dashboard (Auto)"): |
|
|
| if not st.session_state.api_key: |
|
|
| st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
|
|
| else: |
|
|
| with st.spinner("Generating dashboard..."): |
|
|
| message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list) |
|
|
| st.success(message) |
|
|
| |
|
|
| with col2: |
|
|
| if st.button("Refresh Column Options"): |
|
|
| st.session_state.columns = get_columns() |
|
|
| st.rerun() |
|
|
| |
|
|
| |
|
|
| st.subheader("Dashboard") |
|
|
| |
|
|
| |
|
|
| col1, col2 = st.columns(2) |
|
|
| |
|
|
| with col1: |
|
|
| if st.session_state.dashboard_plots[0]: |
|
|
| st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True) |
|
|
| if st.button("Remove Plot 1"): |
|
|
| remove_plot(0) |
|
|
| st.rerun() |
|
|
| |
|
|
| with col2: |
|
|
| if st.session_state.dashboard_plots[1]: |
|
|
| st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True) |
|
|
| if st.button("Remove Plot 2"): |
|
|
| remove_plot(1) |
|
|
| st.rerun() |
|
|
| |
|
|
| |
|
|
| col3, col4 = st.columns(2) |
|
|
| |
|
|
| with col3: |
|
|
| if st.session_state.dashboard_plots[2]: |
|
|
| st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True) |
|
|
| if st.button("Remove Plot 3"): |
|
|
| remove_plot(2) |
|
|
| st.rerun() |
|
|
| |
|
|
| with col4: |
|
|
| if st.session_state.dashboard_plots[3]: |
|
|
| st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True) |
|
|
| if st.button("Remove Plot 4"): |
|
|
| remove_plot(3) |
|
|
| st.rerun() |
|
|
| |
|
|
| |
|
|
| st.subheader("Custom Plot Generator") |
|
|
| |
|
|
| |
|
|
| col1, col2, col3 = st.columns(3) |
|
|
| |
|
|
| with col1: |
|
|
| x_axis = st.selectbox("X-axis Column", options=st.session_state.columns) |
|
|
| |
|
|
| with col2: |
|
|
| y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns) |
|
|
| |
|
|
| with col3: |
|
|
| facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns) |
|
|
|
|
| if st.button("Generate Custom Visualizations"): |
|
|
| if not st.session_state.api_key: |
|
|
| st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
|
|
| else: |
|
|
| with st.spinner("Generating custom visualizations..."): |
|
|
| custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet) |
|
|
| |
|
|
| for i, plot in enumerate(custom_plots): |
|
|
| if plot: |
|
|
| st.session_state.custom_plots_to_save[i] = plot |
|
|
| |
|
|
| |
|
|
| for i, plot in enumerate(custom_plots): |
|
|
| if plot: |
|
|
| st.plotly_chart(plot, use_container_width=True) |
|
|
| st.button( |
|
|
| f"Add Plot {i+1} to Dashboard", |
|
|
| key=f"add_plot_{i}", |
|
|
| on_click=save_plot_to_dashboard, |
|
|
| args=(i,) |
|
|
| ) |
|
|
| |
|
|
| with tab5: |
|
|
| st.header("Generate Analysis Report") |
|
|
| |
|
|
| st.markdown(""" |
| |
| Generate a PDF report containing: |
| |
| - Dashboard visualizations |
| |
| - Chat conversation history |
| |
| """) |
|
|
| |
|
|
| report_title = st.text_input("Report Title (Optional)", "Data Analysis Report") |
|
|
| |
|
|
| if st.button("Generate PDF Report"): |
|
|
| if not st.session_state.api_key: |
|
|
| st.warning("⚠️ Please set up your API key and model in the Settings tab first.") |
|
|
| else: |
|
|
| with st.spinner("Generating report..."): |
|
|
| pdf_data = generate_enhanced_pdf_report() |
|
|
| if pdf_data: |
|
|
| |
|
|
| b64_pdf = base64.b64encode(pdf_data).decode('utf-8') |
|
|
| |
|
|
| pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>' |
|
|
| st.markdown("### Your report is ready!") |
|
|
| st.markdown(pdf_download_link, unsafe_allow_html=True) |
|
|
| |
|
|
| with st.expander("Preview Report"): |
|
|
| st.warning("PDF preview is not available in Streamlit, please download the report to view it.") |
|
|
| else: |
|
|
| st.error("Failed to generate the report. Please try again.") |
|
|
| |
|
|
| with tab6: |
|
|
| st.header("AI Provider Settings") |
|
|
| |
|
|
| |
|
|
| provider = st.radio("Select AI Provider", |
|
|
| options=["OpenAI", "Groq"], |
|
|
| index=0 if st.session_state.ai_provider == "openai" else 1, |
|
|
| horizontal=True) |
|
|
| |
|
|
| |
|
|
| st.session_state.ai_provider = provider.lower() |
|
|
| |
|
|
| |
|
|
| api_key = st.text_input("Enter API Key", |
|
|
| value=st.session_state.api_key, |
|
|
| type="password", |
|
|
| help="Your API key for the selected provider") |
|
|
| |
|
|
| |
|
|
| if st.session_state.ai_provider == "openai": |
|
|
| model_options = OPENAI_MODELS |
|
|
| model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable." |
|
|
| else: |
|
|
| model_options = GROQ_MODELS |
|
|
| model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest." |
|
|
| |
|
|
| |
|
|
| selected_model = st.selectbox("Select Model", |
|
|
| options=model_options, |
|
|
| index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0, |
|
|
| help=model_help) |
|
|
| |
|
|
| |
|
|
| if st.button("Save Settings"): |
|
|
| st.session_state.api_key = api_key |
|
|
| st.session_state.selected_model = selected_model |
|
|
| |
|
|
| |
|
|
| try: |
|
|
| |
|
|
| test_llm = initialize_llm() |
|
|
| if test_llm: |
|
|
| st.success(f"✅ Successfully configured {provider} with model: {selected_model}") |
|
|
| else: |
|
|
| st.error("Failed to initialize the AI provider. Please check your API key and model selection.") |
|
|
| except Exception as e: |
|
|
| st.error(f"Error testing settings: {str(e)}") |
|
|
| |
|
|
| |
|
|
| st.subheader("Current Settings") |
|
|
| settings_info = f""" |
| |
| - **Provider**: {st.session_state.ai_provider.upper()} |
| |
| - **Model**: {st.session_state.selected_model} |
| |
| - **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'} |
| |
| """ |
|
|
| st.markdown(settings_info) |
|
|
| |
|
|
| |
|
|
| if st.session_state.ai_provider == "openai": |
|
|
| st.info(""" |
| |
| **OpenAI Models Information:** |
| |
| - **GPT-4**: Most powerful model, best for complex analysis and detailed explanations |
| |
| - **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities |
| |
| - **GPT-4-Mini**: Economical option with good performance for standard tasks |
| |
| - **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization |
| |
| """) |
|
|
| else: |
|
|
| st.info(""" |
| |
| **Groq Models Information:** |
| |
| - **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis |
| |
| - **gemma2-9b-it**: Good balance of speed and capabilities |
| |
| - **llama-3-8b-8192**: Fastest option for basic analysis tasks |
| |
| """) |
|
|
| |
|
|
| |
|
|
| with st.expander("How to get API Keys"): |
|
|
| if st.session_state.ai_provider == "openai": |
|
|
| st.markdown(""" |
| |
| ### Getting an OpenAI API Key |
| |
| |
| |
| 1. Go to [OpenAI's platform](https://platform.openai.com) |
| |
| 2. Sign up or log in to your account |
| |
| 3. Navigate to the API section |
| |
| 4. Create a new API key |
| |
| 5. Copy the key and paste it above |
| |
| |
| |
| Note: OpenAI API usage incurs charges based on tokens used. |
| |
| """) |
|
|
| else: |
|
|
| st.markdown(""" |
| |
| ### Getting a Groq API Key |
| |
| |
| |
| 1. Go to [Groq's website](https://console.groq.com/keys) |
| |
| 2. Sign up or log in to your account |
| |
| 3. Navigate to API Keys section |
| |
| 4. Create a new API key |
| |
| 5. Copy the key and paste it above |
| |
| |
| |
| Note: Check Groq's pricing page for current rates. |
| |
| """) |
|
|
| |
|
|
| |
|
|
| def cleanup(): |
|
|
| try: |
|
|
| shutil.rmtree(st.session_state.temp_dir) |
|
|
| print(f"Cleaned up temporary directory: {st.session_state.temp_dir}") |
|
|
| except Exception as e: |
|
|
| print(f"Error cleaning up: {e}") |
|
|
| |
|
|
| import atexit |
|
|
| atexit.register(cleanup) |