feat(models): EEG classifier loader + predict (stub-able for hackathon demo)
Browse files- src/models/eeg_model.py +72 -0
- tests/fixtures/build_dummy_eeg_clf.py +34 -0
- tests/models/test_eeg_model.py +56 -0
src/models/eeg_model.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""EEG classifier inference utilities.
|
| 2 |
+
|
| 3 |
+
Loads any sklearn-style classifier (object with `predict_proba`) from joblib
|
| 4 |
+
and emits the same dict shape as src.models.mri_model.predict_with_proba so
|
| 5 |
+
the API surface and fusion engine treat MRI and EEG predictions identically.
|
| 6 |
+
|
| 7 |
+
The real pretrained artifact swaps in at data/processed/eeg_clf.joblib (or
|
| 8 |
+
override via EEG_CLF_ARTIFACT env). Tests use a stub fixture; the real model
|
| 9 |
+
drops in without code changes.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any, Sequence
|
| 16 |
+
|
| 17 |
+
import joblib
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from src.core.logger import get_logger
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
DEFAULT_LABELS: tuple[str, ...] = ("control", "alzheimers")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _resolve_labels() -> tuple[str, ...]:
|
| 28 |
+
raw = os.environ.get("EEG_CLF_LABELS")
|
| 29 |
+
if not raw:
|
| 30 |
+
return DEFAULT_LABELS
|
| 31 |
+
return tuple(s.strip() for s in raw.split(",") if s.strip())
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load(path: Path) -> Any:
|
| 35 |
+
path = Path(path)
|
| 36 |
+
if not path.exists():
|
| 37 |
+
raise FileNotFoundError(f"EEG classifier artifact not found: {path}")
|
| 38 |
+
return joblib.load(str(path))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def predict_features(
|
| 42 |
+
model: Any,
|
| 43 |
+
features: np.ndarray,
|
| 44 |
+
labels: Sequence[str] | None = None,
|
| 45 |
+
) -> dict[str, Any]:
|
| 46 |
+
"""Run inference on one row of EEG features."""
|
| 47 |
+
arr = np.asarray(features, dtype=np.float32).reshape(-1)
|
| 48 |
+
expected = int(getattr(model, "n_features_in_", arr.size))
|
| 49 |
+
if arr.size != expected:
|
| 50 |
+
raise ValueError(
|
| 51 |
+
f"EEG feature count mismatch: model expects {expected}, got {arr.size}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
proba = np.asarray(model.predict_proba(arr.reshape(1, -1))[0], dtype=np.float32)
|
| 55 |
+
label_names = tuple(labels or _resolve_labels())
|
| 56 |
+
if len(label_names) != proba.shape[0]:
|
| 57 |
+
logger.warning(
|
| 58 |
+
"EEG label count (%d) != model output dim (%d); falling back to class_0..N",
|
| 59 |
+
len(label_names), proba.shape[0],
|
| 60 |
+
)
|
| 61 |
+
label_names = tuple(f"class_{i}" for i in range(proba.shape[0]))
|
| 62 |
+
|
| 63 |
+
label_idx = int(np.argmax(proba))
|
| 64 |
+
return {
|
| 65 |
+
"label": label_idx,
|
| 66 |
+
"label_text": label_names[label_idx],
|
| 67 |
+
"confidence": float(proba[label_idx]),
|
| 68 |
+
"probabilities": [
|
| 69 |
+
{"label": i, "label_text": label_names[i], "probability": float(p)}
|
| 70 |
+
for i, p in enumerate(proba)
|
| 71 |
+
],
|
| 72 |
+
}
|
tests/fixtures/build_dummy_eeg_clf.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build a stub EEG classifier (sklearn RF) for tests.
|
| 2 |
+
|
| 3 |
+
Demo-time placeholder — produces a 2-class probability output matching the
|
| 4 |
+
eeg_model.predict_features contract. Replace with the real artifact when
|
| 5 |
+
the user provides it; tests don't change.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import joblib
|
| 12 |
+
import numpy as np
|
| 13 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build(path: Path, n_features: int = 16, seed: int = 0) -> Path:
|
| 17 |
+
"""Save a fitted RandomForestClassifier at `path` and return the path."""
|
| 18 |
+
path = Path(path)
|
| 19 |
+
if path.exists():
|
| 20 |
+
return path
|
| 21 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
rng = np.random.default_rng(seed)
|
| 24 |
+
n = 200
|
| 25 |
+
n_alz = n // 2
|
| 26 |
+
X_ctrl = rng.normal(0.0, 1.0, size=(n - n_alz, n_features))
|
| 27 |
+
X_alz = rng.normal(2.0, 1.0, size=(n_alz, n_features))
|
| 28 |
+
X = np.vstack([X_ctrl, X_alz])
|
| 29 |
+
y = np.array([0] * (n - n_alz) + [1] * n_alz)
|
| 30 |
+
|
| 31 |
+
clf = RandomForestClassifier(n_estimators=12, max_depth=6, random_state=seed)
|
| 32 |
+
clf.fit(X, y)
|
| 33 |
+
joblib.dump(clf, str(path))
|
| 34 |
+
return path
|
tests/models/test_eeg_model.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.models.eeg_model."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from src.models import eeg_model
|
| 10 |
+
from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestEEGModel:
|
| 14 |
+
def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
|
| 15 |
+
with pytest.raises(FileNotFoundError, match="EEG classifier artifact not found"):
|
| 16 |
+
eeg_model.load(tmp_path / "nope.joblib")
|
| 17 |
+
|
| 18 |
+
def test_predict_returns_full_dict(self, tmp_path: Path) -> None:
|
| 19 |
+
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
|
| 20 |
+
clf = eeg_model.load(ckpt)
|
| 21 |
+
features = np.zeros((16,), dtype=np.float32)
|
| 22 |
+
|
| 23 |
+
out = eeg_model.predict_features(clf, features)
|
| 24 |
+
|
| 25 |
+
assert set(out) == {"label", "label_text", "confidence", "probabilities"}
|
| 26 |
+
assert out["label"] in {0, 1}
|
| 27 |
+
assert out["label_text"] in eeg_model.DEFAULT_LABELS
|
| 28 |
+
assert 0.0 <= out["confidence"] <= 1.0
|
| 29 |
+
probs = out["probabilities"]
|
| 30 |
+
assert len(probs) == 2
|
| 31 |
+
assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
|
| 32 |
+
|
| 33 |
+
def test_alzheimers_separation_with_synthetic_features(self, tmp_path: Path) -> None:
|
| 34 |
+
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
|
| 35 |
+
clf = eeg_model.load(ckpt)
|
| 36 |
+
alz_features = np.full((16,), 2.0, dtype=np.float32)
|
| 37 |
+
ctrl_features = np.zeros((16,), dtype=np.float32)
|
| 38 |
+
|
| 39 |
+
alz_pred = eeg_model.predict_features(clf, alz_features)
|
| 40 |
+
ctrl_pred = eeg_model.predict_features(clf, ctrl_features)
|
| 41 |
+
|
| 42 |
+
assert alz_pred["label_text"] == "alzheimers"
|
| 43 |
+
assert ctrl_pred["label_text"] == "control"
|
| 44 |
+
|
| 45 |
+
def test_label_override_via_env(self, tmp_path: Path, monkeypatch) -> None:
|
| 46 |
+
monkeypatch.setenv("EEG_CLF_LABELS", "no_disease,alzheimers")
|
| 47 |
+
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
|
| 48 |
+
clf = eeg_model.load(ckpt)
|
| 49 |
+
out = eeg_model.predict_features(clf, np.zeros((16,), dtype=np.float32))
|
| 50 |
+
assert out["label_text"] in {"no_disease", "alzheimers"}
|
| 51 |
+
|
| 52 |
+
def test_feature_count_mismatch_raises(self, tmp_path: Path) -> None:
|
| 53 |
+
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
|
| 54 |
+
clf = eeg_model.load(ckpt)
|
| 55 |
+
with pytest.raises(ValueError, match="feature count"):
|
| 56 |
+
eeg_model.predict_features(clf, np.zeros((8,), dtype=np.float32))
|