feat(fusion): add disease/modality weight registry
Browse files- src/fusion/weights.py +48 -0
- tests/fusion/test_weights.py +29 -0
src/fusion/weights.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Disease × modality weight registry for the fusion engine.
|
| 2 |
+
|
| 3 |
+
Heuristic preset weights — tune offline as clinician feedback arrives.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import Mapping
|
| 8 |
+
|
| 9 |
+
DEFAULT_WEIGHTS: dict[str, dict[str, float]] = {
|
| 10 |
+
"alzheimers": {
|
| 11 |
+
"mri": 0.35,
|
| 12 |
+
"eeg": 0.20,
|
| 13 |
+
"clinical_mmse": 0.20,
|
| 14 |
+
"clinical_moca": 0.15,
|
| 15 |
+
"clinical_age": 0.10,
|
| 16 |
+
},
|
| 17 |
+
"parkinsons": {
|
| 18 |
+
"mri": 0.20,
|
| 19 |
+
"eeg": 0.30,
|
| 20 |
+
"clinical_updrs": 0.30,
|
| 21 |
+
"clinical_gait": 0.15,
|
| 22 |
+
"clinical_age": 0.05,
|
| 23 |
+
},
|
| 24 |
+
"other": {
|
| 25 |
+
"mri": 0.50,
|
| 26 |
+
"eeg": 0.50,
|
| 27 |
+
},
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def available_diseases() -> list[str]:
|
| 32 |
+
return sorted(DEFAULT_WEIGHTS.keys())
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def available_clinical_tests() -> list[str]:
|
| 36 |
+
"""Return the bare clinical-test names (without the 'clinical_' prefix)."""
|
| 37 |
+
names: set[str] = set()
|
| 38 |
+
for table in DEFAULT_WEIGHTS.values():
|
| 39 |
+
for key in table:
|
| 40 |
+
if key.startswith("clinical_"):
|
| 41 |
+
names.add(key[len("clinical_"):])
|
| 42 |
+
return sorted(names)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_weights(disease: str) -> Mapping[str, float]:
|
| 46 |
+
if disease not in DEFAULT_WEIGHTS:
|
| 47 |
+
raise KeyError(f"unknown disease: {disease!r}")
|
| 48 |
+
return DEFAULT_WEIGHTS[disease]
|
tests/fusion/test_weights.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.fusion.weights — disease/modality weight registry."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from src.fusion import weights
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TestWeights:
|
| 10 |
+
def test_available_diseases_includes_known(self) -> None:
|
| 11 |
+
diseases = set(weights.available_diseases())
|
| 12 |
+
assert {"alzheimers", "parkinsons", "other"} <= diseases
|
| 13 |
+
|
| 14 |
+
def test_available_clinical_tests_includes_each_named_input(self) -> None:
|
| 15 |
+
tests = set(weights.available_clinical_tests())
|
| 16 |
+
assert {"mmse", "moca", "updrs", "gait", "age"} <= tests
|
| 17 |
+
|
| 18 |
+
def test_get_weights_returns_nonempty_mapping(self) -> None:
|
| 19 |
+
ws = weights.get_weights("alzheimers")
|
| 20 |
+
assert ws["mri"] > 0
|
| 21 |
+
assert sum(ws.values()) == pytest.approx(1.0, abs=1e-6)
|
| 22 |
+
|
| 23 |
+
def test_get_weights_unknown_disease_raises(self) -> None:
|
| 24 |
+
with pytest.raises(KeyError, match="unknown disease"):
|
| 25 |
+
weights.get_weights("invented_disease")
|
| 26 |
+
|
| 27 |
+
def test_each_disease_weight_table_sums_to_one(self) -> None:
|
| 28 |
+
for d in weights.available_diseases():
|
| 29 |
+
assert sum(weights.get_weights(d).values()) == pytest.approx(1.0, abs=1e-6), d
|