fix(mri): handle all-constant features; tighten variance threshold; reorder log
Browse files
src/pipelines/mri_pipeline.py
CHANGED
|
@@ -266,6 +266,14 @@ DEFAULT_INPUT = Path("data/raw/mri")
|
|
| 266 |
DEFAULT_OUTPUT = Path("data/processed/mri_features.parquet")
|
| 267 |
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
def _list_nifti_volumes(input_dir: Path) -> list[Path]:
|
| 270 |
"""Return sorted list of .nii / .nii.gz files in `input_dir`."""
|
| 271 |
return sorted(
|
|
@@ -326,8 +334,7 @@ def run_pipeline(
|
|
| 326 |
continue
|
| 327 |
mask = mask_brain(volume, intensity_threshold=intensity_threshold)
|
| 328 |
feats = extract_features_from_volume(volume, mask, n_roi_axes=n_roi_axes)
|
| 329 |
-
|
| 330 |
-
rows.append(feats)
|
| 331 |
|
| 332 |
n_total = len(nifti_paths)
|
| 333 |
n_dropped = len(invalid_subject_ids)
|
|
@@ -373,18 +380,31 @@ def run_pipeline(
|
|
| 373 |
f"sites_csv missing site assignment for subjects: {missing}"
|
| 374 |
)
|
| 375 |
|
| 376 |
-
# ComBat cannot handle zero-variance columns
|
| 377 |
-
# Split feature_cols
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
final = pd.concat(
|
| 390 |
[raw_features[["subject_id", "site"]].reset_index(drop=True),
|
|
@@ -392,6 +412,11 @@ def run_pipeline(
|
|
| 392 |
axis=1,
|
| 393 |
)
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 396 |
if output_path.is_dir():
|
| 397 |
raise IsADirectoryError(
|
|
@@ -402,10 +427,6 @@ def run_pipeline(
|
|
| 402 |
final.to_parquet(
|
| 403 |
output_path, index=False, engine="pyarrow", compression="snappy",
|
| 404 |
)
|
| 405 |
-
logger.info(
|
| 406 |
-
"Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
|
| 407 |
-
n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
|
| 408 |
-
)
|
| 409 |
logger.info(
|
| 410 |
"Wrote processed features to %s (rows=%d, cols=%d)",
|
| 411 |
output_path, len(final), final.shape[1],
|
|
|
|
| 266 |
DEFAULT_OUTPUT = Path("data/processed/mri_features.parquet")
|
| 267 |
|
| 268 |
|
| 269 |
+
# Variance floor used to decide whether a feature column is "constant" for
|
| 270 |
+
# ComBat. Strict ``std() > 0`` would still send near-zero-variance columns
|
| 271 |
+
# (e.g. ULP-level differences) into ComBat, where var_pooled ≈ 0 produces
|
| 272 |
+
# NaN. 1e-8 is well above machine epsilon and far below any biologically
|
| 273 |
+
# meaningful signal variance.
|
| 274 |
+
_MIN_VAR_THRESHOLD: float = 1e-8
|
| 275 |
+
|
| 276 |
+
|
| 277 |
def _list_nifti_volumes(input_dir: Path) -> list[Path]:
|
| 278 |
"""Return sorted list of .nii / .nii.gz files in `input_dir`."""
|
| 279 |
return sorted(
|
|
|
|
| 334 |
continue
|
| 335 |
mask = mask_brain(volume, intensity_threshold=intensity_threshold)
|
| 336 |
feats = extract_features_from_volume(volume, mask, n_roi_axes=n_roi_axes)
|
| 337 |
+
rows.append({"subject_id": subject_id, **feats})
|
|
|
|
| 338 |
|
| 339 |
n_total = len(nifti_paths)
|
| 340 |
n_dropped = len(invalid_subject_ids)
|
|
|
|
| 380 |
f"sites_csv missing site assignment for subjects: {missing}"
|
| 381 |
)
|
| 382 |
|
| 383 |
+
# ComBat cannot handle (near-)zero-variance columns: var_pooled ≈ 0 produces
|
| 384 |
+
# NaN. Split feature_cols on a strictly-positive variance floor so ULP-level
|
| 385 |
+
# noise is treated as constant.
|
| 386 |
+
col_std = raw_features[feature_cols].std()
|
| 387 |
+
var_feature_cols = [c for c in feature_cols if col_std[c] > _MIN_VAR_THRESHOLD]
|
| 388 |
+
zero_var_cols = [c for c in feature_cols if col_std[c] <= _MIN_VAR_THRESHOLD]
|
| 389 |
|
| 390 |
+
if not var_feature_cols:
|
| 391 |
+
# Degenerate dataset: every feature is essentially constant. ComBat has
|
| 392 |
+
# no signal to harmonize on; pass all columns through and warn.
|
| 393 |
+
logger.warning(
|
| 394 |
+
"All %d feature columns have variance ≤ %.1e; ComBat skipped "
|
| 395 |
+
"(output contains unharmonized features).",
|
| 396 |
+
len(feature_cols), _MIN_VAR_THRESHOLD,
|
| 397 |
+
)
|
| 398 |
+
harmonized = raw_features[feature_cols].copy()
|
| 399 |
+
else:
|
| 400 |
+
harmonized = harmonize_combat(
|
| 401 |
+
raw_features, raw_features["site"], var_feature_cols,
|
| 402 |
+
)
|
| 403 |
+
# Re-attach zero-variance columns (unchanged) and restore the original
|
| 404 |
+
# column order.
|
| 405 |
+
for c in zero_var_cols:
|
| 406 |
+
harmonized[c] = raw_features[c].to_numpy()
|
| 407 |
+
harmonized = harmonized[feature_cols]
|
| 408 |
|
| 409 |
final = pd.concat(
|
| 410 |
[raw_features[["subject_id", "site"]].reset_index(drop=True),
|
|
|
|
| 412 |
axis=1,
|
| 413 |
)
|
| 414 |
|
| 415 |
+
logger.info(
|
| 416 |
+
"Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
|
| 417 |
+
n_total, len(final), n_dropped, 100.0 * n_dropped / max(n_total, 1),
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 421 |
if output_path.is_dir():
|
| 422 |
raise IsADirectoryError(
|
|
|
|
| 427 |
final.to_parquet(
|
| 428 |
output_path, index=False, engine="pyarrow", compression="snappy",
|
| 429 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
logger.info(
|
| 431 |
"Wrote processed features to %s (rows=%d, cols=%d)",
|
| 432 |
output_path, len(final), final.shape[1],
|
tests/pipelines/test_mri_pipeline.py
CHANGED
|
@@ -378,3 +378,70 @@ class TestRunPipeline:
|
|
| 378 |
# 5 surviving valid subjects (subject_5 dropped).
|
| 379 |
assert len(df) == 5
|
| 380 |
assert "subject_5" not in df["subject_id"].tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
# 5 surviving valid subjects (subject_5 dropped).
|
| 379 |
assert len(df) == 5
|
| 380 |
assert "subject_5" not in df["subject_id"].tolist()
|
| 381 |
+
|
| 382 |
+
def test_run_pipeline_handles_all_constant_features(self, tmp_path: Path) -> None:
|
| 383 |
+
"""Degenerate dataset: every feature column is constant — ComBat must be
|
| 384 |
+
skipped gracefully with a WARNING, not crash with ValueError."""
|
| 385 |
+
import io
|
| 386 |
+
import logging
|
| 387 |
+
|
| 388 |
+
from src.core.logger import get_logger
|
| 389 |
+
from src.pipelines import mri_pipeline as mod
|
| 390 |
+
|
| 391 |
+
raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
|
| 392 |
+
# Overwrite all volumes with the same constant intensity so every
|
| 393 |
+
# feature column is identical across subjects.
|
| 394 |
+
affine = np.eye(4)
|
| 395 |
+
for nii in sorted(raw_dir.glob("*.nii.gz")):
|
| 396 |
+
const_vol = np.full((8, 8, 8), 7.0, dtype=np.float64)
|
| 397 |
+
nib.save(nib.Nifti1Image(const_vol, affine=affine), nii)
|
| 398 |
+
|
| 399 |
+
logger = get_logger(mod.__name__, level=logging.INFO)
|
| 400 |
+
handler = logger.handlers[0]
|
| 401 |
+
buf = io.StringIO()
|
| 402 |
+
original_stream = handler.stream
|
| 403 |
+
handler.stream = buf
|
| 404 |
+
try:
|
| 405 |
+
run_pipeline(
|
| 406 |
+
input_dir=raw_dir, sites_csv=sites_csv,
|
| 407 |
+
output_path=output_path, intensity_threshold=1.0,
|
| 408 |
+
)
|
| 409 |
+
finally:
|
| 410 |
+
handler.stream = original_stream
|
| 411 |
+
|
| 412 |
+
df = pd.read_parquet(output_path)
|
| 413 |
+
assert len(df) == 6
|
| 414 |
+
feat_cols = [c for c in df.columns if c.startswith("feat_")]
|
| 415 |
+
# All-zero-variance fallback: features pass through unchanged.
|
| 416 |
+
assert df[feat_cols].notna().all().all()
|
| 417 |
+
log_output = buf.getvalue()
|
| 418 |
+
assert "ComBat skipped" in log_output
|
| 419 |
+
|
| 420 |
+
def test_run_pipeline_extraction_log_precedes_write(self, tmp_path: Path) -> None:
|
| 421 |
+
"""The 'Feature extraction complete' INFO must fire BEFORE the
|
| 422 |
+
'Wrote processed features' INFO so that operators get a summary
|
| 423 |
+
even if to_parquet raises."""
|
| 424 |
+
import io
|
| 425 |
+
import logging
|
| 426 |
+
|
| 427 |
+
from src.core.logger import get_logger
|
| 428 |
+
from src.pipelines import mri_pipeline as mod
|
| 429 |
+
|
| 430 |
+
raw_dir, sites_csv, output_path = self._stage_inputs(tmp_path)
|
| 431 |
+
|
| 432 |
+
logger = get_logger(mod.__name__, level=logging.INFO)
|
| 433 |
+
handler = logger.handlers[0]
|
| 434 |
+
buf = io.StringIO()
|
| 435 |
+
original_stream = handler.stream
|
| 436 |
+
handler.stream = buf
|
| 437 |
+
try:
|
| 438 |
+
run_pipeline(
|
| 439 |
+
input_dir=raw_dir, sites_csv=sites_csv, output_path=output_path,
|
| 440 |
+
)
|
| 441 |
+
finally:
|
| 442 |
+
handler.stream = original_stream
|
| 443 |
+
|
| 444 |
+
log_output = buf.getvalue()
|
| 445 |
+
extract_idx = log_output.index("Feature extraction complete:")
|
| 446 |
+
wrote_idx = log_output.index("Wrote processed features to")
|
| 447 |
+
assert extract_idx < wrote_idx, "extraction summary must precede write log"
|