mekosotto Claude Sonnet 4.6 commited on
Commit
b18a079
·
1 Parent(s): 7215c7f

fix(mri): handle all-constant features; tighten variance threshold; reorder log

Browse files
src/pipelines/mri_pipeline.py CHANGED
@@ -266,6 +266,14 @@ DEFAULT_INPUT = Path("data/raw/mri")
266
  DEFAULT_OUTPUT = Path("data/processed/mri_features.parquet")
267
 
268
 
 
 
 
 
 
 
 
 
269
  def _list_nifti_volumes(input_dir: Path) -> list[Path]:
270
  """Return sorted list of .nii / .nii.gz files in `input_dir`."""
271
  return sorted(
@@ -326,8 +334,7 @@ def run_pipeline(
326
  continue
327
  mask = mask_brain(volume, intensity_threshold=intensity_threshold)
328
  feats = extract_features_from_volume(volume, mask, n_roi_axes=n_roi_axes)
329
- feats["subject_id"] = subject_id
330
- rows.append(feats)
331
 
332
  n_total = len(nifti_paths)
333
  n_dropped = len(invalid_subject_ids)
@@ -373,18 +380,31 @@ def run_pipeline(
373
  f"sites_csv missing site assignment for subjects: {missing}"
374
  )
375
 
376
- # ComBat cannot handle zero-variance columns (var_pooled = 0 → NaN divide).
377
- # Split feature_cols into variable (harmonize) and constant (pass through).
378
- var_feature_cols = [c for c in feature_cols if raw_features[c].std() > 0]
379
- zero_var_cols = [c for c in feature_cols if raw_features[c].std() == 0]
 
 
380
 
381
- harmonized = harmonize_combat(
382
- raw_features, raw_features["site"], var_feature_cols,
383
- )
384
- # Re-attach zero-variance columns (unchanged) and restore original column order.
385
- for c in zero_var_cols:
386
- harmonized[c] = raw_features[c].to_numpy()
387
- harmonized = harmonized[feature_cols]
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  final = pd.concat(
390
  [raw_features[["subject_id", "site"]].reset_index(drop=True),
@@ -392,6 +412,11 @@ def run_pipeline(
392
  axis=1,
393
  )
394
 
 
 
 
 
 
395
  output_path.parent.mkdir(parents=True, exist_ok=True)
396
  if output_path.is_dir():
397
  raise IsADirectoryError(
@@ -402,10 +427,6 @@ def run_pipeline(
402
  final.to_parquet(
403
  output_path, index=False, engine="pyarrow", compression="snappy",
404
  )
405
- logger.info(
406
- "Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
407
- n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
408
- )
409
  logger.info(
410
  "Wrote processed features to %s (rows=%d, cols=%d)",
411
  output_path, len(final), final.shape[1],
 
266
  DEFAULT_OUTPUT = Path("data/processed/mri_features.parquet")
267
 
268
 
269
+ # Variance floor used to decide whether a feature column is "constant" for
270
+ # ComBat. Strict ``std() > 0`` would still send near-zero-variance columns
271
+ # (e.g. ULP-level differences) into ComBat, where var_pooled ≈ 0 produces
272
+ # NaN. 1e-8 is well above machine epsilon and far below any biologically
273
+ # meaningful signal variance.
274
+ _MIN_VAR_THRESHOLD: float = 1e-8
275
+
276
+
277
  def _list_nifti_volumes(input_dir: Path) -> list[Path]:
278
  """Return sorted list of .nii / .nii.gz files in `input_dir`."""
279
  return sorted(
 
334
  continue
335
  mask = mask_brain(volume, intensity_threshold=intensity_threshold)
336
  feats = extract_features_from_volume(volume, mask, n_roi_axes=n_roi_axes)
337
+ rows.append({"subject_id": subject_id, **feats})
 
338
 
339
  n_total = len(nifti_paths)
340
  n_dropped = len(invalid_subject_ids)
 
380
  f"sites_csv missing site assignment for subjects: {missing}"
381
  )
382
 
383
+ # ComBat cannot handle (near-)zero-variance columns: var_pooled 0 produces
384
+ # NaN. Split feature_cols on a strictly-positive variance floor so ULP-level
385
+ # noise is treated as constant.
386
+ col_std = raw_features[feature_cols].std()
387
+ var_feature_cols = [c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD]
388
+ zero_var_cols = [c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD]
389
 
390
+ if not var_feature_cols:
391
+ # Degenerate dataset: every feature is essentially constant. ComBat has
392
+ # no signal to harmonize on; pass all columns through and warn.
393
+ logger.warning(
394
+ "All %d feature columns have variance ≤ %.1e; ComBat skipped "
395
+ "(output contains unharmonized features).",
396
+ len(feature_cols), _MIN_VAR_THRESHOLD,
397
+ )
398
+ harmonized = raw_features[feature_cols].copy()
399
+ else:
400
+ harmonized = harmonize_combat(
401
+ raw_features, raw_features["site"], var_feature_cols,
402
+ )
403
+ # Re-attach zero-variance columns (unchanged) and restore the original
404
+ # column order.
405
+ for c in zero_var_cols:
406
+ harmonized[c] = raw_features[c].to_numpy()
407
+ harmonized = harmonized[feature_cols]
408
 
409
  final = pd.concat(
410
  [raw_features[["subject_id", "site"]].reset_index(drop=True),
 
412
  axis=1,
413
  )
414
 
415
+ logger.info(
416
+ "Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
417
+ n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
418
+ )
419
+
420
  output_path.parent.mkdir(parents=True, exist_ok=True)
421
  if output_path.is_dir():
422
  raise IsADirectoryError(
 
427
  final.to_parquet(
428
  output_path, index=False, engine="pyarrow", compression="snappy",
429
  )
 
 
 
 
430
  logger.info(
431
  "Wrote processed features to %s (rows=%d, cols=%d)",
432
  output_path, len(final), final.shape[1],
tests/pipelines/test_mri_pipeline.py CHANGED
@@ -378,3 +378,70 @@ class TestRunPipeline:
378
  # 5 surviving valid subjects (subject_5 dropped).
379
  assert len(df) == 5
380
  assert "subject_5" not in df["subject_id"].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  # 5 surviving valid subjects (subject_5 dropped).
379
  assert len(df) == 5
380
  assert "subject_5" not in df["subject_id"].tolist()
381
+
382
+ def test_run_pipeline_handles_all_constant_features(self, tmp_path: Path) -> None:
383
+ """Degenerate dataset: every feature column is constant — ComBat must be
384
+ skipped gracefully with a WARNING, not crash with ValueError."""
385
+ import io
386
+ import logging
387
+
388
+ from src.core.logger import get_logger
389
+ from src.pipelines import mri_pipeline as mod
390
+
391
+ raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
392
+ # Overwrite all volumes with the same constant intensity so every
393
+ # feature column is identical across subjects.
394
+ affine = np.eye(4)
395
+ for nii in sorted(raw_dir.glob("*.nii.gz")):
396
+ const_vol = np.full((8, 8, 8), 7.0, dtype=np.float64)
397
+ nib.save(nib.Nifti1Image(const_vol, affine=affine), nii)
398
+
399
+ logger = get_logger(mod.__name__, level=logging.INFO)
400
+ handler = logger.handlers[0]
401
+ buf = io.StringIO()
402
+ original_stream = handler.stream
403
+ handler.stream = buf
404
+ try:
405
+ run_pipeline(
406
+ input_dir=raw_dir, sites_csv=sites_csv,
407
+ output_path=output_path, intensity_threshold=1.0,
408
+ )
409
+ finally:
410
+ handler.stream = original_stream
411
+
412
+ df = pd.read_parquet(output_path)
413
+ assert len(df) == 6
414
+ feat_cols = [c for c in df.columns if c.startswith("feat_")]
415
+ # All-zero-variance fallback: features pass through unchanged.
416
+ assert df[feat_cols].notna().all().all()
417
+ log_output = buf.getvalue()
418
+ assert "ComBat skipped" in log_output
419
+
420
+ def test_run_pipeline_extraction_log_precedes_write(self, tmp_path: Path) -> None:
421
+ """The 'Feature extraction complete' INFO must fire BEFORE the
422
+ 'Wrote processed features' INFO so that operators get a summary
423
+ even if to_parquet raises."""
424
+ import io
425
+ import logging
426
+
427
+ from src.core.logger import get_logger
428
+ from src.pipelines import mri_pipeline as mod
429
+
430
+ raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
431
+
432
+ logger = get_logger(mod.__name__, level=logging.INFO)
433
+ handler = logger.handlers[0]
434
+ buf = io.StringIO()
435
+ original_stream = handler.stream
436
+ handler.stream = buf
437
+ try:
438
+ run_pipeline(
439
+ input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
440
+ )
441
+ finally:
442
+ handler.stream = original_stream
443
+
444
+ log_output = buf.getvalue()
445
+ extract_idx = log_output.index("Feature extraction complete:")
446
+ wrote_idx = log_output.index("Wrote processed features to")
447
+ assert extract_idx < wrote_idx, "extraction summary must precede write log"