File size: 8,184 Bytes
96f0aad
2790bac
b66e512
 
52ac874
96f0aad
e1dabbf
 
037b17d
e1dabbf
 
fcc25ba
037b17d
e1dabbf
b66e512
 
8bf44c2
 
 
b66e512
 
 
 
 
e1dabbf
2790bac
 
 
8a030d9
 
 
 
 
 
 
 
2790bac
52ac874
2790bac
b66e512
52ac874
8a030d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2790bac
 
52ac874
 
 
2790bac
52ac874
 
 
e1dabbf
fb5cdfa
 
 
8a030d9
fb5cdfa
cb5bb24
fb5cdfa
cb5bb24
 
 
 
 
 
 
e1dabbf
fb5cdfa
 
 
 
 
 
 
 
 
fcc25ba
fb5cdfa
 
 
 
cb5bb24
 
 
 
 
 
 
 
 
 
 
 
 
 
fb5cdfa
 
 
cb5bb24
 
 
 
 
e1dabbf
 
fcc25ba
e1dabbf
 
 
 
 
 
 
 
 
 
 
b22ffb2
 
 
 
 
8a030d9
 
 
 
 
 
b22ffb2
 
b66e512
 
 
e1dabbf
b22ffb2
 
 
b66e512
b22ffb2
b66e512
 
52ac874
b22ffb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a030d9
 
 
 
 
 
b22ffb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1dabbf
b22ffb2
ed83876
b22ffb2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import streamlit as st
import pandas as pd
import os
import requests
import json

# --- Page config
st.set_page_config(page_title="CSV-Backed AI Chat Agent", layout="wide")

# --- Title & image
st.title("CSV-Backed AI Chat Agent")
st.image("./nadi-lok-image.png")

# --- Load API key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.")
    st.stop()

HEADERS = {
    "Authorization": f"Bearer {OPENAI_API_KEY}",
    "Content-Type": "application/json",
}

# --- Sidebar: CSV upload & preview
st.sidebar.header("Upload CSV File")
uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type="csv")

# --- Conversation memory: Use Streamlit session state
if "messages" not in st.session_state:
    st.session_state.messages = []
if "temp_input" not in st.session_state:
    st.session_state.temp_input = ""

# --- Only load df and reset chat on new file upload
if uploaded_file is not None:
    try:
        df = pd.read_csv(uploaded_file)
        st.sidebar.success("File uploaded successfully!")
        st.sidebar.write("Preview of the uploaded file:")
        st.sidebar.dataframe(df.head())
        columns = ", ".join(df.columns)
        system_message = {
            "role": "system",
            "content": (
                f"You are an AI data analyst for a CSV file with these columns: {columns}. "
                "When the user asks a question, always use the most relevant function to get the answer directly. "
                "Do not describe your plan or reasoning steps. Do not ask the user for clarification. "
                "Just call the function needed and give the answer, as briefly as possible. "
                "If you need to search or filter the CSV, use the 'search_csv' function. "
                "If you need to count unique values, use the 'count_unique' function. "
                "If you use 'search_csv', use Pandas query syntax."
            ),
        }
        # Only reset memory on new file load
        if not st.session_state.messages or (
            st.session_state.messages and
            ("system" not in st.session_state.messages[0].get("role", ""))
        ):
            st.session_state.messages = [system_message]
        elif (
            st.session_state.messages and
            st.session_state.messages[0].get("role", "") == "system" and
            st.session_state.messages[0].get("content", "") != system_message["content"]
        ):
            st.session_state.messages[0] = system_message
    except Exception as e:
        st.sidebar.error(f"Error reading file: {e}")
        df = None
else:
    df = None

if df is not None:
    st.markdown(f"**Loaded CSV:** {df.shape[0]} rows × {df.shape[1]} columns")

# --- Functions for function calling
def search_csv(query: str):
    try:
        result_df = df.query(query)
        return result_df.head(10).to_dict(orient="records")   # limit for safety
    except Exception as e:
        return {"error": f"Invalid query. Example: 'price > 100'. Details: {str(e)}"}

