"""NeuroBridge Enterprise — Streamlit B2B dashboard. Three tabs (Molecule / Signal / Image), each fires a POST request against the sibling FastAPI service and renders a result card with row counts, runtime, and a deep link to the corresponding MLflow run. Design: Trust & Authority — navy + sky CTA + cool-white background, Plus Jakarta Sans, generous whitespace. Avoids emoji icons, AI gradients, and playful flourishes (per design-system guidance for clinical-ML B2B). Launch: `streamlit run src/frontend/app.py` """ from __future__ import annotations import html as _html import os import httpx import streamlit as st _API_URL = os.environ.get("NEUROBRIDGE_API_URL", "http://localhost:8000") _MLFLOW_URL = os.environ.get( "NEUROBRIDGE_MLFLOW_URL", os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000"), ) _MLFLOW_DISABLED = os.environ.get("NEUROBRIDGE_DISABLE_MLFLOW") == "1" # Trust & Authority custom CSS — overrides Streamlit defaults to lock the # design-system tokens. Loaded once at app start via st.markdown. _CUSTOM_CSS = """ """ def _check_api_health() -> tuple[bool, str]: """Ping FastAPI /health endpoint; return (ok, status_text).""" try: resp = httpx.get(f"{_API_URL}/health", timeout=2.0) if resp.status_code == 200: return True, "ok" return False, f"http {resp.status_code}" except httpx.RequestError as e: return False, str(type(e).__name__) def _post(endpoint: str, payload: dict) -> dict: """POST to the FastAPI surface; let httpx raise on non-2xx.""" resp = httpx.post(f"{_API_URL}{endpoint}", json=payload, timeout=120.0) resp.raise_for_status() return resp.json() def _render_brand_header() -> None: st.markdown( """
Three-modality clinical ML — Data Drift, Missing Modalities, Artifacts
{eyebrow}
{desc}
""", unsafe_allow_html=True, ) def _render_result(body: dict) -> None: """Render a 3-metric result card + MLflow deep link.""" cols = st.columns(3) cols[0].metric("Rows", f"{body['rows']:,}") cols[1].metric("Columns", f"{body['columns']:,}") cols[2].metric("Runtime", f"{body['duration_sec']:.2f} s") safe_output_path = _html.escape(str(body["output_path"])) st.markdown( f""
f"Output written to "
f"{safe_output_path}
MLflow run: " f"{safe_run_id[:12]}…
", unsafe_allow_html=True, ) elif _MLFLOW_DISABLED: st.markdown( "" "MLflow tracking is disabled (NEUROBRIDGE_DISABLE_MLFLOW=1).
", unsafe_allow_html=True, ) def _render_sidebar(api_ok: bool, api_status: str) -> None: with st.sidebar: st.markdown("### System Status") safe_api_status = _html.escape(api_status) api_pill = ( f"API · {safe_api_status}" if api_ok else f"API · {safe_api_status}" ) mlflow_pill = ( "MLflow · disabled" if _MLFLOW_DISABLED else "MLflow · tracking" ) st.markdown(api_pill + mlflow_pill, unsafe_allow_html=True) st.markdown("### Endpoints") st.markdown( f""
f"FastAPI · {_API_URL}
"
f"MLflow · {_MLFLOW_URL}
" "Solving Data Drift, Missing Modalities, and Artifacts in clinical " "biosignal pipelines. Three production modalities behind one FastAPI " "surface, all runs tracked to MLflow.
", unsafe_allow_html=True, ) def _render_bbb_tab() -> None: _render_section( "MOLECULE — BBBP", "Blood-Brain-Barrier permeability decision", "Enter a SMILES string. The system computes a 2,048-bit Morgan " "fingerprint, runs it through a trained Random Forest classifier, " "and returns the predicted permeability label, the model's " "self-rated confidence, and the top SHAP feature attributions " "explaining the decision.", ) EDGE_CASES = { "Custom input (default)": { "smiles": "CCO", "label": "Ethanol — small, drug-like, BBB-permeable", "expectation": "High confidence, label = permeable", }, "Invalid SMILES (parse-error path)": { "smiles": "this_is_not_a_valid_molecule_at_all_!!", "label": "Garbage string — should not parse", "expectation": "API returns HTTP 400 with parse error; UI shows recoverable warning", }, "Empty string (boundary)": { "smiles": "", "label": "Empty input — boundary condition", "expectation": "Pydantic accepts empty; API returns 400 (RDKit cannot parse)", }, "Massive OOD: cyclosporine-like macrocycle": { "smiles": ( "CC[C@H](C)[C@@H]1NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@H](CC(C)C)N(C)C(=O)" "[C@@H]2CCCN2C(=O)[C@H](C(C)C)NC(=O)[C@H]([C@@H](C)CC)N(C)C(=O)" "[C@H](C)NC(=O)[C@H](C)NC(=O)[C@H](CC(C)C)N(C)C(=O)[C@@H](NC(=O)" "[C@H](CC(C)C)N(C)C(=O)CN(C)C1=O)C(C)C" ), "label": "Cyclosporine — 11-residue macrocycle (~1.2 kDa)", "expectation": ( "Far outside training distribution; model should hedge " "with low confidence (well-calibrated systems don't " "pretend to know)." ), }, "OOD: heavy halogenated aromatic": { "smiles": "Fc1c(F)c(F)c(c(F)c1F)c2c(F)c(F)c(F)c(F)c2F", "label": "Decafluorobiphenyl — extreme halogen density", "expectation": "Rare scaffold; expect lowered confidence vs ethanol", }, } case_name = st.selectbox( "Test Edge Cases", options=list(EDGE_CASES.keys()), index=0, key="bbb_case", help=( "Pick a robustness probe. Each case demonstrates how the " "system handles a real-world failure mode — invalid input, " "out-of-distribution molecules, or boundary conditions." ), ) case = EDGE_CASES[case_name] st.caption(f"**Probe:** {case['label']} · **Expected:** {case['expectation']}") smiles = st.text_input( "SMILES string", value=case["smiles"], key="bbb_smiles", help="Examples: CCO (ethanol), CC(=O)Nc1ccc(O)cc1 (paracetamol)", ) top_k = st.slider( "SHAP features to display", min_value=3, max_value=10, value=5, key="bbb_topk", ) if st.button("Predict BBB permeability", type="primary", key="bbb_predict"): with st.spinner("Computing fingerprint, predicting, and explaining…"): try: result = _post("/predict/bbb", {"smiles": smiles, "top_k": top_k}) _render_prediction_card(result) st.toast("Prediction complete", icon="✅") except httpx.HTTPStatusError as e: if e.response.status_code == 503: st.error( "Model artifact not loaded yet. Run " "`python -m src.models.bbb_model` to train it, " "then retry." ) elif e.response.status_code == 400: # Robustness story: show the WARNING instead of an ERROR # — invalid input is a recoverable path, not a crash. st.warning( f"Robustness check passed: API rejected the input " f"with HTTP 400 (no crash). Detail: " f"{e.response.json().get('detail', e.response.text)}" ) else: st.error( f"Prediction failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") def _render_eeg_tab() -> None: _render_section( "SIGNAL — EEG", "Electroencephalogram artifact removal", "Bandpass-filters raw FIF/EDF recordings, removes EOG artifacts via " "ICA decomposition, and extracts per-band PSD + statistical features " "across fixed-duration epochs.", ) eeg_in = st.text_input("Input FIF/EDF path", "data/raw/eeg.fif", key="eeg_in") eeg_out = st.text_input("Output Parquet path", "data/processed/eeg_features.parquet", key="eeg_out") if st.button("Run EEG pipeline", type="primary", key="eeg_run"): with st.spinner("Filtering and running ICA…"): try: _render_result(_post("/pipeline/eeg", { "input_path": eeg_in, "output_path": eeg_out, })) st.toast("EEG pipeline complete", icon="✅") except httpx.HTTPStatusError as e: st.error(f"Pipeline failed (HTTP {e.response.status_code}): {e.response.text}") except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") def _render_mri_tab() -> None: _render_section( "IMAGE — MRI", "Multi-site harmonization via ComBat", "Loads NIfTI volumes, masks brain tissue, computes per-ROI summary " "statistics, then harmonizes across acquisition sites with neuroHarmonize " "to remove scanner-driven domain shift. The diagnostic plot below " "compares per-site feature distributions before and after harmonization." ) mri_dir = st.text_input( "Input NIfTI directory", "tests/fixtures/mri_sample", key="mri_dir", help="Path to a directory of .nii(.gz) files + sites.csv", ) sites_csv = st.text_input( "Sites CSV", "tests/fixtures/mri_sample/sites.csv", key="mri_sites", ) if st.button("Run ComBat diagnostics", type="primary", key="mri_diag"): with st.spinner("Running pre + post ComBat (×2 the work)…"): try: result = _post( "/pipeline/mri/diagnostics", {"input_dir": mri_dir, "sites_csv": sites_csv}, ) _render_combat_diagnostics(result) st.toast("Diagnostics complete", icon="✅") except httpx.HTTPStatusError as e: st.error( f"Diagnostics failed (HTTP {e.response.status_code}): " f"{e.response.text}" ) except httpx.RequestError as e: st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}") def _render_prediction_card(result: dict) -> None: """Render a B2B-styled decision card: label badge + confidence + SHAP bars.""" st.session_state["last_bbb_prediction"] = result provenance = result.get("provenance") if provenance is not None: run_id = provenance.get("mlflow_run_id") run_label = run_id[:8] if run_id else "—" train_date = provenance.get("train_date") or "—" n_examples = provenance.get("n_examples") n_label = f"n={n_examples}" if n_examples else "n=—" st.caption( f"🔎 MLflow run **{run_label}** · " f"Model **{provenance.get('model_version', 'v1')}** · " f"trained {train_date} · {n_label}" ) label_text = _html.escape(str(result["label_text"])) badge_color = "#166534" if result["label"] == 1 else "#991B1B" badge_bg = "#DCFCE7" if result["label"] == 1 else "#FEE2E2" confidence_pct = result["confidence"] * 100 st.markdown( f"""Prediction
" "Confidence
", unsafe_allow_html=True, ) st.progress(float(result["confidence"])) # Trust caption — precision-at-confidence from held-out 20% test split. # Silent skip when the API response has no calibration field (legacy models). calibration = result.get("calibration") if calibration is not None: threshold_pct = round(calibration["threshold"] * 100) precision_pct = round(calibration["precision"] * 100) support = calibration["support"] if support == 0: st.caption( "📊 Bu güven aralığında held-out test örneği yok — " "kalibrasyon bilgisi mevcut değil." ) else: st.caption( f"📊 Test set'te ≥{threshold_pct}% güven üreten tahminlerin " f"precision'ı **{precision_pct}%** (n={support})." ) drift_z = result.get("drift_z") rolling_n = result.get("rolling_n", 0) if drift_z is None and rolling_n < 10: st.caption( f"📈 Drift: warming up ({rolling_n}/10 predictions buffered)." ) elif drift_z is None: st.caption( "📈 Drift: unavailable (model lacks train-time confidence stats)." ) else: # Sign + magnitude: |z| < 1 in-band, 1–2 mild, >=2 significant. if abs(drift_z) < 1.0: tag = "within expected range" elif abs(drift_z) < 2.0: tag = "mild distribution shift" else: tag = "significant shift — retrain recommended" st.caption( f"📈 Drift: trailing-{rolling_n} confidence median is " f"**{drift_z:+.2f}σ** from train-time distribution ({tag})." ) # SHAP attributions chart n_features = len(result["top_features"]) st.markdown( f"" f"Top {n_features} SHAP attributions
", unsafe_allow_html=True, ) import pandas as pd shap_df = pd.DataFrame(result["top_features"]).set_index("feature") st.bar_chart(shap_df, height=240, color="#0369A1") st.caption( "Positive SHAP values pushed the model toward the predicted class; " "negative values pushed it away. Feature names are 2,048-bit Morgan " "fingerprint indices (`fp_