File size: 1,553 Bytes
b91e55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Tests for src.fusion.modality — turn ModalityPrediction into a per-disease signal."""
from __future__ import annotations

import pytest

from src.fusion.modality import signal_for_disease
from src.fusion.types import ModalityClassProb, ModalityPrediction


def _pred(probs: dict[str, float]) -> ModalityPrediction:
    items = [ModalityClassProb(label_text=k, probability=v) for k, v in probs.items()]
    top = max(items, key=lambda p: p.probability)
    return ModalityPrediction(
        label_text=top.label_text,
        label=list(probs).index(top.label_text),
        confidence=top.probability,
        probabilities=items,
    )


class TestSignalForDisease:
    def test_disease_class_present_high_prob(self) -> None:
        pred = _pred({"control": 0.1, "alzheimers": 0.9})
        sig = signal_for_disease(pred, disease="alzheimers")
        assert sig == pytest.approx(0.8)

    def test_disease_class_present_low_prob(self) -> None:
        pred = _pred({"control": 0.95, "alzheimers": 0.05})
        sig = signal_for_disease(pred, disease="alzheimers")
        assert sig == pytest.approx(-0.9)

    def test_disease_class_absent_returns_none(self) -> None:
        pred = _pred({"control": 0.4, "parkinsons": 0.6})
        sig = signal_for_disease(pred, disease="alzheimers")
        assert sig is None

    def test_label_alias_matches_case_insensitively(self) -> None:
        pred = _pred({"Control": 0.2, "ALZHEIMERS": 0.8})
        sig = signal_for_disease(pred, disease="alzheimers")
        assert sig == pytest.approx(0.6)