File size: 4,421 Bytes
2134339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac781dd
2134339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Tests for src.fusion.engine.fuse — the core multi-modal combiner."""
from __future__ import annotations

import logging
from typing import Any

import pytest

from src.fusion import engine
from src.fusion.types import (
    ClinicalScores,
    FusionInput,
    ModalityClassProb,
    ModalityPrediction,
)


def _mri(prob_alz: float, prob_pd: float = 0.0) -> ModalityPrediction:
    p_other = max(0.0, 1.0 - prob_alz - prob_pd)
    items = [
        ModalityClassProb(label_text="control", probability=p_other),
        ModalityClassProb(label_text="alzheimers", probability=prob_alz),
        ModalityClassProb(label_text="parkinsons", probability=prob_pd),
    ]
    top = max(items, key=lambda p: p.probability)
    return ModalityPrediction(
        label_text=top.label_text,
        label=[p.label_text for p in items].index(top.label_text),
        confidence=top.probability,
        probabilities=items,
    )


class TestFuse:
    def test_empty_input_returns_baseline_with_missing_listed(self) -> None:
        out = engine.fuse(FusionInput())
        assert {d.disease for d in out.diseases} >= {"alzheimers", "parkinsons", "other"}
        for ds in out.diseases:
            assert ds.probability == pytest.approx(0.5, abs=1e-6)
            assert ds.contributions == []
        assert "mri" in out.missing_inputs
        assert "eeg" in out.missing_inputs
        assert out.top_disease is None

    def test_mri_only_alzheimers_high(self) -> None:
        inp = FusionInput(mri=_mri(prob_alz=0.9))
        out = engine.fuse(inp)
        alz = next(d for d in out.diseases if d.disease == "alzheimers")
        assert alz.probability > 0.7
        assert any(c.modality == "mri" for c in alz.contributions)
        assert out.top_disease == "alzheimers"

    def test_mri_eeg_agreement_boosts_above_either_alone(self) -> None:
        only_mri = engine.fuse(FusionInput(mri=_mri(prob_alz=0.8)))
        only_eeg = engine.fuse(FusionInput(eeg=_mri(prob_alz=0.8)))
        both = engine.fuse(FusionInput(
            mri=_mri(prob_alz=0.8), eeg=_mri(prob_alz=0.8),
        ))

        def alz(out: Any) -> float:
            return next(d for d in out.diseases if d.disease == "alzheimers").probability

        assert alz(both) > alz(only_mri)
        assert alz(both) > alz(only_eeg)

    def test_clinical_only_low_mmse_raises_alzheimers(self) -> None:
        out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0)))
        alz = next(d for d in out.diseases if d.disease == "alzheimers")
        assert alz.probability > 0.55
        assert any(c.modality == "clinical_mmse" for c in alz.contributions)

    def test_disagreement_moderates_confidence(self) -> None:
        out = engine.fuse(FusionInput(
            mri=_mri(prob_alz=0.85),
            clinical=ClinicalScores(mmse=30.0),
        ))
        alz = next(d for d in out.diseases if d.disease == "alzheimers")
        assert 0.5 < alz.probability < 0.78

    def test_unknown_clinical_field_is_ignored_safely(self) -> None:
        out = engine.fuse(FusionInput(clinical=ClinicalScores(age_years=80.0)))
        assert out.top_disease in {"alzheimers", "parkinsons", "other"}

    def test_engine_does_not_depend_on_bbb(self) -> None:
        # Independence regression: fusion must not couple to BBB. A patient
        # with only MRI/EEG/clinical data must produce a valid output even
        # though no BBB module is involved.
        import inspect
        import src.fusion.engine as engine_mod
        import src.fusion.weights as weights_mod
        assert "bbb" not in inspect.getsource(engine_mod).lower()
        for disease in weights_mod.available_diseases():
            for key in weights_mod.get_weights(disease):
                assert "bbb" not in key.lower(), (disease, key)

    def test_warning_logged_when_disease_has_no_signals(
        self, caplog: pytest.LogCaptureFixture
    ) -> None:
        # 'other' disease with no MRI/EEG inputs -> no signals available.
        engine.logger.addHandler(caplog.handler)
        caplog.handler.setLevel(logging.INFO)
        try:
            out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0)))
        finally:
            engine.logger.removeHandler(caplog.handler)
        other = next(d for d in out.diseases if d.disease == "other")
        assert other.probability == pytest.approx(0.5, abs=1e-6)
        assert other.contributions == []