mekosotto Claude Sonnet 4.6 commited on
Commit
2134339
·
1 Parent(s): b91e55e

feat(fusion): add core multi-modal fuse() with per-disease attribution

Browse files
Files changed (2) hide show
  1. src/fusion/engine.py +105 -0
  2. tests/fusion/test_engine.py +107 -0
src/fusion/engine.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-modal fusion engine — combines MRI, EEG, and clinical signals into
2
+ per-disease confidence with full attribution.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import Callable
8
+
9
+ from src.core.logger import get_logger
10
+ from src.fusion import clinical as clinical_signals
11
+ from src.fusion import weights as weight_registry
12
+ from src.fusion.modality import signal_for_disease
13
+ from src.fusion.types import (
14
+ ClinicalScores,
15
+ DiseaseScore,
16
+ FusionInput,
17
+ FusionOutput,
18
+ ModalityContribution,
19
+ ModalityPrediction,
20
+ )
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ _LOGIT_SCALE = 4.0 # tuned so a single saturated modality maps to ~0.88
25
+
26
+
27
+ # Clinical-test name -> (signal_fn, attribute_on_ClinicalScores)
28
+ _CLINICAL_FNS: dict[str, tuple[Callable[[float], float], str]] = {
29
+ "clinical_mmse": (clinical_signals.mmse_to_signal, "mmse"),
30
+ "clinical_moca": (clinical_signals.moca_to_signal, "moca"),
31
+ "clinical_updrs": (clinical_signals.updrs_to_signal, "updrs"),
32
+ "clinical_gait": (clinical_signals.gait_to_signal, "gait_speed_m_s"),
33
+ "clinical_age": (clinical_signals.age_to_signal, "age_years"),
34
+ }
35
+
36
+
37
+ def fuse(inp: FusionInput) -> FusionOutput:
38
+ """Combine all available modalities into a per-disease confidence."""
39
+ missing: list[str] = []
40
+ if inp.mri is None:
41
+ missing.append("mri")
42
+ if inp.eeg is None:
43
+ missing.append("eeg")
44
+
45
+ diseases: list[DiseaseScore] = []
46
+ for disease in weight_registry.available_diseases():
47
+ diseases.append(_score_one_disease(disease, inp))
48
+
49
+ top = max(diseases, key=lambda d: d.probability).disease
50
+ return FusionOutput(diseases=diseases, top_disease=top, missing_inputs=missing)
51
+
52
+
53
+ def _score_one_disease(disease: str, inp: FusionInput) -> DiseaseScore:
54
+ weights = weight_registry.get_weights(disease)
55
+ contributions: list[ModalityContribution] = []
56
+
57
+ for modality_key, weight in weights.items():
58
+ signal = _signal_for_modality(modality_key, disease, inp.mri, inp.eeg, inp.clinical)
59
+ if signal is None:
60
+ continue
61
+ contributions.append(ModalityContribution(
62
+ modality=modality_key,
63
+ weight=weight,
64
+ signal=signal,
65
+ delta_logit=weight * signal,
66
+ ))
67
+
68
+ if not contributions:
69
+ logger.info("no signals available for disease=%s; returning baseline 0.5", disease)
70
+ return DiseaseScore(disease=disease, probability=0.5, contributions=[])
71
+
72
+ logit = sum(c.delta_logit for c in contributions)
73
+ probability = _sigmoid(_LOGIT_SCALE * logit)
74
+ return DiseaseScore(
75
+ disease=disease,
76
+ probability=probability,
77
+ contributions=contributions,
78
+ )
79
+
80
+
81
+ def _signal_for_modality(
82
+ modality_key: str,
83
+ disease: str,
84
+ mri: ModalityPrediction | None,
85
+ eeg: ModalityPrediction | None,
86
+ clinical: ClinicalScores,
87
+ ) -> float | None:
88
+ if modality_key == "mri":
89
+ return signal_for_disease(mri, disease) if mri is not None else None
90
+ if modality_key == "eeg":
91
+ return signal_for_disease(eeg, disease) if eeg is not None else None
92
+ if modality_key in _CLINICAL_FNS:
93
+ fn, attr = _CLINICAL_FNS[modality_key]
94
+ value = getattr(clinical, attr, None)
95
+ return fn(value) if value is not None else None
96
+ logger.warning("unknown modality key in weights table: %s", modality_key)
97
+ return None
98
+
99
+
100
+ def _sigmoid(x: float) -> float:
101
+ if x >= 0:
102
+ z = math.exp(-x)
103
+ return 1.0 / (1.0 + z)
104
+ z = math.exp(x)
105
+ return z / (1.0 + z)
tests/fusion/test_engine.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.fusion.engine.fuse — the core multi-modal combiner."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from typing import Any
6
+
7
+ import pytest
8
+
9
+ from src.fusion import engine
10
+ from src.fusion.types import (
11
+ ClinicalScores,
12
+ FusionInput,
13
+ ModalityClassProb,
14
+ ModalityPrediction,
15
+ )
16
+
17
+
18
+ def _mri(prob_alz: float, prob_pd: float = 0.0) -> ModalityPrediction:
19
+ p_other = max(0.0, 1.0 - prob_alz - prob_pd)
20
+ items = [
21
+ ModalityClassProb(label_text="control", probability=p_other),
22
+ ModalityClassProb(label_text="alzheimers", probability=prob_alz),
23
+ ModalityClassProb(label_text="parkinsons", probability=prob_pd),
24
+ ]
25
+ top = max(items, key=lambda p: p.probability)
26
+ return ModalityPrediction(
27
+ label_text=top.label_text,
28
+ label=[p.label_text for p in items].index(top.label_text),
29
+ confidence=top.probability,
30
+ probabilities=items,
31
+ )
32
+
33
+
34
+ class TestFuse:
35
+ def test_empty_input_returns_baseline_with_missing_listed(self) -> None:
36
+ out = engine.fuse(FusionInput())
37
+ assert {d.disease for d in out.diseases} >= {"alzheimers", "parkinsons", "other"}
38
+ for ds in out.diseases:
39
+ assert ds.probability == pytest.approx(0.5, abs=1e-6)
40
+ assert ds.contributions == []
41
+ assert "mri" in out.missing_inputs
42
+ assert "eeg" in out.missing_inputs
43
+
44
+ def test_mri_only_alzheimers_high(self) -> None:
45
+ inp = FusionInput(mri=_mri(prob_alz=0.9))
46
+ out = engine.fuse(inp)
47
+ alz = next(d for d in out.diseases if d.disease == "alzheimers")
48
+ assert alz.probability > 0.7
49
+ assert any(c.modality == "mri" for c in alz.contributions)
50
+ assert out.top_disease == "alzheimers"
51
+
52
+ def test_mri_eeg_agreement_boosts_above_either_alone(self) -> None:
53
+ only_mri = engine.fuse(FusionInput(mri=_mri(prob_alz=0.8)))
54
+ only_eeg = engine.fuse(FusionInput(eeg=_mri(prob_alz=0.8)))
55
+ both = engine.fuse(FusionInput(
56
+ mri=_mri(prob_alz=0.8), eeg=_mri(prob_alz=0.8),
57
+ ))
58
+
59
+ def alz(out: Any) -> float:
60
+ return next(d for d in out.diseases if d.disease == "alzheimers").probability
61
+
62
+ assert alz(both) > alz(only_mri)
63
+ assert alz(both) > alz(only_eeg)
64
+
65
+ def test_clinical_only_low_mmse_raises_alzheimers(self) -> None:
66
+ out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0)))
67
+ alz = next(d for d in out.diseases if d.disease == "alzheimers")
68
+ assert alz.probability > 0.55
69
+ assert any(c.modality == "clinical_mmse" for c in alz.contributions)
70
+
71
+ def test_disagreement_moderates_confidence(self) -> None:
72
+ out = engine.fuse(FusionInput(
73
+ mri=_mri(prob_alz=0.85),
74
+ clinical=ClinicalScores(mmse=30.0),
75
+ ))
76
+ alz = next(d for d in out.diseases if d.disease == "alzheimers")
77
+ assert 0.5 < alz.probability < 0.78
78
+
79
+ def test_unknown_clinical_field_is_ignored_safely(self) -> None:
80
+ out = engine.fuse(FusionInput(clinical=ClinicalScores(age_years=80.0)))
81
+ assert out.top_disease in {"alzheimers", "parkinsons", "other"}
82
+
83
+ def test_engine_does_not_depend_on_bbb(self) -> None:
84
+ # Independence regression: fusion must not couple to BBB. A patient
85
+ # with only MRI/EEG/clinical data must produce a valid output even
86
+ # though no BBB module is involved.
87
+ import inspect
88
+ import src.fusion.engine as engine_mod
89
+ import src.fusion.weights as weights_mod
90
+ assert "bbb" not in inspect.getsource(engine_mod).lower()
91
+ for disease in weights_mod.available_diseases():
92
+ for key in weights_mod.get_weights(disease):
93
+ assert "bbb" not in key.lower(), (disease, key)
94
+
95
+ def test_warning_logged_when_disease_has_no_signals(
96
+ self, caplog: pytest.LogCaptureFixture
97
+ ) -> None:
98
+ # 'other' disease with no MRI/EEG inputs -> no signals available.
99
+ engine.logger.addHandler(caplog.handler)
100
+ caplog.handler.setLevel(logging.INFO)
101
+ try:
102
+ out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0)))
103
+ finally:
104
+ engine.logger.removeHandler(caplog.handler)
105
+ other = next(d for d in out.diseases if d.disease == "other")
106
+ assert other.probability == pytest.approx(0.5, abs=1e-6)
107
+ assert other.contributions == []