| """End-to-end: EEG classifier output flows into fusion as the `eeg` modality.""" |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
|
|
| from src.fusion import engine |
| from src.fusion.types import ( |
| FusionInput, |
| ModalityClassProb, |
| ModalityPrediction, |
| ) |
| from src.models import eeg_model |
| from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg |
|
|
|
|
| def _eeg_pred_from_features(model, features: np.ndarray) -> ModalityPrediction: |
| raw = eeg_model.predict_features(model, features) |
| return ModalityPrediction( |
| label_text=raw["label_text"], |
| label=raw["label"], |
| confidence=raw["confidence"], |
| probabilities=[ |
| ModalityClassProb(label_text=p["label_text"], probability=p["probability"]) |
| for p in raw["probabilities"] |
| ], |
| ) |
|
|
|
|
| class TestEEGFusionFlow: |
| def test_alzheimers_eeg_lifts_alzheimers_disease_score(self, tmp_path: Path) -> None: |
| ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16) |
| model = eeg_model.load(ckpt) |
| eeg_pred = _eeg_pred_from_features(model, np.full((16,), 2.0, dtype=np.float32)) |
|
|
| out = engine.fuse(FusionInput(eeg=eeg_pred)) |
|
|
| alz = next(d for d in out.diseases if d.disease == "alzheimers") |
| assert alz.probability > 0.5 |
| assert any(c.modality == "eeg" for c in alz.contributions) |
| assert "mri" in out.missing_inputs |
|
|
| def test_control_eeg_does_not_inflate_alzheimers(self, tmp_path: Path) -> None: |
| ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16) |
| model = eeg_model.load(ckpt) |
| eeg_pred = _eeg_pred_from_features(model, np.zeros((16,), dtype=np.float32)) |
|
|
| out = engine.fuse(FusionInput(eeg=eeg_pred)) |
|
|
| alz = next(d for d in out.diseases if d.disease == "alzheimers") |
| assert alz.probability < 0.5 |
|
|