File size: 2,320 Bytes
8f586ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """MLflow tracking helper used by all three pipelines.
Wraps `mlflow.start_run` so each pipeline can log params, metrics, and an
output artifact in one block. Honors `NEUROBRIDGE_DISABLE_MLFLOW=1` for
environments where the tracking server is not reachable (offline demos, CI
without mlflow service). When disabled, yields `None` and does no I/O.
Tracking URI source of truth: the standard `MLFLOW_TRACKING_URI` env var.
Tests pin this via the repo-wide conftest.py autouse fixture.
"""
from __future__ import annotations
import contextlib
import os
from pathlib import Path
from typing import Iterator
import mlflow
from src.core.logger import get_logger
logger = get_logger(__name__)
_DISABLE_FLAG = "NEUROBRIDGE_DISABLE_MLFLOW"
@contextlib.contextmanager
def track_pipeline_run(
experiment_name: str,
params: dict[str, object],
metrics: dict[str, float],
artifact_path: Path,
) -> Iterator[str | None]:
"""Context manager that creates an MLflow run for one pipeline invocation.
On enter: creates/loads `experiment_name`, starts a run, logs params + metrics.
On exit: logs `artifact_path` as an artifact and ends the run.
Yields the active `run_id` (str), or `None` if MLflow is disabled.
Args:
experiment_name: e.g. "bbb_pipeline" / "eeg_pipeline" / "mri_pipeline".
params: Run parameters (input path, hyper-params, etc.). Stringified by MLflow.
metrics: Numeric metrics (row counts, durations).
artifact_path: Path to the produced Parquet — logged as a run artifact.
"""
if os.environ.get(_DISABLE_FLAG) == "1":
logger.info("MLflow disabled via %s=1; skipping run tracking", _DISABLE_FLAG)
yield None
return
mlflow.set_experiment(experiment_name)
with mlflow.start_run() as run:
for key, value in params.items():
mlflow.log_param(key, value)
for key, value in metrics.items():
mlflow.log_metric(key, value)
try:
yield run.info.run_id
finally:
if Path(artifact_path).exists():
mlflow.log_artifact(str(artifact_path))
else:
logger.warning(
"artifact_path %s does not exist; skipping artifact log",
artifact_path,
)
|