mekosotto Claude Sonnet 4.6 commited on
Commit
7215c7f
·
1 Parent(s): f7e54c4

feat(mri): add run_pipeline orchestrator + CLI (NIfTI dir → ComBat Parquet)

Browse files
src/pipelines/mri_pipeline.py CHANGED
@@ -12,6 +12,7 @@ traceability (in/out/dropped counts at INFO), and idempotent overwrite.
12
  from __future__ import annotations
13
 
14
  import os
 
15
 
16
  import nibabel as nib
17
  import numpy as np
@@ -258,3 +259,162 @@ def harmonize_combat(
258
  len(out), len(feature_cols), sites.nunique(),
259
  )
260
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from __future__ import annotations
13
 
14
  import os
15
+ from pathlib import Path
16
 
17
  import nibabel as nib
18
  import numpy as np
 
259
  len(out), len(feature_cols), sites.nunique(),
260
  )
261
  return out
262
+
263
+
264
+ # Default I/O paths for the MRI pipeline. Override via run_pipeline() args.
265
+ DEFAULT_INPUT = Path("data/raw/mri")
266
+ DEFAULT_OUTPUT = Path("data/processed/mri_features.parquet")
267
+
268
+
269
+ def _list_nifti_volumes(input_dir: Path) -> list[Path]:
270
+ """Return sorted list of .nii / .nii.gz files in `input_dir`."""
271
+ return sorted(
272
+ p for p in input_dir.iterdir()
273
+ if p.suffix == ".nii" or p.name.endswith(".nii.gz")
274
+ )
275
+
276
+
277
+ def run_pipeline(
278
+ input_dir: Path = DEFAULT_INPUT,
279
+ sites_csv: Path | None = None,
280
+ output_path: Path = DEFAULT_OUTPUT,
281
+ intensity_threshold: float | None = None,
282
+ n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
283
+ ) -> None:
284
+ """Run the MRI pipeline end-to-end: NIfTI directory → harmonized Parquet.
285
+
286
+ For each `subject_id.nii(.gz)` in `input_dir`, validates the volume,
287
+ masks the brain, computes per-ROI statistics, then harmonizes across
288
+ sites (column "site" of `sites_csv`, joined on "subject_id") via ComBat.
289
+ Output is float64 Parquet at `output_path`.
290
+
291
+ Args:
292
+ input_dir: Directory containing one NIfTI per subject and a
293
+ `sites.csv` (or `sites_csv` override) with columns
294
+ `subject_id, site`.
295
+ sites_csv: Path to the site-covariates CSV. If `None`, defaults to
296
+ `input_dir / "sites.csv"`.
297
+ output_path: Where to write the processed feature Parquet file.
298
+ intensity_threshold: Brain-mask intensity floor. `None` → per-volume
299
+ mean (see `mask_brain`).
300
+ n_roi_axes: ROI grid (z, y, x).
301
+
302
+ Raises:
303
+ FileNotFoundError: if `input_dir` does not exist.
304
+ IsADirectoryError: if `output_path` resolves to an existing directory.
305
+ KeyError: if `sites_csv` is missing a site for some subject.
306
+ """
307
+ input_dir = Path(input_dir)
308
+ output_path = Path(output_path)
309
+ if not input_dir.exists():
310
+ raise FileNotFoundError(f"MRI input directory not found: {input_dir}")
311
+ sites_csv = Path(sites_csv) if sites_csv is not None else input_dir / "sites.csv"
312
+ if not sites_csv.exists():
313
+ raise FileNotFoundError(f"sites_csv not found: {sites_csv}")
314
+
315
+ logger.info("Reading MRI volumes from %s", input_dir)
316
+ nifti_paths = _list_nifti_volumes(input_dir)
317
+ sites_df = pd.read_csv(sites_csv)
318
+
319
+ rows: list[dict[str, float | str]] = []
320
+ invalid_subject_ids: list[str] = []
321
+ for path in nifti_paths:
322
+ subject_id = path.name.removesuffix(".nii.gz").removesuffix(".nii")
323
+ volume = nib.load(path).get_fdata()
324
+ if not is_valid_volume(volume):
325
+ invalid_subject_ids.append(subject_id)
326
+ continue
327
+ mask = mask_brain(volume, intensity_threshold=intensity_threshold)
328
+ feats = extract_features_from_volume(volume, mask, n_roi_axes=n_roi_axes)
329
+ feats["subject_id"] = subject_id
330
+ rows.append(feats)
331
+
332
+ n_total = len(nifti_paths)
333
+ n_dropped = len(invalid_subject_ids)
334
+ if n_dropped:
335
+ display = invalid_subject_ids[:10]
336
+ suffix = (
337
+ f"... (+{n_dropped - 10} more)" if n_dropped > 10 else ""
338
+ )
339
+ logger.warning(
340
+ "Dropping %d/%d volumes with invalid samples (subjects=%s%s)",
341
+ n_dropped, n_total, display, suffix,
342
+ )
343
+
344
+ feature_cols = [
345
+ f"feat_roi{i}_{stat}"
346
+ for i in range(int(np.prod(n_roi_axes)))
347
+ for stat in ROI_STATS
348
+ ]
349
+
350
+ if not rows:
351
+ logger.info(
352
+ "Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
353
+ n_total, n_dropped, 100.0 * n_dropped / max(n_total, 1),
354
+ )
355
+ empty = pd.DataFrame(
356
+ columns=["subject_id", "site", *feature_cols]
357
+ ).astype({c: np.float64 for c in feature_cols})
358
+ output_path.parent.mkdir(parents=True, exist_ok=True)
359
+ if output_path.is_dir():
360
+ raise IsADirectoryError(
361
+ f"output_path must be a file, got a directory: {output_path}"
362
+ )
363
+ empty.to_parquet(
364
+ output_path, index=False, engine="pyarrow", compression="snappy",
365
+ )
366
+ return
367
+
368
+ raw_features = pd.DataFrame(rows)
369
+ raw_features = raw_features.merge(sites_df, on="subject_id", how="left")
370
+ if raw_features["site"].isna().any():
371
+ missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
372
+ raise KeyError(
373
+ f"sites_csv missing site assignment for subjects: {missing}"
374
+ )
375
+
376
+ # ComBat cannot handle zero-variance columns (var_pooled = 0 → NaN divide).
377
+ # Split feature_cols into variable (harmonize) and constant (pass through).
378
+ var_feature_cols = [c for c in feature_cols if raw_features[c].std() > 0]
379
+ zero_var_cols = [c for c in feature_cols if raw_features[c].std() == 0]
380
+
381
+ harmonized = harmonize_combat(
382
+ raw_features, raw_features["site"], var_feature_cols,
383
+ )
384
+ # Re-attach zero-variance columns (unchanged) and restore original column order.
385
+ for c in zero_var_cols:
386
+ harmonized[c] = raw_features[c].to_numpy()
387
+ harmonized = harmonized[feature_cols]
388
+
389
+ final = pd.concat(
390
+ [raw_features[["subject_id", "site"]].reset_index(drop=True),
391
+ harmonized.reset_index(drop=True)],
392
+ axis=1,
393
+ )
394
+
395
+ output_path.parent.mkdir(parents=True, exist_ok=True)
396
+ if output_path.is_dir():
397
+ raise IsADirectoryError(
398
+ f"output_path must be a file, got a directory: {output_path}"
399
+ )
400
+ # Parquet preserves dtypes (float64 features stay float64) and is
401
+ # byte-deterministic with single-threaded snappy. AGENTS.md §6.
402
+ final.to_parquet(
403
+ output_path, index=False, engine="pyarrow", compression="snappy",
404
+ )
405
+ logger.info(
406
+ "Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
407
+ n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
408
+ )
409
+ logger.info(
410
+ "Wrote processed features to %s (rows=%d, cols=%d)",
411
+ output_path, len(final), final.shape[1],
412
+ )
413
+
414
+
415
+ if __name__ == "__main__":
416
+ # Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
417
+ # Expects `data/raw/mri/sites.csv` with columns `subject_id, site`.
418
+ # Argument parsing (argparse / click) will land in a later task.
419
+ # python -m src.pipelines.mri_pipeline
420
+ run_pipeline()
tests/pipelines/test_mri_pipeline.py CHANGED
@@ -1,6 +1,7 @@
1
  """Unit + integration tests for the MRI ComBat pipeline."""
