| """Tests for src.models.eeg_model.""" |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import pytest |
|
|
| from src.models import eeg_model |
| from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg |
|
|
|
|
| class TestEEGModel: |
| def test_load_missing_artifact_raises(self, tmp_path: Path) -> None: |
| with pytest.raises(FileNotFoundError, match="EEG classifier artifact not found"): |
| eeg_model.load(tmp_path / "nope.joblib") |
|
|
| def test_predict_returns_full_dict(self, tmp_path: Path) -> None: |
| ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16) |
| clf = eeg_model.load(ckpt) |
| features = np.zeros((16,), dtype=np.float32) |
|
|
| out = eeg_model.predict_features(clf, features) |
|
|
| assert set(out) == {"label", "label_text", "confidence", "probabilities"} |
| assert out["label"] in {0, 1} |
| assert out["label_text"] in eeg_model.DEFAULT_LABELS |
| assert 0.0 <= out["confidence"] <= 1.0 |
| probs = out["probabilities"] |
| assert len(probs) == 2 |
| assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5 |
|
|
| def test_alzheimers_separation_with_synthetic_features(self, tmp_path: Path) -> None: |
| ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16) |
| clf = eeg_model.load(ckpt) |
| alz_features = np.full((16,), 2.0, dtype=np.float32) |
| ctrl_features = np.zeros((16,), dtype=np.float32) |
|
|
| alz_pred = eeg_model.predict_features(clf, alz_features) |
| ctrl_pred = eeg_model.predict_features(clf, ctrl_features) |
|
|
| assert alz_pred["label_text"] == "alzheimers" |
| assert ctrl_pred["label_text"] == "control" |
|
|
| def test_label_override_via_env(self, tmp_path: Path, monkeypatch) -> None: |
| monkeypatch.setenv("EEG_CLF_LABELS", "no_disease,alzheimers") |
| ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16) |
| clf = eeg_model.load(ckpt) |
| out = eeg_model.predict_features(clf, np.zeros((16,), dtype=np.float32)) |
| assert out["label_text"] in {"no_disease", "alzheimers"} |
|
|
| def test_feature_count_mismatch_raises(self, tmp_path: Path) -> None: |
| ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16) |
| clf = eeg_model.load(ckpt) |
| with pytest.raises(ValueError, match="feature count"): |
| eeg_model.predict_features(clf, np.zeros((8,), dtype=np.float32)) |
|
|