Seth0330's picture
Update app.py
8a030d9 verified
raw
history blame
8.18 kB
import streamlit as st
import pandas as pd
import os
import requests
import json
# --- Page config
st.set_page_config(page_title="CSV-Backed AI Chat Agent", layout="wide")
# --- Title & image
st.title("CSV-Backed AI Chat Agent")
st.image("./nadi-lok-image.png")
# --- Load API key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.")
st.stop()
HEADERS = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json",
}
# --- Sidebar: CSV upload & preview
st.sidebar.header("Upload CSV File")
uploaded_file = st.sidebar.file_uploader("Choose a CSV file", type="csv")
# --- Conversation memory: Use Streamlit session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "temp_input" not in st.session_state:
st.session_state.temp_input = ""
# --- Only load df and reset chat on new file upload
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file)
st.sidebar.success("File uploaded successfully!")
st.sidebar.write("Preview of the uploaded file:")
st.sidebar.dataframe(df.head())
columns = ", ".join(df.columns)
system_message = {
"role": "system",
"content": (
f"You are an AI data analyst for a CSV file with these columns: {columns}. "
"When the user asks a question, always use the most relevant function to get the answer directly. "
"Do not describe your plan or reasoning steps. Do not ask the user for clarification. "
"Just call the function needed and give the answer, as briefly as possible. "
"If you need to search or filter the CSV, use the 'search_csv' function. "
"If you need to count unique values, use the 'count_unique' function. "
"If you use 'search_csv', use Pandas query syntax."
),
}
# Only reset memory on new file load
if not st.session_state.messages or (
st.session_state.messages and
("system" not in st.session_state.messages[0].get("role", ""))
):
st.session_state.messages = [system_message]
elif (
st.session_state.messages and
st.session_state.messages[0].get("role", "") == "system" and
st.session_state.messages[0].get("content", "") != system_message["content"]
):
st.session_state.messages[0] = system_message
except Exception as e:
st.sidebar.error(f"Error reading file: {e}")
df = None
else:
df = None
if df is not None:
st.markdown(f"**Loaded CSV:** {df.shape[0]} rows × {df.shape[1]} columns")
# --- Functions for function calling
def search_csv(query: str):
try:
result_df = df.query(query)
return result_df.head(10).to_dict(orient="records") # limit for safety
except Exception as e:
return {"error": f"Invalid query. Example: 'price > 100'. Details: {str(e)}"}
def count_unique(column: str):
try:
n = df[column].nunique()
return {"column": column, "unique_count": int(n)}
except Exception as e:
return {"error": f"Column '{column}' not found or not countable. Details: {str(e)}"}
# --- Function schemas for OpenAI
function_schema = [
{
"name": "search_csv",
"description": "Filter the CSV rows by a Pandas query. Example: price > 100",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "A Pandas query string, e.g. 'price > 100 and city == \"Miami\"'"
},
},
"required": ["query"],
},
},
{
"name": "count_unique",
"description": "Count the number of unique values in a column.",
"parameters": {
"type": "object",
"properties": {
"column": {
"type": "string",
"description": "The column name to count unique values."
},
},
"required": ["column"],
},
}
]
function_map = {
"search_csv": search_csv,
"count_unique": count_unique,
}
# --- Chat interface
st.markdown("### Conversation")
for i, msg in enumerate(st.session_state.messages[1:]): # Skip system message for display
if msg["role"] == "user":
st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
elif msg["role"] == "assistant":
st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {msg['content']}</div>", unsafe_allow_html=True)
elif msg["role"] == "function":
try:
result = json.loads(msg["content"])
st.markdown(f"<details><summary><b>Function '{msg['name']}' output:</b></summary><pre>{json.dumps(result, indent=2)}</pre></details>", unsafe_allow_html=True)
except Exception:
st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
# --- Sending user input and OpenAI call logic using a callback
def send_message():
user_input = st.session_state.temp_input
if user_input and user_input.strip():
st.session_state.messages.append({"role": "user", "content": user_input})
# Limit history for context size (keep system + last 8)
chat_messages = st.session_state.messages
if len(chat_messages) > 10:
chat_messages = [chat_messages[0]] + chat_messages[-9:]
else:
chat_messages = chat_messages.copy()
# First OpenAI call: Check for function call
chat_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-3.5-turbo-1106",
"messages": chat_messages,
"functions": function_schema,
"function_call": "auto",
"temperature": 0,
"max_tokens": 1000,
},
timeout=60,
)
chat_resp.raise_for_status()
response_json = chat_resp.json()
msg = response_json["choices"][0]["message"]
# If OpenAI requests a function call
if msg.get("function_call"):
func_name = msg["function_call"]["name"]
args_json = msg["function_call"]["arguments"]
args = json.loads(args_json)
# Call the correct Python function
if func_name in function_map:
function_result = function_map[func_name](**args)
else:
function_result = {"error": f"Unknown function: {func_name}"}
st.session_state.messages.append({
"role": "function",
"name": func_name,
"content": json.dumps(function_result),
})
# Limit history again for second call
followup_messages = st.session_state.messages
if len(followup_messages) > 12:
followup_messages = [followup_messages[0]] + followup_messages[-11:]
else:
followup_messages = followup_messages.copy()
final_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-3.5-turbo-1106",
"messages": followup_messages,
"temperature": 0,
"max_tokens": 1500,
},
timeout=60,
)
final_resp.raise_for_status()
answer = final_resp.json()["choices"][0]["message"]["content"]
st.session_state.messages.append({"role": "assistant", "content": answer})
else:
st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
st.session_state.temp_input = ""
if df is not None:
st.text_input("Your message:", key="temp_input", on_change=send_message)