feat(eeg,frontend): EEG fusion-flow test + Streamlit EEG form + real-artifact sanity
Browse files- tests/fusion/test_eeg_modality_flow.py: end-to-end EEG -> fusion `eeg`
modality flow (alzheimers feature spike lifts AD score; control profile
doesn't inflate it).
- tests/models/test_eeg_model_real.py: skip-if-absent sanity for the real
EEG joblib; runs automatically once the file lands at the canonical path.
- src/frontend/app.py: doctor MRI tab grows an EEG predict form
(comma-separated features) calling POST /predict/eeg. 503 surfaces as a
human-readable warning instead of a stack trace.
- src/frontend/app.py +41 -0
- tests/fusion/test_eeg_modality_flow.py +52 -0
- tests/models/test_eeg_model_real.py +23 -0
src/frontend/app.py
CHANGED
|
@@ -1414,6 +1414,47 @@ def _render_mri_tab() -> None:
|
|
| 1414 |
if probs:
|
| 1415 |
st.dataframe(probs, use_container_width=True, hide_index=True)
|
| 1416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1417 |
|
| 1418 |
def _render_prediction_card(result: dict) -> None:
|
| 1419 |
"""Editorial decision card: provenance · verdict · signals · SHAP."""
|
|
|
|
| 1414 |
if probs:
|
| 1415 |
st.dataframe(probs, use_container_width=True, hide_index=True)
|
| 1416 |
|
| 1417 |
+
st.markdown("#### EEG Pretrained Classifier")
|
| 1418 |
+
st.caption(
|
| 1419 |
+
"Stub-able for the demo: drop a sklearn `predict_proba` joblib at "
|
| 1420 |
+
"`data/processed/eeg_clf.joblib` (or set `EEG_CLF_ARTIFACT`). Default "
|
| 1421 |
+
"labels are `(control, alzheimers)` — override via `EEG_CLF_LABELS`."
|
| 1422 |
+
)
|
| 1423 |
+
eeg_csv = st.text_area(
|
| 1424 |
+
"EEG features (comma-separated)",
|
| 1425 |
+
",".join(["0.0"] * 16),
|
| 1426 |
+
key="eeg_predict_features",
|
| 1427 |
+
height=80,
|
| 1428 |
+
)
|
| 1429 |
+
if st.button("Predict EEG", key="eeg_predict"):
|
| 1430 |
+
try:
|
| 1431 |
+
features = [float(x.strip()) for x in eeg_csv.split(",") if x.strip()]
|
| 1432 |
+
except ValueError:
|
| 1433 |
+
st.error("EEG features must all be numeric.")
|
| 1434 |
+
else:
|
| 1435 |
+
payload = {"features": features}
|
| 1436 |
+
with st.spinner("Running EEG classifier..."):
|
| 1437 |
+
try:
|
| 1438 |
+
result = _post("/predict/eeg", payload, timeout=30.0)
|
| 1439 |
+
except httpx.HTTPStatusError as e:
|
| 1440 |
+
if e.response.status_code == 503:
|
| 1441 |
+
st.warning(
|
| 1442 |
+
"EEG model artifact missing. Drop the trained joblib at "
|
| 1443 |
+
"`data/processed/eeg_clf.joblib` or set `EEG_CLF_ARTIFACT`."
|
| 1444 |
+
)
|
| 1445 |
+
else:
|
| 1446 |
+
st.error(f"EEG prediction failed (HTTP {e.response.status_code}): {e.response.text}")
|
| 1447 |
+
except httpx.RequestError as e:
|
| 1448 |
+
st.error(f"Cannot reach FastAPI at {_API_URL}: {e!r}")
|
| 1449 |
+
else:
|
| 1450 |
+
st.metric(
|
| 1451 |
+
label=result.get("label_text", "prediction"),
|
| 1452 |
+
value=f"{float(result.get('confidence', 0.0)) * 100:.1f}%",
|
| 1453 |
+
)
|
| 1454 |
+
probs = result.get("probabilities", [])
|
| 1455 |
+
if probs:
|
| 1456 |
+
st.dataframe(probs, use_container_width=True, hide_index=True)
|
| 1457 |
+
|
| 1458 |
|
| 1459 |
def _render_prediction_card(result: dict) -> None:
|
| 1460 |
"""Editorial decision card: provenance · verdict · signals · SHAP."""
|
tests/fusion/test_eeg_modality_flow.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""End-to-end: EEG classifier output flows into fusion as the `eeg` modality."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from src.fusion import engine
|
| 9 |
+
from src.fusion.types import (
|
| 10 |
+
FusionInput,
|
| 11 |
+
ModalityClassProb,
|
| 12 |
+
ModalityPrediction,
|
| 13 |
+
)
|
| 14 |
+
from src.models import eeg_model
|
| 15 |
+
from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _eeg_pred_from_features(model, features: np.ndarray) -> ModalityPrediction:
|
| 19 |
+
raw = eeg_model.predict_features(model, features)
|
| 20 |
+
return ModalityPrediction(
|
| 21 |
+
label_text=raw["label_text"],
|
| 22 |
+
label=raw["label"],
|
| 23 |
+
confidence=raw["confidence"],
|
| 24 |
+
probabilities=[
|
| 25 |
+
ModalityClassProb(label_text=p["label_text"], probability=p["probability"])
|
| 26 |
+
for p in raw["probabilities"]
|
| 27 |
+
],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TestEEGFusionFlow:
|
| 32 |
+
def test_alzheimers_eeg_lifts_alzheimers_disease_score(self, tmp_path: Path) -> None:
|
| 33 |
+
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
|
| 34 |
+
model = eeg_model.load(ckpt)
|
| 35 |
+
eeg_pred = _eeg_pred_from_features(model, np.full((16,), 2.0, dtype=np.float32))
|
| 36 |
+
|
| 37 |
+
out = engine.fuse(FusionInput(eeg=eeg_pred))
|
| 38 |
+
|
| 39 |
+
alz = next(d for d in out.diseases if d.disease == "alzheimers")
|
| 40 |
+
assert alz.probability > 0.5
|
| 41 |
+
assert any(c.modality == "eeg" for c in alz.contributions)
|
| 42 |
+
assert "mri" in out.missing_inputs
|
| 43 |
+
|
| 44 |
+
def test_control_eeg_does_not_inflate_alzheimers(self, tmp_path: Path) -> None:
|
| 45 |
+
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
|
| 46 |
+
model = eeg_model.load(ckpt)
|
| 47 |
+
eeg_pred = _eeg_pred_from_features(model, np.zeros((16,), dtype=np.float32))
|
| 48 |
+
|
| 49 |
+
out = engine.fuse(FusionInput(eeg=eeg_pred))
|
| 50 |
+
|
| 51 |
+
alz = next(d for d in out.diseases if d.disease == "alzheimers")
|
| 52 |
+
assert alz.probability < 0.5
|
tests/models/test_eeg_model_real.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Real-artifact EEG sanity. Skipped unless data/processed/eeg_clf.joblib exists."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from src.models import eeg_model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
REAL_CKPT = Path("data/processed/eeg_clf.joblib")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.mark.skipif(not REAL_CKPT.exists(), reason="real EEG checkpoint not present")
|
| 16 |
+
def test_real_eeg_checkpoint_loads_and_predicts():
|
| 17 |
+
model = eeg_model.load(REAL_CKPT)
|
| 18 |
+
n_features = int(getattr(model, "n_features_in_", 16))
|
| 19 |
+
features = np.zeros((n_features,), dtype=np.float32)
|
| 20 |
+
out = eeg_model.predict_features(model, features)
|
| 21 |
+
s = sum(p["probability"] for p in out["probabilities"])
|
| 22 |
+
assert abs(s - 1.0) < 1e-5
|
| 23 |
+
assert out["label_text"]
|