File size: 15,478 Bytes
2d7b690
 
 
 
c0a7163
2d7b690
c0a7163
c26a55c
2d7b690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae883d4
 
c0a7163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae883d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c26a55c
 
 
 
 
 
 
ae883d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42366a8
 
 
 
 
 
 
 
 
 
 
 
ae883d4
c26a55c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28ca4f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae883d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985240b
 
c0a7163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985240b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e9f487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f348a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4000ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
"""Tests for /pipeline/{bbb,eeg,mri} POST endpoints."""
from __future__ import annotations

from pathlib import Path
from unittest.mock import patch

import pandas as pd
import pytest
from fastapi.testclient import TestClient

from src.api.main import app


client = TestClient(app)
_FIXTURES = Path(__file__).resolve().parents[1] / "fixtures"


class TestBBBRoute:
    def test_returns_200_with_valid_input(self, tmp_path: Path):
        out = tmp_path / "out.parquet"
        resp = client.post(
            "/pipeline/bbb",
            json={
                "input_path": str(_FIXTURES / "bbbp_sample.csv"),
                "output_path": str(out),
            },
        )
        assert resp.status_code == 200
        body = resp.json()
        assert body["status"] == "ok"
        assert body["rows"] > 0
        assert out.exists()

    def test_returns_404_when_input_missing(self, tmp_path: Path):
        resp = client.post(
            "/pipeline/bbb",
            json={
                "input_path": str(tmp_path / "does_not_exist.csv"),
                "output_path": str(tmp_path / "out.parquet"),
            },
        )
        assert resp.status_code == 404

    def test_returns_422_on_malformed_body(self):
        resp = client.post("/pipeline/bbb", json={"banana": 1})
        assert resp.status_code == 422  # pydantic validation


class TestEEGRoute:
    def test_returns_200_with_valid_input(self, tmp_path: Path):
        fif = _FIXTURES / "eeg_sample.fif"
        out = tmp_path / "out.parquet"
        resp = client.post(
            "/pipeline/eeg",
            json={"input_path": str(fif), "output_path": str(out)},
        )
        assert resp.status_code == 200
        assert resp.json()["rows"] > 0


class TestMRIRoute:
    def test_returns_200_with_valid_input(self, tmp_path: Path):
        from tests.fixtures.build_mri_fixture import build as build_mri
        fixture_dir = build_mri(out_dir=tmp_path / "mri_fixture")
        out = tmp_path / "out.parquet"
        resp = client.post(
            "/pipeline/mri",
            json={
                "input_dir": str(fixture_dir),
                "sites_csv": str(fixture_dir / "sites.csv"),
                "output_path": str(out),
            },
        )
        assert resp.status_code == 200
        assert resp.json()["rows"] > 0


class TestPipelineWrap:
    def test_wrap_skips_mlflow_lookup_when_disabled(self, tmp_path: Path, monkeypatch):
        from src.api import routes

        out = tmp_path / "out.parquet"
        pd.DataFrame({"x": [1]}).to_parquet(out)
        monkeypatch.setenv("NEUROBRIDGE_DISABLE_MLFLOW", "1")

        with patch("src.api.routes.mlflow.search_runs") as search_runs:
            resp = routes._wrap("bbb_pipeline", out, lambda: None)

        search_runs.assert_not_called()
        assert resp.status == "ok"
        assert resp.mlflow_run_id is None


