mekosotto Claude Opus 4.7 (1M context) commited on
Commit
985240b
·
1 Parent(s): 1068ed1

feat(api): POST /pipeline/mri/diagnostics — pre/post ComBat KPIs + long-format rows

Browse files

- Adds MRIDiagnosticsRequest/HarmonizationRow/MRIDiagnosticsResponse
schemas. The response carries the long-format rows plus 3 KPIs:
site_gap_pre, site_gap_post, reduction_factor (= pre/max(post,eps)).
- Site-gap is computed on the first feature's per-site means
(max - min). reduction_factor uses 1e-9 epsilon to avoid div-by-zero
when ComBat collapses the gap to numerical zero.
- Empty-volume input returns an empty rows list with zero KPIs (no
exception). FileNotFoundError → 404, KeyError → 400.
- 2 new tests: 200 happy path on the synthetic fixture (must pin
reduction_factor >= 1.0) and 404 on missing input_dir.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (3) hide show
  1. src/api/routes.py +47 -0
  2. src/api/schemas.py +22 -0
  3. tests/api/test_routes.py +32 -0
src/api/routes.py CHANGED
@@ -23,6 +23,9 @@ from src.api.schemas import (
23
  CalibrationContext,
24
  EEGRequest,
25
  FeatureAttribution,
 
 
 
26
  MRIRequest,
27
  PipelineResponse,
28
  )
@@ -184,3 +187,47 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
184
  top_features=[FeatureAttribution(**a) for a in attributions],
185
  calibration=calibration,
186
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  CalibrationContext,
24
  EEGRequest,
25
  FeatureAttribution,
26
+ HarmonizationRow,
27
+ MRIDiagnosticsRequest,
28
+ MRIDiagnosticsResponse,
29
  MRIRequest,
30
  PipelineResponse,
31
  )
 
187
  top_features=[FeatureAttribution(**a) for a in attributions],
188
  calibration=calibration,
189
  )
190
+
191
+
192
+ @router.post("/mri/diagnostics", response_model=MRIDiagnosticsResponse)
193
+ def mri_diagnostics(req: MRIDiagnosticsRequest) -> MRIDiagnosticsResponse:
194
+ """Run the MRI pipeline twice and return pre/post ComBat data + site-gap KPIs."""
195
+ input_dir = Path(req.input_dir)
196
+ sites_csv = Path(req.sites_csv)
197
+ try:
198
+ df = mri_pipeline.compute_harmonization_diagnostics(
199
+ input_dir=input_dir, sites_csv=sites_csv,
200
+ )
201
+ except FileNotFoundError as e:
202
+ raise HTTPException(status_code=404, detail=str(e))
203
+ except KeyError as e:
204
+ raise HTTPException(status_code=400, detail=str(e))
205
+
206
+ if df.empty:
207
+ return MRIDiagnosticsResponse(
208
+ rows=[], site_gap_pre=0.0, site_gap_post=0.0, reduction_factor=0.0,
209
+ )
210
+
211
+ # Site-gap KPI on the first feature, averaged per site
212
+ feat = df["feature"].iloc[0]
213
+ feat_df = df[df["feature"] == feat]
214
+ pre_means = feat_df[feat_df["harmonization_state"] == "Pre-ComBat"].groupby(
215
+ "site"
216
+ )["feature_value"].mean()
217
+ post_means = feat_df[feat_df["harmonization_state"] == "Post-ComBat"].groupby(
218
+ "site"
219
+ )["feature_value"].mean()
220
+ site_gap_pre = float(pre_means.max() - pre_means.min())
221
+ site_gap_post = float(post_means.max() - post_means.min())
222
+ eps = 1e-9
223
+ reduction_factor = site_gap_pre / max(site_gap_post, eps)
224
+
225
+ rows = [
226
+ HarmonizationRow(**rec) for rec in df.to_dict(orient="records")
227
+ ]
228
+ return MRIDiagnosticsResponse(
229
+ rows=rows,
230
+ site_gap_pre=site_gap_pre,
231
+ site_gap_post=site_gap_post,
232
+ reduction_factor=reduction_factor,
233
+ )
src/api/schemas.py CHANGED
@@ -80,3 +80,25 @@ class BBBPredictResponse(BaseModel):
80
  None,
81
  description="Statistical context: how often the model is right when this confident on held-out data.",
82
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  None,
81
  description="Statistical context: how often the model is right when this confident on held-out data.",
82
  )
83
+
84
+
85
+ class MRIDiagnosticsRequest(BaseModel):
86
+ """Request body for /pipeline/mri/diagnostics — same as MRIRequest minus output_path."""
87
+ input_dir: str = Field(..., description="Directory of .nii.gz files")
88
+ sites_csv: str = Field(..., description="CSV mapping subject_id → site")
89
+
90
+
91
+ class HarmonizationRow(BaseModel):
92
+ subject_id: str
93
+ site: str
94
+ feature: str
95
+ feature_value: float
96
+ harmonization_state: str
97
+
98
+
99
+ class MRIDiagnosticsResponse(BaseModel):
100
+ """Long-format pre/post ComBat data for visualization."""
101
+ rows: list[HarmonizationRow]
102
+ site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
103
+ site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
104
+ reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
tests/api/test_routes.py CHANGED
@@ -137,3 +137,35 @@ class TestBBBPredictRoute:
137
  json={"smiles": "CCO", "top_k": 5},
138
  )
139
  assert resp.status_code == 503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  json={"smiles": "CCO", "top_k": 5},
138
  )
139
  assert resp.status_code == 503
140
+
141
+
142
+ class TestMRIDiagnosticsRoute:
143
+ def test_returns_200_with_pre_and_post_data(self, tmp_path: Path):
144
+ from tests.fixtures.build_mri_fixture import build as build_mri
145
+ fixture_dir = build_mri(out_dir=tmp_path / "mri")
146
+ resp = client.post(
147
+ "/pipeline/mri/diagnostics",
148
+ json={
149
+ "input_dir": str(fixture_dir),
150
+ "sites_csv": str(fixture_dir / "sites.csv"),
151
+ },
152
+ )
153
+ assert resp.status_code == 200
154
+ body = resp.json()
155
+ assert len(body["rows"]) > 0
156
+ assert body["site_gap_pre"] >= 0.0
157
+ assert body["site_gap_post"] >= 0.0
158
+ # Reduction factor is the headline KPI
159
+ assert body["reduction_factor"] >= 1.0 # ComBat must reduce, not amplify
160
+ states = {r["harmonization_state"] for r in body["rows"]}
161
+ assert states == {"Pre-ComBat", "Post-ComBat"}
162
+
163
+ def test_returns_404_when_input_dir_missing(self, tmp_path: Path):
164
+ resp = client.post(
165
+ "/pipeline/mri/diagnostics",
166
+ json={
167
+ "input_dir": str(tmp_path / "does_not_exist"),
168
+ "sites_csv": str(tmp_path / "sites.csv"),
169
+ },
170
+ )
171
+ assert resp.status_code == 404