File size: 1,869 Bytes
e8e922d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""End-to-end: EEG classifier output flows into fusion as the `eeg` modality."""
from __future__ import annotations

from pathlib import Path

import numpy as np

from src.fusion import engine
from src.fusion.types import (
    FusionInput,
    ModalityClassProb,
    ModalityPrediction,
)
from src.models import eeg_model
from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg


def _eeg_pred_from_features(model, features: np.ndarray) -> ModalityPrediction:
    raw = eeg_model.predict_features(model, features)
    return ModalityPrediction(
        label_text=raw["label_text"],
        label=raw["label"],
        confidence=raw["confidence"],
        probabilities=[
            ModalityClassProb(label_text=p["label_text"], probability=p["probability"])
            for p in raw["probabilities"]
        ],
    )


class TestEEGFusionFlow:
    def test_alzheimers_eeg_lifts_alzheimers_disease_score(self, tmp_path: Path) -> None:
        ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
        model = eeg_model.load(ckpt)
        eeg_pred = _eeg_pred_from_features(model, np.full((16,), 2.0, dtype=np.float32))

        out = engine.fuse(FusionInput(eeg=eeg_pred))

        alz = next(d for d in out.diseases if d.disease == "alzheimers")
        assert alz.probability > 0.5
        assert any(c.modality == "eeg" for c in alz.contributions)
        assert "mri" in out.missing_inputs

    def test_control_eeg_does_not_inflate_alzheimers(self, tmp_path: Path) -> None:
        ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
        model = eeg_model.load(ckpt)
        eeg_pred = _eeg_pred_from_features(model, np.zeros((16,), dtype=np.float32))

        out = engine.fuse(FusionInput(eeg=eeg_pred))

        alz = next(d for d in out.diseases if d.disease == "alzheimers")
        assert alz.probability < 0.5