feat(api): POST /pipeline/mri/diagnostics — pre/post ComBat KPIs + long-format rows
Browse files- Adds MRIDiagnosticsRequest/HarmonizationRow/MRIDiagnosticsResponse
schemas. The response carries the long-format rows plus 3 KPIs:
site_gap_pre, site_gap_post, reduction_factor (= pre/max(post,eps)).
- Site-gap is computed on the first feature's per-site means
(max - min). reduction_factor uses 1e-9 epsilon to avoid div-by-zero
when ComBat collapses the gap to numerical zero.
- Empty-volume input returns an empty rows list with zero KPIs (no
exception). FileNotFoundError → 404, KeyError → 400.
- 2 new tests: 200 happy path on the synthetic fixture (must pin
reduction_factor >= 1.0) and 404 on missing input_dir.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/api/routes.py +47 -0
- src/api/schemas.py +22 -0
- tests/api/test_routes.py +32 -0
src/api/routes.py
CHANGED
|
@@ -23,6 +23,9 @@ from src.api.schemas import (
|
|
| 23 |
CalibrationContext,
|
| 24 |
EEGRequest,
|
| 25 |
FeatureAttribution,
|
|
|
|
|
|
|
|
|
|
| 26 |
MRIRequest,
|
| 27 |
PipelineResponse,
|
| 28 |
)
|
|
@@ -184,3 +187,47 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
|
|
| 184 |
top_features=[FeatureAttribution(**a) for a in attributions],
|
| 185 |
calibration=calibration,
|
| 186 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
CalibrationContext,
|
| 24 |
EEGRequest,
|
| 25 |
FeatureAttribution,
|
| 26 |
+
HarmonizationRow,
|
| 27 |
+
MRIDiagnosticsRequest,
|
| 28 |
+
MRIDiagnosticsResponse,
|
| 29 |
MRIRequest,
|
| 30 |
PipelineResponse,
|
| 31 |
)
|
|
|
|
| 187 |
top_features=[FeatureAttribution(**a) for a in attributions],
|
| 188 |
calibration=calibration,
|
| 189 |
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@router.post("/mri/diagnostics", response_model=MRIDiagnosticsResponse)
|
| 193 |
+
def mri_diagnostics(req: MRIDiagnosticsRequest) -> MRIDiagnosticsResponse:
|
| 194 |
+
"""Run the MRI pipeline twice and return pre/post ComBat data + site-gap KPIs."""
|
| 195 |
+
input_dir = Path(req.input_dir)
|
| 196 |
+
sites_csv = Path(req.sites_csv)
|
| 197 |
+
try:
|
| 198 |
+
df = mri_pipeline.compute_harmonization_diagnostics(
|
| 199 |
+
input_dir=input_dir, sites_csv=sites_csv,
|
| 200 |
+
)
|
| 201 |
+
except FileNotFoundError as e:
|
| 202 |
+
raise HTTPException(status_code=404, detail=str(e))
|
| 203 |
+
except KeyError as e:
|
| 204 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 205 |
+
|
| 206 |
+
if df.empty:
|
| 207 |
+
return MRIDiagnosticsResponse(
|
| 208 |
+
rows=[], site_gap_pre=0.0, site_gap_post=0.0, reduction_factor=0.0,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Site-gap KPI on the first feature, averaged per site
|
| 212 |
+
feat = df["feature"].iloc[0]
|
| 213 |
+
feat_df = df[df["feature"] == feat]
|
| 214 |
+
pre_means = feat_df[feat_df["harmonization_state"] == "Pre-ComBat"].groupby(
|
| 215 |
+
"site"
|
| 216 |
+
)["feature_value"].mean()
|
| 217 |
+
post_means = feat_df[feat_df["harmonization_state"] == "Post-ComBat"].groupby(
|
| 218 |
+
"site"
|
| 219 |
+
)["feature_value"].mean()
|
| 220 |
+
site_gap_pre = float(pre_means.max() - pre_means.min())
|
| 221 |
+
site_gap_post = float(post_means.max() - post_means.min())
|
| 222 |
+
eps = 1e-9
|
| 223 |
+
reduction_factor = site_gap_pre / max(site_gap_post, eps)
|
| 224 |
+
|
| 225 |
+
rows = [
|
| 226 |
+
HarmonizationRow(**rec) for rec in df.to_dict(orient="records")
|
| 227 |
+
]
|
| 228 |
+
return MRIDiagnosticsResponse(
|
| 229 |
+
rows=rows,
|
| 230 |
+
site_gap_pre=site_gap_pre,
|
| 231 |
+
site_gap_post=site_gap_post,
|
| 232 |
+
reduction_factor=reduction_factor,
|
| 233 |
+
)
|
src/api/schemas.py
CHANGED
|
@@ -80,3 +80,25 @@ class BBBPredictResponse(BaseModel):
|
|
| 80 |
None,
|
| 81 |
description="Statistical context: how often the model is right when this confident on held-out data.",
|
| 82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
None,
|
| 81 |
description="Statistical context: how often the model is right when this confident on held-out data.",
|
| 82 |
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class MRIDiagnosticsRequest(BaseModel):
|
| 86 |
+
"""Request body for /pipeline/mri/diagnostics — same as MRIRequest minus output_path."""
|
| 87 |
+
input_dir: str = Field(..., description="Directory of .nii.gz files")
|
| 88 |
+
sites_csv: str = Field(..., description="CSV mapping subject_id → site")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class HarmonizationRow(BaseModel):
|
| 92 |
+
subject_id: str
|
| 93 |
+
site: str
|
| 94 |
+
feature: str
|
| 95 |
+
feature_value: float
|
| 96 |
+
harmonization_state: str
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class MRIDiagnosticsResponse(BaseModel):
|
| 100 |
+
"""Long-format pre/post ComBat data for visualization."""
|
| 101 |
+
rows: list[HarmonizationRow]
|
| 102 |
+
site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
|
| 103 |
+
site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
|
| 104 |
+
reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
|
tests/api/test_routes.py
CHANGED
|
@@ -137,3 +137,35 @@ class TestBBBPredictRoute:
|
|
| 137 |
json={"smiles": "CCO", "top_k": 5},
|
| 138 |
)
|
| 139 |
assert resp.status_code == 503
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
json={"smiles": "CCO", "top_k": 5},
|
| 138 |
)
|
| 139 |
assert resp.status_code == 503
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class TestMRIDiagnosticsRoute:
|
| 143 |
+
def test_returns_200_with_pre_and_post_data(self, tmp_path: Path):
|
| 144 |
+
from tests.fixtures.build_mri_fixture import build as build_mri
|
| 145 |
+
fixture_dir = build_mri(out_dir=tmp_path / "mri")
|
| 146 |
+
resp = client.post(
|
| 147 |
+
"/pipeline/mri/diagnostics",
|
| 148 |
+
json={
|
| 149 |
+
"input_dir": str(fixture_dir),
|
| 150 |
+
"sites_csv": str(fixture_dir / "sites.csv"),
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
assert resp.status_code == 200
|
| 154 |
+
body = resp.json()
|
| 155 |
+
assert len(body["rows"]) > 0
|
| 156 |
+
assert body["site_gap_pre"] >= 0.0
|
| 157 |
+
assert body["site_gap_post"] >= 0.0
|
| 158 |
+
# Reduction factor is the headline KPI
|
| 159 |
+
assert body["reduction_factor"] >= 1.0 # ComBat must reduce, not amplify
|
| 160 |
+
states = {r["harmonization_state"] for r in body["rows"]}
|
| 161 |
+
assert states == {"Pre-ComBat", "Post-ComBat"}
|
| 162 |
+
|
| 163 |
+
def test_returns_404_when_input_dir_missing(self, tmp_path: Path):
|
| 164 |
+
resp = client.post(
|
| 165 |
+
"/pipeline/mri/diagnostics",
|
| 166 |
+
json={
|
| 167 |
+
"input_dir": str(tmp_path / "does_not_exist"),
|
| 168 |
+
"sites_csv": str(tmp_path / "sites.csv"),
|
| 169 |
+
},
|
| 170 |
+
)
|
| 171 |
+
assert resp.status_code == 404
|