mekosotto commited on
Commit
837970b
·
1 Parent(s): 0ce94e3

feat(mri): log run params, metrics, and parquet artifact to MLflow

Browse files
src/pipelines/mri_pipeline.py CHANGED
@@ -11,6 +11,7 @@ traceability (in/out/dropped counts at INFO), and idempotent overwrite.
11
  """
12
  from __future__ import annotations
13
 
 
14
  from pathlib import Path
15
 
16
  import nibabel as nib
@@ -21,6 +22,7 @@ from scipy import ndimage as scipy_ndimage
21
  from src.core.determinism import pin_threads
22
  from src.core.logger import get_logger
23
  from src.core.storage import write_parquet
 
24
 
25
  logger = get_logger(__name__)
26
 
@@ -315,6 +317,7 @@ def run_pipeline(
315
  if not sites_csv.exists():
316
  raise FileNotFoundError(f"sites_csv not found: {sites_csv}")
317
 
 
318
  logger.info("Reading MRI volumes from %s", input_dir)
319
  nifti_paths = _list_nifti_volumes(input_dir)
320
  sites_df = pd.read_csv(sites_csv)
@@ -354,65 +357,85 @@ def run_pipeline(
354
  "Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
355
  n_total, n_dropped, 100.0 * n_dropped / max(n_total, 1),
356
  )
357
- empty = pd.DataFrame(
358
  columns=["subject_id", "site", *feature_cols]
359
  ).astype({c: np.float64 for c in feature_cols})
360
- write_parquet(empty, output_path)
361
- return
362
-
363
- raw_features = pd.DataFrame(rows)
364
- raw_features = raw_features.merge(sites_df, on="subject_id", how="left")
365
- if raw_features["site"].isna().any():
366
- missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
367
- raise KeyError(
368
- f"sites_csv missing site assignment for subjects: {missing}"
369
- )
370
-
371
- # ComBat cannot handle (near-)zero-variance columns: var_pooled ≈ 0 produces
372
- # NaN. Split feature_cols on a strictly-positive variance floor so ULP-level
373
- # noise is treated as constant.
374
- col_std = raw_features[feature_cols].std()
375
- var_feature_cols = [c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD]
376
- zero_var_cols = [c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD]
377
-
378
- if not var_feature_cols:
379
- # Degenerate dataset: every feature is essentially constant. ComBat has
380
- # no signal to harmonize on; pass all columns through and warn.
381
- logger.warning(
382
- "All %d feature columns have variance ≤ %.1e; ComBat skipped "
383
- "(output contains unharmonized features).",
384
- len(feature_cols), _MIN_VAR_THRESHOLD,
385
- )
386
- harmonized = raw_features[feature_cols].copy()
387
  else:
388
- harmonized = harmonize_combat(
389
- raw_features, raw_features["site"], var_feature_cols,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  )
391
- # Re-attach zero-variance columns (unchanged) and restore the original
392
- # column order.
393
- for c in zero_var_cols:
394
- harmonized[c] = raw_features[c].to_numpy()
395
- harmonized = harmonized[feature_cols]
396
-
397
- final = pd.concat(
398
- [raw_features[["subject_id", "site"]].reset_index(drop=True),
399
- harmonized.reset_index(drop=True)],
400
- axis=1,
401
- )
402
 
403
- logger.info(
404
- "Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
405
- n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
406
- )
407
 
408
  # Parquet preserves dtypes (float64 features stay float64) and is
409
- # byte-deterministic with single-threaded snappy. AGENTS.md §6.
 
410
  write_parquet(final, output_path)
411
  logger.info(
412
  "Wrote processed features to %s (rows=%d, cols=%d)",
413
  output_path, len(final), final.shape[1],
414
  )
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  if __name__ == "__main__":
418
  # Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
 
11
  """
12
  from __future__ import annotations
13
 
14
+ import time
15
  from pathlib import Path
16
 
17
  import nibabel as nib
 
22
  from src.core.determinism import pin_threads
23
  from src.core.logger import get_logger
24
  from src.core.storage import write_parquet
25
+ from src.core.tracking import track_pipeline_run
26
 
27
  logger = get_logger(__name__)
28
 
 
317
  if not sites_csv.exists():
318
  raise FileNotFoundError(f"sites_csv not found: {sites_csv}")
319
 
320
+ started = time.perf_counter()
321
  logger.info("Reading MRI volumes from %s", input_dir)
322
  nifti_paths = _list_nifti_volumes(input_dir)
323
  sites_df = pd.read_csv(sites_csv)
 
357
  "Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
358
  n_total, n_dropped, 100.0 * n_dropped / max(n_total, 1),
359
  )
360
+ final = pd.DataFrame(
361
  columns=["subject_id", "site", *feature_cols]
362
  ).astype({c: np.float64 for c in feature_cols})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  else:
