mekosotto Claude Sonnet 4.6 commited on
Commit
9931366
·
1 Parent(s): 1a86a6e

refactor(eeg): tighten ICA docstring + log accuracy; extract EOG threshold constant

Browse files
src/pipelines/eeg_pipeline.py CHANGED
@@ -20,6 +20,12 @@ from src.core.logger import get_logger
20
 
21
  logger = get_logger(__name__)
22
 
 
 
 
 
 
 
23
 
24
  def is_valid_epoch(epoch: np.ndarray | None) -> bool:
25
  """Return True iff `epoch` is a non-empty 2-D numeric array with no NaN/inf.
@@ -89,29 +95,42 @@ def remove_artifacts_with_ica(
89
  `measure="correlation"`, marks them as "bad" and reconstructs the signal
90
  without them. Returns a copy; the input `raw` is unchanged.
91
 
92
- If `eog_ch_name` is None or no bad components are found, returns a copy of
93
- `raw` unchanged. This keeps the function safe to call on recordings
94
- without an EOG reference.
95
 
96
  Args:
97
  raw: Loaded, ideally bandpass-filtered, `mne.io.BaseRaw`.
98
  eog_ch_name: Name of the EOG channel for correlation-based detection.
99
- None disables auto-rejection.
100
- n_components: Cap on ICA components. For small recordings, MNE will
101
- silently cap this at the rank of the data.
 
 
102
  random_state: Seed for ICA's underlying solver. Required for §4
103
  Determinism.
104
 
105
  Returns:
106
- A copy of `raw` with EOG-correlated ICA components removed.
 
 
 
 
 
107
  """
108
  out = raw.copy()
109
- if eog_ch_name is None or eog_ch_name not in out.ch_names:
110
- logger.info("ICA skipped: no EOG channel reference provided")
 
 
 
 
 
 
111
  return out
112
 
113
- # Cap n_components at the rank of the data to avoid solver complaints
114
- # on small synthetic fixtures.
 
115
  n_eeg = len(mne.pick_types(out.info, eeg=True, meg=False))
116
  safe_n = min(n_components, max(n_eeg - 1, 1))
117
 
@@ -130,13 +149,13 @@ def remove_artifacts_with_ica(
130
  out,
131
  ch_name=eog_ch_name,
132
  measure="correlation",
133
- threshold=0.9,
134
  verbose="ERROR",
135
  )
136
  ica.exclude = list(bad_idx)
137
  logger.info(
138
- "ICA fit: n_components=%d, EOG-correlated rejected=%d",
139
- safe_n, len(ica.exclude),
140
  )
141
  ica.apply(out, verbose="ERROR")
142
  return out
 
20
 
21
  logger = get_logger(__name__)
22
 
23
+ # Pearson-correlation threshold for EOG-component rejection in ICA.
24
+ # Real-world EOG components typically score 0.8-0.95 against the EOG channel;
25
+ # 0.9 is a conservative floor that avoids false positives at the cost of
26
+ # missing weak artifacts. Lower (0.7-0.8) for noisier recordings.
27
+ _EOG_CORR_THRESHOLD: float = 0.9
28
+
29
 
30
  def is_valid_epoch(epoch: np.ndarray | None) -> bool:
