"""ยง4 UQ forecast-envelope figure: parameter UQ propagated to forecaster output. For each model with K=20 LHS perturbation tensors, plots median + 10/90 band of y_true (input) and y_pred (output) across K, at one or three forecast windows for one item. The grey band is the band of physical realisations the network produces under demand-side parameter perturbation; the coloured band is the band of zero-shot forecasts of those realisations. python plot_uq_envelope.py # 2x2, deterministic mid window python plot_uq_envelope.py --multi # 3x4 multi-window grid python plot_uq_envelope.py --window 25 # explicit single window python plot_uq_envelope.py --item I05 chronos # subset of models """ from __future__ import annotations import sys from pathlib import Path import numpy as np import matplotlib.pyplot as plt REPO = Path(__file__).resolve().parents[1] RESULT_DIR = REPO / "results" / "eval" / "uq" FIG_DIR = REPO / "results" / "uq" / "figures" MODELS = { "chronos": {"prefix": "chronos_t5_base", "fill": "#9BB0CC", "line": "#2F4A75", "display": "Chronos"}, "moirai": {"prefix": "moirai_1_1_R_base", "fill": "#A8BFA0", "line": "#345531", "display": "Moirai"}, "timesfm": {"prefix": "timesfm_2_0_500m_pytorch", "fill": "#D7A992", "line": "#693220", "display": "TimesFM"}, "lagllama": {"prefix": "lag_llama", "fill": "#C2A8CC", "line": "#4D2752", "display": "Lag-Llama"}, } TRUTH_BAND_COLOR = "#9CA3AF" TRUTH_LINE_COLOR = "#374151" AXES_FACE = "#FFFFFF" SPINE_COLOR = "#374151" def load_K_tensors(prefix: str, K: int = 20): yp_list, yt_list = [], [] item_ids = window_starts = None for k in range(1, K + 1): p = RESULT_DIR / f"{prefix}_perturb_k{k:02d}_tensors.npz" if not p.exists(): return None d = np.load(p) yp_list.append(d["y_pred"]) yt_list.append(d["y_true"]) if item_ids is None: item_ids, window_starts = d["item_ids"], d["window_starts"] y_pred = np.stack(yp_list, axis=0) # (K, W, H, C) y_true = np.stack(yt_list, axis=0) return y_pred, y_true, item_ids, window_starts def deterministic_windows(W: int, n: int) -> list[int]: """Evenly-spaced windows inside the test split, avoiding the edges.""" if n == 1: return [W // 2] return [int(round(W * (i + 1) / (n + 1))) for i in range(n)] def draw_band(ax, arr_kh: np.ndarray, fill_color: str, line_color: str, label: str, fill_alpha: float, lw: float, z_base: int): h = np.arange(1, arr_kh.shape[1] + 1) med = np.median(arr_kh, axis=0) q10 = np.percentile(arr_kh, 10, axis=0) q90 = np.percentile(arr_kh, 90, axis=0) ax.fill_between(h, q10, q90, color=fill_color, alpha=fill_alpha, linewidth=0, zorder=z_base) ax.plot(h, med, color=line_color, lw=lw, label=label, zorder=z_base + 2) def style_axes(ax): ax.set_facecolor(AXES_FACE) for side in ("top", "right", "left", "bottom"): ax.spines[side].set_color(SPINE_COLOR) ax.spines[side].set_linewidth(0.7) ax.tick_params(colors=SPINE_COLOR, length=3, width=0.7) for label in ax.get_xticklabels() + ax.get_yticklabels(): label.set_color(SPINE_COLOR) ax.grid(True, linestyle=':', alpha=0.35, linewidth=0.5, color=SPINE_COLOR) def _draw_main(ax, yp_kh, yt_kh, model_key, draw_legend): draw_band(ax, yt_kh, fill_color=TRUTH_BAND_COLOR, line_color=TRUTH_LINE_COLOR, label=r"truth $y_{i,t}$", fill_alpha=0.30, lw=1.0, z_base=1) draw_band(ax, yp_kh, fill_color=MODELS[model_key]["fill"], line_color=MODELS[model_key]["line"], label=r"forecast $\hat y_{i,t}$", fill_alpha=0.30, lw=1.8, z_base=3) style_axes(ax) if draw_legend: leg = ax.legend(loc="upper left", fontsize=8.5, frameon=True, facecolor=AXES_FACE, edgecolor=SPINE_COLOR) leg.get_frame().set_linewidth(0.6) for txt in leg.get_texts(): txt.set_color(SPINE_COLOR) def _zoom_ylim(yp_kh, yt_kh, pad_frac: float = 0.08): yt_med = np.median(yt_kh, axis=0) yp_med = np.median(yp_kh, axis=0) y_lo = float(min(yt_med.min(), yp_med.min())) y_hi = float(max(yt_med.max(), yp_med.max())) span = max(y_hi - y_lo, 1e-6) pad = pad_frac * span return y_lo - pad, y_hi + pad def _draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key): """Sibling axes below `main_ax` with the same bands and medians, but y-axis tightened to the median range. Also shades the corresponding horizontal slice on `main_ax` so the link is explicit. """ h = np.arange(1, yp_kh.shape[1] + 1) yt_med = np.median(yt_kh, axis=0) yp_med = np.median(yp_kh, axis=0) yt_q10 = np.percentile(yt_kh, 10, axis=0) yt_q90 = np.percentile(yt_kh, 90, axis=0) yp_q10 = np.percentile(yp_kh, 10, axis=0) yp_q90 = np.percentile(yp_kh, 90, axis=0) y_lo, y_hi = _zoom_ylim(yp_kh, yt_kh) zoom_ax.fill_between(h, yt_q10, yt_q90, color=TRUTH_BAND_COLOR, alpha=0.30, linewidth=0, zorder=1) zoom_ax.plot(h, yt_med, color=TRUTH_LINE_COLOR, lw=1.0, zorder=3) zoom_ax.fill_between(h, yp_q10, yp_q90, color=MODELS[model_key]["fill"], alpha=0.30, linewidth=0, zorder=2) zoom_ax.plot(h, yp_med, color=MODELS[model_key]["line"], lw=1.6, zorder=4) zoom_ax.set_xlim(int(h[0]), int(h[-1])) zoom_ax.set_ylim(y_lo, y_hi) style_axes(zoom_ax) zoom_ax.tick_params(axis="y", labelsize=8) # Mark the zoom y-slice on the parent so the reader sees exactly which # part of the main panel is being zoomed. main_ax.axhspan(y_lo, y_hi, color=SPINE_COLOR, alpha=0.10, linewidth=0, zorder=0.5) main_ax.axhline(y_lo, color=SPINE_COLOR, lw=0.5, ls=":", alpha=0.7, zorder=0.6) main_ax.axhline(y_hi, color=SPINE_COLOR, lw=0.5, ls=":", alpha=0.7, zorder=0.6) def plot_2x2(data: dict, item_id: str, item_idx: int, w: int, window_start: int, fig_path: Path, with_zoom: bool = True): fig_h = 8.4 if with_zoom else 5.6 fig = plt.figure(figsize=(9.6, fig_h)) fig.patch.set_facecolor("white") outer = fig.add_gridspec(2, 2, hspace=0.30, wspace=0.18, left=0.07, right=0.99, top=0.94, bottom=0.07) items = list(data.items()) placements = [(0, 0), (0, 1), (1, 0), (1, 1)] for (ri, ci), (model_key, (yp, yt)) in zip(placements, items): yp_kh = yp[:, w, :, item_idx] yt_kh = yt[:, w, :, item_idx] if with_zoom: inner = outer[ri, ci].subgridspec( 2, 1, height_ratios=[2.6, 1.9], hspace=0.06) main_ax = fig.add_subplot(inner[0]) zoom_ax = fig.add_subplot(inner[1], sharex=main_ax) else: main_ax = fig.add_subplot(outer[ri, ci]) zoom_ax = None _draw_main(main_ax, yp_kh, yt_kh, model_key, draw_legend=(ri == 0 and ci == 0)) main_ax.set_title(MODELS[model_key]["display"], fontsize=11, color=SPINE_COLOR) if zoom_ax is not None: _draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key) plt.setp(main_ax.get_xticklabels(), visible=False) if ri == 1: (zoom_ax if zoom_ax is not None else main_ax).set_xlabel( r"forecast horizon $h$ (time units)", color=SPINE_COLOR) if ci == 0: main_ax.set_ylabel(f"item {item_id} demand", color=SPINE_COLOR) if zoom_ax is not None: zoom_ax.set_ylabel("zoom (medians)", fontsize=8.5, color=SPINE_COLOR) fig.savefig(fig_path, bbox_inches="tight", facecolor="white") fig.savefig(fig_path.with_suffix(".png"), bbox_inches="tight", dpi=160, facecolor="white") plt.close(fig) def plot_multi(data: dict, item_id: str, item_idx: int, windows: list[int], window_starts_arr: np.ndarray, fig_path: Path): """Grid: rows = windows, cols = models. Each cell is a (main, zoom) vertical pair sharing x; the zoom row uses tightened y-limits.""" n_rows = len(windows) n_cols = len(data) fig = plt.figure(figsize=(3.2 * n_cols, 4.0 * n_rows)) fig.patch.set_facecolor("white") outer = fig.add_gridspec(n_rows, n_cols, hspace=0.32, wspace=0.20, left=0.06, right=0.99, top=0.95, bottom=0.06) model_items = list(data.items()) for r, w in enumerate(windows): t0 = int(window_starts_arr[w]) for c, (model_key, (yp, yt)) in enumerate(model_items): inner = outer[r, c].subgridspec( 2, 1, height_ratios=[2.6, 1.9], hspace=0.06) main_ax = fig.add_subplot(inner[0]) zoom_ax = fig.add_subplot(inner[1], sharex=main_ax) yp_kh = yp[:, w, :, item_idx] yt_kh = yt[:, w, :, item_idx] _draw_main(main_ax, yp_kh, yt_kh, model_key, draw_legend=(r == 0 and c == 0)) if r == 0: main_ax.set_title(MODELS[model_key]["display"], fontsize=11, color=SPINE_COLOR) _draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key) plt.setp(main_ax.get_xticklabels(), visible=False) if c == 0: main_ax.set_ylabel(f"$t_0{{=}}{t0}$", fontsize=10, color=SPINE_COLOR) zoom_ax.set_ylabel("zoom", fontsize=8.5, color=SPINE_COLOR) if r == n_rows - 1: zoom_ax.set_xlabel(r"forecast horizon $h$ (time units)", color=SPINE_COLOR) fig.text(0.005, 0.5, f"item {item_id} demand", rotation="vertical", va="center", ha="left", fontsize=10.5, color=SPINE_COLOR) fig.savefig(fig_path, bbox_inches="tight", facecolor="white") fig.savefig(fig_path.with_suffix(".png"), bbox_inches="tight", dpi=160, facecolor="white") plt.close(fig) def main(): args = sys.argv[1:] item = "I01" window_arg: int | None = None multi = False narrowest = False if "--item" in args: i = args.index("--item"); item = args[i + 1]; args = args[:i] + args[i + 2:] if "--window" in args: i = args.index("--window"); window_arg = int(args[i + 1]) args = args[:i] + args[i + 2:] if "--multi" in args: multi = True; args.remove("--multi") if "--narrowest" in args: narrowest = True; args.remove("--narrowest") no_zoom = False if "--no-zoom" in args: no_zoom = True; args.remove("--no-zoom") requested = args if args else list(MODELS.keys()) bad = [m for m in requested if m not in MODELS] if bad: sys.exit(f"unknown model(s): {bad}; choose from {list(MODELS.keys())}") FIG_DIR.mkdir(parents=True, exist_ok=True) data: dict = {} item_ids = window_starts = None for m in requested: out = load_K_tensors(MODELS[m]["prefix"]) if out is None: print(f"[{m}] missing tensors, skipping") continue yp, yt, ids, ws = out if item_ids is None: item_ids, window_starts = ids, ws data[m] = (yp, yt) print(f"[{m}] y_pred={yp.shape}, y_true={yt.shape}") if not data: sys.exit("no models loaded") item_idx = list(item_ids).index(item) W_total = next(iter(data.values()))[0].shape[1] if multi: windows = deterministic_windows(W_total, 3) print(f"item {item} (idx={item_idx}); multi-window grid w={windows} " f"(t0={[int(window_starts[w]) for w in windows]})") out = FIG_DIR / f"uq_envelope_{item}_multi.pdf" plot_multi(data, item, item_idx, windows, np.asarray(window_starts), out) print(f"wrote {out}") return if window_arg is None: if narrowest: yt_any = next(iter(data.values()))[1] win_mean = yt_any[:, :, :, item_idx].mean(axis=-1) # (K, W) spread = win_mean.std(axis=0) # (W,) w = int(spread.argmin()) print(f"min cross-K spread window: w={w}, " f"spread={spread[w]:.2f} " f"(min={spread.min():.2f}, max={spread.max():.2f})") else: w = 12 print(f"default window: w={w}") else: w = window_arg t0 = int(window_starts[w]) print(f"item {item} (idx={item_idx}); window w={w}, t_start={t0}") if window_arg is not None: suffix = f"{item}_w{w:02d}" elif narrowest: suffix = f"{item}_narrowest" else: suffix = item if no_zoom: suffix = f"{suffix}_nozoom" out = FIG_DIR / f"uq_envelope_{suffix}.pdf" plot_2x2(data, item, item_idx, w, t0, out, with_zoom=not no_zoom) print(f"wrote {out}") if __name__ == "__main__": main()