mekosotto Claude Opus 4.7 (1M context) commited on
Commit
8da57c6
·
1 Parent(s): a1ab9ac

refactor(eeg): bind STATS labels to callables; document moment conventions

Browse files
src/pipelines/eeg_pipeline.py CHANGED
@@ -170,7 +170,6 @@ EEG_BANDS: dict[str, tuple[float, float]] = {
170
  "beta": (13.0, 30.0),
171
  "gamma": (30.0, 40.0),
172
  }
173
- STATS: tuple[str, ...] = ("mean", "std", "var", "skew", "kurtosis")
174
 
175
 
176
  def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> float:
@@ -181,6 +180,44 @@ def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> flo
181
  return float(psd[mask].mean())
182
 
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
185
  """Compute PSD-band + statistical features for one epoch.
186
 
@@ -190,12 +227,25 @@ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
190
  Channels are stacked in their input order. The resulting 1-D vector has
191
  length ``n_channels * (len(EEG_BANDS) + len(STATS))``.
192
 
193
- PSD is computed with Welch's method (`scipy.signal.welch`) at the
194
- epoch's sample rate. Higher moments use `scipy.stats` with default
195
- bias correction.
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  Args:
198
- epoch: A 2-D array shape (n_channels, n_samples).
199
  sfreq: Sampling rate in Hz.
200
 
201
  Returns:
@@ -209,9 +259,6 @@ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
209
  freqs, psd = scipy_signal.welch(x, fs=sfreq, nperseg=nperseg)
210
  for _band, (lo, hi) in EEG_BANDS.items():
211
  feats.append(_band_power(freqs, psd, lo, hi))
212
- feats.append(float(np.mean(x)))
213
- feats.append(float(np.std(x)))
214
- feats.append(float(np.var(x)))
215
- feats.append(float(scipy_stats.skew(x)))
216
- feats.append(float(scipy_stats.kurtosis(x)))
217
  return np.asarray(feats, dtype=np.float64)
 
170
  "beta": (13.0, 30.0),
171
  "gamma": (30.0, 40.0),
172
  }
 
173
 
174
 
175
  def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> float:
 
180
  return float(psd[mask].mean())
181
 
182
 
183
+ # Statistical-moment functions, bound to their column-label names. The
184
+ # `STATS` tuple below is derived from this list so labels and computations
185
+ # can never drift out of sync (a class of bug the original parallel-list
186
+ # design was vulnerable to).
187
+ _STATS_FUNCS: tuple[tuple[str, "_StatFn"], ...] # populated below
188
+ _StatFn = "callable that maps a 1-D channel array to a single float"
189
+
190
+
191
+ def _stat_mean(x: np.ndarray) -> float:
192
+ return float(np.mean(x))
193
+
194
+
195
+ def _stat_std(x: np.ndarray) -> float:
196
+ return float(np.std(x))
197
+
198
+
199
+ def _stat_var(x: np.ndarray) -> float:
200
+ return float(np.var(x))
201
+
202
+
203
+ def _stat_skew(x: np.ndarray) -> float:
204
+ return float(scipy_stats.skew(x))
205
+
206
+
207
+ def _stat_kurtosis(x: np.ndarray) -> float:
208
+ return float(scipy_stats.kurtosis(x))
209
+
210
+
211
+ _STATS_FUNCS = (
212
+ ("mean", _stat_mean),
213
+ ("std", _stat_std),
214
+ ("var", _stat_var),
215
+ ("skew", _stat_skew),
216
+ ("kurtosis", _stat_kurtosis),
217
+ )
218
+ STATS: tuple[str, ...] = tuple(name for name, _ in _STATS_FUNCS)
219
+
220
+
221
  def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
222
  """Compute PSD-band + statistical features for one epoch.
223
 
 
227
  Channels are stacked in their input order. The resulting 1-D vector has
228
  length ``n_channels * (len(EEG_BANDS) + len(STATS))``.
229
 
230
+ PSD uses Welch's method (`scipy.signal.welch`, `nperseg=min(256, n_samples)`).
231
+ For meaningful Welch averaging, the epoch should contain at least
232
+ `2 * nperseg` samples (e.g. ≥2 seconds at 256 Hz); shorter epochs degrade
233
+ to a single-segment periodogram with high estimation variance.
234
+
235
+ Statistical conventions:
236
+ - ``mean``, ``std``, ``var`` use NumPy with ``ddof=0`` (biased / population
237
+ estimators). For sample statistics callers must apply ``ddof=1`` adjustment
238
+ downstream.
239
+ - ``skew`` uses ``scipy.stats.skew(bias=True)`` (biased estimator).
240
+ - ``kurtosis`` uses ``scipy.stats.kurtosis(fisher=True, bias=True)`` —
241
+ Fisher's *excess* kurtosis (Gaussian → 0, not 3). Add 3 if Pearson
242
+ kurtosis is required downstream.
243
+
244
+ Precondition: `epoch` must be finite (no NaN/inf). Filter via
245
+ `is_valid_epoch` before calling — feature values are NaN-propagating.
246
 
247
  Args:
248
+ epoch: A 2-D array shape (n_channels, n_samples), all-finite.
249
  sfreq: Sampling rate in Hz.
250
 
251
  Returns:
 
259
  freqs, psd = scipy_signal.welch(x, fs=sfreq, nperseg=nperseg)
260
  for _band, (lo, hi) in EEG_BANDS.items():
261
  feats.append(_band_power(freqs, psd, lo, hi))
262
+ for _name, fn in _STATS_FUNCS:
263
+ feats.append(fn(x))
 
 
 
264
  return np.asarray(feats, dtype=np.float64)
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -219,3 +219,10 @@ class TestComputeFeaturesFromEpoch:
219
  a = compute_features_from_epoch(epoch, sfreq=256.0)
220
  b = compute_features_from_epoch(epoch, sfreq=256.0)
221
  np.testing.assert_array_equal(a, b)
 
 
 
 
 
 
 
 
219
  a = compute_features_from_epoch(epoch, sfreq=256.0)
220
  b = compute_features_from_epoch(epoch, sfreq=256.0)
221
  np.testing.assert_array_equal(a, b)
222
+
223
+ def test_stats_labels_and_funcs_stay_in_sync(self) -> None:
224
+ """STATS labels must equal the names in _STATS_FUNCS — single source of truth."""
225
+ from src.pipelines.eeg_pipeline import _STATS_FUNCS
226
+
227
+ derived_names = tuple(name for name, _ in _STATS_FUNCS)
228
+ assert derived_names == STATS