mekosotto Claude Sonnet 4.6 commited on
Commit
32e13cf
·
1 Parent(s): 8da57c6

feat(eeg): flatten 3D epochs into deterministic 2D feat_<ch>_<band|stat> table

Browse files
src/pipelines/eeg_pipeline.py CHANGED
@@ -14,6 +14,7 @@ from __future__ import annotations
14
 
15
  import mne
16
  import numpy as np
 
17
  from mne.preprocessing import ICA
18
  from scipy import signal as scipy_signal
19
  from scipy import stats as scipy_stats
@@ -130,6 +131,18 @@ def remove_artifacts_with_ica(
130
  )
131
  return out
132
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # Cap n_components at rank-1. Average reference (if applied) reduces rank
134
  # to n_eeg - 1; using that as the ceiling is safe for both referenced and
135
  # unreferenced data and avoids ValueError from ICA.fit on small recordings.
@@ -262,3 +275,118 @@ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
262
  for _name, fn in _STATS_FUNCS:
263
  feats.append(fn(x))
264
  return np.asarray(feats, dtype=np.float64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  import mne
16
  import numpy as np
17
+ import pandas as pd
18
  from mne.preprocessing import ICA
19
  from scipy import signal as scipy_signal
20
  from scipy import stats as scipy_stats
 
131
  )
132
  return out
133
 
134
+ # Guard: ICA.fit cannot handle NaN/inf in the data (scipy SVD will raise).
135
+ # If the raw contains non-finite samples, skip ICA so the NaN propagates
136
+ # to the epoch-level validity check in extract_features_from_recording
137
+ # where it will be cleanly dropped with a WARNING.
138
+ eeg_picks_check = mne.pick_types(out.info, eeg=True, meg=False)
139
+ if not np.all(np.isfinite(out.get_data(picks=eeg_picks_check))):
140
+ logger.warning(
141
+ "ICA skipped: EEG data contains NaN/inf values; "
142
+ "invalid epochs will be dropped downstream"
143
+ )
144
+ return out
145
+
146
  # Cap n_components at rank-1. Average reference (if applied) reduces rank
147
  # to n_eeg - 1; using that as the ceiling is safe for both referenced and
148
  # unreferenced data and avoids ValueError from ICA.fit on small recordings.
 
275
  for _name, fn in _STATS_FUNCS:
276
  feats.append(fn(x))
277
  return np.asarray(feats, dtype=np.float64)
