mekosotto commited on
Commit
e3c6c58
·
1 Parent(s): 7e0ed24

fix(eeg): tighten is_valid_epoch signature; reject non-numeric dtype + -inf

Browse files
src/pipelines/eeg_pipeline.py CHANGED
@@ -19,12 +19,12 @@ from src.core.logger import get_logger
19
  logger = get_logger(__name__)
20
 
21
 
22
- def is_valid_epoch(epoch: object) -> bool:
23
- """Return True iff `epoch` is a non-empty 2-D float array with no NaN/inf.
24
 
25
- Used to drop corrupted segments before feature extraction. Defensive
26
- against the full set of garbage we expect from real recordings: lists,
27
- None, NaN/inf samples, zero-sized arrays.
28
  """
29
  if not isinstance(epoch, np.ndarray):
30
  return False
@@ -32,6 +32,8 @@ def is_valid_epoch(epoch: object) -> bool:
32
  return False
33
  if epoch.size == 0:
34
  return False
 
 
35
  if not np.all(np.isfinite(epoch)):
36
  return False
37
  return True
 
19
  logger = get_logger(__name__)
20
 
21
 
22
+ def is_valid_epoch(epoch: np.ndarray | None) -> bool:
23
+ """Return True iff `epoch` is a non-empty 2-D numeric array with no NaN/inf.
24
 
25
+ The annotation is the *expected* input class; the implementation defensively
26
+ rejects any other garbage (lists, scalars, string dtypes, zero-sized arrays)
27
+ without raising matching the BBB pipeline's `is_valid_smiles` pattern.
28
  """
29
  if not isinstance(epoch, np.ndarray):
30
  return False
 
32
  return False
33
  if epoch.size == 0:
34
  return False
35
+ if not np.issubdtype(epoch.dtype, np.number):
36
+ return False
37
  if not np.all(np.isfinite(epoch)):
38
  return False
39
  return True
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -30,6 +30,8 @@ class TestIsValidEpoch:
30
  epoch = np.zeros((4, 256))
31
  epoch[1, 5] = np.inf
32
  assert is_valid_epoch(epoch) is False
 
 
33
 
34
  def test_rejects_empty(self) -> None:
35
  assert is_valid_epoch(np.zeros((0, 256))) is False
@@ -38,3 +40,8 @@ class TestIsValidEpoch:
38
  def test_rejects_non_array(self) -> None:
39
  assert is_valid_epoch([[1, 2, 3]]) is False
40
  assert is_valid_epoch(None) is False
 
 
 
 
 
 
30
  epoch = np.zeros((4, 256))
31
  epoch[1, 5] = np.inf
32
  assert is_valid_epoch(epoch) is False
33
+ epoch[1, 5] = -np.inf
34
+ assert is_valid_epoch(epoch) is False
35
 
36
  def test_rejects_empty(self) -> None:
37
  assert is_valid_epoch(np.zeros((0, 256))) is False
 
40
  def test_rejects_non_array(self) -> None:
41
  assert is_valid_epoch([[1, 2, 3]]) is False
42
  assert is_valid_epoch(None) is False
43
+
44
+ def test_rejects_non_numeric_dtype(self) -> None:
45
+ """String / object dtype arrays must be rejected without raising."""
46
+ epoch = np.array([["a", "b"], ["c", "d"]])
47
+ assert is_valid_epoch(epoch) is False