feat(mri): add run_pipeline orchestrator + CLI (NIfTI dir → ComBat Parquet)
Browse files
src/pipelines/mri_pipeline.py
CHANGED
|
@@ -12,6 +12,7 @@ traceability (in/out/dropped counts at INFO), and idempotent overwrite.
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import os
|
|
|
|
| 15 |
|
| 16 |
import nibabel as nib
|
| 17 |
import numpy as np
|
|
@@ -258,3 +259,162 @@ def harmonize_combat(
|
|
| 258 |
len(out), len(feature_cols), sites.nunique(),
|
| 259 |
)
|
| 260 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
|
| 17 |
import nibabel as nib
|
| 18 |
import numpy as np
|
|
|
|
| 259 |
len(out), len(feature_cols), sites.nunique(),
|
| 260 |
)
|
| 261 |
return out
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# Default I/O paths for the MRI pipeline. Override via run_pipeline() args.
|
| 265 |
+
DEFAULT_INPUT = Path("data/raw/mri")
|
| 266 |
+
DEFAULT_OUTPUT = Path("data/processed/mri_features.parquet")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _list_nifti_volumes(input_dir: Path) -> list[Path]:
|
| 270 |
+
"""Return sorted list of .nii / .nii.gz files in `input_dir`."""
|
| 271 |
+
return sorted(
|
| 272 |
+
p for p in input_dir.iterdir()
|
| 273 |
+
if p.suffix == ".nii" or p.name.endswith(".nii.gz")
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def run_pipeline(
|
| 278 |
+
input_dir: Path = DEFAULT_INPUT,
|
| 279 |
+
sites_csv: Path | None = None,
|
| 280 |
+
output_path: Path = DEFAULT_OUTPUT,
|
| 281 |
+
intensity_threshold: float | None = None,
|
| 282 |
+
n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
|
| 283 |
+
) -> None:
|
| 284 |
+
"""Run the MRI pipeline end-to-end: NIfTI directory → harmonized Parquet.
|
| 285 |
+
|
| 286 |
+
For each `subject_id.nii(.gz)` in `input_dir`, validates the volume,
|
| 287 |
+
masks the brain, computes per-ROI statistics, then harmonizes across
|
| 288 |
+
sites (column "site" of `sites_csv`, joined on "subject_id") via ComBat.
|
| 289 |
+
Output is float64 Parquet at `output_path`.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
input_dir: Directory containing one NIfTI per subject and a
|
| 293 |
+
`sites.csv` (or `sites_csv` override) with columns
|
| 294 |
+
`subject_id, site`.
|
| 295 |
+
sites_csv: Path to the site-covariates CSV. If `None`, defaults to
|
| 296 |
+
`input_dir / "sites.csv"`.
|
| 297 |
+
output_path: Where to write the processed feature Parquet file.
|
| 298 |
+
intensity_threshold: Brain-mask intensity floor. `None` → per-volume
|
| 299 |
+
mean (see `mask_brain`).
|
| 300 |
+
n_roi_axes: ROI grid (z, y, x).
|
| 301 |
+
|
| 302 |
+
Raises:
|
| 303 |
+
FileNotFoundError: if `input_dir` does not exist.
|
| 304 |
+
IsADirectoryError: if `output_path` resolves to an existing directory.
|
| 305 |
+
KeyError: if `sites_csv` is missing a site for some subject.
|
| 306 |
+
"""
|
| 307 |
+
input_dir = Path(input_dir)
|
| 308 |
+
output_path = Path(output_path)
|
| 309 |
+
if not input_dir.exists():
|
| 310 |
+
raise FileNotFoundError(f"MRI input directory not found: {input_dir}")
|
| 311 |
+
sites_csv = Path(sites_csv) if sites_csv is not None else input_dir / "sites.csv"
|
| 312 |
+
if not sites_csv.exists():
|
| 313 |
+
raise FileNotFoundError(f"sites_csv not found: {sites_csv}")
|
| 314 |
+
|
| 315 |
+
logger.info("Reading MRI volumes from %s", input_dir)
|
| 316 |
+
nifti_paths = _list_nifti_volumes(input_dir)
|
| 317 |
+
sites_df = pd.read_csv(sites_csv)
|
| 318 |
+
|
| 319 |
+
rows: list[dict[str, float | str]] = []
|
| 320 |
+
invalid_subject_ids: list[str] = []
|
| 321 |
+
for path in nifti_paths:
|
| 322 |
+
subject_id = path.name.removesuffix(".nii.gz").removesuffix(".nii")
|
| 323 |
+
volume = nib.load(path).get_fdata()
|
| 324 |
+
if not is_valid_volume(volume):
|
| 325 |
+
invalid_subject_ids.append(subject_id)
|
| 326 |
+
continue
|
| 327 |
+
mask = mask_brain(volume, intensity_threshold=intensity_threshold)
|
| 328 |
+
feats = extract_features_from_volume(volume, mask, n_roi_axes=n_roi_axes)
|
| 329 |
+
feats["subject_id"] = subject_id
|
| 330 |
+
rows.append(feats)
|
| 331 |
+
|
| 332 |
+
n_total = len(nifti_paths)
|
| 333 |
+
n_dropped = len(invalid_subject_ids)
|
| 334 |
+
if n_dropped:
|
| 335 |
+
display = invalid_subject_ids[:10]
|
| 336 |
+
suffix = (
|
| 337 |
+
f"... (+{n_dropped - 10} more)" if n_dropped > 10 else ""
|
| 338 |
+
)
|
| 339 |
+
logger.warning(
|
| 340 |
+
"Dropping %d/%d volumes with invalid samples (subjects=%s%s)",
|
| 341 |
+
n_dropped, n_total, display, suffix,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
feature_cols = [
|
| 345 |
+
f"feat_roi{i}_{stat}"
|
| 346 |
+
for i in range(int(np.prod(n_roi_axes)))
|
| 347 |
+
for stat in ROI_STATS
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
if not rows:
|
| 351 |
+
logger.info(
|
| 352 |
+
"Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
|
| 353 |
+
n_total, n_dropped, 100.0 * n_dropped / max(n_total, 1),
|
| 354 |
+
)
|
| 355 |
+
empty = pd.DataFrame(
|
| 356 |
+
columns=["subject_id", "site", *feature_cols]
|
| 357 |
+
).astype({c: np.float64 for c in feature_cols})
|
| 358 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 359 |
+
if output_path.is_dir():
|
| 360 |
+
raise IsADirectoryError(
|
| 361 |
+
f"output_path must be a file, got a directory: {output_path}"
|
| 362 |
+
)
|
| 363 |
+
empty.to_parquet(
|
| 364 |
+
output_path, index=False, engine="pyarrow", compression="snappy",
|
| 365 |
+
)
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
raw_features = pd.DataFrame(rows)
|
| 369 |
+
raw_features = raw_features.merge(sites_df, on="subject_id", how="left")
|
| 370 |
+
if raw_features["site"].isna().any():
|
| 371 |
+
missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
|
| 372 |
+
raise KeyError(
|
| 373 |
+
f"sites_csv missing site assignment for subjects: {missing}"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# ComBat cannot handle zero-variance columns (var_pooled = 0 → NaN divide).
|
| 377 |
+
# Split feature_cols into variable (harmonize) and constant (pass through).
|
| 378 |
+
var_feature_cols = [c for c in feature_cols if raw_features[c].std() > 0]
|
| 379 |
+
zero_var_cols = [c for c in feature_cols if raw_features[c].std() == 0]
|
| 380 |
+
|
| 381 |
+
harmonized = harmonize_combat(
|
| 382 |
+
raw_features, raw_features["site"], var_feature_cols,
|
| 383 |
+
)
|
| 384 |
+
# Re-attach zero-variance columns (unchanged) and restore original column order.
|
| 385 |
+
for c in zero_var_cols:
|
| 386 |
+
harmonized[c] = raw_features[c].to_numpy()
|
| 387 |
+
harmonized = harmonized[feature_cols]
|
| 388 |
+
|
| 389 |
+
final = pd.concat(
|
| 390 |
+
[raw_features[["subject_id", "site"]].reset_index(drop=True),
|
| 391 |
+
harmonized.reset_index(drop=True)],
|
| 392 |
+
axis=1,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 396 |
+
if output_path.is_dir():
|
| 397 |
+
raise IsADirectoryError(
|
| 398 |
+
f"output_path must be a file, got a directory: {output_path}"
|
| 399 |
+
)
|
| 400 |
+
# Parquet preserves dtypes (float64 features stay float64) and is
|
| 401 |
+
# byte-deterministic with single-threaded snappy. AGENTS.md §6.
|
| 402 |
+
final.to_parquet(
|
| 403 |
+
output_path, index=False, engine="pyarrow", compression="snappy",
|
| 404 |
+
)
|
| 405 |
+
logger.info(
|
| 406 |
+
"Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
|
| 407 |
+
n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
|
| 408 |
+
)
|
| 409 |
+
logger.info(
|
| 410 |
+
"Wrote processed features to %s (rows=%d, cols=%d)",
|
| 411 |
+
output_path, len(final), final.shape[1],
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
if __name__ == "__main__":
|
| 416 |
+
# Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
|
| 417 |
+
# Expects `data/raw/mri/sites.csv` with columns `subject_id, site`.
|
| 418 |
+
# Argument parsing (argparse / click) will land in a later task.
|
| 419 |
+
# python -m src.pipelines.mri_pipeline
|
| 420 |
+
run_pipeline()
|
tests/pipelines/test_mri_pipeline.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""Unit + integration tests for the MRI ComBat pipeline."""
|
| 2 |
from __future__ import annotations
|
| 3 |
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
import nibabel as nib
|
|
@@ -15,6 +16,7 @@ from src.pipelines.mri_pipeline import (
|
|
| 15 |
harmonize_combat,
|
| 16 |
is_valid_volume,
|
| 17 |
mask_brain,
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
|
|
@@ -284,3 +286,95 @@ class TestHarmonizeCombat:
|
|
| 284 |
bad_sites = sites.iloc[:5]
|
| 285 |
with pytest.raises(ValueError, match=r"features has 6 rows but sites has 5 elements"):
|
| 286 |
harmonize_combat(df, bad_sites, feature_cols)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Unit + integration tests for the MRI ComBat pipeline."""
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
import shutil
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
import nibabel as nib
|
|
|
|
| 16 |
harmonize_combat,
|
| 17 |
is_valid_volume,
|
| 18 |
mask_brain,
|
| 19 |
+
run_pipeline,
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
|
|
| 286 |
bad_sites = sites.iloc[:5]
|
| 287 |
with pytest.raises(ValueError, match=r"features has 6 rows but sites has 5 elements"):
|
| 288 |
harmonize_combat(df, bad_sites, feature_cols)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class TestRunPipeline:
|
| 292 |
+
def _stage_inputs(self, tmp_path: Path) -> tuple[Path, Path, Path]:
|
| 293 |
+
"""Copy the committed MRI fixture into a tmp_path layout."""
|
| 294 |
+
raw_dir = tmp_path / "data" / "raw" / "mri"
|
| 295 |
+
proc_dir = tmp_path / "data" / "processed"
|
| 296 |
+
raw_dir.mkdir(parents=True)
|
| 297 |
+
proc_dir.mkdir(parents=True)
|
| 298 |
+
for src in FIXTURE_DIR.iterdir():
|
| 299 |
+
shutil.copy(src, raw_dir / src.name)
|
| 300 |
+
sites_csv = raw_dir / "sites.csv"
|
| 301 |
+
output_path = proc_dir / "mri_features.parquet"
|
| 302 |
+
return raw_dir, sites_csv, output_path
|
| 303 |
+
|
| 304 |
+
def test_end_to_end_writes_processed_parquet(self, tmp_path: Path) -> None:
|
| 305 |
+
raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
|
| 306 |
+
run_pipeline(
|
| 307 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
|
| 308 |
+
)
|
| 309 |
+
assert output_path.exists()
|
| 310 |
+
df = pd.read_parquet(output_path)
|
| 311 |
+
assert len(df) == 6
|
| 312 |
+
assert "subject_id" in df.columns
|
| 313 |
+
assert "site" in df.columns
|
| 314 |
+
assert any(c.startswith("feat_roi") for c in df.columns)
|
| 315 |
+
|
| 316 |
+
def test_run_pipeline_preserves_float64_for_features(self, tmp_path: Path) -> None:
|
| 317 |
+
raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
|
| 318 |
+
run_pipeline(
|
| 319 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
|
| 320 |
+
)
|
| 321 |
+
df = pd.read_parquet(output_path)
|
| 322 |
+
feat_cols = [c for c in df.columns if c.startswith("feat_")]
|
| 323 |
+
for c in feat_cols:
|
| 324 |
+
assert df[c].dtype == np.float64, f"{c} widened to {df[c].dtype}"
|
| 325 |
+
|
| 326 |
+
def test_run_pipeline_is_idempotent(self, tmp_path: Path) -> None:
|
| 327 |
+
raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
|
| 328 |
+
run_pipeline(
|
| 329 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
|
| 330 |
+
)
|
| 331 |
+
first = output_path.read_bytes()
|
| 332 |
+
run_pipeline(
|
| 333 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
|
| 334 |
+
)
|
| 335 |
+
second = output_path.read_bytes()
|
| 336 |
+
assert first == second, "MRI pipeline output must be byte-deterministic"
|
| 337 |
+
|
| 338 |
+
def test_run_pipeline_reduces_site_gap(self, tmp_path: Path) -> None:
|
| 339 |
+
"""End-to-end: ComBat must shrink the per-site mean gap in feat_roi0_mean."""
|
| 340 |
+
raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
|
| 341 |
+
run_pipeline(
|
| 342 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
|
| 343 |
+
)
|
| 344 |
+
df = pd.read_parquet(output_path)
|
| 345 |
+
site_means = df.groupby("site")["feat_roi0_mean"].mean()
|
| 346 |
+
gap = abs(site_means["B"] - site_means["A"])
|
| 347 |
+
assert gap < 1.0, f"site gap after ComBat: {gap}"
|
| 348 |
+
|
| 349 |
+
def test_run_pipeline_raises_when_input_missing(self, tmp_path: Path) -> None:
|
| 350 |
+
with pytest.raises(FileNotFoundError, match="MRI input directory not found"):
|
| 351 |
+
run_pipeline(
|
| 352 |
+
input_dir=tmp_path / "nope",
|
| 353 |
+
sites_csv=tmp_path / "sites.csv",
|
| 354 |
+
output_path=tmp_path / "out.parquet",
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
def test_run_pipeline_rejects_directory_as_output(self, tmp_path: Path) -> None:
|
| 358 |
+
raw_dir, sites_csv, _ = self._stage_inputs(tmp_path)
|
| 359 |
+
bad_output = tmp_path / "out_dir"
|
| 360 |
+
bad_output.mkdir()
|
| 361 |
+
with pytest.raises(IsADirectoryError, match="must be a file"):
|
| 362 |
+
run_pipeline(
|
| 363 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=bad_output,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def test_run_pipeline_drops_invalid_volumes(self, tmp_path: Path) -> None:
|
| 367 |
+
"""A NaN-containing volume must be logged + dropped, not silently included."""
|
| 368 |
+
raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
|
| 369 |
+
# Corrupt subject_5 to contain NaN. Re-save in place.
|
| 370 |
+
bad = nib.load(raw_dir / "subject_5.nii.gz").get_fdata()
|
| 371 |
+
bad[0, 0, 0] = np.nan
|
| 372 |
+
nib.save(nib.Nifti1Image(bad, affine=np.eye(4)), raw_dir / "subject_5.nii.gz")
|
| 373 |
+
|
| 374 |
+
run_pipeline(
|
| 375 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
|
| 376 |
+
)
|
| 377 |
+
df = pd.read_parquet(output_path)
|
| 378 |
+
# 5 surviving valid subjects (subject_5 dropped).
|
| 379 |
+
assert len(df) == 5
|
| 380 |
+
assert "subject_5" not in df["subject_id"].tolist()
|