mekosotto Claude Sonnet 4.6 commited on
Commit
1a86a6e
·
1 Parent(s): 32359e5

feat(eeg): add remove_artifacts_with_ica with EOG correlation rejection

Browse files

Implements ICA-based EOG artifact removal using measure="correlation",
threshold=0.9 to reliably flag components on small (4-channel) fixtures
where the default z-score threshold is algebraically unreachable.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

src/pipelines/eeg_pipeline.py CHANGED
@@ -14,6 +14,7 @@ from __future__ import annotations
14
 
15
  import mne
16
  import numpy as np
 
17
 
18
  from src.core.logger import get_logger
19
 
@@ -73,3 +74,69 @@ def bandpass_filter(
73
  out.filter(l_freq=l_freq, h_freq=h_freq, picks="all", verbose="ERROR")
74
  logger.info("Bandpass filter applied: %.1f-%.1f Hz", l_freq, h_freq)
75
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  import mne
16
  import numpy as np
17
+ from mne.preprocessing import ICA
18
 
19
  from src.core.logger import get_logger
20
 
 
74
  out.filter(l_freq=l_freq, h_freq=h_freq, picks="all", verbose="ERROR")
75
  logger.info("Bandpass filter applied: %.1f-%.1f Hz", l_freq, h_freq)
76
  return out
77
+
78
+
79
+ def remove_artifacts_with_ica(
80
+ raw: mne.io.BaseRaw,
81
+ eog_ch_name: str | None = None,
82
+ n_components: int = 15,
83
+ random_state: int = 97,
84
+ ) -> mne.io.BaseRaw:
85
+ """Remove EOG-like artifacts using MNE's ICA + EOG correlation.
86
+
87
+ Fits an ICA decomposition on `raw`, finds components whose time courses
88
+ correlate (Pearson) with the named EOG channel via `find_bads_eog` using
89
+ `measure="correlation"`, marks them as "bad" and reconstructs the signal
90
+ without them. Returns a copy; the input `raw` is unchanged.
91
+
92
+ If `eog_ch_name` is None or no bad components are found, returns a copy of
93
+ `raw` unchanged. This keeps the function safe to call on recordings
94
+ without an EOG reference.
95
+
96
+ Args:
97
+ raw: Loaded, ideally bandpass-filtered, `mne.io.BaseRaw`.
98
+ eog_ch_name: Name of the EOG channel for correlation-based detection.
99
+ None disables auto-rejection.
100
+ n_components: Cap on ICA components. For small recordings, MNE will
101
+ silently cap this at the rank of the data.
102
+ random_state: Seed for ICA's underlying solver. Required for §4
103
+ Determinism.
104
+
105
+ Returns:
106
+ A copy of `raw` with EOG-correlated ICA components removed.
107
+ """
108
+ out = raw.copy()
109
+ if eog_ch_name is None or eog_ch_name not in out.ch_names:
110
+ logger.info("ICA skipped: no EOG channel reference provided")
111
+ return out
112
+
113
+ # Cap n_components at the rank of the data to avoid solver complaints
114
+ # on small synthetic fixtures.
115
+ n_eeg = len(mne.pick_types(out.info, eeg=True, meg=False))
116
+ safe_n = min(n_components, max(n_eeg - 1, 1))
117
+
118
+ ica = ICA(
119
+ n_components=safe_n,
120
+ random_state=random_state,
121
+ max_iter="auto",
122
+ method="fastica",
123
+ verbose="ERROR",
124
+ )
125
+ ica.fit(out, picks="eeg", verbose="ERROR")
126
+ # Use raw correlation (not z-score) so we can reliably flag artifact
127
+ # components on small recordings where n_components < 10 makes the
128
+ # default z-score threshold algebraically unreachable.
129
+ bad_idx, _ = ica.find_bads_eog(
130
+ out,
131
+ ch_name=eog_ch_name,
132
+ measure="correlation",
133
+ threshold=0.9,
134
+ verbose="ERROR",
135
+ )
136
+ ica.exclude = list(bad_idx)
137
+ logger.info(
138
+ "ICA fit: n_components=%d, EOG-correlated rejected=%d",
139
+ safe_n, len(ica.exclude),
140
+ )
141
+ ica.apply(out, verbose="ERROR")
142
+ return out
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -10,6 +10,7 @@ import pytest
10
  from src.pipelines.eeg_pipeline import (
11
  bandpass_filter,
12
  is_valid_epoch,
 
13
  )
14
 
15
 
@@ -89,3 +90,61 @@ class TestBandpassFilter:
89
  bandpass_filter(raw, l_freq=40.0, h_freq=1.0)
90
  with pytest.raises(ValueError, match="must be strictly less than"):
91
  bandpass_filter(raw, l_freq=10.0, h_freq=10.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from src.pipelines.eeg_pipeline import (
11
  bandpass_filter,
12
  is_valid_epoch,
13
+ remove_artifacts_with_ica,
14
  )
15
 
16
 
 
90
  bandpass_filter(raw, l_freq=40.0, h_freq=1.0)
91
  with pytest.raises(ValueError, match="must be strictly less than"):
92
  bandpass_filter(raw, l_freq=10.0, h_freq=10.0)
93
+
94
+
95
+ class TestRemoveArtifactsWithIca:
96
+ def _load(self) -> mne.io.BaseRaw:
97
+ return mne.io.read_raw_fif(FIXTURE, preload=True, verbose="ERROR")
98
+
99
+ def test_returns_raw_instance(self) -> None:
100
+ raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
101
+ out = remove_artifacts_with_ica(
102
+ raw, eog_ch_name="EOG061", n_components=4, random_state=97,
103
+ )
104
+ assert isinstance(out, mne.io.BaseRaw)
105
+
106
+ def test_preserves_shape(self) -> None:
107
+ raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
108
+ before = raw.get_data().shape
109
+ out = remove_artifacts_with_ica(
110
+ raw, eog_ch_name="EOG061", n_components=4, random_state=97,
111
+ )
112
+ assert out.get_data().shape == before
113
+
114
+ def test_reduces_eog_correlation_on_frontal_channel(self) -> None:
115
+ """ICA must reduce correlation between EOG and Cz (the bleed channel)."""
116
+ raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
117
+ before = raw.get_data()
118
+ cz_idx = raw.ch_names.index("Cz")
119
+ eog_idx = raw.ch_names.index("EOG061")
120
+ corr_before = abs(np.corrcoef(before[cz_idx], before[eog_idx])[0, 1])
121
+
122
+ out = remove_artifacts_with_ica(
123
+ raw, eog_ch_name="EOG061", n_components=4, random_state=97,
124
+ )
125
+ after = out.get_data()
126
+ corr_after = abs(np.corrcoef(after[cz_idx], after[eog_idx])[0, 1])
127
+ # Allow for noise — but the dominant EOG bleed must be reduced.
128
+ assert corr_after < corr_before
129
+
130
+ def test_no_eog_channel_is_a_noop(self) -> None:
131
+ """Without an EOG reference, ICA can't auto-reject — should pass through."""
132
+ raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
133
+ out = remove_artifacts_with_ica(
134
+ raw, eog_ch_name=None, n_components=4, random_state=97,
135
+ )
136
+ # Identical shape; data approximately equal (no rejection happened).
137
+ assert out.get_data().shape == raw.get_data().shape
138
+ np.testing.assert_allclose(
139
+ out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12
140
+ )
141
+
142
+ def test_is_deterministic_with_seed(self) -> None:
143
+ raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
144
+ a = remove_artifacts_with_ica(
145
+ raw, eog_ch_name="EOG061", n_components=4, random_state=97,
146
+ )
147
+ b = remove_artifacts_with_ica(
148
+ raw, eog_ch_name="EOG061", n_components=4, random_state=97,
149
+ )
150
+ np.testing.assert_allclose(a.get_data(), b.get_data(), rtol=1e-12, atol=1e-15)