""" Chronos zero-shot rolling-origin inference. Loads a Chronos pipeline from HuggingFace (uses HF_HOME cache so no network is needed once the checkpoint is local), runs L=512 → H=30 inference per channel batched across channels, and reduces the 20 sample paths to the per-day median for point-forecast metrics. """ from __future__ import annotations import sys from pathlib import Path import time import numpy as np import torch from chronos import ChronosPipeline def load_chronos(model_id: str, device: str | None = None, dtype: torch.dtype | None = None) -> ChronosPipeline: """Load Chronos pipeline; auto-detect device/dtype if not given.""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if dtype is None: dtype = (torch.bfloat16 if device == "cuda" else torch.float32) print(f" loading {model_id} on {device} ({dtype})", file=sys.stderr) return ChronosPipeline.from_pretrained( model_id, device_map=device, torch_dtype=dtype, ) def predict_rolling_origin( pipe: ChronosPipeline, D: np.ndarray, # (T, n_items) window_starts: list[int], # rolling-origin t values; forecast [t, t+H) L: int = 512, H: int = 30, num_samples: int = 20, channel_batch: int = 16, ) -> np.ndarray: """Returns y_pred of shape (n_windows, H, n_items), point=median.""" n_windows = len(window_starts) n_items = D.shape[1] y_pred = np.zeros((n_windows, H, n_items), dtype=np.float32) t0 = time.time() for wi, t in enumerate(window_starts): ctx = D[t - L:t, :] # (L, n_items) # Predict all channels for this window in batches. for j0 in range(0, n_items, channel_batch): j1 = min(j0 + channel_batch, n_items) # ChronosPipeline.predict expects a list of 1-D tensors. ctxs = [torch.tensor(ctx[:, j], dtype=torch.float32) for j in range(j0, j1)] samples = pipe.predict( ctxs, prediction_length=H, num_samples=num_samples, limit_prediction_length=False, ) # samples: (batch, num_samples, H) -- median over samples samples = samples.cpu().to(torch.float32).numpy() med = np.median(samples, axis=1) # (batch, H) y_pred[wi, :, j0:j1] = med.T # (H, batch) if wi == 0 or (wi + 1) % 10 == 0 or wi + 1 == n_windows: elapsed = time.time() - t0 rate = (wi + 1) / max(elapsed, 1e-9) eta = (n_windows - wi - 1) / max(rate, 1e-9) print(f" window {wi+1:4d}/{n_windows} " f"elapsed={elapsed:6.1f}s " f"rate={rate:5.2f} win/s " f"eta={eta/60:5.1f}min", file=sys.stderr) return y_pred def collect_y_true(D: np.ndarray, window_starts: list[int], H: int) -> np.ndarray: n_windows = len(window_starts) n_items = D.shape[1] y_true = np.zeros((n_windows, H, n_items), dtype=np.float32) for wi, t in enumerate(window_starts): y_true[wi] = D[t:t + H] return y_true