| """Tests for src.fusion.engine.fuse — the core multi-modal combiner.""" |
| from __future__ import annotations |
|
|
| import logging |
| from typing import Any |
|
|
| import pytest |
|
|
| from src.fusion import engine |
| from src.fusion.types import ( |
| ClinicalScores, |
| FusionInput, |
| ModalityClassProb, |
| ModalityPrediction, |
| ) |
|
|
|
|
| def _mri(prob_alz: float, prob_pd: float = 0.0) -> ModalityPrediction: |
| p_other = max(0.0, 1.0 - prob_alz - prob_pd) |
| items = [ |
| ModalityClassProb(label_text="control", probability=p_other), |
| ModalityClassProb(label_text="alzheimers", probability=prob_alz), |
| ModalityClassProb(label_text="parkinsons", probability=prob_pd), |
| ] |
| top = max(items, key=lambda p: p.probability) |
| return ModalityPrediction( |
| label_text=top.label_text, |
| label=[p.label_text for p in items].index(top.label_text), |
| confidence=top.probability, |
| probabilities=items, |
| ) |
|
|
|
|
| class TestFuse: |
| def test_empty_input_returns_baseline_with_missing_listed(self) -> None: |
| out = engine.fuse(FusionInput()) |
| assert {d.disease for d in out.diseases} >= {"alzheimers", "parkinsons", "other"} |
| for ds in out.diseases: |
| assert ds.probability == pytest.approx(0.5, abs=1e-6) |
| assert ds.contributions == [] |
| assert "mri" in out.missing_inputs |
| assert "eeg" in out.missing_inputs |
| assert out.top_disease is None |
|
|
| def test_mri_only_alzheimers_high(self) -> None: |
| inp = FusionInput(mri=_mri(prob_alz=0.9)) |
| out = engine.fuse(inp) |
| alz = next(d for d in out.diseases if d.disease == "alzheimers") |
| assert alz.probability > 0.7 |
| assert any(c.modality == "mri" for c in alz.contributions) |
| assert out.top_disease == "alzheimers" |
|
|
| def test_mri_eeg_agreement_boosts_above_either_alone(self) -> None: |
| only_mri = engine.fuse(FusionInput(mri=_mri(prob_alz=0.8))) |
| only_eeg = engine.fuse(FusionInput(eeg=_mri(prob_alz=0.8))) |
| both = engine.fuse(FusionInput( |
| mri=_mri(prob_alz=0.8), eeg=_mri(prob_alz=0.8), |
| )) |
|
|
| def alz(out: Any) -> float: |
| return next(d for d in out.diseases if d.disease == "alzheimers").probability |
|
|
| assert alz(both) > alz(only_mri) |
| assert alz(both) > alz(only_eeg) |
|
|
| def test_clinical_only_low_mmse_raises_alzheimers(self) -> None: |
| out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0))) |
| alz = next(d for d in out.diseases if d.disease == "alzheimers") |
| assert alz.probability > 0.55 |
| assert any(c.modality == "clinical_mmse" for c in alz.contributions) |
|
|
| def test_disagreement_moderates_confidence(self) -> None: |
| out = engine.fuse(FusionInput( |
| mri=_mri(prob_alz=0.85), |
| clinical=ClinicalScores(mmse=30.0), |
| )) |
| alz = next(d for d in out.diseases if d.disease == "alzheimers") |
| assert 0.5 < alz.probability < 0.78 |
|
|
| def test_unknown_clinical_field_is_ignored_safely(self) -> None: |
| out = engine.fuse(FusionInput(clinical=ClinicalScores(age_years=80.0))) |
| assert out.top_disease in {"alzheimers", "parkinsons", "other"} |
|
|
| def test_engine_does_not_depend_on_bbb(self) -> None: |
| |
| |
| |
| import inspect |
| import src.fusion.engine as engine_mod |
| import src.fusion.weights as weights_mod |
| assert "bbb" not in inspect.getsource(engine_mod).lower() |
| for disease in weights_mod.available_diseases(): |
| for key in weights_mod.get_weights(disease): |
| assert "bbb" not in key.lower(), (disease, key) |
|
|
| def test_warning_logged_when_disease_has_no_signals( |
| self, caplog: pytest.LogCaptureFixture |
| ) -> None: |
| |
| engine.logger.addHandler(caplog.handler) |
| caplog.handler.setLevel(logging.INFO) |
| try: |
| out = engine.fuse(FusionInput(clinical=ClinicalScores(mmse=10.0))) |
| finally: |
| engine.logger.removeHandler(caplog.handler) |
| other = next(d for d in out.diseases if d.disease == "other") |
| assert other.probability == pytest.approx(0.5, abs=1e-6) |
| assert other.contributions == [] |
|
|