364
+ raw_features = pd.DataFrame(rows)
365
+ raw_features = raw_features.merge(sites_df, on="subject_id", how="left")
366
+ if raw_features["site"].isna().any():
367
+ missing = raw_features.loc[raw_features["site"].isna(), "subject_id"].tolist()
368
+ raise KeyError(
369
+ f"sites_csv missing site assignment for subjects: {missing}"
370
+ )
371
+
372
+ # ComBat cannot handle (near-)zero-variance columns: var_pooled ≈ 0 produces
373
+ # NaN. Split feature_cols on a strictly-positive variance floor so ULP-level
374
+ # noise is treated as constant.
375
+ col_std = raw_features[feature_cols].std()
376
+ var_feature_cols = [c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD]
377
+ zero_var_cols = [c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD]
378
+
379
+ if not var_feature_cols:
380
+ # Degenerate dataset: every feature is essentially constant. ComBat has
381
+ # no signal to harmonize on; pass all columns through and warn.
382
+ logger.warning(
383
+ "All %d feature columns have variance ≤ %.1e; ComBat skipped "
384
+ "(output contains unharmonized features).",
385
+ len(feature_cols), _MIN_VAR_THRESHOLD,
386
+ )
387
+ harmonized = raw_features[feature_cols].copy()
388
+ else:
389
+ harmonized = harmonize_combat(
390
+ raw_features, raw_features["site"], var_feature_cols,
391
+ )
392
+ # Re-attach zero-variance columns (unchanged) and restore the original
393
+ # column order.
394
+ for c in zero_var_cols:
395
+ harmonized[c] = raw_features[c].to_numpy()
396
+ harmonized = harmonized[feature_cols]
397
+
398
+ final = pd.concat(
399
+ [raw_features[["subject_id", "site"]].reset_index(drop=True),
400
+ harmonized.reset_index(drop=True)],
401
+ axis=1,
402
  )
 
 
 
 
 
 
 
 
 
 
 
403
 
404
+ logger.info(
405
+ "Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
406
+ n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
407
+ )
408
 
409
  # Parquet preserves dtypes (float64 features stay float64) and is
410
+ # byte-deterministic with single-threaded snappy. AGENTS.md §6. Unconditional
411
+ # so the §4-rule-4 traceability log fires for both empty and non-empty paths.
412
  write_parquet(final, output_path)
413
  logger.info(
414
  "Wrote processed features to %s (rows=%d, cols=%d)",
415
  output_path, len(final), final.shape[1],
416
  )
417
 
418
+ duration_sec = time.perf_counter() - started
419
+
420
+ with track_pipeline_run(
421
+ experiment_name="mri_pipeline",
422
+ params={
423
+ "input_dir": str(input_dir),
424
+ "sites_csv": str(sites_csv),
425
+ "output_path": str(output_path),
426
+ "intensity_threshold": str(intensity_threshold),
427
+ "n_roi_axes": str(n_roi_axes),
428
+ },
429
+ metrics={
430
+ "subjects_in": float(n_total),
431
+ "subjects_out": float(len(final)),
432
+ "subjects_dropped": float(n_dropped),
433
+ "duration_sec": duration_sec,
434
+ },
435
+ artifact_path=output_path,
436
+ ):
437
+ pass
438
+
439
 
440
  if __name__ == "__main__":
441
  # Day-3 CLI entrypoint — runs with default paths against `data/raw/mri/`.
tests/pipelines/test_mri_pipeline.py CHANGED
@@ -445,3 +445,24 @@ class TestRunPipeline:
445
  extract_idx = log_output.index("Feature extraction complete:")
446
  wrote_idx = log_output.index("Wrote processed features to")
447
  assert extract_idx < wrote_idx, "extraction summary must precede write log"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  extract_idx = log_output.index("Feature extraction complete:")
446
  wrote_idx = log_output.index("Wrote processed features to")
447
  assert extract_idx < wrote_idx, "extraction summary must precede write log"
448
+
449
+
450
+ import mlflow
451
+ from src.pipelines import mri_pipeline as _mri_for_mlflow_test
452
+ from tests.fixtures import build_mri_fixture as _build_mri_for_mlflow_test
453
+
454
+
455
+ class TestMRIPipelineMLflow:
456
+ def test_run_pipeline_creates_mlflow_run(self, tmp_path):
457
+ fixture_dir = _build_mri_for_mlflow_test.build(out_dir=tmp_path / "mri_fixture")
458
+ out = tmp_path / "out.parquet"
459
+ _mri_for_mlflow_test.run_pipeline(
460
+ input_dir=fixture_dir, output_path=out,
461
+ )
462
+ runs = mlflow.search_runs(
463
+ experiment_names=["mri_pipeline"],
464
+ order_by=["start_time DESC"],
465
+ )
466
+ assert len(runs) >= 1
467
+ assert "metrics.subjects_out" in runs.columns
468
+ assert runs.iloc[0]["metrics.subjects_out"] > 0