ISOMORPH-demo / eval /moirai_runner.py
HyeminGu
initial demo
ec7be33
"""
Moirai zero-shot rolling-origin inference.
Wraps Salesforce moirai (uni2ts.model.moirai) to forecast our supply-chain
demand release. Multivariate-native: a single forward predicts all C items.
Designed to mirror chronos_runner.py's I/O contract so metrics.py can stay
unchanged: produces y_pred of shape (n_windows, H, n_items) with point
forecast = median over num_samples draws.
"""
from __future__ import annotations
import sys
import time
import numpy as np
import pandas as pd
import torch
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split as gluonts_split
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
def load_moirai(model_id: str, prediction_length: int, context_length: int,
target_dim: int, num_samples: int = 100,
patch_size: int | str = 32, batch_size: int = 32,
device: str | None = None):
"""Build a Moirai predictor from a pretrained checkpoint."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f" loading {model_id} on {device} "
f"(L={context_length}, H={prediction_length}, "
f"target_dim={target_dim}, patch_size={patch_size}, "
f"num_samples={num_samples})", file=sys.stderr)
module = MoiraiModule.from_pretrained(model_id)
model = MoiraiForecast(
module=module,
prediction_length=prediction_length,
context_length=context_length,
patch_size=patch_size,
num_samples=num_samples,
target_dim=target_dim,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
predictor = model.create_predictor(batch_size=batch_size)
return predictor
def build_test_data(D: np.ndarray, item_ids: list[str],
val_end: int, H: int, stride: int,
max_windows: int | None = None):
"""Build a GluonTS multivariate test dataset for rolling-origin eval.
The test region is [val_end, T). Window i has its forecast horizon at
[val_end + i*stride, val_end + i*stride + H), with full-history input
(the model itself trims to the last `context_length` steps).
"""
T, n_items = D.shape
test_len = T - val_end
# Wide DataFrame: each column is one item; daily frequency.
df = pd.DataFrame(
D,
columns=item_ids,
index=pd.date_range("2000-01-01", periods=T, freq="D"),
)
ds = PandasDataset(dict(df))
grouper = MultivariateGrouper(len(ds))
multivar_ds = grouper(ds)
train, test_template = gluonts_split(multivar_ds, offset=-test_len)
n_windows = (test_len - H) // stride + 1
if max_windows is not None:
n_windows = min(n_windows, max_windows)
test_data = test_template.generate_instances(
prediction_length=H,
windows=n_windows,
distance=stride,
)
return test_data, n_windows
def predict_rolling_origin(predictor, test_data, n_windows: int,
H: int, n_items: int) -> np.ndarray:
"""Returns y_pred of shape (n_windows, H, n_items); point=median samples."""
y_pred = np.zeros((n_windows, H, n_items), dtype=np.float32)
t0 = time.time()
forecasts = predictor.predict(test_data.input)
for wi, fc in enumerate(forecasts):
# fc.samples shape: (num_samples, H, target_dim)
samples = np.asarray(fc.samples, dtype=np.float32)
med = np.median(samples, axis=0) # (H, target_dim)
y_pred[wi] = med
if wi == 0 or (wi + 1) % 10 == 0 or wi + 1 == n_windows:
elapsed = time.time() - t0
rate = (wi + 1) / max(elapsed, 1e-9)
eta = (n_windows - wi - 1) / max(rate, 1e-9)
print(f" window {wi+1:4d}/{n_windows} "
f"elapsed={elapsed:6.1f}s "
f"rate={rate:5.2f} win/s "
f"eta={eta/60:5.1f}min", file=sys.stderr)
return y_pred
def collect_y_true(test_data, n_windows: int, H: int,
n_items: int) -> np.ndarray:
"""Extract ground truth labels: shape (n_windows, H, n_items)."""
y_true = np.zeros((n_windows, H, n_items), dtype=np.float32)
for wi, lbl in enumerate(test_data.label):
# lbl["target"] shape: (target_dim, H)
y_true[wi] = np.asarray(lbl["target"], dtype=np.float32).T
return y_true