| import streamlit as st |
| import os |
| import json |
| import requests |
| import traceback |
|
|
| |
| if "json_data" not in st.session_state: |
| st.session_state.json_data = {} |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
| if "temp_input" not in st.session_state: |
| st.session_state.temp_input = "" |
| if "files_loaded" not in st.session_state: |
| st.session_state.files_loaded = False |
|
|
| |
| st.set_page_config(page_title="JSON-Backed AI Chat Agent", layout="wide") |
| st.title("JSON-Backed AI Chat Agent") |
|
|
| |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
| if not OPENAI_API_KEY: |
| st.error("❌ OPENAI_API_KEY not set in Settings → Secrets.") |
| st.stop() |
|
|
| HEADERS = { |
| "Authorization": f"Bearer {OPENAI_API_KEY}", |
| "Content-Type": "application/json", |
| } |
|
|
| st.sidebar.header("Upload Multiple JSON Files") |
| uploaded_files = st.sidebar.file_uploader( |
| "Choose one or more JSON files", type="json", accept_multiple_files=True |
| ) |
|
|
| |
| if uploaded_files and not st.session_state.files_loaded: |
| st.session_state.json_data.clear() |
| file_summaries = [] |
| for f in uploaded_files: |
| try: |
| content = json.load(f) |
| st.session_state.json_data[f.name] = content |
| if isinstance(content, dict): |
| keys = list(content.keys()) |
| elif isinstance(content, list) and content and isinstance(content[0], dict): |
| keys = list(content[0].keys()) |
| else: |
| keys = [] |
| file_summaries.append(f"{f.name}: keys={keys[:10]}{'...' if len(keys)>10 else ''}") |
| st.sidebar.success(f"Loaded: {f.name}") |
| st.sidebar.write(f"Keys: {keys[:10]}{'...' if len(keys)>10 else ''}") |
| except Exception as e: |
| st.sidebar.error(f"Error reading {f.name}: {e}") |
|
|
| |
| system_message = { |
| "role": "system", |
| "content": ( |
| "You are an AI data analyst for uploaded JSON files. " |
| "Each file may have different structures and keys, including lists and nested dictionaries. " |
| "You have access to a function 'search_all_jsons' that finds all records in all JSON files where a key matches a value, recursively. " |
| "If a user asks about groups of people or wants to know counts such as 'How many females are there?', " |
| "interpret this as 'search for all records where gender equals female'. " |
| "Always use the 'search_all_jsons' function with key='gender' and value='female' for such queries, unless another key/value is clear from context. " |
| "If someone asks 'How many males?', search for gender equals male, etc. " |
| "EXAMPLES:\n" |
| "User: How many females are there?\n" |
| "Assistant: (Call search_all_jsons with key='gender', value='female')\n" |
| "User: How many males are there?\n" |
| "Assistant: (Call search_all_jsons with key='gender', value='male')\n" |
| "User: Show all females\n" |
| "Assistant: (Call search_all_jsons with key='gender', value='female')\n" |
| "User: How many people are named Emily?\n" |
| "Assistant: (Call search_all_jsons with key='firstName', value='Emily')" |
| ) |
| } |
| st.session_state.messages = [system_message] |
| st.session_state.files_loaded = True |
| elif not uploaded_files: |
| st.session_state.json_data.clear() |
| st.session_state.files_loaded = False |
|
|
| |
| def search_all_jsons(key, value): |
| found = [] |
| for file_name, data in st.session_state.json_data.items(): |
| def recursive_search(obj): |
| if isinstance(obj, dict): |
| if key in obj and str(obj[key]).lower() == str(value).lower(): |
| found.append({**obj, "__file__": file_name}) |
| for v in obj.values(): |
| recursive_search(v) |
| elif isinstance(obj, list): |
| for item in obj: |
| recursive_search(item) |
| recursive_search(data) |
| return found |
|
|
| |
| def search_json(file_name, key, value): |
| def recursive_search(obj, key, value, found): |
| if isinstance(obj, dict): |
| if key in obj and str(obj[key]).lower() == str(value).lower(): |
| found.append(obj) |
| for v in obj.values(): |
| recursive_search(v, key, value, found) |
| elif isinstance(obj, list): |
| for item in obj: |
| recursive_search(item, key, value, found) |
| return found |
| try: |
| data = st.session_state.json_data[file_name] |
| results = recursive_search(data, key, value, []) |
| return results |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| def list_keys(file_name): |
| try: |
| data = st.session_state.json_data[file_name] |
| if isinstance(data, dict): |
| return list(data.keys()) |
| elif isinstance(data, list) and data and isinstance(data[0], dict): |
| return list(data[0].keys()) |
| else: |
| return [] |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| def count_key_occurrences(file_name, key): |
| try: |
| data = st.session_state.json_data[file_name] |
| if isinstance(data, dict): |
| return 1 if key in data else 0 |
| elif isinstance(data, list): |
| return sum(1 for item in data if isinstance(item, dict) and key in item) |
| else: |
| return 0 |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| |
| function_schema = [ |
| { |
| "name": "search_json", |
| "description": "Find records in the specified JSON file where key matches a given value.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "file_name": {"type": "string", "description": "The uploaded JSON file to search."}, |
| "key": {"type": "string", "description": "The key/field to filter by."}, |
| "value": {"type": "string", "description": "The value to match."} |
| }, |
| "required": ["file_name", "key", "value"], |
| }, |
| }, |
| { |
| "name": "list_keys", |
| "description": "List all top-level keys in a given JSON file.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "file_name": {"type": "string", "description": "The uploaded JSON file."}, |
| }, |
| "required": ["file_name"], |
| }, |
| }, |
| { |
| "name": "count_key_occurrences", |
| "description": "Count the number of times a given key appears in a JSON file.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "file_name": {"type": "string", "description": "The uploaded JSON file."}, |
| "key": {"type": "string", "description": "The key to count."}, |
| }, |
| "required": ["file_name", "key"], |
| }, |
| }, |
| { |
| "name": "search_all_jsons", |
| "description": "Search all uploaded JSON files recursively for dicts where a key matches a value.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "key": {"type": "string", "description": "The key to search for (e.g. 'gender')"}, |
| "value": {"type": "string", "description": "The value to match (e.g. 'female')"} |
| }, |
| "required": ["key", "value"] |
| } |
| } |
| ] |
|
|
| |
| st.markdown("### Conversation") |
| for i, msg in enumerate(st.session_state.messages[1:]): |
| if msg["role"] == "user": |
| st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True) |
| elif msg["role"] == "assistant": |
| content = msg.get("content", "") |
| if content.strip(): |
| st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {content}</div>", unsafe_allow_html=True) |
| else: |
| st.markdown(f"<div style='color: #DC143C;'><b>Agent:</b> [No response generated]</div>", unsafe_allow_html=True) |
| elif msg["role"] == "function": |
| try: |
| result = json.loads(msg["content"]) |
| st.markdown(f"<details><summary><b>Function '{msg['name']}' output:</b></summary><pre>{json.dumps(result, indent=2)}</pre></details>", unsafe_allow_html=True) |
| except Exception: |
| st.markdown(f"<b>Function '{msg['name']}' output:</b> {msg['content']}", unsafe_allow_html=True) |
|
|
| |
| def send_message(): |
| try: |
| user_input = st.session_state.temp_input |
| if user_input and user_input.strip(): |
| st.session_state.messages.append({"role": "user", "content": user_input}) |
| chat_messages = st.session_state.messages |
| if len(chat_messages) > 10: |
| chat_messages = [chat_messages[0]] + chat_messages[-9:] |
| else: |
| chat_messages = chat_messages.copy() |
| chat_resp = requests.post( |
| "https://api.openai.com/v1/chat/completions", |
| headers=HEADERS, |
| json={ |
| "model": "gpt-4.1", |
| "messages": chat_messages, |
| "functions": function_schema, |
| "function_call": "auto", |
| "temperature": 0, |
| "max_tokens": 1000, |
| }, |
| timeout=60, |
| ) |
| chat_resp.raise_for_status() |
| response_json = chat_resp.json() |
| msg = response_json["choices"][0]["message"] |
|
|
| if msg.get("function_call"): |
| func_name = msg["function_call"]["name"] |
| args_json = msg["function_call"]["arguments"] |
| args = json.loads(args_json) |
|
|
| if func_name == "search_json": |
| result = search_json(args.get("file_name"), args.get("key"), args.get("value")) |
| elif func_name == "list_keys": |
| result = list_keys(args.get("file_name")) |
| elif func_name == "count_key_occurrences": |
| result = count_key_occurrences(args.get("file_name"), args.get("key")) |
| elif func_name == "search_all_jsons": |
| result = search_all_jsons(args.get("key"), args.get("value")) |
| else: |
| result = {"error": f"Unknown function: {func_name}"} |
|
|
| st.session_state.messages.append({ |
| "role": "function", |
| "name": func_name, |
| "content": json.dumps(result), |
| }) |
| followup_messages = st.session_state.messages |
| if len(followup_messages) > 12: |
| followup_messages = [followup_messages[0]] + followup_messages[-11:] |
| else: |
| followup_messages = followup_messages.copy() |
| final_resp = requests.post( |
| "https://api.openai.com/v1/chat/completions", |
| headers=HEADERS, |
| json={ |
| "model": "gpt-4.1", |
| "messages": followup_messages, |
| "temperature": 0, |
| "max_tokens": 1500, |
| }, |
| timeout=60, |
| ) |
| final_resp.raise_for_status() |
| final_json = final_resp.json() |
| answer = final_json["choices"][0]["message"]["content"] |
| st.session_state.messages.append({"role": "assistant", "content": answer}) |
| else: |
| st.session_state.messages.append({"role": "assistant", "content": msg["content"]}) |
| st.session_state.temp_input = "" |
| except Exception as e: |
| st.error("Exception: " + str(e)) |
| st.code(traceback.format_exc()) |
|
|
| if st.session_state.json_data: |
| st.text_input("Your message:", key="temp_input", on_change=send_message) |
| |
| st.json(results) |
| else: |
| st.info("Please upload at least one JSON file to start chatting.") |
|
|