class TestBBBPredictRoute:
    def _setup_model_artifact(self, tmp_path: Path) -> Path:
        """Build features + train + save a tiny model. Returns artifact path."""
        from src.pipelines import bbb_pipeline
        from src.models import bbb_model
        import pandas as pd
        features_path = tmp_path / "features.parquet"
        bbb_pipeline.run_pipeline(
            input_path=_FIXTURES / "bbbp_sample.csv",
            output_path=features_path,
        )
        df = pd.read_parquet(features_path)
        model = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
        artifact = tmp_path / "bbb_model.joblib"
        bbb_model.save(model, artifact)
        return artifact

    @pytest.fixture
    def _set_bbb_model_path(self, tmp_path: Path, monkeypatch):
        """Build a model artifact and point BBB_MODEL_PATH at it for the test."""
        artifact = self._setup_model_artifact(tmp_path)
        monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
        return artifact

    def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
        artifact = self._setup_model_artifact(tmp_path)
        monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))

        resp = client.post(
            "/predict/bbb",
            json={"smiles": "CCO", "top_k": 5},
        )
        assert resp.status_code == 200
        body = resp.json()
        assert body["label"] in (0, 1)
        assert body["label_text"] in ("permeable", "non-permeable")
        assert 0.0 <= body["confidence"] <= 1.0
        assert len(body["top_features"]) == 5
        for f in body["top_features"]:
            assert f["feature"].startswith("fp_")
            assert isinstance(f["shap_value"], float)
        # Day-6 calibration assertions: trained test fixture model has
        # _neurobridge_calibration metadata, so calibration must be populated.
        assert body["calibration"] is not None
        cal = body["calibration"]
        valid_thresholds = [0.50, 0.60, 0.70, 0.75, 0.80, 0.90]
        assert any(
            cal["threshold"] == pytest.approx(t) for t in valid_thresholds
        ), f"threshold {cal['threshold']} not in {valid_thresholds}"
        assert cal["threshold"] <= body["confidence"]
        assert 0.0 <= cal["precision"] <= 1.0
        assert isinstance(cal["support"], int)
        assert cal["support"] >= 0

    def test_predict_response_includes_drift_z_and_rolling_n(
        self, _set_bbb_model_path,
    ):
        """T1B: drift_z and rolling_n keys must always appear in the body."""
        # Reset deque before this test so rolling_n starts deterministic.
        from src.api import routes
        routes.WORKER_CONFIDENCE_DEQUE.clear()

        resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 5})
        assert resp.status_code == 200, resp.text
        body = resp.json()
        assert "drift_z" in body
        assert "rolling_n" in body
        # First request: buffer has 1 sample (just appended), so warming up.
        assert body["rolling_n"] == 1
        assert body["drift_z"] is None  # <10 samples = warming up

    def test_predict_deque_rolls_at_100(self, _set_bbb_model_path):
        """T1B: after 100 predictions, deque caps at maxlen=100 (rolls)."""
        from src.api import routes
        routes.WORKER_CONFIDENCE_DEQUE.clear()
        # Fire 105 calls; final rolling_n must be 100, not 105.
        last_body = None
        for _ in range(105):
            resp = client.post(
                "/predict/bbb", json={"smiles": "CCO", "top_k": 3},
            )
            assert resp.status_code == 200
            last_body = resp.json()
        assert last_body["rolling_n"] == 100
        # By call 105, drift_z is computable (≥10 samples) — assert numeric.
        assert isinstance(last_body["drift_z"], float)

    def test_predict_response_includes_provenance(self, _set_bbb_model_path):
        """T2: provenance field is present in body (fields may be None)."""
        from src.api import routes
        routes.WORKER_CONFIDENCE_DEQUE.clear()

        resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 3})
        assert resp.status_code == 200, resp.text
        body = resp.json()
        assert "provenance" in body
        assert body["provenance"] is not None, "provenance should be populated even when MLflow is empty"
        prov = body["provenance"]
        assert "mlflow_run_id" in prov
        assert "model_version" in prov
        assert prov["model_version"] == "v1"  # default until bumped manually
        assert "train_date" in prov
        assert "n_examples" in prov
        # n_examples comes from train_stats — must be a positive int for the test fixture
        assert isinstance(prov["n_examples"], int) and prov["n_examples"] >= 1

    def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
        artifact = self._setup_model_artifact(tmp_path)
        monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))

        resp = client.post(
            "/predict/bbb",
            json={"smiles": "this_is_not_a_smiles", "top_k": 5},
        )
        assert resp.status_code == 400

    def test_returns_503_when_artifact_missing(self, tmp_path: Path, monkeypatch):
        monkeypatch.setenv("BBB_MODEL_PATH", str(tmp_path / "does_not_exist.joblib"))
        resp = client.post(
            "/predict/bbb",
            json={"smiles": "CCO", "top_k": 5},
        )
        assert resp.status_code == 503


class TestMRIPredictRoute:
    def test_returns_503_when_artifact_missing(self, tmp_path: Path, monkeypatch):
        monkeypatch.setenv("MRI_MODEL_PATH", str(tmp_path / "missing.onnx"))

        resp = client.post(
            "/predict/mri",
            json={"input_path": str(_FIXTURES / "mri_sample" / "subject_0.nii.gz")},
        )

        assert resp.status_code == 503
        assert "MRI model artifact not available" in resp.text

    def test_returns_404_when_input_missing(self, tmp_path: Path, monkeypatch):
        from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_mri_onnx

        artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
        monkeypatch.setenv("MRI_MODEL_PATH", str(artifact))

        resp = client.post(
            "/predict/mri",
            json={"input_path": str(tmp_path / "missing.nii.gz"), "target_shape": [8, 8, 8]},
        )

        assert resp.status_code == 404

    def test_returns_200_with_prediction(self, tmp_path: Path, monkeypatch):
        from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_mri_onnx

        artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
        monkeypatch.setenv("MRI_MODEL_PATH", str(artifact))

        resp = client.post(
            "/predict/mri",
            json={
                "input_path": str(_FIXTURES / "mri_sample" / "subject_0.nii.gz"),
                "target_shape": [8, 8, 8],
                "label_names": ["control", "abnormal"],
            },
        )

        assert resp.status_code == 200, resp.text
        body = resp.json()
        assert body["label"] == 1
        assert body["label_text"] == "abnormal"
        assert body["confidence"] > 0.5
        assert body["input_path"].endswith("subject_0.nii.gz")
        assert body["model_path"] == str(artifact)
        assert len(body["probabilities"]) == 2


