File size: 1,553 Bytes
b91e55e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | """Tests for src.fusion.modality — turn ModalityPrediction into a per-disease signal."""
from __future__ import annotations
import pytest
from src.fusion.modality import signal_for_disease
from src.fusion.types import ModalityClassProb, ModalityPrediction
def _pred(probs: dict[str, float]) -> ModalityPrediction:
items = [ModalityClassProb(label_text=k, probability=v) for k, v in probs.items()]
top = max(items, key=lambda p: p.probability)
return ModalityPrediction(
label_text=top.label_text,
label=list(probs).index(top.label_text),
confidence=top.probability,
probabilities=items,
)
class TestSignalForDisease:
def test_disease_class_present_high_prob(self) -> None:
pred = _pred({"control": 0.1, "alzheimers": 0.9})
sig = signal_for_disease(pred, disease="alzheimers")
assert sig == pytest.approx(0.8)
def test_disease_class_present_low_prob(self) -> None:
pred = _pred({"control": 0.95, "alzheimers": 0.05})
sig = signal_for_disease(pred, disease="alzheimers")
assert sig == pytest.approx(-0.9)
def test_disease_class_absent_returns_none(self) -> None:
pred = _pred({"control": 0.4, "parkinsons": 0.6})
sig = signal_for_disease(pred, disease="alzheimers")
assert sig is None
def test_label_alias_matches_case_insensitively(self) -> None:
pred = _pred({"Control": 0.2, "ALZHEIMERS": 0.8})
sig = signal_for_disease(pred, disease="alzheimers")
assert sig == pytest.approx(0.6)
|