| """MLflow tracking helper used by all three pipelines. |
| |
| Wraps `mlflow.start_run` so each pipeline can log params, metrics, and an |
| output artifact in one block. Honors `NEUROBRIDGE_DISABLE_MLFLOW=1` for |
| environments where the tracking server is not reachable (offline demos, CI |
| without mlflow service). When disabled, yields `None` and does no I/O. |
| |
| Tracking URI source of truth: the standard `MLFLOW_TRACKING_URI` env var. |
| Tests pin this via the repo-wide conftest.py autouse fixture. |
| """ |
| from __future__ import annotations |
|
|
| import contextlib |
| import os |
| from pathlib import Path |
| from typing import Iterator |
|
|
| import mlflow |
|
|
| from src.core.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| _DISABLE_FLAG = "NEUROBRIDGE_DISABLE_MLFLOW" |
|
|
|
|
| @contextlib.contextmanager |
| def track_pipeline_run( |
| experiment_name: str, |
| params: dict[str, object], |
| metrics: dict[str, float], |
| artifact_path: Path, |
| ) -> Iterator[str | None]: |
| """Context manager that creates an MLflow run for one pipeline invocation. |
| |
| On enter: creates/loads `experiment_name`, starts a run, logs params + metrics. |
| On exit: logs `artifact_path` as an artifact and ends the run. |
| |
| Yields the active `run_id` (str), or `None` if MLflow is disabled. |
| |
| Args: |
| experiment_name: e.g. "bbb_pipeline" / "eeg_pipeline" / "mri_pipeline". |
| params: Run parameters (input path, hyper-params, etc.). Stringified by MLflow. |
| metrics: Numeric metrics (row counts, durations). |
| artifact_path: Path to the produced Parquet — logged as a run artifact. |
| """ |
| if os.environ.get(_DISABLE_FLAG) == "1": |
| logger.info("MLflow disabled via %s=1; skipping run tracking", _DISABLE_FLAG) |
| yield None |
| return |
|
|
| mlflow.set_experiment(experiment_name) |
| with mlflow.start_run() as run: |
| for key, value in params.items(): |
| mlflow.log_param(key, value) |
| for key, value in metrics.items(): |
| mlflow.log_metric(key, value) |
| try: |
| yield run.info.run_id |
| finally: |
| if Path(artifact_path).exists(): |
| mlflow.log_artifact(str(artifact_path)) |
| else: |
| logger.warning( |
| "artifact_path %s does not exist; skipping artifact log", |
| artifact_path, |
| ) |
|
|