fix(eeg): validate epochs after filter to guarantee no NaN in feature table
Browse filesReplace pre-screen-before-filter logic with principled filter→ICA→epoch→validate
order so that FIR-spread NaN is caught at epoch validation time, not missed after
surviving epochs are extracted from contaminated filtered data.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
src/pipelines/eeg_pipeline.py
CHANGED
|
@@ -316,31 +316,33 @@ def extract_features_from_recording(
|
|
| 316 |
A `pd.DataFrame` with one row per valid epoch and ``n_eeg_channels *
|
| 317 |
(len(EEG_BANDS) + len(STATS))`` ``feat_*`` columns.
|
| 318 |
"""
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
n_samples_per_epoch = int(round(epoch_duration_s * sfreq))
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
| 327 |
n_total_epochs = n_times // n_samples_per_epoch
|
| 328 |
|
| 329 |
-
|
|
|
|
| 330 |
invalid_indices: list[int] = []
|
| 331 |
for ep in range(n_total_epochs):
|
| 332 |
start = ep * n_samples_per_epoch
|
| 333 |
end = start + n_samples_per_epoch
|
| 334 |
-
|
| 335 |
-
if is_valid_epoch(
|
| 336 |
-
valid_ep_indices.append(ep)
|
| 337 |
-
else:
|
| 338 |
invalid_indices.append(ep)
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
feature_cols = _build_feature_columns(
|
| 342 |
-
[raw.ch_names[i] for i in pre_picks]
|
| 343 |
-
)
|
| 344 |
|
| 345 |
n_dropped = len(invalid_indices)
|
| 346 |
if n_dropped:
|
|
@@ -353,7 +355,7 @@ def extract_features_from_recording(
|
|
| 353 |
n_dropped, n_total_epochs, display, suffix,
|
| 354 |
)
|
| 355 |
|
| 356 |
-
if not
|
| 357 |
logger.info(
|
| 358 |
"Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
|
| 359 |
n_total_epochs, n_dropped,
|
|
@@ -361,27 +363,6 @@ def extract_features_from_recording(
|
|
| 361 |
)
|
| 362 |
return pd.DataFrame(columns=feature_cols).astype(np.float64)
|
| 363 |
|
| 364 |
-
filtered = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
|
| 365 |
-
cleaned = remove_artifacts_with_ica(
|
| 366 |
-
filtered,
|
| 367 |
-
eog_ch_name=eog_ch_name,
|
| 368 |
-
n_components=n_components,
|
| 369 |
-
random_state=random_state,
|
| 370 |
-
)
|
| 371 |
-
|
| 372 |
-
eeg_picks = mne.pick_types(cleaned.info, eeg=True, meg=False, eog=False)
|
| 373 |
-
eeg_names = [cleaned.ch_names[i] for i in eeg_picks]
|
| 374 |
-
data = cleaned.get_data(picks=eeg_picks) # shape (n_eeg, n_times)
|
| 375 |
-
# Rebuild feature_cols using post-ICA channel order (should match pre_picks).
|
| 376 |
-
feature_cols = _build_feature_columns(eeg_names)
|
| 377 |
-
|
| 378 |
-
rows: list[np.ndarray] = []
|
| 379 |
-
for ep in valid_ep_indices:
|
| 380 |
-
start = ep * n_samples_per_epoch
|
| 381 |
-
end = start + n_samples_per_epoch
|
| 382 |
-
epoch = data[:, start:end]
|
| 383 |
-
rows.append(compute_features_from_epoch(epoch, sfreq=sfreq))
|
| 384 |
-
|
| 385 |
matrix = np.vstack(rows)
|
| 386 |
out = pd.DataFrame(matrix, columns=feature_cols, dtype=np.float64)
|
| 387 |
logger.info(
|
|
|
|
| 316 |
A `pd.DataFrame` with one row per valid epoch and ``n_eeg_channels *
|
| 317 |
(len(EEG_BANDS) + len(STATS))`` ``feat_*`` columns.
|
| 318 |
"""
|
| 319 |
+
filtered = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
|
| 320 |
+
cleaned = remove_artifacts_with_ica(
|
| 321 |
+
filtered,
|
| 322 |
+
eog_ch_name=eog_ch_name,
|
| 323 |
+
n_components=n_components,
|
| 324 |
+
random_state=random_state,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
sfreq = float(cleaned.info["sfreq"])
|
| 328 |
n_samples_per_epoch = int(round(epoch_duration_s * sfreq))
|
| 329 |
+
eeg_picks = mne.pick_types(cleaned.info, eeg=True, meg=False, eog=False)
|
| 330 |
+
eeg_names = [cleaned.ch_names[i] for i in eeg_picks]
|
| 331 |
+
data = cleaned.get_data(picks=eeg_picks) # shape (n_eeg, n_times)
|
| 332 |
+
n_eeg, n_times = data.shape
|
| 333 |
n_total_epochs = n_times // n_samples_per_epoch
|
| 334 |
|
| 335 |
+
feature_cols = _build_feature_columns(eeg_names)
|
| 336 |
+
rows: list[np.ndarray] = []
|
| 337 |
invalid_indices: list[int] = []
|
| 338 |
for ep in range(n_total_epochs):
|
| 339 |
start = ep * n_samples_per_epoch
|
| 340 |
end = start + n_samples_per_epoch
|
| 341 |
+
epoch = data[:, start:end]
|
| 342 |
+
if not is_valid_epoch(epoch):
|
|
|
|
|
|
|
| 343 |
invalid_indices.append(ep)
|
| 344 |
+
continue
|
| 345 |
+
rows.append(compute_features_from_epoch(epoch, sfreq=sfreq))
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
n_dropped = len(invalid_indices)
|
| 348 |
if n_dropped:
|
|
|
|
| 355 |
n_dropped, n_total_epochs, display, suffix,
|
| 356 |
)
|
| 357 |
|
| 358 |
+
if not rows:
|
| 359 |
logger.info(
|
| 360 |
"Feature extraction complete: in=%d, out=0, dropped=%d (%.2f%%)",
|
| 361 |
n_total_epochs, n_dropped,
|
|
|
|
| 363 |
)
|
| 364 |
return pd.DataFrame(columns=feature_cols).astype(np.float64)
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
matrix = np.vstack(rows)
|
| 367 |
out = pd.DataFrame(matrix, columns=feature_cols, dtype=np.float64)
|
| 368 |
logger.info(
|
tests/pipelines/test_eeg_pipeline.py
CHANGED
|
@@ -284,10 +284,17 @@ class TestExtractFeaturesFromRecording:
|
|
| 284 |
assert np.isfinite(df[feat_cols].to_numpy()).all()
|
| 285 |
|
| 286 |
def test_drops_invalid_epochs_with_warning(self) -> None:
|
| 287 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
raw = self._load()
|
| 289 |
-
# Inject a NaN into the last 2-second window
|
| 290 |
-
# fails `is_valid_epoch`.
|
| 291 |
data = raw.get_data().copy()
|
| 292 |
data[0, -10] = np.nan
|
| 293 |
bad_raw = mne.io.RawArray(data, raw.info, verbose="ERROR")
|
|
@@ -295,5 +302,9 @@ class TestExtractFeaturesFromRecording:
|
|
| 295 |
bad_raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
|
| 296 |
n_components=4, random_state=97,
|
| 297 |
)
|
| 298 |
-
#
|
| 299 |
-
assert len(df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
assert np.isfinite(df[feat_cols].to_numpy()).all()
|
| 285 |
|
| 286 |
def test_drops_invalid_epochs_with_warning(self) -> None:
|
| 287 |
+
"""A NaN in the recording: at least one epoch dropped, no NaN survives.
|
| 288 |
+
|
| 289 |
+
The bandpass filter is a long FIR convolution, so a single NaN sample
|
| 290 |
+
spreads across many samples. The principled behavior is therefore:
|
| 291 |
+
(a) drop every contaminated epoch, not just the source epoch, and
|
| 292 |
+
(b) guarantee no NaN in the output. The exact drop count depends on
|
| 293 |
+
the filter's FIR length, so we assert range + cleanliness instead of
|
| 294 |
+
an exact number.
|
| 295 |
+
"""
|
| 296 |
raw = self._load()
|
| 297 |
+
# Inject a NaN into the last 2-second window.
|
|
|
|
| 298 |
data = raw.get_data().copy()
|
| 299 |
data[0, -10] = np.nan
|
| 300 |
bad_raw = mne.io.RawArray(data, raw.info, verbose="ERROR")
|
|
|
|
| 302 |
bad_raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
|
| 303 |
n_components=4, random_state=97,
|
| 304 |
)
|
| 305 |
+
# At least one epoch dropped (vs the clean 5-row baseline).
|
| 306 |
+
assert len(df) < 5
|
| 307 |
+
# No NaN/inf must survive into the feature table.
|
| 308 |
+
feat_cols = [c for c in df.columns if c.startswith("feat_")]
|
| 309 |
+
assert df[feat_cols].notna().all().all()
|
| 310 |
+
assert np.isfinite(df[feat_cols].to_numpy()).all()
|