mekosotto Claude Sonnet 4.6 commited on
Commit
5d4dc71
·
1 Parent(s): 2134339

feat(api): add POST /fusion/predict route for multi-modal fusion

Browse files
src/api/main.py CHANGED
@@ -12,6 +12,7 @@ from src.api.routes import (
12
  explain_router,
13
  experiments_router,
14
  agent_router,
 
15
  )
16
  from src.api.schemas import HealthResponse
17
 
@@ -26,6 +27,7 @@ app.include_router(predict_router)
26
  app.include_router(explain_router)
27
  app.include_router(experiments_router)
28
  app.include_router(agent_router)
 
29
 
30
 
31
  @app.get("/health", response_model=HealthResponse)
 
12
  explain_router,
13
  experiments_router,
14
  agent_router,
15
+ fusion_router, # NEW
16
  )
17
  from src.api.schemas import HealthResponse
18
 
 
27
  app.include_router(explain_router)
28
  app.include_router(experiments_router)
29
  app.include_router(agent_router)
30
+ app.include_router(fusion_router)
31
 
32
 
33
  @app.get("/health", response_model=HealthResponse)
src/api/routes.py CHANGED
@@ -28,6 +28,8 @@ from src.api.schemas import (
28
  BBBRequest,
29
  CalibrationContext,
30
  EEGExplainRequest,
 
 
31
  EEGExplainResponse,
32
  EEGRequest,
33
  FeatureAttribution,
@@ -632,3 +634,15 @@ def run_agent(req: AgentRunRequest) -> AgentRunResponse:
632
  model=result.model,
633
  finish_reason=result.finish_reason,
634
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  BBBRequest,
29
  CalibrationContext,
30
  EEGExplainRequest,
31
+ FusionRequest,
32
+ FusionResponse,
33
  EEGExplainResponse,
34
  EEGRequest,
35
  FeatureAttribution,
 
634
  model=result.model,
635
  finish_reason=result.finish_reason,
636
  )
637
+
638
+
639
+ # --- Fusion router ---------------------------------------------------------
640
+
641
+ fusion_router = APIRouter(prefix="/fusion")
642
+
643
+
644
+ @fusion_router.post("/predict", response_model=FusionResponse)
645
+ def fusion_predict(req: FusionRequest) -> FusionResponse:
646
+ """Combine MRI, EEG, and clinical scores into per-disease confidence."""
647
+ from src.fusion.engine import fuse as fuse_engine
648
+ return fuse_engine(req)
src/api/schemas.py CHANGED
@@ -288,3 +288,16 @@ class AgentRunResponse(BaseModel):
288
  trace: list[AgentToolTraceItem] = Field(default_factory=list)
289
  model: str | None = None
290
  finish_reason: str = "complete"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  trace: list[AgentToolTraceItem] = Field(default_factory=list)
289
  model: str | None = None
290
  finish_reason: str = "complete"
291
+
292
+
293
+ # --- Fusion engine surface --------------------------------------------------
294
+
295
+ # Re-export the fusion types so the API surface lives in one file but the
296
+ # implementation stays in src/fusion. This keeps `from src.api.schemas import *`
297
+ # style imports stable for the frontend layer.
298
+ from src.fusion.types import ( # noqa: E402,F401
299
+ ClinicalScores as FusionClinicalScores,
300
+ FusionInput as FusionRequest,
301
+ FusionOutput as FusionResponse,
302
+ ModalityPrediction as FusionModalityPrediction,
303
+ )
tests/api/test_fusion_route.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration test for POST /fusion/predict."""
2
+ from __future__ import annotations
3
+
4
+ from fastapi.testclient import TestClient
5
+
6
+ from src.api.main import app
7
+
8
+
9
+ client = TestClient(app)
10
+
11
+
12
+ class TestFusionRoute:
13
+ def test_happy_path_mri_only(self) -> None:
14
+ body = {
15
+ "mri": {
16
+ "label_text": "alzheimers",
17
+ "label": 1,
18
+ "confidence": 0.88,
19
+ "probabilities": [
20
+ {"label_text": "control", "probability": 0.12},
21
+ {"label_text": "alzheimers", "probability": 0.88},
22
+ ],
23
+ },
24
+ }
25
+ r = client.post("/fusion/predict", json=body)
26
+ assert r.status_code == 200, r.text
27
+ data = r.json()
28
+ assert "diseases" in data
29
+ assert any(d["disease"] == "alzheimers" for d in data["diseases"])
30
+ assert data["top_disease"] in {"alzheimers", "parkinsons", "other"}
31
+
32
+ def test_empty_input_returns_baseline(self) -> None:
33
+ r = client.post("/fusion/predict", json={})
34
+ assert r.status_code == 200
35
+ data = r.json()
36
+ for d in data["diseases"]:
37
+ assert abs(d["probability"] - 0.5) < 1e-6
38
+ assert "mri" in data["missing_inputs"]
39
+
40
+ def test_invalid_probability_returns_422(self) -> None:
41
+ body = {
42
+ "mri": {
43
+ "label_text": "x",
44
+ "label": 0,
45
+ "confidence": 1.5,
46
+ "probabilities": [{"label_text": "x", "probability": 1.5}],
47
+ },
48
+ }
49
+ r = client.post("/fusion/predict", json=body)
50
+ assert r.status_code == 422