| """EEG classifier inference utilities. |
| |
| Loads any sklearn-style classifier (object with `predict_proba`) from joblib |
| and emits the same dict shape as src.models.mri_model.predict_with_proba so |
| the API surface and fusion engine treat MRI and EEG predictions identically. |
| |
| The real pretrained artifact swaps in at data/processed/eeg_clf.joblib (or |
| override via EEG_CLF_ARTIFACT env). Tests use a stub fixture; the real model |
| drops in without code changes. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any, Sequence |
|
|
| import joblib |
| import numpy as np |
|
|
| from src.core.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| DEFAULT_LABELS: tuple[str, ...] = ("control", "alzheimers") |
|
|
|
|
| def _resolve_labels() -> tuple[str, ...]: |
| raw = os.environ.get("EEG_CLF_LABELS") |
| if not raw: |
| return DEFAULT_LABELS |
| return tuple(s.strip() for s in raw.split(",") if s.strip()) |
|
|
|
|
| def load(path: Path) -> Any: |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"EEG classifier artifact not found: {path}") |
| return joblib.load(str(path)) |
|
|
|
|
| def predict_features( |
| model: Any, |
| features: np.ndarray, |
| labels: Sequence[str] | None = None, |
| ) -> dict[str, Any]: |
| """Run inference on one row of EEG features.""" |
| arr = np.asarray(features, dtype=np.float32).reshape(-1) |
| expected = int(getattr(model, "n_features_in_", arr.size)) |
| if arr.size != expected: |
| raise ValueError( |
| f"EEG feature count mismatch: model expects {expected}, got {arr.size}" |
| ) |
|
|
| proba = np.asarray(model.predict_proba(arr.reshape(1, -1))[0], dtype=np.float32) |
| label_names = tuple(labels or _resolve_labels()) |
| if len(label_names) != proba.shape[0]: |
| logger.warning( |
| "EEG label count (%d) != model output dim (%d); falling back to class_0..N", |
| len(label_names), proba.shape[0], |
| ) |
| label_names = tuple(f"class_{i}" for i in range(proba.shape[0])) |
|
|
| label_idx = int(np.argmax(proba)) |
| return { |
| "label": label_idx, |
| "label_text": label_names[label_idx], |
| "confidence": float(proba[label_idx]), |
| "probabilities": [ |
| {"label": i, "label_text": label_names[i], "probability": float(p)} |
| for i, p in enumerate(proba) |
| ], |
| } |
|
|