refactor(eeg): bind STATS labels to callables; document moment conventions
Browse files
src/pipelines/eeg_pipeline.py
CHANGED
|
@@ -170,7 +170,6 @@ EEG_BANDS: dict[str, tuple[float, float]] = {
|
|
| 170 |
"beta": (13.0, 30.0),
|
| 171 |
"gamma": (30.0, 40.0),
|
| 172 |
}
|
| 173 |
-
STATS: tuple[str, ...] = ("mean", "std", "var", "skew", "kurtosis")
|
| 174 |
|
| 175 |
|
| 176 |
def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> float:
|
|
@@ -181,6 +180,44 @@ def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> flo
|
|
| 181 |
return float(psd[mask].mean())
|
| 182 |
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
|
| 185 |
"""Compute PSD-band + statistical features for one epoch.
|
| 186 |
|
|
@@ -190,12 +227,25 @@ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
|
|
| 190 |
Channels are stacked in their input order. The resulting 1-D vector has
|
| 191 |
length ``n_channels * (len(EEG_BANDS) + len(STATS))``.
|
| 192 |
|
| 193 |
-
PSD
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
Args:
|
| 198 |
-
epoch: A 2-D array shape (n_channels, n_samples).
|
| 199 |
sfreq: Sampling rate in Hz.
|
| 200 |
|
| 201 |
Returns:
|
|
@@ -209,9 +259,6 @@ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
|
|
| 209 |
freqs, psd = scipy_signal.welch(x, fs=sfreq, nperseg=nperseg)
|
| 210 |
for _band, (lo, hi) in EEG_BANDS.items():
|
| 211 |
feats.append(_band_power(freqs, psd, lo, hi))
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
feats.append(float(np.var(x)))
|
| 215 |
-
feats.append(float(scipy_stats.skew(x)))
|
| 216 |
-
feats.append(float(scipy_stats.kurtosis(x)))
|
| 217 |
return np.asarray(feats, dtype=np.float64)
|
|
|
|
| 170 |
"beta": (13.0, 30.0),
|
| 171 |
"gamma": (30.0, 40.0),
|
| 172 |
}
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> float:
|
|
|
|
| 180 |
return float(psd[mask].mean())
|
| 181 |
|
| 182 |
|
| 183 |
+
# Statistical-moment functions, bound to their column-label names. The
|
| 184 |
+
# `STATS` tuple below is derived from this list so labels and computations
|
| 185 |
+
# can never drift out of sync (a class of bug the original parallel-list
|
| 186 |
+
# design was vulnerable to).
|
| 187 |
+
_STATS_FUNCS: tuple[tuple[str, "_StatFn"], ...] # populated below
|
| 188 |
+
_StatFn = "callable that maps a 1-D channel array to a single float"
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _stat_mean(x: np.ndarray) -> float:
|
| 192 |
+
return float(np.mean(x))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _stat_std(x: np.ndarray) -> float:
|
| 196 |
+
return float(np.std(x))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _stat_var(x: np.ndarray) -> float:
|
| 200 |
+
return float(np.var(x))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _stat_skew(x: np.ndarray) -> float:
|
| 204 |
+
return float(scipy_stats.skew(x))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _stat_kurtosis(x: np.ndarray) -> float:
|
| 208 |
+
return float(scipy_stats.kurtosis(x))
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
_STATS_FUNCS = (
|
| 212 |
+
("mean", _stat_mean),
|
| 213 |
+
("std", _stat_std),
|
| 214 |
+
("var", _stat_var),
|
| 215 |
+
("skew", _stat_skew),
|
| 216 |
+
("kurtosis", _stat_kurtosis),
|
| 217 |
+
)
|
| 218 |
+
STATS: tuple[str, ...] = tuple(name for name, _ in _STATS_FUNCS)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
|
| 222 |
"""Compute PSD-band + statistical features for one epoch.
|
| 223 |
|
|
|
|
| 227 |
Channels are stacked in their input order. The resulting 1-D vector has
|
| 228 |
length ``n_channels * (len(EEG_BANDS) + len(STATS))``.
|
| 229 |
|
| 230 |
+
PSD uses Welch's method (`scipy.signal.welch`, `nperseg=min(256, n_samples)`).
|
| 231 |
+
For meaningful Welch averaging, the epoch should contain at least
|
| 232 |
+
`2 * nperseg` samples (e.g. ≥2 seconds at 256 Hz); shorter epochs degrade
|
| 233 |
+
to a single-segment periodogram with high estimation variance.
|
| 234 |
+
|
| 235 |
+
Statistical conventions:
|
| 236 |
+
- ``mean``, ``std``, ``var`` use NumPy with ``ddof=0`` (biased / population
|
| 237 |
+
estimators). For sample statistics callers must apply ``ddof=1`` adjustment
|
| 238 |
+
downstream.
|
| 239 |
+
- ``skew`` uses ``scipy.stats.skew(bias=True)`` (biased estimator).
|
| 240 |
+
- ``kurtosis`` uses ``scipy.stats.kurtosis(fisher=True, bias=True)`` —
|
| 241 |
+
Fisher's *excess* kurtosis (Gaussian → 0, not 3). Add 3 if Pearson
|
| 242 |
+
kurtosis is required downstream.
|
| 243 |
+
|
| 244 |
+
Precondition: `epoch` must be finite (no NaN/inf). Filter via
|
| 245 |
+
`is_valid_epoch` before calling — feature values are NaN-propagating.
|
| 246 |
|
| 247 |
Args:
|
| 248 |
+
epoch: A 2-D array shape (n_channels, n_samples), all-finite.
|
| 249 |
sfreq: Sampling rate in Hz.
|
| 250 |
|
| 251 |
Returns:
|
|
|
|
| 259 |
freqs, psd = scipy_signal.welch(x, fs=sfreq, nperseg=nperseg)
|
| 260 |
for _band, (lo, hi) in EEG_BANDS.items():
|
| 261 |
feats.append(_band_power(freqs, psd, lo, hi))
|
| 262 |
+
for _name, fn in _STATS_FUNCS:
|
| 263 |
+
feats.append(fn(x))
|
|
|
|
|
|
|
|
|
|
| 264 |
return np.asarray(feats, dtype=np.float64)
|
tests/pipelines/test_eeg_pipeline.py
CHANGED
|
@@ -219,3 +219,10 @@ class TestComputeFeaturesFromEpoch:
|
|
| 219 |
a = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 220 |
b = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 221 |
np.testing.assert_array_equal(a, b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
a = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 220 |
b = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 221 |
np.testing.assert_array_equal(a, b)
|
| 222 |
+
|
| 223 |
+
def test_stats_labels_and_funcs_stay_in_sync(self) -> None:
|
| 224 |
+
"""STATS labels must equal the names in _STATS_FUNCS — single source of truth."""
|
| 225 |
+
from src.pipelines.eeg_pipeline import _STATS_FUNCS
|
| 226 |
+
|
| 227 |
+
derived_names = tuple(name for name, _ in _STATS_FUNCS)
|
| 228 |
+
assert derived_names == STATS
|