refactor(eeg): tighten ICA docstring + log accuracy; extract EOG threshold constant
Browse files
src/pipelines/eeg_pipeline.py
CHANGED
|
@@ -20,6 +20,12 @@ from src.core.logger import get_logger
|
|
| 20 |
|
| 21 |
logger = get_logger(__name__)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def is_valid_epoch(epoch: np.ndarray | None) -> bool:
|
| 25 |
"""Return True iff `epoch` is a non-empty 2-D numeric array with no NaN/inf.
|
|
@@ -89,29 +95,42 @@ def remove_artifacts_with_ica(
|
|
| 89 |
`measure="correlation"`, marks them as "bad" and reconstructs the signal
|
| 90 |
without them. Returns a copy; the input `raw` is unchanged.
|
| 91 |
|
| 92 |
-
If `eog_ch_name` is None or
|
| 93 |
-
|
| 94 |
-
without an EOG reference.
|
| 95 |
|
| 96 |
Args:
|
| 97 |
raw: Loaded, ideally bandpass-filtered, `mne.io.BaseRaw`.
|
| 98 |
eog_ch_name: Name of the EOG channel for correlation-based detection.
|
| 99 |
-
None disables auto-rejection
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
random_state: Seed for ICA's underlying solver. Required for §4
|
| 103 |
Determinism.
|
| 104 |
|
| 105 |
Returns:
|
| 106 |
-
A copy of `raw` with EOG-correlated ICA components removed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
"""
|
| 108 |
out = raw.copy()
|
| 109 |
-
if eog_ch_name is None
|
| 110 |
-
logger.info("ICA skipped:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
return out
|
| 112 |
|
| 113 |
-
# Cap n_components at
|
| 114 |
-
#
|
|
|
|
| 115 |
n_eeg = len(mne.pick_types(out.info, eeg=True, meg=False))
|
| 116 |
safe_n = min(n_components, max(n_eeg - 1, 1))
|
| 117 |
|
|
@@ -130,13 +149,13 @@ def remove_artifacts_with_ica(
|
|
| 130 |
out,
|
| 131 |
ch_name=eog_ch_name,
|
| 132 |
measure="correlation",
|
| 133 |
-
threshold=
|
| 134 |
verbose="ERROR",
|
| 135 |
)
|
| 136 |
ica.exclude = list(bad_idx)
|
| 137 |
logger.info(
|
| 138 |
-
"ICA fit: n_components=%d, EOG-correlated rejected=%d",
|
| 139 |
-
safe_n, len(ica.exclude),
|
| 140 |
)
|
| 141 |
ica.apply(out, verbose="ERROR")
|
| 142 |
return out
|
|
|
|
| 20 |
|
| 21 |
logger = get_logger(__name__)
|
| 22 |
|
| 23 |
+
# Pearson-correlation threshold for EOG-component rejection in ICA.
|
| 24 |
+
# Real-world EOG components typically score 0.8-0.95 against the EOG channel;
|
| 25 |
+
# 0.9 is a conservative floor that avoids false positives at the cost of
|
| 26 |
+
# missing weak artifacts. Lower (0.7-0.8) for noisier recordings.
|
| 27 |
+
_EOG_CORR_THRESHOLD: float = 0.9
|
| 28 |
+
|
| 29 |
|
| 30 |
def is_valid_epoch(epoch: np.ndarray | None) -> bool:
|
| 31 |
"""Return True iff `epoch` is a non-empty 2-D numeric array with no NaN/inf.
|
|
|
|
| 95 |
`measure="correlation"`, marks them as "bad" and reconstructs the signal
|
| 96 |
without them. Returns a copy; the input `raw` is unchanged.
|
| 97 |
|
| 98 |
+
If `eog_ch_name` is None or not present in the recording's channels,
|
| 99 |
+
ICA is skipped entirely and a copy of `raw` is returned unchanged.
|
|
|
|
| 100 |
|
| 101 |
Args:
|
| 102 |
raw: Loaded, ideally bandpass-filtered, `mne.io.BaseRaw`.
|
| 103 |
eog_ch_name: Name of the EOG channel for correlation-based detection.
|
| 104 |
+
None disables auto-rejection; a string that is not in the recording's
|
| 105 |
+
channel list logs a WARNING and skips ICA.
|
| 106 |
+
n_components: Cap on ICA components. If this exceeds the number of EEG
|
| 107 |
+
channels, MNE raises ValueError, so the implementation internally
|
| 108 |
+
caps it at `max(n_eeg - 1, 1)` before fitting.
|
| 109 |
random_state: Seed for ICA's underlying solver. Required for §4
|
| 110 |
Determinism.
|
| 111 |
|
| 112 |
Returns:
|
| 113 |
+
A copy of `raw` with EOG-correlated ICA components removed (or an
|
| 114 |
+
unchanged copy if ICA was skipped).
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
ValueError: if the EEG data is rank-deficient (all-zero or constant
|
| 118 |
+
channels) and `mne.preprocessing.ICA.fit` cannot converge.
|
| 119 |
"""
|
| 120 |
out = raw.copy()
|
| 121 |
+
if eog_ch_name is None:
|
| 122 |
+
logger.info("ICA skipped: eog_ch_name not provided")
|
| 123 |
+
return out
|
| 124 |
+
if eog_ch_name not in out.ch_names:
|
| 125 |
+
logger.warning(
|
| 126 |
+
"ICA skipped: eog_ch_name=%r not found in channels %s",
|
| 127 |
+
eog_ch_name, out.ch_names,
|
| 128 |
+
)
|
| 129 |
return out
|
| 130 |
|
| 131 |
+
# Cap n_components at rank-1. Average reference (if applied) reduces rank
|
| 132 |
+
# to n_eeg - 1; using that as the ceiling is safe for both referenced and
|
| 133 |
+
# unreferenced data and avoids ValueError from ICA.fit on small recordings.
|
| 134 |
n_eeg = len(mne.pick_types(out.info, eeg=True, meg=False))
|
| 135 |
safe_n = min(n_components, max(n_eeg - 1, 1))
|
| 136 |
|
|
|
|
| 149 |
out,
|
| 150 |
ch_name=eog_ch_name,
|
| 151 |
measure="correlation",
|
| 152 |
+
threshold=_EOG_CORR_THRESHOLD,
|
| 153 |
verbose="ERROR",
|
| 154 |
)
|
| 155 |
ica.exclude = list(bad_idx)
|
| 156 |
logger.info(
|
| 157 |
+
"ICA fit: n_components=%d, EOG-correlated rejected=%d (indices=%s)",
|
| 158 |
+
safe_n, len(ica.exclude), ica.exclude,
|
| 159 |
)
|
| 160 |
ica.apply(out, verbose="ERROR")
|
| 161 |
return out
|
tests/pipelines/test_eeg_pipeline.py
CHANGED
|
@@ -148,3 +148,30 @@ class TestRemoveArtifactsWithIca:
|
|
| 148 |
raw, eog_ch_name="EOG061", n_components=4, random_state=97,
|
| 149 |
)
|
| 150 |
np.testing.assert_allclose(a.get_data(), b.get_data(), rtol=1e-12, atol=1e-15)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
raw, eog_ch_name="EOG061", n_components=4, random_state=97,
|
| 149 |
)
|
| 150 |
np.testing.assert_allclose(a.get_data(), b.get_data(), rtol=1e-12, atol=1e-15)
|
| 151 |
+
|
| 152 |
+
def test_unknown_eog_channel_logs_warning_and_is_a_noop(self) -> None:
|
| 153 |
+
"""A misconfigured eog_ch_name (typo) must not silently behave like None."""
|
| 154 |
+
import io
|
| 155 |
+
import logging
|
| 156 |
+
|
| 157 |
+
from src.core.logger import get_logger
|
| 158 |
+
from src.pipelines import eeg_pipeline as mod
|
| 159 |
+
|
| 160 |
+
raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
|
| 161 |
+
logger = get_logger(mod.__name__, level=logging.INFO)
|
| 162 |
+
handler = logger.handlers[0]
|
| 163 |
+
buf = io.StringIO()
|
| 164 |
+
original_stream = handler.stream
|
| 165 |
+
handler.stream = buf
|
| 166 |
+
try:
|
| 167 |
+
out = remove_artifacts_with_ica(
|
| 168 |
+
raw, eog_ch_name="EOG_DOES_NOT_EXIST",
|
| 169 |
+
n_components=4, random_state=97,
|
| 170 |
+
)
|
| 171 |
+
finally:
|
| 172 |
+
handler.stream = original_stream
|
| 173 |
+
|
| 174 |
+
# Behavior: ICA was skipped (no-op) but the log differentiates it from None.
|
| 175 |
+
np.testing.assert_allclose(out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12)
|
| 176 |
+
log_output = buf.getvalue()
|
| 177 |
+
assert "ICA skipped: eog_ch_name='EOG_DOES_NOT_EXIST' not found" in log_output
|