mekosotto Claude Opus 4.7 (1M context) commited on
Commit
0d591d4
·
1 Parent(s): c26c6a2

fix(eeg): NaN-clean features for flat channels; guard zero-size epochs; assert WARNING

Browse files
src/pipelines/eeg_pipeline.py CHANGED
@@ -253,6 +253,11 @@ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
253
  - ``kurtosis`` uses ``scipy.stats.kurtosis(fisher=True, bias=True)`` —
254
  Fisher's *excess* kurtosis (Gaussian → 0, not 3). Add 3 if Pearson
255
  kurtosis is required downstream.
 
 
 
 
 
256
 
257
  Precondition: `epoch` must be finite (no NaN/inf). Filter via
258
  `is_valid_epoch` before calling — feature values are NaN-propagating.
@@ -274,7 +279,11 @@ def compute_features_from_epoch(epoch: np.ndarray, sfreq: float) -> np.ndarray:
274
  feats.append(_band_power(freqs, psd, lo, hi))
275
  for _name, fn in _STATS_FUNCS:
276
  feats.append(fn(x))
277
- return np.asarray(feats, dtype=np.float64)
 
 
 
 
278
 
279
 
280
  def _build_feature_columns(eeg_ch_names: list[str]) -> list[str]:
@@ -315,6 +324,11 @@ def extract_features_from_recording(
315
  Returns:
316
  A `pd.DataFrame` with one row per valid epoch and ``n_eeg_channels *
317
  (len(EEG_BANDS) + len(STATS))`` ``feat_*`` columns.
 
 
 
 
 
318
  """
319
  filtered = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
320
  cleaned = remove_artifacts_with_ica(
@@ -326,10 +340,15 @@ def extract_features_from_recording(
326
 
327
  sfreq = float(cleaned.info["sfreq"])
328
  n_samples_per_epoch = int(round(epoch_duration_s * sfreq))
 
 
 
 
 
329
  eeg_picks = mne.pick_types(cleaned.info, eeg=True, meg=False, eog=False)
330
  eeg_names = [cleaned.ch_names[i] for i in eeg_picks]
331
  data = cleaned.get_data(picks=eeg_picks) # shape (n_eeg, n_times)
332
- n_eeg, n_times = data.shape
333
  n_total_epochs = n_times // n_samples_per_epoch
334
 
335
  feature_cols = _build_feature_columns(eeg_names)
 
253
  - ``kurtosis`` uses ``scipy.stats.kurtosis(fisher=True, bias=True)`` —
254
  Fisher's *excess* kurtosis (Gaussian → 0, not 3). Add 3 if Pearson
255
  kurtosis is required downstream.
256
+ - For constant-valued channels (zero variance), ``skew`` and
257
+ ``kurtosis`` are mathematically undefined and scipy returns NaN.
258
+ We post-process the feature vector with ``np.nan_to_num`` to map
259
+ any NaN/inf to 0.0, preserving the "no NaN survives" Parquet
260
+ contract from AGENTS.md §6.
261
 
262
  Precondition: `epoch` must be finite (no NaN/inf). Filter via
263
  `is_valid_epoch` before calling — feature values are NaN-propagating.
 
279
  feats.append(_band_power(freqs, psd, lo, hi))
280
  for _name, fn in _STATS_FUNCS:
281
  feats.append(fn(x))
282
+ arr = np.asarray(feats, dtype=np.float64)
283
+ # Constant-valued / zero-variance channels (e.g., disconnected electrodes)
284
+ # make scipy.stats.skew / kurtosis return NaN. Map those to 0.0 so the
285
+ # downstream Parquet contract ("no NaN in feature table") holds.
286
+ return np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
287
 
288
 
289
  def _build_feature_columns(eeg_ch_names: list[str]) -> list[str]:
 
324
  Returns:
325
  A `pd.DataFrame` with one row per valid epoch and ``n_eeg_channels *
326
  (len(EEG_BANDS) + len(STATS))`` ``feat_*`` columns.
327
+
328
+ Raises:
329
+ ValueError: if `epoch_duration_s * sfreq` rounds to less than 1 sample.
330
+ (Other ValueError sources can propagate from `bandpass_filter`
331
+ and `remove_artifacts_with_ica`; see their respective docstrings.)
332
  """
333
  filtered = bandpass_filter(raw, l_freq=1.0, h_freq=40.0)
334
  cleaned = remove_artifacts_with_ica(
 
340
 
341
  sfreq = float(cleaned.info["sfreq"])
342
  n_samples_per_epoch = int(round(epoch_duration_s * sfreq))
343
+ if n_samples_per_epoch < 1:
344
+ raise ValueError(
345
+ f"epoch_duration_s={epoch_duration_s!r} at sfreq={sfreq} Hz produces "
346
+ f"{n_samples_per_epoch} samples per epoch (must be >= 1)"
347
+ )
348
  eeg_picks = mne.pick_types(cleaned.info, eeg=True, meg=False, eog=False)
349
  eeg_names = [cleaned.ch_names[i] for i in eeg_picks]
350
  data = cleaned.get_data(picks=eeg_picks) # shape (n_eeg, n_times)
351
+ _, n_times = data.shape
352
  n_total_epochs = n_times // n_samples_per_epoch
353
 
354
  feature_cols = _build_feature_columns(eeg_names)
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -229,6 +229,12 @@ class TestComputeFeaturesFromEpoch:
229
  derived_names = tuple(name for name, _ in _STATS_FUNCS)
230
  assert derived_names == STATS
231
 
 
 
 
 
 
 
232
 
233
  class TestExtractFeaturesFromRecording:
234
  def _load(self) -> mne.io.BaseRaw:
@@ -284,27 +290,54 @@ class TestExtractFeaturesFromRecording:
284
  assert np.isfinite(df[feat_cols].to_numpy()).all()
285
 
286
  def test_drops_invalid_epochs_with_warning(self) -> None:
287
- """A NaN in the recording: at least one epoch dropped, no NaN survives.
288
 
289
  The bandpass filter is a long FIR convolution, so a single NaN sample
290
  spreads across many samples. The principled behavior is therefore:
291
  (a) drop every contaminated epoch, not just the source epoch, and
292
  (b) guarantee no NaN in the output. The exact drop count depends on
293
  the filter's FIR length, so we assert range + cleanliness instead of
294
- an exact number.
 
295
  """
 
 
 
 
 
 
296
  raw = self._load()
297
- # Inject a NaN into the last 2-second window.
298
  data = raw.get_data().copy()
299
  data[0, -10] = np.nan
300
  bad_raw = mne.io.RawArray(data, raw.info, verbose="ERROR")
301
- df = extract_features_from_recording(
302
- bad_raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
303
- n_components=4, random_state=97,
304
- )
 
 
 
 
 
 
 
 
 
 
305
  # At least one epoch dropped (vs the clean 5-row baseline).
306
  assert len(df) < 5
307
  # No NaN/inf must survive into the feature table.
308
  feat_cols = [c for c in df.columns if c.startswith("feat_")]
309
  assert df[feat_cols].notna().all().all()
310
  assert np.isfinite(df[feat_cols].to_numpy()).all()
 
 
 
 
 
 
 
 
 
 
 
 
229
  derived_names = tuple(name for name, _ in _STATS_FUNCS)
230
  assert derived_names == STATS
231
 
232
+ def test_constant_channel_yields_finite_features(self) -> None:
233
+ """A flat-line channel must not produce NaN features (skew/kurtosis are undefined for zero-variance)."""
234
+ epoch = np.zeros((4, 512), dtype=np.float64)
235
+ out = compute_features_from_epoch(epoch, sfreq=256.0)
236
+ assert np.all(np.isfinite(out))
237
+
238
 
239
  class TestExtractFeaturesFromRecording:
240
  def _load(self) -> mne.io.BaseRaw:
 
290
  assert np.isfinite(df[feat_cols].to_numpy()).all()
291
 
292
  def test_drops_invalid_epochs_with_warning(self) -> None:
293
+ """A NaN in the recording: at least one epoch dropped, no NaN survives, WARNING is logged.
294
 
295
  The bandpass filter is a long FIR convolution, so a single NaN sample
296
  spreads across many samples. The principled behavior is therefore:
297
  (a) drop every contaminated epoch, not just the source epoch, and
298
  (b) guarantee no NaN in the output. The exact drop count depends on
299
  the filter's FIR length, so we assert range + cleanliness instead of
300
+ an exact number. The WARNING line is part of the AGENTS.md §4
301
+ traceability contract and must always fire when drops happen.
302
  """
303
+ import io
304
+ import logging
305
+
306
+ from src.core.logger import get_logger
307
+ from src.pipelines import eeg_pipeline as mod
308
+
309
  raw = self._load()
 
310
  data = raw.get_data().copy()
311
  data[0, -10] = np.nan
312
  bad_raw = mne.io.RawArray(data, raw.info, verbose="ERROR")
313
+
314
+ logger = get_logger(mod.__name__, level=logging.INFO)
315
+ handler = logger.handlers[0]
316
+ buf = io.StringIO()
317
+ original_stream = handler.stream
318
+ handler.stream = buf
319
+ try:
320
+ df = extract_features_from_recording(
321
+ bad_raw, epoch_duration_s=2.0, eog_ch_name="EOG061",
322
+ n_components=4, random_state=97,
323
+ )
324
+ finally:
325
+ handler.stream = original_stream
326
+
327
  # At least one epoch dropped (vs the clean 5-row baseline).
328
  assert len(df) < 5
329
  # No NaN/inf must survive into the feature table.
330
  feat_cols = [c for c in df.columns if c.startswith("feat_")]
331
  assert df[feat_cols].notna().all().all()
332
  assert np.isfinite(df[feat_cols].to_numpy()).all()
333
+ # AGENTS.md §4: the WARNING line was actually emitted.
334
+ log_output = buf.getvalue()
335
+ assert "Dropping" in log_output and "epochs with invalid samples" in log_output
336
+
337
+ def test_raises_when_epoch_duration_too_small(self) -> None:
338
+ raw = self._load()
339
+ with pytest.raises(ValueError, match="must be >= 1"):
340
+ extract_features_from_recording(
341
+ raw, epoch_duration_s=1e-6, eog_ch_name="EOG061",
342
+ n_components=4, random_state=97,
343
+ )