Spaces:
Running
Running
| """ | |
| Moirai zero-shot rolling-origin inference. | |
| Wraps Salesforce moirai (uni2ts.model.moirai) to forecast our supply-chain | |
| demand release. Multivariate-native: a single forward predicts all C items. | |
| Designed to mirror chronos_runner.py's I/O contract so metrics.py can stay | |
| unchanged: produces y_pred of shape (n_windows, H, n_items) with point | |
| forecast = median over num_samples draws. | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import time | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from gluonts.dataset.multivariate_grouper import MultivariateGrouper | |
| from gluonts.dataset.pandas import PandasDataset | |
| from gluonts.dataset.split import split as gluonts_split | |
| from uni2ts.model.moirai import MoiraiForecast, MoiraiModule | |
| def load_moirai(model_id: str, prediction_length: int, context_length: int, | |
| target_dim: int, num_samples: int = 100, | |
| patch_size: int | str = 32, batch_size: int = 32, | |
| device: str | None = None): | |
| """Build a Moirai predictor from a pretrained checkpoint.""" | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f" loading {model_id} on {device} " | |
| f"(L={context_length}, H={prediction_length}, " | |
| f"target_dim={target_dim}, patch_size={patch_size}, " | |
| f"num_samples={num_samples})", file=sys.stderr) | |
| module = MoiraiModule.from_pretrained(model_id) | |
| model = MoiraiForecast( | |
| module=module, | |
| prediction_length=prediction_length, | |
| context_length=context_length, | |
| patch_size=patch_size, | |
| num_samples=num_samples, | |
| target_dim=target_dim, | |
| feat_dynamic_real_dim=0, | |
| past_feat_dynamic_real_dim=0, | |
| ) | |
| predictor = model.create_predictor(batch_size=batch_size) | |
| return predictor | |
| def build_test_data(D: np.ndarray, item_ids: list[str], | |
| val_end: int, H: int, stride: int, | |
| max_windows: int | None = None): | |
| """Build a GluonTS multivariate test dataset for rolling-origin eval. | |
| The test region is [val_end, T). Window i has its forecast horizon at | |
| [val_end + i*stride, val_end + i*stride + H), with full-history input | |
| (the model itself trims to the last `context_length` steps). | |
| """ | |
| T, n_items = D.shape | |
| test_len = T - val_end | |
| # Wide DataFrame: each column is one item; daily frequency. | |
| df = pd.DataFrame( | |
| D, | |
| columns=item_ids, | |
| index=pd.date_range("2000-01-01", periods=T, freq="D"), | |
| ) | |
| ds = PandasDataset(dict(df)) | |
| grouper = MultivariateGrouper(len(ds)) | |
| multivar_ds = grouper(ds) | |
| train, test_template = gluonts_split(multivar_ds, offset=-test_len) | |
| n_windows = (test_len - H) // stride + 1 | |
| if max_windows is not None: | |
| n_windows = min(n_windows, max_windows) | |
| test_data = test_template.generate_instances( | |
| prediction_length=H, | |
| windows=n_windows, | |
| distance=stride, | |
| ) | |
| return test_data, n_windows | |
| def predict_rolling_origin(predictor, test_data, n_windows: int, | |
| H: int, n_items: int) -> np.ndarray: | |
| """Returns y_pred of shape (n_windows, H, n_items); point=median samples.""" | |
| y_pred = np.zeros((n_windows, H, n_items), dtype=np.float32) | |
| t0 = time.time() | |
| forecasts = predictor.predict(test_data.input) | |
| for wi, fc in enumerate(forecasts): | |
| # fc.samples shape: (num_samples, H, target_dim) | |
| samples = np.asarray(fc.samples, dtype=np.float32) | |
| med = np.median(samples, axis=0) # (H, target_dim) | |
| y_pred[wi] = med | |
| if wi == 0 or (wi + 1) % 10 == 0 or wi + 1 == n_windows: | |
| elapsed = time.time() - t0 | |
| rate = (wi + 1) / max(elapsed, 1e-9) | |
| eta = (n_windows - wi - 1) / max(rate, 1e-9) | |
| print(f" window {wi+1:4d}/{n_windows} " | |
| f"elapsed={elapsed:6.1f}s " | |
| f"rate={rate:5.2f} win/s " | |
| f"eta={eta/60:5.1f}min", file=sys.stderr) | |
| return y_pred | |
| def collect_y_true(test_data, n_windows: int, H: int, | |
| n_items: int) -> np.ndarray: | |
| """Extract ground truth labels: shape (n_windows, H, n_items).""" | |
| y_true = np.zeros((n_windows, H, n_items), dtype=np.float32) | |
| for wi, lbl in enumerate(test_data.label): | |
| # lbl["target"] shape: (target_dim, H) | |
| y_true[wi] = np.asarray(lbl["target"], dtype=np.float32).T | |
| return y_true | |