31
  """Return True iff `epoch` is a non-empty 2-D numeric array with no NaN/inf.
 
95
  `measure="correlation"`, marks them as "bad" and reconstructs the signal
96
  without them. Returns a copy; the input `raw` is unchanged.
97
 
98
+ If `eog_ch_name` is None or not present in the recording's channels,
99
+ ICA is skipped entirely and a copy of `raw` is returned unchanged.
 
100
 
101
  Args:
102
  raw: Loaded, ideally bandpass-filtered, `mne.io.BaseRaw`.
103
  eog_ch_name: Name of the EOG channel for correlation-based detection.
104
+ None disables auto-rejection; a string that is not in the recording's
105
+ channel list logs a WARNING and skips ICA.
106
+ n_components: Cap on ICA components. If this exceeds the number of EEG
107
+ channels, MNE raises ValueError, so the implementation internally
108
+ caps it at `max(n_eeg - 1, 1)` before fitting.
109
  random_state: Seed for ICA's underlying solver. Required for §4
110
  Determinism.
111
 
112
  Returns:
113
+ A copy of `raw` with EOG-correlated ICA components removed (or an
114
+ unchanged copy if ICA was skipped).
115
+
116
+ Raises:
117
+ ValueError: if the EEG data is rank-deficient (all-zero or constant
118
+ channels) and `mne.preprocessing.ICA.fit` cannot converge.
119
  """
120
  out = raw.copy()
121
+ if eog_ch_name is None:
122
+ logger.info("ICA skipped: eog_ch_name not provided")
123
+ return out
124
+ if eog_ch_name not in out.ch_names:
125
+ logger.warning(
126
+ "ICA skipped: eog_ch_name=%r not found in channels %s",
127
+ eog_ch_name, out.ch_names,
128
+ )
129
  return out
130
 
131
+ # Cap n_components at rank-1. Average reference (if applied) reduces rank
132
+ # to n_eeg - 1; using that as the ceiling is safe for both referenced and
133
+ # unreferenced data and avoids ValueError from ICA.fit on small recordings.
134
  n_eeg = len(mne.pick_types(out.info, eeg=True, meg=False))
135
  safe_n = min(n_components, max(n_eeg - 1, 1))
136
 
 
149
  out,
150
  ch_name=eog_ch_name,
151
  measure="correlation",
152
+ threshold=_EOG_CORR_THRESHOLD,
153
  verbose="ERROR",
154
  )
155
  ica.exclude = list(bad_idx)
156
  logger.info(
157
+ "ICA fit: n_components=%d, EOG-correlated rejected=%d (indices=%s)",
158
+ safe_n, len(ica.exclude), ica.exclude,
159
  )
160
  ica.apply(out, verbose="ERROR")
161
  return out
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -148,3 +148,30 @@ class TestRemoveArtifactsWithIca:
148
  raw, eog_ch_name="EOG061", n_components=4, random_state=97,
149
  )
150
  np.testing.assert_allclose(a.get_data(), b.get_data(), rtol=1e-12, atol=1e-15)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  raw, eog_ch_name="EOG061", n_components=4, random_state=97,
149
  )
150
  np.testing.assert_allclose(a.get_data(), b.get_data(), rtol=1e-12, atol=1e-15)
151
+
152
+ def test_unknown_eog_channel_logs_warning_and_is_a_noop(self) -> None:
153
+ """A misconfigured eog_ch_name (typo) must not silently behave like None."""
154
+ import io
155
+ import logging
156
+
157
+ from src.core.logger import get_logger
158
+ from src.pipelines import eeg_pipeline as mod
159
+
160
+ raw = bandpass_filter(self._load(), l_freq=1.0, h_freq=40.0)
161
+ logger = get_logger(mod.__name__, level=logging.INFO)
162
+ handler = logger.handlers[0]
163
+ buf = io.StringIO()
164
+ original_stream = handler.stream
165
+ handler.stream = buf
166
+ try:
167
+ out = remove_artifacts_with_ica(
168
+ raw, eog_ch_name="EOG_DOES_NOT_EXIST",
169
+ n_components=4, random_state=97,
170
+ )
171
+ finally:
172
+ handler.stream = original_stream
173
+
174
+ # Behavior: ICA was skipped (no-op) but the log differentiates it from None.
175
+ np.testing.assert_allclose(out.get_data(), raw.get_data(), rtol=1e-6, atol=1e-12)
176
+ log_output = buf.getvalue()
177
+ assert "ICA skipped: eog_ch_name='EOG_DOES_NOT_EXIST' not found" in log_output