mekosotto commited on
Commit
e8e922d
·
1 Parent(s): a2a375c

feat(eeg,frontend): EEG fusion-flow test + Streamlit EEG form + real-artifact sanity

Browse files

- tests/fusion/test_eeg_modality_flow.py: end-to-end EEG -> fusion `eeg`
modality flow (alzheimers feature spike lifts AD score; control profile
doesn't inflate it).
- tests/models/test_eeg_model_real.py: skip-if-absent sanity for the real
EEG joblib; runs automatically once the file lands at the canonical path.
- src/frontend/app.py: doctor MRI tab grows an EEG predict form
(comma-separated features) calling POST /predict/eeg. 503 surfaces as a
human-readable warning instead of a stack trace.

src/frontend/app.py CHANGED
@@ -1414,6 +1414,47 @@ def _render_mri_tab() -> None:
1414
  if probs:
1415
  st.dataframe(probs, use_container_width=True, hide_index=True)
1416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1417
 
1418
  def _render_prediction_card(result: dict) -> None:
1419
  """Editorial decision card: provenance · verdict · signals · SHAP."""
 
1414
  if probs:
1415
  st.dataframe(probs, use_container_width=True, hide_index=True)
1416
 
1417
+ st.markdown("#### EEG Pretrained Classifier")
1418
+ st.caption(
1419
+ "Stub-able for the demo: drop a sklearn `predict_proba` joblib at "
1420
+ "`data/processed/eeg_clf.joblib` (or set `EEG_CLF_ARTIFACT`). Default "
1421
+ "labels are `(control, alzheimers)` — override via `EEG_CLF_LABELS`."
1422
+ )
1423
+ eeg_csv = st.text_area(
1424
+ "EEG features (comma-separated)",
1425
+ ",".join(["0.0"] * 16),
1426
+ key="eeg_predict_features",
1427
+ height=80,
1428
+ )
1429
+ if st.button("Predict EEG", key="eeg_predict"):
1430
+ try:
1431
+ features = [float(x.strip()) for x in eeg_csv.split(",") if x.strip()]
1432
+ except ValueError:
1433
+ st.error("EEG features must all be numeric.")
1434
+ else:
1435
+ payload = {"features": features}
1436
+ with st.spinner("Running EEG classifier..."):
1437
+ try:
1438
+ result = _post("/predict/eeg", payload, timeout=30.0)
1439
+ except httpx.HTTPStatusError as e:
1440
+ if e.response.status_code == 503:
1441
+ st.warning(
1442
+ "EEG model artifact missing. Drop the trained joblib at "
1443
+ "`data/processed/eeg_clf.joblib` or set `EEG_CLF_ARTIFACT`."
1444
+ )
1445
+ else:
1446
+ st.error(f"EEG prediction failed (HTTP {e.response.status_code}): {e.response.text}")
1447
+ except httpx.RequestError as e:
1448
+ st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
1449
+ else:
1450
+ st.metric(
1451
+ label=result.get("label_text", "prediction"),
1452
+ value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%",
1453
+ )
1454
+ probs = result.get("probabilities", [])
1455
+ if probs:
1456
+ st.dataframe(probs, use_container_width=True, hide_index=True)
1457
+
1458
 
1459
  def _render_prediction_card(result: dict) -> None:
1460
  """Editorial decision card: provenance · verdict · signals · SHAP."""
tests/fusion/test_eeg_modality_flow.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end: EEG classifier output flows into fusion as the `eeg` modality."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+
8
+ from src.fusion import engine
9
+ from src.fusion.types import (
10
+ FusionInput,
11
+ ModalityClassProb,
12
+ ModalityPrediction,
13
+ )
14
+ from src.models import eeg_model
15
+ from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg
16
+
17
+
18
+ def _eeg_pred_from_features(model, features: np.ndarray) -> ModalityPrediction:
19
+ raw = eeg_model.predict_features(model, features)
20
+ return ModalityPrediction(
21
+ label_text=raw["label_text"],
22
+ label=raw["label"],
23
+ confidence=raw["confidence"],
24
+ probabilities=[
25
+ ModalityClassProb(label_text=p["label_text"], probability=p["probability"])
26
+ for p in raw["probabilities"]
27
+ ],
28
+ )
29
+
30
+
31
+ class TestEEGFusionFlow:
32
+ def test_alzheimers_eeg_lifts_alzheimers_disease_score(self, tmp_path: Path) -> None:
33
+ ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
34
+ model = eeg_model.load(ckpt)
35
+ eeg_pred = _eeg_pred_from_features(model, np.full((16,), 2.0, dtype=np.float32))
36
+
37
+ out = engine.fuse(FusionInput(eeg=eeg_pred))
38
+
39
+ alz = next(d for d in out.diseases if d.disease == "alzheimers")
40
+ assert alz.probability > 0.5
41
+ assert any(c.modality == "eeg" for c in alz.contributions)
42
+ assert "mri" in out.missing_inputs
43
+
44
+ def test_control_eeg_does_not_inflate_alzheimers(self, tmp_path: Path) -> None:
45
+ ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
46
+ model = eeg_model.load(ckpt)
47
+ eeg_pred = _eeg_pred_from_features(model, np.zeros((16,), dtype=np.float32))
48
+
49
+ out = engine.fuse(FusionInput(eeg=eeg_pred))
50
+
51
+ alz = next(d for d in out.diseases if d.disease == "alzheimers")
52
+ assert alz.probability < 0.5
tests/models/test_eeg_model_real.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real-artifact EEG sanity. Skipped unless data/processed/eeg_clf.joblib exists."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ from src.models import eeg_model
10
+
11
+
12
+ REAL_CKPT = Path("data/processed/eeg_clf.joblib")
13
+
14
+
15
+ @pytest.mark.skipif(not REAL_CKPT.exists(), reason="real EEG checkpoint not present")
16
+ def test_real_eeg_checkpoint_loads_and_predicts():
17
+ model = eeg_model.load(REAL_CKPT)
18
+ n_features = int(getattr(model, "n_features_in_", 16))
19
+ features = np.zeros((n_features,), dtype=np.float32)
20
+ out = eeg_model.predict_features(model, features)
21
+ s = sum(p["probability"] for p in out["probabilities"])
22
+ assert abs(s - 1.0) < 1e-5
23
+ assert out["label_text"]