feat(api): add POST /fusion/predict route for multi-modal fusion
Browse files- src/api/main.py +2 -0
- src/api/routes.py +14 -0
- src/api/schemas.py +13 -0
- tests/api/test_fusion_route.py +50 -0
src/api/main.py
CHANGED
|
@@ -12,6 +12,7 @@ from src.api.routes import (
|
|
| 12 |
explain_router,
|
| 13 |
experiments_router,
|
| 14 |
agent_router,
|
|
|
|
| 15 |
)
|
| 16 |
from src.api.schemas import HealthResponse
|
| 17 |
|
|
@@ -26,6 +27,7 @@ app.include_router(predict_router)
|
|
| 26 |
app.include_router(explain_router)
|
| 27 |
app.include_router(experiments_router)
|
| 28 |
app.include_router(agent_router)
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
@app.get("/health", response_model=HealthResponse)
|
|
|
|
| 12 |
explain_router,
|
| 13 |
experiments_router,
|
| 14 |
agent_router,
|
| 15 |
+
fusion_router, # NEW
|
| 16 |
)
|
| 17 |
from src.api.schemas import HealthResponse
|
| 18 |
|
|
|
|
| 27 |
app.include_router(explain_router)
|
| 28 |
app.include_router(experiments_router)
|
| 29 |
app.include_router(agent_router)
|
| 30 |
+
app.include_router(fusion_router)
|
| 31 |
|
| 32 |
|
| 33 |
@app.get("/health", response_model=HealthResponse)
|
src/api/routes.py
CHANGED
|
@@ -28,6 +28,8 @@ from src.api.schemas import (
|
|
| 28 |
BBBRequest,
|
| 29 |
CalibrationContext,
|
| 30 |
EEGExplainRequest,
|
|
|
|
|
|
|
| 31 |
EEGExplainResponse,
|
| 32 |
EEGRequest,
|
| 33 |
FeatureAttribution,
|
|
@@ -632,3 +634,15 @@ def run_agent(req: AgentRunRequest) -> AgentRunResponse:
|
|
| 632 |
model=result.model,
|
| 633 |
finish_reason=result.finish_reason,
|
| 634 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
BBBRequest,
|
| 29 |
CalibrationContext,
|
| 30 |
EEGExplainRequest,
|
| 31 |
+
FusionRequest,
|
| 32 |
+
FusionResponse,
|
| 33 |
EEGExplainResponse,
|
| 34 |
EEGRequest,
|
| 35 |
FeatureAttribution,
|
|
|
|
| 634 |
model=result.model,
|
| 635 |
finish_reason=result.finish_reason,
|
| 636 |
)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
# --- Fusion router ---------------------------------------------------------
|
| 640 |
+
|
| 641 |
+
fusion_router = APIRouter(prefix="/fusion")
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
@fusion_router.post("/predict", response_model=FusionResponse)
|
| 645 |
+
def fusion_predict(req: FusionRequest) -> FusionResponse:
|
| 646 |
+
"""Combine MRI, EEG, and clinical scores into per-disease confidence."""
|
| 647 |
+
from src.fusion.engine import fuse as fuse_engine
|
| 648 |
+
return fuse_engine(req)
|
src/api/schemas.py
CHANGED
|
@@ -288,3 +288,16 @@ class AgentRunResponse(BaseModel):
|
|
| 288 |
trace: list[AgentToolTraceItem] = Field(default_factory=list)
|
| 289 |
model: str | None = None
|
| 290 |
finish_reason: str = "complete"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
trace: list[AgentToolTraceItem] = Field(default_factory=list)
|
| 289 |
model: str | None = None
|
| 290 |
finish_reason: str = "complete"
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# --- Fusion engine surface --------------------------------------------------
|
| 294 |
+
|
| 295 |
+
# Re-export the fusion types so the API surface lives in one file but the
|
| 296 |
+
# implementation stays in src/fusion. This keeps `from src.api.schemas import *`
|
| 297 |
+
# style imports stable for the frontend layer.
|
| 298 |
+
from src.fusion.types import ( # noqa: E402,F401
|
| 299 |
+
ClinicalScores as FusionClinicalScores,
|
| 300 |
+
FusionInput as FusionRequest,
|
| 301 |
+
FusionOutput as FusionResponse,
|
| 302 |
+
ModalityPrediction as FusionModalityPrediction,
|
| 303 |
+
)
|
tests/api/test_fusion_route.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration test for POST /fusion/predict."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from fastapi.testclient import TestClient
|
| 5 |
+
|
| 6 |
+
from src.api.main import app
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
client = TestClient(app)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestFusionRoute:
|
| 13 |
+
def test_happy_path_mri_only(self) -> None:
|
| 14 |
+
body = {
|
| 15 |
+
"mri": {
|
| 16 |
+
"label_text": "alzheimers",
|
| 17 |
+
"label": 1,
|
| 18 |
+
"confidence": 0.88,
|
| 19 |
+
"probabilities": [
|
| 20 |
+
{"label_text": "control", "probability": 0.12},
|
| 21 |
+
{"label_text": "alzheimers", "probability": 0.88},
|
| 22 |
+
],
|
| 23 |
+
},
|
| 24 |
+
}
|
| 25 |
+
r = client.post("/fusion/predict", json=body)
|
| 26 |
+
assert r.status_code == 200, r.text
|
| 27 |
+
data = r.json()
|
| 28 |
+
assert "diseases" in data
|
| 29 |
+
assert any(d["disease"] == "alzheimers" for d in data["diseases"])
|
| 30 |
+
assert data["top_disease"] in {"alzheimers", "parkinsons", "other"}
|
| 31 |
+
|
| 32 |
+
def test_empty_input_returns_baseline(self) -> None:
|
| 33 |
+
r = client.post("/fusion/predict", json={})
|
| 34 |
+
assert r.status_code == 200
|
| 35 |
+
data = r.json()
|
| 36 |
+
for d in data["diseases"]:
|
| 37 |
+
assert abs(d["probability"] - 0.5) < 1e-6
|
| 38 |
+
assert "mri" in data["missing_inputs"]
|
| 39 |
+
|
| 40 |
+
def test_invalid_probability_returns_422(self) -> None:
|
| 41 |
+
body = {
|
| 42 |
+
"mri": {
|
| 43 |
+
"label_text": "x",
|
| 44 |
+
"label": 0,
|
| 45 |
+
"confidence": 1.5,
|
| 46 |
+
"probabilities": [{"label_text": "x", "probability": 1.5}],
|
| 47 |
+
},
|
| 48 |
+
}
|
| 49 |
+
r = client.post("/fusion/predict", json=body)
|
| 50 |
+
assert r.status_code == 422
|