mekosotto Claude Sonnet 4.6 commited on
Commit
a1ab9ac
·
1 Parent(s): 9931366

feat(eeg): add compute_features_from_epoch (PSD bands + 5 statistics)

Browse files
src/pipelines/eeg_pipeline.py CHANGED
@@ -15,6 +15,8 @@ from __future__ import annotations
15
  import mne
16
  import numpy as np
17
  from mne.preprocessing import ICA
 
 
18
 
19
  from src.core.logger import get_logger
20
 
@@ -159,3 +161,57 @@ def remove_artifacts_with_ica(
159
  )
160
  ica.apply(out, verbose="ERROR")
161
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  import mne
16
  import numpy as np
17
  from mne.preprocessing import ICA
18
+ from scipy import signal as scipy_signal
19
+ from scipy import stats as scipy_stats
20
 
21
  from src.core.logger import get_logger
22
 
 
161
  )
162
  ica.apply(out, verbose="ERROR")
163
  return out
164
+
165
+
166
+ EEG_BANDS: dict[str, tuple[float, float]] = {
167
+ "delta": (1.0, 4.0),
168
+ "theta": (4.0, 8.0),
169
+ "alpha": (8.0, 13.0),
170
+ "beta": (13.0, 30.0),
171
+ "gamma": (30.0, 40.0),
172
+ }
173
+ STATS: tuple[str, ...] = ("mean", "std", "var", "skew", "kurtosis")
174
+
175
+
176
+ def _band_power(freqs: np.ndarray, psd: np.ndarray, lo: float, hi: float) -> float:
177
+ """Mean PSD value within the [lo, hi) frequency band."""
178
+ mask = (freqs >= lo) & (freqs < hi)
179
+ if not mask.any():
180
+ return 0.0
181
+ return float(psd[mask].mean())
182
+
183
+
184
+ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
185
+ """Compute PSD-band + statistical features for one epoch.
186
+
187
+ Per channel, the feature block is:
188
+ [psd_delta, psd_theta, psd_alpha, psd_beta, psd_gamma,
189
+ mean, std, var, skew, kurtosis]
190
+ Channels are stacked in their input order. The resulting 1-D vector has
191
+ length ``n_channels * (len(EEG_BANDS) + len(STATS))``.
192
+
193
+ PSD is computed with Welch's method (`scipy.signal.welch`) at the
194
+ epoch's sample rate. Higher moments use `scipy.stats` with default
195
+ bias correction.
196
+
197
+ Args:
198
+ epoch: A 2-D array shape (n_channels, n_samples).
199
+ sfreq: Sampling rate in Hz.
200
+
201
+ Returns:
202
+ A 1-D `np.ndarray` of dtype float64.
203
+ """
204
+ n_channels, n_samples = epoch.shape
205
+ nperseg = min(256, n_samples)
206
+ feats: list[float] = []
207
+ for ch in range(n_channels):
208
+ x = epoch[ch]
209
+ freqs, psd = scipy_signal.welch(x, fs=sfreq, nperseg=nperseg)
210
+ for _band, (lo, hi) in EEG_BANDS.items():
211
+ feats.append(_band_power(freqs, psd, lo, hi))
212
+ feats.append(float(np.mean(x)))
213
+ feats.append(float(np.std(x)))
214
+ feats.append(float(np.var(x)))
215
+ feats.append(float(scipy_stats.skew(x)))
216
+ feats.append(float(scipy_stats.kurtosis(x)))
217
+ return np.asarray(feats, dtype=np.float64)
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -9,6 +9,7 @@ import pytest
9
 
10
  from src.pipelines.eeg_pipeline import (
11
  bandpass_filter,
 
12
  is_valid_epoch,
13
  remove_artifacts_with_ica,
14
  )
@@ -17,6 +18,10 @@ from src.pipelines.eeg_pipeline import (
17
  FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
18
 
19
 
 
 
 
 
20
  class TestIsValidEpoch:
21
  def test_accepts_2d_finite_array(self) -> None:
22
  epoch = np.zeros((4, 256), dtype=np.float64)
@@ -175,3 +180,42 @@ class TestRemoveArtifactsWithIca:
175
  np.testing.assert_allclose(out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12)
176
  log_output = buf.getvalue()
177
  assert "ICA skipped: eog_ch_name='EOG_DOES_NOT_EXIST' not found" in log_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  from src.pipelines.eeg_pipeline import (
11
  bandpass_filter,
12
+ compute_features_from_epoch,
13
  is_valid_epoch,
14
  remove_artifacts_with_ica,
15
  )
 
18
  FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
19
 
20
 
21
+ EEG_BANDS = ("delta", "theta", "alpha", "beta", "gamma")
22
+ STATS = ("mean", "std", "var", "skew", "kurtosis")
23
+
24
+
25
  class TestIsValidEpoch:
26
  def test_accepts_2d_finite_array(self) -> None:
27
  epoch = np.zeros((4, 256), dtype=np.float64)
 
180
  np.testing.assert_allclose(out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12)
181
  log_output = buf.getvalue()
182
  assert "ICA skipped: eog_ch_name='EOG_DOES_NOT_EXIST' not found" in log_output
183
+
184
+
185
+ class TestComputeFeaturesFromEpoch:
186
+ def test_returns_1d_float_array(self) -> None:
187
+ epoch = np.random.default_rng(0).standard_normal((4, 256))
188
+ out = compute_features_from_epoch(epoch, sfreq=256.0)
189
+ assert isinstance(out, np.ndarray)
190
+ assert out.ndim == 1
191
+ assert out.dtype == np.float64
192
+
193
+ def test_feature_count_matches_contract(self) -> None:
194
+ """Each channel contributes len(EEG_BANDS) PSD features + len(STATS) stats."""
195
+ n_channels = 4
196
+ epoch = np.random.default_rng(0).standard_normal((n_channels, 256))
197
+ out = compute_features_from_epoch(epoch, sfreq=256.0)
198
+ expected = n_channels * (len(EEG_BANDS) + len(STATS))
199
+ assert out.shape == (expected,)
200
+
201
+ def test_alpha_band_dominates_for_alpha_signal(self) -> None:
202
+ """Pure 10 Hz sine on 1 channel should put most PSD power in alpha (8-13 Hz)."""
203
+ sfreq = 256.0
204
+ t = np.arange(int(sfreq * 2.0)) / sfreq
205
+ signal = np.sin(2 * np.pi * 10.0 * t)[None, :] # (1, n_samples)
206
+ out = compute_features_from_epoch(signal, sfreq=sfreq)
207
+ # Layout for n_channels=1: [psd_delta, psd_theta, psd_alpha, psd_beta, psd_gamma, mean, std, var, skew, kurtosis]
208
+ psd_block = out[: len(EEG_BANDS)]
209
+ alpha_idx = EEG_BANDS.index("alpha")
210
+ assert psd_block[alpha_idx] == psd_block.max()
211
+
212
+ def test_finite_output(self) -> None:
213
+ epoch = np.random.default_rng(0).standard_normal((4, 256))
214
+ out = compute_features_from_epoch(epoch, sfreq=256.0)
215
+ assert np.all(np.isfinite(out))
216
+
217
+ def test_deterministic_for_same_input(self) -> None:
218
+ epoch = np.random.default_rng(0).standard_normal((4, 256))
219
+ a = compute_features_from_epoch(epoch, sfreq=256.0)
220
+ b = compute_features_from_epoch(epoch, sfreq=256.0)
221
+ np.testing.assert_array_equal(a, b)