278
+
279
+
280
+ def _build_feature_columns(eeg_ch_names: list[str]) -> list[str]:
281
+ """Generate the deterministic, in-channel-order column ordering."""
282
+ cols: list[str] = []
283
+ for ch in eeg_ch_names:
284
+ for band in EEG_BANDS:
285
+ cols.append(f"feat_{ch}_psd_{band}")
286
+ for stat in STATS:
287
+ cols.append(f"feat_{ch}_{stat}")
288
+ return cols
289
+
290
+
291
+ def extract_features_from_recording(
292
+ raw: mne.io.BaseRaw,
293
+ epoch_duration_s: float = 2.0,
294
+ eog_ch_name: str | None = None,
295
+ n_components: int = 15,
296
+ random_state: int = 97,
297
+ ) -> pd.DataFrame:
298
+ """Run the EEG pipeline on a Raw and return a 2-D feature DataFrame.
299
+
300
+ Steps:
301
+ 1. Bandpass filter (1-40 Hz).
302
+ 2. ICA-based EOG artifact rejection (skipped if `eog_ch_name` is None).
303
+ 3. Slice into fixed-duration epochs.
304
+ 4. Drop any epoch with NaN/inf samples (logged WARNING).
305
+ 5. Compute features per epoch and stack into a DataFrame whose columns
306
+ are `feat_<channel>_psd_<band>` and `feat_<channel>_<stat>`.
307
+
308
+ Args:
309
+ raw: Loaded `mne.io.BaseRaw` (must be `.load_data()`'d).
310
+ epoch_duration_s: Length of each fixed-duration epoch in seconds.
311
+ eog_ch_name: Name of EOG reference channel for ICA. None disables ICA.
312
+ n_components: Cap on ICA components.
313
+ random_state: Seed for ICA's solver (determinism).
314
+
315
+ Returns:
316
+ A `pd.DataFrame` with one row per valid epoch and ``n_eeg_channels *
317
+ (len(EEG_BANDS) + len(STATS))`` ``feat_*`` columns.
318
+ """
319
+ # Pre-screen epochs on the original (unfiltered) raw data so that NaN/inf
320
+ # values injected into one epoch window do not spread across the full signal
321
+ # via the bandpass convolution and invalidate neighbouring epochs.
322
+ sfreq = float(raw.info["sfreq"])
323
+ n_samples_per_epoch = int(round(epoch_duration_s * sfreq))
324
+ pre_picks = mne.pick_types(raw.info, eeg=True, meg=False, eog=False)
325
+ pre_data = raw.get_data(picks=pre_picks) # shape (n_eeg, n_times)
326
+ n_eeg, n_times = pre_data.shape
327
+ n_total_epochs = n_times // n_samples_per_epoch
328
+
329
+ valid_ep_indices: list[int] = []
330
+ invalid_indices: list[int] = []
331
+ for ep in range(n_total_epochs):
332
+ start = ep * n_samples_per_epoch
333
+ end = start + n_samples_per_epoch
334
+ epoch_pre = pre_data[:, start:end]
335
+ if is_valid_epoch(epoch_pre):
336
+ valid_ep_indices.append(ep)
337
+ else:
338
+ invalid_indices.append(ep)
339
+
340
+ # Only run the expensive filter + ICA pipeline if there is something to do.
341
+ feature_cols = _build_feature_columns(
342
+ [raw.ch_names[i] for i in pre_picks]
343
+ )
344
+
345
+ n_dropped = len(invalid_indices)
346
+ if n_dropped:
347
+ display = invalid_indices[:10]
348
+ suffix = (
349
+ f"... (+{n_dropped - 10} more)" if n_dropped > 10 else ""
350
+ )
351
+ logger.warning(
352
+ "Dropping %d/%d epochs with invalid samples (indices=%s%s)",
353
+ n_dropped, n_total_epochs, display, suffix,
354
+ )
355
+
356
+ if not valid_ep_indices:
357
+ logger.info(
358
+ "Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
359
+ n_total_epochs, n_dropped,
360
+ 100.0 * n_dropped / max(n_total_epochs, 1),
361
+ )
362
+ return pd.DataFrame(columns=feature_cols).astype(np.float64)
363
+
364
+ filtered = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
365
+ cleaned = remove_artifacts_with_ica(
366
+ filtered,
367
+ eog_ch_name=eog_ch_name,
368
+ n_components=n_components,
369
+ random_state=random_state,
370
+ )
371
+
372
+ eeg_picks = mne.pick_types(cleaned.info, eeg=True, meg=False, eog=False)
373
+ eeg_names = [cleaned.ch_names[i] for i in eeg_picks]
374
+ data = cleaned.get_data(picks=eeg_picks) # shape (n_eeg, n_times)
375
+ # Rebuild feature_cols using post-ICA channel order (should match pre_picks).
376
+ feature_cols = _build_feature_columns(eeg_names)
377
+
378
+ rows: list[np.ndarray] = []
379
+ for ep in valid_ep_indices:
380
+ start = ep * n_samples_per_epoch
381
+ end = start + n_samples_per_epoch
382
+ epoch = data[:, start:end]
383
+ rows.append(compute_features_from_epoch(epoch, sfreq=sfreq))
384
+
385
+ matrix = np.vstack(rows)
386
+ out = pd.DataFrame(matrix, columns=feature_cols, dtype=np.float64)
387
+ logger.info(
388
+ "Feature extraction complete: in=%d, out=%d, dropped=%d (%.2f%%)",
389
+ n_total_epochs, len(out), n_dropped,
390
+ 100.0 * n_dropped / max(n_total_epochs, 1),
391
+ )
392
+ return out
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -5,11 +5,13 @@ from pathlib import Path
5
 
