mekosotto Claude Sonnet 4.6 commited on
Commit
c743c0a
·
1 Parent(s): e3c6c58

feat(eeg): add non-mutating bandpass_filter (default 1-40 Hz)

Browse files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

src/pipelines/eeg_pipeline.py CHANGED
@@ -12,6 +12,7 @@ a logged WARNING), determinism (seeded ICA + sklearn RNG), traceability
12
  """
13
  from __future__ import annotations
14
 
 
15
  import numpy as np
16
 
17
  from src.core.logger import get_logger
@@ -37,3 +38,27 @@ def is_valid_epoch(epoch: np.ndarray | None) -> bool:
37
  if not np.all(np.isfinite(epoch)):
38
  return False
39
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
  from __future__ import annotations
14
 
15
+ import mne
16
  import numpy as np
17
 
18
  from src.core.logger import get_logger
 
38
  if not np.all(np.isfinite(epoch)):
39
  return False
40
  return True
41
+
42
+
43
+ def bandpass_filter(
44
+ raw: mne.io.BaseRaw,
45
+ l_freq: float = 1.0,
46
+ h_freq: float = 40.0,
47
+ ) -> mne.io.BaseRaw:
48
+ """Apply a non-mutating bandpass filter to an MNE Raw.
49
+
50
+ Default 1-40 Hz removes drift below 1 Hz and high-frequency noise / line
51
+ artifacts above 40 Hz. Returns a copy; the input `raw` is unchanged.
52
+
53
+ Args:
54
+ raw: Loaded `mne.io.BaseRaw` (call `.load_data()` first if from disk).
55
+ l_freq: Low-cut frequency in Hz.
56
+ h_freq: High-cut frequency in Hz.
57
+
58
+ Returns:
59
+ A filtered copy of `raw`.
60
+ """
61
+ out = raw.copy()
62
+ out.filter(l_freq=l_freq, h_freq=h_freq, picks="all", verbose="ERROR")
63
+ logger.info("Bandpass filter applied: %.1f-%.1f Hz", l_freq, h_freq)
64
+ return out
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -3,10 +3,14 @@ from __future__ import annotations
3
 
4
  from pathlib import Path
5
 
 
6
  import numpy as np
7
  import pytest
8
 
9
- from src.pipelines.eeg_pipeline import is_valid_epoch
 
 
 
10
 
11
 
12
  FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
@@ -45,3 +49,35 @@ class TestIsValidEpoch:
45
  """String / object dtype arrays must be rejected without raising."""
46
  epoch = np.array([["a", "b"], ["c", "d"]])
47
  assert is_valid_epoch(epoch) is False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  from pathlib import Path
5
 
6
+ import mne
7
  import numpy as np
8
  import pytest
9
 
10
+ from src.pipelines.eeg_pipeline import (
11
+ bandpass_filter,
12
+ is_valid_epoch,
13
+ )
14
 
15
 
16
  FIXTURE = Path(__file__).parent.parent / "fixtures" / "eeg_sample.fif"
 
49
  """String / object dtype arrays must be rejected without raising."""
50
  epoch = np.array([["a", "b"], ["c", "d"]])
51
  assert is_valid_epoch(epoch) is False
52
+
53
+
54
+ class TestBandpassFilter:
55
+ def _load(self) -> mne.io.BaseRaw:
56
+ return mne.io.read_raw_fif(FIXTURE, preload=True, verbose="ERROR")
57
+
58
+ def test_returns_raw_instance(self) -> None:
59
+ raw = self._load()
60
+ out = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
61
+ assert isinstance(out, mne.io.BaseRaw)
62
+
63
+ def test_preserves_shape(self) -> None:
64
+ raw = self._load()
65
+ n_ch_before, n_t_before = raw.get_data().shape
66
+ out = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
67
+ assert out.get_data().shape == (n_ch_before, n_t_before)
68
+
69
+ def test_attenuates_dc_component(self) -> None:
70
+ """A bandpass with l_freq=1.0 must remove a DC offset."""
71
+ raw = self._load()
72
+ # Inject a large DC offset on every channel.
73
+ data = raw.get_data() + 1e-3
74
+ raw_dc = mne.io.RawArray(data, raw.info, verbose="ERROR")
75
+ out = bandpass_filter(raw_dc, l_freq=1.0, h_freq=40.0)
76
+ # Mean on each channel should be near zero (much smaller than 1e-3).
77
+ assert np.all(np.abs(out.get_data().mean(axis=1)) < 1e-4)
78
+
79
+ def test_does_not_mutate_input(self) -> None:
80
+ raw = self._load()
81
+ original_mean = raw.get_data().mean()
82
+ _ = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
83
+ assert raw.get_data().mean() == pytest.approx(original_mean, rel=1e-12)