def count_unique(column: str):
    try:
        n = df[column].nunique()
        return {"column": column, "unique_count": int(n)}
    except Exception as e:
        return {"error": f"Column '{column}' not found or not countable. Details: {str(e)}"}

# --- Function schemas for OpenAI
function_schema = [
    {
        "name": "search_csv",
        "description": "Filter the CSV rows by a Pandas query. Example: price > 100",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "A Pandas query string, e.g. 'price > 100 and city == \"Miami\"'"
                },
            },
            "required": ["query"],
        },
    },
    {
        "name": "count_unique",
        "description": "Count the number of unique values in a column.",
        "parameters": {
            "type": "object",
            "properties": {
                "column": {
                    "type": "string",
                    "description": "The column name to count unique values."
                },
            },
            "required": ["column"],
        },
    }
]

function_map = {
    "search_csv": search_csv,
    "count_unique": count_unique,
}

# --- Chat interface
st.markdown("### Conversation")
for i, msg in enumerate(st.session_state.messages[1:]):  # Skip system message for display
    if msg["role"] == "user":
        st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
    elif msg["role"] == "assistant":
        st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {msg['content']}</div>", unsafe_allow_html=True)
    elif msg["role"] == "function":
        try:
            result = json.loads(msg["content"])
            st.markdown(f"<details><summary><b>Function '{msg['name']}' output:</b></summary><pre>{json.dumps(result, indent=2)}</pre></details>", unsafe_allow_html=True)
        except Exception:
            st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)

# --- Sending user input and OpenAI call logic using a callback
def send_message():
    user_input = st.session_state.temp_input
    if user_input and user_input.strip():
        st.session_state.messages.append({"role": "user", "content": user_input})
        # Limit history for context size (keep system + last 8)
        chat_messages = st.session_state.messages
        if len(chat_messages) > 10:
            chat_messages = [chat_messages[0]] + chat_messages[-9:]
        else:
            chat_messages = chat_messages.copy()
        # First OpenAI call: Check for function call
        chat_resp = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=HEADERS,
            json={
                "model": "gpt-3.5-turbo-1106",
                "messages": chat_messages,
                "functions": function_schema,
                "function_call": "auto",
                "temperature": 0,
                "max_tokens": 1000,
            },
            timeout=60,
        )
        chat_resp.raise_for_status()
        response_json = chat_resp.json()
        msg = response_json["choices"][0]["message"]

        # If OpenAI requests a function call
        if msg.get("function_call"):
            func_name = msg["function_call"]["name"]
            args_json = msg["function_call"]["arguments"]
            args = json.loads(args_json)
            # Call the correct Python function
            if func_name in function_map:
                function_result = function_map[func_name](**args)
            else:
                function_result = {"error": f"Unknown function: {func_name}"}
            st.session_state.messages.append({
                "role": "function",
                "name": func_name,
                "content": json.dumps(function_result),
            })
            # Limit history again for second call
            followup_messages = st.session_state.messages
            if len(followup_messages) > 12:
                followup_messages = [followup_messages[0]] + followup_messages[-11:]
            else:
                followup_messages = followup_messages.copy()
            final_resp = requests.post(
                "https://api.openai.com/v1/chat/completions",
                headers=HEADERS,
                json={
                    "model": "gpt-3.5-turbo-1106",
                    "messages": followup_messages,
                    "temperature": 0,
                    "max_tokens": 1500,
                },
                timeout=60,
            )
            final_resp.raise_for_status()
            answer = final_resp.json()["choices"][0]["message"]["content"]
            st.session_state.messages.append({"role": "assistant", "content": answer})
        else:
            st.session_state.messages.append({"role": "assistant", "content": msg["content"]})

        st.session_state.temp_input = ""

if df is not None:
    st.text_input("Your message:", key="temp_input", on_change=send_message)