File size: 1,451 Bytes
93ed35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""W&B helpers. All training scripts go through here so logging stays
consistent and the run URL is easy to grab for the model card and README.

Default: online run if WANDB_API_KEY is set, offline otherwise.
"""

from __future__ import annotations

from contextlib import contextmanager
from typing import Any

try:
    import wandb
except ImportError:  # let local dev work without W&B installed
    wandb = None  # type: ignore[assignment]

from collections.abc import Iterator

from . import config


@contextmanager
def run(name: str, run_config: dict[str, Any], tags: list[str] | None = None) -> Iterator[Any]:
    """Context manager that opens a W&B run and yields the run object.

    Use:
        with tracking.run(name="train_v3", run_config={"lr": 1e-3}) as r:
            ...
            r.log({"val_mae": 0.42})
    """
    if wandb is None:
        # Yield a no-op stand-in so training code stays unchanged offline.
        class _NoopRun:
            url = "(wandb not installed)"

            def log(self, *a, **k):
                pass

            def log_artifact(self, *a, **k):
                pass

            def finish(self):
                pass

        yield _NoopRun()
        return

    r = wandb.init(
        project=config.WANDB_PROJECT,
        entity=config.WANDB_ENTITY or None,
        name=name,
        config=run_config,
        tags=tags or [],
    )
    try:
        yield r
    finally:
        r.finish()