feat(core): add MLflow tracking helper with disable env-flag
Browse files- conftest.py +23 -0
- src/core/tracking.py +67 -0
- tests/core/test_tracking.py +82 -0
conftest.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Repo-wide pytest fixtures.
|
| 2 |
+
|
| 3 |
+
Pins MLflow's tracking URI to a per-session tmp directory so pipeline tests
|
| 4 |
+
don't litter `./mlruns/` in the working tree, and so test runs are isolated
|
| 5 |
+
from production MLflow state.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Iterator
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.fixture(autouse=True, scope="session")
|
| 18 |
+
def _isolate_mlflow_tracking_uri() -> Iterator[None]:
|
| 19 |
+
tmp_root = Path(tempfile.mkdtemp(prefix="mlflow_test_"))
|
| 20 |
+
os.environ["MLFLOW_TRACKING_URI"] = f"file://{tmp_root}"
|
| 21 |
+
yield
|
| 22 |
+
# Don't rmtree — pytest tmpdir cleanup or OS handles it; rmtree
|
| 23 |
+
# races with mlflow background writes on slow CI.
|
src/core/tracking.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MLflow tracking helper used by all three pipelines.
|
| 2 |
+
|
| 3 |
+
Wraps `mlflow.start_run` so each pipeline can log params, metrics, and an
|
| 4 |
+
output artifact in one block. Honors `NEUROBRIDGE_DISABLE_MLFLOW=1` for
|
| 5 |
+
environments where the tracking server is not reachable (offline demos, CI
|
| 6 |
+
without mlflow service). When disabled, yields `None` and does no I/O.
|
| 7 |
+
|
| 8 |
+
Tracking URI source of truth: the standard `MLFLOW_TRACKING_URI` env var.
|
| 9 |
+
Tests pin this via the repo-wide conftest.py autouse fixture.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import contextlib
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Iterator
|
| 17 |
+
|
| 18 |
+
import mlflow
|
| 19 |
+
|
| 20 |
+
from src.core.logger import get_logger
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
_DISABLE_FLAG = "NEUROBRIDGE_DISABLE_MLFLOW"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@contextlib.contextmanager
|
| 28 |
+
def track_pipeline_run(
|
| 29 |
+
experiment_name: str,
|
| 30 |
+
params: dict[str, object],
|
| 31 |
+
metrics: dict[str, float],
|
| 32 |
+
artifact_path: Path,
|
| 33 |
+
) -> Iterator[str | None]:
|
| 34 |
+
"""Context manager that creates an MLflow run for one pipeline invocation.
|
| 35 |
+
|
| 36 |
+
On enter: creates/loads `experiment_name`, starts a run, logs params + metrics.
|
| 37 |
+
On exit: logs `artifact_path` as an artifact and ends the run.
|
| 38 |
+
|
| 39 |
+
Yields the active `run_id` (str), or `None` if MLflow is disabled.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
experiment_name: e.g. "bbb_pipeline" / "eeg_pipeline" / "mri_pipeline".
|
| 43 |
+
params: Run parameters (input path, hyper-params, etc.). Stringified by MLflow.
|
| 44 |
+
metrics: Numeric metrics (row counts, durations).
|
| 45 |
+
artifact_path: Path to the produced Parquet — logged as a run artifact.
|
| 46 |
+
"""
|
| 47 |
+
if os.environ.get(_DISABLE_FLAG) == "1":
|
| 48 |
+
logger.info("MLflow disabled via %s=1; skipping run tracking", _DISABLE_FLAG)
|
| 49 |
+
yield None
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
mlflow.set_experiment(experiment_name)
|
| 53 |
+
with mlflow.start_run() as run:
|
| 54 |
+
for key, value in params.items():
|
| 55 |
+
mlflow.log_param(key, value)
|
| 56 |
+
for key, value in metrics.items():
|
| 57 |
+
mlflow.log_metric(key, value)
|
| 58 |
+
try:
|
| 59 |
+
yield run.info.run_id
|
| 60 |
+
finally:
|
| 61 |
+
if Path(artifact_path).exists():
|
| 62 |
+
mlflow.log_artifact(str(artifact_path))
|
| 63 |
+
else:
|
| 64 |
+
logger.warning(
|
| 65 |
+
"artifact_path %s does not exist; skipping artifact log",
|
| 66 |
+
artifact_path,
|
| 67 |
+
)
|
tests/core/test_tracking.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.core.tracking."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import mlflow
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
from src.core import tracking
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestTrackPipelineRun:
|
| 13 |
+
def test_creates_run_with_experiment_name(self, tmp_path: Path):
|
| 14 |
+
out = tmp_path / "out.parquet"
|
| 15 |
+
pd.DataFrame({"a": [1]}).to_parquet(out)
|
| 16 |
+
with tracking.track_pipeline_run(
|
| 17 |
+
experiment_name="bbb_pipeline",
|
| 18 |
+
params={"input_path": "x.csv"},
|
| 19 |
+
metrics={"rows_in": 6.0, "rows_out": 4.0},
|
| 20 |
+
artifact_path=out,
|
| 21 |
+
) as run_id:
|
| 22 |
+
assert run_id is not None
|
| 23 |
+
runs = mlflow.search_runs(experiment_names=["bbb_pipeline"])
|
| 24 |
+
assert len(runs) >= 1
|
| 25 |
+
|
| 26 |
+
def test_logs_params(self, tmp_path: Path):
|
| 27 |
+
out = tmp_path / "out.parquet"
|
| 28 |
+
pd.DataFrame({"a": [1]}).to_parquet(out)
|
| 29 |
+
with tracking.track_pipeline_run(
|
| 30 |
+
experiment_name="bbb_pipeline_params",
|
| 31 |
+
params={"n_bits": 2048, "radius": 2},
|
| 32 |
+
metrics={},
|
| 33 |
+
artifact_path=out,
|
| 34 |
+
):
|
| 35 |
+
pass
|
| 36 |
+
runs = mlflow.search_runs(experiment_names=["bbb_pipeline_params"])
|
| 37 |
+
assert "params.n_bits" in runs.columns
|
| 38 |
+
assert runs.iloc[0]["params.n_bits"] == "2048"
|
| 39 |
+
|
| 40 |
+
def test_logs_metrics(self, tmp_path: Path):
|
| 41 |
+
out = tmp_path / "out.parquet"
|
| 42 |
+
pd.DataFrame({"a": [1]}).to_parquet(out)
|
| 43 |
+
with tracking.track_pipeline_run(
|
| 44 |
+
experiment_name="eeg_pipeline_metrics",
|
| 45 |
+
params={},
|
| 46 |
+
metrics={"duration_sec": 1.234, "rows_out": 100.0},
|
| 47 |
+
artifact_path=out,
|
| 48 |
+
):
|
| 49 |
+
pass
|
| 50 |
+
runs = mlflow.search_runs(experiment_names=["eeg_pipeline_metrics"])
|
| 51 |
+
assert runs.iloc[0]["metrics.duration_sec"] == 1.234
|
| 52 |
+
assert runs.iloc[0]["metrics.rows_out"] == 100.0
|
| 53 |
+
|
| 54 |
+
def test_logs_artifact(self, tmp_path: Path):
|
| 55 |
+
out = tmp_path / "out.parquet"
|
| 56 |
+
pd.DataFrame({"a": [1]}).to_parquet(out)
|
| 57 |
+
with tracking.track_pipeline_run(
|
| 58 |
+
experiment_name="mri_pipeline_artifact",
|
| 59 |
+
params={},
|
| 60 |
+
metrics={},
|
| 61 |
+
artifact_path=out,
|
| 62 |
+
) as run_id:
|
| 63 |
+
pass
|
| 64 |
+
artifacts = mlflow.MlflowClient().list_artifacts(run_id)
|
| 65 |
+
assert any(a.path.endswith("out.parquet") for a in artifacts)
|
| 66 |
+
|
| 67 |
+
def test_disabled_via_env_returns_no_op(self, monkeypatch, tmp_path: Path):
|
| 68 |
+
"""Setting NEUROBRIDGE_DISABLE_MLFLOW=1 must skip MLflow entirely
|
| 69 |
+
(used by live demo when the tracking server is down)."""
|
| 70 |
+
monkeypatch.setenv("NEUROBRIDGE_DISABLE_MLFLOW", "1")
|
| 71 |
+
out = tmp_path / "out.parquet"
|
| 72 |
+
pd.DataFrame({"a": [1]}).to_parquet(out)
|
| 73 |
+
with tracking.track_pipeline_run(
|
| 74 |
+
experiment_name="should_not_appear",
|
| 75 |
+
params={"x": 1},
|
| 76 |
+
metrics={"y": 2.0},
|
| 77 |
+
artifact_path=out,
|
| 78 |
+
) as run_id:
|
| 79 |
+
assert run_id is None
|
| 80 |
+
# No "should_not_appear" experiment was created
|
| 81 |
+
names = [e.name for e in mlflow.search_experiments()]
|
| 82 |
+
assert "should_not_appear" not in names
|