File size: 1,869 Bytes
e8e922d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | """End-to-end: EEG classifier output flows into fusion as the `eeg` modality."""
from __future__ import annotations
from pathlib import Path
import numpy as np
from src.fusion import engine
from src.fusion.types import (
FusionInput,
ModalityClassProb,
ModalityPrediction,
)
from src.models import eeg_model
from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg
def _eeg_pred_from_features(model, features: np.ndarray) -> ModalityPrediction:
raw = eeg_model.predict_features(model, features)
return ModalityPrediction(
label_text=raw["label_text"],
label=raw["label"],
confidence=raw["confidence"],
probabilities=[
ModalityClassProb(label_text=p["label_text"], probability=p["probability"])
for p in raw["probabilities"]
],
)
class TestEEGFusionFlow:
def test_alzheimers_eeg_lifts_alzheimers_disease_score(self, tmp_path: Path) -> None:
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
model = eeg_model.load(ckpt)
eeg_pred = _eeg_pred_from_features(model, np.full((16,), 2.0, dtype=np.float32))
out = engine.fuse(FusionInput(eeg=eeg_pred))
alz = next(d for d in out.diseases if d.disease == "alzheimers")
assert alz.probability > 0.5
assert any(c.modality == "eeg" for c in alz.contributions)
assert "mri" in out.missing_inputs
def test_control_eeg_does_not_inflate_alzheimers(self, tmp_path: Path) -> None:
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
model = eeg_model.load(ckpt)
eeg_pred = _eeg_pred_from_features(model, np.zeros((16,), dtype=np.float32))
out = engine.fuse(FusionInput(eeg=eeg_pred))
alz = next(d for d in out.diseases if d.disease == "alzheimers")
assert alz.probability < 0.5
|