mekosotto Claude Sonnet 4.6 commited on
Commit
4d00e0f
·
1 Parent(s): 853cb9e

refactor(mri): bind ROI_STATS to callables; guard volume/mask shape mismatch

Browse files
src/pipelines/mri_pipeline.py CHANGED
@@ -98,7 +98,6 @@ def mask_brain(
98
  # Default ROI partition: split a (D, H, W) volume into 2×2×2 = 8 octant ROIs.
99
  # Octant index follows binary (z, y, x) ordering: 0..7.
100
  DEFAULT_N_ROI_AXES: tuple[int, int, int] = (2, 2, 2)
101
- ROI_STATS: tuple[str, ...] = ("mean", "std", "p10", "p50", "p90", "voxel_count")
102
 
103
 
104
  def _roi_slices(
@@ -123,18 +122,32 @@ def _roi_slices(
123
  return out
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def _roi_stats_for(values: np.ndarray) -> dict[str, float]:
127
- """Compute the 6 ROI stats. Empty array → all 0.0 (no-NaN contract)."""
128
  if values.size == 0:
129
- return {stat: 0.0 for stat in ROI_STATS}
130
- return {
131
- "mean": float(values.mean()),
132
- "std": float(values.std()),
133
- "p10": float(np.percentile(values, 10)),
134
- "p50": float(np.percentile(values, 50)),
135
- "p90": float(np.percentile(values, 90)),
136
- "voxel_count": float(values.size),
137
- }
138
 
139
 
140
  def extract_features_from_volume(
@@ -150,6 +163,13 @@ def extract_features_from_volume(
150
  90th percentile / voxel count. Empty ROIs (no mask voxels) report all
151
  zeros so the resulting Parquet has no NaN values.
152
 
 
 
 
 
 
 
 
153
  Args:
154
  volume: 3-D numeric `np.ndarray` (already validated).
155
  mask: Boolean `np.ndarray` of the same shape (from `mask_brain`).
@@ -158,7 +178,15 @@ def extract_features_from_volume(
158
  Returns:
159
  Flat dict `{"feat_roi{i}_{stat}": float}` of length
160
  ``prod(n_roi_axes) * len(ROI_STATS)``.
 
 
 
161
  """
 
 
 
 
 
162
  feats: dict[str, float] = {}
163
  slices = _roi_slices(volume.shape, n_roi_axes)
164
  for i, sl in enumerate(slices):
 
98
  # Default ROI partition: split a (D, H, W) volume into 2×2×2 = 8 octant ROIs.
99
  # Octant index follows binary (z, y, x) ordering: 0..7.
100
  DEFAULT_N_ROI_AXES: tuple[int, int, int] = (2, 2, 2)
 
101
 
102
 
103
  def _roi_slices(
 
122
  return out
123
 
124
 
125
+ # Statistical functions, bound to their column-label names. The `ROI_STATS`
126
+ # tuple below is derived from this list so labels and computations cannot
127
+ # drift out of sync (a class of bug the prior parallel-list design was
128
+ # vulnerable to — same pattern as EEG's _STATS_FUNCS).
129
+ #
130
+ # `mean`/`std` use NumPy with `ddof=0` (biased / population estimators).
131
+ # `p10`/`p50`/`p90` use `np.percentile` default linear interpolation.
132
+ # `voxel_count` is stored as float for column-uniformity in the eventual
133
+ # Parquet, but always represents a whole number (assertable via
134
+ # `v == float(int(v))`).
135
+ _ROI_STATS_FUNCS: tuple[tuple[str, "object"], ...] = (
136
+ ("mean", lambda v: float(v.mean())),
137
+ ("std", lambda v: float(v.std())),
138
+ ("p10", lambda v: float(np.percentile(v, 10))),
139
+ ("p50", lambda v: float(np.percentile(v, 50))),
140
+ ("p90", lambda v: float(np.percentile(v, 90))),
141
+ ("voxel_count", lambda v: float(v.size)),
142
+ )
143
+ ROI_STATS: tuple[str, ...] = tuple(name for name, _ in _ROI_STATS_FUNCS)
144
+
145
+
146
  def _roi_stats_for(values: np.ndarray) -> dict[str, float]:
147
+ """Compute the ROI stats. Empty array → all 0.0 (no-NaN contract)."""
148
  if values.size == 0:
149
+ return {name: 0.0 for name, _ in _ROI_STATS_FUNCS}
150
+ return {name: fn(values) for name, fn in _ROI_STATS_FUNCS}
 
 
 
 
 
 
 
151
 
152
 
153
  def extract_features_from_volume(
 
163
  90th percentile / voxel count. Empty ROIs (no mask voxels) report all
164
  zeros so the resulting Parquet has no NaN values.
165
 
166
+ Statistical conventions:
167
+ - ``mean`` / ``std`` use ``ddof=0`` (biased / population estimators).
168
+ - ``p10`` / ``p50`` / ``p90`` use ``np.percentile`` with the default
169
+ linear interpolation.
170
+ - ``voxel_count`` is stored as float for column uniformity but always
171
+ represents a whole number.
172
+
173
  Args:
174
  volume: 3-D numeric `np.ndarray` (already validated).
175
  mask: Boolean `np.ndarray` of the same shape (from `mask_brain`).
 
178
  Returns:
179
  Flat dict `{"feat_roi{i}_{stat}": float}` of length
180
  ``prod(n_roi_axes) * len(ROI_STATS)``.
181
+
182
+ Raises:
183
+ ValueError: if `volume.shape` and `mask.shape` differ.
184
  """
185
+ if volume.shape != mask.shape:
186
+ raise ValueError(
187
+ f"volume.shape {volume.shape} != mask.shape {mask.shape}"
188
+ )
189
+
190
  feats: dict[str, float] = {}
191
  slices = _roi_slices(volume.shape, n_roi_axes)
192
  for i, sl in enumerate(slices):
tests/pipelines/test_mri_pipeline.py CHANGED
@@ -189,3 +189,17 @@ class TestExtractFeaturesFromVolume:
189
  a = extract_features_from_volume(vol, mask)
190
  b = extract_features_from_volume(vol, mask)
191
  assert a == b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  a = extract_features_from_volume(vol, mask)
190
  b = extract_features_from_volume(vol, mask)
191
  assert a == b
192
+
193
+ def test_roi_stats_labels_and_funcs_stay_in_sync(self) -> None:
194
+ """ROI_STATS labels must equal the names in _ROI_STATS_FUNCS — single source of truth."""
195
+ from src.pipelines.mri_pipeline import _ROI_STATS_FUNCS
196
+
197
+ derived_names = tuple(name for name, _ in _ROI_STATS_FUNCS)
198
+ assert derived_names == ROI_STATS
199
+
200
+ def test_raises_on_shape_mismatch(self) -> None:
201
+ """volume.shape and mask.shape must agree — the contract is enforced."""
202
+ vol = np.zeros((8, 8, 8), dtype=np.float64)
203
+ bad_mask = np.zeros((4, 4, 4), dtype=bool)
204
+ with pytest.raises(ValueError, match=r"volume\.shape .* != mask\.shape"):
205
+ extract_features_from_volume(vol, bad_mask)