6
  import mne
7
  import numpy as np
 
8
  import pytest
9
 
10
  from src.pipelines.eeg_pipeline import (
11
  bandpass_filter,
12
  compute_features_from_epoch,
 
13
  is_valid_epoch,
14
  remove_artifacts_with_ica,
15
  )
@@ -226,3 +228,72 @@ class TestComputeFeaturesFromEpoch:
226
 
227
  derived_names = tuple(name for name, _ in _STATS_FUNCS)
228
  assert derived_names == STATS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import mne
7
  import numpy as np
8
+ import pandas as pd
9
  import pytest
10
 
11
  from src.pipelines.eeg_pipeline import (
12
  bandpass_filter,
13
  compute_features_from_epoch,
14
+ extract_features_from_recording,
15
  is_valid_epoch,
16
  remove_artifacts_with_ica,
17
  )
 
228
 
229
  derived_names = tuple(name for name, _ in _STATS_FUNCS)
230
  assert derived_names == STATS
231
+
232
+
233
+ class TestExtractFeaturesFromRecording:
234
+ def _load(self) -> mne.io.BaseRaw:
235
+ return mne.io.read_raw_fif(FIXTURE, preload=True, verbose="ERROR")
236
+
237
+ def test_returns_dataframe(self) -> None:
238
+ raw = self._load()
239
+ df = extract_features_from_recording(
240
+ raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
241
+ n_components=4, random_state=97,
242
+ )
243
+ assert isinstance(df, pd.DataFrame)
244
+
245
+ def test_row_count_matches_epochs(self) -> None:
246
+ """10 s recording / 2 s epoch = 5 epochs."""
247
+ raw = self._load()
248
+ df = extract_features_from_recording(
249
+ raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
250
+ n_components=4, random_state=97,
251
+ )
252
+ assert len(df) == 5
253
+
254
+ def test_column_naming_is_deterministic_and_explicit(self) -> None:
255
+ raw = self._load()
256
+ df = extract_features_from_recording(
257
+ raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
258
+ n_components=4, random_state=97,
259
+ )
260
+ # 4 EEG channels: Cz, Pz, O1, O2 (EOG channel is excluded from features).
261
+ for ch in ("Cz", "Pz", "O1", "O2"):
262
+ for band in EEG_BANDS:
263
+ assert f"feat_{ch}_psd_{band}" in df.columns
264
+ for stat in STATS:
265
+ assert f"feat_{ch}_{stat}" in df.columns
266
+
267
+ def test_no_feat_for_eog_channel(self) -> None:
268
+ raw = self._load()
269
+ df = extract_features_from_recording(
270
+ raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
271
+ n_components=4, random_state=97,
272
+ )
273
+ assert not any("EOG061" in c for c in df.columns)
274
+
275
+ def test_all_features_finite_float64(self) -> None:
276
+ raw = self._load()
277
+ df = extract_features_from_recording(
278
+ raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
279
+ n_components=4, random_state=97,
280
+ )
281
+ feat_cols = [c for c in df.columns if c.startswith("feat_")]
282
+ assert all(df[c].dtype == np.float64 for c in feat_cols)
283
+ assert df[feat_cols].notna().all().all()
284
+ assert np.isfinite(df[feat_cols].to_numpy()).all()
285
+
286
+ def test_drops_invalid_epochs_with_warning(self) -> None:
287
+ """If an epoch contains NaN, it is logged and dropped."""
288
+ raw = self._load()
289
+ # Inject a NaN into the last 2-second window so that exactly one epoch
290
+ # fails `is_valid_epoch`.
291
+ data = raw.get_data().copy()
292
+ data[0, -10] = np.nan
293
+ bad_raw = mne.io.RawArray(data, raw.info, verbose="ERROR")
294
+ df = extract_features_from_recording(
295
+ bad_raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
296
+ n_components=4, random_state=97,
297
+ )
298
+ # 5 epochs minus 1 dropped = 4
299
+ assert len(df) == 4