| """Toto 2.0 inference wrapper. |
| |
| We use the smallest Toto 2.0 variant (4M params) for speed on CPU. The model |
| is downloaded from the HuggingFace Hub on first use and cached. |
| |
| API confirmed against DataDog/toto's `toto2/notebooks/quick_start.ipynb`: |
| |
| from toto2 import Toto2Model |
| model = Toto2Model.from_pretrained("Datadog/Toto-2.0-4m", map_location=device) |
| quantiles = model.forecast( |
| {"target": ..., "target_mask": ..., "series_ids": ...}, |
| horizon=H, |
| ) |
| # quantiles shape: (9, batch, n_var, horizon) |
| # quantile levels: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| DEFAULT_MODEL_ID = "Datadog/Toto-2.0-22m" |
|
|
| |
| Q10_IDX = 0 |
| Q50_IDX = 4 |
| Q90_IDX = 8 |
|
|
|
|
| @dataclass |
| class TotoForecast: |
| """One metric's forecast. |
| |
| `index` is a future-timestamp DatetimeIndex; `median`, `p10`, `p90` are |
| pandas Series aligned to it. |
| """ |
| median: pd.Series |
| p10: pd.Series |
| p90: pd.Series |
|
|
|
|
| _MODEL_CACHE: dict[str, object] = {} |
|
|
|
|
| def load_model(model_id: str = DEFAULT_MODEL_ID, device: str = "cpu"): |
| """Lazy-load + cache the Toto model. Imports torch lazily so this module |
| is importable in environments without torch (local dev on Intel mac).""" |
| if model_id in _MODEL_CACHE: |
| return _MODEL_CACHE[model_id] |
|
|
| import torch |
| from toto2 import Toto2Model |
|
|
| actual_device = device if (device != "cuda" or torch.cuda.is_available()) else "cpu" |
| model = Toto2Model.from_pretrained(model_id, map_location=actual_device) |
| model = model.to(actual_device).eval() |
| _MODEL_CACHE[model_id] = model |
| return model |
|
|
|
|
| def _series_freq(series: pd.Series) -> pd.Timedelta: |
| """Infer the spacing of a regular time series; default to 1 hour.""" |
| if len(series.index) < 2: |
| return pd.Timedelta("1h") |
| diffs = pd.Series(series.index).diff().dropna() |
| if diffs.empty: |
| return pd.Timedelta("1h") |
| return diffs.median() |
|
|
|
|
| def forecast_series( |
| series: pd.Series, |
| horizon: int = 24, |
| model_id: str = DEFAULT_MODEL_ID, |
| device: str = "cpu", |
| ) -> TotoForecast: |
| """Univariate forecast for one metric. |
| |
| `series` must be regularly-spaced and have a DatetimeIndex (UTC). Returns |
| median, p10, p90 over `horizon` future steps at the same cadence. |
| """ |
| import torch |
|
|
| if series.empty: |
| raise ValueError("Cannot forecast an empty series") |
|
|
| import numpy as np |
|
|
| clean = series.astype(float).interpolate(limit_direction="both") |
|
|
| |
| |
| |
| |
| |
| model = load_model(model_id, device=device) |
| patch = int(model.config.patch_size) |
| raw = clean.to_numpy(dtype=np.float32) |
| n_raw = len(raw) |
|
|
| if n_raw >= patch: |
| n = (n_raw // patch) * patch |
| arr = raw[-n:] |
| mask_vec = np.ones(n, dtype=bool) |
| else: |
| n = patch |
| pad = n - n_raw |
| arr = np.concatenate([np.full(pad, raw[0], dtype=np.float32), raw]) |
| mask_vec = np.concatenate([np.zeros(pad, dtype=bool), np.ones(n_raw, dtype=bool)]) |
|
|
| target = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) |
| target_mask = torch.from_numpy(mask_vec).unsqueeze(0).unsqueeze(0) |
| series_ids = torch.zeros(1, 1, dtype=torch.long) |
|
|
| target = target.to(device) |
| target_mask = target_mask.to(device) |
| series_ids = series_ids.to(device) |
|
|
| with torch.no_grad(): |
| quantiles = model.forecast( |
| {"target": target, "target_mask": target_mask, "series_ids": series_ids}, |
| horizon=horizon, |
| ) |
| |
| q = quantiles.detach().cpu().numpy() |
| p10 = q[Q10_IDX, 0, 0] |
| p50 = q[Q50_IDX, 0, 0] |
| p90 = q[Q90_IDX, 0, 0] |
|
|
| freq = _series_freq(clean) |
| last_ts = clean.index[-1] |
| future_idx = pd.date_range(start=last_ts + freq, periods=horizon, freq=freq, tz=last_ts.tz) |
|
|
| return TotoForecast( |
| median=pd.Series(p50, index=future_idx, name=f"{series.name}_median"), |
| p10=pd.Series(p10, index=future_idx, name=f"{series.name}_p10"), |
| p90=pd.Series(p90, index=future_idx, name=f"{series.name}_p90"), |
| ) |
|
|