File size: 2,412 Bytes
a3f2882 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """Tests for src.models.eeg_model."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
from src.models import eeg_model
from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg
class TestEEGModel:
def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="EEG classifier artifact not found"):
eeg_model.load(tmp_path / "nope.joblib")
def test_predict_returns_full_dict(self, tmp_path: Path) -> None:
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
clf = eeg_model.load(ckpt)
features = np.zeros((16,), dtype=np.float32)
out = eeg_model.predict_features(clf, features)
assert set(out) == {"label", "label_text", "confidence", "probabilities"}
assert out["label"] in {0, 1}
assert out["label_text"] in eeg_model.DEFAULT_LABELS
assert 0.0 <= out["confidence"] <= 1.0
probs = out["probabilities"]
assert len(probs) == 2
assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5
def test_alzheimers_separation_with_synthetic_features(self, tmp_path: Path) -> None:
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
clf = eeg_model.load(ckpt)
alz_features = np.full((16,), 2.0, dtype=np.float32)
ctrl_features = np.zeros((16,), dtype=np.float32)
alz_pred = eeg_model.predict_features(clf, alz_features)
ctrl_pred = eeg_model.predict_features(clf, ctrl_features)
assert alz_pred["label_text"] == "alzheimers"
assert ctrl_pred["label_text"] == "control"
def test_label_override_via_env(self, tmp_path: Path, monkeypatch) -> None:
monkeypatch.setenv("EEG_CLF_LABELS", "no_disease,alzheimers")
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
clf = eeg_model.load(ckpt)
out = eeg_model.predict_features(clf, np.zeros((16,), dtype=np.float32))
assert out["label_text"] in {"no_disease", "alzheimers"}
def test_feature_count_mismatch_raises(self, tmp_path: Path) -> None:
ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
clf = eeg_model.load(ckpt)
with pytest.raises(ValueError, match="feature count"):
eeg_model.predict_features(clf, np.zeros((8,), dtype=np.float32))
|