feat(pipelines): compute_harmonization_diagnostics — long-format pre/post ComBat for viz
Browse files- Runs feature extraction once, then variance-aware ComBat, returns
both states as a single long-format DataFrame with columns
subject_id / site / feature / feature_value / harmonization_state.
- Reuses the same _MIN_VAR_THRESHOLD split as run_pipeline so
diagnostics reflect production exactly.
- 2 new tests: long-format shape + post-ComBat site-gap < pre-ComBat
(regression pin for the 5.0 → 0.0015 reduction story).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
src/pipelines/mri_pipeline.py
CHANGED
|
@@ -437,6 +437,100 @@ def run_pipeline(
|
|
| 437 |
pass
|
| 438 |
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
if __name__ == "__main__":
|
| 441 |
# Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
|
| 442 |
# Expects `data/raw/mri/sites.csv` with columns `subject_id, site`.
|
|
|
|
| 437 |
pass
|
| 438 |
|
| 439 |
|
| 440 |
+
def compute_harmonization_diagnostics(
|
| 441 |
+
input_dir: Path,
|
| 442 |
+
sites_csv: Path | None = None,
|
| 443 |
+
intensity_threshold: float | None = None,
|
| 444 |
+
n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
|
| 445 |
+
) -> pd.DataFrame:
|
| 446 |
+
"""Run the MRI pipeline twice — pre-ComBat features and post-ComBat —
|
| 447 |
+
and return a long-format DataFrame ready for visualization.
|
| 448 |
+
|
| 449 |
+
Output columns: ``subject_id``, ``site``, ``feature``, ``feature_value``,
|
| 450 |
+
``harmonization_state`` ('Pre-ComBat' or 'Post-ComBat').
|
| 451 |
+
|
| 452 |
+
Used by the FastAPI ``/pipeline/mri/diagnostics`` endpoint to feed the
|
| 453 |
+
Streamlit MRI tab's KDE / histogram comparison plot.
|
| 454 |
+
|
| 455 |
+
Raises:
|
| 456 |
+
FileNotFoundError: if ``input_dir`` does not exist.
|
| 457 |
+
KeyError: if any subject is missing a site assignment.
|
| 458 |
+
"""
|
| 459 |
+
input_dir = Path(input_dir)
|
| 460 |
+
if not input_dir.exists():
|
| 461 |
+
raise FileNotFoundError(f"MRI input directory not found: {input_dir}")
|
| 462 |
+
sites_csv = Path(sites_csv) if sites_csv is not None else input_dir / "sites.csv"
|
| 463 |
+
sites_df = pd.read_csv(sites_csv)
|
| 464 |
+
|
| 465 |
+
feature_cols = [
|
| 466 |
+
f"feat_roi{i}_{stat}"
|
| 467 |
+
for i in range(int(np.prod(n_roi_axes)))
|
| 468 |
+
for stat in ROI_STATS
|
| 469 |
+
]
|
| 470 |
+
|
| 471 |
+
rows: list[dict[str, object]] = []
|
| 472 |
+
for nifti_path in sorted(input_dir.glob("*.nii*")):
|
| 473 |
+
subject_id = nifti_path.stem.replace(".nii", "")
|
| 474 |
+
volume = nib.load(nifti_path).get_fdata()
|
| 475 |
+
if not is_valid_volume(volume):
|
| 476 |
+
continue
|
| 477 |
+
mask = mask_brain(volume, intensity_threshold=intensity_threshold)
|
| 478 |
+
feats = extract_features_from_volume(
|
| 479 |
+
volume, mask, n_roi_axes=n_roi_axes,
|
| 480 |
+
)
|
| 481 |
+
row: dict[str, object] = {"subject_id": subject_id}
|
| 482 |
+
row.update(feats)
|
| 483 |
+
rows.append(row)
|
| 484 |
+
|
| 485 |
+
if not rows:
|
| 486 |
+
return pd.DataFrame(columns=[
|
| 487 |
+
"subject_id", "site", "feature", "feature_value", "harmonization_state",
|
| 488 |
+
])
|
| 489 |
+
|
| 490 |
+
raw_features = pd.DataFrame(rows).merge(sites_df, on="subject_id", how="left")
|
| 491 |
+
if raw_features["site"].isna().any():
|
| 492 |
+
missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
|
| 493 |
+
raise KeyError(
|
| 494 |
+
f"sites_csv missing site assignment for subjects: {missing}"
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Post-ComBat: variance-aware harmonization. Reuses the same logic as
|
| 498 |
+
# run_pipeline so diagnostics reflect production behavior exactly.
|
| 499 |
+
col_std = raw_features[feature_cols].std()
|
| 500 |
+
var_feature_cols = [
|
| 501 |
+
c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD
|
| 502 |
+
]
|
| 503 |
+
zero_var_cols = [
|
| 504 |
+
c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD
|
| 505 |
+
]
|
| 506 |
+
if not var_feature_cols:
|
| 507 |
+
harmonized = raw_features[feature_cols].copy()
|
| 508 |
+
else:
|
| 509 |
+
harmonized = harmonize_combat(
|
| 510 |
+
raw_features, raw_features["site"], var_feature_cols,
|
| 511 |
+
)
|
| 512 |
+
for c in zero_var_cols:
|
| 513 |
+
harmonized[c] = raw_features[c].to_numpy()
|
| 514 |
+
harmonized = harmonized[feature_cols]
|
| 515 |
+
post_features = pd.concat(
|
| 516 |
+
[raw_features[["subject_id", "site"]].reset_index(drop=True),
|
| 517 |
+
harmonized.reset_index(drop=True)],
|
| 518 |
+
axis=1,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
long_pre = raw_features.melt(
|
| 522 |
+
id_vars=["subject_id", "site"], value_vars=feature_cols,
|
| 523 |
+
var_name="feature", value_name="feature_value",
|
| 524 |
+
)
|
| 525 |
+
long_pre["harmonization_state"] = "Pre-ComBat"
|
| 526 |
+
long_post = post_features.melt(
|
| 527 |
+
id_vars=["subject_id", "site"], value_vars=feature_cols,
|
| 528 |
+
var_name="feature", value_name="feature_value",
|
| 529 |
+
)
|
| 530 |
+
long_post["harmonization_state"] = "Post-ComBat"
|
| 531 |
+
return pd.concat([long_pre, long_post], ignore_index=True)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
if __name__ == "__main__":
|
| 535 |
# Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
|
| 536 |
# Expects `data/raw/mri/sites.csv` with columns `subject_id, site`.
|
tests/pipelines/test_mri_pipeline.py
CHANGED
|
@@ -466,3 +466,43 @@ class TestMRIPipelineMLflow:
|
|
| 466 |
assert len(runs) >= 1
|
| 467 |
assert "metrics.subjects_out" in runs.columns
|
| 468 |
assert runs.iloc[0]["metrics.subjects_out"] > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
assert len(runs) >= 1
|
| 467 |
assert "metrics.subjects_out" in runs.columns
|
| 468 |
assert runs.iloc[0]["metrics.subjects_out"] > 0
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class TestComputeHarmonizationDiagnostics:
|
| 472 |
+
def test_returns_long_format_with_pre_and_post_states(self, tmp_path: Path):
|
| 473 |
+
from tests.fixtures.build_mri_fixture import build as build_mri
|
| 474 |
+
from src.pipelines.mri_pipeline import compute_harmonization_diagnostics
|
| 475 |
+
|
| 476 |
+
fixture_dir = build_mri(out_dir=tmp_path / "mri")
|
| 477 |
+
diagnostics = compute_harmonization_diagnostics(
|
| 478 |
+
input_dir=fixture_dir,
|
| 479 |
+
sites_csv=fixture_dir / "sites.csv",
|
| 480 |
+
)
|
| 481 |
+
assert "feature_value" in diagnostics.columns
|
| 482 |
+
assert "site" in diagnostics.columns
|
| 483 |
+
assert "harmonization_state" in diagnostics.columns
|
| 484 |
+
assert "feature" in diagnostics.columns
|
| 485 |
+
states = set(diagnostics["harmonization_state"].unique())
|
| 486 |
+
assert states == {"Pre-ComBat", "Post-ComBat"}
|
| 487 |
+
|
| 488 |
+
def test_post_combat_site_gap_is_smaller_than_pre(self, tmp_path: Path):
|
| 489 |
+
"""Day-3 demonstrated 5.0 → 0.0015 gap reduction. This regression
|
| 490 |
+
test pins the property: post-ComBat per-site means MUST be closer
|
| 491 |
+
together than pre-ComBat per-site means."""
|
| 492 |
+
from tests.fixtures.build_mri_fixture import build as build_mri
|
| 493 |
+
from src.pipelines.mri_pipeline import compute_harmonization_diagnostics
|
| 494 |
+
|
| 495 |
+
fixture_dir = build_mri(out_dir=tmp_path / "mri")
|
| 496 |
+
diagnostics = compute_harmonization_diagnostics(
|
| 497 |
+
input_dir=fixture_dir,
|
| 498 |
+
sites_csv=fixture_dir / "sites.csv",
|
| 499 |
+
)
|
| 500 |
+
pre = diagnostics[diagnostics["harmonization_state"] == "Pre-ComBat"]
|
| 501 |
+
post = diagnostics[diagnostics["harmonization_state"] == "Post-ComBat"]
|
| 502 |
+
# Compute site-gap as range of per-site means on the first feature
|
| 503 |
+
feat = diagnostics["feature"].iloc[0]
|
| 504 |
+
pre_gap = pre[pre["feature"] == feat].groupby("site")["feature_value"].mean().agg(lambda s: s.max() - s.min())
|
| 505 |
+
post_gap = post[post["feature"] == feat].groupby("site")["feature_value"].mean().agg(lambda s: s.max() - s.min())
|
| 506 |
+
assert post_gap < pre_gap, (
|
| 507 |
+
f"Expected post-gap < pre-gap, got pre={pre_gap}, post={post_gap}"
|
| 508 |
+
)
|