RAVENOCC commited on
Commit
c45a9b8
·
verified ·
1 Parent(s): 2ffc683

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3137 -0
app.py ADDED
@@ -0,0 +1,3137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+
5
+ import pandas as pd
6
+
7
+ import pickle
8
+
9
+ import base64
10
+
11
+ from io import BytesIO, StringIO
12
+
13
+ import sys
14
+
15
+ import operator
16
+
17
+ from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple
18
+
19
+ import tempfile
20
+
21
+ import shutil
22
+
23
+ import plotly.io as pio
24
+
25
+ import io
26
+
27
+ import re
28
+
29
+ import json
30
+
31
+ import openai
32
+
33
+ # from fpdf import FPDF
34
+
35
+ import base64
36
+
37
+ from datetime import datetime
38
+
39
+ from reportlab.lib.pagesizes import letter
40
+
41
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
42
+
43
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
44
+
45
+ from reportlab.lib.units import inch
46
+
47
+ from PIL import Image as PILImage
48
+
49
+
50
+
51
+ # Import LangChain and LangGraph components
52
+
53
+ from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage
54
+
55
+ from langchain_core.prompts import ChatPromptTemplate
56
+
57
+ from langchain_openai import ChatOpenAI
58
+
59
+ from langchain_experimental.utilities import PythonREPL
60
+
61
+ from langgraph.prebuilt import ToolInvocation, ToolExecutor
62
+
63
+ from langchain_core.tools import tool
64
+
65
+ from langgraph.prebuilt import InjectedState
66
+
67
+ from langgraph.graph import StateGraph, END
68
+
69
+ from reportlab.platypus import PageBreak
70
+
71
+ from PIL import Image as PILImage
72
+
73
+
74
+
75
+ # Initialize session state for AI provider settings
76
+
77
+ if 'ai_provider' not in st.session_state:
78
+
79
+ st.session_state.ai_provider = "openai"
80
+
81
+
82
+
83
+ if 'api_key' not in st.session_state:
84
+
85
+ st.session_state.api_key = ""
86
+
87
+
88
+
89
+ if 'selected_model' not in st.session_state:
90
+
91
+ st.session_state.selected_model = "gpt-4"
92
+
93
+
94
+
95
+ # Define model options for each provider
96
+
97
+ OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"]
98
+
99
+ GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"]
100
+
101
+
102
+
103
+ # Create temporary directory for file storage
104
+
105
+ if 'temp_dir' not in st.session_state:
106
+
107
+ st.session_state.temp_dir = tempfile.mkdtemp()
108
+
109
+ st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle")
110
+
111
+ os.makedirs(st.session_state.images_dir, exist_ok=True)
112
+
113
+ print(f"Created temporary directory: {st.session_state.temp_dir}")
114
+
115
+ print(f"Created images directory: {st.session_state.images_dir}")
116
+
117
+
118
+
119
+ # Define the system prompt
120
+
121
+ SYSTEM_PROMPT = """## Role
122
+
123
+ You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data.
124
+
125
+
126
+
127
+ ## Capabilities
128
+
129
+ 1. **Execute python code** using the `complete_python_task` tool.
130
+
131
+
132
+
133
+ ## Goals
134
+
135
+ 1. Understand the user's objectives clearly.
136
+
137
+ 2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems.
138
+
139
+ 3. Investigate if the goal is achievable by running Python code via the `python_code` field.
140
+
141
+ 4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances.
142
+
143
+
144
+
145
+ ## Code Guidelines
146
+
147
+ - **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data.
148
+
149
+ - **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed.
150
+
151
+ - **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise.
152
+
153
+ - **ONLY USE THE FOLLOWING LIBRARIES**:
154
+
155
+ - `pandas`
156
+
157
+ - `sklearn` (including all major ML models)
158
+
159
+ - `plotly`
160
+
161
+ - `numpy`
162
+
163
+
164
+
165
+ All these libraries are already imported for you.
166
+
167
+
168
+
169
+ ## Machine Learning Guidelines
170
+
171
+ - For regression tasks:
172
+
173
+ - Linear Regression: `LinearRegression`
174
+
175
+ - Logistic Regression: `LogisticRegression`
176
+
177
+ - Ridge Regression: `Ridge`
178
+
179
+ - Lasso Regression: `Lasso`
180
+
181
+ - Random Forest Regression: `RandomForestRegressor`
182
+
183
+
184
+
185
+ - For classification tasks:
186
+
187
+ - Logistic Regression: `LogisticRegression`
188
+
189
+ - Decision Trees: `DecisionTreeClassifier`
190
+
191
+ - Random Forests: `RandomForestClassifier`
192
+
193
+ - Support Vector Machines: `SVC`
194
+
195
+ - K-Nearest Neighbors: `KNeighborsClassifier`
196
+
197
+ - Naive Bayes: `GaussianNB`
198
+
199
+
200
+
201
+ - For clustering:
202
+
203
+ - K-Means: `KMeans`
204
+
205
+ - DBSCAN: `DBSCAN`
206
+
207
+
208
+
209
+ - For dimensionality reduction:
210
+
211
+ - PCA: `PCA`
212
+
213
+
214
+
215
+ - Always preprocess data appropriately:
216
+
217
+ - Scale numerical features with `StandardScaler` or `MinMaxScaler`
218
+
219
+ - Encode categorical variables with `OneHotEncoder` when needed
220
+
221
+ - Handle missing values with `SimpleImputer`
222
+
223
+
224
+
225
+ - Always split data into training and testing sets using `train_test_split`
226
+
227
+ - Evaluate models using appropriate metrics:
228
+
229
+ - For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score`
230
+
231
+ - For classification: `accuracy_score`, `confusion_matrix`, `classification_report`
232
+
233
+ - For clustering: `silhouette_score`
234
+
235
+
236
+
237
+ - Consider using `cross_val_score` for more robust evaluation
238
+
239
+ - Visualize ML results with plotly when possible
240
+
241
+
242
+
243
+ ## Plotting Guidelines
244
+
245
+ - Always use the `plotly` library for plotting.
246
+
247
+ - Store all plotly figures inside a `plotly_figures` list, they will be saved automatically.
248
+
249
+ - Do not try and show the plots inline with `fig.show()`.
250
+
251
+ """
252
+
253
+
254
+
255
+ # Define the State class
256
+
257
+ class AgentState(TypedDict):
258
+
259
+ messages: Annotated[Sequence[BaseMessage], operator.add]
260
+
261
+ input_data: Annotated[List[Dict], operator.add]
262
+
263
+ intermediate_outputs: Annotated[List[dict], operator.add]
264
+
265
+ current_variables: dict
266
+
267
+ output_image_paths: Annotated[List[str], operator.add]
268
+
269
+
270
+
271
+ # Initialize session state variables
272
+
273
+ if 'in_memory_datasets' not in st.session_state:
274
+
275
+ st.session_state.in_memory_datasets = {}
276
+
277
+
278
+
279
+ if 'persistent_vars' not in st.session_state:
280
+
281
+ st.session_state.persistent_vars = {}
282
+
283
+
284
+
285
+ if 'dataset_metadata_list' not in st.session_state:
286
+
287
+ st.session_state.dataset_metadata_list = []
288
+
289
+
290
+
291
+ if 'chat_history' not in st.session_state:
292
+
293
+ st.session_state.chat_history = []
294
+
295
+
296
+
297
+ if 'dashboard_plots' not in st.session_state:
298
+
299
+ st.session_state.dashboard_plots = [None, None, None, None]
300
+
301
+
302
+
303
+ if 'columns' not in st.session_state:
304
+
305
+ st.session_state.columns = ["No columns available"]
306
+
307
+
308
+
309
+ if 'custom_plots_to_save' not in st.session_state:
310
+
311
+ st.session_state.custom_plots_to_save = {}
312
+
313
+
314
+
315
+ # Set up the tools
316
+
317
+ repl = PythonREPL()
318
+
319
+ plotly_saving_code = """import pickle
320
+
321
+
322
+
323
+ import uuid
324
+
325
+ import os
326
+
327
+ for figure in plotly_figures:
328
+
329
+ pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle"
330
+
331
+ with open(pickle_filename, 'wb') as f:
332
+
333
+ pickle.dump(figure, f)
334
+
335
+ """
336
+
337
+
338
+
339
+ @tool
340
+
341
+ def complete_python_task(
342
+
343
+ graph_state: Annotated[dict, InjectedState],
344
+
345
+ thought: str,
346
+
347
+ python_code: str
348
+
349
+ ) -> Tuple[str, dict]:
350
+
351
+ """Execute Python code for data analysis and visualization."""
352
+
353
+
354
+
355
+ current_variables = graph_state.get("current_variables", {})
356
+
357
+
358
+
359
+ # Load datasets from in-memory storage
360
+
361
+ for input_dataset in graph_state.get("input_data", []):
362
+
363
+ var_name = input_dataset.get("variable_name")
364
+
365
+ if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets:
366
+
367
+ print(f"Loading {var_name} from in-memory storage")
368
+
369
+ current_variables[var_name] = st.session_state.in_memory_datasets[var_name]
370
+
371
+ current_image_pickle_files = os.listdir(st.session_state.images_dir)
372
+
373
+
374
+
375
+ try:
376
+
377
+ # Capture stdout
378
+
379
+ old_stdout = sys.stdout
380
+
381
+ sys.stdout = StringIO()
382
+
383
+
384
+
385
+ # Execute the code and capture the result
386
+
387
+ exec_globals = globals().copy()
388
+
389
+ exec_globals.update(st.session_state.persistent_vars)
390
+
391
+ exec_globals.update(current_variables)
392
+
393
+
394
+
395
+ # Add scikit-learn modules to execution environment
396
+
397
+ import sklearn
398
+
399
+ import numpy as np
400
+
401
+
402
+
403
+ # Import scikit-learn components
404
+
405
+ from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso # type: ignore
406
+
407
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier # type: ignore
408
+
409
+ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
410
+
411
+ from sklearn.svm import SVC, SVR
412
+
413
+ from sklearn.naive_bayes import GaussianNB
414
+
415
+ from sklearn.decomposition import PCA
416
+
417
+ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
418
+
419
+ from sklearn.cluster import KMeans, DBSCAN
420
+
421
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
422
+
423
+ from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
424
+
425
+ from sklearn.metrics import (
426
+
427
+ accuracy_score, confusion_matrix, classification_report,
428
+
429
+ mean_squared_error, r2_score, mean_absolute_error, silhouette_score
430
+
431
+ )
432
+
433
+ from sklearn.pipeline import Pipeline
434
+
435
+ from sklearn.impute import SimpleImputer
436
+
437
+
438
+
439
+ # Update execution globals with all ML components
440
+
441
+ exec_globals.update({
442
+
443
+ "plotly_figures": [],
444
+
445
+ "images_dir": st.session_state.images_dir,
446
+
447
+ "np": np,
448
+
449
+ # Linear models
450
+
451
+ "LinearRegression": LinearRegression,
452
+
453
+ "LogisticRegression": LogisticRegression,
454
+
455
+ "Ridge": Ridge,
456
+
457
+ "Lasso": Lasso,
458
+
459
+ # Tree-based models
460
+
461
+ "DecisionTreeClassifier": DecisionTreeClassifier,
462
+
463
+ "DecisionTreeRegressor": DecisionTreeRegressor,
464
+
465
+ "RandomForestClassifier": RandomForestClassifier,
466
+
467
+ "RandomForestRegressor": RandomForestRegressor,
468
+
469
+ "GradientBoostingClassifier": GradientBoostingClassifier,
470
+
471
+ # SVM models
472
+
473
+ "SVC": SVC,
474
+
475
+ "SVR": SVR,
476
+
477
+ # Other models
478
+
479
+ "GaussianNB": GaussianNB,
480
+
481
+ "PCA": PCA,
482
+
483
+ "KNeighborsClassifier": KNeighborsClassifier,
484
+
485
+ "KNeighborsRegressor": KNeighborsRegressor,
486
+
487
+ "KMeans": KMeans,
488
+
489
+ "DBSCAN": DBSCAN,
490
+
491
+ # Preprocessing
492
+
493
+ "StandardScaler": StandardScaler,
494
+
495
+ "MinMaxScaler": MinMaxScaler,
496
+
497
+ "OneHotEncoder": OneHotEncoder,
498
+
499
+ "SimpleImputer": SimpleImputer,
500
+
501
+ # Model selection and evaluation
502
+
503
+ "train_test_split": train_test_split,
504
+
505
+ "cross_val_score": cross_val_score,
506
+
507
+ "GridSearchCV": GridSearchCV,
508
+
509
+ "accuracy_score": accuracy_score,
510
+
511
+ "confusion_matrix": confusion_matrix,
512
+
513
+ "classification_report": classification_report,
514
+
515
+ "mean_squared_error": mean_squared_error,
516
+
517
+ "r2_score": r2_score,
518
+
519
+ "mean_absolute_error": mean_absolute_error,
520
+
521
+ "silhouette_score": silhouette_score,
522
+
523
+ # Pipeline
524
+
525
+ "Pipeline": Pipeline
526
+
527
+ })
528
+
529
+
530
+
531
+ exec(python_code, exec_globals)
532
+
533
+
534
+
535
+ st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()})
536
+
537
+
538
+
539
+ # Get the captured stdout
540
+
541
+ output = sys.stdout.getvalue()
542
+
543
+
544
+
545
+ # Restore stdout
546
+
547
+ sys.stdout = old_stdout
548
+
549
+
550
+
551
+ updated_state = {
552
+
553
+ "intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}],
554
+
555
+ "current_variables": st.session_state.persistent_vars
556
+
557
+ }
558
+
559
+
560
+
561
+ if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']:
562
+
563
+ exec(plotly_saving_code, exec_globals)
564
+
565
+
566
+
567
+ # Check if any images were created
568
+
569
+ new_image_folder_contents = os.listdir(st.session_state.images_dir)
570
+
571
+ new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files]
572
+
573
+
574
+
575
+ if new_image_files:
576
+
577
+ updated_state["output_image_paths"] = new_image_files
578
+
579
+ st.session_state.persistent_vars["plotly_figures"] = []
580
+
581
+ return output, updated_state
582
+
583
+
584
+
585
+ except Exception as e:
586
+
587
+ sys.stdout = old_stdout # Restore stdout in case of error
588
+
589
+ print(f"Error in complete_python_task: {str(e)}")
590
+
591
+ return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]}
592
+
593
+
594
+
595
+ # Function to initialize the LLM based on selected provider and model
596
+
597
+ def initialize_llm():
598
+
599
+ api_key = st.session_state.api_key
600
+
601
+ model = st.session_state.selected_model
602
+
603
+
604
+
605
+ if not api_key:
606
+
607
+ return None
608
+
609
+
610
+
611
+ try:
612
+
613
+ if st.session_state.ai_provider == "openai":
614
+
615
+ os.environ["OPENAI_API_KEY"] = api_key
616
+
617
+ return ChatOpenAI(model=model, temperature=0)
618
+
619
+ elif st.session_state.ai_provider == "groq":
620
+
621
+ os.environ["GROQ_API_KEY"] = api_key
622
+
623
+ # For Groq, set the base URL and use the model
624
+
625
+ from langchain_groq import ChatGroq
626
+
627
+ return ChatGroq(model=model, temperature=0)
628
+
629
+ except Exception as e:
630
+
631
+ print(f"Error initializing LLM: {str(e)}")
632
+
633
+ return None
634
+
635
+
636
+
637
+ # Set up the tools
638
+
639
+ tools = [complete_python_task]
640
+
641
+ tool_executor = ToolExecutor(tools)
642
+
643
+
644
+
645
+ # Load the prompt template
646
+
647
+ chat_template = ChatPromptTemplate.from_messages([
648
+
649
+ ("system", SYSTEM_PROMPT),
650
+
651
+ ("placeholder", "{messages}"),
652
+
653
+ ])
654
+
655
+
656
+
657
+ def create_data_summary(state: AgentState) -> str:
658
+
659
+ summary = ""
660
+
661
+ variables = []
662
+
663
+
664
+
665
+ # Add sample data for each dataset
666
+
667
+ for d in state.get("input_data", []):
668
+
669
+ var_name = d.get("variable_name")
670
+
671
+ if var_name:
672
+
673
+
674
+
675
+ variables.append(var_name)
676
+
677
+ summary += f"\n\nVariable: {var_name}\n"
678
+
679
+ summary += f"Description: {d.get('data_description', 'No description')}\n"
680
+
681
+
682
+
683
+ # Add sample data if available
684
+
685
+ if var_name in st.session_state.in_memory_datasets:
686
+
687
+ df = st.session_state.in_memory_datasets[var_name]
688
+
689
+ summary += "\nSample Data (first 5 rows):\n"
690
+
691
+ summary += df.head(5).to_string()
692
+
693
+
694
+
695
+ if "current_variables" in state:
696
+
697
+ remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")]
698
+
699
+
700
+
701
+ for v in remaining_variables:
702
+
703
+
704
+
705
+ var_value = state["current_variables"].get(v)
706
+
707
+
708
+
709
+ if isinstance(var_value, pd.DataFrame):
710
+
711
+ summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})"
712
+
713
+ else:
714
+
715
+ summary += f"\n\nVariable: {v}"
716
+
717
+ return summary
718
+
719
+
720
+
721
+ def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]:
722
+
723
+ """Determine if we should route to tools or end the chain"""
724
+
725
+ if messages := state.get("messages", []):
726
+
727
+ ai_message = messages[-1]
728
+
729
+ else:
730
+
731
+ raise ValueError(f"No messages found in input state to tool_edge: {state}")
732
+
733
+
734
+
735
+ if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
736
+
737
+ return "tools"
738
+
739
+
740
+
741
+ return "__end__"
742
+
743
+
744
+
745
+ def call_model(state: AgentState):
746
+
747
+ """Call the LLM to get a response"""
748
+
749
+ current_data_template = """The following data is available:\n{data_summary}"""
750
+
751
+ current_data_message = HumanMessage(
752
+
753
+ content=current_data_template.format(data_summary=create_data_summary(state))
754
+
755
+ )
756
+
757
+ messages = [current_data_message] + state["messages"]
758
+
759
+
760
+
761
+ # Get the initialized LLM
762
+
763
+ llm = initialize_llm()
764
+
765
+ if llm is None:
766
+
767
+ return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]}
768
+
769
+
770
+
771
+ # Create the model with bound tools
772
+
773
+ model = llm.bind_tools(tools)
774
+
775
+ model = chat_template | model
776
+
777
+
778
+
779
+ llm_outputs = model.invoke({"messages": messages})
780
+
781
+ return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]}
782
+
783
+
784
+
785
+ def call_tools(state: AgentState):
786
+
787
+ """Execute tools called by the LLM"""
788
+
789
+ last_message = state["messages"][-1]
790
+
791
+ tool_invocations = []
792
+
793
+
794
+
795
+ if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'):
796
+
797
+ tool_invocations = [
798
+
799
+ ToolInvocation(
800
+
801
+ tool=tool_call["name"],
802
+
803
+ tool_input={**tool_call["args"], "graph_state": state}
804
+
805
+ ) for tool_call in last_message.tool_calls
806
+
807
+ ]
808
+
809
+ responses = tool_executor.batch(tool_invocations, return_exceptions=True)
810
+
811
+
812
+
813
+ tool_messages = []
814
+
815
+ state_updates = {}
816
+
817
+
818
+
819
+ for tc, response in zip(last_message.tool_calls, responses):
820
+
821
+ if isinstance(response, Exception):
822
+
823
+ print(f"Exception in tool execution: {str(response)}")
824
+
825
+ tool_messages.append(ToolMessage(
826
+
827
+ content=f"Error: {str(response)}",
828
+
829
+ name=tc["name"],
830
+
831
+ tool_call_id=tc["id"]
832
+
833
+ ))
834
+
835
+ continue
836
+
837
+
838
+
839
+ message, updates = response
840
+
841
+ tool_messages.append(ToolMessage(
842
+
843
+ content=str(message),
844
+
845
+ name=tc["name"],
846
+
847
+ tool_call_id=tc["id"]
848
+
849
+ ))
850
+
851
+
852
+
853
+ # Merge updates instead of overwriting
854
+
855
+ for key, value in updates.items():
856
+
857
+ if key in state_updates:
858
+
859
+ if isinstance(value, list) and isinstance(state_updates[key], list):
860
+
861
+ state_updates[key].extend(value)
862
+
863
+ elif isinstance(value, dict) and isinstance(state_updates[key], dict):
864
+
865
+ state_updates[key].update(value)
866
+
867
+ else:
868
+
869
+ state_updates[key] = value
870
+
871
+ else:
872
+
873
+ state_updates[key] = value
874
+
875
+
876
+
877
+ if 'messages' not in state_updates:
878
+
879
+ state_updates["messages"] = []
880
+
881
+
882
+
883
+ state_updates["messages"] = tool_messages
884
+
885
+ return state_updates
886
+
887
+
888
+
889
+ # Set up the graph
890
+
891
+ workflow = StateGraph(AgentState)
892
+
893
+ workflow.add_node("agent", call_model)
894
+
895
+ workflow.add_node("tools", call_tools)
896
+
897
+ workflow.add_conditional_edges(
898
+
899
+ "agent",
900
+
901
+ route_to_tools,
902
+
903
+ {
904
+
905
+ "tools": "tools",
906
+
907
+ "__end__": END
908
+
909
+ }
910
+
911
+ )
912
+
913
+ workflow.add_edge("tools", "agent")
914
+
915
+ workflow.set_entry_point("agent")
916
+
917
+
918
+
919
+ chain = workflow.compile()
920
+
921
+
922
+
923
+ def process_file_upload(files):
924
+
925
+ """Process uploaded files and return dataframe previews and column names"""
926
+
927
+ st.session_state.in_memory_datasets = {} # Clear previous datasets
928
+
929
+ st.session_state.dataset_metadata_list = [] # Clear previous metadata
930
+
931
+ st.session_state.persistent_vars.clear() # Clear persistent variables for new session
932
+
933
+
934
+
935
+ if not files:
936
+
937
+ return "No files uploaded.", [], ["No columns available"]
938
+
939
+
940
+
941
+ results = []
942
+
943
+ all_columns = [] # Track all columns from all datasets
944
+
945
+
946
+
947
+ for file in files:
948
+
949
+ try:
950
+
951
+ # Use file object directly
952
+
953
+ if file.name.endswith('.csv'):
954
+
955
+ df = pd.read_csv(file)
956
+
957
+ elif file.name.endswith(('.xls', '.xlsx')):
958
+
959
+ df = pd.read_excel(file)
960
+
961
+ else:
962
+
963
+ results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.")
964
+
965
+ continue
966
+
967
+
968
+
969
+ var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower()
970
+
971
+ st.session_state.in_memory_datasets[var_name] = df
972
+
973
+
974
+
975
+ # Collect all columns
976
+
977
+ all_columns.extend(df.columns.tolist())
978
+
979
+
980
+
981
+ # Create dataset metadata
982
+
983
+ dataset_metadata = {
984
+
985
+ "variable_name": var_name,
986
+
987
+ "data_path": "in_memory",
988
+
989
+ "data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}",
990
+
991
+ "original_filename": file.name
992
+
993
+ }
994
+
995
+
996
+
997
+ st.session_state.dataset_metadata_list.append(dataset_metadata)
998
+
999
+
1000
+
1001
+ # Return preview of the dataset
1002
+
1003
+ preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n"
1004
+
1005
+ preview += df.head(10).to_markdown()
1006
+
1007
+ results.append(preview)
1008
+
1009
+ print(f"Successfully processed {file.name}")
1010
+
1011
+
1012
+
1013
+ except Exception as e:
1014
+
1015
+ print(f"Error processing {file.name}: {str(e)}")
1016
+
1017
+ results.append(f"Error processing {file.name}: {str(e)}")
1018
+
1019
+
1020
+
1021
+ # Get unique columns
1022
+
1023
+ unique_columns = []
1024
+
1025
+ seen = set()
1026
+
1027
+
1028
+
1029
+ for col in all_columns:
1030
+
1031
+ if col not in seen:
1032
+
1033
+ seen.add(col)
1034
+
1035
+ unique_columns.append(col)
1036
+
1037
+
1038
+
1039
+ if not unique_columns:
1040
+
1041
+ unique_columns = ["No columns available"]
1042
+
1043
+
1044
+
1045
+ print(f"Found {len(unique_columns)} unique columns across datasets")
1046
+
1047
+ return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns
1048
+
1049
+
1050
+
1051
+ def get_columns():
1052
+
1053
+ """Directly gets columns from in-memory datasets"""
1054
+
1055
+ all_columns = []
1056
+
1057
+
1058
+
1059
+ for var_name, df in st.session_state.in_memory_datasets.items():
1060
+
1061
+ if isinstance(df, pd.DataFrame):
1062
+
1063
+ all_columns.extend(df.columns.tolist())
1064
+
1065
+
1066
+
1067
+ # Remove duplicates while preserving order
1068
+
1069
+ unique_columns = []
1070
+
1071
+ seen = set()
1072
+
1073
+
1074
+
1075
+ for col in all_columns:
1076
+
1077
+ if col not in seen:
1078
+
1079
+ seen.add(col)
1080
+
1081
+ unique_columns.append(col)
1082
+
1083
+
1084
+
1085
+ if not unique_columns:
1086
+
1087
+ unique_columns = ["No columns available"]
1088
+
1089
+
1090
+
1091
+ print(f"Populating dropdowns with {len(unique_columns)} columns")
1092
+
1093
+ return unique_columns
1094
+
1095
+
1096
+
1097
+ # === FUNCTIONS ===
1098
+
1099
+ import openai
1100
+
1101
+ import pandas as pd
1102
+
1103
+ import json
1104
+
1105
+ import re
1106
+
1107
+
1108
+
1109
+ def standard_clean(df):
1110
+
1111
+ df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns]
1112
+
1113
+ df.drop_duplicates(inplace=True)
1114
+
1115
+ df.dropna(axis=1, how='all', inplace=True)
1116
+
1117
+ df.dropna(axis=0, how='all', inplace=True)
1118
+
1119
+ for col in df.select_dtypes(include='object').columns:
1120
+
1121
+ df[col] = df[col].astype(str).str.strip()
1122
+
1123
+ return df
1124
+
1125
+
1126
+
1127
+ def query_openai(prompt):
1128
+
1129
+ try:
1130
+
1131
+ # Use the configured API key and model from session state
1132
+
1133
+ api_key = st.session_state.api_key
1134
+
1135
+ model = st.session_state.selected_model
1136
+
1137
+
1138
+
1139
+ if st.session_state.ai_provider == "openai":
1140
+
1141
+ client = openai.OpenAI(api_key=api_key)
1142
+
1143
+ response = client.chat.completions.create(
1144
+
1145
+ model=model,
1146
+
1147
+ messages=[{"role": "user", "content": prompt}],
1148
+
1149
+ temperature=0.7
1150
+
1151
+ )
1152
+
1153
+ return response.choices[0].message.content
1154
+
1155
+ elif st.session_state.ai_provider == "groq":
1156
+
1157
+ from groq import Groq
1158
+
1159
+ client = Groq(api_key=api_key)
1160
+
1161
+ response = client.chat.completions.create(
1162
+
1163
+ model=model,
1164
+
1165
+ messages=[{"role": "user", "content": prompt}],
1166
+
1167
+ temperature=0.7
1168
+
1169
+ )
1170
+
1171
+ return response.choices[0].message.content
1172
+
1173
+ except Exception as e:
1174
+
1175
+ print(f"API Error: {e}")
1176
+
1177
+ return "{}"
1178
+
1179
+
1180
+
1181
+ def llm_suggest_cleaning(df):
1182
+
1183
+ sample = df.head(10).to_csv(index=False)
1184
+
1185
+ prompt = f"""
1186
+
1187
+ You are a professional data wrangler. Below is a sample of a messy dataset.
1188
+
1189
+
1190
+
1191
+ Return a Python dictionary with the following keys:
1192
+
1193
+
1194
+
1195
+ 1. rename_columns – fix unclear or inconsistent column names
1196
+
1197
+ 2. convert_types – correct datatypes: int, float, str, or date
1198
+
1199
+ 3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0
1200
+
1201
+ 4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes)
1202
+
1203
+
1204
+
1205
+ Do not drop any rows or columns. Your output must be a valid Python dict.
1206
+
1207
+
1208
+
1209
+ Example:
1210
+
1211
+ {{
1212
+
1213
+ "rename_columns": {{"dob": "date_of_birth"}},
1214
+
1215
+ "convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}},
1216
+
1217
+ "fill_missing": {{"gender": "mode", "salary": -1}},
1218
+
1219
+ "value_map": {{
1220
+
1221
+ "gender": {{"M": "Male", "F": "Female"}},
1222
+
1223
+ "subscribed": {{"Y": "Yes", "N": "No"}}
1224
+
1225
+ }}
1226
+
1227
+ }}
1228
+
1229
+ Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning.
1230
+
1231
+ Sample data:
1232
+
1233
+ {sample}
1234
+
1235
+ """
1236
+
1237
+ raw_response = query_openai(prompt)
1238
+
1239
+ try:
1240
+
1241
+ suggestions = eval(raw_response)
1242
+
1243
+ return suggestions
1244
+
1245
+ except:
1246
+
1247
+ print("Could not parse suggestions.")
1248
+
1249
+ return {
1250
+
1251
+ "rename_columns": {},
1252
+
1253
+ "convert_types": {},
1254
+
1255
+ "fill_missing": {},
1256
+
1257
+ "value_map": {}
1258
+
1259
+ }
1260
+
1261
+
1262
+
1263
+ def apply_suggestions(df, suggestions):
1264
+
1265
+ df.rename(columns=suggestions.get("rename_columns", {}), inplace=True)
1266
+
1267
+
1268
+
1269
+ for col, dtype in suggestions.get("convert_types", {}).items():
1270
+
1271
+ if col not in df.columns:
1272
+
1273
+ continue
1274
+
1275
+ try:
1276
+
1277
+ if dtype == "int":
1278
+
1279
+ df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64")
1280
+
1281
+ elif dtype == "float":
1282
+
1283
+ df[col] = pd.to_numeric(df[col], errors='coerce')
1284
+
1285
+ elif dtype == "str":
1286
+
1287
+ df[col] = df[col].astype(str)
1288
+
1289
+ elif dtype == "date":
1290
+
1291
+ df[col] = pd.to_datetime(df[col], errors='coerce')
1292
+
1293
+ except:
1294
+
1295
+ print(f"Failed to convert {col} to {dtype}")
1296
+
1297
+
1298
+
1299
+ for col, method in suggestions.get("fill_missing", {}).items():
1300
+
1301
+ if col not in df.columns:
1302
+
1303
+ continue
1304
+
1305
+ try:
1306
+
1307
+ if method == "mean":
1308
+
1309
+ df[col].fillna(df[col].mean(), inplace=True)
1310
+
1311
+ elif method == "median":
1312
+
1313
+ df[col].fillna(df[col].median(), inplace=True)
1314
+
1315
+ elif method == "mode":
1316
+
1317
+ df[col].fillna(df[col].mode().iloc[0], inplace=True)
1318
+
1319
+ elif isinstance(method, str):
1320
+
1321
+ df[col].fillna(method, inplace=True)
1322
+
1323
+ except:
1324
+
1325
+ print(f"Could not fill missing values for {col}")
1326
+
1327
+
1328
+
1329
+ for col, mapping in suggestions.get("value_map", {}).items():
1330
+
1331
+ if col in df.columns:
1332
+
1333
+ try:
1334
+
1335
+ df[col] = df[col].replace(mapping)
1336
+
1337
+ except:
1338
+
1339
+ print(f"Could not map values in {col}")
1340
+
1341
+
1342
+
1343
+ return df
1344
+
1345
+
1346
+
1347
+ def capture_dashboard_screenshot():
1348
+
1349
+ """Capture the entire dashboard as a single image"""
1350
+
1351
+ try:
1352
+
1353
+ # Create a figure that combines all dashboard plots
1354
+
1355
+ import plotly.graph_objects as go
1356
+
1357
+ from plotly.subplots import make_subplots
1358
+
1359
+
1360
+
1361
+ # Create a 2x2 subplot
1362
+
1363
+ fig = make_subplots(rows=2, cols=2,
1364
+
1365
+ subplot_titles=["Visualization 1", "Visualization 2",
1366
+
1367
+ "Visualization 3", "Visualization 4"])
1368
+
1369
+
1370
+
1371
+ # Add each plot from the dashboard to the combined figure
1372
+
1373
+ for i, plot in enumerate(st.session_state.dashboard_plots):
1374
+
1375
+ if plot is not None:
1376
+
1377
+ row = (i // 2) + 1
1378
+
1379
+ col = (i % 2) + 1
1380
+
1381
+
1382
+
1383
+ # Extract traces from the original figure and add to our subplot
1384
+
1385
+ for trace in plot.data:
1386
+
1387
+ fig.add_trace(trace, row=row, col=col)
1388
+
1389
+
1390
+
1391
+ # Copy layout properties for each subplot
1392
+
1393
+ for axis_type in ['xaxis', 'yaxis']:
1394
+
1395
+ axis_name = f"{axis_type}{i+1 if i > 0 else ''}"
1396
+
1397
+ subplot_name = f"{axis_type}{row}{col}"
1398
+
1399
+
1400
+
1401
+ # Copy axis properties if they exist
1402
+
1403
+ if hasattr(plot.layout, axis_name):
1404
+
1405
+ axis_props = getattr(plot.layout, axis_name)
1406
+
1407
+ fig.update_layout({subplot_name: axis_props})
1408
+
1409
+
1410
+
1411
+ # Update layout for better appearance
1412
+
1413
+ fig.update_layout(
1414
+
1415
+ height=800,
1416
+
1417
+ width=1000,
1418
+
1419
+ title_text="Dashboard Overview",
1420
+
1421
+ showlegend=False,
1422
+
1423
+ )
1424
+
1425
+
1426
+
1427
+ # Save to a temporary file
1428
+
1429
+ dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png"
1430
+
1431
+ fig.write_image(dashboard_path, scale=2) # Higher scale for better resolution
1432
+
1433
+ return dashboard_path
1434
+
1435
+
1436
+
1437
+ except Exception as e:
1438
+
1439
+ import traceback
1440
+
1441
+ print(f"Error capturing dashboard: {str(e)}")
1442
+
1443
+ print(traceback.format_exc())
1444
+
1445
+ return None
1446
+
1447
+
1448
+
1449
+ def generate_enhanced_pdf_report():
1450
+
1451
+ """Generate an enhanced PDF report with proper handling of base64 image data"""
1452
+
1453
+ try:
1454
+
1455
+ # Create a buffer for the PDF
1456
+
1457
+ buffer = io.BytesIO()
1458
+
1459
+
1460
+
1461
+ # Create the PDF document
1462
+
1463
+ doc = SimpleDocTemplate(buffer, pagesize=letter,
1464
+
1465
+ leftMargin=36, rightMargin=36,
1466
+
1467
+ topMargin=36, bottomMargin=36)
1468
+
1469
+
1470
+
1471
+ # Create custom styles with better formatting
1472
+
1473
+ styles = getSampleStyleSheet()
1474
+
1475
+
1476
+
1477
+ # Add custom styles with improved formatting
1478
+
1479
+ styles.add(ParagraphStyle(
1480
+
1481
+ name='ReportTitle',
1482
+
1483
+ parent=styles['Heading1'],
1484
+
1485
+ fontSize=24,
1486
+
1487
+ alignment=1, # Centered
1488
+
1489
+ spaceAfter=20,
1490
+
1491
+ textColor='#2C3E50' # Dark blue color
1492
+
1493
+ ))
1494
+
1495
+
1496
+
1497
+ styles.add(ParagraphStyle(
1498
+
1499
+ name='SectionHeader',
1500
+
1501
+ parent=styles['Heading2'],
1502
+
1503
+ fontSize=16,
1504
+
1505
+ spaceBefore=15,
1506
+
1507
+ spaceAfter=10,
1508
+
1509
+ textColor='#2C3E50',
1510
+
1511
+ borderWidth=1,
1512
+
1513
+ borderColor='#95A5A6',
1514
+
1515
+ borderPadding=5,
1516
+
1517
+ borderRadius=5
1518
+
1519
+ ))
1520
+
1521
+
1522
+
1523
+ styles.add(ParagraphStyle(
1524
+
1525
+ name='SubHeader',
1526
+
1527
+ parent=styles['Heading3'],
1528
+
1529
+ fontSize=14,
1530
+
1531
+ spaceBefore=10,
1532
+
1533
+ spaceAfter=8,
1534
+
1535
+ textColor='#34495E',
1536
+
1537
+ fontWeight='bold'
1538
+
1539
+ ))
1540
+
1541
+ styles.add(ParagraphStyle(
1542
+
1543
+ name='UserMessage',
1544
+
1545
+ parent=styles['Normal'],
1546
+
1547
+ fontSize=11,
1548
+
1549
+ leftIndent=10,
1550
+
1551
+ spaceBefore=8,
1552
+
1553
+ spaceAfter=4
1554
+
1555
+ ))
1556
+
1557
+
1558
+
1559
+ styles.add(ParagraphStyle(
1560
+
1561
+ name='AssistantMessage',
1562
+
1563
+ parent=styles['Normal'],
1564
+
1565
+ fontSize=11,
1566
+
1567
+ leftIndent=10,
1568
+
1569
+ spaceBefore=4,
1570
+
1571
+ spaceAfter=12,
1572
+
1573
+ textColor='#2980B9'
1574
+
1575
+ ))
1576
+
1577
+
1578
+
1579
+ styles.add(ParagraphStyle(
1580
+
1581
+ name='Timestamp',
1582
+
1583
+ parent=styles['Italic'],
1584
+
1585
+ fontSize=10,
1586
+
1587
+ textColor='#7F8C8D',
1588
+
1589
+ alignment=2 # Right aligned
1590
+
1591
+ ))
1592
+
1593
+
1594
+
1595
+ # Create the document content
1596
+
1597
+ elements = []
1598
+
1599
+
1600
+
1601
+ # Add title
1602
+
1603
+ elements.append(Paragraph('Data Analysis Report', styles['ReportTitle']))
1604
+
1605
+
1606
+
1607
+ # Add timestamp
1608
+
1609
+ elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
1610
+
1611
+ styles['Timestamp']))
1612
+
1613
+ elements.append(Spacer(1, 0.5*inch))
1614
+
1615
+
1616
+
1617
+ # Add conversation history with better formatting
1618
+
1619
+ elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader']))
1620
+
1621
+
1622
+
1623
+ if st.session_state.chat_history:
1624
+
1625
+ for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
1626
+
1627
+ # Format user message with proper styling
1628
+
1629
+ elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader']))
1630
+
1631
+ user_msg_formatted = user_msg.replace('\n', '<br/>')
1632
+
1633
+ elements.append(Paragraph(user_msg_formatted, styles['UserMessage']))
1634
+
1635
+
1636
+
1637
+ # Process assistant message to handle visualization
1638
+
1639
+ # Look for markdown image syntax with base64 data
1640
+
1641
+ base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)'
1642
+
1643
+
1644
+
1645
+ # Check if the message contains visualizations
1646
+
1647
+ if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg):
1648
+
1649
+ # Split the message at the Visualizations header if it exists
1650
+
1651
+ if '### Visualizations' in assistant_msg:
1652
+
1653
+ parts = assistant_msg.split('### Visualizations', 1)
1654
+
1655
+ text_part = parts[0]
1656
+
1657
+ viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else ""
1658
+
1659
+ else:
1660
+
1661
+ # If no header but still has visualization
1662
+
1663
+ match = re.search(base64_pattern, assistant_msg)
1664
+
1665
+ text_part = assistant_msg[:match.start()]
1666
+
1667
+ viz_part = assistant_msg[match.start():]
1668
+
1669
+
1670
+
1671
+ # Format the text part
1672
+
1673
+ elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
1674
+
1675
+ text_part = text_part.replace('\n', '<br/>')
1676
+
1677
+ elements.append(Paragraph(text_part, styles['AssistantMessage']))
1678
+
1679
+
1680
+
1681
+ # Process visualizations
1682
+
1683
+ matches = re.findall(base64_pattern, viz_part)
1684
+
1685
+ for j, base64_data in enumerate(matches):
1686
+
1687
+ try:
1688
+
1689
+ # Decode the base64 image
1690
+
1691
+ image_data = base64.b64decode(base64_data)
1692
+
1693
+
1694
+
1695
+ # Create a temporary file for the image
1696
+
1697
+ temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png"
1698
+
1699
+
1700
+
1701
+ with open(temp_img_path, 'wb') as f:
1702
+
1703
+ f.write(image_data)
1704
+
1705
+
1706
+
1707
+ # Add the image to the PDF
1708
+
1709
+ elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader']))
1710
+
1711
+ elements.append(Spacer(1, 0.1*inch))
1712
+
1713
+ img = Image(temp_img_path, width=6*inch, height=4*inch)
1714
+
1715
+ elements.append(img)
1716
+
1717
+ elements.append(Spacer(1, 0.2*inch))
1718
+
1719
+ except Exception as e:
1720
+
1721
+ print(f"Error processing base64 image: {str(e)}")
1722
+
1723
+ elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]",
1724
+
1725
+ styles['Normal']))
1726
+
1727
+ else:
1728
+
1729
+ # No visualizations, just format the text
1730
+
1731
+ elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
1732
+
1733
+ assistant_msg_formatted = assistant_msg.replace('\n', '<br/>')
1734
+
1735
+ if len(assistant_msg_formatted) > 1500:
1736
+
1737
+ assistant_msg_formatted = assistant_msg_formatted[:1500] + '...'
1738
+
1739
+ elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage']))
1740
+
1741
+
1742
+
1743
+ elements.append(Spacer(1, 0.2*inch))
1744
+
1745
+ else:
1746
+
1747
+ elements.append(Paragraph('No conversation history available.', styles['Normal']))
1748
+
1749
+
1750
+
1751
+ # Force a page break before the dashboard
1752
+
1753
+ elements.append(PageBreak())
1754
+
1755
+
1756
+
1757
+ # Add dashboard section header
1758
+
1759
+ elements.append(Paragraph('Dashboard Overview', styles['SectionHeader']))
1760
+
1761
+ elements.append(Spacer(1, 0.2*inch))
1762
+
1763
+
1764
+
1765
+ # Capture the dashboard as a single image
1766
+
1767
+ dashboard_img_path = capture_dashboard_screenshot()
1768
+
1769
+
1770
+
1771
+ if dashboard_img_path:
1772
+
1773
+ # Calculate available width
1774
+
1775
+ available_width = doc.width
1776
+
1777
+
1778
+
1779
+ # Create PIL image to get dimensions
1780
+
1781
+ pil_img = PILImage.open(dashboard_img_path)
1782
+
1783
+ img_width, img_height = pil_img.size
1784
+
1785
+
1786
+
1787
+ # Calculate scaling factor to fit within page width
1788
+
1789
+ scale_factor = available_width / img_width
1790
+
1791
+
1792
+
1793
+ # Calculate new height based on aspect ratio
1794
+
1795
+ new_height = img_height * scale_factor
1796
+
1797
+
1798
+
1799
+ # Add the image with scaled dimensions
1800
+
1801
+ img = Image(dashboard_img_path, width=available_width, height=new_height)
1802
+
1803
+ elements.append(img)
1804
+
1805
+ else:
1806
+
1807
+ # Fallback: Add individual plots if combined dashboard fails
1808
+
1809
+ plot_count = 0
1810
+
1811
+ for i, plot in enumerate(st.session_state.dashboard_plots):
1812
+
1813
+ if plot is not None:
1814
+
1815
+ plot_count += 1
1816
+
1817
+
1818
+
1819
+ # Convert plotly figure to image
1820
+
1821
+ img_bytes = io.BytesIO()
1822
+
1823
+ plot.write_image(img_bytes, format='png', width=500, height=300)
1824
+
1825
+ img_bytes.seek(0)
1826
+
1827
+
1828
+
1829
+ # Create a temporary file for the image
1830
+
1831
+ temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png"
1832
+
1833
+
1834
+
1835
+ with open(temp_img_path, 'wb') as f:
1836
+
1837
+ f.write(img_bytes.getvalue())
1838
+
1839
+
1840
+
1841
+ # Add to PDF with appropriate caption and formatting
1842
+
1843
+ elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader']))
1844
+
1845
+ elements.append(Spacer(1, 0.1*inch))
1846
+
1847
+
1848
+
1849
+ # Add the image with proper scaling
1850
+
1851
+ img = Image(temp_img_path, width=6.5*inch, height=4*inch)
1852
+
1853
+ elements.append(img)
1854
+
1855
+ elements.append(Spacer(1, 0.3*inch))
1856
+
1857
+
1858
+
1859
+ if plot_count == 0:
1860
+
1861
+ elements.append(Paragraph('No visualizations have been added to the dashboard.',
1862
+
1863
+ styles['Normal']))
1864
+
1865
+
1866
+
1867
+ # Build the PDF with improved formatting
1868
+
1869
+ doc.build(elements)
1870
+
1871
+
1872
+
1873
+ # Get the value of the buffer
1874
+
1875
+ pdf_value = buffer.getvalue()
1876
+
1877
+ buffer.close()
1878
+
1879
+
1880
+
1881
+ return pdf_value
1882
+
1883
+
1884
+
1885
+ except Exception as e:
1886
+
1887
+ import traceback
1888
+
1889
+ print(f"Error generating enhanced PDF report: {str(e)}")
1890
+
1891
+ print(traceback.format_exc())
1892
+
1893
+ return None
1894
+
1895
+
1896
+
1897
+ def chat_with_workflow(message, history, dataset_info):
1898
+
1899
+ """Send user query to the workflow and get response"""
1900
+
1901
+
1902
+
1903
+ if not dataset_info:
1904
+
1905
+ return "Please upload at least one dataset before asking questions."
1906
+
1907
+
1908
+
1909
+ # Check if we have a valid API key and model
1910
+
1911
+ if not st.session_state.api_key:
1912
+
1913
+ return "Please set up your API key and model in the Settings tab before chatting."
1914
+
1915
+
1916
+
1917
+ print(f"Chat with workflow called with {len(dataset_info)} datasets")
1918
+
1919
+
1920
+
1921
+ try:
1922
+
1923
+ # Extract chat history for context (last 3 exchanges)
1924
+
1925
+ max_history = 3
1926
+
1927
+ previous_messages = []
1928
+
1929
+
1930
+
1931
+ if history:
1932
+
1933
+ start_idx = max(0, len(history) - max_history)
1934
+
1935
+ recent_history = history[start_idx:]
1936
+
1937
+
1938
+
1939
+ for exchange in recent_history:
1940
+
1941
+ if exchange[0]: # User message
1942
+
1943
+ previous_messages.append(HumanMessage(content=exchange[0]))
1944
+
1945
+ if exchange[1]: # AI response
1946
+
1947
+ previous_messages.append(AIMessage(content=exchange[1]))
1948
+
1949
+
1950
+
1951
+ # Initialize the workflow state
1952
+
1953
+ state = AgentState(
1954
+
1955
+ messages=previous_messages + [HumanMessage(content=message)],
1956
+
1957
+ input_data=dataset_info,
1958
+
1959
+ intermediate_outputs=[],
1960
+
1961
+ current_variables=st.session_state.persistent_vars,
1962
+
1963
+ output_image_paths=[]
1964
+
1965
+ )
1966
+
1967
+
1968
+
1969
+ # Execute the workflow
1970
+
1971
+ print("Executing workflow...")
1972
+
1973
+ result = chain.invoke(state)
1974
+
1975
+ print("Workflow execution completed")
1976
+
1977
+
1978
+
1979
+ # Extract messages from the result
1980
+
1981
+ messages = result["messages"]
1982
+
1983
+
1984
+
1985
+ # Format the response - only get the latest response
1986
+
1987
+ response = ""
1988
+
1989
+ if messages:
1990
+
1991
+ latest_message = messages[-1] # Get only the last message
1992
+
1993
+ if hasattr(latest_message, "content"):
1994
+
1995
+ content = latest_message.content
1996
+
1997
+
1998
+
1999
+ # Clean up the response
2000
+
2001
+ # Remove any instances where the user's message is repeated
2002
+
2003
+ if message in content:
2004
+
2005
+ content = content.split(message)[-1].strip()
2006
+
2007
+
2008
+
2009
+ # Remove any chat history markers
2010
+
2011
+ content_lines = content.split('\n')
2012
+
2013
+ filtered_lines = [line for line in content_lines
2014
+
2015
+ if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))]
2016
+
2017
+ content = '\n'.join(filtered_lines)
2018
+
2019
+
2020
+
2021
+ response = content.strip() + "\n\n"
2022
+
2023
+
2024
+
2025
+ # Handle visualizations
2026
+
2027
+ if "output_image_paths" in result and result["output_image_paths"]:
2028
+
2029
+ response += "### Visualizations\n\n"
2030
+
2031
+ for img_path in result["output_image_paths"]:
2032
+
2033
+ try:
2034
+
2035
+ full_path = os.path.join(st.session_state.images_dir, img_path)
2036
+
2037
+ with open(full_path, 'rb') as f:
2038
+
2039
+ fig = pickle.load(f)
2040
+
2041
+
2042
+
2043
+ # Convert plotly figure to image
2044
+
2045
+ img_bytes = BytesIO()
2046
+
2047
+ fig.update_layout(width=800, height=500)
2048
+
2049
+ pio.write_image(fig, img_bytes, format='png')
2050
+
2051
+ img_bytes.seek(0)
2052
+
2053
+
2054
+
2055
+ # Convert to base64 for markdown image
2056
+
2057
+ b64_img = base64.b64encode(img_bytes.read()).decode()
2058
+
2059
+ response += f"![Visualization](data:image/png;base64,{b64_img})\n\n"
2060
+
2061
+ except Exception as e:
2062
+
2063
+ response += f"Error loading visualization: {str(e)}\n\n"
2064
+
2065
+
2066
+
2067
+ return response
2068
+
2069
+
2070
+
2071
+ except Exception as e:
2072
+
2073
+ import traceback
2074
+
2075
+ print(f"Error in chat_with_workflow: {str(e)}")
2076
+
2077
+ print(traceback.format_exc())
2078
+
2079
+ return f"Error executing workflow: {str(e)}"
2080
+
2081
+
2082
+
2083
+ def auto_generate_dashboard(dataset_info):
2084
+
2085
+ """Generate an automatic dashboard with four plots"""
2086
+
2087
+
2088
+
2089
+ if not dataset_info:
2090
+
2091
+ return "Please upload a dataset first.", [None, None, None, None]
2092
+
2093
+
2094
+
2095
+ prompt = """
2096
+
2097
+ You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends).
2098
+
2099
+
2100
+
2101
+ Use plotly and store the plots in a list named plotly_figures.
2102
+
2103
+
2104
+
2105
+ Include multivariate plots using color/size/facets when helpful.
2106
+
2107
+ """
2108
+
2109
+
2110
+
2111
+ state = AgentState(
2112
+
2113
+ messages=[HumanMessage(content=prompt)],
2114
+
2115
+ input_data=dataset_info,
2116
+
2117
+ intermediate_outputs=[],
2118
+
2119
+ current_variables=st.session_state.persistent_vars,
2120
+
2121
+ output_image_paths=[]
2122
+
2123
+ )
2124
+
2125
+
2126
+
2127
+ result = chain.invoke(state)
2128
+
2129
+ figures = []
2130
+
2131
+
2132
+
2133
+ if "output_image_paths" in result:
2134
+
2135
+ for img_path in result["output_image_paths"][:4]:
2136
+
2137
+ try:
2138
+
2139
+ full_path = os.path.join(st.session_state.images_dir, img_path)
2140
+
2141
+ with open(full_path, 'rb') as f:
2142
+
2143
+ fig = pickle.load(f)
2144
+
2145
+ figures.append(fig)
2146
+
2147
+ except Exception as e:
2148
+
2149
+ print(f"Error loading figure: {e}")
2150
+
2151
+
2152
+
2153
+ while len(figures) < 4:
2154
+
2155
+ figures.append(None)
2156
+
2157
+
2158
+
2159
+ st.session_state.dashboard_plots = figures
2160
+
2161
+ return "Dashboard generated!", figures
2162
+
2163
+
2164
+
2165
+ def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col):
2166
+
2167
+ """Generate custom plots based on user-selected columns"""
2168
+
2169
+
2170
+
2171
+ if not dataset_info or not x_col or not y_col:
2172
+
2173
+ return [None, None, None]
2174
+
2175
+
2176
+
2177
+ prompt = f"""
2178
+
2179
+ You are a data visualization expert.
2180
+
2181
+
2182
+
2183
+ Create 3 insightful visualizations using Plotly based on:
2184
+
2185
+
2186
+
2187
+ - X-axis: {x_col}
2188
+
2189
+ - Y-axis: {y_col}
2190
+
2191
+ - Facet (optional): {facet_col if facet_col != 'None' else 'None'}
2192
+
2193
+
2194
+
2195
+ Try to find interesting relationships, trends, or clusters using appropriate chart types.
2196
+
2197
+
2198
+
2199
+ Use `plotly_figures` list and avoid using fig.show().
2200
+
2201
+ """
2202
+
2203
+
2204
+
2205
+ state = AgentState(
2206
+
2207
+ messages=[HumanMessage(content=prompt)],
2208
+
2209
+ input_data=dataset_info,
2210
+
2211
+ intermediate_outputs=[],
2212
+
2213
+ current_variables=st.session_state.persistent_vars,
2214
+
2215
+ output_image_paths=[]
2216
+
2217
+ )
2218
+
2219
+
2220
+
2221
+ result = chain.invoke(state)
2222
+
2223
+ figures = []
2224
+
2225
+
2226
+
2227
+ if "output_image_paths" in result:
2228
+
2229
+ for img_path in result["output_image_paths"][:3]:
2230
+
2231
+ try:
2232
+
2233
+ full_path = os.path.join(st.session_state.images_dir, img_path)
2234
+
2235
+ with open(full_path, 'rb') as f:
2236
+
2237
+ fig = pickle.load(f)
2238
+
2239
+ figures.append(fig)
2240
+
2241
+ except Exception as e:
2242
+
2243
+ print(f"Error loading figure: {e}")
2244
+
2245
+
2246
+
2247
+ while len(figures) < 3:
2248
+
2249
+ figures.append(None)
2250
+
2251
+ return figures
2252
+
2253
+
2254
+
2255
+ def remove_plot(index):
2256
+
2257
+ """Remove a plot from the dashboard"""
2258
+
2259
+ if 0 <= index < len(st.session_state.dashboard_plots):
2260
+
2261
+ st.session_state.dashboard_plots[index] = None
2262
+
2263
+
2264
+
2265
+ def respond(message):
2266
+
2267
+ """Handle chat message response"""
2268
+
2269
+ if not st.session_state.dataset_metadata_list:
2270
+
2271
+ bot_message = "Please upload at least one dataset before asking questions."
2272
+
2273
+ else:
2274
+
2275
+ bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list)
2276
+
2277
+
2278
+
2279
+ st.session_state.chat_history.append((message, bot_message))
2280
+
2281
+ st.rerun()
2282
+
2283
+
2284
+
2285
+ def save_plot_to_dashboard(plot_index):
2286
+
2287
+ """Callback for the Add Plot button"""
2288
+
2289
+ for i in range(len(st.session_state.dashboard_plots)):
2290
+
2291
+ if st.session_state.dashboard_plots[i] is None:
2292
+
2293
+ # Found an empty slot
2294
+
2295
+ st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index]
2296
+
2297
+ return
2298
+
2299
+
2300
+
2301
+ # Streamlit UI
2302
+
2303
+ st.set_page_config(page_title="QueryMind 🧠", layout="wide")
2304
+
2305
+ st.title("QueryMind 🧠 - Data Assistant")
2306
+
2307
+ st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.")
2308
+
2309
+
2310
+
2311
+ # Create tabs
2312
+
2313
+ tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report", "Settings"])
2314
+
2315
+
2316
+
2317
+ with tab1:
2318
+
2319
+ st.header("Upload Datasets")
2320
+
2321
+ uploaded_files = st.file_uploader("Upload CSV or Excel Files",
2322
+
2323
+ accept_multiple_files=True,
2324
+
2325
+ type=['csv', 'xlsx', 'xls'])
2326
+
2327
+
2328
+
2329
+ if uploaded_files and st.button("Process Uploaded Files"):
2330
+
2331
+ with st.spinner("Processing files..."):
2332
+
2333
+ preview, metadata_list, columns = process_file_upload(uploaded_files)
2334
+
2335
+ st.session_state.columns = columns
2336
+
2337
+
2338
+
2339
+ # Display basic information about processed files
2340
+
2341
+ st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)")
2342
+
2343
+
2344
+
2345
+ # Show detailed preview for each dataset
2346
+
2347
+ st.subheader("Dataset Previews")
2348
+
2349
+
2350
+
2351
+ for dataset_name, df in st.session_state.in_memory_datasets.items():
2352
+
2353
+ with st.expander(f"Preview: {dataset_name}"):
2354
+
2355
+ # Display dataset info
2356
+
2357
+ st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}")
2358
+
2359
+
2360
+
2361
+ # Display column information
2362
+
2363
+ col_info = pd.DataFrame({
2364
+
2365
+ 'Column Name': df.columns,
2366
+
2367
+ 'Data Type': df.dtypes.astype(str),
2368
+
2369
+ 'Non-Null Count': df.count().values,
2370
+
2371
+ 'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns]
2372
+
2373
+ })
2374
+
2375
+
2376
+
2377
+ # Show column information in a compact table
2378
+
2379
+ st.write("**Column Information:**")
2380
+
2381
+ st.dataframe(col_info, use_container_width=True)
2382
+
2383
+
2384
+
2385
+ # Show actual data preview
2386
+
2387
+ st.write("**Data Preview (First 10 rows):**")
2388
+
2389
+ st.dataframe(df.head(10), use_container_width=True)
2390
+
2391
+
2392
+
2393
+ # Provide hint for the next steps
2394
+
2395
+ st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.")
2396
+
2397
+
2398
+
2399
+ with tab2:
2400
+
2401
+ st.header("Data Cleaning")
2402
+
2403
+
2404
+
2405
+ if 'cleaning_done' not in st.session_state:
2406
+
2407
+ st.session_state.cleaning_done = False
2408
+
2409
+
2410
+
2411
+ if 'cleaned_datasets' not in st.session_state:
2412
+
2413
+ st.session_state.cleaned_datasets = {}
2414
+
2415
+
2416
+
2417
+ if 'cleaning_summaries' not in st.session_state:
2418
+
2419
+ st.session_state.cleaning_summaries = {}
2420
+
2421
+
2422
+
2423
+ if st.session_state.get("in_memory_datasets"):
2424
+
2425
+ if not st.session_state.cleaning_done:
2426
+
2427
+ if st.button("Run Data Cleaning"):
2428
+
2429
+ with st.spinner("Running LLM-assisted cleaning..."):
2430
+
2431
+ for name, df in st.session_state.in_memory_datasets.items():
2432
+
2433
+ raw_df = df.copy()
2434
+
2435
+ df_std = standard_clean(raw_df.copy())
2436
+
2437
+ suggestions = llm_suggest_cleaning(df_std.copy())
2438
+
2439
+ df_clean = apply_suggestions(df_std.copy(), suggestions)
2440
+
2441
+ st.session_state.cleaned_datasets[name] = df_clean
2442
+
2443
+ st.session_state.cleaning_summaries[name] = suggestions
2444
+
2445
+ st.session_state.cleaning_done = True
2446
+
2447
+ st.rerun()
2448
+
2449
+ else:
2450
+
2451
+ st.info("Click Run Data Cleaning to clean your datasets using the LLM.")
2452
+
2453
+ else:
2454
+
2455
+ for name, df_clean in st.session_state.cleaned_datasets.items():
2456
+
2457
+ raw_df = st.session_state.in_memory_datasets[name]
2458
+
2459
+
2460
+
2461
+ st.subheader(f"Dataset: {name}")
2462
+
2463
+ col1, col2 = st.columns(2)
2464
+
2465
+
2466
+
2467
+ with col1:
2468
+
2469
+ st.markdown("Original Data (First 5 Rows)")
2470
+
2471
+ st.dataframe(raw_df.head())
2472
+
2473
+
2474
+
2475
+ with col2:
2476
+
2477
+ st.markdown("Cleaned Data (First 5 Rows)")
2478
+
2479
+ st.dataframe(df_clean.head())
2480
+
2481
+
2482
+
2483
+ st.markdown("Summary of Cleaning Actions")
2484
+
2485
+ suggestions = st.session_state.cleaning_summaries[name]
2486
+
2487
+ summary_text = ""
2488
+
2489
+
2490
+
2491
+ if suggestions:
2492
+
2493
+ for key, value in suggestions.items():
2494
+
2495
+ summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n"
2496
+
2497
+ st.markdown(summary_text)
2498
+
2499
+
2500
+
2501
+ st.markdown("Refine the Cleaning (Natural Language Instructions)")
2502
+
2503
+ user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'",
2504
+
2505
+ key=f"user_input_{name}")
2506
+
2507
+
2508
+
2509
+ if f'corrections_{name}' not in st.session_state:
2510
+
2511
+ st.session_state[f'corrections_{name}'] = []
2512
+
2513
+
2514
+
2515
+ if st.button("Apply Correction", key=f'apply_correction_{name}'):
2516
+
2517
+ if user_input.strip():
2518
+
2519
+ correction_prompt = f"""
2520
+
2521
+ You are a data cleaning expert. Below is a previously cleaned dataset with these actions:
2522
+
2523
+
2524
+
2525
+ {summary_text}
2526
+
2527
+
2528
+
2529
+ The user now wants the following additional instruction:
2530
+
2531
+ \"{user_input.strip()}\"
2532
+
2533
+
2534
+
2535
+ Write only the Python code that modifies the pandas DataFrame `df` accordingly.
2536
+
2537
+ Do not include explanations or markdown.
2538
+
2539
+ """
2540
+
2541
+ correction_code = query_openai(correction_prompt)
2542
+
2543
+
2544
+
2545
+ try:
2546
+
2547
+ df = st.session_state.cleaned_datasets[name].copy()
2548
+
2549
+ local_vars = {"df": df}
2550
+
2551
+ exec(correction_code, {}, local_vars)
2552
+
2553
+ df_updated = local_vars["df"]
2554
+
2555
+
2556
+
2557
+ st.session_state.cleaned_datasets[name] = df_updated
2558
+
2559
+ st.session_state[f'corrections_{name}'].append((user_input, correction_code))
2560
+
2561
+ st.success("Correction applied.")
2562
+
2563
+ st.rerun()
2564
+
2565
+
2566
+
2567
+ except Exception as e:
2568
+
2569
+ st.error(f"Failed to apply correction: {str(e)}")
2570
+
2571
+
2572
+
2573
+ if st.session_state[f'corrections_{name}']:
2574
+
2575
+ st.markdown("Applied Corrections")
2576
+
2577
+ for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']):
2578
+
2579
+ st.markdown(f"**Instruction:** {msg}")
2580
+
2581
+ st.code(code, language='python')
2582
+
2583
+
2584
+
2585
+ col1, col2 = st.columns([1, 2])
2586
+
2587
+ with col1:
2588
+
2589
+ if st.button("Reset Cleaning and Re-run"):
2590
+
2591
+ st.session_state.cleaning_done = False
2592
+
2593
+ st.rerun()
2594
+
2595
+
2596
+
2597
+ with col2:
2598
+
2599
+ if st.button("Finalize and Proceed to Visualizations"):
2600
+
2601
+ st.session_state.cleaning_finalized = True
2602
+
2603
+ st.rerun()
2604
+
2605
+ else:
2606
+
2607
+ st.info("Please upload and process datasets first.")
2608
+
2609
+
2610
+
2611
+ with tab3:
2612
+
2613
+ st.header("Chat with AI Assistant")
2614
+
2615
+
2616
+
2617
+ # Show API warning if not set
2618
+
2619
+ if not st.session_state.api_key:
2620
+
2621
+ st.warning("⚠️ Please set up your API key and model in the Settings tab before using the chat.")
2622
+
2623
+
2624
+
2625
+ st.markdown("""
2626
+
2627
+ ## Example Questions
2628
+
2629
+ - "What analysis can you perform on this dataset?"
2630
+
2631
+ - "Show me basic statistics for all columns"
2632
+
2633
+ - "Create a correlation heatmap"
2634
+
2635
+ - "Plot the distribution of a specific column"
2636
+
2637
+ - "What is the relationship between two columns?"
2638
+
2639
+ """)
2640
+
2641
+
2642
+
2643
+ # Display chat history
2644
+
2645
+ for exchange in st.session_state.chat_history:
2646
+
2647
+ with st.chat_message("user"):
2648
+
2649
+ st.write(exchange[0])
2650
+
2651
+ with st.chat_message("assistant"):
2652
+
2653
+ st.write(exchange[1])
2654
+
2655
+
2656
+
2657
+ # Chat input
2658
+
2659
+ if prompt := st.chat_input("Your question"):
2660
+
2661
+ with st.spinner("Thinking..."):
2662
+
2663
+ respond(prompt)
2664
+
2665
+
2666
+
2667
+ with tab4:
2668
+
2669
+ st.header("Auto Dashboard Generator")
2670
+
2671
+
2672
+
2673
+ # Dashboard controls
2674
+
2675
+ dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title")
2676
+
2677
+
2678
+ col1, col2 = st.columns(2)
2679
+
2680
+
2681
+
2682
+ with col1:
2683
+
2684
+ if st.button("Generate Suggested Dashboard (Auto)"):
2685
+
2686
+ if not st.session_state.api_key:
2687
+
2688
+ st.warning("⚠️ Please set up your API key and model in the Settings tab first.")
2689
+
2690
+ else:
2691
+
2692
+ with st.spinner("Generating dashboard..."):
2693
+
2694
+ message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list)
2695
+
2696
+ st.success(message)
2697
+
2698
+
2699
+
2700
+ with col2:
2701
+
2702
+ if st.button("Refresh Column Options"):
2703
+
2704
+ st.session_state.columns = get_columns()
2705
+
2706
+ st.rerun()
2707
+
2708
+
2709
+
2710
+ # Dashboard display
2711
+
2712
+ st.subheader("Dashboard")
2713
+
2714
+
2715
+
2716
+ # Row 1
2717
+
2718
+ col1, col2 = st.columns(2)
2719
+
2720
+
2721
+
2722
+ with col1:
2723
+
2724
+ if st.session_state.dashboard_plots[0]:
2725
+
2726
+ st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True)
2727
+
2728
+ if st.button("Remove Plot 1"):
2729
+
2730
+ remove_plot(0)
2731
+
2732
+ st.rerun()
2733
+
2734
+
2735
+
2736
+ with col2:
2737
+
2738
+ if st.session_state.dashboard_plots[1]:
2739
+
2740
+ st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True)
2741
+
2742
+ if st.button("Remove Plot 2"):
2743
+
2744
+ remove_plot(1)
2745
+
2746
+ st.rerun()
2747
+
2748
+
2749
+
2750
+ # Row 2
2751
+
2752
+ col3, col4 = st.columns(2)
2753
+
2754
+
2755
+
2756
+ with col3:
2757
+
2758
+ if st.session_state.dashboard_plots[2]:
2759
+
2760
+ st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True)
2761
+
2762
+ if st.button("Remove Plot 3"):
2763
+
2764
+ remove_plot(2)
2765
+
2766
+ st.rerun()
2767
+
2768
+
2769
+
2770
+ with col4:
2771
+
2772
+ if st.session_state.dashboard_plots[3]:
2773
+
2774
+ st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True)
2775
+
2776
+ if st.button("Remove Plot 4"):
2777
+
2778
+ remove_plot(3)
2779
+
2780
+ st.rerun()
2781
+
2782
+
2783
+
2784
+ # Custom plot generator
2785
+
2786
+ st.subheader("Custom Plot Generator")
2787
+
2788
+
2789
+
2790
+ # Column selection
2791
+
2792
+ col1, col2, col3 = st.columns(3)
2793
+
2794
+
2795
+
2796
+ with col1:
2797
+
2798
+ x_axis = st.selectbox("X-axis Column", options=st.session_state.columns)
2799
+
2800
+
2801
+
2802
+ with col2:
2803
+
2804
+ y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns)
2805
+
2806
+
2807
+
2808
+ with col3:
2809
+
2810
+ facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns)
2811
+
2812
+
2813
+ if st.button("Generate Custom Visualizations"):
2814
+
2815
+ if not st.session_state.api_key:
2816
+
2817
+ st.warning("⚠️ Please set up your API key and model in the Settings tab first.")
2818
+
2819
+ else:
2820
+
2821
+ with st.spinner("Generating custom visualizations..."):
2822
+
2823
+ custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet)
2824
+
2825
+ # Store plots in session state
2826
+
2827
+ for i, plot in enumerate(custom_plots):
2828
+
2829
+ if plot:
2830
+
2831
+ st.session_state.custom_plots_to_save[i] = plot
2832
+
2833
+
2834
+
2835
+ # Display custom plots with add buttons
2836
+
2837
+ for i, plot in enumerate(custom_plots):
2838
+
2839
+ if plot:
2840
+
2841
+ st.plotly_chart(plot, use_container_width=True)
2842
+
2843
+ st.button(
2844
+
2845
+ f"Add Plot {i+1} to Dashboard",
2846
+
2847
+ key=f"add_plot_{i}",
2848
+
2849
+ on_click=save_plot_to_dashboard,
2850
+
2851
+ args=(i,)
2852
+
2853
+ )
2854
+
2855
+
2856
+
2857
+ with tab5:
2858
+
2859
+ st.header("Generate Analysis Report")
2860
+
2861
+
2862
+
2863
+ st.markdown("""
2864
+
2865
+ Generate a PDF report containing:
2866
+
2867
+ - Dashboard visualizations
2868
+
2869
+ - Chat conversation history
2870
+
2871
+ """)
2872
+
2873
+
2874
+
2875
+ report_title = st.text_input("Report Title (Optional)", "Data Analysis Report")
2876
+
2877
+
2878
+
2879
+ if st.button("Generate PDF Report"):
2880
+
2881
+ if not st.session_state.api_key:
2882
+
2883
+ st.warning("⚠️ Please set up your API key and model in the Settings tab first.")
2884
+
2885
+ else:
2886
+
2887
+ with st.spinner("Generating report..."):
2888
+
2889
+ pdf_data = generate_enhanced_pdf_report()
2890
+
2891
+ if pdf_data:
2892
+
2893
+ # Create download button for PDF
2894
+
2895
+ b64_pdf = base64.b64encode(pdf_data).decode('utf-8')
2896
+
2897
+ # Create download link
2898
+
2899
+ pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>'
2900
+
2901
+ st.markdown("### Your report is ready!")
2902
+
2903
+ st.markdown(pdf_download_link, unsafe_allow_html=True)
2904
+
2905
+ # Preview option (simplified)
2906
+
2907
+ with st.expander("Preview Report"):
2908
+
2909
+ st.warning("PDF preview is not available in Streamlit, please download the report to view it.")
2910
+
2911
+ else:
2912
+
2913
+ st.error("Failed to generate the report. Please try again.")
2914
+
2915
+
2916
+
2917
+ with tab6:
2918
+
2919
+ st.header("AI Provider Settings")
2920
+
2921
+
2922
+
2923
+ # AI Provider selection
2924
+
2925
+ provider = st.radio("Select AI Provider",
2926
+
2927
+ options=["OpenAI", "Groq"],
2928
+
2929
+ index=0 if st.session_state.ai_provider == "openai" else 1,
2930
+
2931
+ horizontal=True)
2932
+
2933
+
2934
+
2935
+ # Update session state based on selection
2936
+
2937
+ st.session_state.ai_provider = provider.lower()
2938
+
2939
+
2940
+
2941
+ # API Key input
2942
+
2943
+ api_key = st.text_input("Enter API Key",
2944
+
2945
+ value=st.session_state.api_key,
2946
+
2947
+ type="password",
2948
+
2949
+ help="Your API key for the selected provider")
2950
+
2951
+
2952
+
2953
+ # Display different model options based on provider
2954
+
2955
+ if st.session_state.ai_provider == "openai":
2956
+
2957
+ model_options = OPENAI_MODELS
2958
+
2959
+ model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable."
2960
+
2961
+ else: # groq
2962
+
2963
+ model_options = GROQ_MODELS
2964
+
2965
+ model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest."
2966
+
2967
+
2968
+
2969
+ # Model selection
2970
+
2971
+ selected_model = st.selectbox("Select Model",
2972
+
2973
+ options=model_options,
2974
+
2975
+ index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0,
2976
+
2977
+ help=model_help)
2978
+
2979
+
2980
+
2981
+ # Save button
2982
+
2983
+ if st.button("Save Settings"):
2984
+
2985
+ st.session_state.api_key = api_key
2986
+
2987
+ st.session_state.selected_model = selected_model
2988
+
2989
+
2990
+
2991
+ # Test the API key and model
2992
+
2993
+ try:
2994
+
2995
+ # Initialize LLM using the provided settings
2996
+
2997
+ test_llm = initialize_llm()
2998
+
2999
+ if test_llm:
3000
+
3001
+ st.success(f"✅ Successfully configured {provider} with model: {selected_model}")
3002
+
3003
+ else:
3004
+
3005
+ st.error("Failed to initialize the AI provider. Please check your API key and model selection.")
3006
+
3007
+ except Exception as e:
3008
+
3009
+ st.error(f"Error testing settings: {str(e)}")
3010
+
3011
+
3012
+
3013
+ # Display current settings
3014
+
3015
+ st.subheader("Current Settings")
3016
+
3017
+ settings_info = f"""
3018
+
3019
+ - **Provider**: {st.session_state.ai_provider.upper()}
3020
+
3021
+ - **Model**: {st.session_state.selected_model}
3022
+
3023
+ - **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'}
3024
+
3025
+ """
3026
+
3027
+ st.markdown(settings_info)
3028
+
3029
+
3030
+
3031
+ # Provider-specific information
3032
+
3033
+ if st.session_state.ai_provider == "openai":
3034
+
3035
+ st.info("""
3036
+
3037
+ **OpenAI Models Information:**
3038
+
3039
+ - **GPT-4**: Most powerful model, best for complex analysis and detailed explanations
3040
+
3041
+ - **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities
3042
+
3043
+ - **GPT-4-Mini**: Economical option with good performance for standard tasks
3044
+
3045
+ - **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization
3046
+
3047
+ """)
3048
+
3049
+ else:
3050
+
3051
+ st.info("""
3052
+
3053
+ **Groq Models Information:**
3054
+
3055
+ - **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis
3056
+
3057
+ - **gemma2-9b-it**: Good balance of speed and capabilities
3058
+
3059
+ - **llama-3-8b-8192**: Fastest option for basic analysis tasks
3060
+
3061
+ """)
3062
+
3063
+
3064
+
3065
+ # Integration instructions
3066
+
3067
+ with st.expander("How to get API Keys"):
3068
+
3069
+ if st.session_state.ai_provider == "openai":
3070
+
3071
+ st.markdown("""
3072
+
3073
+ ### Getting an OpenAI API Key
3074
+
3075
+
3076
+
3077
+ 1. Go to [OpenAI's platform](https://platform.openai.com)
3078
+
3079
+ 2. Sign up or log in to your account
3080
+
3081
+ 3. Navigate to the API section
3082
+
3083
+ 4. Create a new API key
3084
+
3085
+ 5. Copy the key and paste it above
3086
+
3087
+
3088
+
3089
+ Note: OpenAI API usage incurs charges based on tokens used.
3090
+
3091
+ """)
3092
+
3093
+ else:
3094
+
3095
+ st.markdown("""
3096
+
3097
+ ### Getting a Groq API Key
3098
+
3099
+
3100
+
3101
+ 1. Go to [Groq's website](https://console.groq.com/keys)
3102
+
3103
+ 2. Sign up or log in to your account
3104
+
3105
+ 3. Navigate to API Keys section
3106
+
3107
+ 4. Create a new API key
3108
+
3109
+ 5. Copy the key and paste it above
3110
+
3111
+
3112
+
3113
+ Note: Check Groq's pricing page for current rates.
3114
+
3115
+ """)
3116
+
3117
+
3118
+
3119
+ # Cleanup on app exit
3120
+
3121
+ def cleanup():
3122
+
3123
+ try:
3124
+
3125
+ shutil.rmtree(st.session_state.temp_dir)
3126
+
3127
+ print(f"Cleaned up temporary directory: {st.session_state.temp_dir}")
3128
+
3129
+ except Exception as e:
3130
+
3131
+ print(f"Error cleaning up: {e}")
3132
+
3133
+
3134
+
3135
+ import atexit
3136
+
3137
+ atexit.register(cleanup)