feat(fusion): add pydantic data contract for multi-modal fusion
Browse files- src/fusion/__init__.py +0 -0
- src/fusion/types.py +56 -0
- tests/fusion/__init__.py +0 -0
- tests/fusion/test_types.py +68 -0
src/fusion/__init__.py
ADDED
|
File without changes
|
src/fusion/types.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic data contract for the multi-modal fusion engine."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Annotated
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ModalityClassProb(BaseModel):
|
| 10 |
+
label_text: str
|
| 11 |
+
probability: float = Field(..., ge=0.0, le=1.0)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ModalityPrediction(BaseModel):
|
| 15 |
+
"""One modality's classifier output (MRI or EEG)."""
|
| 16 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 17 |
+
|
| 18 |
+
label_text: str
|
| 19 |
+
label: int = Field(..., ge=0)
|
| 20 |
+
confidence: float = Field(..., ge=0.0, le=1.0)
|
| 21 |
+
probabilities: list[ModalityClassProb] = Field(..., min_length=1)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ClinicalScores(BaseModel):
|
| 25 |
+
"""Doctor-entered extra-test scores. Each is optional."""
|
| 26 |
+
mmse: Annotated[float, Field(ge=0.0, le=30.0)] | None = None
|
| 27 |
+
moca: Annotated[float, Field(ge=0.0, le=30.0)] | None = None
|
| 28 |
+
updrs: Annotated[float, Field(ge=0.0, le=199.0)] | None = None
|
| 29 |
+
gait_speed_m_s: Annotated[float, Field(ge=0.0, le=2.5)] | None = None
|
| 30 |
+
age_years: Annotated[float, Field(ge=0.0, le=120.0)] | None = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FusionInput(BaseModel):
|
| 34 |
+
mri: ModalityPrediction | None = None
|
| 35 |
+
eeg: ModalityPrediction | None = None
|
| 36 |
+
clinical: ClinicalScores = Field(default_factory=ClinicalScores)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ModalityContribution(BaseModel):
|
| 40 |
+
"""One row of the attribution table for a single disease score."""
|
| 41 |
+
modality: str # "mri" | "eeg" | "clinical_<name>"
|
| 42 |
+
weight: float
|
| 43 |
+
signal: float = Field(..., ge=-1.0, le=1.0)
|
| 44 |
+
delta_logit: float
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DiseaseScore(BaseModel):
|
| 48 |
+
disease: str
|
| 49 |
+
probability: float = Field(..., ge=0.0, le=1.0)
|
| 50 |
+
contributions: list[ModalityContribution]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FusionOutput(BaseModel):
|
| 54 |
+
diseases: list[DiseaseScore]
|
| 55 |
+
top_disease: str
|
| 56 |
+
missing_inputs: list[str] = Field(default_factory=list)
|
tests/fusion/__init__.py
ADDED
|
File without changes
|
tests/fusion/test_types.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.fusion.types — pydantic contract for fusion I/O."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
from pydantic import ValidationError
|
| 6 |
+
|
| 7 |
+
from src.fusion.types import (
|
| 8 |
+
ClinicalScores,
|
| 9 |
+
DiseaseScore,
|
| 10 |
+
FusionInput,
|
| 11 |
+
FusionOutput,
|
| 12 |
+
ModalityContribution,
|
| 13 |
+
ModalityPrediction,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestModalityPrediction:
|
| 18 |
+
def test_minimal_round_trip(self) -> None:
|
| 19 |
+
pred = ModalityPrediction(
|
| 20 |
+
label_text="alzheimers", label=1, confidence=0.81,
|
| 21 |
+
probabilities=[
|
| 22 |
+
{"label_text": "control", "probability": 0.19},
|
| 23 |
+
{"label_text": "alzheimers", "probability": 0.81},
|
| 24 |
+
],
|
| 25 |
+
)
|
| 26 |
+
assert pred.label == 1
|
| 27 |
+
assert pred.probabilities[1].probability == pytest.approx(0.81)
|
| 28 |
+
|
| 29 |
+
def test_probabilities_must_be_non_empty(self) -> None:
|
| 30 |
+
with pytest.raises(ValidationError):
|
| 31 |
+
ModalityPrediction(label_text="x", label=0, confidence=0.5, probabilities=[])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TestClinicalScores:
|
| 35 |
+
def test_all_optional(self) -> None:
|
| 36 |
+
s = ClinicalScores()
|
| 37 |
+
assert s.mmse is None and s.age_years is None
|
| 38 |
+
|
| 39 |
+
def test_rejects_out_of_range_mmse(self) -> None:
|
| 40 |
+
with pytest.raises(ValidationError):
|
| 41 |
+
ClinicalScores(mmse=42.0)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TestFusionInputOutput:
|
| 45 |
+
def test_fusion_input_allows_no_modalities(self) -> None:
|
| 46 |
+
# Caller may pass nothing — engine returns baseline scores.
|
| 47 |
+
f = FusionInput()
|
| 48 |
+
assert f.mri is None and f.eeg is None
|
| 49 |
+
assert f.clinical == ClinicalScores()
|
| 50 |
+
|
| 51 |
+
def test_fusion_output_round_trip(self) -> None:
|
| 52 |
+
out = FusionOutput(
|
| 53 |
+
diseases=[
|
| 54 |
+
DiseaseScore(
|
| 55 |
+
disease="alzheimers",
|
| 56 |
+
probability=0.7,
|
| 57 |
+
contributions=[
|
| 58 |
+
ModalityContribution(
|
| 59 |
+
modality="mri", weight=0.35, signal=0.6, delta_logit=0.21,
|
| 60 |
+
)
|
| 61 |
+
],
|
| 62 |
+
)
|
| 63 |
+
],
|
| 64 |
+
top_disease="alzheimers",
|
| 65 |
+
missing_inputs=["eeg"],
|
| 66 |
+
)
|
| 67 |
+
assert out.top_disease == "alzheimers"
|
| 68 |
+
assert out.diseases[0].contributions[0].delta_logit == pytest.approx(0.21)
|