mekosotto Claude Sonnet 4.6 commited on
Commit
c26c6a2
·
1 Parent(s): 32e13cf

fix(eeg): validate epochs after filter to guarantee no NaN in feature table

Browse files

Replace pre-screen-before-filter logic with principled filter→ICA→epoch→validate
order so that FIR-spread NaN is caught at epoch validation time, not missed after
surviving epochs are extracted from contaminated filtered data.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

src/pipelines/eeg_pipeline.py CHANGED
@@ -316,31 +316,33 @@ def extract_features_from_recording(
316
  A `pd.DataFrame` with one row per valid epoch and ``n_eeg_channels *
317
  (len(EEG_BANDS) + len(STATS))`` ``feat_*`` columns.
318
  """
319
- # Pre-screen epochs on the original (unfiltered) raw data so that NaN/inf
320
- # values injected into one epoch window do not spread across the full signal
321
- # via the bandpass convolution and invalidate neighbouring epochs.
322
- sfreq = float(raw.info["sfreq"])
 
 
 
 
 
323
  n_samples_per_epoch = int(round(epoch_duration_s * sfreq))
324
- pre_picks = mne.pick_types(raw.info, eeg=True, meg=False, eog=False)
325
- pre_data = raw.get_data(picks=pre_picks) # shape (n_eeg, n_times)
326
- n_eeg, n_times = pre_data.shape
 
327
  n_total_epochs = n_times // n_samples_per_epoch
328
 
329
- valid_ep_indices: list[int] = []
 
330
  invalid_indices: list[int] = []
331
  for ep in range(n_total_epochs):
332
  start = ep * n_samples_per_epoch
333
  end = start + n_samples_per_epoch
334
- epoch_pre = pre_data[:, start:end]
335
- if is_valid_epoch(epoch_pre):
336
- valid_ep_indices.append(ep)
337
- else:
338
  invalid_indices.append(ep)
339
-
340
- # Only run the expensive filter + ICA pipeline if there is something to do.
341
- feature_cols = _build_feature_columns(
342
- [raw.ch_names[i] for i in pre_picks]
343
- )
344
 
345
  n_dropped = len(invalid_indices)
346
  if n_dropped:
@@ -353,7 +355,7 @@ def extract_features_from_recording(
353
  n_dropped, n_total_epochs, display, suffix,
354
  )
355
 
356
- if not valid_ep_indices:
357
  logger.info(
358
  "Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
359
  n_total_epochs, n_dropped,
@@ -361,27 +363,6 @@ def extract_features_from_recording(
361
  )
362
  return pd.DataFrame(columns=feature_cols).astype(np.float64)
363
 
364
- filtered = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
365
- cleaned = remove_artifacts_with_ica(
366
- filtered,
367
- eog_ch_name=eog_ch_name,
368
- n_components=n_components,
369
- random_state=random_state,
370
- )
371
-
372
- eeg_picks = mne.pick_types(cleaned.info, eeg=True, meg=False, eog=False)
373
- eeg_names = [cleaned.ch_names[i] for i in eeg_picks]
374
- data = cleaned.get_data(picks=eeg_picks) # shape (n_eeg, n_times)
375
- # Rebuild feature_cols using post-ICA channel order (should match pre_picks).
376
- feature_cols = _build_feature_columns(eeg_names)
377
-
378
- rows: list[np.ndarray] = []
379
- for ep in valid_ep_indices:
380
- start = ep * n_samples_per_epoch
381
- end = start + n_samples_per_epoch
382
- epoch = data[:, start:end]
383
- rows.append(compute_features_from_epoch(epoch, sfreq=sfreq))
384
-
385
  matrix = np.vstack(rows)
386
  out = pd.DataFrame(matrix, columns=feature_cols, dtype=np.float64)
387
  logger.info(
 
316
  A `pd.DataFrame` with one row per valid epoch and ``n_eeg_channels *
317
  (len(EEG_BANDS) + len(STATS))`` ``feat_*`` columns.
318
  """
319
+ filtered = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
320
+ cleaned = remove_artifacts_with_ica(
321
+ filtered,
322
+ eog_ch_name=eog_ch_name,
323
+ n_components=n_components,
324
+ random_state=random_state,
325
+ )
326
+
327
+ sfreq = float(cleaned.info["sfreq"])
328
  n_samples_per_epoch = int(round(epoch_duration_s * sfreq))
329
+ eeg_picks = mne.pick_types(cleaned.info, eeg=True, meg=False, eog=False)
330
+ eeg_names = [cleaned.ch_names[i] for i in eeg_picks]
331
+ data = cleaned.get_data(picks=eeg_picks) # shape (n_eeg, n_times)
332
+ n_eeg, n_times = data.shape
333
  n_total_epochs = n_times // n_samples_per_epoch
334
 
335
+ feature_cols = _build_feature_columns(eeg_names)
336
+ rows: list[np.ndarray] = []
337
  invalid_indices: list[int] = []
338
  for ep in range(n_total_epochs):
339
  start = ep * n_samples_per_epoch
340
  end = start + n_samples_per_epoch
341
+ epoch = data[:, start:end]
342
+ if not is_valid_epoch(epoch):
 
 
343
  invalid_indices.append(ep)
344
+ continue
345
+ rows.append(compute_features_from_epoch(epoch, sfreq=sfreq))
 
 
 
346
 
347
  n_dropped = len(invalid_indices)
348
  if n_dropped:
 
355
  n_dropped, n_total_epochs, display, suffix,
356
  )