class TestMRIDiagnosticsRoute:
    def test_returns_200_with_pre_and_post_data(self, tmp_path: Path):
        from tests.fixtures.build_mri_fixture import build as build_mri
        fixture_dir = build_mri(out_dir=tmp_path / "mri")
        resp = client.post(
            "/pipeline/mri/diagnostics",
            json={
                "input_dir": str(fixture_dir),
                "sites_csv": str(fixture_dir / "sites.csv"),
            },
        )
        assert resp.status_code == 200
        body = resp.json()
        assert len(body["rows"]) > 0
        assert body["site_gap_pre"] >= 0.0
        assert body["site_gap_post"] >= 0.0
        # Reduction factor is the headline KPI
        assert body["reduction_factor"] >= 1.0  # ComBat must reduce, not amplify
        states = {r["harmonization_state"] for r in body["rows"]}
        assert states == {"Pre-ComBat", "Post-ComBat"}

    def test_returns_404_when_input_dir_missing(self, tmp_path: Path):
        resp = client.post(
            "/pipeline/mri/diagnostics",
            json={
                "input_dir": str(tmp_path / "does_not_exist"),
                "sites_csv": str(tmp_path / "sites.csv"),
            },
        )
        assert resp.status_code == 404


class TestExplainBBBRoute:
    """Day-7 T3B: POST /explain/bbb."""

    def test_returns_200_with_template_source(self, monkeypatch):
        """Kill-switch on → /explain/bbb returns rationale with source=template."""
        monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
        body = {
            "smiles": "CCO",
            "label": 1,
            "label_text": "permeable",
            "confidence": 0.82,
            "top_features": [
                {"feature": "fp_341", "shap_value": 0.045},
                {"feature": "fp_902", "shap_value": -0.031},
                {"feature": "fp_77", "shap_value": 0.022},
            ],
            "calibration": {"threshold": 0.80, "precision": 0.92, "support": 18},
            "drift_z": 0.42,
            "user_question": "Why permeable?",
        }
        resp = client.post("/explain/bbb", json=body)
        assert resp.status_code == 200, resp.text
        out = resp.json()
        assert out["source"] == "template"
        assert out["model"] is None
        # Template must mention all three features
        for feat in ("fp_341", "fp_902", "fp_77"):
            assert feat in out["rationale"]
        assert "permeable" in out["rationale"]


class TestExplainEEGRoute:
    """Day-8 T1B: POST /explain/eeg."""

    def test_returns_200_with_template_source(self, monkeypatch):
        monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
        body = {
            "rows": 30,
            "columns": 95,
            "duration_sec": 4.32,
            "mlflow_run_id": "abc12345",
            "user_question": "Why were epochs dropped?",
        }
        resp = client.post("/explain/eeg", json=body)
        assert resp.status_code == 200, resp.text
        out = resp.json()
        assert out["source"] == "template"
        assert out["model"] is None
        assert "30" in out["rationale"]
        assert "95" in out["rationale"]


class TestExplainMRIRoute:
    """Day-8 T1B: POST /explain/mri."""

    def test_returns_200_with_template_source(self, monkeypatch):
        monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
        body = {
            "site_gap_pre": 5.0004,
            "site_gap_post": 0.0015,
            "reduction_factor": 3290.0,
            "n_subjects": 6,
            "user_question": "Why does ComBat matter?",
        }
        resp = client.post("/explain/mri", json=body)
        assert resp.status_code == 200, resp.text
        out = resp.json()
        assert out["source"] == "template"
        assert "3290" in out["rationale"]
        assert "6" in out["rationale"]


class TestExperimentsRoutes:
    """Day-8 T2A: GET /experiments/runs and POST /experiments/diff."""

    def test_runs_endpoint_returns_list(self):
        """GET /experiments/runs returns a runs list (may be empty if no MLflow data)."""
        resp = client.get("/experiments/runs")
        assert resp.status_code == 200, resp.text
        body = resp.json()
        assert "runs" in body
        assert isinstance(body["runs"], list)
        # If any runs exist, each must have the expected keys
        for run in body["runs"]:
            for key in ("run_id", "experiment_name", "start_time", "status", "metrics", "params"):
                assert key in run

    def test_diff_endpoint_handles_unknown_runs_gracefully(self):
        """POST /experiments/diff with bogus run ids returns 404 (not 500)."""
        resp = client.post(
            "/experiments/diff",
            json={"run_id_a": "nonexistent_aaa", "run_id_b": "nonexistent_bbb"},
        )
        assert resp.status_code in (404, 200), (
            f"unexpected status {resp.status_code}: {resp.text}"
        )
        # 404 is the documented contract; 200 with empty rows is acceptable too
        # because some MLflow stores treat unknown ids as "empty result".
        body = resp.json()
        if resp.status_code == 200:
            assert body.get("rows", []) == []