mekosotto commited on
Commit
a2a375c
·
1 Parent(s): a3f2882

feat(api): add POST /predict/eeg route (stub-able for demo)

Browse files
src/api/routes.py CHANGED
@@ -27,7 +27,10 @@ from src.api.schemas import (
27
  BBBPredictResponse,
28
  BBBRequest,
29
  CalibrationContext,
 
30
  EEGExplainRequest,
 
 
31
  FusionRequest,
32
  FusionResponse,
33
  EEGExplainResponse,
@@ -317,6 +320,43 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
317
  )
318
 
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  @predict_router.post("/mri", response_model=MRIPredictResponse)
321
  def predict_mri(req: MRIPredictRequest) -> MRIPredictResponse:
322
  """Predict from one MRI image. Backend selected by MRI_MODEL_KIND env.
 
27
  BBBPredictResponse,
28
  BBBRequest,
29
  CalibrationContext,
30
+ EEGClassProbability,
31
  EEGExplainRequest,
32
+ EEGPredictRequest,
33
+ EEGPredictResponse,
34
  FusionRequest,
35
  FusionResponse,
36
  EEGExplainResponse,
 
320
  )
321
 
322
 
323
+ @predict_router.post("/eeg", response_model=EEGPredictResponse)
324
+ def predict_eeg(req: EEGPredictRequest) -> EEGPredictResponse:
325
+ """Predict from EEG features using an externally-trained sklearn classifier.
326
+
327
+ Real artifact lands at data/processed/eeg_clf.joblib (override via
328
+ EEG_CLF_ARTIFACT). For the demo a stub fixture (RandomForestClassifier
329
+ on synthetic features) is acceptable — the response shape stays stable.
330
+ """
331
+ import numpy as np
332
+ from src.models import eeg_model
333
+
334
+ artifact = Path(os.environ.get("EEG_CLF_ARTIFACT", "data/processed/eeg_clf.joblib"))
335
+ if not artifact.exists():
336
+ raise HTTPException(
337
+ status_code=503,
338
+ detail=(
339
+ f"EEG model artifact not available at {artifact}. "
340
+ "Drop the trained joblib at this path or set EEG_CLF_ARTIFACT."
341
+ ),
342
+ )
343
+ try:
344
+ clf = eeg_model.load(artifact)
345
+ features = np.asarray(req.features, dtype=np.float32)
346
+ out = eeg_model.predict_features(clf, features)
347
+ except FileNotFoundError as e:
348
+ raise HTTPException(status_code=404, detail=str(e))
349
+ except ValueError as e:
350
+ raise HTTPException(status_code=400, detail=str(e))
351
+
352
+ return EEGPredictResponse(
353
+ label=int(out["label"]),
354
+ label_text=str(out["label_text"]),
355
+ confidence=float(out["confidence"]),
356
+ probabilities=[EEGClassProbability(**p) for p in out["probabilities"]],
357
+ )
358
+
359
+
360
  @predict_router.post("/mri", response_model=MRIPredictResponse)
361
  def predict_mri(req: MRIPredictRequest) -> MRIPredictResponse:
362
  """Predict from one MRI image. Backend selected by MRI_MODEL_KIND env.
src/api/schemas.py CHANGED
@@ -113,6 +113,29 @@ class BBBPredictResponse(BaseModel):
113
  )
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  class MRIPredictRequest(BaseModel):
117
  """Single-subject MRI image prediction request."""
118
  input_path: str = Field(
 
113
  )
114
 
115
 
116
+ class EEGPredictRequest(BaseModel):
117
+ """Single-subject EEG-features prediction request."""
118
+ features: list[float] = Field(
119
+ ..., min_length=1,
120
+ description="EEG features matching the classifier's training-time feature count.",
121
+ )
122
+
123
+
124
+ class EEGClassProbability(BaseModel):
125
+ """One EEG model class probability."""
126
+ label: int
127
+ label_text: str
128
+ probability: float
129
+
130
+
131
+ class EEGPredictResponse(BaseModel):
132
+ """EEG prediction payload — same shape as MRIPredictResponse minus model_path."""
133
+ label: int
134
+ label_text: str
135
+ confidence: float
136
+ probabilities: list[EEGClassProbability]
137
+
138
+
139
  class MRIPredictRequest(BaseModel):
140
  """Single-subject MRI image prediction request."""
141
  input_path: str = Field(
tests/api/test_eeg_predict_route.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration: POST /predict/eeg."""
2
+ from __future__ import annotations
3
+
4
+ import pytest
5
+ from fastapi.testclient import TestClient
6
+
7
+ from src.api.main import app
8
+ from tests.fixtures.build_dummy_eeg_clf import build as build_dummy_eeg
9
+
10
+
11
+ @pytest.fixture()
12
+ def client(monkeypatch, tmp_path):
13
+ artifact = build_dummy_eeg(tmp_path / "eeg.joblib", n_features=16)
14
+ monkeypatch.setenv("EEG_CLF_ARTIFACT", str(artifact))
15
+ return TestClient(app)
16
+
17
+
18
+ def test_predict_eeg_happy_path(client):
19
+ body = {"features": [0.0] * 16}
20
+ r = client.post("/predict/eeg", json=body)
21
+ assert r.status_code == 200, r.text
22
+ data = r.json()
23
+ assert data["label_text"] in {"control", "alzheimers"}
24
+ assert 0.0 <= data["confidence"] <= 1.0
25
+ assert len(data["probabilities"]) == 2
26
+
27
+
28
+ def test_predict_eeg_alzheimers_profile(client):
29
+ body = {"features": [2.0] * 16}
30
+ r = client.post("/predict/eeg", json=body)
31
+ assert r.status_code == 200, r.text
32
+ data = r.json()
33
+ assert data["label_text"] == "alzheimers"
34
+
35
+
36
+ def test_predict_eeg_feature_mismatch_returns_400(client):
37
+ body = {"features": [0.0] * 8}
38
+ r = client.post("/predict/eeg", json=body)
39
+ assert r.status_code == 400
40
+
41
+
42
+ def test_predict_eeg_missing_artifact_returns_503(monkeypatch, tmp_path):
43
+ monkeypatch.setenv("EEG_CLF_ARTIFACT", str(tmp_path / "missing.joblib"))
44
+ client = TestClient(app)
45
+ r = client.post("/predict/eeg", json={"features": [0.0] * 16})
46
+ assert r.status_code == 503