hackathon / src /pipelines /mri_pipeline.py
mekosotto's picture
feat(pipelines): compute_harmonization_diagnostics — long-format pre/post ComBat for viz
1068ed1
"""MRI (magnetic resonance imaging) pipeline.
Loads NIfTI volumes (`.nii` / `.nii.gz`), applies a brain mask, harmonizes
across sites with ComBat (`neuroHarmonize`), and writes per-subject ROI
statistics as a model-ready Parquet at `data/processed/mri_features.parquet`.
Follows the Data Readiness contract in AGENTS.md §4 and the Parquet storage
convention in §6: schema validity, domain validity (drop NaN/inf volumes
with a logged WARNING), determinism (ComBat is RNG-free given fixed input),
traceability (in/out/dropped counts at INFO), and idempotent overwrite.
"""
from __future__ import annotations
import time
from pathlib import Path
import nibabel as nib
import numpy as np
import pandas as pd
from scipy import ndimage as scipy_ndimage
from src.core.determinism import pin_threads
from src.core.logger import get_logger
from src.core.storage import write_parquet
from src.core.tracking import track_pipeline_run
logger = get_logger(__name__)
# Pin BLAS / OpenMP / pyarrow to single-threaded mode so byte-determinism
# (AGENTS.md §4 rule 3) holds across hardware. See src.core.determinism.
pin_threads()
def is_valid_volume(volume: np.ndarray | None) -> bool:
"""Return True iff `volume` is a non-empty 3-D numeric array with no NaN/inf.
Used to drop corrupted volumes before masking + feature extraction.
Defensive against the full set of garbage we expect from real archives:
lists, None, NaN/inf samples, zero-sized arrays, string-dtype arrays.
"""
if not isinstance(volume, np.ndarray):
return False
if volume.ndim != 3:
return False
if volume.size == 0:
return False
if not np.issubdtype(volume.dtype, np.number):
return False
if not np.all(np.isfinite(volume)):
return False
return True
def mask_brain(
volume: np.ndarray,
intensity_threshold: float | None = None,
) -> np.ndarray:
"""Build a brain mask from a 3-D MRI volume.
Two-step pipeline:
1. Intensity threshold: keep voxels above `intensity_threshold`. When
`None`, use the volume's mean as a robust auto-threshold (works on
the synthetic fixture where brain ≫ background; for real data the
caller should pass an Otsu or BET-derived threshold explicitly).
2. Morphological opening (`scipy.ndimage.binary_opening`, 6-connectivity,
iterations=1) to remove isolated noise voxels and disconnected
fragments. Note: thin features (< 3 voxels wide along any axis pair)
may be eroded entirely; for production data with cortical sheets or
sulcal bridges, prefer 26-connectivity or pass `iterations=0` upstream.
If the resulting mask is all-False (e.g. caller passed a threshold above
the volume's max intensity, or the volume is constant-valued), a WARNING
is emitted so silent feature-zeroing is visible in production logs.
Args:
volume: 3-D numeric `np.ndarray` (must satisfy `is_valid_volume`).
intensity_threshold: Voxel-intensity floor. `None` → use `volume.mean()`.
Returns:
A boolean `np.ndarray` of the same shape as `volume`. True = brain.
"""
if intensity_threshold is None:
intensity_threshold = float(volume.mean())
raw = volume > intensity_threshold
cleaned = scipy_ndimage.binary_opening(raw, iterations=1).astype(bool)
if not cleaned.any():
logger.warning(
"mask_brain produced an all-False mask "
"(volume min=%.4f, max=%.4f, threshold=%.4f); "
"downstream features for this volume will be all-zero.",
float(volume.min()), float(volume.max()), intensity_threshold,
)
return cleaned
# Default ROI partition: split a (D, H, W) volume into 2×2×2 = 8 octant ROIs.
# Octant index follows binary (z, y, x) ordering: 0..7.
DEFAULT_N_ROI_AXES: tuple[int, int, int] = (2, 2, 2)
def _roi_slices(
shape: tuple[int, int, int],
n_roi_axes: tuple[int, int, int],
) -> list[tuple[slice, slice, slice]]:
"""Generate the ROI slice list in deterministic (z, y, x) octant order."""
nz, ny, nx = n_roi_axes
dz, dy, dx = shape
bins_z = np.array_split(np.arange(dz), nz)
bins_y = np.array_split(np.arange(dy), ny)
bins_x = np.array_split(np.arange(dx), nx)
out: list[tuple[slice, slice, slice]] = []
for bz in bins_z:
for by in bins_y:
for bx in bins_x:
out.append((
slice(bz[0], bz[-1] + 1),
slice(by[0], by[-1] + 1),
slice(bx[0], bx[-1] + 1),
))
return out
# Statistical functions, bound to their column-label names. The `ROI_STATS`
# tuple below is derived from this list so labels and computations cannot
# drift out of sync (a class of bug the prior parallel-list design was
# vulnerable to — same pattern as EEG's _STATS_FUNCS).
#
# `mean`/`std` use NumPy with `ddof=0` (biased / population estimators).
# `p10`/`p50`/`p90` use `np.percentile` default linear interpolation.
# `voxel_count` is stored as float for column-uniformity in the eventual
# Parquet, but always represents a whole number (assertable via
# `v == float(int(v))`).
_ROI_STATS_FUNCS: tuple[tuple[str, "object"], ...] = (
("mean", lambda v: float(v.mean())),
("std", lambda v: float(v.std())),
("p10", lambda v: float(np.percentile(v, 10))),
("p50", lambda v: float(np.percentile(v, 50))),
("p90", lambda v: float(np.percentile(v, 90))),
("voxel_count", lambda v: float(v.size)),
)
ROI_STATS: tuple[str, ...] = tuple(name for name, _ in _ROI_STATS_FUNCS)
def _roi_stats_for(values: np.ndarray) -> dict[str, float]:
"""Compute the ROI stats. Empty array → all 0.0 (no-NaN contract)."""
if values.size == 0:
return {name: 0.0 for name, _ in _ROI_STATS_FUNCS}
return {name: fn(values) for name, fn in _ROI_STATS_FUNCS}
def extract_features_from_volume(
volume: np.ndarray,
mask: np.ndarray,
n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
) -> dict[str, float]:
"""Compute per-ROI summary statistics from a masked volume.
The volume is partitioned into ``prod(n_roi_axes)`` axis-aligned octants
in deterministic (z, y, x) order. For each ROI, intensity values from
voxels where `mask` is True are summarized via mean / std / 10th, 50th,
90th percentile / voxel count. Empty ROIs (no mask voxels) report all
zeros so the resulting Parquet has no NaN values.
Statistical conventions:
- ``mean`` / ``std`` use ``ddof=0`` (biased / population estimators).
- ``p10`` / ``p50`` / ``p90`` use ``np.percentile`` with the default
linear interpolation.
- ``voxel_count`` is stored as float for column uniformity but always
represents a whole number.
Args:
volume: 3-D numeric `np.ndarray` (already validated).
mask: Boolean `np.ndarray` of the same shape (from `mask_brain`).
n_roi_axes: ROI grid along (z, y, x). Default `(2, 2, 2)` → 8 ROIs.
Returns:
Flat dict `{"feat_roi{i}_{stat}": float}` of length
``prod(n_roi_axes) * len(ROI_STATS)``.
Raises:
ValueError: if `volume.shape` and `mask.shape` differ.
"""
if volume.shape != mask.shape:
raise ValueError(
f"volume.shape {volume.shape} != mask.shape {mask.shape}"
)
feats: dict[str, float] = {}
slices = _roi_slices(volume.shape, n_roi_axes)
for i, sl in enumerate(slices):
roi_values = volume[sl][mask[sl]]
stats = _roi_stats_for(roi_values)
for stat_name, stat_val in stats.items():
feats[f"feat_roi{i}_{stat_name}"] = stat_val
return feats
def harmonize_combat(
features: pd.DataFrame,
sites: pd.Series,
feature_cols: list[str],
) -> pd.DataFrame:
"""Apply ComBat harmonization across sites to remove site-level domain shift.
Wraps `neuroHarmonize.harmonizationLearn` which fits a parametric ComBat
model (no internal RNG → byte-deterministic given fixed input). Only
`feature_cols` are harmonized; other columns in `features` (e.g.
metadata) are not touched by this function — callers should join after.
Args:
features: DataFrame with at least the columns listed in `feature_cols`.
sites: Site label per row (length must match `len(features)`).
feature_cols: Names of the columns to harmonize.
Returns:
A new DataFrame of identical shape & column order to
`features[feature_cols]`, with ComBat-harmonized values.
Raises:
ValueError: if fewer than 2 distinct sites are present.
"""
from neuroHarmonize import harmonizationLearn
if not feature_cols:
raise ValueError("feature_cols must be a non-empty list")
if len(features) != len(sites):
raise ValueError(
f"features has {len(features)} rows but sites has {len(sites)} elements"
)
if sites.nunique() < 2:
raise ValueError(
f"ComBat requires at least 2 sites; got {sites.nunique()} "
f"({sites.unique().tolist()})"
)
matrix = features[feature_cols].to_numpy(dtype=np.float64)
covars = pd.DataFrame({"SITE": sites.to_numpy()})
_, harmonized = harmonizationLearn(matrix, covars)
# Defensive: with OMP/OPENBLAS/MKL_NUM_THREADS=1 (set at module import,
# per AGENTS.md §4), harmonizationLearn is already bit-identical across
# calls. np.round(14) provides an additional determinism boundary for
# environments where those env pins are overridden before module load
# (e.g. a sub-process that re-exports a thread count). It discards ~5
# trailing-mantissa bits, which is well below ComBat's biological
# effect-size precision floor.
out = pd.DataFrame(
np.round(np.asarray(harmonized, dtype=np.float64), 14),
columns=list(feature_cols),
index=features.index,
)
logger.info(
"ComBat harmonized %d rows × %d features across %d sites",
len(out), len(feature_cols), sites.nunique(),
)
return out
# Default I/O paths for the MRI pipeline. Override via run_pipeline() args.
DEFAULT_INPUT = Path("data/raw/mri")
DEFAULT_OUTPUT = Path("data/processed/mri_features.parquet")
# Variance floor used to decide whether a feature column is "constant" for
# ComBat. Strict ``std() > 0`` would still send near-zero-variance columns
# (e.g. ULP-level differences) into ComBat, where var_pooled ≈ 0 produces
# NaN. 1e-8 is well above machine epsilon and far below any biologically
# meaningful signal variance.
_MIN_VAR_THRESHOLD: float = 1e-8
def _list_nifti_volumes(input_dir: Path) -> list[Path]:
"""Return sorted list of .nii / .nii.gz files in `input_dir`."""
return sorted(
p for p in input_dir.iterdir()
if p.suffix == ".nii" or p.name.endswith(".nii.gz")
)
def run_pipeline(
input_dir: Path = DEFAULT_INPUT,
sites_csv: Path | None = None,
output_path: Path = DEFAULT_OUTPUT,
intensity_threshold: float | None = None,
n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
) -> None:
"""Run the MRI pipeline end-to-end: NIfTI directory → harmonized Parquet.
For each `subject_id.nii(.gz)` in `input_dir`, validates the volume,
masks the brain, computes per-ROI statistics, then harmonizes across
sites (column "site" of `sites_csv`, joined on "subject_id") via ComBat.
Output is float64 Parquet at `output_path`.
Args:
input_dir: Directory containing one NIfTI per subject and a
`sites.csv` (or `sites_csv` override) with columns
`subject_id, site`.
sites_csv: Path to the site-covariates CSV. If `None`, defaults to
`input_dir / "sites.csv"`.
output_path: Where to write the processed feature Parquet file.
intensity_threshold: Brain-mask intensity floor. `None` → per-volume
mean (see `mask_brain`).
n_roi_axes: ROI grid (z, y, x).
Raises:
FileNotFoundError: if `input_dir` does not exist.
IsADirectoryError: if `output_path` resolves to an existing directory.
KeyError: if `sites_csv` is missing a site for some subject.
"""
input_dir = Path(input_dir)
output_path = Path(output_path)
if not input_dir.exists():
raise FileNotFoundError(f"MRI input directory not found: {input_dir}")
sites_csv = Path(sites_csv) if sites_csv is not None else input_dir / "sites.csv"
if not sites_csv.exists():
raise FileNotFoundError(f"sites_csv not found: {sites_csv}")
started = time.perf_counter()
logger.info("Reading MRI volumes from %s", input_dir)
nifti_paths = _list_nifti_volumes(input_dir)
sites_df = pd.read_csv(sites_csv)
rows: list[dict[str, float | str]] = []
invalid_subject_ids: list[str] = []
for path in nifti_paths:
subject_id = path.name.removesuffix(".nii.gz").removesuffix(".nii")
volume = nib.load(path).get_fdata()
if not is_valid_volume(volume):
invalid_subject_ids.append(subject_id)
continue
mask = mask_brain(volume, intensity_threshold=intensity_threshold)
feats = extract_features_from_volume(volume, mask, n_roi_axes=n_roi_axes)
rows.append({"subject_id": subject_id, **feats})
n_total = len(nifti_paths)
n_dropped = len(invalid_subject_ids)
if n_dropped:
display = invalid_subject_ids[:10]
suffix = (
f"... (+{n_dropped - 10} more)" if n_dropped > 10 else ""
)
logger.warning(
"Dropping %d/%d volumes with invalid samples (subjects=%s%s)",
n_dropped, n_total, display, suffix,
)
feature_cols = [
f"feat_roi{i}_{stat}"
for i in range(int(np.prod(n_roi_axes)))
for stat in ROI_STATS
]
if not rows:
logger.info(
"Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
n_total, n_dropped, 100.0 * n_dropped / max(n_total, 1),
)
final = pd.DataFrame(
columns=["subject_id", "site", *feature_cols]
).astype({c: np.float64 for c in feature_cols})
else:
raw_features = pd.DataFrame(rows)
raw_features = raw_features.merge(sites_df, on="subject_id", how="left")
if raw_features["site"].isna().any():
missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
raise KeyError(
f"sites_csv missing site assignment for subjects: {missing}"
)
# ComBat cannot handle (near-)zero-variance columns: var_pooled ≈ 0 produces
# NaN. Split feature_cols on a strictly-positive variance floor so ULP-level
# noise is treated as constant.
col_std = raw_features[feature_cols].std()
var_feature_cols = [c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD]
zero_var_cols = [c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD]
if not var_feature_cols:
# Degenerate dataset: every feature is essentially constant. ComBat has
# no signal to harmonize on; pass all columns through and warn.
logger.warning(
"All %d feature columns have variance ≤ %.1e; ComBat skipped "
"(output contains unharmonized features).",
len(feature_cols), _MIN_VAR_THRESHOLD,
)
harmonized = raw_features[feature_cols].copy()
else:
harmonized = harmonize_combat(
raw_features, raw_features["site"], var_feature_cols,
)
# Re-attach zero-variance columns (unchanged) and restore the original
# column order.
for c in zero_var_cols:
harmonized[c] = raw_features[c].to_numpy()
harmonized = harmonized[feature_cols]
final = pd.concat(
[raw_features[["subject_id", "site"]].reset_index(drop=True),
harmonized.reset_index(drop=True)],
axis=1,
)
logger.info(
"Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
)
# Parquet preserves dtypes (float64 features stay float64) and is
# byte-deterministic with single-threaded snappy. AGENTS.md §6. Unconditional
# so the §4-rule-4 traceability log fires for both empty and non-empty paths.
write_parquet(final, output_path)
logger.info(
"Wrote processed features to %s (rows=%d, cols=%d)",
output_path, len(final), final.shape[1],
)
duration_sec = time.perf_counter() - started
with track_pipeline_run(
experiment_name="mri_pipeline",
params={
"input_dir": str(input_dir),
"sites_csv": str(sites_csv),
"output_path": str(output_path),
"intensity_threshold": str(intensity_threshold),
"n_roi_axes": str(n_roi_axes),
},
metrics={
"subjects_in": float(n_total),
"subjects_out": float(len(final)),
"subjects_dropped": float(n_dropped),
"duration_sec": duration_sec,
},
artifact_path=output_path,
):
pass
def compute_harmonization_diagnostics(
input_dir: Path,
sites_csv: Path | None = None,
intensity_threshold: float | None = None,
n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
) -> pd.DataFrame:
"""Run the MRI pipeline twice — pre-ComBat features and post-ComBat —
and return a long-format DataFrame ready for visualization.
Output columns: ``subject_id``, ``site``, ``feature``, ``feature_value``,
``harmonization_state`` ('Pre-ComBat' or 'Post-ComBat').
Used by the FastAPI ``/pipeline/mri/diagnostics`` endpoint to feed the
Streamlit MRI tab's KDE / histogram comparison plot.
Raises:
FileNotFoundError: if ``input_dir`` does not exist.
KeyError: if any subject is missing a site assignment.
"""
input_dir = Path(input_dir)
if not input_dir.exists():
raise FileNotFoundError(f"MRI input directory not found: {input_dir}")
sites_csv = Path(sites_csv) if sites_csv is not None else input_dir / "sites.csv"
sites_df = pd.read_csv(sites_csv)
feature_cols = [
f"feat_roi{i}_{stat}"
for i in range(int(np.prod(n_roi_axes)))
for stat in ROI_STATS
]
rows: list[dict[str, object]] = []
for nifti_path in sorted(input_dir.glob("*.nii*")):
subject_id = nifti_path.stem.replace(".nii", "")
volume = nib.load(nifti_path).get_fdata()
if not is_valid_volume(volume):
continue
mask = mask_brain(volume, intensity_threshold=intensity_threshold)
feats = extract_features_from_volume(
volume, mask, n_roi_axes=n_roi_axes,
)
row: dict[str, object] = {"subject_id": subject_id}
row.update(feats)
rows.append(row)
if not rows:
return pd.DataFrame(columns=[
"subject_id", "site", "feature", "feature_value", "harmonization_state",
])
raw_features = pd.DataFrame(rows).merge(sites_df, on="subject_id", how="left")
if raw_features["site"].isna().any():
missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
raise KeyError(
f"sites_csv missing site assignment for subjects: {missing}"
)
# Post-ComBat: variance-aware harmonization. Reuses the same logic as
# run_pipeline so diagnostics reflect production behavior exactly.
col_std = raw_features[feature_cols].std()
var_feature_cols = [
c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD
]
zero_var_cols = [
c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD
]
if not var_feature_cols:
harmonized = raw_features[feature_cols].copy()
else:
harmonized = harmonize_combat(
raw_features, raw_features["site"], var_feature_cols,
)
for c in zero_var_cols:
harmonized[c] = raw_features[c].to_numpy()
harmonized = harmonized[feature_cols]
post_features = pd.concat(
[raw_features[["subject_id", "site"]].reset_index(drop=True),
harmonized.reset_index(drop=True)],
axis=1,
)
long_pre = raw_features.melt(
id_vars=["subject_id", "site"], value_vars=feature_cols,
var_name="feature", value_name="feature_value",
)
long_pre["harmonization_state"] = "Pre-ComBat"
long_post = post_features.melt(
id_vars=["subject_id", "site"], value_vars=feature_cols,
var_name="feature", value_name="feature_value",
)
long_post["harmonization_state"] = "Post-ComBat"
return pd.concat([long_pre, long_post], ignore_index=True)
if __name__ == "__main__":
# Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
# Expects `data/raw/mri/sites.csv` with columns `subject_id, site`.
# Argument parsing (argparse / click) will land in a later task.
# python -m src.pipelines.mri_pipeline
run_pipeline()