File size: 1,376 Bytes
93ed35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""Training entry point. The train function is what /log-run wraps.

This is a skeleton — replace the model + loss + optimization with your
actual model. The contract that matters: take a config dict in, return
metrics + an artifact path out, and log everything via tracking.run().
"""

from __future__ import annotations

from typing import Any

from .. import tracking


def train(run_config: dict[str, Any]) -> dict[str, float]:
    """Skeleton. Replace the body with your training loop.

    Returns a dict of headline metrics so /eval-report can pick them up.
    Saves model artifact to models/<name>.pt (or .joblib, .onnx, etc.).
    """
    name = run_config.get("name", "train")
    with tracking.run(name=name, run_config=run_config) as r:
        # TODO: load data via your_project.data
        # TODO: build model
        # TODO: training loop with r.log({"train_loss": ..., "val_mae": ...})
        # TODO: save artifact and r.log_artifact(path)
        metrics = {"val_mae": float("nan")}
        r.log(metrics)
        print(f"W&B run URL: {r.url}")
        return metrics


if __name__ == "__main__":
    import argparse
    import json

    p = argparse.ArgumentParser()
    p.add_argument("--config", required=True, help="Path to run config JSON")
    args = p.parse_args()
    with open(args.config) as f:
        cfg = json.load(f)
    print(train(cfg))