2
  from __future__ import annotations
3
 
 
4
  from pathlib import Path
5
 
6
  import nibabel as nib
@@ -15,6 +16,7 @@ from src.pipelines.mri_pipeline import (
15
  harmonize_combat,
16
  is_valid_volume,
17
  mask_brain,
 
18
  )
19
 
20
 
@@ -284,3 +286,95 @@ class TestHarmonizeCombat:
284
  bad_sites = sites.iloc[:5]
285
  with pytest.raises(ValueError, match=r"features has 6 rows but sites has 5 elements"):
286
  harmonize_combat(df, bad_sites, feature_cols)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Unit + integration tests for the MRI ComBat pipeline."""
2
  from __future__ import annotations
3
 
4
+ import shutil
5
  from pathlib import Path
6
 
7
  import nibabel as nib
 
16
  harmonize_combat,
17
  is_valid_volume,
18
  mask_brain,
19
+ run_pipeline,
20
  )
21
 
22
 
 
286
  bad_sites = sites.iloc[:5]
287
  with pytest.raises(ValueError, match=r"features has 6 rows but sites has 5 elements"):
288
  harmonize_combat(df, bad_sites, feature_cols)
289
+
290
+
291
+ class TestRunPipeline:
292
+ def _stage_inputs(self, tmp_path: Path) -> tuple[Path, Path, Path]:
293
+ """Copy the committed MRI fixture into a tmp_path layout."""
294
+ raw_dir = tmp_path / "data" / "raw" / "mri"
295
+ proc_dir = tmp_path / "data" / "processed"
296
+ raw_dir.mkdir(parents=True)
297
+ proc_dir.mkdir(parents=True)
298
+ for src in FIXTURE_DIR.iterdir():
299
+ shutil.copy(src, raw_dir / src.name)
300
+ sites_csv = raw_dir / "sites.csv"
301
+ output_path = proc_dir / "mri_features.parquet"
302
+ return raw_dir, sites_csv, output_path
303
+
304
+ def test_end_to_end_writes_processed_parquet(self, tmp_path: Path) -> None:
305
+ raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
306
+ run_pipeline(
307
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
308
+ )
309
+ assert output_path.exists()
310
+ df = pd.read_parquet(output_path)
311
+ assert len(df) == 6
312
+ assert "subject_id" in df.columns
313
+ assert "site" in df.columns
314
+ assert any(c.startswith("feat_roi") for c in df.columns)
315
+
316
+ def test_run_pipeline_preserves_float64_for_features(self, tmp_path: Path) -> None:
317
+ raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
318
+ run_pipeline(
319
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
320
+ )
321
+ df = pd.read_parquet(output_path)
322
+ feat_cols = [c for c in df.columns if c.startswith("feat_")]
323
+ for c in feat_cols:
324
+ assert df[c].dtype == np.float64, f"{c} widened to {df[c].dtype}"
325
+
326
+ def test_run_pipeline_is_idempotent(self, tmp_path: Path) -> None:
327
+ raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
328
+ run_pipeline(
329
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
330
+ )
331
+ first = output_path.read_bytes()
332
+ run_pipeline(
333
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
334
+ )
335
+ second = output_path.read_bytes()
336
+ assert first == second, "MRI pipeline output must be byte-deterministic"
337
+
338
+ def test_run_pipeline_reduces_site_gap(self, tmp_path: Path) -> None:
339
+ """End-to-end: ComBat must shrink the per-site mean gap in feat_roi0_mean."""
340
+ raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
341
+ run_pipeline(
342
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
343
+ )
344
+ df = pd.read_parquet(output_path)
345
+ site_means = df.groupby("site")["feat_roi0_mean"].mean()
346
+ gap = abs(site_means["B"] - site_means["A"])
347
+ assert gap < 1.0, f"site gap after ComBat: {gap}"
348
+
349
+ def test_run_pipeline_raises_when_input_missing(self, tmp_path: Path) -> None:
350
+ with pytest.raises(FileNotFoundError, match="MRI input directory not found"):
351
+ run_pipeline(
352
+ input_dir=tmp_path / "nope",
353
+ sites_csv=tmp_path / "sites.csv",
354
+ output_path=tmp_path / "out.parquet",
355
+ )
356
+
357
+ def test_run_pipeline_rejects_directory_as_output(self, tmp_path: Path) -> None:
358
+ raw_dir, sites_csv, _ = self._stage_inputs(tmp_path)
359
+ bad_output = tmp_path / "out_dir"
360
+ bad_output.mkdir()
361
+ with pytest.raises(IsADirectoryError, match="must be a file"):
362
+ run_pipeline(
363
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=bad_output,
364
+ )
365
+
366
+ def test_run_pipeline_drops_invalid_volumes(self, tmp_path: Path) -> None:
367
+ """A NaN-containing volume must be logged + dropped, not silently included."""
368
+ raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
369
+ # Corrupt subject_5 to contain NaN. Re-save in place.
370
+ bad = nib.load(raw_dir / "subject_5.nii.gz").get_fdata()
371
+ bad[0, 0, 0] = np.nan
372
+ nib.save(nib.Nifti1Image(bad, affine=np.eye(4)), raw_dir / "subject_5.nii.gz")
373
+
374
+ run_pipeline(
375
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
376
+ )
377
+ df = pd.read_parquet(output_path)
378
+ # 5 surviving valid subjects (subject_5 dropped).
379
+ assert len(df) == 5
380
+ assert "subject_5" not in df["subject_id"].tolist()