feat(eeg): add remove_artifacts_with_ica with EOG correlation rejection
Browse filesImplements ICA-based EOG artifact removal using measure="correlation",
threshold=0.9 to reliably flag components on small (4-channel) fixtures
where the default z-score threshold is algebraically unreachable.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
src/pipelines/eeg_pipeline.py
CHANGED
|
@@ -14,6 +14,7 @@ from __future__ import annotations
|
|
| 14 |
|
| 15 |
import mne
|
| 16 |
import numpy as np
|
|
|
|
| 17 |
|
| 18 |
from src.core.logger import get_logger
|
| 19 |
|
|
@@ -73,3 +74,69 @@ def bandpass_filter(
|
|
| 73 |
out.filter(l_freq=l_freq, h_freq=h_freq, picks="all", verbose="ERROR")
|
| 74 |
logger.info("Bandpass filter applied: %.1f-%.1f Hz", l_freq, h_freq)
|
| 75 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
import mne
|
| 16 |
import numpy as np
|
| 17 |
+
from mne.preprocessing import ICA
|
| 18 |
|
| 19 |
from src.core.logger import get_logger
|
| 20 |
|
|
|
|
| 74 |
out.filter(l_freq=l_freq, h_freq=h_freq, picks="all", verbose="ERROR")
|
| 75 |
logger.info("Bandpass filter applied: %.1f-%.1f Hz", l_freq, h_freq)
|
| 76 |
return out
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def remove_artifacts_with_ica(
|
| 80 |
+
raw: mne.io.BaseRaw,
|
| 81 |
+
eog_ch_name: str | None = None,
|
| 82 |
+
n_components: int = 15,
|
| 83 |
+
random_state: int = 97,
|
| 84 |
+
) -> mne.io.BaseRaw:
|
| 85 |
+
"""Remove EOG-like artifacts using MNE's ICA + EOG correlation.
|
| 86 |
+
|
| 87 |
+
Fits an ICA decomposition on `raw`, finds components whose time courses
|
| 88 |
+
correlate (Pearson) with the named EOG channel via `find_bads_eog` using
|
| 89 |
+
`measure="correlation"`, marks them as "bad" and reconstructs the signal
|
| 90 |
+
without them. Returns a copy; the input `raw` is unchanged.
|
| 91 |
+
|
| 92 |
+
If `eog_ch_name` is None or no bad components are found, returns a copy of
|
| 93 |
+
`raw` unchanged. This keeps the function safe to call on recordings
|
| 94 |
+
without an EOG reference.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
raw: Loaded, ideally bandpass-filtered, `mne.io.BaseRaw`.
|
| 98 |
+
eog_ch_name: Name of the EOG channel for correlation-based detection.
|
| 99 |
+
None disables auto-rejection.
|
| 100 |
+
n_components: Cap on ICA components. For small recordings, MNE will
|
| 101 |
+
silently cap this at the rank of the data.
|
| 102 |
+
random_state: Seed for ICA's underlying solver. Required for §4
|
| 103 |
+
Determinism.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
A copy of `raw` with EOG-correlated ICA components removed.
|
| 107 |
+
"""
|
| 108 |
+
out = raw.copy()
|
| 109 |
+
if eog_ch_name is None or eog_ch_name not in out.ch_names:
|
| 110 |
+
logger.info("ICA skipped: no EOG channel reference provided")
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
# Cap n_components at the rank of the data to avoid solver complaints
|
| 114 |
+
# on small synthetic fixtures.
|
| 115 |
+
n_eeg = len(mne.pick_types(out.info, eeg=True, meg=False))
|
| 116 |
+
safe_n = min(n_components, max(n_eeg - 1, 1))
|
| 117 |
+
|
| 118 |
+
ica = ICA(
|
| 119 |
+
n_components=safe_n,
|
| 120 |
+
random_state=random_state,
|
| 121 |
+
max_iter="auto",
|
| 122 |
+
method="fastica",
|
| 123 |
+
verbose="ERROR",
|
| 124 |
+
)
|
| 125 |
+
ica.fit(out, picks="eeg", verbose="ERROR")
|
| 126 |
+
# Use raw correlation (not z-score) so we can reliably flag artifact
|
| 127 |
+
# components on small recordings where n_components < 10 makes the
|
| 128 |
+
# default z-score threshold algebraically unreachable.
|
| 129 |
+
bad_idx, _ = ica.find_bads_eog(
|
| 130 |
+
out,
|
| 131 |
+
ch_name=eog_ch_name,
|
| 132 |
+
measure="correlation",
|
| 133 |
+
threshold=0.9,
|
| 134 |
+
verbose="ERROR",
|
| 135 |
+
)
|
| 136 |
+
ica.exclude = list(bad_idx)
|
| 137 |
+
logger.info(
|
| 138 |
+
"ICA fit: n_components=%d, EOG-correlated rejected=%d",
|
| 139 |
+
safe_n, len(ica.exclude),
|
| 140 |
+
)
|
| 141 |
+
ica.apply(out, verbose="ERROR")
|
| 142 |
+
return out
|
tests/pipelines/test_eeg_pipeline.py
CHANGED
|
@@ -10,6 +10,7 @@ import pytest
|
|
| 10 |
from src.pipelines.eeg_pipeline import (
|
| 11 |
bandpass_filter,
|
| 12 |
is_valid_epoch,
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
|
|
@@ -89,3 +90,61 @@ class TestBandpassFilter:
|
|
| 89 |
bandpass_filter(raw, l_freq=40.0, h_freq=1.0)
|
| 90 |
with pytest.raises(ValueError, match="must be strictly less than"):
|
| 91 |
bandpass_filter(raw, l_freq=10.0, h_freq=10.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from src.pipelines.eeg_pipeline import (
|
| 11 |
bandpass_filter,
|
| 12 |
is_valid_epoch,
|
| 13 |
+
remove_artifacts_with_ica,
|
| 14 |
)
|
| 15 |
|
| 16 |
|
|
|
|
| 90 |
bandpass_filter(raw, l_freq=40.0, h_freq=1.0)
|
| 91 |
with pytest.raises(ValueError, match="must be strictly less than"):
|
| 92 |
bandpass_filter(raw, l_freq=10.0, h_freq=10.0)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TestRemoveArtifactsWithIca:
|
| 96 |
+
def _load(self) -> mne.io.BaseRaw:
|
| 97 |
+
return mne.io.read_raw_fif(FIXTURE, preload=True, verbose="ERROR")
|
| 98 |
+
|
| 99 |
+
def test_returns_raw_instance(self) -> None:
|
| 100 |
+
raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
|
| 101 |
+
out = remove_artifacts_with_ica(
|
| 102 |
+
raw, eog_ch_name="EOG061", n_components=4, random_state=97,
|
| 103 |
+
)
|
| 104 |
+
assert isinstance(out, mne.io.BaseRaw)
|
| 105 |
+
|
| 106 |
+
def test_preserves_shape(self) -> None:
|
| 107 |
+
raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
|
| 108 |
+
before = raw.get_data().shape
|
| 109 |
+
out = remove_artifacts_with_ica(
|
| 110 |
+
raw, eog_ch_name="EOG061", n_components=4, random_state=97,
|
| 111 |
+
)
|
| 112 |
+
assert out.get_data().shape == before
|
| 113 |
+
|
| 114 |
+
def test_reduces_eog_correlation_on_frontal_channel(self) -> None:
|
| 115 |
+
"""ICA must reduce correlation between EOG and Cz (the bleed channel)."""
|
| 116 |
+
raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
|
| 117 |
+
before = raw.get_data()
|
| 118 |
+
cz_idx = raw.ch_names.index("Cz")
|
| 119 |
+
eog_idx = raw.ch_names.index("EOG061")
|
| 120 |
+
corr_before = abs(np.corrcoef(before[cz_idx], before[eog_idx])[0, 1])
|
| 121 |
+
|
| 122 |
+
out = remove_artifacts_with_ica(
|
| 123 |
+
raw, eog_ch_name="EOG061", n_components=4, random_state=97,
|
| 124 |
+
)
|
| 125 |
+
after = out.get_data()
|
| 126 |
+
corr_after = abs(np.corrcoef(after[cz_idx], after[eog_idx])[0, 1])
|
| 127 |
+
# Allow for noise — but the dominant EOG bleed must be reduced.
|
| 128 |
+
assert corr_after < corr_before
|
| 129 |
+
|
| 130 |
+
def test_no_eog_channel_is_a_noop(self) -> None:
|
| 131 |
+
"""Without an EOG reference, ICA can't auto-reject — should pass through."""
|
| 132 |
+
raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
|
| 133 |
+
out = remove_artifacts_with_ica(
|
| 134 |
+
raw, eog_ch_name=None, n_components=4, random_state=97,
|
| 135 |
+
)
|
| 136 |
+
# Identical shape; data approximately equal (no rejection happened).
|
| 137 |
+
assert out.get_data().shape == raw.get_data().shape
|
| 138 |
+
np.testing.assert_allclose(
|
| 139 |
+
out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def test_is_deterministic_with_seed(self) -> None:
|
| 143 |
+
raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
|
| 144 |
+
a = remove_artifacts_with_ica(
|
| 145 |
+
raw, eog_ch_name="EOG061", n_components=4, random_state=97,
|
| 146 |
+
)
|
| 147 |
+
b = remove_artifacts_with_ica(
|
| 148 |
+
raw, eog_ch_name="EOG061", n_components=4, random_state=97,
|
| 149 |
+
)
|
| 150 |
+
np.testing.assert_allclose(a.get_data(), b.get_data(), rtol=1e-12, atol=1e-15)
|