File size: 8,184 Bytes
96f0aad 2790bac b66e512 52ac874 96f0aad e1dabbf 037b17d e1dabbf fcc25ba 037b17d e1dabbf b66e512 8bf44c2 b66e512 e1dabbf 2790bac 8a030d9 2790bac 52ac874 2790bac b66e512 52ac874 8a030d9 2790bac 52ac874 2790bac 52ac874 e1dabbf fb5cdfa 8a030d9 fb5cdfa cb5bb24 fb5cdfa cb5bb24 e1dabbf fb5cdfa fcc25ba fb5cdfa cb5bb24 fb5cdfa cb5bb24 e1dabbf fcc25ba e1dabbf b22ffb2 8a030d9 b22ffb2 b66e512 e1dabbf b22ffb2 b66e512 b22ffb2 b66e512 52ac874 b22ffb2 8a030d9 b22ffb2 e1dabbf b22ffb2 ed83876 b22ffb2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | import streamlit as st
import pandas as pd
import os
import requests
import json
# --- Page config
st.set_page_config(page_title="CSV-Backed AI Chat Agent", layout="wide")
# --- Title & image
st.title("CSV-Backed AI Chat Agent")
st.image("./nadi-lok-image.png")
# --- Load API key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.")
st.stop()
HEADERS = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json",
}
# --- Sidebar: CSV upload & preview
st.sidebar.header("Upload CSV File")
uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type="csv")
# --- Conversation memory: Use Streamlit session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "temp_input" not in st.session_state:
st.session_state.temp_input = ""
# --- Only load df and reset chat on new file upload
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file)
st.sidebar.success("File uploaded successfully!")
st.sidebar.write("Preview of the uploaded file:")
st.sidebar.dataframe(df.head())
columns = ", ".join(df.columns)
system_message = {
"role": "system",
"content": (
f"You are an AI data analyst for a CSV file with these columns: {columns}. "
"When the user asks a question, always use the most relevant function to get the answer directly. "
"Do not describe your plan or reasoning steps. Do not ask the user for clarification. "
"Just call the function needed and give the answer, as briefly as possible. "
"If you need to search or filter the CSV, use the 'search_csv' function. "
"If you need to count unique values, use the 'count_unique' function. "
"If you use 'search_csv', use Pandas query syntax."
),
}
# Only reset memory on new file load
if not st.session_state.messages or (
st.session_state.messages and
("system" not in st.session_state.messages[0].get("role", ""))
):
st.session_state.messages = [system_message]
elif (
st.session_state.messages and
st.session_state.messages[0].get("role", "") == "system" and
st.session_state.messages[0].get("content", "") != system_message["content"]
):
st.session_state.messages[0] = system_message
except Exception as e:
st.sidebar.error(f"Error reading file: {e}")
df = None
else:
df = None
if df is not None:
st.markdown(f"**Loaded CSV:** {df.shape[0]} rows × {df.shape[1]} columns")
# --- Functions for function calling
def search_csv(query: str):
try:
result_df = df.query(query)
return result_df.head(10).to_dict(orient="records") # limit for safety
except Exception as e:
return {"error": f"Invalid query. Example: 'price > 100'. Details: {str(e)}"}
def count_unique(column: str):
try:
n = df[column].nunique()
return {"column": column, "unique_count": int(n)}
except Exception as e:
return {"error": f"Column '{column}' not found or not countable. Details: {str(e)}"}
# --- Function schemas for OpenAI
function_schema = [
{
"name": "search_csv",
"description": "Filter the CSV rows by a Pandas query. Example: price > 100",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "A Pandas query string, e.g. 'price > 100 and city == \"Miami\"'"
},
},
"required": ["query"],
},
},
{
"name": "count_unique",
"description": "Count the number of unique values in a column.",
"parameters": {
"type": "object",
"properties": {
"column": {
"type": "string",
"description": "The column name to count unique values."
},
},
"required": ["column"],
},
}
]
function_map = {
"search_csv": search_csv,
"count_unique": count_unique,
}
# --- Chat interface
st.markdown("### Conversation")
for i, msg in enumerate(st.session_state.messages[1:]): # Skip system message for display
if msg["role"] == "user":
st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
elif msg["role"] == "assistant":
st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {msg['content']}</div>", unsafe_allow_html=True)
elif msg["role"] == "function":
try:
result = json.loads(msg["content"])
st.markdown(f"<details><summary><b>Function '{msg['name']}' output:</b></summary><pre>{json.dumps(result, indent=2)}</pre></details>", unsafe_allow_html=True)
except Exception:
st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
# --- Sending user input and OpenAI call logic using a callback
def send_message():
user_input = st.session_state.temp_input
if user_input and user_input.strip():
st.session_state.messages.append({"role": "user", "content": user_input})
# Limit history for context size (keep system + last 8)
chat_messages = st.session_state.messages
if len(chat_messages) > 10:
chat_messages = [chat_messages[0]] + chat_messages[-9:]
else:
chat_messages = chat_messages.copy()
# First OpenAI call: Check for function call
chat_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-3.5-turbo-1106",
"messages": chat_messages,
"functions": function_schema,
"function_call": "auto",
"temperature": 0,
"max_tokens": 1000,
},
timeout=60,
)
chat_resp.raise_for_status()
response_json = chat_resp.json()
msg = response_json["choices"][0]["message"]
# If OpenAI requests a function call
if msg.get("function_call"):
func_name = msg["function_call"]["name"]
args_json = msg["function_call"]["arguments"]
args = json.loads(args_json)
# Call the correct Python function
if func_name in function_map:
function_result = function_map[func_name](**args)
else:
function_result = {"error": f"Unknown function: {func_name}"}
st.session_state.messages.append({
"role": "function",
"name": func_name,
"content": json.dumps(function_result),
})
# Limit history again for second call
followup_messages = st.session_state.messages
if len(followup_messages) > 12:
followup_messages = [followup_messages[0]] + followup_messages[-11:]
else:
followup_messages = followup_messages.copy()
final_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-3.5-turbo-1106",
"messages": followup_messages,
"temperature": 0,
"max_tokens": 1500,
},
timeout=60,
)
final_resp.raise_for_status()
answer = final_resp.json()["choices"][0]["message"]["content"]
st.session_state.messages.append({"role": "assistant", "content": answer})
else:
st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
st.session_state.temp_input = ""
if df is not None:
st.text_input("Your message:", key="temp_input", on_change=send_message)
|