mekosotto Claude Opus 4.7 (1M context) commited on
Commit
28ca4f9
·
1 Parent(s): 95c5aff

feat(api+frontend): MLflow provenance badge in decision card

Browse files

- ModelProvenance schema (mlflow_run_id, model_version, train_date,
n_examples). BBBPredictResponse.provenance is always populated; failed
MLflow lookup degrades to None fields without breaking the response.
- _build_provenance() module-level cache: one MLflow query per worker.
NEUROBRIDGE_DISABLE_MLFLOW=1 short-circuits to None fields. n_examples
pulled per-request from model._neurobridge_train_stats.
- Streamlit decision card renders a one-line audit badge above the
label: run id (first 8 chars), model version, train date, n_examples.
- 1 new test: provenance field present in /predict/bbb body with the
fixture model (n_examples ≥ 1 from train stats).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

src/api/routes.py CHANGED
@@ -25,6 +25,7 @@ from src.api.schemas import (
25
  EEGRequest,
26
  FeatureAttribution,
27
  HarmonizationRow,
 
28
  MRIDiagnosticsRequest,
29
  MRIDiagnosticsResponse,
30
  MRIRequest,
@@ -160,6 +161,59 @@ def _compute_drift_z(model, confidence: float) -> tuple[float | None, int]:
160
  return float(drift_z), rolling_n
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
164
  """Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
165
  bins = getattr(model, "_neurobridge_calibration", None)
@@ -211,6 +265,7 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
211
  label_text = "permeable" if pred["label"] == 1 else "non-permeable"
212
  calibration = _matching_calibration_bin(model, pred["confidence"])
213
  drift_z, rolling_n = _compute_drift_z(model, pred["confidence"])
 
214
  return BBBPredictResponse(
215
  label=pred["label"],
216
  label_text=label_text,
@@ -219,6 +274,7 @@ def predict_bbb(req: BBBPredictRequest) -> BBBPredictResponse:
219
  calibration=calibration,
220
  drift_z=drift_z,
221
  rolling_n=rolling_n,
 
222
  )
223
 
224
 
 
25
  EEGRequest,
26
  FeatureAttribution,
27
  HarmonizationRow,
28
+ ModelProvenance,
29
  MRIDiagnosticsRequest,
30
  MRIDiagnosticsResponse,
31
  MRIRequest,
 
161
  return float(drift_z), rolling_n
162
 
163
 
164
+ _PROVENANCE_CACHE: ModelProvenance | None = None
165
+ _MODEL_VERSION = "v1" # bump manually per train cycle
166
+
167
+
168
+ def _build_provenance(model) -> ModelProvenance:
169
+ """Look up the most recent BBB MLflow run; build a ModelProvenance.
170
+
171
+ Cached at module level so we hit MLflow once per worker. Failures (no
172
+ runs found, MLflow unreachable, NEUROBRIDGE_DISABLE_MLFLOW=1) all
173
+ degrade to a partial ModelProvenance with mlflow_run_id=None — the
174
+ badge still renders, just without a run id.
175
+ """
176
+ global _PROVENANCE_CACHE
177
+ if _PROVENANCE_CACHE is not None:
178
+ # Refresh n_examples each call from the model (cheap lookup).
179
+ n_train = None
180
+ stats = getattr(model, "_neurobridge_train_stats", None)
181
+ if stats is not None:
182
+ n_train = int(stats.get("n_train", 0)) or None
183
+ return _PROVENANCE_CACHE.model_copy(update={"n_examples": n_train})
184
+
185
+ run_id: str | None = None
186
+ train_date: str | None = None
187
+ if os.environ.get("NEUROBRIDGE_DISABLE_MLFLOW") != "1":
188
+ try:
189
+ runs = mlflow.search_runs(
190
+ experiment_names=["bbb_pipeline"],
191
+ max_results=1,
192
+ order_by=["start_time DESC"],
193
+ )
194
+ if len(runs):
195
+ row = runs.iloc[0]
196
+ run_id = str(row["run_id"])
197
+ ts = row.get("start_time")
198
+ if ts is not None:
199
+ train_date = str(pd.Timestamp(ts).isoformat())
200
+ except Exception as e: # broad: MLflow store unreachable, schema mismatch, etc.
201
+ logger.warning("MLflow provenance lookup failed: %s", e)
202
+
203
+ n_train = None
204
+ stats = getattr(model, "_neurobridge_train_stats", None)
205
+ if stats is not None:
206
+ n_train = int(stats.get("n_train", 0)) or None
207
+
208
+ _PROVENANCE_CACHE = ModelProvenance(
209
+ mlflow_run_id=run_id,
210
+ model_version=_MODEL_VERSION,
211
+ train_date=train_date,
212
+ n_examples=n_train,
213
+ )
214
+ return _PROVENANCE_CACHE
215
+
216
+
217
  def _matching_calibration_bin(model, confidence: float) -> CalibrationContext | None:
218
  """Pick the highest-threshold bin whose threshold <= confidence. None if no match or no metadata."""
219
  bins = getattr(model, "_neurobridge_calibration", None)
 
265
  label_text = "permeable" if pred["label"] == 1 else "non-permeable"
266
  calibration = _matching_calibration_bin(model, pred["confidence"])
267
  drift_z, rolling_n = _compute_drift_z(model, pred["confidence"])
268
+ provenance = _build_provenance(model)
269
  return BBBPredictResponse(
270
  label=pred["label"],
271
  label_text=label_text,
 
274
  calibration=calibration,
275
  drift_z=drift_z,
276
  rolling_n=rolling_n,
277
+ provenance=provenance,
278
  )
279
 
280
 
src/api/schemas.py CHANGED
@@ -70,6 +70,14 @@ class CalibrationContext(BaseModel):
70
  support: int = Field(..., description="Number of held-out predictions falling in this bin")
71
 
72
 
 
 
 
 
 
 
 
 
73
  class BBBPredictResponse(BaseModel):
74
  """Decision-system payload: prediction + uncertainty + explanation + drift."""
75
  label: int
@@ -95,6 +103,10 @@ class BBBPredictResponse(BaseModel):
95
  "rolling window (max 100). Zero on a fresh worker."
96
  ),
97
  )
 
 
 
 
98
 
99
 
100
  class MRIDiagnosticsRequest(BaseModel):
 
70
  support: int = Field(..., description="Number of held-out predictions falling in this bin")
71
 
72
 
73
+ class ModelProvenance(BaseModel):
74
+ """Auditable provenance of the BBB model that produced a prediction."""
75
+ mlflow_run_id: str | None = Field(None, description="MLflow run id of the most recent training run, if any")
76
+ model_version: str = Field("v1", description="Manually-bumped model version label")
77
+ train_date: str | None = Field(None, description="ISO 8601 train timestamp from MLflow run start_time")
78
+ n_examples: int | None = Field(None, description="Training set size (from model._neurobridge_train_stats[\"n_train\"])")
79
+
80
+
81
  class BBBPredictResponse(BaseModel):
82
  """Decision-system payload: prediction + uncertainty + explanation + drift."""
83
  label: int
 
103
  "rolling window (max 100). Zero on a fresh worker."
104
  ),
105
  )
106
+ provenance: ModelProvenance | None = Field(
107
+ None,
108
+ description="Auditing metadata (MLflow run id, train date, n_examples).",
109
+ )
110
 
111
 
112
  class MRIDiagnosticsRequest(BaseModel):
src/frontend/app.py CHANGED
@@ -457,6 +457,18 @@ def _render_mri_tab() -> None:
457
  def _render_prediction_card(result: dict) -> None:
458
  """Render a B2B-styled decision card: label badge + confidence + SHAP bars."""
459
  st.session_state["last_bbb_prediction"] = result
 
 
 
 
 
 
 
 
 
 
 
 
460
  label_text = _html.escape(str(result["label_text"]))
461
  badge_color = "#166534" if result["label"] == 1 else "#991B1B"
462
  badge_bg = "#DCFCE7" if result["label"] == 1 else "#FEE2E2"
 
457
  def _render_prediction_card(result: dict) -> None:
458
  """Render a B2B-styled decision card: label badge + confidence + SHAP bars."""
459
  st.session_state["last_bbb_prediction"] = result
460
+ provenance = result.get("provenance")
461
+ if provenance is not None:
462
+ run_id = provenance.get("mlflow_run_id")
463
+ run_label = run_id[:8] if run_id else "—"
464
+ train_date = provenance.get("train_date") or "—"
465
+ n_examples = provenance.get("n_examples")
466
+ n_label = f"n={n_examples}" if n_examples else "n=—"
467
+ st.caption(
468
+ f"🔎 MLflow run **{run_label}** · "
469
+ f"Model **{provenance.get('model_version', 'v1')}** · "
470
+ f"trained {train_date} · {n_label}"
471
+ )
472
  label_text = _html.escape(str(result["label_text"]))
473
  badge_color = "#166534" if result["label"] == 1 else "#991B1B"
474
  badge_bg = "#DCFCE7" if result["label"] == 1 else "#FEE2E2"
tests/api/test_routes.py CHANGED
@@ -160,6 +160,25 @@ class TestBBBPredictRoute:
160
  # By call 105, drift_z is computable (≥10 samples) — assert numeric.
161
  assert isinstance(last_body["drift_z"], float)
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
164
  artifact = self._setup_model_artifact(tmp_path)
165
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
 
160
  # By call 105, drift_z is computable (≥10 samples) — assert numeric.
161
  assert isinstance(last_body["drift_z"], float)
162
 
163
+ def test_predict_response_includes_provenance(self, _set_bbb_model_path):
164
+ """T2: provenance field is present in body (fields may be None)."""
165
+ from src.api import routes
166
+ routes.WORKER_CONFIDENCE_DEQUE.clear()
167
+
168
+ resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 3})
169
+ assert resp.status_code == 200, resp.text
170
+ body = resp.json()
171
+ assert "provenance" in body
172
+ assert body["provenance"] is not None, "provenance should be populated even when MLflow is empty"
173
+ prov = body["provenance"]
174
+ assert "mlflow_run_id" in prov
175
+ assert "model_version" in prov
176
+ assert prov["model_version"] == "v1" # default until bumped manually
177
+ assert "train_date" in prov
178
+ assert "n_examples" in prov
179
+ # n_examples comes from train_stats — must be a positive int for the test fixture
180
+ assert isinstance(prov["n_examples"], int) and prov["n_examples"] >= 1
181
+
182
  def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
183
  artifact = self._setup_model_artifact(tmp_path)
184
  monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))