"""Tests for src.fusion.modality — turn ModalityPrediction into a per-disease signal.""" from __future__ import annotations import pytest from src.fusion.modality import signal_for_disease from src.fusion.types import ModalityClassProb, ModalityPrediction def _pred(probs: dict[str, float]) -> ModalityPrediction: items = [ModalityClassProb(label_text=k, probability=v) for k, v in probs.items()] top = max(items, key=lambda p: p.probability) return ModalityPrediction( label_text=top.label_text, label=list(probs).index(top.label_text), confidence=top.probability, probabilities=items, ) class TestSignalForDisease: def test_disease_class_present_high_prob(self) -> None: pred = _pred({"control": 0.1, "alzheimers": 0.9}) sig = signal_for_disease(pred, disease="alzheimers") assert sig == pytest.approx(0.8) def test_disease_class_present_low_prob(self) -> None: pred = _pred({"control": 0.95, "alzheimers": 0.05}) sig = signal_for_disease(pred, disease="alzheimers") assert sig == pytest.approx(-0.9) def test_disease_class_absent_returns_none(self) -> None: pred = _pred({"control": 0.4, "parkinsons": 0.6}) sig = signal_for_disease(pred, disease="alzheimers") assert sig is None def test_label_alias_matches_case_insensitively(self) -> None: pred = _pred({"Control": 0.2, "ALZHEIMERS": 0.8}) sig = signal_for_disease(pred, disease="alzheimers") assert sig == pytest.approx(0.6)