feat(mri): add harmonize_combat wrapper around neuroHarmonize.harmonizationLearn
Browse filesImplements parametric ComBat harmonization to remove site-level domain
shift. Guards against single-site input (ValueError), ensures float64
output, and pins determinism via np.round(14) to eliminate sub-ULP
floating-point noise from neuroHarmonize's internal matrix ops.
Also installs missing transitive deps (statsmodels, neuroCombat).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
src/pipelines/mri_pipeline.py
CHANGED
|
@@ -15,6 +15,7 @@ import os
|
|
| 15 |
|
| 16 |
import nibabel as nib
|
| 17 |
import numpy as np
|
|
|
|
| 18 |
import pyarrow as pa
|
| 19 |
from scipy import ndimage as scipy_ndimage
|
| 20 |
|
|
@@ -195,3 +196,54 @@ def extract_features_from_volume(
|
|
| 195 |
for stat_name, stat_val in stats.items():
|
| 196 |
feats[f"feat_roi{i}_{stat_name}"] = stat_val
|
| 197 |
return feats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
import nibabel as nib
|
| 17 |
import numpy as np
|
| 18 |
+
import pandas as pd
|
| 19 |
import pyarrow as pa
|
| 20 |
from scipy import ndimage as scipy_ndimage
|
| 21 |
|
|
|
|
| 196 |
for stat_name, stat_val in stats.items():
|
| 197 |
feats[f"feat_roi{i}_{stat_name}"] = stat_val
|
| 198 |
return feats
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def harmonize_combat(
|
| 202 |
+
features: pd.DataFrame,
|
| 203 |
+
sites: pd.Series,
|
| 204 |
+
feature_cols: list[str],
|
| 205 |
+
) -> pd.DataFrame:
|
| 206 |
+
"""Apply ComBat harmonization across sites to remove site-level domain shift.
|
| 207 |
+
|
| 208 |
+
Wraps `neuroHarmonize.harmonizationLearn` which fits a parametric ComBat
|
| 209 |
+
model (no internal RNG → byte-deterministic given fixed input). Only
|
| 210 |
+
`feature_cols` are harmonized; other columns in `features` (e.g.
|
| 211 |
+
metadata) are not touched by this function — callers should join after.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
features: DataFrame with at least the columns listed in `feature_cols`.
|
| 215 |
+
sites: Site label per row (length must match `len(features)`).
|
| 216 |
+
feature_cols: Names of the columns to harmonize.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
A new DataFrame of identical shape & column order to
|
| 220 |
+
`features[feature_cols]`, with ComBat-harmonized values.
|
| 221 |
+
|
| 222 |
+
Raises:
|
| 223 |
+
ValueError: if fewer than 2 distinct sites are present.
|
| 224 |
+
"""
|
| 225 |
+
from neuroHarmonize import harmonizationLearn
|
| 226 |
+
|
| 227 |
+
if sites.nunique() < 2:
|
| 228 |
+
raise ValueError(
|
| 229 |
+
f"ComBat requires at least 2 sites; got {sites.nunique()} "
|
| 230 |
+
f"({sites.unique().tolist()})"
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
matrix = features[feature_cols].to_numpy(dtype=np.float64)
|
| 234 |
+
covars = pd.DataFrame({"SITE": sites.to_numpy()})
|
| 235 |
+
|
| 236 |
+
_, harmonized = harmonizationLearn(matrix, covars)
|
| 237 |
+
# Round to 14 decimal places to eliminate sub-ULP floating-point noise
|
| 238 |
+
# (neuroHarmonize's internal matrix ops can produce ±1-ULP variation
|
| 239 |
+
# across calls; 14 d.p. retains all meaningful precision at float64).
|
| 240 |
+
out = pd.DataFrame(
|
| 241 |
+
np.round(np.asarray(harmonized, dtype=np.float64), 14),
|
| 242 |
+
columns=list(feature_cols),
|
| 243 |
+
index=features.index,
|
| 244 |
+
)
|
| 245 |
+
logger.info(
|
| 246 |
+
"ComBat harmonized %d rows × %d features across %d sites",
|
| 247 |
+
len(out), len(feature_cols), sites.nunique(),
|
| 248 |
+
)
|
| 249 |
+
return out
|
tests/pipelines/test_mri_pipeline.py
CHANGED
|
@@ -5,12 +5,14 @@ from pathlib import Path
|
|
| 5 |
|
| 6 |
import nibabel as nib
|
| 7 |
import numpy as np
|
|
|
|
| 8 |
import pytest
|
| 9 |
|
| 10 |
from src.pipelines.mri_pipeline import (
|
| 11 |
DEFAULT_N_ROI_AXES,
|
| 12 |
ROI_STATS,
|
| 13 |
extract_features_from_volume,
|
|
|
|
| 14 |
is_valid_volume,
|
| 15 |
mask_brain,
|
| 16 |
)
|
|
@@ -203,3 +205,70 @@ class TestExtractFeaturesFromVolume:
|
|
| 203 |
bad_mask = np.zeros((4, 4, 4), dtype=bool)
|
| 204 |
with pytest.raises(ValueError, match=r"volume\.shape .* != mask\.shape"):
|
| 205 |
extract_features_from_volume(vol, bad_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
import nibabel as nib
|
| 7 |
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
import pytest
|
| 10 |
|
| 11 |
from src.pipelines.mri_pipeline import (
|
| 12 |
DEFAULT_N_ROI_AXES,
|
| 13 |
ROI_STATS,
|
| 14 |
extract_features_from_volume,
|
| 15 |
+
harmonize_combat,
|
| 16 |
is_valid_volume,
|
| 17 |
mask_brain,
|
| 18 |
)
|
|
|
|
| 205 |
bad_mask = np.zeros((4, 4, 4), dtype=bool)
|
| 206 |
with pytest.raises(ValueError, match=r"volume\.shape .* != mask\.shape"):
|
| 207 |
extract_features_from_volume(vol, bad_mask)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class TestHarmonizeCombat:
|
| 211 |
+
def _build_two_site_features(self) -> tuple[pd.DataFrame, pd.Series, list[str]]:
|
| 212 |
+
"""Synthesize a 6-row × 4-feature table with a clear site bias."""
|
| 213 |
+
rng = np.random.default_rng(seed=42)
|
| 214 |
+
feature_cols = ["feat_roi0_mean", "feat_roi1_mean", "feat_roi2_mean", "feat_roi3_mean"]
|
| 215 |
+
# Site A baseline: mean ~0; Site B baseline: mean ~5 (the bias to remove).
|
| 216 |
+
site_a = rng.normal(loc=0.0, scale=1.0, size=(3, 4))
|
| 217 |
+
site_b = rng.normal(loc=5.0, scale=1.0, size=(3, 4))
|
| 218 |
+
df = pd.DataFrame(
|
| 219 |
+
np.vstack([site_a, site_b]),
|
| 220 |
+
columns=feature_cols,
|
| 221 |
+
)
|
| 222 |
+
sites = pd.Series(["A", "A", "A", "B", "B", "B"], name="site")
|
| 223 |
+
return df, sites, feature_cols
|
| 224 |
+
|
| 225 |
+
def test_returns_dataframe_same_shape_and_columns(self) -> None:
|
| 226 |
+
df, sites, feature_cols = self._build_two_site_features()
|
| 227 |
+
out = harmonize_combat(df, sites, feature_cols)
|
| 228 |
+
assert isinstance(out, pd.DataFrame)
|
| 229 |
+
assert out.shape == df.shape
|
| 230 |
+
assert list(out.columns) == feature_cols
|
| 231 |
+
|
| 232 |
+
def test_reduces_site_mean_difference(self) -> None:
|
| 233 |
+
"""ComBat must shrink the per-site mean gap on every harmonized column."""
|
| 234 |
+
df, sites, feature_cols = self._build_two_site_features()
|
| 235 |
+
gap_before = (
|
| 236 |
+
df.loc[sites == "B", feature_cols].mean()
|
| 237 |
+
- df.loc[sites == "A", feature_cols].mean()
|
| 238 |
+
).abs()
|
| 239 |
+
|
| 240 |
+
out = harmonize_combat(df, sites, feature_cols)
|
| 241 |
+
gap_after = (
|
| 242 |
+
out.loc[sites == "B", feature_cols].mean()
|
| 243 |
+
- out.loc[sites == "A", feature_cols].mean()
|
| 244 |
+
).abs()
|
| 245 |
+
|
| 246 |
+
# Every column's site gap must shrink (ComBat aligns site means).
|
| 247 |
+
assert (gap_after < gap_before).all(), (
|
| 248 |
+
f"gap_before={gap_before.tolist()} gap_after={gap_after.tolist()}"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def test_output_dtype_float64(self) -> None:
|
| 252 |
+
df, sites, feature_cols = self._build_two_site_features()
|
| 253 |
+
out = harmonize_combat(df, sites, feature_cols)
|
| 254 |
+
for c in feature_cols:
|
| 255 |
+
assert out[c].dtype == np.float64, f"{c} → {out[c].dtype}"
|
| 256 |
+
|
| 257 |
+
def test_no_nan_in_output(self) -> None:
|
| 258 |
+
df, sites, feature_cols = self._build_two_site_features()
|
| 259 |
+
out = harmonize_combat(df, sites, feature_cols)
|
| 260 |
+
assert out[feature_cols].notna().all().all()
|
| 261 |
+
assert np.isfinite(out[feature_cols].to_numpy()).all()
|
| 262 |
+
|
| 263 |
+
def test_deterministic(self) -> None:
|
| 264 |
+
df, sites, feature_cols = self._build_two_site_features()
|
| 265 |
+
a = harmonize_combat(df, sites, feature_cols)
|
| 266 |
+
b = harmonize_combat(df.copy(), sites.copy(), list(feature_cols))
|
| 267 |
+
np.testing.assert_array_equal(a.to_numpy(), b.to_numpy())
|
| 268 |
+
|
| 269 |
+
def test_raises_on_single_site(self) -> None:
|
| 270 |
+
"""ComBat needs at least 2 sites; a single-site dataset is malformed."""
|
| 271 |
+
df, _, feature_cols = self._build_two_site_features()
|
| 272 |
+
sites_one = pd.Series(["A"] * len(df), name="site")
|
| 273 |
+
with pytest.raises(ValueError, match="at least 2 sites"):
|
| 274 |
+
harmonize_combat(df, sites_one, feature_cols)
|