mekosotto Claude Sonnet 4.6 commited on
Commit
ac781dd
·
1 Parent(s): f45c02a

fix(fusion): correct logit-scale comment; top_disease=None on empty input

Browse files
src/fusion/engine.py CHANGED
@@ -21,7 +21,7 @@ from src.fusion.types import (
21
 
22
  logger = get_logger(__name__)
23
 
24
- _LOGIT_SCALE = 4.0 # tuned so a single saturated modality maps to ~0.88
25
 
26
 
27
  # Clinical-test name -> (signal_fn, attribute_on_ClinicalScores)
@@ -46,7 +46,10 @@ def fuse(inp: FusionInput) -> FusionOutput:
46
  for disease in weight_registry.available_diseases():
47
  diseases.append(_score_one_disease(disease, inp))
48
 
49
- top = max(diseases, key=lambda d: d.probability).disease
 
 
 
50
  return FusionOutput(diseases=diseases, top_disease=top, missing_inputs=missing)
51
 
52
 
 
21
 
22
  logger = get_logger(__name__)
23
 
24
+ _LOGIT_SCALE = 4.0 # a single saturated modality maps to ~0.77-0.80 (depending on its weight)
25
 
26
 
27
  # Clinical-test name -> (signal_fn, attribute_on_ClinicalScores)
 
46
  for disease in weight_registry.available_diseases():
47
  diseases.append(_score_one_disease(disease, inp))
48
 
49
+ if any(d.contributions for d in diseases):
50
+ top: str | None = max(diseases, key=lambda d: d.probability).disease
51
+ else:
52
+ top = None
53
  return FusionOutput(diseases=diseases, top_disease=top, missing_inputs=missing)
54
 
55
 
src/fusion/types.py CHANGED
@@ -51,5 +51,5 @@ class DiseaseScore(BaseModel):
51
 
52
  class FusionOutput(BaseModel):
53
  diseases: list[DiseaseScore]
54
- top_disease: str
55
  missing_inputs: list[str] = Field(default_factory=list)
 
51
 
52
  class FusionOutput(BaseModel):
53
  diseases: list[DiseaseScore]
54
+ top_disease: str | None
55
  missing_inputs: list[str] = Field(default_factory=list)
tests/api/test_fusion_route.py CHANGED
@@ -36,6 +36,7 @@ class TestFusionRoute:
36
  for d in data["diseases"]:
37
  assert abs(d["probability"] - 0.5) < 1e-6
38
  assert "mri" in data["missing_inputs"]
 
39
 
40
  def test_invalid_probability_returns_422(self) -> None:
41
  body = {
 
36
  for d in data["diseases"]:
37
  assert abs(d["probability"] - 0.5) < 1e-6
38
  assert "mri" in data["missing_inputs"]
39
+ assert data["top_disease"] is None
40
 
41
  def test_invalid_probability_returns_422(self) -> None:
42
  body = {
tests/fusion/test_engine.py CHANGED
@@ -40,6 +40,7 @@ class TestFuse:
40
  assert ds.contributions == []
41
  assert "mri" in out.missing_inputs
42
  assert "eeg" in out.missing_inputs
 
43
 
44
  def test_mri_only_alzheimers_high(self) -> None:
45
  inp = FusionInput(mri=_mri(prob_alz=0.9))
 
40
  assert ds.contributions == []
41
  assert "mri" in out.missing_inputs
42
  assert "eeg" in out.missing_inputs
43
+ assert out.top_disease is None
44
 
45
  def test_mri_only_alzheimers_high(self) -> None:
46
  inp = FusionInput(mri=_mri(prob_alz=0.9))