refactor(mri): bind ROI_STATS to callables; guard volume/mask shape mismatch
Browse files
src/pipelines/mri_pipeline.py
CHANGED
|
@@ -98,7 +98,6 @@ def mask_brain(
|
|
| 98 |
# Default ROI partition: split a (D, H, W) volume into 2×2×2 = 8 octant ROIs.
|
| 99 |
# Octant index follows binary (z, y, x) ordering: 0..7.
|
| 100 |
DEFAULT_N_ROI_AXES: tuple[int, int, int] = (2, 2, 2)
|
| 101 |
-
ROI_STATS: tuple[str, ...] = ("mean", "std", "p10", "p50", "p90", "voxel_count")
|
| 102 |
|
| 103 |
|
| 104 |
def _roi_slices(
|
|
@@ -123,18 +122,32 @@ def _roi_slices(
|
|
| 123 |
return out
|
| 124 |
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def _roi_stats_for(values: np.ndarray) -> dict[str, float]:
|
| 127 |
-
"""Compute the
|
| 128 |
if values.size == 0:
|
| 129 |
-
return {
|
| 130 |
-
return {
|
| 131 |
-
"mean": float(values.mean()),
|
| 132 |
-
"std": float(values.std()),
|
| 133 |
-
"p10": float(np.percentile(values, 10)),
|
| 134 |
-
"p50": float(np.percentile(values, 50)),
|
| 135 |
-
"p90": float(np.percentile(values, 90)),
|
| 136 |
-
"voxel_count": float(values.size),
|
| 137 |
-
}
|
| 138 |
|
| 139 |
|
| 140 |
def extract_features_from_volume(
|
|
@@ -150,6 +163,13 @@ def extract_features_from_volume(
|
|
| 150 |
90th percentile / voxel count. Empty ROIs (no mask voxels) report all
|
| 151 |
zeros so the resulting Parquet has no NaN values.
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
Args:
|
| 154 |
volume: 3-D numeric `np.ndarray` (already validated).
|
| 155 |
mask: Boolean `np.ndarray` of the same shape (from `mask_brain`).
|
|
@@ -158,7 +178,15 @@ def extract_features_from_volume(
|
|
| 158 |
Returns:
|
| 159 |
Flat dict `{"feat_roi{i}_{stat}": float}` of length
|
| 160 |
``prod(n_roi_axes) * len(ROI_STATS)``.
|
|
|
|
|
|
|
|
|
|
| 161 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
feats: dict[str, float] = {}
|
| 163 |
slices = _roi_slices(volume.shape, n_roi_axes)
|
| 164 |
for i, sl in enumerate(slices):
|
|
|
|
| 98 |
# Default ROI partition: split a (D, H, W) volume into 2×2×2 = 8 octant ROIs.
|
| 99 |
# Octant index follows binary (z, y, x) ordering: 0..7.
|
| 100 |
DEFAULT_N_ROI_AXES: tuple[int, int, int] = (2, 2, 2)
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def _roi_slices(
|
|
|
|
| 122 |
return out
|
| 123 |
|
| 124 |
|
| 125 |
+
# Statistical functions, bound to their column-label names. The `ROI_STATS`
|
| 126 |
+
# tuple below is derived from this list so labels and computations cannot
|
| 127 |
+
# drift out of sync (a class of bug the prior parallel-list design was
|
| 128 |
+
# vulnerable to — same pattern as EEG's _STATS_FUNCS).
|
| 129 |
+
#
|
| 130 |
+
# `mean`/`std` use NumPy with `ddof=0` (biased / population estimators).
|
| 131 |
+
# `p10`/`p50`/`p90` use `np.percentile` default linear interpolation.
|
| 132 |
+
# `voxel_count` is stored as float for column-uniformity in the eventual
|
| 133 |
+
# Parquet, but always represents a whole number (assertable via
|
| 134 |
+
# `v == float(int(v))`).
|
| 135 |
+
_ROI_STATS_FUNCS: tuple[tuple[str, "object"], ...] = (
|
| 136 |
+
("mean", lambda v: float(v.mean())),
|
| 137 |
+
("std", lambda v: float(v.std())),
|
| 138 |
+
("p10", lambda v: float(np.percentile(v, 10))),
|
| 139 |
+
("p50", lambda v: float(np.percentile(v, 50))),
|
| 140 |
+
("p90", lambda v: float(np.percentile(v, 90))),
|
| 141 |
+
("voxel_count", lambda v: float(v.size)),
|
| 142 |
+
)
|
| 143 |
+
ROI_STATS: tuple[str, ...] = tuple(name for name, _ in _ROI_STATS_FUNCS)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
def _roi_stats_for(values: np.ndarray) -> dict[str, float]:
|
| 147 |
+
"""Compute the ROI stats. Empty array → all 0.0 (no-NaN contract)."""
|
| 148 |
if values.size == 0:
|
| 149 |
+
return {name: 0.0 for name, _ in _ROI_STATS_FUNCS}
|
| 150 |
+
return {name: fn(values) for name, fn in _ROI_STATS_FUNCS}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
def extract_features_from_volume(
|
|
|
|
| 163 |
90th percentile / voxel count. Empty ROIs (no mask voxels) report all
|
| 164 |
zeros so the resulting Parquet has no NaN values.
|
| 165 |
|
| 166 |
+
Statistical conventions:
|
| 167 |
+
- ``mean`` / ``std`` use ``ddof=0`` (biased / population estimators).
|
| 168 |
+
- ``p10`` / ``p50`` / ``p90`` use ``np.percentile`` with the default
|
| 169 |
+
linear interpolation.
|
| 170 |
+
- ``voxel_count`` is stored as float for column uniformity but always
|
| 171 |
+
represents a whole number.
|
| 172 |
+
|
| 173 |
Args:
|
| 174 |
volume: 3-D numeric `np.ndarray` (already validated).
|
| 175 |
mask: Boolean `np.ndarray` of the same shape (from `mask_brain`).
|
|
|
|
| 178 |
Returns:
|
| 179 |
Flat dict `{"feat_roi{i}_{stat}": float}` of length
|
| 180 |
``prod(n_roi_axes) * len(ROI_STATS)``.
|
| 181 |
+
|
| 182 |
+
Raises:
|
| 183 |
+
ValueError: if `volume.shape` and `mask.shape` differ.
|
| 184 |
"""
|
| 185 |
+
if volume.shape != mask.shape:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"volume.shape {volume.shape} != mask.shape {mask.shape}"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
feats: dict[str, float] = {}
|
| 191 |
slices = _roi_slices(volume.shape, n_roi_axes)
|
| 192 |
for i, sl in enumerate(slices):
|
tests/pipelines/test_mri_pipeline.py
CHANGED
|
@@ -189,3 +189,17 @@ class TestExtractFeaturesFromVolume:
|
|
| 189 |
a = extract_features_from_volume(vol, mask)
|
| 190 |
b = extract_features_from_volume(vol, mask)
|
| 191 |
assert a == b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
a = extract_features_from_volume(vol, mask)
|
| 190 |
b = extract_features_from_volume(vol, mask)
|
| 191 |
assert a == b
|
| 192 |
+
|
| 193 |
+
def test_roi_stats_labels_and_funcs_stay_in_sync(self) -> None:
|
| 194 |
+
"""ROI_STATS labels must equal the names in _ROI_STATS_FUNCS — single source of truth."""
|
| 195 |
+
from src.pipelines.mri_pipeline import _ROI_STATS_FUNCS
|
| 196 |
+
|
| 197 |
+
derived_names = tuple(name for name, _ in _ROI_STATS_FUNCS)
|
| 198 |
+
assert derived_names == ROI_STATS
|
| 199 |
+
|
| 200 |
+
def test_raises_on_shape_mismatch(self) -> None:
|
| 201 |
+
"""volume.shape and mask.shape must agree — the contract is enforced."""
|
| 202 |
+
vol = np.zeros((8, 8, 8), dtype=np.float64)
|
| 203 |
+
bad_mask = np.zeros((4, 4, 4), dtype=bool)
|
| 204 |
+
with pytest.raises(ValueError, match=r"volume\.shape .* != mask\.shape"):
|
| 205 |
+
extract_features_from_volume(vol, bad_mask)
|