fix(fusion): correct logit-scale comment; top_disease=None on empty input
Browse files- src/fusion/engine.py +5 -2
- src/fusion/types.py +1 -1
- tests/api/test_fusion_route.py +1 -0
- tests/fusion/test_engine.py +1 -0
src/fusion/engine.py
CHANGED
|
@@ -21,7 +21,7 @@ from src.fusion.types import (
|
|
| 21 |
|
| 22 |
logger = get_logger(__name__)
|
| 23 |
|
| 24 |
-
_LOGIT_SCALE = 4.0 #
|
| 25 |
|
| 26 |
|
| 27 |
# Clinical-test name -> (signal_fn, attribute_on_ClinicalScores)
|
|
@@ -46,7 +46,10 @@ def fuse(inp: FusionInput) -> FusionOutput:
|
|
| 46 |
for disease in weight_registry.available_diseases():
|
| 47 |
diseases.append(_score_one_disease(disease, inp))
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
| 50 |
return FusionOutput(diseases=diseases, top_disease=top, missing_inputs=missing)
|
| 51 |
|
| 52 |
|
|
|
|
| 21 |
|
| 22 |
logger = get_logger(__name__)
|
| 23 |
|
| 24 |
+
_LOGIT_SCALE = 4.0 # a single saturated modality maps to ~0.77-0.80 (depending on its weight)
|
| 25 |
|
| 26 |
|
| 27 |
# Clinical-test name -> (signal_fn, attribute_on_ClinicalScores)
|
|
|
|
| 46 |
for disease in weight_registry.available_diseases():
|
| 47 |
diseases.append(_score_one_disease(disease, inp))
|
| 48 |
|
| 49 |
+
if any(d.contributions for d in diseases):
|
| 50 |
+
top: str | None = max(diseases, key=lambda d: d.probability).disease
|
| 51 |
+
else:
|
| 52 |
+
top = None
|
| 53 |
return FusionOutput(diseases=diseases, top_disease=top, missing_inputs=missing)
|
| 54 |
|
| 55 |
|
src/fusion/types.py
CHANGED
|
@@ -51,5 +51,5 @@ class DiseaseScore(BaseModel):
|
|
| 51 |
|
| 52 |
class FusionOutput(BaseModel):
|
| 53 |
diseases: list[DiseaseScore]
|
| 54 |
-
top_disease: str
|
| 55 |
missing_inputs: list[str] = Field(default_factory=list)
|
|
|
|
| 51 |
|
| 52 |
class FusionOutput(BaseModel):
|
| 53 |
diseases: list[DiseaseScore]
|
| 54 |
+
top_disease: str | None
|
| 55 |
missing_inputs: list[str] = Field(default_factory=list)
|
tests/api/test_fusion_route.py
CHANGED
|
@@ -36,6 +36,7 @@ class TestFusionRoute:
|
|
| 36 |
for d in data["diseases"]:
|
| 37 |
assert abs(d["probability"] - 0.5) < 1e-6
|
| 38 |
assert "mri" in data["missing_inputs"]
|
|
|
|
| 39 |
|
| 40 |
def test_invalid_probability_returns_422(self) -> None:
|
| 41 |
body = {
|
|
|
|
| 36 |
for d in data["diseases"]:
|
| 37 |
assert abs(d["probability"] - 0.5) < 1e-6
|
| 38 |
assert "mri" in data["missing_inputs"]
|
| 39 |
+
assert data["top_disease"] is None
|
| 40 |
|
| 41 |
def test_invalid_probability_returns_422(self) -> None:
|
| 42 |
body = {
|
tests/fusion/test_engine.py
CHANGED
|
@@ -40,6 +40,7 @@ class TestFuse:
|
|
| 40 |
assert ds.contributions == []
|
| 41 |
assert "mri" in out.missing_inputs
|
| 42 |
assert "eeg" in out.missing_inputs
|
|
|
|
| 43 |
|
| 44 |
def test_mri_only_alzheimers_high(self) -> None:
|
| 45 |
inp = FusionInput(mri=_mri(prob_alz=0.9))
|
|
|
|
| 40 |
assert ds.contributions == []
|
| 41 |
assert "mri" in out.missing_inputs
|
| 42 |
assert "eeg" in out.missing_inputs
|
| 43 |
+
assert out.top_disease is None
|
| 44 |
|
| 45 |
def test_mri_only_alzheimers_high(self) -> None:
|
| 46 |
inp = FusionInput(mri=_mri(prob_alz=0.9))
|