feat(fusion): add core multi-modal fuse() with per-disease attribution
Browse files- src/fusion/engine.py +105 -0
- tests/fusion/test_engine.py +107 -0
src/fusion/engine.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-modal fusion engine — combines MRI, EEG, and clinical signals into
|
| 2 |
+
per-disease confidence with full attribution.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Callable
|
| 8 |
+
|
| 9 |
+
from src.core.logger import get_logger
|
| 10 |
+
from src.fusion import clinical as clinical_signals
|
| 11 |
+
from src.fusion import weights as weight_registry
|
| 12 |
+
from src.fusion.modality import signal_for_disease
|
| 13 |
+
from src.fusion.types import (
|
| 14 |
+
ClinicalScores,
|
| 15 |
+
DiseaseScore,
|
| 16 |
+
FusionInput,
|
| 17 |
+
FusionOutput,
|
| 18 |
+
ModalityContribution,
|
| 19 |
+
ModalityPrediction,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
_LOGIT_SCALE = 4.0 # tuned so a single saturated modality maps to ~0.88
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Clinical-test name -> (signal_fn, attribute_on_ClinicalScores)
|
| 28 |
+
_CLINICAL_FNS: dict[str, tuple[Callable[[float], float], str]] = {
|
| 29 |
+
"clinical_mmse": (clinical_signals.mmse_to_signal, "mmse"),
|
| 30 |
+
"clinical_moca": (clinical_signals.moca_to_signal, "moca"),
|
| 31 |
+
"clinical_updrs": (clinical_signals.updrs_to_signal, "updrs"),
|
| 32 |
+
"clinical_gait": (clinical_signals.gait_to_signal, "gait_speed_m_s"),
|
| 33 |
+
"clinical_age": (clinical_signals.age_to_signal, "age_years"),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def fuse(inp: FusionInput) -> FusionOutput:
|
| 38 |
+
"""Combine all available modalities into a per-disease confidence."""
|
| 39 |
+
missing: list[str] = []
|
| 40 |
+
if inp.mri is None:
|
| 41 |
+
missing.append("mri")
|
| 42 |
+
if inp.eeg is None:
|
| 43 |
+
missing.append("eeg")
|
| 44 |
+
|
| 45 |
+
diseases: list[DiseaseScore] = []
|
| 46 |
+
for disease in weight_registry.available_diseases():
|
| 47 |
+
diseases.append(_score_one_disease(disease, inp))
|
| 48 |
+
|
| 49 |
+
top = max(diseases, key=lambda d: d.probability).disease
|
| 50 |
+
return FusionOutput(diseases=diseases, top_disease=top, missing_inputs=missing)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _score_one_disease(disease: str, inp: FusionInput) -> DiseaseScore:
|
| 54 |
+
weights = weight_registry.get_weights(disease)
|
| 55 |
+
contributions: list[ModalityContribution] = []
|
| 56 |
+
|
| 57 |
+
for modality_key, weight in weights.items():
|
| 58 |
+
signal = _signal_for_modality(modality_key, disease, inp.mri, inp.eeg, inp.clinical)
|
| 59 |
+
if signal is None:
|
| 60 |
+
continue
|
| 61 |
+
contributions.append(ModalityContribution(
|
| 62 |
+
modality=modality_key,
|
| 63 |
+
weight=weight,
|
| 64 |
+
signal=signal,
|
| 65 |
+
delta_logit=weight * signal,
|
| 66 |
+
))
|
| 67 |
+
|
| 68 |
+
if not contributions:
|
| 69 |
+
logger.info("no signals available for disease=%s; returning baseline 0.5", disease)
|
| 70 |
+
return DiseaseScore(disease=disease, probability=0.5, contributions=[])
|
| 71 |
+
|
| 72 |
+
logit = sum(c.delta_logit for c in contributions)
|
| 73 |
+
probability = _sigmoid(_LOGIT_SCALE * logit)
|
| 74 |
+
return DiseaseScore(
|
| 75 |
+
disease=disease,
|
| 76 |
+
probability=probability,
|
| 77 |
+
contributions=contributions,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _signal_for_modality(
|
| 82 |
+
modality_key: str,
|
| 83 |
+
disease: str,
|
| 84 |
+
mri: ModalityPrediction | None,
|
| 85 |
+
eeg: ModalityPrediction | None,
|
| 86 |
+
clinical: ClinicalScores,
|
| 87 |
+
) -> float | None:
|
| 88 |
+
if modality_key == "mri":
|
| 89 |
+
return signal_for_disease(mri, disease) if mri is not None else None
|
| 90 |
+
if modality_key == "eeg":
|
| 91 |
+
return signal_for_disease(eeg, disease) if eeg is not None else None
|
| 92 |
+
if modality_key in _CLINICAL_FNS:
|
| 93 |
+
fn, attr = _CLINICAL_FNS[modality_key]
|
| 94 |
+
value = getattr(clinical, attr, None)
|
| 95 |
+
return fn(value) if value is not None else None
|
| 96 |
+
logger.warning("unknown modality key in weights table: %s", modality_key)
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _sigmoid(x: float) -> float:
|
| 101 |
+
if x >= 0:
|
| 102 |
+
z = math.exp(-x)
|
| 103 |
+
return 1.0 / (1.0 + z)
|
| 104 |
+
z = math.exp(x)
|
| 105 |
+
return z / (1.0 + z)
|
tests/fusion/test_engine.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.fusion.engine.fuse — the core multi-modal combiner."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from src.fusion import engine
|
| 10 |
+
from src.fusion.types import (
|
| 11 |
+
ClinicalScores,
|
| 12 |
+
FusionInput,
|
| 13 |
+
ModalityClassProb,
|
| 14 |
+
ModalityPrediction,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _mri(prob_alz: float, prob_pd: float = 0.0) -> ModalityPrediction:
|
| 19 |
+
p_other = max(0.0, 1.0 - prob_alz - prob_pd)
|
| 20 |
+
items = [
|
| 21 |
+
ModalityClassProb(label_text="control", probability=p_other),
|
| 22 |
+
ModalityClassProb(label_text="alzheimers", probability=prob_alz),
|
| 23 |
+
ModalityClassProb(label_text="parkinsons", probability=prob_pd),
|
| 24 |
+
]
|
| 25 |
+
top = max(items, key=lambda p: p.probability)
|
| 26 |
+
return ModalityPrediction(
|
| 27 |
+
label_text=top.label_text,
|
| 28 |
+
label=[p.label_text for p in items].index(top.label_text),
|
| 29 |
+
confidence=top.probability,
|
| 30 |
+
probabilities=items,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TestFuse:
|
| 35 |
+
def test_empty_input_returns_baseline_with_missing_listed(self) -> None:
|
| 36 |
+
out = engine.fuse(FusionInput())
|
| 37 |
+
assert {d.disease for d in out.diseases} >= {"alzheimers", "parkinsons", "other"}
|
| 38 |
+
for ds in out.diseases:
|
| 39 |
+
assert ds.probability == pytest.approx(0.5, abs=1e-6)
|
| 40 |
+
assert ds.contributions == []
|
| 41 |
+
assert "mri" in out.missing_inputs
|
| 42 |
+
assert "eeg" in out.missing_inputs
|
| 43 |
+
|
| 44 |
+
def test_mri_only_alzheimers_high(self) -> None:
|
| 45 |
+
inp = FusionInput(mri=_mri(prob_alz=0.9))
|
| 46 |
+
out = engine.fuse(inp)
|
| 47 |
+
alz = next(d for d in out.diseases if d.disease == "alzheimers")
|
| 48 |
+
assert alz.probability > 0.7
|
| 49 |
+
assert any(c.modality == "mri" for c in alz.contributions)
|
| 50 |
+
assert out.top_disease == "alzheimers"
|
| 51 |
+
|
| 52 |
+
def test_mri_eeg_agreement_boosts_above_either_alone(self) -> None:
|
| 53 |
+
only_mri = engine.fuse(FusionInput(mri=_mri(prob_alz=0.8)))
|
| 54 |
+
only_eeg = engine.fuse(FusionInput(eeg=_mri(prob_alz=0.8)))
|
| 55 |
+
both = engine.fuse(FusionInput(
|
| 56 |
+
mri=_mri(prob_alz=0.8), eeg=_mri(prob_alz=0.8),
|
| 57 |
+
))
|
| 58 |
+
|
| 59 |
+
def alz(out: Any) -> float:
|
| 60 |
+
return next(d for d in out.diseases if d.disease == "alzheimers").probability
|
| 61 |
+
|
| 62 |
+
assert alz(both) > alz(only_mri)
|
| 63 |
+
assert alz(both) > alz(only_eeg)
|
| 64 |
+
|
| 65 |
+
def test_clinical_only_low_mmse_raises_alzheimers(self) -> None:
|
| 66 |
+
out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0)))
|
| 67 |
+
alz = next(d for d in out.diseases if d.disease == "alzheimers")
|
| 68 |
+
assert alz.probability > 0.55
|
| 69 |
+
assert any(c.modality == "clinical_mmse" for c in alz.contributions)
|
| 70 |
+
|
| 71 |
+
def test_disagreement_moderates_confidence(self) -> None:
|
| 72 |
+
out = engine.fuse(FusionInput(
|
| 73 |
+
mri=_mri(prob_alz=0.85),
|
| 74 |
+
clinical=ClinicalScores(mmse=30.0),
|
| 75 |
+
))
|
| 76 |
+
alz = next(d for d in out.diseases if d.disease == "alzheimers")
|
| 77 |
+
assert 0.5 < alz.probability < 0.78
|
| 78 |
+
|
| 79 |
+
def test_unknown_clinical_field_is_ignored_safely(self) -> None:
|
| 80 |
+
out = engine.fuse(FusionInput(clinical=ClinicalScores(age_years=80.0)))
|
| 81 |
+
assert out.top_disease in {"alzheimers", "parkinsons", "other"}
|
| 82 |
+
|
| 83 |
+
def test_engine_does_not_depend_on_bbb(self) -> None:
|
| 84 |
+
# Independence regression: fusion must not couple to BBB. A patient
|
| 85 |
+
# with only MRI/EEG/clinical data must produce a valid output even
|
| 86 |
+
# though no BBB module is involved.
|
| 87 |
+
import inspect
|
| 88 |
+
import src.fusion.engine as engine_mod
|
| 89 |
+
import src.fusion.weights as weights_mod
|
| 90 |
+
assert "bbb" not in inspect.getsource(engine_mod).lower()
|
| 91 |
+
for disease in weights_mod.available_diseases():
|
| 92 |
+
for key in weights_mod.get_weights(disease):
|
| 93 |
+
assert "bbb" not in key.lower(), (disease, key)
|
| 94 |
+
|
| 95 |
+
def test_warning_logged_when_disease_has_no_signals(
|
| 96 |
+
self, caplog: pytest.LogCaptureFixture
|
| 97 |
+
) -> None:
|
| 98 |
+
# 'other' disease with no MRI/EEG inputs -> no signals available.
|
| 99 |
+
engine.logger.addHandler(caplog.handler)
|
| 100 |
+
caplog.handler.setLevel(logging.INFO)
|
| 101 |
+
try:
|
| 102 |
+
out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0)))
|
| 103 |
+
finally:
|
| 104 |
+
engine.logger.removeHandler(caplog.handler)
|
| 105 |
+
other = next(d for d in out.diseases if d.disease == "other")
|
| 106 |
+
assert other.probability == pytest.approx(0.5, abs=1e-6)
|
| 107 |
+
assert other.contributions == []
|