feat(mri): add extract_features_from_volume (8 ROI octants × 6 stats)
Browse files
src/pipelines/mri_pipeline.py
CHANGED
|
@@ -93,3 +93,77 @@ def mask_brain(
|
|
| 93 |
float(volume.min()), float(volume.max()), intensity_threshold,
|
| 94 |
)
|
| 95 |
return cleaned
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
float(volume.min()), float(volume.max()), intensity_threshold,
|
| 94 |
)
|
| 95 |
return cleaned
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Default ROI partition: split a (D, H, W) volume into 2×2×2 = 8 octant ROIs.
|
| 99 |
+
# Octant index follows binary (z, y, x) ordering: 0..7.
|
| 100 |
+
DEFAULT_N_ROI_AXES: tuple[int, int, int] = (2, 2, 2)
|
| 101 |
+
ROI_STATS: tuple[str, ...] = ("mean", "std", "p10", "p50", "p90", "voxel_count")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _roi_slices(
|
| 105 |
+
shape: tuple[int, int, int],
|
| 106 |
+
n_roi_axes: tuple[int, int, int],
|
| 107 |
+
) -> list[tuple[slice, slice, slice]]:
|
| 108 |
+
"""Generate the ROI slice list in deterministic (z, y, x) octant order."""
|
| 109 |
+
nz, ny, nx = n_roi_axes
|
| 110 |
+
dz, dy, dx = shape
|
| 111 |
+
bins_z = np.array_split(np.arange(dz), nz)
|
| 112 |
+
bins_y = np.array_split(np.arange(dy), ny)
|
| 113 |
+
bins_x = np.array_split(np.arange(dx), nx)
|
| 114 |
+
out: list[tuple[slice, slice, slice]] = []
|
| 115 |
+
for bz in bins_z:
|
| 116 |
+
for by in bins_y:
|
| 117 |
+
for bx in bins_x:
|
| 118 |
+
out.append((
|
| 119 |
+
slice(bz[0], bz[-1] + 1),
|
| 120 |
+
slice(by[0], by[-1] + 1),
|
| 121 |
+
slice(bx[0], bx[-1] + 1),
|
| 122 |
+
))
|
| 123 |
+
return out
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _roi_stats_for(values: np.ndarray) -> dict[str, float]:
|
| 127 |
+
"""Compute the 6 ROI stats. Empty array → all 0.0 (no-NaN contract)."""
|
| 128 |
+
if values.size == 0:
|
| 129 |
+
return {stat: 0.0 for stat in ROI_STATS}
|
| 130 |
+
return {
|
| 131 |
+
"mean": float(values.mean()),
|
| 132 |
+
"std": float(values.std()),
|
| 133 |
+
"p10": float(np.percentile(values, 10)),
|
| 134 |
+
"p50": float(np.percentile(values, 50)),
|
| 135 |
+
"p90": float(np.percentile(values, 90)),
|
| 136 |
+
"voxel_count": float(values.size),
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def extract_features_from_volume(
|
| 141 |
+
volume: np.ndarray,
|
| 142 |
+
mask: np.ndarray,
|
| 143 |
+
n_roi_axes: tuple[int, int, int] = DEFAULT_N_ROI_AXES,
|
| 144 |
+
) -> dict[str, float]:
|
| 145 |
+
"""Compute per-ROI summary statistics from a masked volume.
|
| 146 |
+
|
| 147 |
+
The volume is partitioned into ``prod(n_roi_axes)`` axis-aligned octants
|
| 148 |
+
in deterministic (z, y, x) order. For each ROI, intensity values from
|
| 149 |
+
voxels where `mask` is True are summarized via mean / std / 10th, 50th,
|
| 150 |
+
90th percentile / voxel count. Empty ROIs (no mask voxels) report all
|
| 151 |
+
zeros so the resulting Parquet has no NaN values.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
volume: 3-D numeric `np.ndarray` (already validated).
|
| 155 |
+
mask: Boolean `np.ndarray` of the same shape (from `mask_brain`).
|
| 156 |
+
n_roi_axes: ROI grid along (z, y, x). Default `(2, 2, 2)` → 8 ROIs.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Flat dict `{"feat_roi{i}_{stat}": float}` of length
|
| 160 |
+
``prod(n_roi_axes) * len(ROI_STATS)``.
|
| 161 |
+
"""
|
| 162 |
+
feats: dict[str, float] = {}
|
| 163 |
+
slices = _roi_slices(volume.shape, n_roi_axes)
|
| 164 |
+
for i, sl in enumerate(slices):
|
| 165 |
+
roi_values = volume[sl][mask[sl]]
|
| 166 |
+
stats = _roi_stats_for(roi_values)
|
| 167 |
+
for stat_name, stat_val in stats.items():
|
| 168 |
+
feats[f"feat_roi{i}_{stat_name}"] = stat_val
|
| 169 |
+
return feats
|
tests/pipelines/test_mri_pipeline.py
CHANGED
|
@@ -8,6 +8,9 @@ import numpy as np
|
|
| 8 |
import pytest
|
| 9 |
|
| 10 |
from src.pipelines.mri_pipeline import (
|
|
|
|
|
|
|
|
|
|
| 11 |
is_valid_volume,
|
| 12 |
mask_brain,
|
| 13 |
)
|
|
@@ -128,3 +131,61 @@ class TestMaskBrain:
|
|
| 128 |
log_output = buf.getvalue()
|
| 129 |
assert "all-False mask" in log_output
|
| 130 |
assert "downstream features for this volume will be all-zero" in log_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import pytest
|
| 9 |
|
| 10 |
from src.pipelines.mri_pipeline import (
|
| 11 |
+
DEFAULT_N_ROI_AXES,
|
| 12 |
+
ROI_STATS,
|
| 13 |
+
extract_features_from_volume,
|
| 14 |
is_valid_volume,
|
| 15 |
mask_brain,
|
| 16 |
)
|
|
|
|
| 131 |
log_output = buf.getvalue()
|
| 132 |
assert "all-False mask" in log_output
|
| 133 |
assert "downstream features for this volume will be all-zero" in log_output
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class TestExtractFeaturesFromVolume:
|
| 137 |
+
def _load_subject(self, sid: str) -> np.ndarray:
|
| 138 |
+
return nib.load(FIXTURE_DIR / f"{sid}.nii.gz").get_fdata()
|
| 139 |
+
|
| 140 |
+
def test_returns_dict_with_correct_keys(self) -> None:
|
| 141 |
+
vol = self._load_subject("subject_0")
|
| 142 |
+
mask = mask_brain(vol)
|
| 143 |
+
feats = extract_features_from_volume(vol, mask)
|
| 144 |
+
n_roi = int(np.prod(DEFAULT_N_ROI_AXES))
|
| 145 |
+
expected = {
|
| 146 |
+
f"feat_roi{i}_{stat}"
|
| 147 |
+
for i in range(n_roi)
|
| 148 |
+
for stat in ROI_STATS
|
| 149 |
+
}
|
| 150 |
+
assert set(feats.keys()) == expected
|
| 151 |
+
|
| 152 |
+
def test_feature_count_matches_contract(self) -> None:
|
| 153 |
+
vol = self._load_subject("subject_0")
|
| 154 |
+
mask = mask_brain(vol)
|
| 155 |
+
feats = extract_features_from_volume(vol, mask)
|
| 156 |
+
n_roi = int(np.prod(DEFAULT_N_ROI_AXES))
|
| 157 |
+
assert len(feats) == n_roi * len(ROI_STATS)
|
| 158 |
+
|
| 159 |
+
def test_all_features_finite_float(self) -> None:
|
| 160 |
+
vol = self._load_subject("subject_0")
|
| 161 |
+
mask = mask_brain(vol)
|
| 162 |
+
feats = extract_features_from_volume(vol, mask)
|
| 163 |
+
for k, v in feats.items():
|
| 164 |
+
assert isinstance(v, float), f"{k}: {type(v).__name__}"
|
| 165 |
+
assert np.isfinite(v), f"{k}: {v}"
|
| 166 |
+
|
| 167 |
+
def test_voxel_count_is_integer_valued(self) -> None:
|
| 168 |
+
vol = self._load_subject("subject_0")
|
| 169 |
+
mask = mask_brain(vol)
|
| 170 |
+
feats = extract_features_from_volume(vol, mask)
|
| 171 |
+
for k, v in feats.items():
|
| 172 |
+
if k.endswith("_voxel_count"):
|
| 173 |
+
# voxel_count stored as float for column-uniformity, but must be
|
| 174 |
+
# a whole number.
|
| 175 |
+
assert v == float(int(v))
|
| 176 |
+
|
| 177 |
+
def test_empty_mask_yields_zero_features(self) -> None:
|
| 178 |
+
"""If a volume has zero brain voxels (mask all False), every stat
|
| 179 |
+
must default to 0.0 — not NaN — to preserve the no-NaN Parquet contract."""
|
| 180 |
+
vol = self._load_subject("subject_0")
|
| 181 |
+
empty_mask = np.zeros_like(vol, dtype=bool)
|
| 182 |
+
feats = extract_features_from_volume(vol, empty_mask)
|
| 183 |
+
for k, v in feats.items():
|
| 184 |
+
assert v == 0.0, f"{k}: {v}"
|
| 185 |
+
|
| 186 |
+
def test_deterministic_for_same_input(self) -> None:
|
| 187 |
+
vol = self._load_subject("subject_0")
|
| 188 |
+
mask = mask_brain(vol)
|
| 189 |
+
a = extract_features_from_volume(vol, mask)
|
| 190 |
+
b = extract_features_from_volume(vol, mask)
|
| 191 |
+
assert a == b
|