"""NeuroBridge Enterprise — Streamlit B2B dashboard (Editorial redesign). Five tabs (Molecule / Signal / Image / AI Assistant / Experiments) sitting on top of one FastAPI surface. Every interaction returns an auditable decision artefact: label + confidence + calibration + drift + provenance + SHAP. Visual language (post-redesign): - Dark theme = editorial Netflix-style — deep neutral grays + sand accent - Light theme = warm paper + charcoal type — Apple HIG / NYT-Cooking energy - Single sand brand-mark across both themes (#D2C4B1) - Inter (display + body) + JetBrains Mono (data / code) Launch: `streamlit run src/frontend/app.py` """ from __future__ import annotations import html as _html import os import httpx import streamlit as st _API_URL = os.environ.get("NEUROBRIDGE_API_URL", "http://localhost:8000") _MLFLOW_URL = os.environ.get( "NEUROBRIDGE_MLFLOW_URL", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"), ) _MLFLOW_DISABLED = os.environ.get("NEUROBRIDGE_DISABLE_MLFLOW") == "1" _LLM_DISABLED = os.environ.get("NEUROBRIDGE_DISABLE_LLM") == "1" # --------------------------------------------------------------------------- # # Design tokens — single source of truth for both themes. # # Tokens are exposed as CSS custom properties at the :root level; every # # component reads from them so a theme swap is just a value swap. # # --------------------------------------------------------------------------- # _TOKENS_DARK = { # Surfaces (deepest → most elevated) "bg-base": "#0e0e10", "bg-elevated": "#161618", "bg-elevated-2": "#1e1e21", "bg-elevated-3": "#2a2a2e", # Brand accent "accent": "#D2C4B1", "accent-strong": "#E8DCC6", "accent-soft": "rgba(210, 196, 177, 0.12)", "accent-ring": "rgba(210, 196, 177, 0.35)", # Text "text-primary": "#F5F2ED", "text-secondary": "#A8A29A", "text-tertiary": "#6B6660", "text-on-accent": "#161618", # Lines "border": "#2a2a2e", "border-strong": "#3a3a3e", # Semantic (keep cool — never red/green dominant in editorial) "success": "#7FB069", "warning": "#E0B469", "danger": "#D97A6C", # Effects "shadow-sm": "0 1px 2px rgba(0, 0, 0, 0.4)", "shadow-md": "0 8px 24px rgba(0, 0, 0, 0.45)", "shadow-lg": "0 16px 48px rgba(0, 0, 0, 0.55)", } _TOKENS_LIGHT = { "bg-base": "#FAF7F2", "bg-elevated": "#FFFFFF", "bg-elevated-2": "#F5F0E8", "bg-elevated-3": "#EDE5D5", "accent": "#1e1e21", "accent-strong": "#0e0e10", "accent-soft": "rgba(30, 30, 33, 0.06)", "accent-ring": "rgba(30, 30, 33, 0.18)", "text-primary": "#161618", "text-secondary": "#4A4540", "text-tertiary": "#8A857E", "text-on-accent": "#FAF7F2", "border": "#E5DDC9", "border-strong": "#D2C4B1", "success": "#3F7D45", "warning": "#A06D1F", "danger": "#A1483D", "shadow-sm": "0 1px 2px rgba(40, 30, 20, 0.04)", "shadow-md": "0 4px 16px rgba(40, 30, 20, 0.08)", "shadow-lg": "0 12px 40px rgba(40, 30, 20, 0.12)", } def _build_css(theme: str) -> str: """Return the full """ # --------------------------------------------------------------------------- # # Theme management # # --------------------------------------------------------------------------- # def _init_theme() -> str: """Initialize and return the active theme ('dark' default).""" if "theme" not in st.session_state: st.session_state["theme"] = "dark" return st.session_state["theme"] def _altair_theme(theme: str) -> dict: """Return an altair theme matching the active palette. Registered as 'neurobridge' on first call; subsequent calls just enable. """ tokens = _TOKENS_DARK if theme == "dark" else _TOKENS_LIGHT return { "config": { "background": tokens["bg-elevated"], "view": {"stroke": "transparent"}, "axis": { "labelColor": tokens["text-secondary"], "titleColor": tokens["text-secondary"], "labelFont": "Inter", "titleFont": "Inter", "labelFontSize": 11, "titleFontSize": 12, "gridColor": tokens["border"], "domainColor": tokens["border"], "tickColor": tokens["border"], }, "header": { "labelColor": tokens["text-primary"], "labelFont": "Inter", "labelFontSize": 13, "labelFontWeight": 600, "titleColor": tokens["text-secondary"], }, "legend": { "labelColor": tokens["text-secondary"], "titleColor": tokens["text-secondary"], "labelFont": "Inter", "titleFont": "Inter", }, "title": { "color": tokens["text-primary"], "font": "Inter", "fontWeight": 600, }, "range": { # Editorial palette: sand-led, then warm secondaries. "category": [ tokens["accent"], "#8FB3C9", "#C99B8F", "#9DAD86", "#B8A4C9", "#D4B86A", "#7FB069", "#A6A2C2", ], }, } } def _register_altair_theme(theme: str) -> None: """Register + enable the neurobridge altair theme for the current run.""" try: import altair as alt alt.themes.register("neurobridge", lambda: _altair_theme(theme)) alt.themes.enable("neurobridge") except Exception: # altair may not be importable in some environments; chart calls # will simply use altair defaults — no functional impact. pass # --------------------------------------------------------------------------- # # HTTP helpers # # --------------------------------------------------------------------------- # def _check_api_health() -> tuple[bool, str]: """Ping FastAPI /health endpoint; return (ok, status_text).""" try: resp = httpx.get(f"{_API_URL}/health", timeout=2.0) if resp.status_code == 200: return True, "operational" return False, f"http {resp.status_code}" except httpx.RequestError as e: return False, type(e).__name__.lower() def _post(endpoint: str, payload: dict, timeout: float = 120.0) -> dict: """POST to the FastAPI surface; let httpx raise on non-2xx.""" resp = httpx.post(f"{_API_URL}{endpoint}", json=payload, timeout=timeout) resp.raise_for_status() return resp.json() def _get(path: str) -> dict: """GET helper symmetric with _post.""" resp = httpx.get(f"{_API_URL}{path}", timeout=10.0) resp.raise_for_status() return resp.json() # --------------------------------------------------------------------------- # # Hero / sidebar / section primitives # # --------------------------------------------------------------------------- # def _render_brand_header(api_ok: bool, api_status: str) -> None: """Editorial hero strip: word-mark + tagline + 3 status dots.""" api_class = "is-ok" if api_ok else "is-down" mlflow_class = "is-mute" if _MLFLOW_DISABLED else "is-ok" mlflow_label = "tracking off" if _MLFLOW_DISABLED else "tracking" llm_class = "is-mute" if _LLM_DISABLED else "is-ok" llm_label = "template only" if _LLM_DISABLED else "llm online" st.markdown( f"""

Living decision system · clinical ML

NeuroBridge Enterprise

Three production pipelines — molecule, signal, image — behind one auditable surface. Every prediction returns label, calibration, drift, provenance and a natural-language rationale.

api · {_html.escape(api_status)} mlflow · {mlflow_label} explainer · {llm_label}
""", unsafe_allow_html=True, ) def _render_section(eyebrow: str, title: str, desc: str) -> None: st.markdown( f"""

{_html.escape(eyebrow)}

{_html.escape(title)}

{_html.escape(desc)}

""", unsafe_allow_html=True, ) def _render_result(body: dict) -> None: """Render a 3-metric result card + (optional) MLflow deep link.""" cols = st.columns(3) cols[0].metric("Rows", f"{body['rows']:,}") cols[1].metric("Columns", f"{body['columns']:,}") cols[2].metric("Runtime", f"{body['duration_sec']:.2f} s") safe_output_path = _html.escape(str(body["output_path"])) st.markdown( f"

" f"output → {safe_output_path}

", unsafe_allow_html=True, ) run_id = body.get("mlflow_run_id") if run_id and not _MLFLOW_DISABLED: safe_run_id = _html.escape(str(run_id)) safe_url = _html.escape(_MLFLOW_URL, quote=True) st.markdown( f"

" f"mlflow run · " f"{safe_run_id[:12]}…

", unsafe_allow_html=True, ) elif _MLFLOW_DISABLED: st.caption("mlflow tracking disabled (NEUROBRIDGE_DISABLE_MLFLOW=1)") def _render_sidebar(api_ok: bool, api_status: str) -> None: with st.sidebar: st.markdown( """ """, unsafe_allow_html=True, ) st.markdown("### Theme") theme = st.session_state.get("theme", "dark") is_dark = st.toggle( "Dark mode", value=(theme == "dark"), key="theme_toggle", help="Switch between editorial dark (Netflix-style) and warm paper (Apple HIG-style).", ) new_theme = "dark" if is_dark else "light" if new_theme != theme: st.session_state["theme"] = new_theme st.rerun() st.markdown("### System") api_class = "is-ok" if api_ok else "is-down" mlflow_class = "is-mute" if _MLFLOW_DISABLED else "is-ok" llm_class = "is-mute" if _LLM_DISABLED else "is-ok" st.markdown( f"""
api · {_html.escape(api_status)} mlflow · {"off" if _MLFLOW_DISABLED else "on"} llm · {"template" if _LLM_DISABLED else "online"}
""", unsafe_allow_html=True, ) st.markdown("### Endpoints") st.markdown( f"

" f"fastapi · {_API_URL}
" f"mlflow  · {_MLFLOW_URL}

", unsafe_allow_html=True, ) if st.button("🔧 Diagnose LLM", key="diag_llm_btn", help="Probe OpenRouter from this container"): try: diag = httpx.get(f"{_API_URL}/diag/openrouter", timeout=15.0).json() st.json(diag) except Exception as e: st.error(f"diag failed: {e!r}") st.markdown("### About") st.markdown( "

" "Trust-engineered clinical-ML platform. Three modalities — BBB drug " "screening, EEG signal cleaning, MRI multi-site harmonization — " "behind one FastAPI surface. Every inference is auditable.

", unsafe_allow_html=True, ) # --------------------------------------------------------------------------- # # Tabs # # --------------------------------------------------------------------------- # def _render_bbb_tab() -> None: _render_section( "MOLECULE — BBBP", "Blood-Brain-Barrier permeability decision", "Enter a SMILES string. The system computes a 2,048-bit Morgan " "fingerprint, runs it through a Random Forest classifier, and returns " "a label, calibration-grounded confidence, drift signal, and the top " "SHAP attributions explaining the decision.", ) EDGE_CASES = { "Custom input (default)": { "smiles": "CCO", "label": "Ethanol — small, drug-like, BBB-permeable", "expectation": "High confidence, label = permeable", }, "Invalid SMILES (parse-error path)": { "smiles": "this_is_not_a_valid_molecule_at_all_!!", "label": "Garbage string — should not parse", "expectation": "API returns HTTP 400 with parse error; UI shows recoverable warning", }, "Empty string (boundary)": { "smiles": "", "label": "Empty input — boundary condition", "expectation": "Pydantic accepts empty; API returns 400 (RDKit cannot parse)", }, "Massive OOD: cyclosporine-like macrocycle": { "smiles": ( "CC[C@H](C)[C@@H]1NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@H](CC(C)C)N(C)C(=O)" "[C@@H]2CCCN2C(=O)[C@H](C(C)C)NC(=O)[C@H]([C@@H](C)CC)N(C)C(=O)" "[C@H](C)NC(=O)[C@H](C)NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@@H](NC(=O)" "[C@H](CC(C)C)N(C)C(=O)CN(C)C1=O)C(C)C" ), "label": "Cyclosporine — 11-residue macrocycle (~1.2 kDa)", "expectation": ( "Far outside training distribution; model should hedge " "with low confidence (well-calibrated systems don't " "pretend to know)." ), }, "OOD: heavy halogenated aromatic": { "smiles": "Fc1c(F)c(F)c(c(F)c1F)c2c(F)c(F)c(F)c(F)c2F", "label": "Decafluorobiphenyl — extreme halogen density", "expectation": "Rare scaffold; expect lowered confidence vs ethanol", }, } case_name = st.selectbox( "Test edge cases", options=list(EDGE_CASES.keys()), index=0, key="bbb_case", help=( "Pick a robustness probe. Each case demonstrates how the system " "handles a real-world failure mode — invalid input, " "out-of-distribution molecules, or boundary conditions." ), ) case = EDGE_CASES[case_name] st.caption(f"**Probe:** {case['label']} · **Expected:** {case['expectation']}") smiles = st.text_input( "SMILES string", value=case["smiles"], key="bbb_smiles", help="Examples: CCO (ethanol), CC(=O)Nc1ccc(O)cc1 (paracetamol)", ) top_k = st.slider( "SHAP features to display", min_value=3, max_value=10, value=5, key="bbb_topk", ) if st.button("Predict BBB permeability", type="primary", key="bbb_predict"): with st.spinner("Computing fingerprint, predicting, explaining…"): try: result = _post("/predict/bbb", {"smiles": smiles, "top_k": top_k}) _render_prediction_card(result) st.toast("Prediction complete", icon="✅") except httpx.HTTPStatusError as e: if e.response.status_code == 503: st.error( "Model artifact not loaded yet. Run " "`python -m src.models.bbb_model` to train it, " "then retry." ) elif e.response.status_code == 400: # Robustness story: WARNING (recoverable), not ERROR. st.warning( f"Robustness check passed: API rejected the input " f"with HTTP 400 (no crash). Detail: " f"{e.response.json().get('detail', e.response.text)}" ) else: st.error( f"Prediction failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") def _render_eeg_tab() -> None: _render_section( "SIGNAL — EEG", "Electroencephalogram artifact removal", "Bandpass-filters raw FIF/EDF recordings, removes EOG artifacts via " "ICA decomposition, and extracts per-band PSD + statistical features " "across fixed-duration epochs.", ) eeg_in = st.text_input( "Input FIF/EDF path", "tests/fixtures/eeg_sample.fif", key="eeg_in", help="Path to a .fif/.edf EEG recording on the server filesystem.", ) eeg_out = st.text_input( "Output Parquet path", "data/processed/eeg_features.parquet", key="eeg_out", ) if st.button("Run EEG pipeline", type="primary", key="eeg_run"): with st.spinner("Filtering and running ICA…"): try: result = _post( "/pipeline/eeg", {"input_path": eeg_in, "output_path": eeg_out}, ) st.session_state["last_eeg_run"] = result _render_result(result) st.toast("EEG pipeline complete", icon="✅") except httpx.HTTPStatusError as e: st.error( f"Pipeline failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") last_eeg = st.session_state.get("last_eeg_run") if last_eeg is not None: with st.expander("Ask the AI Assistant about this EEG run", expanded=False): eeg_q_presets = [ "Why were certain ICA components dropped?", "What does the bandpass filter do?", "Is this run consistent with previous runs?", ] eeg_preset = st.selectbox( "Preset question", options=eeg_q_presets, key="eeg_ai_preset", ) eeg_custom = st.text_input( "Or type your own question (optional)", value="", key="eeg_ai_custom", ) eeg_question = eeg_custom.strip() or eeg_preset if st.button("Ask AI Assistant", key="eeg_ai_ask"): with st.spinner("Composing rationale…"): try: eeg_resp = _post( "/explain/eeg", { "rows": int(last_eeg.get("rows", 0)), "columns": int(last_eeg.get("columns", 0)), "duration_sec": float(last_eeg.get("duration_sec", 0.0)), "mlflow_run_id": last_eeg.get("mlflow_run_id"), "user_question": eeg_question, }, ) st.markdown(f"**A:** {eeg_resp['rationale']}") st.caption( f"Source: `{eeg_resp.get('source', '?')}` · " f"Model: `{eeg_resp.get('model') or '—'}`" ) except httpx.HTTPStatusError as e: st.error( f"Assistant failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI: {e!r}") def _render_mri_tab() -> None: _render_section( "IMAGE — MRI", "Multi-site harmonization via ComBat", "Loads NIfTI volumes, masks brain tissue, computes per-ROI summary " "statistics, then harmonizes across acquisition sites with " "neuroHarmonize to remove scanner-driven domain shift. The diagnostic " "plot below compares per-site feature distributions before and after " "harmonization.", ) mri_dir = st.text_input( "Input NIfTI directory", "tests/fixtures/mri_sample", key="mri_dir", help="Path to a directory of .nii(.gz) files + sites.csv", ) sites_csv = st.text_input( "Sites CSV", "tests/fixtures/mri_sample/sites.csv", key="mri_sites", ) if st.button("Run ComBat diagnostics", type="primary", key="mri_diag"): with st.spinner("Running pre + post ComBat (×2 the work)…"): try: result = _post( "/pipeline/mri/diagnostics", {"input_dir": mri_dir, "sites_csv": sites_csv}, ) _render_combat_diagnostics(result) st.toast("Diagnostics complete", icon="✅") except httpx.HTTPStatusError as e: st.error( f"Diagnostics failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") st.markdown("#### MRI Image Model") mri_kind = os.environ.get("MRI_MODEL_KIND", "volumetric_onnx") if mri_kind == "resnet18_2d": mri_image = st.text_input( "2D MRI image (.png/.jpg)", "tests/fixtures/mri_sample/subject_0_axial.png", key="mri_predict_image", ) st.caption( "Resnet18 4-class — labels: MildDemented, ModerateDemented, " "NonDemented, VeryMildDemented. Resize/labels are baked into the model." ) if st.button("Predict MRI image", key="mri_predict"): payload = {"input_path": mri_image} with st.spinner("Running 2D MRI model..."): try: result = _post("/predict/mri", payload, timeout=120.0) except httpx.HTTPStatusError as e: if e.response.status_code == 503: st.warning( "MRI 2D model artifact missing. Drop the trained checkpoint at " "`data/processed/mri_dl_2d/best_model.pt` or set `MRI_MODEL_PATH_2D`." ) else: st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {e.response.text}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") else: st.metric( label=result.get("label_text", "prediction"), value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%", ) probs = result.get("probabilities", []) if probs: st.dataframe(probs, use_container_width=True, hide_index=True) else: mri_image = st.text_input( "NIfTI image", "tests/fixtures/mri_sample/subject_0.nii.gz", key="mri_predict_image", ) mri_labels = st.text_input( "Class labels", "control,abnormal", key="mri_predict_labels", ) shape_cols = st.columns(3) target_d = shape_cols[0].number_input( "Resize D", min_value=1, max_value=256, value=64, step=1, key="mri_predict_d" ) target_h = shape_cols[1].number_input( "Resize H", min_value=1, max_value=256, value=64, step=1, key="mri_predict_h" ) target_w = shape_cols[2].number_input( "Resize W", min_value=1, max_value=256, value=64, step=1, key="mri_predict_w" ) st.caption( "Resize target as (D, H, W). Default 64³ matches typical model exports." ) if st.button("Predict MRI image", key="mri_predict"): labels = [x.strip() for x in mri_labels.split(",") if x.strip()] payload: dict = { "input_path": mri_image, "target_shape": [int(target_d), int(target_h), int(target_w)], } if labels: payload["label_names"] = labels with st.spinner("Running MRI image model..."): try: result = _post("/predict/mri", payload, timeout=120.0) except httpx.HTTPStatusError as e: detail = e.response.text if e.response.status_code == 503: st.warning( "MRI model artifact is not available yet. Export the trained " "ONNX model to `data/processed/mri_model.onnx` or set `MRI_MODEL_PATH`." ) else: st.error(f"MRI prediction failed (HTTP {e.response.status_code}): {detail}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") else: st.metric( label=result.get("label_text", "prediction"), value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%", ) probs = result.get("probabilities", []) if probs: st.dataframe(probs, use_container_width=True, hide_index=True) st.markdown("#### EEG Pretrained Classifier") st.caption( "Pretrained sklearn classifier on EEG band-power features. " "Output: per-class probabilities for `(control, alzheimers)`." ) eeg_csv = st.text_area( "EEG features (comma-separated)", ",".join(["0.0"] * 16), key="eeg_predict_features", height=80, ) if st.button("Predict EEG", key="eeg_predict"): try: features = [float(x.strip()) for x in eeg_csv.split(",") if x.strip()] except ValueError: st.error("EEG features must all be numeric.") else: payload = {"features": features} with st.spinner("Running EEG classifier..."): try: result = _post("/predict/eeg", payload, timeout=30.0) except httpx.HTTPStatusError as e: if e.response.status_code == 503: st.warning( "EEG model artifact missing. Drop the trained joblib at " "`data/processed/eeg_clf.joblib` or set `EEG_CLF_ARTIFACT`." ) else: st.error(f"EEG prediction failed (HTTP {e.response.status_code}): {e.response.text}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") else: st.metric( label=result.get("label_text", "prediction"), value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%", ) probs = result.get("probabilities", []) if probs: st.dataframe(probs, use_container_width=True, hide_index=True) def _render_prediction_card(result: dict) -> None: """Editorial decision card: provenance · verdict · signals · SHAP.""" st.session_state["last_bbb_prediction"] = result label_text = _html.escape(str(result["label_text"])) confidence_pct = float(result["confidence"]) * 100 # 1) Provenance strip (auditable line) provenance = result.get("provenance") or {} run_id = provenance.get("mlflow_run_id") run_label = run_id[:8] if run_id else "—" train_date = provenance.get("train_date") or "—" model_version = provenance.get("model_version", "v1") n_examples = provenance.get("n_examples") n_label = f"n={n_examples}" if n_examples else "n=—" # 2) Build signal rows: calibration, drift signal_rows: list[tuple[str, str]] = [] calibration = result.get("calibration") if calibration is not None: threshold_pct = round(float(calibration["threshold"]) * 100) precision_pct = round(float(calibration["precision"]) * 100) support = int(calibration["support"]) if support == 0: cal_str = "no held-out support in this band" else: cal_str = ( f"≥{threshold_pct}% confident → " f"{precision_pct}% precision · n={support}" ) signal_rows.append(("calibration", cal_str)) drift_z = result.get("drift_z") rolling_n = int(result.get("rolling_n", 0)) if drift_z is None and rolling_n < 10: drift_str = f"warming up · {rolling_n}/10 buffered" elif drift_z is None: drift_str = "unavailable · model lacks train-time stats" else: if abs(drift_z) < 1.0: tag = "within expected range" elif abs(drift_z) < 2.0: tag = "mild distribution shift" else: tag = "significant shift — retrain recommended" drift_str = ( f"trailing-{rolling_n} median {drift_z:+.2f}σ · {tag}" ) signal_rows.append(("drift", drift_str)) signals_html = "".join( f'
{k}' f'{v}
' for k, v in signal_rows ) st.markdown( f"""
mlflow · {_html.escape(run_label)} model · {_html.escape(model_version)} trained · {_html.escape(train_date)} {_html.escape(n_label)}

verdict

{label_text.lower()}

Model confidence · {confidence_pct:.1f}%

""", unsafe_allow_html=True, ) # Native progress bar — themed via CSS variables st.progress(float(result["confidence"])) st.markdown( f"""
{signals_html}
""", unsafe_allow_html=True, ) # SHAP attributions chart n_features = len(result["top_features"]) st.markdown( f'

' f'top {n_features} shap attributions

', unsafe_allow_html=True, ) import pandas as pd shap_df = pd.DataFrame(result["top_features"]).set_index("feature") # Keep st.bar_chart for simplicity; the wrapper now sits in a themed frame. st.bar_chart(shap_df, height=240, color=_TOKENS_DARK["accent"] if st.session_state.get("theme", "dark") == "dark" else _TOKENS_LIGHT["accent"]) st.caption( "Positive SHAP values pushed the model toward the predicted class; " "negative values pushed it away. Features are 2,048-bit Morgan " "fingerprint indices (`fp_`)." ) def _render_combat_diagnostics(result: dict) -> None: """Pre/Post-ComBat KDE comparison + 3-metric site-gap KPI strip.""" import altair as alt import pandas as pd rows = result.get("rows", []) if not rows: st.info( "No data returned. Check that the input directory contains " ".nii(.gz) files and a sites.csv with subject_id/site columns." ) return cols = st.columns(3) cols[0].metric("Site-gap (Pre-ComBat)", f"{result['site_gap_pre']:.4f}") cols[1].metric("Site-gap (Post-ComBat)", f"{result['site_gap_post']:.4f}") cols[2].metric( "Reduction factor", f"{result['reduction_factor']:.0f}×", help=( "Pre-gap / Post-gap. A 100× reduction means ComBat removed " "two orders of magnitude of site-driven domain shift." ), ) df = pd.DataFrame(rows) feat = df["feature"].iloc[0] feat_df = df[df["feature"] == feat] chart = ( alt.Chart(feat_df) .transform_density( density="feature_value", groupby=["site", "harmonization_state"], as_=["feature_value", "density"], ) .mark_area(opacity=0.5) .encode( x=alt.X("feature_value:Q", title=f"{feat} (intensity)"), y=alt.Y("density:Q", title="Density"), color=alt.Color( "site:N", title="Site", ), tooltip=[ alt.Tooltip("site:N"), alt.Tooltip("feature_value:Q", format=".4f"), alt.Tooltip("density:Q", format=".3f"), ], ) .properties(width=380, height=260) .facet( column=alt.Column( "harmonization_state:N", title=None, sort=["Pre-ComBat", "Post-ComBat"], ) ) .resolve_scale(x="shared", y="shared") ) st.altair_chart(chart, use_container_width=True) st.caption( f"Per-site density of `{feat}` before and after ComBat. Each colored " f"region is one acquisition site. **Convergence of the colored " f"regions in the Post-ComBat panel is the visual proof of " f"harmonization** — the same property the " f"{result['reduction_factor']:.0f}× site-gap reduction quantifies." ) n_subjects = len({r["subject_id"] for r in result.get("rows", [])}) with st.expander("Ask the AI Assistant about this ComBat run", expanded=False): mri_q_presets = [ "Why does ComBat matter for multi-site MRI?", "How significant is this reduction factor?", "What would I lose without harmonization?", ] mri_preset = st.selectbox( "Preset question", options=mri_q_presets, key="mri_ai_preset", ) mri_custom = st.text_input( "Or type your own question (optional)", value="", key="mri_ai_custom", ) mri_question = mri_custom.strip() or mri_preset if st.button("Ask AI Assistant", key="mri_ai_ask"): with st.spinner("Composing rationale…"): try: mri_resp = _post( "/explain/mri", { "site_gap_pre": float(result["site_gap_pre"]), "site_gap_post": float(result["site_gap_post"]), "reduction_factor": float(result["reduction_factor"]), "n_subjects": n_subjects, "user_question": mri_question, }, ) st.markdown(f"**A:** {mri_resp['rationale']}") st.caption( f"Source: `{mri_resp.get('source', '?')}` · " f"Model: `{mri_resp.get('model') or '—'}`" ) except httpx.HTTPStatusError as e: st.error( f"Assistant failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI: {e!r}") def _render_researcher_tab() -> None: """Drug researcher view: BBB permeability map + dose adjustment.""" st.markdown("### Drug Researcher") st.caption( "DCE-MRI inspired BBB leakage score → revised dose suggestion. " "Output is a research signal, NOT medical advice." ) col_left, col_right = st.columns(2) with col_left: st.markdown("**1. Patient BBB permeability**") mri_path = st.text_input( "MRI image path (server-side)", "tests/fixtures/mri_sample/subject_0_axial.png", key="researcher_mri_path", ) mode = st.selectbox( "Scoring mode", ["heuristic_proxy", "dce_onnx"], index=0, key="researcher_perm_mode", help="heuristic_proxy uses the 2D classifier; dce_onnx requires a trained DCE artifact.", ) if st.button("Compute BBB leakage score", key="researcher_compute_perm"): with st.spinner("Running BBB permeability scorer..."): try: result = _post( "/predict/bbb_permeability_map", {"input_path": mri_path, "mode": mode}, timeout=60.0, ) except httpx.HTTPStatusError as e: st.error(f"BBB permeability failed (HTTP {e.response.status_code}): {e.response.text}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI: {e!r}") else: st.session_state["researcher_perm"] = result st.metric( label=result.get("interpretation", "BBB"), value=f"{float(result['permeability_score']) * 100:.1f}%", help=f"method={result.get('method', '?')}", ) with col_right: st.markdown("**2. Drug + baseline dose**") smiles = st.text_input("SMILES", "CCO", key="researcher_smiles") baseline = st.number_input( "Baseline dose (mg)", min_value=0.1, max_value=2000.0, value=100.0, step=10.0, key="researcher_baseline", ) score_default = float( st.session_state.get("researcher_perm", {}).get("permeability_score", 0.0) ) score = st.number_input( "BBB permeability score", min_value=0.0, max_value=1.0, value=score_default, step=0.05, key="researcher_score", help="Auto-fills from the BBB leakage score above; override manually if you want.", ) if st.button("Suggest revised dose", key="researcher_compute_dose"): payload = { "smiles": smiles or None, "baseline_dose_mg": float(baseline), "bbb_permeability_score": float(score), } with st.spinner("Computing dose adjustment..."): try: result = _post("/research/drug_dose_adjustment", payload, timeout=30.0) except httpx.HTTPStatusError as e: st.error(f"Dose adjustment failed (HTTP {e.response.status_code}): {e.response.text}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI: {e!r}") else: risk = result.get("risk_level", "unknown") risk_emoji = {"low": "🟢", "moderate": "🟡", "high": "🔴"}.get(risk, "⚪️") st.metric( label=f"{risk_emoji} Recommended dose", value=f"{result['recommended_dose_mg']:.1f} mg", delta=f"{(result['adjustment_factor'] - 1.0) * 100:+.0f}%", delta_color="inverse", ) drug_perm = result.get("drug_bbb_permeable") if drug_perm is not None: st.caption(f"Drug BBB-permeable: **{drug_perm}**") st.info(result.get("rationale", "")) def _render_ai_assistant_tab() -> None: """Chat-style explainer for the most recent BBB prediction.""" _render_section( "AI Assistant", "Natural-language rationale (LLM or deterministic template)", "Pulls the most recent BBB prediction from this session and asks the " "explainer to justify it. Falls back to a deterministic, auditable " "template when no LLM is configured.", ) last = st.session_state.get("last_bbb_prediction") if last is None: st.info( "Run a BBB prediction first (Molecule tab → Predict button), " "then come back here to ask the assistant about it." ) return top_features_preview = ", ".join( f["feature"] for f in last.get("top_features", [])[:3] ) st.caption( f"Latest prediction: **{last['label_text']}** " f"({float(last['confidence']) * 100:.0f}% confident) · " f"Top SHAP: {top_features_preview}" ) PRESETS = [ "Why was this molecule predicted as permeable?", "Which features pushed the verdict the most?", "Is this prediction trustworthy given the drift signal?", ] preset = st.selectbox("Preset question", options=PRESETS, key="ai_preset") custom = st.text_input( "Or type your own question (optional)", value="", key="ai_custom", help=( "Custom questions only affect the LLM path; the template gives a " "generic SHAP-driven rationale either way." ), ) question = custom.strip() or preset if st.button("Ask the AI Assistant", type="primary", key="ai_ask"): with st.spinner("Composing rationale…"): try: body = { "smiles": last.get("smiles", ""), "label": last["label"], "label_text": last["label_text"], "confidence": last["confidence"], "top_features": last.get("top_features", []), "calibration": last.get("calibration"), "drift_z": last.get("drift_z"), "user_question": question, } if not body["smiles"]: body["smiles"] = st.session_state.get("bbb_smiles", "") resp = _post("/explain/bbb", body) except httpx.HTTPStatusError as e: st.error( f"Explainer failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) return except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") return history = st.session_state.setdefault("explain_history", []) history.insert(0, (question, resp)) history = st.session_state.get("explain_history", []) if history: st.markdown("### Conversation") for q, r in history[:10]: st.markdown(f"**Q:** {q}") st.markdown(f"**A:** {r['rationale']}") source = r.get("source", "?") model = r.get("model") or "—" st.caption(f"Source: `{source}` · Model: `{model}`") st.divider() def _render_experiments_tab() -> None: """MLflow runs table + two-run diff (Track 5).""" _render_section( "Experiments — MLOps Audit", "MLflow runs across BBB / EEG / MRI experiments", "Lists every recorded training run; pick any two to see a side-by-side " "metric + parameter diff. Foundation for auditable, reproducible " "model lineage.", ) if st.button("Refresh runs", key="exp_refresh"): st.session_state.pop("experiments_runs_cache", None) runs = st.session_state.get("experiments_runs_cache") if runs is None: try: data = _get("/experiments/runs") runs = data.get("runs", []) st.session_state["experiments_runs_cache"] = runs except httpx.HTTPStatusError as e: st.error( f"Failed to load runs (HTTP {e.response.status_code}): " f"{e.response.text}" ) return except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") return if not runs: st.info( "No MLflow runs found. Trigger a pipeline first (Molecule / " "Signal / Image), then refresh this tab. (Under " "NEUROBRIDGE_DISABLE_MLFLOW=1 the list will stay empty.)" ) return rows_preview = [ { "run_id": run["run_id"][:8], "experiment": run["experiment_name"], "start_time": run["start_time"][:19], "status": run["status"], "n_metrics": len(run["metrics"]), "n_params": len(run["params"]), } for run in runs ] st.dataframe(rows_preview, use_container_width=True, hide_index=True) st.markdown("### Compare two runs") run_ids = [r["run_id"] for r in runs] if len(run_ids) < 2: st.caption("Need at least 2 runs to compare. Trigger another pipeline.") return col_a, col_b = st.columns(2) with col_a: sel_a = st.selectbox( "Run A", options=run_ids, format_func=lambda x: x[:8], key="diff_a", ) with col_b: sel_b = st.selectbox( "Run B", options=run_ids, index=min(1, len(run_ids) - 1), format_func=lambda x: x[:8], key="diff_b", ) if st.button("Show diff", type="primary", key="exp_diff_go"): try: diff = _post( "/experiments/diff", {"run_id_a": sel_a, "run_id_b": sel_b}, ) except httpx.HTTPStatusError as e: st.error( f"Diff failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) return rows = diff.get("rows", []) if not rows: st.info("Both runs have identical metrics and params (or are empty).") return diff_table = [ { "key": r["key"], "kind": r["kind"], "A": r["value_a"] or "—", "B": r["value_b"] or "—", "differs": "✓" if r["differs"] else "", } for r in rows ] st.dataframe(diff_table, use_container_width=True, hide_index=True) # --------------------------------------------------------------------------- # # Entrypoint # # --------------------------------------------------------------------------- # def main() -> None: """Streamlit entrypoint. Idempotent — Streamlit re-runs on every interaction.""" st.set_page_config( page_title="NeuroBridge Enterprise", page_icon=None, layout="wide", initial_sidebar_state="expanded", ) theme = _init_theme() st.markdown(_build_css(theme), unsafe_allow_html=True) _register_altair_theme(theme) api_ok, api_status = _check_api_health() _render_brand_header(api_ok, api_status) _render_sidebar(api_ok, api_status) if not api_ok: st.warning( f"FastAPI surface is not reachable at `{_API_URL}` ({api_status}). " "Pipeline runs will fail until the API service is up. " "Run `uvicorn src.api.main:app --port 8000` or `docker compose up`." ) bbb_tab, eeg_tab, mri_tab, researcher_tab, assistant_tab, experiments_tab, agent_tab = st.tabs([ "Molecule", "Signal", "Image", "Researcher", "AI Assistant", "Experiments", "🤖 Agent", ]) with bbb_tab: _render_bbb_tab() with eeg_tab: _render_eeg_tab() with mri_tab: _render_mri_tab() with researcher_tab: _render_researcher_tab() with assistant_tab: _render_ai_assistant_tab() with experiments_tab: _render_experiments_tab() with agent_tab: st.markdown("### Orchestrator Agent") st.caption( "Pick the pipeline automatically, run it, then ground the response " "in curated reference docs (RAG)." ) with st.form("agent_form"): agent_input = st.text_input( "Input", value="CCO", help="SMILES (e.g., CCO), .fif/.edf path, or NIfTI directory path", ) agent_question = st.text_input( "Question (optional)", value="", help="Ask in any language — the agent will mirror it in the response", ) agent_sites_csv = st.text_input( "MRI sites CSV (optional)", value="", help="Defaults to /sites.csv", ) submitted = st.form_submit_button("Run agent") if submitted and agent_input: with st.spinner("Agent is reasoning..."): try: payload: dict = {"user_input": agent_input} if agent_question: payload["user_question"] = agent_question if agent_sites_csv: payload["sites_csv"] = agent_sites_csv response = _post("/agent/run", payload, timeout=120.0) except Exception as e: st.error(f"Agent run failed: {e}") else: st.markdown("#### Response") st.write(response.get("text", "")) st.caption( f"model: `{response.get('model', '?')}` · " f"finish: `{response.get('finish_reason', '?')}`" ) trace = response.get("trace", []) expander_title = f"🧠 Decision trace ({len(trace)} step{'s' if len(trace) != 1 else ''})" with st.expander(expander_title, expanded=True): if not trace: st.write("_(no tool calls)_") for i, step in enumerate(trace, start=1): st.markdown(f"**{i}. `{step['name']}`**") if step.get("error"): st.error(step["error"]) else: st.json(step.get("args", {})) st.json(step.get("result", {})) if __name__ == "__main__": main()