File size: 4,578 Bytes
b00faa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b7eb68
b00faa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b620a9c
 
b00faa3
00be188
 
b620a9c
 
 
 
00be188
 
b620a9c
 
 
 
 
 
 
 
 
 
 
 
b00faa3
 
b620a9c
b00faa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Toto 2.0 inference wrapper.

We use the smallest Toto 2.0 variant (4M params) for speed on CPU. The model
is downloaded from the HuggingFace Hub on first use and cached.

API confirmed against DataDog/toto's `toto2/notebooks/quick_start.ipynb`:

    from toto2 import Toto2Model
    model = Toto2Model.from_pretrained("Datadog/Toto-2.0-4m", map_location=device)
    quantiles = model.forecast(
        {"target": ..., "target_mask": ..., "series_ids": ...},
        horizon=H,
    )
    # quantiles shape: (9, batch, n_var, horizon)
    # quantile levels:  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
import pandas as pd

DEFAULT_MODEL_ID = "Datadog/Toto-2.0-22m"

# Index into the 9-quantile output.
Q10_IDX = 0
Q50_IDX = 4
Q90_IDX = 8


@dataclass
class TotoForecast:
    """One metric's forecast.

    `index` is a future-timestamp DatetimeIndex; `median`, `p10`, `p90` are
    pandas Series aligned to it.
    """
    median: pd.Series
    p10: pd.Series
    p90: pd.Series


_MODEL_CACHE: dict[str, object] = {}


def load_model(model_id: str = DEFAULT_MODEL_ID, device: str = "cpu"):
    """Lazy-load + cache the Toto model. Imports torch lazily so this module
    is importable in environments without torch (local dev on Intel mac)."""
    if model_id in _MODEL_CACHE:
        return _MODEL_CACHE[model_id]

    import torch  # noqa: PLC0415
    from toto2 import Toto2Model  # noqa: PLC0415

    actual_device = device if (device != "cuda" or torch.cuda.is_available()) else "cpu"
    model = Toto2Model.from_pretrained(model_id, map_location=actual_device)
    model = model.to(actual_device).eval()
    _MODEL_CACHE[model_id] = model
    return model


def _series_freq(series: pd.Series) -> pd.Timedelta:
    """Infer the spacing of a regular time series; default to 1 hour."""
    if len(series.index) < 2:
        return pd.Timedelta("1h")
    diffs = pd.Series(series.index).diff().dropna()
    if diffs.empty:
        return pd.Timedelta("1h")
    return diffs.median()


def forecast_series(
    series: pd.Series,
    horizon: int = 24,
    model_id: str = DEFAULT_MODEL_ID,
    device: str = "cpu",
) -> TotoForecast:
    """Univariate forecast for one metric.

    `series` must be regularly-spaced and have a DatetimeIndex (UTC). Returns
    median, p10, p90 over `horizon` future steps at the same cadence.
    """
    import torch  # noqa: PLC0415

    if series.empty:
        raise ValueError("Cannot forecast an empty series")

    import numpy as np  # noqa: PLC0415

    clean = series.astype(float).interpolate(limit_direction="both")

    # Toto requires the context length to be a multiple of the model's
    # patch_size (32 for Toto-2.0-4m). If we have at least one full patch,
    # truncate the oldest points to fit. If we have fewer, left-pad with the
    # first value and mark the padded region False in the mask so Toto
    # ignores it.
    model = load_model(model_id, device=device)
    patch = int(model.config.patch_size)
    raw = clean.to_numpy(dtype=np.float32)
    n_raw = len(raw)

    if n_raw >= patch:
        n = (n_raw // patch) * patch
        arr = raw[-n:]
        mask_vec = np.ones(n, dtype=bool)
    else:
        n = patch
        pad = n - n_raw
        arr = np.concatenate([np.full(pad, raw[0], dtype=np.float32), raw])
        mask_vec = np.concatenate([np.zeros(pad, dtype=bool), np.ones(n_raw, dtype=bool)])

    target = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)  # (1, 1, T)
    target_mask = torch.from_numpy(mask_vec).unsqueeze(0).unsqueeze(0)
    series_ids = torch.zeros(1, 1, dtype=torch.long)

    target = target.to(device)
    target_mask = target_mask.to(device)
    series_ids = series_ids.to(device)

    with torch.no_grad():
        quantiles = model.forecast(
            {"target": target, "target_mask": target_mask, "series_ids": series_ids},
            horizon=horizon,
        )
    # quantiles: (9, 1, 1, horizon) → grab three quantile slices
    q = quantiles.detach().cpu().numpy()
    p10 = q[Q10_IDX, 0, 0]
    p50 = q[Q50_IDX, 0, 0]
    p90 = q[Q90_IDX, 0, 0]

    freq = _series_freq(clean)
    last_ts = clean.index[-1]
    future_idx = pd.date_range(start=last_ts + freq, periods=horizon, freq=freq, tz=last_ts.tz)

    return TotoForecast(
        median=pd.Series(p50, index=future_idx, name=f"{series.name}_median"),
        p10=pd.Series(p10, index=future_idx, name=f"{series.name}_p10"),
        p90=pd.Series(p90, index=future_idx, name=f"{series.name}_p90"),
    )