mekosotto commited on
Commit
8f586ea
·
1 Parent(s): cc8c965

feat(core): add MLflow tracking helper with disable env-flag

Browse files
Files changed (3) hide show
  1. conftest.py +23 -0
  2. src/core/tracking.py +67 -0
  3. tests/core/test_tracking.py +82 -0
conftest.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Repo-wide pytest fixtures.
2
+
3
+ Pins MLflow's tracking URI to a per-session tmp directory so pipeline tests
4
+ don't litter `./mlruns/` in the working tree, and so test runs are isolated
5
+ from production MLflow state.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import tempfile
11
+ from pathlib import Path
12
+ from typing import Iterator
13
+
14
+ import pytest
15
+
16
+
17
+ @pytest.fixture(autouse=True, scope="session")
18
+ def _isolate_mlflow_tracking_uri() -> Iterator[None]:
19
+ tmp_root = Path(tempfile.mkdtemp(prefix="mlflow_test_"))
20
+ os.environ["MLFLOW_TRACKING_URI"] = f"file://{tmp_root}"
21
+ yield
22
+ # Don't rmtree — pytest tmpdir cleanup or OS handles it; rmtree
23
+ # races with mlflow background writes on slow CI.
src/core/tracking.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MLflow tracking helper used by all three pipelines.
2
+
3
+ Wraps `mlflow.start_run` so each pipeline can log params, metrics, and an
4
+ output artifact in one block. Honors `NEUROBRIDGE_DISABLE_MLFLOW=1` for
5
+ environments where the tracking server is not reachable (offline demos, CI
6
+ without mlflow service). When disabled, yields `None` and does no I/O.
7
+
8
+ Tracking URI source of truth: the standard `MLFLOW_TRACKING_URI` env var.
9
+ Tests pin this via the repo-wide conftest.py autouse fixture.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import contextlib
14
+ import os
15
+ from pathlib import Path
16
+ from typing import Iterator
17
+
18
+ import mlflow
19
+
20
+ from src.core.logger import get_logger
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ _DISABLE_FLAG = "NEUROBRIDGE_DISABLE_MLFLOW"
25
+
26
+
27
+ @contextlib.contextmanager
28
+ def track_pipeline_run(
29
+ experiment_name: str,
30
+ params: dict[str, object],
31
+ metrics: dict[str, float],
32
+ artifact_path: Path,
33
+ ) -> Iterator[str | None]:
34
+ """Context manager that creates an MLflow run for one pipeline invocation.
35
+
36
+ On enter: creates/loads `experiment_name`, starts a run, logs params + metrics.
37
+ On exit: logs `artifact_path` as an artifact and ends the run.
38
+
39
+ Yields the active `run_id` (str), or `None` if MLflow is disabled.
40
+
41
+ Args:
42
+ experiment_name: e.g. "bbb_pipeline" / "eeg_pipeline" / "mri_pipeline".
43
+ params: Run parameters (input path, hyper-params, etc.). Stringified by MLflow.
44
+ metrics: Numeric metrics (row counts, durations).
45
+ artifact_path: Path to the produced Parquet — logged as a run artifact.
46
+ """
47
+ if os.environ.get(_DISABLE_FLAG) == "1":
48
+ logger.info("MLflow disabled via %s=1; skipping run tracking", _DISABLE_FLAG)
49
+ yield None
50
+ return
51
+
52
+ mlflow.set_experiment(experiment_name)
53
+ with mlflow.start_run() as run:
54
+ for key, value in params.items():
55
+ mlflow.log_param(key, value)
56
+ for key, value in metrics.items():
57
+ mlflow.log_metric(key, value)
58
+ try:
59
+ yield run.info.run_id
60
+ finally:
61
+ if Path(artifact_path).exists():
62
+ mlflow.log_artifact(str(artifact_path))
63
+ else:
64
+ logger.warning(
65
+ "artifact_path %s does not exist; skipping artifact log",
66
+ artifact_path,
67
+ )
tests/core/test_tracking.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.core.tracking."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import mlflow
7
+ import pandas as pd
8
+
9
+ from src.core import tracking
10
+
11
+
12
+ class TestTrackPipelineRun:
13
+ def test_creates_run_with_experiment_name(self, tmp_path: Path):
14
+ out = tmp_path / "out.parquet"
15
+ pd.DataFrame({"a": [1]}).to_parquet(out)
16
+ with tracking.track_pipeline_run(
17
+ experiment_name="bbb_pipeline",
18
+ params={"input_path": "x.csv"},
19
+ metrics={"rows_in": 6.0, "rows_out": 4.0},
20
+ artifact_path=out,
21
+ ) as run_id:
22
+ assert run_id is not None
23
+ runs = mlflow.search_runs(experiment_names=["bbb_pipeline"])
24
+ assert len(runs) >= 1
25
+
26
+ def test_logs_params(self, tmp_path: Path):
27
+ out = tmp_path / "out.parquet"
28
+ pd.DataFrame({"a": [1]}).to_parquet(out)
29
+ with tracking.track_pipeline_run(
30
+ experiment_name="bbb_pipeline_params",
31
+ params={"n_bits": 2048, "radius": 2},
32
+ metrics={},
33
+ artifact_path=out,
34
+ ):
35
+ pass
36
+ runs = mlflow.search_runs(experiment_names=["bbb_pipeline_params"])
37
+ assert "params.n_bits" in runs.columns
38
+ assert runs.iloc[0]["params.n_bits"] == "2048"
39
+
40
+ def test_logs_metrics(self, tmp_path: Path):
41
+ out = tmp_path / "out.parquet"
42
+ pd.DataFrame({"a": [1]}).to_parquet(out)
43
+ with tracking.track_pipeline_run(
44
+ experiment_name="eeg_pipeline_metrics",
45
+ params={},
46
+ metrics={"duration_sec": 1.234, "rows_out": 100.0},
47
+ artifact_path=out,
48
+ ):
49
+ pass
50
+ runs = mlflow.search_runs(experiment_names=["eeg_pipeline_metrics"])
51
+ assert runs.iloc[0]["metrics.duration_sec"] == 1.234
52
+ assert runs.iloc[0]["metrics.rows_out"] == 100.0
53
+
54
+ def test_logs_artifact(self, tmp_path: Path):
55
+ out = tmp_path / "out.parquet"
56
+ pd.DataFrame({"a": [1]}).to_parquet(out)
57
+ with tracking.track_pipeline_run(
58
+ experiment_name="mri_pipeline_artifact",
59
+ params={},
60
+ metrics={},
61
+ artifact_path=out,
62
+ ) as run_id:
63
+ pass
64
+ artifacts = mlflow.MlflowClient().list_artifacts(run_id)
65
+ assert any(a.path.endswith("out.parquet") for a in artifacts)
66
+
67
+ def test_disabled_via_env_returns_no_op(self, monkeypatch, tmp_path: Path):
68
+ """Setting NEUROBRIDGE_DISABLE_MLFLOW=1 must skip MLflow entirely
69
+ (used by live demo when the tracking server is down)."""
70
+ monkeypatch.setenv("NEUROBRIDGE_DISABLE_MLFLOW", "1")
71
+ out = tmp_path / "out.parquet"
72
+ pd.DataFrame({"a": [1]}).to_parquet(out)
73
+ with tracking.track_pipeline_run(
74
+ experiment_name="should_not_appear",
75
+ params={"x": 1},
76
+ metrics={"y": 2.0},
77
+ artifact_path=out,
78
+ ) as run_id:
79
+ assert run_id is None
80
+ # No "should_not_appear" experiment was created
81
+ names = [e.name for e in mlflow.search_experiments()]
82
+ assert "should_not_appear" not in names