mekosotto commited on
Commit
ccf23d1
·
1 Parent(s): 1914360

feat(fusion): add disease/modality weight registry

Browse files
src/fusion/weights.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Disease × modality weight registry for the fusion engine.
2
+
3
+ Heuristic preset weights — tune offline as clinician feedback arrives.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from typing import Mapping
8
+
9
+ DEFAULT_WEIGHTS: dict[str, dict[str, float]] = {
10
+ "alzheimers": {
11
+ "mri": 0.35,
12
+ "eeg": 0.20,
13
+ "clinical_mmse": 0.20,
14
+ "clinical_moca": 0.15,
15
+ "clinical_age": 0.10,
16
+ },
17
+ "parkinsons": {
18
+ "mri": 0.20,
19
+ "eeg": 0.30,
20
+ "clinical_updrs": 0.30,
21
+ "clinical_gait": 0.15,
22
+ "clinical_age": 0.05,
23
+ },
24
+ "other": {
25
+ "mri": 0.50,
26
+ "eeg": 0.50,
27
+ },
28
+ }
29
+
30
+
31
+ def available_diseases() -> list[str]:
32
+ return sorted(DEFAULT_WEIGHTS.keys())
33
+
34
+
35
+ def available_clinical_tests() -> list[str]:
36
+ """Return the bare clinical-test names (without the 'clinical_' prefix)."""
37
+ names: set[str] = set()
38
+ for table in DEFAULT_WEIGHTS.values():
39
+ for key in table:
40
+ if key.startswith("clinical_"):
41
+ names.add(key[len("clinical_"):])
42
+ return sorted(names)
43
+
44
+
45
+ def get_weights(disease: str) -> Mapping[str, float]:
46
+ if disease not in DEFAULT_WEIGHTS:
47
+ raise KeyError(f"unknown disease: {disease!r}")
48
+ return DEFAULT_WEIGHTS[disease]
tests/fusion/test_weights.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.fusion.weights — disease/modality weight registry."""
2
+ from __future__ import annotations
3
+
4
+ import pytest
5
+
6
+ from src.fusion import weights
7
+
8
+
9
+ class TestWeights:
10
+ def test_available_diseases_includes_known(self) -> None:
11
+ diseases = set(weights.available_diseases())
12
+ assert {"alzheimers", "parkinsons", "other"} <= diseases
13
+
14
+ def test_available_clinical_tests_includes_each_named_input(self) -> None:
15
+ tests = set(weights.available_clinical_tests())
16
+ assert {"mmse", "moca", "updrs", "gait", "age"} <= tests
17
+
18
+ def test_get_weights_returns_nonempty_mapping(self) -> None:
19
+ ws = weights.get_weights("alzheimers")
20
+ assert ws["mri"] > 0
21
+ assert sum(ws.values()) == pytest.approx(1.0, abs=1e-6)
22
+
23
+ def test_get_weights_unknown_disease_raises(self) -> None:
24
+ with pytest.raises(KeyError, match="unknown disease"):
25
+ weights.get_weights("invented_disease")
26
+
27
+ def test_each_disease_weight_table_sums_to_one(self) -> None:
28
+ for d in weights.available_diseases():
29
+ assert sum(weights.get_weights(d).values()) == pytest.approx(1.0, abs=1e-6), d