"""NeuroBridge Enterprise — Streamlit B2B dashboard (Editorial redesign). Five tabs (Molecule / Signal / Image / AI Assistant / Experiments) sitting on top of one FastAPI surface. Every interaction returns an auditable decision artefact: label + confidence + calibration + drift + provenance + SHAP. Visual language (post-redesign): - Dark theme = editorial Netflix-style — deep neutral grays + sand accent - Light theme = warm paper + charcoal type — Apple HIG / NYT-Cooking energy - Single sand brand-mark across both themes (#D2C4B1) - Inter (display + body) + JetBrains Mono (data / code) Launch: `streamlit run src/frontend/app.py` """ from __future__ import annotations import html as _html import os import httpx import streamlit as st _API_URL = os.environ.get("NEUROBRIDGE_API_URL", "http://localhost:8000") _MLFLOW_URL = os.environ.get( "NEUROBRIDGE_MLFLOW_URL", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"), ) _MLFLOW_DISABLED = os.environ.get("NEUROBRIDGE_DISABLE_MLFLOW") == "1" _LLM_DISABLED = os.environ.get("NEUROBRIDGE_DISABLE_LLM") == "1" # --------------------------------------------------------------------------- # # Design tokens — single source of truth for both themes. # # Tokens are exposed as CSS custom properties at the :root level; every # # component reads from them so a theme swap is just a value swap. # # --------------------------------------------------------------------------- # _TOKENS_DARK = { # Surfaces (deepest → most elevated) "bg-base": "#0e0e10", "bg-elevated": "#161618", "bg-elevated-2": "#1e1e21", "bg-elevated-3": "#2a2a2e", # Brand accent "accent": "#D2C4B1", "accent-strong": "#E8DCC6", "accent-soft": "rgba(210, 196, 177, 0.12)", "accent-ring": "rgba(210, 196, 177, 0.35)", # Text "text-primary": "#F5F2ED", "text-secondary": "#A8A29A", "text-tertiary": "#6B6660", "text-on-accent": "#161618", # Lines "border": "#2a2a2e", "border-strong": "#3a3a3e", # Semantic (keep cool — never red/green dominant in editorial) "success": "#7FB069", "warning": "#E0B469", "danger": "#D97A6C", # Effects "shadow-sm": "0 1px 2px rgba(0, 0, 0, 0.4)", "shadow-md": "0 8px 24px rgba(0, 0, 0, 0.45)", "shadow-lg": "0 16px 48px rgba(0, 0, 0, 0.55)", } _TOKENS_LIGHT = { "bg-base": "#FAF7F2", "bg-elevated": "#FFFFFF", "bg-elevated-2": "#F5F0E8", "bg-elevated-3": "#EDE5D5", "accent": "#1e1e21", "accent-strong": "#0e0e10", "accent-soft": "rgba(30, 30, 33, 0.06)", "accent-ring": "rgba(30, 30, 33, 0.18)", "text-primary": "#161618", "text-secondary": "#4A4540", "text-tertiary": "#8A857E", "text-on-accent": "#FAF7F2", "border": "#E5DDC9", "border-strong": "#D2C4B1", "success": "#3F7D45", "warning": "#A06D1F", "danger": "#A1483D", "shadow-sm": "0 1px 2px rgba(40, 30, 20, 0.04)", "shadow-md": "0 4px 16px rgba(40, 30, 20, 0.08)", "shadow-lg": "0 12px 40px rgba(40, 30, 20, 0.12)", } def _build_css(theme: str) -> str: """Return the full """ # --------------------------------------------------------------------------- # # Theme management # # --------------------------------------------------------------------------- # def _init_theme() -> str: """Initialize and return the active theme ('dark' default).""" if "theme" not in st.session_state: st.session_state["theme"] = "dark" return st.session_state["theme"] def _altair_theme(theme: str) -> dict: """Return an altair theme matching the active palette. Registered as 'neurobridge' on first call; subsequent calls just enable. """ tokens = _TOKENS_DARK if theme == "dark" else _TOKENS_LIGHT return { "config": { "background": tokens["bg-elevated"], "view": {"stroke": "transparent"}, "axis": { "labelColor": tokens["text-secondary"], "titleColor": tokens["text-secondary"], "labelFont": "Inter", "titleFont": "Inter", "labelFontSize": 11, "titleFontSize": 12, "gridColor": tokens["border"], "domainColor": tokens["border"], "tickColor": tokens["border"], }, "header": { "labelColor": tokens["text-primary"], "labelFont": "Inter", "labelFontSize": 13, "labelFontWeight": 600, "titleColor": tokens["text-secondary"], }, "legend": { "labelColor": tokens["text-secondary"], "titleColor": tokens["text-secondary"], "labelFont": "Inter", "titleFont": "Inter", }, "title": { "color": tokens["text-primary"], "font": "Inter", "fontWeight": 600, }, "range": { # Editorial palette: sand-led, then warm secondaries. "category": [ tokens["accent"], "#8FB3C9", "#C99B8F", "#9DAD86", "#B8A4C9", "#D4B86A", "#7FB069", "#A6A2C2", ], }, } } def _register_altair_theme(theme: str) -> None: """Register + enable the neurobridge altair theme for the current run.""" try: import altair as alt alt.themes.register("neurobridge", lambda: _altair_theme(theme)) alt.themes.enable("neurobridge") except Exception: # altair may not be importable in some environments; chart calls # will simply use altair defaults — no functional impact. pass # --------------------------------------------------------------------------- # # HTTP helpers # # --------------------------------------------------------------------------- # def _check_api_health() -> tuple[bool, str]: """Ping FastAPI /health endpoint; return (ok, status_text).""" try: resp = httpx.get(f"{_API_URL}/health", timeout=2.0) if resp.status_code == 200: return True, "operational" return False, f"http {resp.status_code}" except httpx.RequestError as e: return False, type(e).__name__.lower() def _post(endpoint: str, payload: dict, timeout: float = 120.0) -> dict: """POST to the FastAPI surface; let httpx raise on non-2xx.""" resp = httpx.post(f"{_API_URL}{endpoint}", json=payload, timeout=timeout) resp.raise_for_status() return resp.json() def _get(path: str) -> dict: """GET helper symmetric with _post.""" resp = httpx.get(f"{_API_URL}{path}", timeout=10.0) resp.raise_for_status() return resp.json() # --------------------------------------------------------------------------- # # Hero / sidebar / section primitives # # --------------------------------------------------------------------------- # def _render_brand_header(api_ok: bool, api_status: str) -> None: """Editorial hero strip: word-mark + tagline + 3 status dots.""" api_class = "is-ok" if api_ok else "is-down" mlflow_class = "is-mute" if _MLFLOW_DISABLED else "is-ok" mlflow_label = "tracking off" if _MLFLOW_DISABLED else "tracking" llm_class = "is-mute" if _LLM_DISABLED else "is-ok" llm_label = "template only" if _LLM_DISABLED else "llm online" st.markdown( f"""
Living decision system · clinical ML
Three production pipelines — molecule, signal, image — behind one auditable surface. Every prediction returns label, calibration, drift, provenance and a natural-language rationale.
{_html.escape(eyebrow)}
{_html.escape(desc)}
"
f"output → {safe_output_path}
" f"mlflow run · " f"{safe_run_id[:12]}…
", unsafe_allow_html=True, ) elif _MLFLOW_DISABLED: st.caption("mlflow tracking disabled (NEUROBRIDGE_DISABLE_MLFLOW=1)") def _render_sidebar(api_ok: bool, api_status: str) -> None: with st.sidebar: st.markdown( """ """, unsafe_allow_html=True, ) st.markdown("### Theme") theme = st.session_state.get("theme", "dark") is_dark = st.toggle( "Dark mode", value=(theme == "dark"), key="theme_toggle", help="Switch between editorial dark (Netflix-style) and warm paper (Apple HIG-style).", ) new_theme = "dark" if is_dark else "light" if new_theme != theme: st.session_state["theme"] = new_theme st.rerun() st.markdown("### System") api_class = "is-ok" if api_ok else "is-down" mlflow_class = "is-mute" if _MLFLOW_DISABLED else "is-ok" llm_class = "is-mute" if _LLM_DISABLED else "is-ok" st.markdown( f""""
f"fastapi · {_API_URL}
"
f"mlflow · {_MLFLOW_URL}
" "Trust-engineered clinical-ML platform. Three modalities — BBB drug " "screening, EEG signal cleaning, MRI multi-site harmonization — " "behind one FastAPI surface. Every inference is auditable.
", unsafe_allow_html=True, ) # --------------------------------------------------------------------------- # # Tabs # # --------------------------------------------------------------------------- # def _render_bbb_tab() -> None: _render_section( "MOLECULE — BBBP", "Blood-Brain-Barrier permeability decision", "Enter a SMILES string. The system computes a 2,048-bit Morgan " "fingerprint, runs it through a Random Forest classifier, and returns " "a label, calibration-grounded confidence, drift signal, and the top " "SHAP attributions explaining the decision.", ) EDGE_CASES = { "Custom input (default)": { "smiles": "CCO", "label": "Ethanol — small, drug-like, BBB-permeable", "expectation": "High confidence, label = permeable", }, "Invalid SMILES (parse-error path)": { "smiles": "this_is_not_a_valid_molecule_at_all_!!", "label": "Garbage string — should not parse", "expectation": "API returns HTTP 400 with parse error; UI shows recoverable warning", }, "Empty string (boundary)": { "smiles": "", "label": "Empty input — boundary condition", "expectation": "Pydantic accepts empty; API returns 400 (RDKit cannot parse)", }, "Massive OOD: cyclosporine-like macrocycle": { "smiles": ( "CC[C@H](C)[C@@H]1NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@H](CC(C)C)N(C)C(=O)" "[C@@H]2CCCN2C(=O)[C@H](C(C)C)NC(=O)[C@H]([C@@H](C)CC)N(C)C(=O)" "[C@H](C)NC(=O)[C@H](C)NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@@H](NC(=O)" "[C@H](CC(C)C)N(C)C(=O)CN(C)C1=O)C(C)C" ), "label": "Cyclosporine — 11-residue macrocycle (~1.2 kDa)", "expectation": ( "Far outside training distribution; model should hedge " "with low confidence (well-calibrated systems don't " "pretend to know)." ), }, "OOD: heavy halogenated aromatic": { "smiles": "Fc1c(F)c(F)c(c(F)c1F)c2c(F)c(F)c(F)c(F)c2F", "label": "Decafluorobiphenyl — extreme halogen density", "expectation": "Rare scaffold; expect lowered confidence vs ethanol", }, } case_name = st.selectbox( "Test edge cases", options=list(EDGE_CASES.keys()), index=0, key="bbb_case", help=( "Pick a robustness probe. Each case demonstrates how the system " "handles a real-world failure mode — invalid input, " "out-of-distribution molecules, or boundary conditions." ), ) case = EDGE_CASES[case_name] st.caption(f"**Probe:** {case['label']} · **Expected:** {case['expectation']}") smiles = st.text_input( "SMILES string", value=case["smiles"], key="bbb_smiles", help="Examples: CCO (ethanol), CC(=O)Nc1ccc(O)cc1 (paracetamol)", ) top_k = st.slider( "SHAP features to display", min_value=3, max_value=10, value=5, key="bbb_topk", ) if st.button("Predict BBB permeability", type="primary", key="bbb_predict"): with st.spinner("Computing fingerprint, predicting, explaining…"): try: result = _post("/predict/bbb", {"smiles": smiles, "top_k": top_k}) _render_prediction_card(result) st.toast("Prediction complete", icon="✅") except httpx.HTTPStatusError as e: if e.response.status_code == 503: st.error( "Model artifact not loaded yet. Run " "`python -m src.models.bbb_model` to train it, " "then retry." ) elif e.response.status_code == 400: # Robustness story: WARNING (recoverable), not ERROR. st.warning( f"Robustness check passed: API rejected the input " f"with HTTP 400 (no crash). Detail: " f"{e.response.json().get('detail', e.response.text)}" ) else: st.error( f"Prediction failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") def _render_eeg_tab() -> None: _render_section( "SIGNAL — EEG", "Electroencephalogram artifact removal", "Bandpass-filters raw FIF/EDF recordings, removes EOG artifacts via " "ICA decomposition, and extracts per-band PSD + statistical features " "across fixed-duration epochs.", ) eeg_in = st.text_input( "Input FIF/EDF path", "tests/fixtures/eeg_sample.fif", key="eeg_in", help="Path to a .fif/.edf EEG recording on the server filesystem.", ) eeg_out = st.text_input( "Output Parquet path", "data/processed/eeg_features.parquet", key="eeg_out", ) if st.button("Run EEG pipeline", type="primary", key="eeg_run"): with st.spinner("Filtering and running ICA…"): try: result = _post( "/pipeline/eeg", {"input_path": eeg_in, "output_path": eeg_out}, ) st.session_state["last_eeg_run"] = result _render_result(result) st.toast("EEG pipeline complete", icon="✅") except httpx.HTTPStatusError as e: st.error( f"Pipeline failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") last_eeg = st.session_state.get("last_eeg_run") if last_eeg is not None: with st.expander("Ask the AI Assistant about this EEG run", expanded=False): eeg_q_presets = [ "Why were certain ICA components dropped?", "What does the bandpass filter do?", "Is this run consistent with previous runs?", ] eeg_preset = st.selectbox( "Preset question", options=eeg_q_presets, key="eeg_ai_preset", ) eeg_custom = st.text_input( "Or type your own question (optional)", value="", key="eeg_ai_custom", ) eeg_question = eeg_custom.strip() or eeg_preset if st.button("Ask AI Assistant", key="eeg_ai_ask"): with st.spinner("Composing rationale…"): try: eeg_resp = _post( "/explain/eeg", { "rows": int(last_eeg.get("rows", 0)), "columns": int(last_eeg.get("columns", 0)), "duration_sec": float(last_eeg.get("duration_sec", 0.0)), "mlflow_run_id": last_eeg.get("mlflow_run_id"), "user_question": eeg_question, }, ) st.markdown(f"**A:** {eeg_resp['rationale']}") st.caption( f"Source: `{eeg_resp.get('source', '?')}` · " f"Model: `{eeg_resp.get('model') or '—'}`" ) except httpx.HTTPStatusError as e: st.error( f"Assistant failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI: {e!r}") def _render_mri_tab() -> None: _render_section( "IMAGE — MRI", "Multi-site harmonization via ComBat", "Loads NIfTI volumes, masks brain tissue, computes per-ROI summary " "statistics, then harmonizes across acquisition sites with " "neuroHarmonize to remove scanner-driven domain shift. The diagnostic " "plot below compares per-site feature distributions before and after " "harmonization.", ) mri_dir = st.text_input( "Input NIfTI directory", "tests/fixtures/mri_sample", key="mri_dir", help="Path to a directory of .nii(.gz) files + sites.csv", ) sites_csv = st.text_input( "Sites CSV", "tests/fixtures/mri_sample/sites.csv", key="mri_sites", ) if st.button("Run ComBat diagnostics", type="primary", key="mri_diag"): with st.spinner("Running pre + post ComBat (×2 the work)…"): try: result = _post( "/pipeline/mri/diagnostics", {"input_dir": mri_dir, "sites_csv": sites_csv}, ) _render_combat_diagnostics(result) st.toast("Diagnostics complete", icon="✅") except httpx.HTTPStatusError as e: st.error( f"Diagnostics failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") st.markdown("#### MRI Image Model") mri_kind = os.environ.get("MRI_MODEL_KIND", "volumetric_onnx") if mri_kind == "resnet18_2d": mri_image = st.text_input( "2D MRI image (.png/.jpg)", "tests/fixtures/mri_sample/subject_0_axial.png", key="mri_predict_image", ) st.caption( "Resnet18 4-class — labels: MildDemented, ModerateDemented, " "NonDemented, VeryMildDemented. Resize/labels are baked into the model." ) if st.button("Predict MRI image", key="mri_predict"): payload = {"input_path": mri_image} with st.spinner("Running 2D MRI model..."): try: result = _post("/predict/mri", payload, timeout=120.0) except httpx.HTTPStatusError as e: if e.response.status_code == 503: st.warning( "MRI 2D model artifact missing. Drop the trained checkpoint at " "`data/processed/mri_dl_2d/best_model.pt` or set `MRI_MODEL_PATH_2D`." ) else: st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {e.response.text}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") else: st.metric( label=result.get("label_text", "prediction"), value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%", ) probs = result.get("probabilities", []) if probs: st.dataframe(probs, use_container_width=True, hide_index=True) else: mri_image = st.text_input( "NIfTI image", "tests/fixtures/mri_sample/subject_0.nii.gz", key="mri_predict_image", ) mri_labels = st.text_input( "Class labels", "control,abnormal", key="mri_predict_labels", ) shape_cols = st.columns(3) target_d = shape_cols[0].number_input( "Resize D", min_value=1, max_value=256, value=64, step=1, key="mri_predict_d" ) target_h = shape_cols[1].number_input( "Resize H", min_value=1, max_value=256, value=64, step=1, key="mri_predict_h" ) target_w = shape_cols[2].number_input( "Resize W", min_value=1, max_value=256, value=64, step=1, key="mri_predict_w" ) st.caption( "Resize target as (D, H, W). Default 64³ matches typical model exports." ) if st.button("Predict MRI image", key="mri_predict"): labels = [x.strip() for x in mri_labels.split(",") if x.strip()] payload: dict = { "input_path": mri_image, "target_shape": [int(target_d), int(target_h), int(target_w)], } if labels: payload["label_names"] = labels with st.spinner("Running MRI image model..."): try: result = _post("/predict/mri", payload, timeout=120.0) except httpx.HTTPStatusError as e: detail = e.response.text if e.response.status_code == 503: st.warning( "MRI model artifact is not available yet. Export the trained " "ONNX model to `data/processed/mri_model.onnx` or set `MRI_MODEL_PATH`." ) else: st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {detail}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") else: st.metric( label=result.get("label_text", "prediction"), value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%", ) probs = result.get("probabilities", []) if probs: st.dataframe(probs, use_container_width=True, hide_index=True) st.markdown("#### EEG Pretrained Classifier") st.caption( "Pretrained sklearn classifier on EEG band-power features. " "Output: per-class probabilities for `(control, alzheimers)`." ) eeg_csv = st.text_area( "EEG features (comma-separated)", ",".join(["0.0"] * 16), key="eeg_predict_features", height=80, ) if st.button("Predict EEG", key="eeg_predict"): try: features = [float(x.strip()) for x in eeg_csv.split(",") if x.strip()] except ValueError: st.error("EEG features must all be numeric.") else: payload = {"features": features} with st.spinner("Running EEG classifier..."): try: result = _post("/predict/eeg", payload, timeout=30.0) except httpx.HTTPStatusError as e: if e.response.status_code == 503: st.warning( "EEG model artifact missing. Drop the trained joblib at " "`data/processed/eeg_clf.joblib` or set `EEG_CLF_ARTIFACT`." ) else: st.error(f"EEG prediction failed (HTTP {e.response.status_code}): {e.response.text}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") else: st.metric( label=result.get("label_text", "prediction"), value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%", ) probs = result.get("probabilities", []) if probs: st.dataframe(probs, use_container_width=True, hide_index=True) def _render_prediction_card(result: dict) -> None: """Editorial decision card: provenance · verdict · signals · SHAP.""" st.session_state["last_bbb_prediction"] = result label_text = _html.escape(str(result["label_text"])) confidence_pct = float(result["confidence"]) * 100 # 1) Provenance strip (auditable line) provenance = result.get("provenance") or {} run_id = provenance.get("mlflow_run_id") run_label = run_id[:8] if run_id else "—" train_date = provenance.get("train_date") or "—" model_version = provenance.get("model_version", "v1") n_examples = provenance.get("n_examples") n_label = f"n={n_examples}" if n_examples else "n=—" # 2) Build signal rows: calibration, drift signal_rows: list[tuple[str, str]] = [] calibration = result.get("calibration") if calibration is not None: threshold_pct = round(float(calibration["threshold"]) * 100) precision_pct = round(float(calibration["precision"]) * 100) support = int(calibration["support"]) if support == 0: cal_str = "no held-out support in this band" else: cal_str = ( f"≥{threshold_pct}% confident → " f"{precision_pct}% precision · n={support}" ) signal_rows.append(("calibration", cal_str)) drift_z = result.get("drift_z") rolling_n = int(result.get("rolling_n", 0)) if drift_z is None and rolling_n < 10: drift_str = f"warming up · {rolling_n}/10 buffered" elif drift_z is None: drift_str = "unavailable · model lacks train-time stats" else: if abs(drift_z) < 1.0: tag = "within expected range" elif abs(drift_z) < 2.0: tag = "mild distribution shift" else: tag = "significant shift — retrain recommended" drift_str = ( f"trailing-{rolling_n} median {drift_z:+.2f}σ · {tag}" ) signal_rows.append(("drift", drift_str)) signals_html = "".join( f'verdict
{label_text.lower()}
Model confidence · {confidence_pct:.1f}%
' f'top {n_features} shap attributions
', unsafe_allow_html=True, ) import pandas as pd shap_df = pd.DataFrame(result["top_features"]).set_index("feature") # Keep st.bar_chart for simplicity; the wrapper now sits in a themed frame. st.bar_chart(shap_df, height=240, color=_TOKENS_DARK["accent"] if st.session_state.get("theme", "dark") == "dark" else _TOKENS_LIGHT["accent"]) st.caption( "Positive SHAP values pushed the model toward the predicted class; " "negative values pushed it away. Features are 2,048-bit Morgan " "fingerprint indices (`fp_