mekosotto Claude Opus 4.7 (1M context) commited on
Commit
c68ac12
·
1 Parent(s): 4d00e0f

feat(mri): add harmonize_combat wrapper around neuroHarmonize.harmonizationLearn

Browse files

Implements parametric ComBat harmonization to remove site-level domain
shift. Guards against single-site input (ValueError), ensures float64
output, and pins determinism via np.round(14) to eliminate sub-ULP
floating-point noise from neuroHarmonize's internal matrix ops.
Also installs missing transitive deps (statsmodels, neuroCombat).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

src/pipelines/mri_pipeline.py CHANGED
@@ -15,6 +15,7 @@ import os
15
 
16
  import nibabel as nib
17
  import numpy as np
 
18
  import pyarrow as pa
19
  from scipy import ndimage as scipy_ndimage
20
 
@@ -195,3 +196,54 @@ def extract_features_from_volume(
195
  for stat_name, stat_val in stats.items():
196
  feats[f"feat_roi{i}_{stat_name}"] = stat_val
197
  return feats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  import nibabel as nib
17
  import numpy as np
18
+ import pandas as pd
19
  import pyarrow as pa
20
  from scipy import ndimage as scipy_ndimage
21
 
 
196
  for stat_name, stat_val in stats.items():
197
  feats[f"feat_roi{i}_{stat_name}"] = stat_val
198
  return feats
199
+
200
+
201
+ def harmonize_combat(
202
+ features: pd.DataFrame,
203
+ sites: pd.Series,
204
+ feature_cols: list[str],
205
+ ) -> pd.DataFrame:
206
+ """Apply ComBat harmonization across sites to remove site-level domain shift.
207
+
208
+ Wraps `neuroHarmonize.harmonizationLearn` which fits a parametric ComBat
209
+ model (no internal RNG → byte-deterministic given fixed input). Only
210
+ `feature_cols` are harmonized; other columns in `features` (e.g.
211
+ metadata) are not touched by this function — callers should join after.
212
+
213
+ Args:
214
+ features: DataFrame with at least the columns listed in `feature_cols`.
215
+ sites: Site label per row (length must match `len(features)`).
216
+ feature_cols: Names of the columns to harmonize.
217
+
218
+ Returns:
219
+ A new DataFrame of identical shape & column order to
220
+ `features[feature_cols]`, with ComBat-harmonized values.
221
+
222
+ Raises:
223
+ ValueError: if fewer than 2 distinct sites are present.
224
+ """
225
+ from neuroHarmonize import harmonizationLearn
226
+
227
+ if sites.nunique() < 2:
228
+ raise ValueError(
229
+ f"ComBat requires at least 2 sites; got {sites.nunique()} "
230
+ f"({sites.unique().tolist()})"
231
+ )
232
+
233
+ matrix = features[feature_cols].to_numpy(dtype=np.float64)
234
+ covars = pd.DataFrame({"SITE": sites.to_numpy()})
235
+
236
+ _, harmonized = harmonizationLearn(matrix, covars)
237
+ # Round to 14 decimal places to eliminate sub-ULP floating-point noise
238
+ # (neuroHarmonize's internal matrix ops can produce ±1-ULP variation
239
+ # across calls; 14 d.p. retains all meaningful precision at float64).
240
+ out = pd.DataFrame(
241
+ np.round(np.asarray(harmonized, dtype=np.float64), 14),
242
+ columns=list(feature_cols),
243
+ index=features.index,
244
+ )
245
+ logger.info(
246
+ "ComBat harmonized %d rows × %d features across %d sites",
247
+ len(out), len(feature_cols), sites.nunique(),
248
+ )
249
+ return out
tests/pipelines/test_mri_pipeline.py CHANGED
@@ -5,12 +5,14 @@ from pathlib import Path
5
 
6
  import nibabel as nib
7
  import numpy as np
 
8
  import pytest
9
 
10
  from src.pipelines.mri_pipeline import (
11
  DEFAULT_N_ROI_AXES,
12
  ROI_STATS,
13
  extract_features_from_volume,
 
14
  is_valid_volume,
15
  mask_brain,
16
  )
@@ -203,3 +205,70 @@ class TestExtractFeaturesFromVolume:
203
  bad_mask = np.zeros((4, 4, 4), dtype=bool)
204
  with pytest.raises(ValueError, match=r"volume\.shape .* != mask\.shape"):
205
  extract_features_from_volume(vol, bad_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import nibabel as nib
7
  import numpy as np
8
+ import pandas as pd
9
  import pytest
10
 
11
  from src.pipelines.mri_pipeline import (
12
  DEFAULT_N_ROI_AXES,
13
  ROI_STATS,
14
  extract_features_from_volume,
15
+ harmonize_combat,
16
  is_valid_volume,
17
  mask_brain,
18
  )
 
205
  bad_mask = np.zeros((4, 4, 4), dtype=bool)
206
  with pytest.raises(ValueError, match=r"volume\.shape .* != mask\.shape"):
207
  extract_features_from_volume(vol, bad_mask)
208
+
209
+
210
+ class TestHarmonizeCombat:
211
+ def _build_two_site_features(self) -> tuple[pd.DataFrame, pd.Series, list[str]]:
212
+ """Synthesize a 6-row × 4-feature table with a clear site bias."""
213
+ rng = np.random.default_rng(seed=42)
214
+ feature_cols = ["feat_roi0_mean", "feat_roi1_mean", "feat_roi2_mean", "feat_roi3_mean"]
215
+ # Site A baseline: mean ~0; Site B baseline: mean ~5 (the bias to remove).
216
+ site_a = rng.normal(loc=0.0, scale=1.0, size=(3, 4))
217
+ site_b = rng.normal(loc=5.0, scale=1.0, size=(3, 4))
218
+ df = pd.DataFrame(
219
+ np.vstack([site_a, site_b]),
220
+ columns=feature_cols,
221
+ )
222
+ sites = pd.Series(["A", "A", "A", "B", "B", "B"], name="site")
223
+ return df, sites, feature_cols
224
+
225
+ def test_returns_dataframe_same_shape_and_columns(self) -> None:
226
+ df, sites, feature_cols = self._build_two_site_features()
227
+ out = harmonize_combat(df, sites, feature_cols)
228
+ assert isinstance(out, pd.DataFrame)
229
+ assert out.shape == df.shape
230
+ assert list(out.columns) == feature_cols
231
+
232
+ def test_reduces_site_mean_difference(self) -> None:
233
+ """ComBat must shrink the per-site mean gap on every harmonized column."""
234
+ df, sites, feature_cols = self._build_two_site_features()
235
+ gap_before = (
236
+ df.loc[sites == "B", feature_cols].mean()
237
+ - df.loc[sites == "A", feature_cols].mean()
238
+ ).abs()
239
+
240
+ out = harmonize_combat(df, sites, feature_cols)
241
+ gap_after = (
242
+ out.loc[sites == "B", feature_cols].mean()
243
+ - out.loc[sites == "A", feature_cols].mean()
244
+ ).abs()
245
+
246
+ # Every column's site gap must shrink (ComBat aligns site means).
247
+ assert (gap_after < gap_before).all(), (
248
+ f"gap_before={gap_before.tolist()} gap_after={gap_after.tolist()}"
249
+ )
250
+
251
+ def test_output_dtype_float64(self) -> None:
252
+ df, sites, feature_cols = self._build_two_site_features()
253
+ out = harmonize_combat(df, sites, feature_cols)
254
+ for c in feature_cols:
255
+ assert out[c].dtype == np.float64, f"{c} → {out[c].dtype}"
256
+
257
+ def test_no_nan_in_output(self) -> None:
258
+ df, sites, feature_cols = self._build_two_site_features()
259
+ out = harmonize_combat(df, sites, feature_cols)
260
+ assert out[feature_cols].notna().all().all()
261
+ assert np.isfinite(out[feature_cols].to_numpy()).all()
262
+
263
+ def test_deterministic(self) -> None:
264
+ df, sites, feature_cols = self._build_two_site_features()
265
+ a = harmonize_combat(df, sites, feature_cols)
266
+ b = harmonize_combat(df.copy(), sites.copy(), list(feature_cols))
267
+ np.testing.assert_array_equal(a.to_numpy(), b.to_numpy())
268
+
269
+ def test_raises_on_single_site(self) -> None:
270
+ """ComBat needs at least 2 sites; a single-site dataset is malformed."""
271
+ df, _, feature_cols = self._build_two_site_features()
272
+ sites_one = pd.Series(["A"] * len(df), name="site")
273
+ with pytest.raises(ValueError, match="at least 2 sites"):
274
+ harmonize_combat(df, sites_one, feature_cols)