File size: 2,320 Bytes
8f586ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""MLflow tracking helper used by all three pipelines.

Wraps `mlflow.start_run` so each pipeline can log params, metrics, and an
output artifact in one block. Honors `NEUROBRIDGE_DISABLE_MLFLOW=1` for
environments where the tracking server is not reachable (offline demos, CI
without mlflow service). When disabled, yields `None` and does no I/O.

Tracking URI source of truth: the standard `MLFLOW_TRACKING_URI` env var.
Tests pin this via the repo-wide conftest.py autouse fixture.
"""
from __future__ import annotations

import contextlib
import os
from pathlib import Path
from typing import Iterator

import mlflow

from src.core.logger import get_logger

logger = get_logger(__name__)

_DISABLE_FLAG = "NEUROBRIDGE_DISABLE_MLFLOW"


@contextlib.contextmanager
def track_pipeline_run(
    experiment_name: str,
    params: dict[str, object],
    metrics: dict[str, float],
    artifact_path: Path,
) -> Iterator[str | None]:
    """Context manager that creates an MLflow run for one pipeline invocation.

    On enter: creates/loads `experiment_name`, starts a run, logs params + metrics.
    On exit: logs `artifact_path` as an artifact and ends the run.

    Yields the active `run_id` (str), or `None` if MLflow is disabled.

    Args:
        experiment_name: e.g. "bbb_pipeline" / "eeg_pipeline" / "mri_pipeline".
        params: Run parameters (input path, hyper-params, etc.). Stringified by MLflow.
        metrics: Numeric metrics (row counts, durations).
        artifact_path: Path to the produced Parquet — logged as a run artifact.
    """
    if os.environ.get(_DISABLE_FLAG) == "1":
        logger.info("MLflow disabled via %s=1; skipping run tracking", _DISABLE_FLAG)
        yield None
        return

    mlflow.set_experiment(experiment_name)
    with mlflow.start_run() as run:
        for key, value in params.items():
            mlflow.log_param(key, value)
        for key, value in metrics.items():
            mlflow.log_metric(key, value)
        try:
            yield run.info.run_id
        finally:
            if Path(artifact_path).exists():
                mlflow.log_artifact(str(artifact_path))
            else:
                logger.warning(
                    "artifact_path %s does not exist; skipping artifact log",
                    artifact_path,
                )