mekosotto Claude Opus 4.7 (1M context) commited on
Commit
1068ed1
·
1 Parent(s): 1c727f2

feat(pipelines): compute_harmonization_diagnostics — long-format pre/post ComBat for viz

Browse files

- Runs feature extraction once, then variance-aware ComBat, returns
both states as a single long-format DataFrame with columns
subject_id / site / feature / feature_value / harmonization_state.
- Reuses the same _MIN_VAR_THRESHOLD split as run_pipeline so
diagnostics reflect production exactly.
- 2 new tests: long-format shape + post-ComBat site-gap < pre-ComBat
(regression pin for the 5.0 → 0.0015 reduction story).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

src/pipelines/mri_pipeline.py CHANGED
@@ -437,6 +437,100 @@ def run_pipeline(
437
  pass
438
 
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  if __name__ == "__main__":
441
  # Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
442
  # Expects `data/raw/mri/sites.csv` with columns `subject_id, site`.
 
437
  pass
438
 
439
 
440
+ def compute_harmonization_diagnostics(
441
+ input_dir: Path,
442
+ sites_csv: Path | None = None,
443
+ intensity_threshold: float | None = None,
444
+ n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
445
+ ) -> pd.DataFrame:
446
+ """Run the MRI pipeline twice — pre-ComBat features and post-ComBat —
447
+ and return a long-format DataFrame ready for visualization.
448
+
449
+ Output columns: ``subject_id``, ``site``, ``feature``, ``feature_value``,
450
+ ``harmonization_state`` ('Pre-ComBat' or 'Post-ComBat').
451
+
452
+ Used by the FastAPI ``/pipeline/mri/diagnostics`` endpoint to feed the
453
+ Streamlit MRI tab's KDE / histogram comparison plot.
454
+
455
+ Raises:
456
+ FileNotFoundError: if ``input_dir`` does not exist.
457
+ KeyError: if any subject is missing a site assignment.
458
+ """
459
+ input_dir = Path(input_dir)
460
+ if not input_dir.exists():
461
+ raise FileNotFoundError(f"MRI input directory not found: {input_dir}")
462
+ sites_csv = Path(sites_csv) if sites_csv is not None else input_dir / "sites.csv"
463
+ sites_df = pd.read_csv(sites_csv)
464
+
465
+ feature_cols = [
466
+ f"feat_roi{i}_{stat}"
467
+ for i in range(int(np.prod(n_roi_axes)))
468
+ for stat in ROI_STATS
469
+ ]
470
+
471
+ rows: list[dict[str, object]] = []
472
+ for nifti_path in sorted(input_dir.glob("*.nii*")):
473
+ subject_id = nifti_path.stem.replace(".nii", "")
474
+ volume = nib.load(nifti_path).get_fdata()
475
+ if not is_valid_volume(volume):
476
+ continue
477
+ mask = mask_brain(volume, intensity_threshold=intensity_threshold)
478
+ feats = extract_features_from_volume(
479
+ volume, mask, n_roi_axes=n_roi_axes,
480
+ )
481
+ row: dict[str, object] = {"subject_id": subject_id}
482
+ row.update(feats)
483
+ rows.append(row)
484
+
485
+ if not rows:
486
+ return pd.DataFrame(columns=[
487
+ "subject_id", "site", "feature", "feature_value", "harmonization_state",
488
+ ])
489
+
490
+ raw_features = pd.DataFrame(rows).merge(sites_df, on="subject_id", how="left")
491
+ if raw_features["site"].isna().any():
492
+ missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
493
+ raise KeyError(
494
+ f"sites_csv missing site assignment for subjects: {missing}"
495
+ )
496
+
497
+ # Post-ComBat: variance-aware harmonization. Reuses the same logic as
498
+ # run_pipeline so diagnostics reflect production behavior exactly.
499
+ col_std = raw_features[feature_cols].std()
500
+ var_feature_cols = [
501
+ c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD
502
+ ]
503
+ zero_var_cols = [
504
+ c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD
505
+ ]
506
+ if not var_feature_cols:
507
+ harmonized = raw_features[feature_cols].copy()
508
+ else:
509
+ harmonized = harmonize_combat(
510
+ raw_features, raw_features["site"], var_feature_cols,
511
+ )
512
+ for c in zero_var_cols:
513
+ harmonized[c] = raw_features[c].to_numpy()
514
+ harmonized = harmonized[feature_cols]
515
+ post_features = pd.concat(
516
+ [raw_features[["subject_id", "site"]].reset_index(drop=True),
517
+ harmonized.reset_index(drop=True)],
518
+ axis=1,
519
+ )
520
+
521
+ long_pre = raw_features.melt(
522
+ id_vars=["subject_id", "site"], value_vars=feature_cols,
523
+ var_name="feature", value_name="feature_value",
524
+ )
525
+ long_pre["harmonization_state"] = "Pre-ComBat"
526
+ long_post = post_features.melt(
527
+ id_vars=["subject_id", "site"], value_vars=feature_cols,
528
+ var_name="feature", value_name="feature_value",
529
+ )
530
+ long_post["harmonization_state"] = "Post-ComBat"
531
+ return pd.concat([long_pre, long_post], ignore_index=True)
532
+
533
+
534
  if __name__ == "__main__":
535
  # Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
536
  # Expects `data/raw/mri/sites.csv` with columns `subject_id, site`.
tests/pipelines/test_mri_pipeline.py CHANGED
@@ -466,3 +466,43 @@ class TestMRIPipelineMLflow:
466
  assert len(runs) >= 1
467
  assert "metrics.subjects_out" in runs.columns
468
  assert runs.iloc[0]["metrics.subjects_out"] > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  assert len(runs) >= 1
467
  assert "metrics.subjects_out" in runs.columns
468
  assert runs.iloc[0]["metrics.subjects_out"] > 0
469
+
470
+
471
+ class TestComputeHarmonizationDiagnostics:
472
+ def test_returns_long_format_with_pre_and_post_states(self, tmp_path: Path):
473
+ from tests.fixtures.build_mri_fixture import build as build_mri
474
+ from src.pipelines.mri_pipeline import compute_harmonization_diagnostics
475
+
476
+ fixture_dir = build_mri(out_dir=tmp_path / "mri")
477
+ diagnostics = compute_harmonization_diagnostics(
478
+ input_dir=fixture_dir,
479
+ sites_csv=fixture_dir / "sites.csv",
480
+ )
481
+ assert "feature_value" in diagnostics.columns
482
+ assert "site" in diagnostics.columns
483
+ assert "harmonization_state" in diagnostics.columns
484
+ assert "feature" in diagnostics.columns
485
+ states = set(diagnostics["harmonization_state"].unique())
486
+ assert states == {"Pre-ComBat", "Post-ComBat"}
487
+
488
+ def test_post_combat_site_gap_is_smaller_than_pre(self, tmp_path: Path):
489
+ """Day-3 demonstrated 5.0 → 0.0015 gap reduction. This regression
490
+ test pins the property: post-ComBat per-site means MUST be closer
491
+ together than pre-ComBat per-site means."""
492
+ from tests.fixtures.build_mri_fixture import build as build_mri
493
+ from src.pipelines.mri_pipeline import compute_harmonization_diagnostics
494
+
495
+ fixture_dir = build_mri(out_dir=tmp_path / "mri")
496
+ diagnostics = compute_harmonization_diagnostics(
497
+ input_dir=fixture_dir,
498
+ sites_csv=fixture_dir / "sites.csv",
499
+ )
500
+ pre = diagnostics[diagnostics["harmonization_state"] == "Pre-ComBat"]
501
+ post = diagnostics[diagnostics["harmonization_state"] == "Post-ComBat"]
502
+ # Compute site-gap as range of per-site means on the first feature
503
+ feat = diagnostics["feature"].iloc[0]
504
+ pre_gap = pre[pre["feature"] == feat].groupby("site")["feature_value"].mean().agg(lambda s: s.max() - s.min())
505
+ post_gap = post[post["feature"] == feat].groupby("site")["feature_value"].mean().agg(lambda s: s.max() - s.min())
506
+ assert post_gap < pre_gap, (
507
+ f"Expected post-gap < pre-gap, got pre={pre_gap}, post={post_gap}"
508
+ )