| import streamlit as st |
| import folium |
| from streamlit_folium import st_folium |
| st.set_page_config( |
| page_title="🔬 Explainable Multi-Agent BioData Constructor", |
| layout="centered", |
| initial_sidebar_state="collapsed" |
| ) |
| from neo4j import GraphDatabase |
| import openai |
| import pandas as pd |
| import os |
| import re |
| import hashlib |
| import json |
| import pydeck as pdk |
| import faiss |
| import numpy as np |
| from sklearn.preprocessing import normalize |
| from transformers import AutoTokenizer, AutoModel |
| import torch |
| import ast |
| import textwrap |
| import requests |
| |
| NEO4J_URI = os.getenv("NEO4J_URI") |
| NEO4J_USERNAME = os.getenv("NEO4J_USERNAME") |
| NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") |
| openai_api_key = os.getenv("openai_api_key") |
|
|
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
| |
| def download_if_missing(url, local_path): |
| if not os.path.exists(local_path): |
| with open(local_path, "wb") as f: |
| f.write(requests.get(url).content) |
|
|
| base_url = "https://github.com/Tianyu-yang-anna/EcoData-collector/releases/download/v1.0" |
| files = { |
| "nodes.csv": "/tmp/nodes.csv", |
| "nodes_embeddings.npy": "/tmp/nodes_embeddings.npy", |
| "relationships.csv": "/tmp/relationships.csv", |
| "relationships_embeddings.npy": "/tmp/relationships_embeddings.npy" |
| } |
|
|
| for fname, path in files.items(): |
| download_if_missing(f"{base_url}/{fname}", path) |
|
|
| |
| @st.cache_resource(show_spinner=False) |
| def create_driver(): |
| try: |
| driver = GraphDatabase.driver( |
| NEO4J_URI, |
| auth=(NEO4J_USERNAME, NEO4J_PASSWORD) |
| ) |
| with driver.session() as session: |
| session.run("RETURN 1") |
| return driver |
| except Exception as e: |
| st.error(f"🔴 Neo4j connection failed: {e}") |
| return None |
|
|
| driver = create_driver() |
| |
| openai_client = openai.OpenAI(api_key=openai_api_key) |
|
|
| def gpt_chat(sys_msg: str, user_msg: str, **kwargs): |
| rsp = openai_client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}], |
| **kwargs |
| ) |
| return rsp.choices[0].message.content.strip() |
|
|
| |
| class SimpleEncoder: |
| def __init__(self): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.tokenizer = AutoTokenizer.from_pretrained("/app/model") |
| self.model = AutoModel.from_pretrained("/app/model").to(self.device) |
| self.model.eval() |
|
|
| def encode(self, texts, batch_size: int = 16): |
| embeddings = [] |
| for i in range(0, len(texts), batch_size): |
| batch = texts[i : i + batch_size] |
| with torch.no_grad(): |
| inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(self.device) |
| outputs = self.model(**inputs) |
| batch_emb = outputs.last_hidden_state.mean(dim=1).cpu().numpy() |
| embeddings.append(batch_emb) |
| return np.vstack(embeddings) |
|
|
|
|
| @st.cache_resource(show_spinner=False) |
| def get_encoder(): |
| return SimpleEncoder() |
|
|
| |
| csv_file_pairs = [ |
| ("/tmp/nodes.csv", "/tmp/nodes_embeddings.npy"), |
| ("/tmp/relationships.csv", "/tmp/relationships_embeddings.npy"), |
| ] |
|
|
| for csv_path, npy_path in csv_file_pairs: |
| if not os.path.exists(npy_path): |
| st.error(f"❌ Embedding file not found: {npy_path}") |
| st.stop() |
|
|
| @st.cache_resource(show_spinner=False) |
| def load_embeddings_and_faiss_indexes(file_pairs): |
| index_list, metadatas = [], [] |
| for csv_path, npy_path in file_pairs: |
| try: |
| df = pd.read_csv(csv_path).fillna("") |
| emb = np.load(npy_path).astype("float32") |
| index = faiss.IndexFlatIP(emb.shape[1]) |
| if faiss.get_num_gpus() > 0: |
| res = faiss.StandardGpuResources() |
| index = faiss.index_cpu_to_gpu(res, 0, index) |
| index.add(emb) |
| index_list.append(index) |
| metadatas.append(df) |
| except Exception as e: |
| st.warning(f"⚠️ Failed to load {csv_path} / {npy_path}: {e}") |
| index_list.append(None) |
| metadatas.append(pd.DataFrame()) |
| return index_list, metadatas |
|
|
| csv_faiss_indexes, csv_metadatas = load_embeddings_and_faiss_indexes(csv_file_pairs) |
|
|
| |
|
|
| def flatten_props(df: pd.DataFrame) -> pd.DataFrame: |
| if "props" not in df.columns: |
| return df |
| try: |
| props_df = df["props"].apply(ast.literal_eval).apply(pd.Series) |
| out = pd.concat([df.drop(columns=["props"]), props_df], axis=1) |
| |
| return out |
| except Exception as e: |
| st.warning(f"⚠️ Failed to parse props column: {e}") |
| return df |
|
|
| def unpack_singletons(df: pd.DataFrame) -> pd.DataFrame: |
| for col in df.columns: |
| if df[col].apply(lambda x: isinstance(x, (list, tuple)) and len(x) == 1).any(): |
| df[col] = df[col].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x) |
| return df |
|
|
| def standardize_latlon(df: pd.DataFrame) -> pd.DataFrame: |
| """ |
| - 统一列名到 latitudes / longitudes |
| - 若出现同名重复列,保留第一列并删除其余 |
| - longitudes 位置保持不动,把 latitudes 放到其右侧 |
| """ |
| |
| col_map = {} |
| for col in df.columns: |
| low = col.lower() |
| if "lat" in low and "lon" not in low: |
| col_map[col] = "latitudes" |
| elif ("lon" in low or "lng" in low): |
| col_map[col] = "longitudes" |
| df = df.rename(columns=col_map) |
|
|
| |
| |
| while df.columns.duplicated().any(): |
| dup_col = df.columns[df.columns.duplicated()][0] |
| |
| first_idx = list(df.columns).index(dup_col) |
| keep = [True] * len(df.columns) |
| for i, c in enumerate(df.columns): |
| if c == dup_col and i != first_idx: |
| keep[i] = False |
| df = df.loc[:, keep] |
|
|
| |
| for c in ("latitudes", "longitudes"): |
| if c in df.columns and not isinstance(df[c], pd.Series): |
| |
| df[c] = df[c].iloc[:, 0] |
| if c in df.columns: |
| df[c] = df[c].apply( |
| lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x |
| ) |
| df[c] = pd.to_numeric(df[c], errors="coerce") |
|
|
| |
| if {"longitudes", "latitudes"}.issubset(df.columns): |
| cols = list(df.columns) |
| lon_idx = cols.index("longitudes") |
| lat_idx = cols.index("latitudes") |
| if lat_idx != lon_idx + 1: |
| cols.pop(lat_idx) |
| cols.insert(lon_idx + 1, "latitudes") |
| df = df[cols] |
|
|
| return df |
|
|
|
|
|
|
| |
| @st.cache_data(show_spinner=False) |
| def rag_csv_fallback(subtask, top_k=2000): |
| encoder = get_encoder() |
| query_vec = encoder.encode([subtask]) |
| query_vec = normalize(query_vec, axis=1).astype("float32") |
| if not np.any(query_vec): |
| return pd.DataFrame() |
| all_results = [] |
| for index, metadata in zip(csv_faiss_indexes, csv_metadatas): |
| if index is None or metadata.empty: |
| continue |
| distances, indices = index.search(query_vec, top_k) |
| retrieved = metadata.iloc[indices[0]].copy() |
| all_results.append(retrieved) |
| if all_results: |
| return pd.concat(all_results).drop_duplicates().reset_index(drop=True) |
| return pd.DataFrame() |
|
|
|
|
|
|
| def generate_cypher_with_gpt(subtask: str) -> str: |
| prompt = f""" |
| You are an expert Cypher query generator for a Neo4j biodiversity database. The schema is as follows: |
| |
| Node Types and Properties: |
| - Observation: animal_name, date, latitude, longitude |
| - Species: name, species_full_name |
| - Site: name |
| - County: name |
| - State: name |
| - Hurricane: name |
| - Policy: title, description |
| - ClimateEvent: event_type, date |
| - TemperatureReading: value, date, location |
| - Precipitation: amount, date, location |
| |
| Relationship Types: |
| - OBSERVED_IN: (Observation)-[:OBSERVED_IN]->(Site) |
| - OBSERVED_ORGANISM: (Observation)-[:OBSERVED_ORGANISM]->(Species) |
| - BELONGS_TO: (Site)-[:BELONGS_TO]->(County) |
| - IN_COUNTY: (Observation)-[:IN_COUNTY]->(County) |
| - IN_STATE: (County)-[:IN_STATE]->(State) |
| - interactsWith: (Species)-[:interactsWith]->(Species) |
| - preysOn: (Species)-[:preysOn]->(Species) |
| |
| Your task is to generate a **precise and efficient** Cypher query for the following subtask: |
| "{subtask}" |
| |
| Guidelines: |
| - Do NOT return all nodes of a type (e.g., all Species) unless the subtask explicitly asks for it. |
| - If a location (county/state) is mentioned or implied, include location filtering using IN_COUNTY, IN_STATE, or BELONGS_TO. |
| - If the subtask implies a taxonomic or common name group (e.g., frog, snake, salmon), apply CONTAINS or STARTS WITH filters on Species.name or species_full_name, using toLower(...) for case-insensitive matching. |
| - If the subtask includes a time range, include date filtering. |
| - Prefer using DISTINCT to avoid redundant results. |
| - Only return fields that are clearl y needed to fulfill the subtask. |
| |
| Return your response strictly as a **JSON object** with the following fields: |
| - "intent": a short description of what the query does |
| - "cypher_query": the Cypher query |
| - "fields": a list of returned field names (e.g., ["species", "county", "date"]) |
| |
| Do not include any explanation or commentary—only return the JSON object. |
| """ |
|
|
| |
| client = openai.OpenAI(api_key=os.getenv("openai_api_key")) |
| response = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{"role": "user", "content": prompt}], |
| temperature=0 |
| ) |
| content = response.choices[0].message.content.strip() |
| content = re.sub(r"^(json|python)?", "", content, flags=re.IGNORECASE).strip() |
| content = re.sub(r"$", "", content).strip() |
|
|
| try: |
| cypher_json = json.loads(content) |
| return cypher_json["cypher_query"] |
| except Exception as e: |
| return "" |
| |
|
|
| def intelligent_retriever_agent(subtask, saved_hashes=None): |
| if saved_hashes is None: |
| saved_hashes = set() |
| st.success("🔍 Attempting to retrieve data from the Ecodata knowledge graph…") |
| cypher_query = generate_cypher_with_gpt(subtask) |
| cypher_df = pd.DataFrame() |
| if cypher_query.strip(): |
| st.code(cypher_query, language="cypher") |
| try: |
| query = re.sub(r"(?i)LIMIT\s+\d+\s*$", "", cypher_query) |
| with driver.session() as session: |
| result = session.run(query) |
| cypher_df = pd.DataFrame(result.data()) |
| except Exception as e: |
| st.error(f"🚨 Cypher execution error: {e}") |
| st.code(query, language="cypher") |
| |
| fallback_needed = False |
| if cypher_df.empty: |
| |
| fallback_needed = True |
| else: |
| df_hash = hashlib.md5(cypher_df.to_csv(index=False).encode()).hexdigest() |
| st.write(f"ℹ️ Cypher rows: {len(cypher_df)} | duplicate?: {df_hash in saved_hashes}") |
| if df_hash in saved_hashes or len(cypher_df) < 10: |
| fallback_needed = True |
| if fallback_needed: |
| csv_df = rag_csv_fallback(subtask) |
| if not csv_df.empty: |
| csv_df = flatten_props(csv_df) |
| csv_df = unpack_singletons(csv_df) |
| csv_df = standardize_latlon(csv_df) |
| |
| return csv_df |
| st.warning("❌ CSV fallback also returned nothing.") |
| return pd.DataFrame() |
| |
| st.success("✅ Cypher query successful. Using Cypher result.") |
| cypher_df = flatten_props(cypher_df) |
| cypher_df = unpack_singletons(cypher_df) |
| cypher_df = standardize_latlon(cypher_df) |
| if "species" not in cypher_df.columns and "animal_name" in cypher_df.columns: |
| cypher_df["species"] = cypher_df["animal_name"] |
| if "date" in cypher_df.columns: |
| cypher_df["date"] = pd.to_datetime(cypher_df["date"], errors="coerce") |
| cypher_df.rename(columns={"latitudes": "latitude", "longitudes": "longitude", "lat": "latitude", "lon": "longitude"}, inplace=True) |
| for col in ("latitude", "longitude"): |
| if col in cypher_df.columns: |
| cypher_df[col] = pd.to_numeric(cypher_df[col], errors="coerce") |
| return cypher_df |
|
|
|
|
| def planner_agent(question: str) -> str: |
| prompt = f""" |
| You are a **research‑data planning assistant**. |
| |
| ------------------------ 📝 TASK ------------------------ |
| Your job is to list the **separate data sets** a researcher must collect |
| to answer the research question below. |
| |
| *Each data set* should be focused on one clearly defined entity or |
| phenomenon (e.g. "Tracks of hurricanes affecting Florida since 1950", |
| "Geo‑tagged snake observations in Florida 2000‑present"). |
| |
| -------------------- 📋 OUTPUT FORMAT -------------------- |
| Write 1–3 blocks. For **each** block use *all* four lines exactly: |
| |
| Dataset Need X: <Concise title, ≤ 10 words> |
| Description: <Why this data matters—1 short sentence> |
| |
| ⚠️ Do NOT add extra lines or markdown. |
| ⚠️ Keep variable names short; no code blocks; no quotes. |
| |
| -------------------- 🔍 RESEARCH QUESTION -------------------- |
| {question} |
| """ |
| rsp = openai_client.chat.completions.create( |
| model="gpt-4o", |
| messages=[ |
| {"role": "system", "content": "You are an expert research planner."}, |
| {"role": "user", "content": prompt} |
| ], |
| temperature=0.2 |
| ) |
| return rsp.choices[0].message.content.strip() |
|
|
|
|
|
|
| def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame, client=openai_client) -> str: |
| max_columns = 50 |
| selected_cols = df.columns[:max_columns] |
| column_info = {col: str(df[col].dtype) for col in selected_cols} |
| sample_rows = df.head(3)[selected_cols].to_dict(orient="records") |
|
|
| prompt = f""" |
| You are a data‑validation assistant. Decide whether the dataset below is useful for the research subtask. |
| |
| ===== TASK ===== |
| Subtask: "{subtask}" |
| |
| ===== DATASET PREVIEW ===== |
| Schema (first {len(selected_cols)} columns): |
| {json.dumps(column_info, indent=10)} |
| Sample rows (10 max): |
| {json.dumps(sample_rows, indent=10)} |
| |
| ===== OUTPUT INSTRUCTIONS (follow strictly) ===== |
| Case A – Relevant: |
| • Write exactly two sentences, each no more than 30 words. |
| • Summarize what the dataset contains and why it helps the subtask. |
| • Do not mention column names or list individual rows. |
| |
| Case B – Not relevant: |
| • Write one or two sentences, each no more than 30 words, **describing only what the dataset contains**. |
| • Do **not** mention the subtask, relevance, suitability, limitations, or missing information (avoid phrases like “not related,” “does not focus,” “irrelevant,” etc.). |
| • After the sentences, output the header **Additionally, here are some external resources you might find helpful:** on a new line. Format your output in markdown as: |
| - [Name of Source](URL) |
| • Then list 2–3 bullet points, each on its own line, starting with “- ” followed immediately by a URL likely to contain the needed data. |
| • No additional commentary. |
| |
| |
| |
| General rules: |
| Plain text only — no code fences. Markdown link syntax (`[text](url)`) is allowed. |
| """ |
|
|
| rsp = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{"role": "user", "content": prompt}], |
| temperature=0.3, |
| ) |
| return rsp.choices[0].message.content.strip() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
|
|
| def external_resource_recommender(subtask: str, client=openai_client) -> str: |
| prompt = f""" |
| You are a helpful research assistant. Your task is to recommend **three reliable, publicly accessible online datasets or data repositories** that can assist with the following scientific subtask: |
| |
| {subtask} |
| |
| Only include sources that are: |
| - Trusted (e.g., government, academic, or well-established platforms) |
| - Relevant to the topic |
| - Accessible without login when possible |
| |
| Format your answer strictly in markdown: |
| - [Name of Source](URL) |
| - [Name of Source](URL) |
| - [Name of Source](URL) |
| |
| Do not include any explanations or extra text—only the list. |
| """ |
| rsp = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{"role": "user", "content": prompt}], |
| temperature=0.3 |
| ) |
| return rsp.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
| def fallback_query_router(subtask: str, driver) -> pd.DataFrame: |
| text = subtask.lower() |
|
|
| with driver.session() as session: |
|
|
| |
| if "where" in text and ("observed" in text or "found" in text): |
| query = """ |
| MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species) |
| RETURN s.name AS species, o.site_name AS location, o.date AS date |
| ORDER BY o.date DESC |
| """ |
|
|
| |
| elif "before" in text or "after" in text: |
| years = re.findall(r'\b(19|20)\d{2}\b', text) |
| if years: |
| op = "<" if "before" in text else ">" |
| query = f""" |
| MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species) |
| WHERE o.date {op} date('{years[0]}-01-01') |
| RETURN s.name AS species, o.site_name AS location, o.date AS date |
| ORDER BY o.date DESC |
| """ |
| else: |
| query = "RETURN 1" |
|
|
| |
| elif "hurricane" in text: |
| query = """ |
| MATCH (o:Observation)-[:OBSERVED_AT]->(h:Hurricane), |
| (o)-[:OBSERVED_ORGANISM]->(s:Species), |
| (o)-[:OBSERVED_IN]->(site)-[:BELONGS_TO]->(c:County)-[:IN_STATE]->(st:State) |
| WHERE st.name = 'Florida' |
| RETURN h.name AS hurricane, |
| s.name AS species, |
| site.name AS site, |
| c.name AS county, |
| o.date AS date |
| ORDER BY o.date DESC |
| """ |
|
|
| |
| elif "preys on" in text or "predator" in text: |
| query = """ |
| MATCH (s1:Species)-[:preysOn]->(s2:Species) |
| RETURN s1.name AS predator, s2.name AS prey |
| """ |
|
|
| |
| else: |
| query = """ |
| MATCH (o:Observation) |
| RETURN o.animal_name AS species, o.site_name AS location, o.date AS date |
| """ |
|
|
| |
| result = session.run(query) |
| df = pd.DataFrame(result.data()) |
|
|
| if df.empty: |
| st.info("🌐 I couldn't find relevant data in KN‑Wildlife. Let me check external sources for you...") |
| suggestions = external_resource_recommender(subtask) |
| st.markdown(suggestions) |
|
|
| return df |
|
|
|
|
| def save_dataset(df: pd.DataFrame, filename: str) -> str: |
| if len(df) < 10: |
| st.warning(f"❌ Dataset too small to save: only {len(df)} rows.") |
| return "" |
| save_dir = "/tmp/saved_datasets" |
| os.makedirs(save_dir, exist_ok=True) |
| path = f"{save_dir}/{filename}.csv" |
| if os.path.exists(path): |
| old_hash = hashlib.md5(open(path, 'rb').read()).hexdigest() |
| new_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest() |
| if old_hash == new_hash: |
| st.info(f"ℹ️ Dataset saved: {filename}.csv") |
| return path |
| df.to_csv(path, index=False) |
| st.info(f"✅ Dataset saved: {filename}.csv") |
|
|
| return path |
| |
|
|
| def suggest_charts_with_gpt(df: pd.DataFrame) -> str: |
| """Generate Streamlit chart code for automatic visualisation.""" |
| try: |
| |
|
|
| |
| if "date" in df.columns: |
| df["date"] = df["date"].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x) |
| df["date"] = pd.to_datetime(df["date"], errors="coerce") |
|
|
| if "animal_name" in df.columns and "species" not in df.columns: |
| df["species"] = df["animal_name"] |
|
|
| df.rename(columns={"latitudes": "latitude", "longitudes": "longitude"}, inplace=True) |
|
|
| chart_code = """ |
| # --- Species Bar Chart --- |
| if 'species' in df.columns: |
| st.markdown('📊 Count of Observations by Species') |
| try: |
| species_counts = df['species'].astype(str).value_counts() |
| st.bar_chart(species_counts) |
| except Exception as e: |
| st.warning(f'⚠️ Could not render species chart: {e}') |
| |
| # --- Timeline Line Chart --- |
| if 'date' in df.columns: |
| st.markdown('📈 Observations Over Time') |
| try: |
| timeline = df['date'].dropna().value_counts().sort_index() |
| st.line_chart(timeline) |
| except Exception as e: |
| st.warning(f'⚠️ Could not render date chart: {e}') |
| |
| # --- Map Visualisation (highlight all points) --- |
| if 'latitude' in df.columns and 'longitude' in df.columns: |
| st.markdown('🗺️ Observation Locations on Map') |
| try: |
| coords = df[['latitude', 'longitude']].dropna() |
| coords = coords[(coords['latitude'].between(-90, 90)) & (coords['longitude'].between(-180, 180))] |
| |
| if len(coords) == 0: |
| raise Exception('⚠️ No valid coordinates to plot on the map.') |
| else: |
| # 计算中心点 |
| center = [coords['latitude'].mean(), coords['longitude'].mean()] |
| m = folium.Map(location=center, zoom_start=5) |
| |
| # 添加散点 |
| for _, row in coords.iterrows(): |
| folium.CircleMarker( |
| location=[row['latitude'], row['longitude']], |
| radius=5, |
| color='green', |
| fill=True, |
| fill_color='green', |
| fill_opacity=0.7, |
| ).add_to(m) |
| |
| st_folium(m, width=700, height=500) |
| except Exception as e: |
| st.warning(f'⚠️ Could not render map: {e}') |
| """ |
| return textwrap.dedent(chart_code) |
| except Exception as outer_error: |
| return f"st.warning('❌ Chart generation failed: {outer_error}')" |
|
|
|
|
|
|
|
|
| |
| if "chat_history" not in st.session_state: |
| st.session_state.chat_history = [] |
| |
| |
| |
| |
| |
| |
|
|
| |
| st.markdown( |
| """ |
| <style> |
| /* 针对正文文字 */ |
| html, body, .block-container, .markdown-text-container { |
| font-size: 19px !important; /* ← 这里改数字 */ |
| line-height: 1.6 !important; |
| } |
| /* 把默认窄屏的 max-width(约700px)改成 1400px,视需要可调整 */ |
| .block-container { |
| max-width: 1600px; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True |
| ) |
|
|
| st.title("🐾 Quest2DataAgent_EcoData") |
|
|
|
|
| st.success(""" |
| 👋 Hi there! I’m **Lily**, your research assistant bot 🤖. I’m here to help you explore data sources related to your **complex research question**. Let’s work together to find the information you need! |
| |
| 💡 You can start by entering a research question like: |
| |
| - *In Florida, how do hurricanes affect the distribution of snakes?* |
| - *How does precipitation impact salmon abundance in freshwater ecosystems?* |
| - *How do climate change and urbanization jointly affect bird migration and diversity in Florida?* |
| """) |
|
|
| if driver: |
| st.success("🟢 Connected to **Ecodata** — a Neo4j-powered biodiversity graph focused on species and ecosystems. I’ll start by checking what relevant data we already have in Ecodata to support your research.") |
|
|
| else: |
| st.error("🔴 Failed to connect to Ecodata! Please fix connection first.") |
| st.stop() |
|
|
| question = st.text_area("Enter your research question:", "") |
|
|
| |
| if "start_clicked" not in st.session_state: |
| st.session_state.start_clicked = False |
| if "subtask_plan" not in st.session_state: |
| st.session_state.subtask_plan = "" |
| if "ready_to_continue" not in st.session_state: |
| st.session_state.ready_to_continue = False |
| if "stop_requested" not in st.session_state: |
| st.session_state.stop_requested = False |
| if "visualization_ready" not in st.session_state: |
| st.session_state.visualization_ready = False |
| if "do_visualize" not in st.session_state: |
| st.session_state.do_visualize = False |
| if "all_dataframes" not in st.session_state: |
| st.session_state.all_dataframes = [] |
| if "retrieval_done" not in st.session_state: |
| st.session_state.retrieval_done = False |
|
|
| |
| if st.button("Let’s start") and question.strip(): |
| st.session_state.start_clicked = True |
| st.session_state.subtask_plan = planner_agent(question) |
| st.session_state.ready_to_continue = False |
| st.session_state.stop_requested = False |
| st.session_state.visualization_ready = False |
| st.session_state.do_visualize = False |
| st.session_state.all_dataframes = [] |
| st.session_state.retrieval_done = False |
|
|
| |
| if st.session_state.start_clicked: |
| |
| st.success("🧠 I’ve identified the distinct datasets you’ll need for this research question.") |
| with st.expander("🔹 Curious how I split your question? Click to see!", expanded=True): |
| st.write(st.session_state.subtask_plan) |
|
|
| st.success("📌 I’m ready to roll up my sleeves — shall I start finding datasets for each subtask? 🕒 This step might take a little while, so thanks for your patience!") |
|
|
| col1, col2 = st.columns([1, 1]) |
| with col1: |
| if st.button("✅ Yes, go ahead", key="confirm_button"): |
| st.session_state.ready_to_continue = True |
| st.session_state.stop_requested = False |
| with col2: |
| if st.button("⛔ No, stop here", key="stop_button"): |
| st.session_state.ready_to_continue = False |
| st.session_state.stop_requested = True |
|
|
|
|
| |
| if st.session_state.ready_to_continue: |
|
|
| |
| |
| if "Dataset Need" in st.session_state.subtask_plan: |
| prefix = "Dataset Need" |
| else: |
| prefix = "Subtask" |
|
|
| |
| pattern = rf"{prefix} \d+:.*?(?={prefix} \d+:|$)" |
| subtasks = re.findall(pattern, |
| st.session_state.subtask_plan, |
| flags=re.DOTALL) |
|
|
| |
| if not subtasks: |
| st.warning("⚠️ No dataset blocks detected in planner output.") |
| st.stop() |
|
|
| |
| if not st.session_state.retrieval_done: |
| progress_bar = st.progress(0) |
| total = len(subtasks) |
| saved_hashes = set() |
| st.session_state.all_dataframes = [] |
|
|
|
|
| for idx, subtask in enumerate(subtasks): |
| |
| with st.expander(f"🔹 Retrieving data for dataset need {idx+1}:", expanded=True): |
| cleaned_subtask = "\n".join(subtask.strip().split("\n")[1:]) |
| st.markdown(cleaned_subtask) |
|
|
| |
| if not st.session_state.retrieval_done: |
| df = intelligent_retriever_agent(subtask, saved_hashes) |
|
|
| if not df.empty: |
| df_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest() |
| if df_hash in saved_hashes: |
| st.warning("⚠️ This dataset has already been saved — skipping duplicate.") |
| elif len(df) < 10: |
| st.warning(f"❌ This dataset is too small — just {len(df)} rows. Skipping save.") |
| else: |
| saved_hashes.add(df_hash) |
| df = flatten_props(df) |
| df = standardize_latlon(df) |
| summary = evaluate_dataset_with_gpt(subtask, df) |
| st.session_state.all_dataframes.append({ |
| "hash": df_hash, |
| "df": df, |
| "summary": summary |
| }) |
| |
| |
| |
| |
| |
| |
| st.dataframe(df.head(50)) |
| save_path = save_dataset(df, f"subtask_{idx+1}") |
| if save_path: |
| st.markdown("**📝 Dataset Introduction:**") |
| st.write(summary) |
| |
| with open(save_path, "rb") as f: |
| st.download_button( |
| label="📥 Download dataset (CSV)", |
| data=f, |
| file_name=os.path.basename(save_path), |
| mime="text/csv", |
| key=f"download_init_{idx}" |
| ) |
|
|
| if 'progress_bar' in locals(): |
| progress_bar.progress((idx + 1) / total) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| else: |
| if idx < len(st.session_state.all_dataframes): |
| entry = st.session_state.all_dataframes[idx] |
| df = standardize_latlon(entry["df"]) |
| st.dataframe(df.head(50)) |
| |
| st.markdown("**📝 Dataset Introduction:**") |
| st.write(entry.get("summary", "")) |
| |
| tmp_path = f"/tmp/subtask_{idx+1}_display.csv" |
| df.to_csv(tmp_path, index=False) |
| with open(tmp_path, "rb") as f: |
| st.download_button( |
| label="📥 Download dataset (CSV)", |
| data=f, |
| file_name=os.path.basename(tmp_path), |
| mime="text/csv", |
| key=f"download_rerun_{idx}" |
| ) |
|
|
|
|
|
|
| |
| if not st.session_state.retrieval_done: |
| st.session_state.retrieval_done = True |
| st.session_state.visualization_ready = bool(st.session_state.all_dataframes) |
|
|
|
|
|
|
| if st.session_state.all_dataframes: |
| st.session_state.visualization_ready = True |
| else: |
| st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!") |
| |
| |
|
|
| |
| if st.session_state.visualization_ready and not st.session_state.do_visualize: |
| st.success("📊 All set! I’ve gathered the datasets. Ready to visualize them?") |
|
|
| col1, col2 = st.columns([1, 1]) |
| with col1: |
| if st.button("✅ Yes, go ahead", key="viz_confirm"): |
| st.session_state.do_visualize = True |
| with col2: |
| if st.button("⛔ No, stop here", key="viz_stop"): |
| st.session_state.visualization_ready = False |
| st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!") |
| |
| |
|
|
|
|
| |
| if st.session_state.do_visualize: |
| for i, entry in enumerate(st.session_state.all_dataframes): |
| df = entry["df"] |
| summary = entry.get("summary", "") |
| if len(df) < 10: |
| continue |
| with st.expander(f"**🔹 Dataset {i + 1} Visualization**", expanded=True): |
| st.markdown(f"Dataset {i + 1} Preview") |
| st.dataframe(df.head(10)) |
| chart_code = suggest_charts_with_gpt(df) |
| if chart_code: |
| try: |
| exec(chart_code, {"st": st, "pd": pd, "df": df, "pdk": pdk, "folium": folium, "st_folium": st_folium}) |
| except Exception as e: |
| st.error(f"❌ Error running chart code: {e}") |
|
|
|
|
| st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!") |
|
|
|
|
|
|
|
|
| if st.session_state.stop_requested: |
| st.info("👍 No problem! You can review the subtasks above or revise your question.") |
|
|
|
|
|
|
| |
| with st.sidebar.expander("💬 Chat with Lily", expanded=True): |
| |
| user_msg = st.chat_input("Type your question here…", key="sidebar_chat_input") |
| if user_msg: |
| |
| context_parts = [] |
| if st.session_state.subtask_plan: |
| context_parts.append("Subtasks:\n" + st.session_state.subtask_plan) |
| for entry in st.session_state.all_dataframes: |
| context_parts.append("Data summary:\n" + entry["summary"]) |
| page_context = "\n\n".join(context_parts) |
|
|
| |
| with st.spinner("Lily is thinking…"): |
| assistant_msg = gpt_chat( |
| sys_msg=f"You are Lily, a research assistant. Here’s what’s on screen:\n\n{page_context}", |
| user_msg=user_msg |
| ) |
|
|
| |
| st.session_state.chat_history.append({"role": "user", "content": user_msg}) |
| st.session_state.chat_history.append({"role": "assistant", "content": assistant_msg}) |
|
|
| |
| for msg in st.session_state.chat_history: |
| if msg["role"] == "user": |
| st.chat_message("user").write(msg["content"]) |
| else: |
| st.chat_message("assistant").write(msg["content"]) |
|
|