Spaces:
Running
Running
| """Shared session state management and data I/O utilities. | |
| Manages cross-page state (selected ROIs, predictions, analysis log) | |
| and provides upload/download widgets. | |
| """ | |
| import io | |
| import json | |
| from datetime import datetime | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| def init_session(): | |
| """Initialize session state with defaults. Safe to call multiple times.""" | |
| defaults = { | |
| "brain_predictions": None, | |
| "model_features": {}, | |
| "roi_indices": None, | |
| "n_vertices": 0, | |
| "selected_rois": [], | |
| "data_source": "synthetic", | |
| "stimulus_type": "visual", | |
| "tr_seconds": 1.0, | |
| "n_timepoints": 80, | |
| "seed": 42, | |
| "analysis_log": [], | |
| "carry_rois": [], # ROIs carried from another page | |
| } | |
| for key, value in defaults.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = value | |
| def log_analysis(description): | |
| """Append an entry to the analysis log.""" | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| entry = f"[{timestamp}] {description}" | |
| if "analysis_log" not in st.session_state: | |
| st.session_state["analysis_log"] = [] | |
| st.session_state["analysis_log"].append(entry) | |
| def carry_rois(rois, target_page=""): | |
| """Store selected ROIs for cross-page workflow.""" | |
| st.session_state["carry_rois"] = list(rois) | |
| log_analysis(f"Carried {len(rois)} ROIs to {target_page}") | |
| def get_carried_rois(): | |
| """Retrieve ROIs carried from another page.""" | |
| return st.session_state.get("carry_rois", []) | |
| def get_or_generate_data(roi_indices): | |
| """Get brain predictions from session or generate new synthetic data.""" | |
| from synthetic import generate_realistic_predictions | |
| params_key = ( | |
| st.session_state.get("n_timepoints", 80), | |
| st.session_state.get("stimulus_type", "visual"), | |
| st.session_state.get("seed", 42), | |
| ) | |
| # Check if we need to regenerate | |
| if ( | |
| st.session_state.get("brain_predictions") is None | |
| or st.session_state.get("_data_params") != params_key | |
| or st.session_state.get("data_source") == "synthetic" | |
| ): | |
| if st.session_state.get("data_source") == "uploaded" and st.session_state.get("brain_predictions") is not None: | |
| return st.session_state["brain_predictions"] | |
| predictions = generate_realistic_predictions( | |
| n_timepoints=st.session_state["n_timepoints"], | |
| roi_indices=roi_indices, | |
| stimulus_type=st.session_state["stimulus_type"], | |
| tr_seconds=st.session_state["tr_seconds"], | |
| seed=st.session_state["seed"], | |
| ) | |
| st.session_state["brain_predictions"] = predictions | |
| st.session_state["_data_params"] = params_key | |
| return st.session_state["brain_predictions"] | |
| def upload_npy_widget(label, key): | |
| """File uploader for .npy arrays with validation.""" | |
| uploaded = st.file_uploader(label, type=["npy"], key=key) | |
| if uploaded is not None: | |
| try: | |
| data = np.load(io.BytesIO(uploaded.read())) | |
| st.success(f"Loaded: shape {data.shape}, dtype {data.dtype}") | |
| return data | |
| except Exception as e: | |
| st.error(f"Failed to load file: {e}") | |
| return None | |
| def download_csv_button(df, filename, label="Download CSV"): | |
| """Download button for a pandas DataFrame as CSV.""" | |
| csv = df.to_csv(index=False) | |
| st.download_button(label, csv, filename, "text/csv") | |
| def download_json_button(data, filename, label="Download JSON"): | |
| """Download button for a dict as JSON.""" | |
| json_str = json.dumps(data, indent=2, default=str) | |
| st.download_button(label, json_str, filename, "application/json") | |
| def show_analysis_log(): | |
| """Display the analysis log in the sidebar.""" | |
| log = st.session_state.get("analysis_log", []) | |
| if log: | |
| with st.sidebar: | |
| with st.expander("Analysis Log", expanded=False): | |
| for entry in reversed(log[-20:]): | |
| st.caption(entry) | |
| def data_summary_widget(predictions, roi_indices): | |
| """Show a summary of the current data.""" | |
| if predictions is None: | |
| st.info("No data loaded. Generate synthetic data or upload your own.") | |
| return | |
| col1, col2, col3, col4 = st.columns(4) | |
| col1.metric("Timepoints", predictions.shape[0]) | |
| col2.metric("Vertices", predictions.shape[1]) | |
| col3.metric("ROIs", len(roi_indices)) | |
| col4.metric("Source", st.session_state.get("data_source", "synthetic").title()) | |