File size: 2,332 Bytes
a3f2882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""EEG classifier inference utilities.

Loads any sklearn-style classifier (object with `predict_proba`) from joblib
and emits the same dict shape as src.models.mri_model.predict_with_proba so
the API surface and fusion engine treat MRI and EEG predictions identically.

The real pretrained artifact swaps in at data/processed/eeg_clf.joblib (or
override via EEG_CLF_ARTIFACT env). Tests use a stub fixture; the real model
drops in without code changes.
"""
from __future__ import annotations

import os
from pathlib import Path
from typing import Any, Sequence

import joblib
import numpy as np

from src.core.logger import get_logger

logger = get_logger(__name__)

DEFAULT_LABELS: tuple[str, ...] = ("control", "alzheimers")


def _resolve_labels() -> tuple[str, ...]:
    raw = os.environ.get("EEG_CLF_LABELS")
    if not raw:
        return DEFAULT_LABELS
    return tuple(s.strip() for s in raw.split(",") if s.strip())


def load(path: Path) -> Any:
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"EEG classifier artifact not found: {path}")
    return joblib.load(str(path))


def predict_features(
    model: Any,
    features: np.ndarray,
    labels: Sequence[str] | None = None,
) -> dict[str, Any]:
    """Run inference on one row of EEG features."""
    arr = np.asarray(features, dtype=np.float32).reshape(-1)
    expected = int(getattr(model, "n_features_in_", arr.size))
    if arr.size != expected:
        raise ValueError(
            f"EEG feature count mismatch: model expects {expected}, got {arr.size}"
        )

    proba = np.asarray(model.predict_proba(arr.reshape(1, -1))[0], dtype=np.float32)
    label_names = tuple(labels or _resolve_labels())
    if len(label_names) != proba.shape[0]:
        logger.warning(
            "EEG label count (%d) != model output dim (%d); falling back to class_0..N",
            len(label_names), proba.shape[0],
        )
        label_names = tuple(f"class_{i}" for i in range(proba.shape[0]))

    label_idx = int(np.argmax(proba))
    return {
        "label": label_idx,
        "label_text": label_names[label_idx],
        "confidence": float(proba[label_idx]),
        "probabilities": [
            {"label": i, "label_text": label_names[i], "probability": float(p)}
            for i, p in enumerate(proba)
        ],
    }