mekosotto Claude Opus 4.7 (1M context) commited on
Commit
5e9f487
·
1 Parent(s): e5c1c61

feat(api): POST /explain/bbb — natural-language rationale endpoint

Browse files

- New explain_router with /explain prefix; symmetric with /predict/bbb
and reserves /explain/eeg, /explain/mri for future expansion.
- BBBExplainRequest carries the prediction snapshot + optional
user_question. top_features is required and must be non-empty
(Pydantic min_length=1 → 422 on empty).
- BBBExplainResponse: {rationale, source, model}. Always 200 because
the explainer's template fallback never raises.
- 1 new test: 200 + source='template' under NEUROBRIDGE_DISABLE_LLM=1
with full SHAP + calibration + drift payload.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

src/api/main.py CHANGED
@@ -6,7 +6,7 @@ from __future__ import annotations
6
 
7
  from fastapi import FastAPI
8
 
9
- from src.api.routes import router as pipeline_router, predict_router
10
  from src.api.schemas import HealthResponse
11
 
12
  app = FastAPI(
@@ -17,6 +17,7 @@ app = FastAPI(
17
 
18
  app.include_router(pipeline_router)
19
  app.include_router(predict_router)
 
20
 
21
 
22
  @app.get("/health", response_model=HealthResponse)
 
6
 
7
  from fastapi import FastAPI
8
 
9
+ from src.api.routes import router as pipeline_router, predict_router, explain_router
10
  from src.api.schemas import HealthResponse
11
 
12
  app = FastAPI(
 
17
 
18
  app.include_router(pipeline_router)
19
  app.include_router(predict_router)
20
+ app.include_router(explain_router)
21
 
22
 
23
  @app.get("/health", response_model=HealthResponse)
src/api/routes.py CHANGED
@@ -18,6 +18,8 @@ import pandas as pd
18
  from fastapi import APIRouter, HTTPException
19
 
20
  from src.api.schemas import (
 
 
21
  BBBPredictRequest,
22
  BBBPredictResponse,
23
  BBBRequest,
@@ -32,12 +34,14 @@ from src.api.schemas import (
32
  PipelineResponse,
33
  )
34
  from src.core.logger import get_logger
 
35
  from src.models import bbb_model
36
  from src.pipelines import bbb_pipeline, eeg_pipeline, mri_pipeline
37
 
38
  logger = get_logger(__name__)
39
  router = APIRouter(prefix="/pipeline")
40
  predict_router = APIRouter(prefix="/predict")
 
41
 
42
 
43
  def _wrap(
@@ -320,3 +324,41 @@ def mri_diagnostics(req: MRIDiagnosticsRequest) -> MRIDiagnosticsResponse:
320
  site_gap_post=site_gap_post,
321
  reduction_factor=reduction_factor,
322
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from fastapi import APIRouter, HTTPException
19
 
20
  from src.api.schemas import (
21
+ BBBExplainRequest,
22
+ BBBExplainResponse,
23
  BBBPredictRequest,
24
  BBBPredictResponse,
25
  BBBRequest,
 
34
  PipelineResponse,
35
  )
36
  from src.core.logger import get_logger
37
+ from src.llm import explainer as llm_explainer
38
  from src.models import bbb_model
39
  from src.pipelines import bbb_pipeline, eeg_pipeline, mri_pipeline
40
 
41
  logger = get_logger(__name__)
42
  router = APIRouter(prefix="/pipeline")
43
  predict_router = APIRouter(prefix="/predict")
44
+ explain_router = APIRouter(prefix="/explain")
45
 
46
 
47
  def _wrap(
 
324
  site_gap_post=site_gap_post,
325
  reduction_factor=reduction_factor,
326
  )
327
+
328
+
329
+ @explain_router.post("/bbb", response_model=BBBExplainResponse)
330
+ def explain_bbb(req: BBBExplainRequest) -> BBBExplainResponse:
331
+ """Natural-language rationale for a single BBB prediction.
332
+
333
+ Always returns 200 — the explainer is guaranteed to produce a
334
+ rationale via deterministic-template fallback. Pydantic enforces
335
+ a non-empty top_features list; an empty list returns 422 from
336
+ FastAPI before this handler runs.
337
+ """
338
+ payload: llm_explainer.ExplainPayload = {
339
+ "smiles": req.smiles,
340
+ "label": req.label,
341
+ "label_text": req.label_text,
342
+ "confidence": req.confidence,
343
+ "top_features": [
344
+ {"feature": f.feature, "shap_value": f.shap_value}
345
+ for f in req.top_features
346
+ ],
347
+ "calibration": (
348
+ None
349
+ if req.calibration is None
350
+ else {
351
+ "threshold": req.calibration.threshold,
352
+ "precision": req.calibration.precision,
353
+ "support": req.calibration.support,
354
+ }
355
+ ),
356
+ "drift_z": req.drift_z,
357
+ "user_question": req.user_question or "",
358
+ }
359
+ result = llm_explainer.explain(payload)
360
+ return BBBExplainResponse(
361
+ rationale=result["rationale"],
362
+ source=result["source"],
363
+ model=result["model"],
364
+ )
src/api/schemas.py CHANGED
@@ -133,3 +133,31 @@ class MRIDiagnosticsResponse(BaseModel):
133
  site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
134
  site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
135
  reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  site_gap_pre: float = Field(..., description="Range of per-site means before ComBat")
134
  site_gap_post: float = Field(..., description="Range of per-site means after ComBat")
135
  reduction_factor: float = Field(..., description="site_gap_pre / max(site_gap_post, eps)")
136
+
137
+
138
+ class BBBExplainRequest(BaseModel):
139
+ """Day-7 T3B: payload for POST /explain/bbb (chat-style explainer)."""
140
+ smiles: str = Field(..., description="SMILES string of the molecule")
141
+ label: int = Field(..., description="Predicted label (0 = non-permeable, 1 = permeable)")
142
+ label_text: str = Field(..., description="'permeable' or 'non-permeable'")
143
+ confidence: float = Field(..., ge=0.0, le=1.0)
144
+ top_features: list[FeatureAttribution] = Field(
145
+ ..., min_length=1,
146
+ description="Non-empty list of SHAP attributions; an empty list returns 400.",
147
+ )
148
+ calibration: CalibrationContext | None = None
149
+ drift_z: float | None = None
150
+ user_question: str | None = Field(
151
+ None,
152
+ description="Optional question from the user; passed to the LLM prompt only.",
153
+ )
154
+
155
+
156
+ class BBBExplainResponse(BaseModel):
157
+ """Day-7 T3B: response from POST /explain/bbb."""
158
+ rationale: str = Field(..., description="2-4 sentence natural-language explanation")
159
+ source: str = Field(..., description="'llm' or 'template'")
160
+ model: str | None = Field(
161
+ None,
162
+ description="LLM model name when source='llm'; None when source='template'",
163
+ )
tests/api/test_routes.py CHANGED
@@ -228,3 +228,34 @@ class TestMRIDiagnosticsRoute:
228
  },
229
  )
230
  assert resp.status_code == 404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  },
229
  )
230
  assert resp.status_code == 404
231
+
232
+
233
+ class TestExplainBBBRoute:
234
+ """Day-7 T3B: POST /explain/bbb."""
235
+
236
+ def test_returns_200_with_template_source(self, monkeypatch):
237
+ """Kill-switch on → /explain/bbb returns rationale with source=template."""
238
+ monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
239
+ body = {
240
+ "smiles": "CCO",
241
+ "label": 1,
242
+ "label_text": "permeable",
243
+ "confidence": 0.82,
244
+ "top_features": [
245
+ {"feature": "fp_341", "shap_value": 0.045},
246
+ {"feature": "fp_902", "shap_value": -0.031},
247
+ {"feature": "fp_77", "shap_value": 0.022},
248
+ ],
249
+ "calibration": {"threshold": 0.80, "precision": 0.92, "support": 18},
250
+ "drift_z": 0.42,
251
+ "user_question": "Why permeable?",
252
+ }
253
+ resp = client.post("/explain/bbb", json=body)
254
+ assert resp.status_code == 200, resp.text
255
+ out = resp.json()
256
+ assert out["source"] == "template"
257
+ assert out["model"] is None
258
+ # Template must mention all three features
259
+ for feat in ("fp_341", "fp_902", "fp_77"):
260
+ assert feat in out["rationale"]
261
+ assert "permeable" in out["rationale"]