Seth0330's picture
Update app.py
e32a04d verified
raw
history blame
12.2 kB
import streamlit as st
import os
import json
import requests
import traceback
# --- ALWAYS INIT SESSION STATE FIRST (before any widgets)
if "json_data" not in st.session_state:
st.session_state.json_data = {}
if "messages" not in st.session_state:
st.session_state.messages = []
if "temp_input" not in st.session_state:
st.session_state.temp_input = ""
if "files_loaded" not in st.session_state:
st.session_state.files_loaded = False
# --- Page config
st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide")
st.title("JSON-Backed AI Chat Agent")
# --- Load API key
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.")
st.stop()
HEADERS = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json",
}
st.sidebar.header("Upload Multiple JSON Files")
uploaded_files = st.sidebar.file_uploader(
"Choose one or more JSON files", type="json", accept_multiple_files=True
)
# --- Only clear/load state when files change
if uploaded_files and not st.session_state.files_loaded:
st.session_state.json_data.clear()
file_summaries = []
for f in uploaded_files:
try:
content = json.load(f)
st.session_state.json_data[f.name] = content
if isinstance(content, dict):
keys = list(content.keys())
elif isinstance(content, list) and content and isinstance(content[0], dict):
keys = list(content[0].keys())
else:
keys = []
file_summaries.append(f"{f.name}: keys={keys[:10]}{'...' if len(keys)>10 else ''}")
st.sidebar.success(f"Loaded: {f.name}")
st.sidebar.write(f"Keys: {keys[:10]}{'...' if len(keys)>10 else ''}")
except Exception as e:
st.sidebar.error(f"Error reading {f.name}: {e}")
# --- System prompt with explicit few-shot examples
system_message = {
"role": "system",
"content": (
"You are an AI data analyst for uploaded JSON files. "
"Each file may have different structures and keys, including lists and nested dictionaries. "
"You have access to a function 'search_all_jsons' that finds all records in all JSON files where a key matches a value, recursively. "
"If a user asks about groups of people or wants to know counts such as 'How many females are there?', "
"interpret this as 'search for all records where gender equals female'. "
"Always use the 'search_all_jsons' function with key='gender' and value='female' for such queries, unless another key/value is clear from context. "
"If someone asks 'How many males?', search for gender equals male, etc. "
"EXAMPLES:\n"
"User: How many females are there?\n"
"Assistant: (Call search_all_jsons with key='gender', value='female')\n"
"User: How many males are there?\n"
"Assistant: (Call search_all_jsons with key='gender', value='male')\n"
"User: Show all females\n"
"Assistant: (Call search_all_jsons with key='gender', value='female')\n"
"User: How many people are named Emily?\n"
"Assistant: (Call search_all_jsons with key='firstName', value='Emily')"
)
}
st.session_state.messages = [system_message]
st.session_state.files_loaded = True
elif not uploaded_files:
st.session_state.json_data.clear()
st.session_state.files_loaded = False
# --- Recursive search for key/value in all files
def search_all_jsons(key, value):
found = []
for file_name, data in st.session_state.json_data.items():
def recursive_search(obj):
if isinstance(obj, dict):
if key in obj and str(obj[key]).lower() == str(value).lower():
found.append({**obj, "__file__": file_name})
for v in obj.values():
recursive_search(v)
elif isinstance(obj, list):
for item in obj:
recursive_search(item)
recursive_search(data)
return found
# --- Other functions for LLM
def search_json(file_name, key, value):
def recursive_search(obj, key, value, found):
if isinstance(obj, dict):
if key in obj and str(obj[key]).lower() == str(value).lower():
found.append(obj)
for v in obj.values():
recursive_search(v, key, value, found)
elif isinstance(obj, list):
for item in obj:
recursive_search(item, key, value, found)
return found
try:
data = st.session_state.json_data[file_name]
results = recursive_search(data, key, value, [])
return results
except Exception as e:
return {"error": str(e)}
def list_keys(file_name):
try:
data = st.session_state.json_data[file_name]
if isinstance(data, dict):
return list(data.keys())
elif isinstance(data, list) and data and isinstance(data[0], dict):
return list(data[0].keys())
else:
return []
except Exception as e:
return {"error": str(e)}
def count_key_occurrences(file_name, key):
try:
data = st.session_state.json_data[file_name]
if isinstance(data, dict):
return 1 if key in data else 0
elif isinstance(data, list):
return sum(1 for item in data if isinstance(item, dict) and key in item)
else:
return 0
except Exception as e:
return {"error": str(e)}
# --- Function schema
function_schema = [
{
"name": "search_json",
"description": "Find records in the specified JSON file where key matches a given value.",
"parameters": {
"type": "object",
"properties": {
"file_name": {"type": "string", "description": "The uploaded JSON file to search."},
"key": {"type": "string", "description": "The key/field to filter by."},
"value": {"type": "string", "description": "The value to match."}
},
"required": ["file_name", "key", "value"],
},
},
{
"name": "list_keys",
"description": "List all top-level keys in a given JSON file.",
"parameters": {
"type": "object",
"properties": {
"file_name": {"type": "string", "description": "The uploaded JSON file."},
},
"required": ["file_name"],
},
},
{
"name": "count_key_occurrences",
"description": "Count the number of times a given key appears in a JSON file.",
"parameters": {
"type": "object",
"properties": {
"file_name": {"type": "string", "description": "The uploaded JSON file."},
"key": {"type": "string", "description": "The key to count."},
},
"required": ["file_name", "key"],
},
},
{
"name": "search_all_jsons",
"description": "Search all uploaded JSON files recursively for dicts where a key matches a value.",
"parameters": {
"type": "object",
"properties": {
"key": {"type": "string", "description": "The key to search for (e.g. 'gender')"},
"value": {"type": "string", "description": "The value to match (e.g. 'female')"}
},
"required": ["key", "value"]
}
}
]
# --- Conversation UI
st.markdown("### Conversation")
for i, msg in enumerate(st.session_state.messages[1:]):
if msg["role"] == "user":
st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
elif msg["role"] == "assistant":
content = msg.get("content", "")
if content.strip():
st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {content}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div style='color: #DC143C;'><b>Agent:</b> [No response generated]</div>", unsafe_allow_html=True)
elif msg["role"] == "function":
try:
result = json.loads(msg["content"])
st.markdown(f"<details><summary><b>Function '{msg['name']}' output:</b></summary><pre>{json.dumps(result, indent=2)}</pre></details>", unsafe_allow_html=True)
except Exception:
st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
# --- Chat input and OpenAI handling
def send_message():
try:
user_input = st.session_state.temp_input
if user_input and user_input.strip():
st.session_state.messages.append({"role": "user", "content": user_input})
chat_messages = st.session_state.messages
if len(chat_messages) > 10:
chat_messages = [chat_messages[0]] + chat_messages[-9:]
else:
chat_messages = chat_messages.copy()
chat_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-4.1",
"messages": chat_messages,
"functions": function_schema,
"function_call": "auto",
"temperature": 0,
"max_tokens": 1000,
},
timeout=60,
)
chat_resp.raise_for_status()
response_json = chat_resp.json()
msg = response_json["choices"][0]["message"]
if msg.get("function_call"):
func_name = msg["function_call"]["name"]
args_json = msg["function_call"]["arguments"]
args = json.loads(args_json)
if func_name == "search_json":
result = search_json(args.get("file_name"), args.get("key"), args.get("value"))
elif func_name == "list_keys":
result = list_keys(args.get("file_name"))
elif func_name == "count_key_occurrences":
result = count_key_occurrences(args.get("file_name"), args.get("key"))
elif func_name == "search_all_jsons":
result = search_all_jsons(args.get("key"), args.get("value"))
else:
result = {"error": f"Unknown function: {func_name}"}
st.session_state.messages.append({
"role": "function",
"name": func_name,
"content": json.dumps(result),
})
followup_messages = st.session_state.messages
if len(followup_messages) > 12:
followup_messages = [followup_messages[0]] + followup_messages[-11:]
else:
followup_messages = followup_messages.copy()
final_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-4.1",
"messages": followup_messages,
"temperature": 0,
"max_tokens": 1500,
},
timeout=60,
)
final_resp.raise_for_status()
final_json = final_resp.json()
answer = final_json["choices"][0]["message"]["content"]
st.session_state.messages.append({"role": "assistant", "content": answer})
else:
st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
st.session_state.temp_input = ""
except Exception as e:
st.error("Exception: " + str(e))
st.code(traceback.format_exc())
if st.session_state.json_data:
st.text_input("Your message:", key="temp_input", on_change=send_message)
st.json(results)
else:
st.info("Please upload at least one JSON file to start chatting.")