File size: 647 Bytes
b91e55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Convert a modality classifier's probability vector into a signed signal."""
from __future__ import annotations

from src.fusion.types import ModalityPrediction


def signal_for_disease(pred: ModalityPrediction, disease: str) -> float | None:
    """Return signal in [-1, 1] for `disease`, or None if the model has no
    matching class.

    A class matches if its `label_text` equals `disease` case-insensitively.
    Signal = 2 * P(disease) - 1.
    """
    target = disease.strip().lower()
    for cls in pred.probabilities:
        if cls.label_text.strip().lower() == target:
            return 2.0 * cls.probability - 1.0
    return None