mekosotto Claude Opus 4.7 (1M context) commited on
Commit
d4000ca
·
1 Parent(s): d5a285c

feat(api): GET /experiments/runs + POST /experiments/diff (Track 5)

Browse files

- New experiments_router (prefix /experiments) hosts two endpoints:
GET /runs lists MLflow runs across all 3 experiments (bbb / eeg /
mri), POST /diff returns a side-by-side metric+param diff for two
given run ids.
- NEUROBRIDGE_DISABLE_MLFLOW=1 short-circuits both to empty
responses (no exception). Unknown run ids → 404 with detail.
- 5 new schemas: MLflowRunSummary, MLflowRunsResponse, RunDiffRequest,
RunDiffRow, RunDiffResponse.
- 2 new tests covering the empty-list and unknown-id paths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

src/api/main.py CHANGED
@@ -6,7 +6,12 @@ from __future__ import annotations
6
 
7
  from fastapi import FastAPI
8
 
9
- from src.api.routes import router as pipeline_router, predict_router, explain_router
 
 
 
 
 
10
  from src.api.schemas import HealthResponse
11
 
12
  app = FastAPI(
@@ -18,6 +23,7 @@ app = FastAPI(
18
  app.include_router(pipeline_router)
19
  app.include_router(predict_router)
20
  app.include_router(explain_router)
 
21
 
22
 
23
  @app.get("/health", response_model=HealthResponse)
 
6
 
7
  from fastapi import FastAPI
8
 
9
+ from src.api.routes import (
10
+ router as pipeline_router,
11
+ predict_router,
12
+ explain_router,
13
+ experiments_router,
14
+ )
15
  from src.api.schemas import HealthResponse
16
 
17
  app = FastAPI(
 
23
  app.include_router(pipeline_router)
24
  app.include_router(predict_router)
25
  app.include_router(explain_router)
26
+ app.include_router(experiments_router)
27
 
28
 
29
  @app.get("/health", response_model=HealthResponse)
src/api/routes.py CHANGED
@@ -29,6 +29,8 @@ from src.api.schemas import (
29
  EEGRequest,
30
  FeatureAttribution,
31
  HarmonizationRow,
 
 
32
  ModelProvenance,
33
  MRIDiagnosticsRequest,
34
  MRIDiagnosticsResponse,
@@ -36,6 +38,9 @@ from src.api.schemas import (
36
  MRIExplainResponse,
37
  MRIRequest,
38
  PipelineResponse,
 
 
 
39
  )
40
  from src.core.logger import get_logger
41
  from src.llm import explainer as llm_explainer
@@ -46,6 +51,7 @@ logger = get_logger(__name__)
46
  router = APIRouter(prefix="/pipeline")
47
  predict_router = APIRouter(prefix="/predict")
48
  explain_router = APIRouter(prefix="/explain")
 
49
 
50
 
51
  def _wrap(
@@ -402,3 +408,95 @@ def explain_mri(req: MRIExplainRequest) -> MRIExplainResponse:
402
  source=result["source"],
403
  model=result["model"],
404
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  EEGRequest,
30
  FeatureAttribution,
31
  HarmonizationRow,
32
+ MLflowRunsResponse,
33
+ MLflowRunSummary,
34
  ModelProvenance,
35
  MRIDiagnosticsRequest,
36
  MRIDiagnosticsResponse,
 
38
  MRIExplainResponse,
39
  MRIRequest,
40
  PipelineResponse,
41
+ RunDiffRequest,
42
+ RunDiffResponse,
43
+ RunDiffRow,
44
  )
45
  from src.core.logger import get_logger
46
  from src.llm import explainer as llm_explainer
 
51
  router = APIRouter(prefix="/pipeline")
52
  predict_router = APIRouter(prefix="/predict")
53
  explain_router = APIRouter(prefix="/explain")
54
+ experiments_router = APIRouter(prefix="/experiments")
55
 
56
 
57
  def _wrap(
 
408
  source=result["source"],
409
  model=result["model"],
410
  )
411
+
412
+
413
+ @experiments_router.get("/runs", response_model=MLflowRunsResponse)
414
+ def list_runs(limit: int = 50) -> MLflowRunsResponse:
415
+ """List recent MLflow runs across known experiments.
416
+
417
+ Returns an empty list when MLflow is disabled or unreachable.
418
+ """
419
+ if os.environ.get("NEUROBRIDGE_DISABLE_MLFLOW") == "1":
420
+ return MLflowRunsResponse(runs=[])
421
+
422
+ summaries: list[MLflowRunSummary] = []
423
+ for exp_name in ("bbb_pipeline", "eeg_pipeline", "mri_pipeline"):
424
+ try:
425
+ df = mlflow.search_runs(
426
+ experiment_names=[exp_name],
427
+ max_results=limit,
428
+ order_by=["start_time DESC"],
429
+ )
430
+ except Exception as e: # broad: MLflow store unreachable / not found
431
+ logger.warning("MLflow lookup failed for %s: %s", exp_name, e)
432
+ continue
433
+ for _, row in df.iterrows():
434
+ metrics = {
435
+ col[len("metrics."):]: float(row[col])
436
+ for col in df.columns
437
+ if col.startswith("metrics.") and pd.notna(row[col])
438
+ }
439
+ params = {
440
+ col[len("params."):]: str(row[col])
441
+ for col in df.columns
442
+ if col.startswith("params.") and pd.notna(row[col])
443
+ }
444
+ summaries.append(
445
+ MLflowRunSummary(
446
+ run_id=str(row["run_id"]),
447
+ experiment_name=exp_name,
448
+ start_time=str(pd.Timestamp(row["start_time"]).isoformat())
449
+ if pd.notna(row.get("start_time"))
450
+ else "",
451
+ status=str(row.get("status", "UNKNOWN")),
452
+ metrics=metrics,
453
+ params=params,
454
+ )
455
+ )
456
+ summaries.sort(key=lambda s: s.start_time, reverse=True)
457
+ return MLflowRunsResponse(runs=summaries[:limit])
458
+
459
+
460
+ @experiments_router.post("/diff", response_model=RunDiffResponse)
461
+ def diff_runs(req: RunDiffRequest) -> RunDiffResponse:
462
+ """Side-by-side diff of two MLflow runs (metrics + params).
463
+
464
+ Returns 404 if either run id is not found in the local MLflow store.
465
+ Returns 200 with an empty rows list when MLflow is disabled.
466
+ """
467
+ if os.environ.get("NEUROBRIDGE_DISABLE_MLFLOW") == "1":
468
+ return RunDiffResponse(rows=[])
469
+
470
+ try:
471
+ run_a = mlflow.get_run(req.run_id_a)
472
+ run_b = mlflow.get_run(req.run_id_b)
473
+ except Exception as e:
474
+ raise HTTPException(status_code=404, detail=f"Run not found: {e}")
475
+
476
+ metrics_a = run_a.data.metrics
477
+ metrics_b = run_b.data.metrics
478
+ params_a = run_a.data.params
479
+ params_b = run_b.data.params
480
+
481
+ rows: list[RunDiffRow] = []
482
+ for key in sorted(set(metrics_a) | set(metrics_b)):
483
+ va = metrics_a.get(key)
484
+ vb = metrics_b.get(key)
485
+ rows.append(
486
+ RunDiffRow(
487
+ key=key, kind="metric",
488
+ value_a=None if va is None else f"{va:.6g}",
489
+ value_b=None if vb is None else f"{vb:.6g}",
490
+ differs=(va != vb),
491
+ )
492
+ )
493
+ for key in sorted(set(params_a) | set(params_b)):
494
+ va = params_a.get(key)
495
+ vb = params_b.get(key)
496
+ rows.append(
497
+ RunDiffRow(
498
+ key=key, kind="param",
499
+ value_a=va, value_b=vb, differs=(va != vb),
500
+ )
501
+ )
502
+ return RunDiffResponse(rows=rows)
src/api/schemas.py CHANGED
@@ -193,3 +193,38 @@ class MRIExplainResponse(BaseModel):
193
  rationale: str
194
  source: str
195
  model: str | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  rationale: str
194
  source: str
195
  model: str | None = None
196
+
197
+
198
+ class MLflowRunSummary(BaseModel):
199
+ """One MLflow run row for the Experiments tab table."""
200
+ run_id: str
201
+ experiment_name: str
202
+ start_time: str # ISO 8601
203
+ status: str
204
+ metrics: dict[str, float] = Field(default_factory=dict)
205
+ params: dict[str, str] = Field(default_factory=dict)
206
+
207
+
208
+ class MLflowRunsResponse(BaseModel):
209
+ """Response for GET /experiments/runs."""
210
+ runs: list[MLflowRunSummary]
211
+
212
+
213
+ class RunDiffRequest(BaseModel):
214
+ """Request body for POST /experiments/diff."""
215
+ run_id_a: str
216
+ run_id_b: str
217
+
218
+
219
+ class RunDiffRow(BaseModel):
220
+ """One row of a run-vs-run diff: metric/param key + value pair."""
221
+ key: str
222
+ kind: str # "metric" | "param"
223
+ value_a: str | None
224
+ value_b: str | None
225
+ differs: bool
226
+
227
+
228
+ class RunDiffResponse(BaseModel):
229
+ """Response for POST /experiments/diff: side-by-side metric/param diff."""
230
+ rows: list[RunDiffRow]
tests/api/test_routes.py CHANGED
@@ -300,3 +300,34 @@ class TestExplainMRIRoute:
300
  assert out["source"] == "template"
301
  assert "3290" in out["rationale"]
302
  assert "6" in out["rationale"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  assert out["source"] == "template"
301
  assert "3290" in out["rationale"]
302
  assert "6" in out["rationale"]
303
+
304
+
305
+ class TestExperimentsRoutes:
306
+ """Day-8 T2A: GET /experiments/runs and POST /experiments/diff."""
307
+
308
+ def test_runs_endpoint_returns_list(self):
309
+ """GET /experiments/runs returns a runs list (may be empty if no MLflow data)."""
310
+ resp = client.get("/experiments/runs")
311
+ assert resp.status_code == 200, resp.text
312
+ body = resp.json()
313
+ assert "runs" in body
314
+ assert isinstance(body["runs"], list)
315
+ # If any runs exist, each must have the expected keys
316
+ for run in body["runs"]:
317
+ for key in ("run_id", "experiment_name", "start_time", "status", "metrics", "params"):
318
+ assert key in run
319
+
320
+ def test_diff_endpoint_handles_unknown_runs_gracefully(self):
321
+ """POST /experiments/diff with bogus run ids returns 404 (not 500)."""
322
+ resp = client.post(
323
+ "/experiments/diff",
324
+ json={"run_id_a": "nonexistent_aaa", "run_id_b": "nonexistent_bbb"},
325
+ )
326
+ assert resp.status_code in (404, 200), (
327
+ f"unexpected status {resp.status_code}: {resp.text}"
328
+ )
329
+ # 404 is the documented contract; 200 with empty rows is acceptable too
330
+ # because some MLflow stores treat unknown ids as "empty result".
331
+ body = resp.json()
332
+ if resp.status_code == 200:
333
+ assert body.get("rows", []) == []