mekosotto commited on
Commit
a3f2882
·
1 Parent(s): 27a97bf

feat(models): EEG classifier loader + predict (stub-able for hackathon demo)

Browse files
src/models/eeg_model.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EEG classifier inference utilities.
2
+
3
+ Loads any sklearn-style classifier (object with `predict_proba`) from joblib
4
+ and emits the same dict shape as src.models.mri_model.predict_with_proba so
5
+ the API surface and fusion engine treat MRI and EEG predictions identically.
6
+
7
+ The real pretrained artifact swaps in at data/processed/eeg_clf.joblib (or
8
+ override via EEG_CLF_ARTIFACT env). Tests use a stub fixture; the real model
9
+ drops in without code changes.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ from pathlib import Path
15
+ from typing import Any, Sequence
16
+
17
+ import joblib
18
+ import numpy as np
19
+
20
+ from src.core.logger import get_logger
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ DEFAULT_LABELS: tuple[str, ...] = ("control", "alzheimers")
25
+
26
+
27
+ def _resolve_labels() -> tuple[str, ...]:
28
+ raw = os.environ.get("EEG_CLF_LABELS")
29
+ if not raw:
30
+ return DEFAULT_LABELS
31
+ return tuple(s.strip() for s in raw.split(",") if s.strip())
32
+
33
+
34
+ def load(path: Path) -> Any:
35
+ path = Path(path)
36
+ if not path.exists():
37
+ raise FileNotFoundError(f"EEG classifier artifact not found: {path}")
38
+ return joblib.load(str(path))
39
+
40
+
41
+ def predict_features(
42
+ model: Any,
43
+ features: np.ndarray,
44
+ labels: Sequence[str] | None = None,
45
+ ) -> dict[str, Any]:
46
+ """Run inference on one row of EEG features."""
47
+ arr = np.asarray(features, dtype=np.float32).reshape(-1)
48
+ expected = int(getattr(model, "n_features_in_", arr.size))
49
+ if arr.size != expected:
50
+ raise ValueError(
51
+ f"EEG feature count mismatch: model expects {expected}, got {arr.size}"
52
+ )
53
+
54
+ proba = np.asarray(model.predict_proba(arr.reshape(1, -1))[0], dtype=np.float32)
55
+ label_names = tuple(labels or _resolve_labels())
56
+ if len(label_names) != proba.shape[0]:
57
+ logger.warning(
58
+ "EEG label count (%d) != model output dim (%d); falling back to class_0..N",
59
+ len(label_names), proba.shape[0],
60
+ )
61
+ label_names = tuple(f"class_{i}" for i in range(proba.shape[0]))
62
+
63
+ label_idx = int(np.argmax(proba))
64
+ return {
65
+ "label": label_idx,
66
+ "label_text": label_names[label_idx],
67
+ "confidence": float(proba[label_idx]),
68
+ "probabilities": [
69
+ {"label": i, "label_text": label_names[i], "probability": float(p)}
70
+ for i, p in enumerate(proba)
71
+ ],
72
+ }
tests/fixtures/build_dummy_eeg_clf.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build a stub EEG classifier (sklearn RF) for tests.
2
+
3
+ Demo-time placeholder — produces a 2-class probability output matching the
4
+ eeg_model.predict_features contract. Replace with the real artifact when
5
+ the user provides it; tests don't change.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+
11
+ import joblib
12
+ import numpy as np
13
+ from sklearn.ensemble import RandomForestClassifier
14
+
15
+
16
+ def build(path: Path, n_features: int = 16, seed: int = 0) -> Path:
17
+ """Save a fitted RandomForestClassifier at `path` and return the path."""
18
+ path = Path(path)
19
+ if path.exists():
20
+ return path
21
+ path.parent.mkdir(parents=True, exist_ok=True)
22
+
23
+ rng = np.random.default_rng(seed)
24
+ n = 200
25
+ n_alz = n // 2
26
+ X_ctrl = rng.normal(0.0, 1.0, size=(n - n_alz, n_features))
27
+ X_alz = rng.normal(2.0, 1.0, size=(n_alz, n_features))
28
+ X = np.vstack([X_ctrl, X_alz])
29
+ y = np.array([0] * (n - n_alz) + [1] * n_alz)
30
+
31
+ clf = RandomForestClassifier(n_estimators=12, max_depth=6, random_state=seed)
32
+ clf.fit(X, y)
33
+ joblib.dump(clf, str(path))
34
+ return path
tests/models/test_eeg_model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.models.eeg_model."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ from src.models import eeg_model
10
+ from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg
11
+
12
+
13
+ class TestEEGModel:
14
+ def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
15
+ with pytest.raises(FileNotFoundError, match="EEG classifier artifact not found"):
16
+ eeg_model.load(tmp_path / "nope.joblib")
17
+
18
+ def test_predict_returns_full_dict(self, tmp_path: Path) -> None:
19
+ ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
20
+ clf = eeg_model.load(ckpt)
21
+ features = np.zeros((16,), dtype=np.float32)
22
+
23
+ out = eeg_model.predict_features(clf, features)
24
+
25
+ assert set(out) == {"label", "label_text", "confidence", "probabilities"}
26
+ assert out["label"] in {0, 1}
27
+ assert out["label_text"] in eeg_model.DEFAULT_LABELS
28
+ assert 0.0 <= out["confidence"] <= 1.0
29
+ probs = out["probabilities"]
30
+ assert len(probs) == 2
31
+ assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
32
+
33
+ def test_alzheimers_separation_with_synthetic_features(self, tmp_path: Path) -> None:
34
+ ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
35
+ clf = eeg_model.load(ckpt)
36
+ alz_features = np.full((16,), 2.0, dtype=np.float32)
37
+ ctrl_features = np.zeros((16,), dtype=np.float32)
38
+
39
+ alz_pred = eeg_model.predict_features(clf, alz_features)
40
+ ctrl_pred = eeg_model.predict_features(clf, ctrl_features)
41
+
42
+ assert alz_pred["label_text"] == "alzheimers"
43
+ assert ctrl_pred["label_text"] == "control"
44
+
45
+ def test_label_override_via_env(self, tmp_path: Path, monkeypatch) -> None:
46
+ monkeypatch.setenv("EEG_CLF_LABELS", "no_disease,alzheimers")
47
+ ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
48
+ clf = eeg_model.load(ckpt)
49
+ out = eeg_model.predict_features(clf, np.zeros((16,), dtype=np.float32))
50
+ assert out["label_text"] in {"no_disease", "alzheimers"}
51
+
52
+ def test_feature_count_mismatch_raises(self, tmp_path: Path) -> None:
53
+ ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
54
+ clf = eeg_model.load(ckpt)
55
+ with pytest.raises(ValueError, match="feature count"):
56
+ eeg_model.predict_features(clf, np.zeros((8,), dtype=np.float32))