File size: 3,096 Bytes
8f586ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | """Tests for src.core.tracking."""
from __future__ import annotations
from pathlib import Path
import mlflow
import pandas as pd
from src.core import tracking
class TestTrackPipelineRun:
def test_creates_run_with_experiment_name(self, tmp_path: Path):
out = tmp_path / "out.parquet"
pd.DataFrame({"a": [1]}).to_parquet(out)
with tracking.track_pipeline_run(
experiment_name="bbb_pipeline",
params={"input_path": "x.csv"},
metrics={"rows_in": 6.0, "rows_out": 4.0},
artifact_path=out,
) as run_id:
assert run_id is not None
runs = mlflow.search_runs(experiment_names=["bbb_pipeline"])
assert len(runs) >= 1
def test_logs_params(self, tmp_path: Path):
out = tmp_path / "out.parquet"
pd.DataFrame({"a": [1]}).to_parquet(out)
with tracking.track_pipeline_run(
experiment_name="bbb_pipeline_params",
params={"n_bits": 2048, "radius": 2},
metrics={},
artifact_path=out,
):
pass
runs = mlflow.search_runs(experiment_names=["bbb_pipeline_params"])
assert "params.n_bits" in runs.columns
assert runs.iloc[0]["params.n_bits"] == "2048"
def test_logs_metrics(self, tmp_path: Path):
out = tmp_path / "out.parquet"
pd.DataFrame({"a": [1]}).to_parquet(out)
with tracking.track_pipeline_run(
experiment_name="eeg_pipeline_metrics",
params={},
metrics={"duration_sec": 1.234, "rows_out": 100.0},
artifact_path=out,
):
pass
runs = mlflow.search_runs(experiment_names=["eeg_pipeline_metrics"])
assert runs.iloc[0]["metrics.duration_sec"] == 1.234
assert runs.iloc[0]["metrics.rows_out"] == 100.0
def test_logs_artifact(self, tmp_path: Path):
out = tmp_path / "out.parquet"
pd.DataFrame({"a": [1]}).to_parquet(out)
with tracking.track_pipeline_run(
experiment_name="mri_pipeline_artifact",
params={},
metrics={},
artifact_path=out,
) as run_id:
pass
artifacts = mlflow.MlflowClient().list_artifacts(run_id)
assert any(a.path.endswith("out.parquet") for a in artifacts)
def test_disabled_via_env_returns_no_op(self, monkeypatch, tmp_path: Path):
"""Setting NEUROBRIDGE_DISABLE_MLFLOW=1 must skip MLflow entirely
(used by live demo when the tracking server is down)."""
monkeypatch.setenv("NEUROBRIDGE_DISABLE_MLFLOW", "1")
out = tmp_path / "out.parquet"
pd.DataFrame({"a": [1]}).to_parquet(out)
with tracking.track_pipeline_run(
experiment_name="should_not_appear",
params={"x": 1},
metrics={"y": 2.0},
artifact_path=out,
) as run_id:
assert run_id is None
# No "should_not_appear" experiment was created
names = [e.name for e in mlflow.search_experiments()]
assert "should_not_appear" not in names
|