File size: 2,412 Bytes
a3f2882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Tests for src.models.eeg_model."""
from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from src.models import eeg_model
from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg


class TestEEGModel:
    def test_load_missing_artifact_raises(self, tmp_path: Path) -> None:
        with pytest.raises(FileNotFoundError, match="EEG classifier artifact not found"):
            eeg_model.load(tmp_path / "nope.joblib")

    def test_predict_returns_full_dict(self, tmp_path: Path) -> None:
        ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
        clf = eeg_model.load(ckpt)
        features = np.zeros((16,), dtype=np.float32)

        out = eeg_model.predict_features(clf, features)

        assert set(out) == {"label", "label_text", "confidence", "probabilities"}
        assert out["label"] in {0, 1}
        assert out["label_text"] in eeg_model.DEFAULT_LABELS
        assert 0.0 <= out["confidence"] <= 1.0
        probs = out["probabilities"]
        assert len(probs) == 2
        assert abs(sum(p["probability"] for p in probs) - 1.0) < 1e-5

    def test_alzheimers_separation_with_synthetic_features(self, tmp_path: Path) -> None:
        ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
        clf = eeg_model.load(ckpt)
        alz_features = np.full((16,), 2.0, dtype=np.float32)
        ctrl_features = np.zeros((16,), dtype=np.float32)

        alz_pred = eeg_model.predict_features(clf, alz_features)
        ctrl_pred = eeg_model.predict_features(clf, ctrl_features)

        assert alz_pred["label_text"] == "alzheimers"
        assert ctrl_pred["label_text"] == "control"

    def test_label_override_via_env(self, tmp_path: Path, monkeypatch) -> None:
        monkeypatch.setenv("EEG_CLF_LABELS", "no_disease,alzheimers")
        ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
        clf = eeg_model.load(ckpt)
        out = eeg_model.predict_features(clf, np.zeros((16,), dtype=np.float32))
        assert out["label_text"] in {"no_disease", "alzheimers"}

    def test_feature_count_mismatch_raises(self, tmp_path: Path) -> None:
        ckpt = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
        clf = eeg_model.load(ckpt)
        with pytest.raises(ValueError, match="feature count"):
            eeg_model.predict_features(clf, np.zeros((8,), dtype=np.float32))