357
 
358
+ if not rows:
359
  logger.info(
360
  "Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
361
  n_total_epochs, n_dropped,
 
363
  )
364
  return pd.DataFrame(columns=feature_cols).astype(np.float64)
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  matrix = np.vstack(rows)
367
  out = pd.DataFrame(matrix, columns=feature_cols, dtype=np.float64)
368
  logger.info(
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -284,10 +284,17 @@ class TestExtractFeaturesFromRecording:
284
  assert np.isfinite(df[feat_cols].to_numpy()).all()
285
 
286
  def test_drops_invalid_epochs_with_warning(self) -> None:
287
- """If an epoch contains NaN, it is logged and dropped."""
 
 
 
 
 
 
 
 
288
  raw = self._load()
289
- # Inject a NaN into the last 2-second window so that exactly one epoch
290
- # fails `is_valid_epoch`.
291
  data = raw.get_data().copy()
292
  data[0, -10] = np.nan
293
  bad_raw = mne.io.RawArray(data, raw.info, verbose="ERROR")
@@ -295,5 +302,9 @@ class TestExtractFeaturesFromRecording:
295
  bad_raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
296
  n_components=4, random_state=97,
297
  )
298
- # 5 epochs minus 1 dropped = 4
299
- assert len(df) == 4
 
 
 
 
 
284
  assert np.isfinite(df[feat_cols].to_numpy()).all()
285
 
286
  def test_drops_invalid_epochs_with_warning(self) -> None:
287
+ """A NaN in the recording: at least one epoch dropped, no NaN survives.
288
+
289
+ The bandpass filter is a long FIR convolution, so a single NaN sample
290
+ spreads across many samples. The principled behavior is therefore:
291
+ (a) drop every contaminated epoch, not just the source epoch, and
292
+ (b) guarantee no NaN in the output. The exact drop count depends on
293
+ the filter's FIR length, so we assert range + cleanliness instead of
294
+ an exact number.
295
+ """
296
  raw = self._load()
297
+ # Inject a NaN into the last 2-second window.
 
298
  data = raw.get_data().copy()
299
  data[0, -10] = np.nan
300
  bad_raw = mne.io.RawArray(data, raw.info, verbose="ERROR")
 
302
  bad_raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
303
  n_components=4, random_state=97,
304
  )
305
+ # At least one epoch dropped (vs the clean 5-row baseline).
306
+ assert len(df) < 5
307
+ # No NaN/inf must survive into the feature table.
308
+ feat_cols = [c for c in df.columns if c.startswith("feat_")]
309
+ assert df[feat_cols].notna().all().all()
310
+ assert np.isfinite(df[feat_cols].to_numpy()).all()