mekosotto commited on
Commit
2d7b690
·
1 Parent(s): fae874a

feat(api): POST /pipeline/{bbb,eeg,mri} dispatch routes

Browse files
Files changed (3) hide show
  1. src/api/main.py +3 -0
  2. src/api/routes.py +110 -0
  3. tests/api/test_routes.py +72 -0
src/api/main.py CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
6
 
7
  from fastapi import FastAPI
8
 
 
9
  from src.api.schemas import HealthResponse
10
 
11
  app = FastAPI(
@@ -14,6 +15,8 @@ app = FastAPI(
14
  version="0.4.0",
15
  )
16
 
 
 
17
 
18
  @app.get("/health", response_model=HealthResponse)
19
  def health() -> HealthResponse:
 
6
 
7
  from fastapi import FastAPI
8
 
9
+ from src.api.routes import router as pipeline_router
10
  from src.api.schemas import HealthResponse
11
 
12
  app = FastAPI(
 
15
  version="0.4.0",
16
  )
17
 
18
+ app.include_router(pipeline_router)
19
+
20
 
21
  @app.get("/health", response_model=HealthResponse)
22
  def health() -> HealthResponse:
src/api/routes.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """POST /pipeline/{bbb,eeg,mri} routes — thin dispatchers over the pipelines.
2
+
3
+ Each route validates its request body via Pydantic, invokes the pipeline,
4
+ reads back the produced Parquet to populate row/column counts, and returns
5
+ a uniform PipelineResponse. Pipeline-domain errors map to standard HTTP
6
+ codes: FileNotFoundError -> 404, ValueError -> 400, anything else -> 500.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Callable
13
+
14
+ import mlflow
15
+ import pandas as pd
16
+ from fastapi import APIRouter, HTTPException
17
+
18
+ from src.api.schemas import (
19
+ BBBRequest,
20
+ EEGRequest,
21
+ MRIRequest,
22
+ PipelineResponse,
23
+ )
24
+ from src.core.logger import get_logger
25
+ from src.pipelines import bbb_pipeline, eeg_pipeline, mri_pipeline
26
+
27
+ logger = get_logger(__name__)
28
+ router = APIRouter(prefix="/pipeline")
29
+
30
+
31
+ def _wrap(
32
+ experiment_name: str,
33
+ output_path: Path,
34
+ fn: Callable[[], None],
35
+ ) -> PipelineResponse:
36
+ """Run `fn()` (the pipeline call), gather metrics, return PipelineResponse."""
37
+ started = time.perf_counter()
38
+ try:
39
+ fn()
40
+ except FileNotFoundError as e:
41
+ raise HTTPException(status_code=404, detail=str(e))
42
+ except (ValueError, KeyError) as e:
43
+ # KeyError: MRI pipeline raises this when sites_csv is missing a
44
+ # site assignment for a subject — a user data problem, not a 500.
45
+ raise HTTPException(status_code=400, detail=str(e))
46
+ duration_sec = time.perf_counter() - started
47
+
48
+ df = pd.read_parquet(output_path)
49
+ runs = mlflow.search_runs(
50
+ experiment_names=[experiment_name],
51
+ max_results=1,
52
+ order_by=["start_time DESC"],
53
+ )
54
+ run_id = runs.iloc[0]["run_id"] if len(runs) else None
55
+
56
+ return PipelineResponse(
57
+ status="ok",
58
+ output_path=str(output_path),
59
+ rows=len(df),
60
+ columns=df.shape[1],
61
+ duration_sec=duration_sec,
62
+ mlflow_run_id=run_id,
63
+ )
64
+
65
+
66
+ @router.post("/bbb", response_model=PipelineResponse)
67
+ def run_bbb(req: BBBRequest) -> PipelineResponse:
68
+ """Run the BBB pipeline; return rows/cols/duration + the MLflow run id."""
69
+ return _wrap(
70
+ "bbb_pipeline",
71
+ Path(req.output_path),
72
+ lambda: bbb_pipeline.run_pipeline(
73
+ input_path=Path(req.input_path),
74
+ output_path=Path(req.output_path),
75
+ smiles_col=req.smiles_col,
76
+ n_bits=req.n_bits,
77
+ radius=req.radius,
78
+ ),
79
+ )
80
+
81
+
82
+ @router.post("/eeg", response_model=PipelineResponse)
83
+ def run_eeg(req: EEGRequest) -> PipelineResponse:
84
+ """Run the EEG pipeline; return rows/cols/duration + the MLflow run id."""
85
+ return _wrap(
86
+ "eeg_pipeline",
87
+ Path(req.output_path),
88
+ lambda: eeg_pipeline.run_pipeline(
89
+ input_path=Path(req.input_path),
90
+ output_path=Path(req.output_path),
91
+ epoch_duration_s=req.epoch_duration_s,
92
+ eog_ch_name=req.eog_ch_name,
93
+ n_components=req.n_components,
94
+ random_state=req.random_state,
95
+ ),
96
+ )
97
+
98
+
99
+ @router.post("/mri", response_model=PipelineResponse)
100
+ def run_mri(req: MRIRequest) -> PipelineResponse:
101
+ """Run the MRI pipeline; return rows/cols/duration + the MLflow run id."""
102
+ return _wrap(
103
+ "mri_pipeline",
104
+ Path(req.output_path),
105
+ lambda: mri_pipeline.run_pipeline(
106
+ input_dir=Path(req.input_dir),
107
+ sites_csv=Path(req.sites_csv),
108
+ output_path=Path(req.output_path),
109
+ ),
110
+ )
tests/api/test_routes.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for /pipeline/{bbb,eeg,mri} POST endpoints."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ from fastapi.testclient import TestClient
7
+
8
+ from src.api.main import app
9
+
10
+
11
+ client = TestClient(app)
12
+ _FIXTURES = Path(__file__).resolve().parents[1] / "fixtures"
13
+
14
+
15
+ class TestBBBRoute:
16
+ def test_returns_200_with_valid_input(self, tmp_path: Path):
17
+ out = tmp_path / "out.parquet"
18
+ resp = client.post(
19
+ "/pipeline/bbb",
20
+ json={
21
+ "input_path": str(_FIXTURES / "bbbp_sample.csv"),
22
+ "output_path": str(out),
23
+ },
24
+ )
25
+ assert resp.status_code == 200
26
+ body = resp.json()
27
+ assert body["status"] == "ok"
28
+ assert body["rows"] > 0
29
+ assert out.exists()
30
+
31
+ def test_returns_404_when_input_missing(self, tmp_path: Path):
32
+ resp = client.post(
33
+ "/pipeline/bbb",
34
+ json={
35
+ "input_path": str(tmp_path / "does_not_exist.csv"),
36
+ "output_path": str(tmp_path / "out.parquet"),
37
+ },
38
+ )
39
+ assert resp.status_code == 404
40
+
41
+ def test_returns_422_on_malformed_body(self):
42
+ resp = client.post("/pipeline/bbb", json={"banana": 1})
43
+ assert resp.status_code == 422 # pydantic validation
44
+
45
+
46
+ class TestEEGRoute:
47
+ def test_returns_200_with_valid_input(self, tmp_path: Path):
48
+ fif = _FIXTURES / "eeg_sample.fif"
49
+ out = tmp_path / "out.parquet"
50
+ resp = client.post(
51
+ "/pipeline/eeg",
52
+ json={"input_path": str(fif), "output_path": str(out)},
53
+ )
54
+ assert resp.status_code == 200
55
+ assert resp.json()["rows"] > 0
56
+
57
+
58
+ class TestMRIRoute:
59
+ def test_returns_200_with_valid_input(self, tmp_path: Path):
60
+ from tests.fixtures.build_mri_fixture import build as build_mri
61
+ fixture_dir = build_mri(out_dir=tmp_path / "mri_fixture")
62
+ out = tmp_path / "out.parquet"
63
+ resp = client.post(
64
+ "/pipeline/mri",
65
+ json={
66
+ "input_dir": str(fixture_dir),
67
+ "sites_csv": str(fixture_dir / "sites.csv"),
68
+ "output_path": str(out),
69
+ },
70
+ )
71
+ assert resp.status_code == 200
72
+ assert resp.json()["rows"] > 0