mekosotto commited on
Commit
ae883d4
·
1 Parent(s): 7dad1a9

feat(api): POST /predict/bbb with prediction, uncertainty, SHAP top-k

Browse files
src/api/main.py CHANGED
@@ -6,7 +6,7 @@ from __future__ import annotations
6
 
7
  from fastapi import FastAPI
8
 
9
- from src.api.routes import router as pipeline_router
10
  from src.api.schemas import HealthResponse
11
 
12
  app = FastAPI(
@@ -16,6 +16,7 @@ app = FastAPI(
16
  )
17
 
18
  app.include_router(pipeline_router)
 
19
 
20
 
21
  @app.get("/health", response_model=HealthResponse)
 
6
 
7
  from fastapi import FastAPI
8
 
9
+ from src.api.routes import router as pipeline_router, predict_router
10
  from src.api.schemas import HealthResponse
11
 
12
  app = FastAPI(
 
16
  )
17
 
18
  app.include_router(pipeline_router)
19
+ app.include_router(predict_router)
20
 
21
 
22
  @app.get("/health", response_model=HealthResponse)
src/api/routes.py CHANGED
@@ -7,6 +7,7 @@ codes: FileNotFoundError -> 404, ValueError -> 400, anything else -> 500.
7
  """
8
  from __future__ import annotations
9
 
 
10
  import time
11
  from pathlib import Path
12
  from typing import Callable
@@ -16,16 +17,21 @@ import pandas as pd
16
  from fastapi import APIRouter, HTTPException
17
 
18
  from src.api.schemas import (
 
 
19
  BBBRequest,
20
  EEGRequest,
 
21
  MRIRequest,
22
  PipelineResponse,
23
  )
24
  from src.core.logger import get_logger
 
25
  from src.pipelines import bbb_pipeline, eeg_pipeline, mri_pipeline
26
 
27
  logger = get_logger(__name__)
28
  router = APIRouter(prefix="/pipeline")
 
29
 
30
 
31
  def _wrap(
@@ -108,3 +114,50 @@ def run_mri(req: MRIRequest) -> PipelineResponse:
108
  output_path=Path(req.output_path),
109
  ),
110
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
  from __future__ import annotations
9
 
10
+ import os
11
  import time
12
  from pathlib import Path
13
  from typing import Callable
 
17
  from fastapi import APIRouter, HTTPException
18
 
19
  from src.api.schemas import (
20
+ BBBPredictRequest,
21
+ BBBPredictResponse,
22
  BBBRequest,
23
  EEGRequest,
24
+ FeatureAttribution,
25
  MRIRequest,
26
  PipelineResponse,
27
  )
28
  from src.core.logger import get_logger
29
+ from src.models import bbb_model
30
  from src.pipelines import bbb_pipeline, eeg_pipeline, mri_pipeline
31
 
32
  logger = get_logger(__name__)
33
  router = APIRouter(prefix="/pipeline")
34
+ predict_router = APIRouter(prefix="/predict")
35
 
36
 
37
  def _wrap(
 
114
  output_path=Path(req.output_path),
115
  ),
116
  )
117
+
118
+
119
+ # Default artifact location. Overridable via BBB_MODEL_PATH env var so tests
120
+ # can point at a tmp-built model without touching production paths.
121
+ _DEFAULT_BBB_MODEL_PATH = Path("data/processed/bbb_model.joblib")
122
+
123
+
124
+ def _bbb_model_path() -> Path:
125
+ """Return the BBB model artifact path, overridable via BBB_MODEL_PATH env var."""
126
+ return Path(os.environ.get("BBB_MODEL_PATH", str(_DEFAULT_BBB_MODEL_PATH)))
127
+
128
+
129
+ @predict_router.post("/bbb", response_model=BBBPredictResponse)
130
+ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
131
+ """Predict BBB permeability + return SHAP attributions for one SMILES.
132
+
133
+ Returns 503 if the model artifact is missing (operator hasn't run the
134
+ trainer CLI yet); 400 on invalid SMILES; 200 with the decision payload
135
+ on success.
136
+ """
137
+ artifact = _bbb_model_path()
138
+ if not artifact.exists():
139
+ raise HTTPException(
140
+ status_code=503,
141
+ detail=(
142
+ f"BBB model artifact not available at {artifact}. "
143
+ f"Run `python -m src.models.bbb_model` to train it."
144
+ ),
145
+ )
146
+ try:
147
+ model = bbb_model.load(artifact)
148
+ except FileNotFoundError as e:
149
+ raise HTTPException(status_code=503, detail=str(e))
150
+
151
+ try:
152
+ pred = bbb_model.predict_with_proba(model, req.smiles)
153
+ attributions = bbb_model.explain_prediction(model, req.smiles, top_k=req.top_k)
154
+ except ValueError as e:
155
+ raise HTTPException(status_code=400, detail=str(e))
156
+
157
+ label_text = "permeable" if pred["label"] == 1 else "non-permeable"
158
+ return BBBPredictResponse(
159
+ label=pred["label"],
160
+ label_text=label_text,
161
+ confidence=pred["confidence"],
162
+ top_features=[FeatureAttribution(**a) for a in attributions],
163
+ )
src/api/schemas.py CHANGED
@@ -46,3 +46,26 @@ class PipelineResponse(BaseModel):
46
  class HealthResponse(BaseModel):
47
  status: str
48
  pipelines: list[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  class HealthResponse(BaseModel):
47
  status: str
48
  pipelines: list[str]
49
+
50
+
51
+ class BBBPredictRequest(BaseModel):
52
+ """Single-molecule BBB-permeability prediction request."""
53
+ smiles: str = Field(..., description="SMILES string; e.g. 'CCO' for ethanol")
54
+ top_k: int = Field(5, ge=1, le=20, description="Top-k SHAP features to return")
55
+
56
+
57
+ class FeatureAttribution(BaseModel):
58
+ """A single SHAP attribution: which fingerprint bit contributed and by how much."""
59
+ feature: str = Field(..., description="Fingerprint column name, e.g. 'fp_1234'")
60
+ shap_value: float = Field(
61
+ ...,
62
+ description="Signed SHAP value for the predicted class (positive pushed model toward, negative away)",
63
+ )
64
+
65
+
66
+ class BBBPredictResponse(BaseModel):
67
+ """Decision-system payload: prediction + uncertainty + explanation."""
68
+ label: int
69
+ label_text: str = Field(..., description="'permeable' or 'non-permeable'")
70
+ confidence: float
71
+ top_features: list[FeatureAttribution]
tests/api/test_routes.py CHANGED
@@ -70,3 +70,57 @@ class TestMRIRoute:
70
  )
71
  assert resp.status_code == 200
72
  assert resp.json()["rows"] > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
  assert resp.status_code == 200
72
  assert resp.json()["rows"] > 0
73
+
74
+
75
+ class TestBBBPredictRoute:
76
+ def _setup_model_artifact(self, tmp_path: Path) -> Path:
77
+ """Build features + train + save a tiny model. Returns artifact path."""
78
+ from src.pipelines import bbb_pipeline
79
+ from src.models import bbb_model
80
+ import pandas as pd
81
+ features_path = tmp_path / "features.parquet"
82
+ bbb_pipeline.run_pipeline(
83
+ input_path=_FIXTURES / "bbbp_sample.csv",
84
+ output_path=features_path,
85
+ )
86
+ df = pd.read_parquet(features_path)
87
+ model = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
88
+ artifact = tmp_path / "bbb_model.joblib"
89
+ bbb_model.save(model, artifact)
90
+ return artifact
91
+
92
+ def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
93
+ artifact = self._setup_model_artifact(tmp_path)
94
+ monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
95
+
96
+ resp = client.post(
97
+ "/predict/bbb",
98
+ json={"smiles": "CCO", "top_k": 5},
99
+ )
100
+ assert resp.status_code == 200
101
+ body = resp.json()
102
+ assert body["label"] in (0, 1)
103
+ assert body["label_text"] in ("permeable", "non-permeable")
104
+ assert 0.0 <= body["confidence"] <= 1.0
105
+ assert len(body["top_features"]) == 5
106
+ for f in body["top_features"]:
107
+ assert f["feature"].startswith("fp_")
108
+ assert isinstance(f["shap_value"], float)
109
+
110
+ def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
111
+ artifact = self._setup_model_artifact(tmp_path)
112
+ monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
113
+
114
+ resp = client.post(
115
+ "/predict/bbb",
116
+ json={"smiles": "this_is_not_a_smiles", "top_k": 5},
117
+ )
118
+ assert resp.status_code == 400
119
+
120
+ def test_returns_503_when_artifact_missing(self, tmp_path: Path, monkeypatch):
121
+ monkeypatch.setenv("BBB_MODEL_PATH", str(tmp_path / "does_not_exist.joblib"))
122
+ resp = client.post(
123
+ "/predict/bbb",
124
+ json={"smiles": "CCO", "top_k": 5},
125
+ )
126
+ assert resp.status_code == 503