mekosotto commited on
Commit
b91e55e
·
1 Parent(s): dd8acc2

feat(fusion): map modality predictions to per-disease signals

Browse files
src/fusion/modality.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert a modality classifier's probability vector into a signed signal."""
2
+ from __future__ import annotations
3
+
4
+ from src.fusion.types import ModalityPrediction
5
+
6
+
7
+ def signal_for_disease(pred: ModalityPrediction, disease: str) -> float | None:
8
+ """Return signal in [-1, 1] for `disease`, or None if the model has no
9
+ matching class.
10
+
11
+ A class matches if its `label_text` equals `disease` case-insensitively.
12
+ Signal = 2 * P(disease) - 1.
13
+ """
14
+ target = disease.strip().lower()
15
+ for cls in pred.probabilities:
16
+ if cls.label_text.strip().lower() == target:
17
+ return 2.0 * cls.probability - 1.0
18
+ return None
tests/fusion/test_modality.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.fusion.modality — turn ModalityPrediction into a per-disease signal."""
2
+ from __future__ import annotations
3
+
4
+ import pytest
5
+
6
+ from src.fusion.modality import signal_for_disease
7
+ from src.fusion.types import ModalityClassProb, ModalityPrediction
8
+
9
+
10
+ def _pred(probs: dict[str, float]) -> ModalityPrediction:
11
+ items = [ModalityClassProb(label_text=k, probability=v) for k, v in probs.items()]
12
+ top = max(items, key=lambda p: p.probability)
13
+ return ModalityPrediction(
14
+ label_text=top.label_text,
15
+ label=list(probs).index(top.label_text),
16
+ confidence=top.probability,
17
+ probabilities=items,
18
+ )
19
+
20
+
21
+ class TestSignalForDisease:
22
+ def test_disease_class_present_high_prob(self) -> None:
23
+ pred = _pred({"control": 0.1, "alzheimers": 0.9})
24
+ sig = signal_for_disease(pred, disease="alzheimers")
25
+ assert sig == pytest.approx(0.8)
26
+
27
+ def test_disease_class_present_low_prob(self) -> None:
28
+ pred = _pred({"control": 0.95, "alzheimers": 0.05})
29
+ sig = signal_for_disease(pred, disease="alzheimers")
30
+ assert sig == pytest.approx(-0.9)
31
+
32
+ def test_disease_class_absent_returns_none(self) -> None:
33
+ pred = _pred({"control": 0.4, "parkinsons": 0.6})
34
+ sig = signal_for_disease(pred, disease="alzheimers")
35
+ assert sig is None
36
+
37
+ def test_label_alias_matches_case_insensitively(self) -> None:
38
+ pred = _pred({"Control": 0.2, "ALZHEIMERS": 0.8})
39
+ sig = signal_for_disease(pred, disease="alzheimers")
40
+ assert sig == pytest.approx(0.6)