feat(eeg): add compute_features_from_epoch (PSD bands + 5 statistics)
Browse files
src/pipelines/eeg_pipeline.py
CHANGED
|
@@ -15,6 +15,8 @@ from __future__ import annotations
|
|
| 15 |
import mne
|
| 16 |
import numpy as np
|
| 17 |
from mne.preprocessing import ICA
|
|
|
|
|
|
|
| 18 |
|
| 19 |
from src.core.logger import get_logger
|
| 20 |
|
|
@@ -159,3 +161,57 @@ def remove_artifacts_with_ica(
|
|
| 159 |
)
|
| 160 |
ica.apply(out, verbose="ERROR")
|
| 161 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import mne
|
| 16 |
import numpy as np
|
| 17 |
from mne.preprocessing import ICA
|
| 18 |
+
from scipy import signal as scipy_signal
|
| 19 |
+
from scipy import stats as scipy_stats
|
| 20 |
|
| 21 |
from src.core.logger import get_logger
|
| 22 |
|
|
|
|
| 161 |
)
|
| 162 |
ica.apply(out, verbose="ERROR")
|
| 163 |
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
EEG_BANDS: dict[str, tuple[float, float]] = {
|
| 167 |
+
"delta": (1.0, 4.0),
|
| 168 |
+
"theta": (4.0, 8.0),
|
| 169 |
+
"alpha": (8.0, 13.0),
|
| 170 |
+
"beta": (13.0, 30.0),
|
| 171 |
+
"gamma": (30.0, 40.0),
|
| 172 |
+
}
|
| 173 |
+
STATS: tuple[str, ...] = ("mean", "std", "var", "skew", "kurtosis")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> float:
|
| 177 |
+
"""Mean PSD value within the [lo, hi) frequency band."""
|
| 178 |
+
mask = (freqs >= lo) & (freqs < hi)
|
| 179 |
+
if not mask.any():
|
| 180 |
+
return 0.0
|
| 181 |
+
return float(psd[mask].mean())
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
|
| 185 |
+
"""Compute PSD-band + statistical features for one epoch.
|
| 186 |
+
|
| 187 |
+
Per channel, the feature block is:
|
| 188 |
+
[psd_delta, psd_theta, psd_alpha, psd_beta, psd_gamma,
|
| 189 |
+
mean, std, var, skew, kurtosis]
|
| 190 |
+
Channels are stacked in their input order. The resulting 1-D vector has
|
| 191 |
+
length ``n_channels * (len(EEG_BANDS) + len(STATS))``.
|
| 192 |
+
|
| 193 |
+
PSD is computed with Welch's method (`scipy.signal.welch`) at the
|
| 194 |
+
epoch's sample rate. Higher moments use `scipy.stats` with default
|
| 195 |
+
bias correction.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
epoch: A 2-D array shape (n_channels, n_samples).
|
| 199 |
+
sfreq: Sampling rate in Hz.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
A 1-D `np.ndarray` of dtype float64.
|
| 203 |
+
"""
|
| 204 |
+
n_channels, n_samples = epoch.shape
|
| 205 |
+
nperseg = min(256, n_samples)
|
| 206 |
+
feats: list[float] = []
|
| 207 |
+
for ch in range(n_channels):
|
| 208 |
+
x = epoch[ch]
|
| 209 |
+
freqs, psd = scipy_signal.welch(x, fs=sfreq, nperseg=nperseg)
|
| 210 |
+
for _band, (lo, hi) in EEG_BANDS.items():
|
| 211 |
+
feats.append(_band_power(freqs, psd, lo, hi))
|
| 212 |
+
feats.append(float(np.mean(x)))
|
| 213 |
+
feats.append(float(np.std(x)))
|
| 214 |
+
feats.append(float(np.var(x)))
|
| 215 |
+
feats.append(float(scipy_stats.skew(x)))
|
| 216 |
+
feats.append(float(scipy_stats.kurtosis(x)))
|
| 217 |
+
return np.asarray(feats, dtype=np.float64)
|
tests/pipelines/test_eeg_pipeline.py
CHANGED
|
@@ -9,6 +9,7 @@ import pytest
|
|
| 9 |
|
| 10 |
from src.pipelines.eeg_pipeline import (
|
| 11 |
bandpass_filter,
|
|
|
|
| 12 |
is_valid_epoch,
|
| 13 |
remove_artifacts_with_ica,
|
| 14 |
)
|
|
@@ -17,6 +18,10 @@ from src.pipelines.eeg_pipeline import (
|
|
| 17 |
FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class TestIsValidEpoch:
|
| 21 |
def test_accepts_2d_finite_array(self) -> None:
|
| 22 |
epoch = np.zeros((4, 256), dtype=np.float64)
|
|
@@ -175,3 +180,42 @@ class TestRemoveArtifactsWithIca:
|
|
| 175 |
np.testing.assert_allclose(out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12)
|
| 176 |
log_output = buf.getvalue()
|
| 177 |
assert "ICA skipped: eog_ch_name='EOG_DOES_NOT_EXIST' not found" in log_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from src.pipelines.eeg_pipeline import (
|
| 11 |
bandpass_filter,
|
| 12 |
+
compute_features_from_epoch,
|
| 13 |
is_valid_epoch,
|
| 14 |
remove_artifacts_with_ica,
|
| 15 |
)
|
|
|
|
| 18 |
FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
|
| 19 |
|
| 20 |
|
| 21 |
+
EEG_BANDS = ("delta", "theta", "alpha", "beta", "gamma")
|
| 22 |
+
STATS = ("mean", "std", "var", "skew", "kurtosis")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
class TestIsValidEpoch:
|
| 26 |
def test_accepts_2d_finite_array(self) -> None:
|
| 27 |
epoch = np.zeros((4, 256), dtype=np.float64)
|
|
|
|
| 180 |
np.testing.assert_allclose(out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12)
|
| 181 |
log_output = buf.getvalue()
|
| 182 |
assert "ICA skipped: eog_ch_name='EOG_DOES_NOT_EXIST' not found" in log_output
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class TestComputeFeaturesFromEpoch:
|
| 186 |
+
def test_returns_1d_float_array(self) -> None:
|
| 187 |
+
epoch = np.random.default_rng(0).standard_normal((4, 256))
|
| 188 |
+
out = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 189 |
+
assert isinstance(out, np.ndarray)
|
| 190 |
+
assert out.ndim == 1
|
| 191 |
+
assert out.dtype == np.float64
|
| 192 |
+
|
| 193 |
+
def test_feature_count_matches_contract(self) -> None:
|
| 194 |
+
"""Each channel contributes len(EEG_BANDS) PSD features + len(STATS) stats."""
|
| 195 |
+
n_channels = 4
|
| 196 |
+
epoch = np.random.default_rng(0).standard_normal((n_channels, 256))
|
| 197 |
+
out = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 198 |
+
expected = n_channels * (len(EEG_BANDS) + len(STATS))
|
| 199 |
+
assert out.shape == (expected,)
|
| 200 |
+
|
| 201 |
+
def test_alpha_band_dominates_for_alpha_signal(self) -> None:
|
| 202 |
+
"""Pure 10 Hz sine on 1 channel should put most PSD power in alpha (8-13 Hz)."""
|
| 203 |
+
sfreq = 256.0
|
| 204 |
+
t = np.arange(int(sfreq * 2.0)) / sfreq
|
| 205 |
+
signal = np.sin(2 * np.pi * 10.0 * t)[None, :] # (1, n_samples)
|
| 206 |
+
out = compute_features_from_epoch(signal, sfreq=sfreq)
|
| 207 |
+
# Layout for n_channels=1: [psd_delta, psd_theta, psd_alpha, psd_beta, psd_gamma, mean, std, var, skew, kurtosis]
|
| 208 |
+
psd_block = out[: len(EEG_BANDS)]
|
| 209 |
+
alpha_idx = EEG_BANDS.index("alpha")
|
| 210 |
+
assert psd_block[alpha_idx] == psd_block.max()
|
| 211 |
+
|
| 212 |
+
def test_finite_output(self) -> None:
|
| 213 |
+
epoch = np.random.default_rng(0).standard_normal((4, 256))
|
| 214 |
+
out = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 215 |
+
assert np.all(np.isfinite(out))
|
| 216 |
+
|
| 217 |
+
def test_deterministic_for_same_input(self) -> None:
|
| 218 |
+
epoch = np.random.default_rng(0).standard_normal((4, 256))
|
| 219 |
+
a = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 220 |
+
b = compute_features_from_epoch(epoch, sfreq=256.0)
|
| 221 |
+
np.testing.assert_array_equal(a, b)
|