mekosotto Claude Sonnet 4.6 commited on
Commit
a189a33
·
1 Parent(s): 44397bd

feat(fusion): add pydantic data contract for multi-modal fusion

Browse files
src/fusion/__init__.py ADDED
File without changes
src/fusion/types.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic data contract for the multi-modal fusion engine."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Annotated
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
+
9
+ class ModalityClassProb(BaseModel):
10
+ label_text: str
11
+ probability: float = Field(..., ge=0.0, le=1.0)
12
+
13
+
14
+ class ModalityPrediction(BaseModel):
15
+ """One modality's classifier output (MRI or EEG)."""
16
+ model_config = ConfigDict(protected_namespaces=())
17
+
18
+ label_text: str
19
+ label: int = Field(..., ge=0)
20
+ confidence: float = Field(..., ge=0.0, le=1.0)
21
+ probabilities: list[ModalityClassProb] = Field(..., min_length=1)
22
+
23
+
24
+ class ClinicalScores(BaseModel):
25
+ """Doctor-entered extra-test scores. Each is optional."""
26
+ mmse: Annotated[float, Field(ge=0.0, le=30.0)] | None = None
27
+ moca: Annotated[float, Field(ge=0.0, le=30.0)] | None = None
28
+ updrs: Annotated[float, Field(ge=0.0, le=199.0)] | None = None
29
+ gait_speed_m_s: Annotated[float, Field(ge=0.0, le=2.5)] | None = None
30
+ age_years: Annotated[float, Field(ge=0.0, le=120.0)] | None = None
31
+
32
+
33
+ class FusionInput(BaseModel):
34
+ mri: ModalityPrediction | None = None
35
+ eeg: ModalityPrediction | None = None
36
+ clinical: ClinicalScores = Field(default_factory=ClinicalScores)
37
+
38
+
39
+ class ModalityContribution(BaseModel):
40
+ """One row of the attribution table for a single disease score."""
41
+ modality: str # "mri" | "eeg" | "clinical_<name>"
42
+ weight: float
43
+ signal: float = Field(..., ge=-1.0, le=1.0)
44
+ delta_logit: float
45
+
46
+
47
+ class DiseaseScore(BaseModel):
48
+ disease: str
49
+ probability: float = Field(..., ge=0.0, le=1.0)
50
+ contributions: list[ModalityContribution]
51
+
52
+
53
+ class FusionOutput(BaseModel):
54
+ diseases: list[DiseaseScore]
55
+ top_disease: str
56
+ missing_inputs: list[str] = Field(default_factory=list)
tests/fusion/__init__.py ADDED
File without changes
tests/fusion/test_types.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.fusion.types — pydantic contract for fusion I/O."""
2
+ from __future__ import annotations
3
+
4
+ import pytest
5
+ from pydantic import ValidationError
6
+
7
+ from src.fusion.types import (
8
+ ClinicalScores,
9
+ DiseaseScore,
10
+ FusionInput,
11
+ FusionOutput,
12
+ ModalityContribution,
13
+ ModalityPrediction,
14
+ )
15
+
16
+
17
+ class TestModalityPrediction:
18
+ def test_minimal_round_trip(self) -> None:
19
+ pred = ModalityPrediction(
20
+ label_text="alzheimers", label=1, confidence=0.81,
21
+ probabilities=[
22
+ {"label_text": "control", "probability": 0.19},
23
+ {"label_text": "alzheimers", "probability": 0.81},
24
+ ],
25
+ )
26
+ assert pred.label == 1
27
+ assert pred.probabilities[1].probability == pytest.approx(0.81)
28
+
29
+ def test_probabilities_must_be_non_empty(self) -> None:
30
+ with pytest.raises(ValidationError):
31
+ ModalityPrediction(label_text="x", label=0, confidence=0.5, probabilities=[])
32
+
33
+
34
+ class TestClinicalScores:
35
+ def test_all_optional(self) -> None:
36
+ s = ClinicalScores()
37
+ assert s.mmse is None and s.age_years is None
38
+
39
+ def test_rejects_out_of_range_mmse(self) -> None:
40
+ with pytest.raises(ValidationError):
41
+ ClinicalScores(mmse=42.0)
42
+
43
+
44
+ class TestFusionInputOutput:
45
+ def test_fusion_input_allows_no_modalities(self) -> None:
46
+ # Caller may pass nothing — engine returns baseline scores.
47
+ f = FusionInput()
48
+ assert f.mri is None and f.eeg is None
49
+ assert f.clinical == ClinicalScores()
50
+
51
+ def test_fusion_output_round_trip(self) -> None:
52
+ out = FusionOutput(
53
+ diseases=[
54
+ DiseaseScore(
55
+ disease="alzheimers",
56
+ probability=0.7,
57
+ contributions=[
58
+ ModalityContribution(
59
+ modality="mri", weight=0.35, signal=0.6, delta_logit=0.21,
60
+ )
61
+ ],
62
+ )
63
+ ],
64
+ top_disease="alzheimers",
65
+ missing_inputs=["eeg"],
66
+ )
67
+ assert out.top_disease == "alzheimers"
68
+ assert out.diseases[0].contributions[0].delta_logit == pytest.approx(0.21)