Seth0330's picture
Update app.py
17fc679 verified
raw
history blame
9.48 kB
import streamlit as st
import os
import json
import requests
import traceback
st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide")
st.title("JSON-Backed AI Chat Agent")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.")
st.stop()
HEADERS = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json",
}
st.sidebar.header("Upload Multiple JSON Files")
uploaded_files = st.sidebar.file_uploader(
"Choose one or more JSON files", type="json", accept_multiple_files=True
)
if "json_data" not in st.session_state:
st.session_state.json_data = {}
if "messages" not in st.session_state:
st.session_state.messages = []
if "temp_input" not in st.session_state:
st.session_state.temp_input = ""
if uploaded_files:
st.session_state.json_data.clear()
file_summaries = []
for f in uploaded_files:
try:
content = json.load(f)
st.session_state.json_data[f.name] = content
if isinstance(content, dict):
keys = list(content.keys())
elif isinstance(content, list) and content and isinstance(content[0], dict):
keys = list(content[0].keys())
else:
keys = []
file_summaries.append(f"{f.name}: keys={keys[:10]}{'...' if len(keys)>10 else ''}")
st.sidebar.success(f"Loaded: {f.name}")
st.sidebar.write(f"Keys: {keys[:10]}{'...' if len(keys)>10 else ''}")
except Exception as e:
st.sidebar.error(f"Error reading {f.name}: {e}")
system_message = {
"role": "system",
"content": (
"You are an AI data analyst for the following JSON files:\n" +
"\n".join(file_summaries) +
"\nEach file may have a different structure and set of keys. "
"When the user asks a question, identify which file(s) it applies to, "
"then use the most relevant function to extract the answer. "
"If the user does not specify a file, make your best guess based on keys/fields mentioned."
),
}
st.session_state.messages = [system_message]
else:
st.session_state.json_data.clear()
def search_json(file_name, key, value):
try:
data = st.session_state.json_data[file_name]
if isinstance(data, list):
results = [item for item in data if isinstance(item, dict) and str(item.get(key)) == str(value)]
return results[:10]
elif isinstance(data, dict):
if key in data and str(data[key]) == str(value):
return [{key: value}]
else:
return []
else:
return []
except Exception as e:
return {"error": str(e)}
def list_keys(file_name):
try:
data = st.session_state.json_data[file_name]
if isinstance(data, dict):
return list(data.keys())
elif isinstance(data, list) and data and isinstance(data[0], dict):
return list(data[0].keys())
else:
return []
except Exception as e:
return {"error": str(e)}
def count_key_occurrences(file_name, key):
try:
data = st.session_state.json_data[file_name]
if isinstance(data, dict):
return 1 if key in data else 0
elif isinstance(data, list):
return sum(1 for item in data if isinstance(item, dict) and key in item)
else:
return 0
except Exception as e:
return {"error": str(e)}
function_schema = [
{
"name": "search_json",
"description": "Find records in the specified JSON file where key matches a given value.",
"parameters": {
"type": "object",
"properties": {
"file_name": {"type": "string", "description": "The uploaded JSON file to search."},
"key": {"type": "string", "description": "The key/field to filter by."},
"value": {"type": "string", "description": "The value to match."}
},
"required": ["file_name", "key", "value"],
},
},
{
"name": "list_keys",
"description": "List all top-level keys in a given JSON file.",
"parameters": {
"type": "object",
"properties": {
"file_name": {"type": "string", "description": "The uploaded JSON file."},
},
"required": ["file_name"],
},
},
{
"name": "count_key_occurrences",
"description": "Count the number of times a given key appears in a JSON file.",
"parameters": {
"type": "object",
"properties": {
"file_name": {"type": "string", "description": "The uploaded JSON file."},
"key": {"type": "string", "description": "The key to count."},
},
"required": ["file_name", "key"],
},
},
]
st.markdown("### Conversation")
for i, msg in enumerate(st.session_state.messages[1:]):
if msg["role"] == "user":
st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
elif msg["role"] == "assistant":
content = msg.get("content", "")
if content.strip():
st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {content}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div style='color: #DC143C;'><b>Agent:</b> [No response generated]</div>", unsafe_allow_html=True)
elif msg["role"] == "function":
try:
result = json.loads(msg["content"])
st.markdown(f"<details><summary><b>Function '{msg['name']}' output:</b></summary><pre>{json.dumps(result, indent=2)}</pre></details>", unsafe_allow_html=True)
except Exception:
st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True)
def send_message():
user_input = st.session_state.temp_input
if user_input and user_input.strip():
st.session_state.messages.append({"role": "user", "content": user_input})
chat_messages = st.session_state.messages
if len(chat_messages) > 10:
chat_messages = [chat_messages[0]] + chat_messages[-9:]
else:
chat_messages = chat_messages.copy()
try:
chat_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-4.1",
"messages": chat_messages,
"functions": function_schema,
"function_call": "auto",
"temperature": 0,
"max_tokens": 1000,
},
timeout=60,
)
chat_resp.raise_for_status()
response_json = chat_resp.json()
msg = response_json["choices"][0]["message"]
if msg.get("function_call"):
func_name = msg["function_call"]["name"]
args_json = msg["function_call"]["arguments"]
args = json.loads(args_json)
if func_name == "search_json":
result = search_json(
args.get("file_name"), args.get("key"), args.get("value")
)
elif func_name == "list_keys":
result = list_keys(args.get("file_name"))
elif func_name == "count_key_occurrences":
result = count_key_occurrences(args.get("file_name"), args.get("key"))
else:
result = {"error": f"Unknown function: {func_name}"}
st.session_state.messages.append({
"role": "function",
"name": func_name,
"content": json.dumps(result),
})
followup_messages = st.session_state.messages
if len(followup_messages) > 12:
followup_messages = [followup_messages[0]] + followup_messages[-11:]
else:
followup_messages = followup_messages.copy()
final_resp = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=HEADERS,
json={
"model": "gpt-4.1",
"messages": followup_messages,
"temperature": 0,
"max_tokens": 1500,
},
timeout=60,
)
final_resp.raise_for_status()
final_json = final_resp.json()
answer = final_json["choices"][0]["message"]["content"]
st.session_state.messages.append({"role": "assistant", "content": answer})
else:
st.session_state.messages.append({"role": "assistant", "content": msg["content"]})
except Exception as e:
st.error("Exception: " + str(e))
st.code(traceback.format_exc())
finally:
st.session_state.temp_input = ""
if st.session_state.json_data:
st.text_input("Your message:", key="temp_input", on_change=send_message)
else:
st.info("Please upload at least one JSON file to start chatting.")