| """Tests for src.fusion.weights — disease/modality weight registry.""" |
| from __future__ import annotations |
|
|
| import pytest |
|
|
| from src.fusion import weights |
|
|
|
|
| class TestWeights: |
| def test_available_diseases_includes_known(self) -> None: |
| diseases = set(weights.available_diseases()) |
| assert {"alzheimers", "parkinsons", "other"} <= diseases |
|
|
| def test_available_clinical_tests_includes_each_named_input(self) -> None: |
| tests = set(weights.available_clinical_tests()) |
| assert {"mmse", "moca", "updrs", "gait", "age"} <= tests |
|
|
| def test_get_weights_returns_nonempty_mapping(self) -> None: |
| ws = weights.get_weights("alzheimers") |
| assert ws["mri"] > 0 |
| assert sum(ws.values()) == pytest.approx(1.0, abs=1e-6) |
|
|
| def test_get_weights_unknown_disease_raises(self) -> None: |
| with pytest.raises(KeyError, match="unknown disease"): |
| weights.get_weights("invented_disease") |
|
|
| def test_each_disease_weight_table_sums_to_one(self) -> None: |
| for d in weights.available_diseases(): |
| assert sum(weights.get_weights(d).values()) == pytest.approx(1.0, abs=1e-6), d |
|
|