feat(eeg): add non-mutating bandpass_filter (default 1-40 Hz)
Browse filesCo-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
src/pipelines/eeg_pipeline.py
CHANGED
|
@@ -12,6 +12,7 @@ a logged WARNING), determinism (seeded ICA + sklearn RNG), traceability
|
|
| 12 |
"""
|
| 13 |
from __future__ import annotations
|
| 14 |
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
|
| 17 |
from src.core.logger import get_logger
|
|
@@ -37,3 +38,27 @@ def is_valid_epoch(epoch: np.ndarray | None) -> bool:
|
|
| 37 |
if not np.all(np.isfinite(epoch)):
|
| 38 |
return False
|
| 39 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
+
import mne
|
| 16 |
import numpy as np
|
| 17 |
|
| 18 |
from src.core.logger import get_logger
|
|
|
|
| 38 |
if not np.all(np.isfinite(epoch)):
|
| 39 |
return False
|
| 40 |
return True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def bandpass_filter(
|
| 44 |
+
raw: mne.io.BaseRaw,
|
| 45 |
+
l_freq: float = 1.0,
|
| 46 |
+
h_freq: float = 40.0,
|
| 47 |
+
) -> mne.io.BaseRaw:
|
| 48 |
+
"""Apply a non-mutating bandpass filter to an MNE Raw.
|
| 49 |
+
|
| 50 |
+
Default 1-40 Hz removes drift below 1 Hz and high-frequency noise / line
|
| 51 |
+
artifacts above 40 Hz. Returns a copy; the input `raw` is unchanged.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
raw: Loaded `mne.io.BaseRaw` (call `.load_data()` first if from disk).
|
| 55 |
+
l_freq: Low-cut frequency in Hz.
|
| 56 |
+
h_freq: High-cut frequency in Hz.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
A filtered copy of `raw`.
|
| 60 |
+
"""
|
| 61 |
+
out = raw.copy()
|
| 62 |
+
out.filter(l_freq=l_freq, h_freq=h_freq, picks="all", verbose="ERROR")
|
| 63 |
+
logger.info("Bandpass filter applied: %.1f-%.1f Hz", l_freq, h_freq)
|
| 64 |
+
return out
|
tests/pipelines/test_eeg_pipeline.py
CHANGED
|
@@ -3,10 +3,14 @@ from __future__ import annotations
|
|
| 3 |
|
| 4 |
from pathlib import Path
|
| 5 |
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import pytest
|
| 8 |
|
| 9 |
-
from src.pipelines.eeg_pipeline import
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
|
|
@@ -45,3 +49,35 @@ class TestIsValidEpoch:
|
|
| 45 |
"""String / object dtype arrays must be rejected without raising."""
|
| 46 |
epoch = np.array([["a", "b"], ["c", "d"]])
|
| 47 |
assert is_valid_epoch(epoch) is False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
+
import mne
|
| 7 |
import numpy as np
|
| 8 |
import pytest
|
| 9 |
|
| 10 |
+
from src.pipelines.eeg_pipeline import (
|
| 11 |
+
bandpass_filter,
|
| 12 |
+
is_valid_epoch,
|
| 13 |
+
)
|
| 14 |
|
| 15 |
|
| 16 |
FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
|
|
|
|
| 49 |
"""String / object dtype arrays must be rejected without raising."""
|
| 50 |
epoch = np.array([["a", "b"], ["c", "d"]])
|
| 51 |
assert is_valid_epoch(epoch) is False
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TestBandpassFilter:
|
| 55 |
+
def _load(self) -> mne.io.BaseRaw:
|
| 56 |
+
return mne.io.read_raw_fif(FIXTURE, preload=True, verbose="ERROR")
|
| 57 |
+
|
| 58 |
+
def test_returns_raw_instance(self) -> None:
|
| 59 |
+
raw = self._load()
|
| 60 |
+
out = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
|
| 61 |
+
assert isinstance(out, mne.io.BaseRaw)
|
| 62 |
+
|
| 63 |
+
def test_preserves_shape(self) -> None:
|
| 64 |
+
raw = self._load()
|
| 65 |
+
n_ch_before, n_t_before = raw.get_data().shape
|
| 66 |
+
out = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
|
| 67 |
+
assert out.get_data().shape == (n_ch_before, n_t_before)
|
| 68 |
+
|
| 69 |
+
def test_attenuates_dc_component(self) -> None:
|
| 70 |
+
"""A bandpass with l_freq=1.0 must remove a DC offset."""
|
| 71 |
+
raw = self._load()
|
| 72 |
+
# Inject a large DC offset on every channel.
|
| 73 |
+
data = raw.get_data() + 1e-3
|
| 74 |
+
raw_dc = mne.io.RawArray(data, raw.info, verbose="ERROR")
|
| 75 |
+
out = bandpass_filter(raw_dc, l_freq=1.0, h_freq=40.0)
|
| 76 |
+
# Mean on each channel should be near zero (much smaller than 1e-3).
|
| 77 |
+
assert np.all(np.abs(out.get_data().mean(axis=1)) < 1e-4)
|
| 78 |
+
|
| 79 |
+
def test_does_not_mutate_input(self) -> None:
|
| 80 |
+
raw = self._load()
|
| 81 |
+
original_mean = raw.get_data().mean()
|
| 82 |
+
_ = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
|
| 83 |
+
assert raw.get_data().mean() == pytest.approx(original_mean, rel=1e-12)
|