feat(mri): log run params, metrics, and parquet artifact to MLflow
Browse files
src/pipelines/mri_pipeline.py
CHANGED
|
@@ -11,6 +11,7 @@ traceability (in/out/dropped counts at INFO), and idempotent overwrite.
|
|
| 11 |
"""
|
| 12 |
from __future__ import annotations
|
| 13 |
|
|
|
|
| 14 |
from pathlib import Path
|
| 15 |
|
| 16 |
import nibabel as nib
|
|
@@ -21,6 +22,7 @@ from scipy import ndimage as scipy_ndimage
|
|
| 21 |
from src.core.determinism import pin_threads
|
| 22 |
from src.core.logger import get_logger
|
| 23 |
from src.core.storage import write_parquet
|
|
|
|
| 24 |
|
| 25 |
logger = get_logger(__name__)
|
| 26 |
|
|
@@ -315,6 +317,7 @@ def run_pipeline(
|
|
| 315 |
if not sites_csv.exists():
|
| 316 |
raise FileNotFoundError(f"sites_csv not found: {sites_csv}")
|
| 317 |
|
|
|
|
| 318 |
logger.info("Reading MRI volumes from %s", input_dir)
|
| 319 |
nifti_paths = _list_nifti_volumes(input_dir)
|
| 320 |
sites_df = pd.read_csv(sites_csv)
|
|
@@ -354,65 +357,85 @@ def run_pipeline(
|
|
| 354 |
"Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
|
| 355 |
n_total, n_dropped, 100.0 * n_dropped / max(n_total, 1),
|
| 356 |
)
|
| 357 |
-
|
| 358 |
columns=["subject_id", "site", *feature_cols]
|
| 359 |
).astype({c: np.float64 for c in feature_cols})
|
| 360 |
-
write_parquet(empty, output_path)
|
| 361 |
-
return
|
| 362 |
-
|
| 363 |
-
raw_features = pd.DataFrame(rows)
|
| 364 |
-
raw_features = raw_features.merge(sites_df, on="subject_id", how="left")
|
| 365 |
-
if raw_features["site"].isna().any():
|
| 366 |
-
missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
|
| 367 |
-
raise KeyError(
|
| 368 |
-
f"sites_csv missing site assignment for subjects: {missing}"
|
| 369 |
-
)
|
| 370 |
-
|
| 371 |
-
# ComBat cannot handle (near-)zero-variance columns: var_pooled ≈ 0 produces
|
| 372 |
-
# NaN. Split feature_cols on a strictly-positive variance floor so ULP-level
|
| 373 |
-
# noise is treated as constant.
|
| 374 |
-
col_std = raw_features[feature_cols].std()
|
| 375 |
-
var_feature_cols = [c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD]
|
| 376 |
-
zero_var_cols = [c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD]
|
| 377 |
-
|
| 378 |
-
if not var_feature_cols:
|
| 379 |
-
# Degenerate dataset: every feature is essentially constant. ComBat has
|
| 380 |
-
# no signal to harmonize on; pass all columns through and warn.
|
| 381 |
-
logger.warning(
|
| 382 |
-
"All %d feature columns have variance ≤ %.1e; ComBat skipped "
|
| 383 |
-
"(output contains unharmonized features).",
|
| 384 |
-
len(feature_cols), _MIN_VAR_THRESHOLD,
|
| 385 |
-
)
|
| 386 |
-
harmonized = raw_features[feature_cols].copy()
|
| 387 |
else:
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
)
|
| 391 |
-
# Re-attach zero-variance columns (unchanged) and restore the original
|
| 392 |
-
# column order.
|
| 393 |
-
for c in zero_var_cols:
|
| 394 |
-
harmonized[c] = raw_features[c].to_numpy()
|
| 395 |
-
harmonized = harmonized[feature_cols]
|
| 396 |
-
|
| 397 |
-
final = pd.concat(
|
| 398 |
-
[raw_features[["subject_id", "site"]].reset_index(drop=True),
|
| 399 |
-
harmonized.reset_index(drop=True)],
|
| 400 |
-
axis=1,
|
| 401 |
-
)
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
|
| 408 |
# Parquet preserves dtypes (float64 features stay float64) and is
|
| 409 |
-
# byte-deterministic with single-threaded snappy. AGENTS.md §6.
|
|
|
|
| 410 |
write_parquet(final, output_path)
|
| 411 |
logger.info(
|
| 412 |
"Wrote processed features to %s (rows=%d, cols=%d)",
|
| 413 |
output_path, len(final), final.shape[1],
|
| 414 |
)
|
| 415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
if __name__ == "__main__":
|
| 418 |
# Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
|
|
|
|
| 11 |
"""
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
+
import time
|
| 15 |
from pathlib import Path
|
| 16 |
|
| 17 |
import nibabel as nib
|
|
|
|
| 22 |
from src.core.determinism import pin_threads
|
| 23 |
from src.core.logger import get_logger
|
| 24 |
from src.core.storage import write_parquet
|
| 25 |
+
from src.core.tracking import track_pipeline_run
|
| 26 |
|
| 27 |
logger = get_logger(__name__)
|
| 28 |
|
|
|
|
| 317 |
if not sites_csv.exists():
|
| 318 |
raise FileNotFoundError(f"sites_csv not found: {sites_csv}")
|
| 319 |
|
| 320 |
+
started = time.perf_counter()
|
| 321 |
logger.info("Reading MRI volumes from %s", input_dir)
|
| 322 |
nifti_paths = _list_nifti_volumes(input_dir)
|
| 323 |
sites_df = pd.read_csv(sites_csv)
|
|
|
|
| 357 |
"Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
|
| 358 |
n_total, n_dropped, 100.0 * n_dropped / max(n_total, 1),
|
| 359 |
)
|
| 360 |
+
final = pd.DataFrame(
|
| 361 |
columns=["subject_id", "site", *feature_cols]
|
| 362 |
).astype({c: np.float64 for c in feature_cols})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
else:
|
| 364 |
+
raw_features = pd.DataFrame(rows)
|
| 365 |
+
raw_features = raw_features.merge(sites_df, on="subject_id", how="left")
|
| 366 |
+
if raw_features["site"].isna().any():
|
| 367 |
+
missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
|
| 368 |
+
raise KeyError(
|
| 369 |
+
f"sites_csv missing site assignment for subjects: {missing}"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# ComBat cannot handle (near-)zero-variance columns: var_pooled ≈ 0 produces
|
| 373 |
+
# NaN. Split feature_cols on a strictly-positive variance floor so ULP-level
|
| 374 |
+
# noise is treated as constant.
|
| 375 |
+
col_std = raw_features[feature_cols].std()
|
| 376 |
+
var_feature_cols = [c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD]
|
| 377 |
+
zero_var_cols = [c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD]
|
| 378 |
+
|
| 379 |
+
if not var_feature_cols:
|
| 380 |
+
# Degenerate dataset: every feature is essentially constant. ComBat has
|
| 381 |
+
# no signal to harmonize on; pass all columns through and warn.
|
| 382 |
+
logger.warning(
|
| 383 |
+
"All %d feature columns have variance ≤ %.1e; ComBat skipped "
|
| 384 |
+
"(output contains unharmonized features).",
|
| 385 |
+
len(feature_cols), _MIN_VAR_THRESHOLD,
|
| 386 |
+
)
|
| 387 |
+
harmonized = raw_features[feature_cols].copy()
|
| 388 |
+
else:
|
| 389 |
+
harmonized = harmonize_combat(
|
| 390 |
+
raw_features, raw_features["site"], var_feature_cols,
|
| 391 |
+
)
|
| 392 |
+
# Re-attach zero-variance columns (unchanged) and restore the original
|
| 393 |
+
# column order.
|
| 394 |
+
for c in zero_var_cols:
|
| 395 |
+
harmonized[c] = raw_features[c].to_numpy()
|
| 396 |
+
harmonized = harmonized[feature_cols]
|
| 397 |
+
|
| 398 |
+
final = pd.concat(
|
| 399 |
+
[raw_features[["subject_id", "site"]].reset_index(drop=True),
|
| 400 |
+
harmonized.reset_index(drop=True)],
|
| 401 |
+
axis=1,
|
| 402 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
+
logger.info(
|
| 405 |
+
"Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
|
| 406 |
+
n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
|
| 407 |
+
)
|
| 408 |
|
| 409 |
# Parquet preserves dtypes (float64 features stay float64) and is
|
| 410 |
+
# byte-deterministic with single-threaded snappy. AGENTS.md §6. Unconditional
|
| 411 |
+
# so the §4-rule-4 traceability log fires for both empty and non-empty paths.
|
| 412 |
write_parquet(final, output_path)
|
| 413 |
logger.info(
|
| 414 |
"Wrote processed features to %s (rows=%d, cols=%d)",
|
| 415 |
output_path, len(final), final.shape[1],
|
| 416 |
)
|
| 417 |
|
| 418 |
+
duration_sec = time.perf_counter() - started
|
| 419 |
+
|
| 420 |
+
with track_pipeline_run(
|
| 421 |
+
experiment_name="mri_pipeline",
|
| 422 |
+
params={
|
| 423 |
+
"input_dir": str(input_dir),
|
| 424 |
+
"sites_csv": str(sites_csv),
|
| 425 |
+
"output_path": str(output_path),
|
| 426 |
+
"intensity_threshold": str(intensity_threshold),
|
| 427 |
+
"n_roi_axes": str(n_roi_axes),
|
| 428 |
+
},
|
| 429 |
+
metrics={
|
| 430 |
+
"subjects_in": float(n_total),
|
| 431 |
+
"subjects_out": float(len(final)),
|
| 432 |
+
"subjects_dropped": float(n_dropped),
|
| 433 |
+
"duration_sec": duration_sec,
|
| 434 |
+
},
|
| 435 |
+
artifact_path=output_path,
|
| 436 |
+
):
|
| 437 |
+
pass
|
| 438 |
+
|
| 439 |
|
| 440 |
if __name__ == "__main__":
|
| 441 |
# Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
|
tests/pipelines/test_mri_pipeline.py
CHANGED
|
@@ -445,3 +445,24 @@ class TestRunPipeline:
|
|
| 445 |
extract_idx = log_output.index("Feature extraction complete:")
|
| 446 |
wrote_idx = log_output.index("Wrote processed features to")
|
| 447 |
assert extract_idx < wrote_idx, "extraction summary must precede write log"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
extract_idx = log_output.index("Feature extraction complete:")
|
| 446 |
wrote_idx = log_output.index("Wrote processed features to")
|
| 447 |
assert extract_idx < wrote_idx, "extraction summary must precede write log"
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
import mlflow
|
| 451 |
+
from src.pipelines import mri_pipeline as _mri_for_mlflow_test
|
| 452 |
+
from tests.fixtures import build_mri_fixture as _build_mri_for_mlflow_test
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class TestMRIPipelineMLflow:
|
| 456 |
+
def test_run_pipeline_creates_mlflow_run(self, tmp_path):
|
| 457 |
+
fixture_dir = _build_mri_for_mlflow_test.build(out_dir=tmp_path / "mri_fixture")
|
| 458 |
+
out = tmp_path / "out.parquet"
|
| 459 |
+
_mri_for_mlflow_test.run_pipeline(
|
| 460 |
+
input_dir=fixture_dir, output_path=out,
|
| 461 |
+
)
|
| 462 |
+
runs = mlflow.search_runs(
|
| 463 |
+
experiment_names=["mri_pipeline"],
|
| 464 |
+
order_by=["start_time DESC"],
|
| 465 |
+
)
|
| 466 |
+
assert len(runs) >= 1
|
| 467 |
+
assert "metrics.subjects_out" in runs.columns
|
| 468 |
+
assert runs.iloc[0]["metrics.subjects_out"] > 0
|