ISOMORPH-demo / uq /plot_uq_envelope.py
HyeminGu
initial demo
ec7be33
"""§4 UQ forecast-envelope figure: parameter UQ propagated to forecaster output.
For each model with K=20 LHS perturbation tensors, plots median + 10/90
band of y_true (input) and y_pred (output) across K, at one or three
forecast windows for one item. The grey band is the band of physical
realisations the network produces under demand-side parameter perturbation;
the coloured band is the band of zero-shot forecasts of those realisations.
python plot_uq_envelope.py # 2x2, deterministic mid window
python plot_uq_envelope.py --multi # 3x4 multi-window grid
python plot_uq_envelope.py --window 25 # explicit single window
python plot_uq_envelope.py --item I05 chronos # subset of models
"""
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
REPO = Path(__file__).resolve().parents[1]
RESULT_DIR = REPO / "results" / "eval" / "uq"
FIG_DIR = REPO / "results" / "uq" / "figures"
MODELS = {
"chronos": {"prefix": "chronos_t5_base",
"fill": "#9BB0CC", "line": "#2F4A75",
"display": "Chronos"},
"moirai": {"prefix": "moirai_1_1_R_base",
"fill": "#A8BFA0", "line": "#345531",
"display": "Moirai"},
"timesfm": {"prefix": "timesfm_2_0_500m_pytorch",
"fill": "#D7A992", "line": "#693220",
"display": "TimesFM"},
"lagllama": {"prefix": "lag_llama",
"fill": "#C2A8CC", "line": "#4D2752",
"display": "Lag-Llama"},
}
TRUTH_BAND_COLOR = "#9CA3AF"
TRUTH_LINE_COLOR = "#374151"
AXES_FACE = "#FFFFFF"
SPINE_COLOR = "#374151"
def load_K_tensors(prefix: str, K: int = 20):
yp_list, yt_list = [], []
item_ids = window_starts = None
for k in range(1, K + 1):
p = RESULT_DIR / f"{prefix}_perturb_k{k:02d}_tensors.npz"
if not p.exists():
return None
d = np.load(p)
yp_list.append(d["y_pred"])
yt_list.append(d["y_true"])
if item_ids is None:
item_ids, window_starts = d["item_ids"], d["window_starts"]
y_pred = np.stack(yp_list, axis=0) # (K, W, H, C)
y_true = np.stack(yt_list, axis=0)
return y_pred, y_true, item_ids, window_starts
def deterministic_windows(W: int, n: int) -> list[int]:
"""Evenly-spaced windows inside the test split, avoiding the edges."""
if n == 1:
return [W // 2]
return [int(round(W * (i + 1) / (n + 1))) for i in range(n)]
def draw_band(ax, arr_kh: np.ndarray, fill_color: str, line_color: str,
label: str, fill_alpha: float, lw: float, z_base: int):
h = np.arange(1, arr_kh.shape[1] + 1)
med = np.median(arr_kh, axis=0)
q10 = np.percentile(arr_kh, 10, axis=0)
q90 = np.percentile(arr_kh, 90, axis=0)
ax.fill_between(h, q10, q90, color=fill_color, alpha=fill_alpha,
linewidth=0, zorder=z_base)
ax.plot(h, med, color=line_color, lw=lw, label=label, zorder=z_base + 2)
def style_axes(ax):
ax.set_facecolor(AXES_FACE)
for side in ("top", "right", "left", "bottom"):
ax.spines[side].set_color(SPINE_COLOR)
ax.spines[side].set_linewidth(0.7)
ax.tick_params(colors=SPINE_COLOR, length=3, width=0.7)
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_color(SPINE_COLOR)
ax.grid(True, linestyle=':', alpha=0.35, linewidth=0.5,
color=SPINE_COLOR)
def _draw_main(ax, yp_kh, yt_kh, model_key, draw_legend):
draw_band(ax, yt_kh,
fill_color=TRUTH_BAND_COLOR, line_color=TRUTH_LINE_COLOR,
label=r"truth $y_{i,t}$",
fill_alpha=0.30, lw=1.0, z_base=1)
draw_band(ax, yp_kh,
fill_color=MODELS[model_key]["fill"],
line_color=MODELS[model_key]["line"],
label=r"forecast $\hat y_{i,t}$",
fill_alpha=0.30, lw=1.8, z_base=3)
style_axes(ax)
if draw_legend:
leg = ax.legend(loc="upper left", fontsize=8.5, frameon=True,
facecolor=AXES_FACE, edgecolor=SPINE_COLOR)
leg.get_frame().set_linewidth(0.6)
for txt in leg.get_texts():
txt.set_color(SPINE_COLOR)
def _zoom_ylim(yp_kh, yt_kh, pad_frac: float = 0.08):
yt_med = np.median(yt_kh, axis=0)
yp_med = np.median(yp_kh, axis=0)
y_lo = float(min(yt_med.min(), yp_med.min()))
y_hi = float(max(yt_med.max(), yp_med.max()))
span = max(y_hi - y_lo, 1e-6)
pad = pad_frac * span
return y_lo - pad, y_hi + pad
def _draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key):
"""Sibling axes below `main_ax` with the same bands and medians, but
y-axis tightened to the median range. Also shades the corresponding
horizontal slice on `main_ax` so the link is explicit.
"""
h = np.arange(1, yp_kh.shape[1] + 1)
yt_med = np.median(yt_kh, axis=0)
yp_med = np.median(yp_kh, axis=0)
yt_q10 = np.percentile(yt_kh, 10, axis=0)
yt_q90 = np.percentile(yt_kh, 90, axis=0)
yp_q10 = np.percentile(yp_kh, 10, axis=0)
yp_q90 = np.percentile(yp_kh, 90, axis=0)
y_lo, y_hi = _zoom_ylim(yp_kh, yt_kh)
zoom_ax.fill_between(h, yt_q10, yt_q90, color=TRUTH_BAND_COLOR,
alpha=0.30, linewidth=0, zorder=1)
zoom_ax.plot(h, yt_med, color=TRUTH_LINE_COLOR, lw=1.0, zorder=3)
zoom_ax.fill_between(h, yp_q10, yp_q90,
color=MODELS[model_key]["fill"],
alpha=0.30, linewidth=0, zorder=2)
zoom_ax.plot(h, yp_med, color=MODELS[model_key]["line"], lw=1.6, zorder=4)
zoom_ax.set_xlim(int(h[0]), int(h[-1]))
zoom_ax.set_ylim(y_lo, y_hi)
style_axes(zoom_ax)
zoom_ax.tick_params(axis="y", labelsize=8)
# Mark the zoom y-slice on the parent so the reader sees exactly which
# part of the main panel is being zoomed.
main_ax.axhspan(y_lo, y_hi, color=SPINE_COLOR, alpha=0.10,
linewidth=0, zorder=0.5)
main_ax.axhline(y_lo, color=SPINE_COLOR, lw=0.5, ls=":",
alpha=0.7, zorder=0.6)
main_ax.axhline(y_hi, color=SPINE_COLOR, lw=0.5, ls=":",
alpha=0.7, zorder=0.6)
def plot_2x2(data: dict, item_id: str, item_idx: int, w: int,
window_start: int, fig_path: Path, with_zoom: bool = True):
fig_h = 8.4 if with_zoom else 5.6
fig = plt.figure(figsize=(9.6, fig_h))
fig.patch.set_facecolor("white")
outer = fig.add_gridspec(2, 2, hspace=0.30, wspace=0.18,
left=0.07, right=0.99, top=0.94, bottom=0.07)
items = list(data.items())
placements = [(0, 0), (0, 1), (1, 0), (1, 1)]
for (ri, ci), (model_key, (yp, yt)) in zip(placements, items):
yp_kh = yp[:, w, :, item_idx]
yt_kh = yt[:, w, :, item_idx]
if with_zoom:
inner = outer[ri, ci].subgridspec(
2, 1, height_ratios=[2.6, 1.9], hspace=0.06)
main_ax = fig.add_subplot(inner[0])
zoom_ax = fig.add_subplot(inner[1], sharex=main_ax)
else:
main_ax = fig.add_subplot(outer[ri, ci])
zoom_ax = None
_draw_main(main_ax, yp_kh, yt_kh, model_key,
draw_legend=(ri == 0 and ci == 0))
main_ax.set_title(MODELS[model_key]["display"], fontsize=11,
color=SPINE_COLOR)
if zoom_ax is not None:
_draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key)
plt.setp(main_ax.get_xticklabels(), visible=False)
if ri == 1:
(zoom_ax if zoom_ax is not None else main_ax).set_xlabel(
r"forecast horizon $h$ (time units)", color=SPINE_COLOR)
if ci == 0:
main_ax.set_ylabel(f"item {item_id} demand",
color=SPINE_COLOR)
if zoom_ax is not None:
zoom_ax.set_ylabel("zoom (medians)", fontsize=8.5,
color=SPINE_COLOR)
fig.savefig(fig_path, bbox_inches="tight", facecolor="white")
fig.savefig(fig_path.with_suffix(".png"), bbox_inches="tight",
dpi=160, facecolor="white")
plt.close(fig)
def plot_multi(data: dict, item_id: str, item_idx: int,
windows: list[int], window_starts_arr: np.ndarray,
fig_path: Path):
"""Grid: rows = windows, cols = models. Each cell is a (main, zoom)
vertical pair sharing x; the zoom row uses tightened y-limits."""
n_rows = len(windows)
n_cols = len(data)
fig = plt.figure(figsize=(3.2 * n_cols, 4.0 * n_rows))
fig.patch.set_facecolor("white")
outer = fig.add_gridspec(n_rows, n_cols, hspace=0.32, wspace=0.20,
left=0.06, right=0.99, top=0.95, bottom=0.06)
model_items = list(data.items())
for r, w in enumerate(windows):
t0 = int(window_starts_arr[w])
for c, (model_key, (yp, yt)) in enumerate(model_items):
inner = outer[r, c].subgridspec(
2, 1, height_ratios=[2.6, 1.9], hspace=0.06)
main_ax = fig.add_subplot(inner[0])
zoom_ax = fig.add_subplot(inner[1], sharex=main_ax)
yp_kh = yp[:, w, :, item_idx]
yt_kh = yt[:, w, :, item_idx]
_draw_main(main_ax, yp_kh, yt_kh, model_key,
draw_legend=(r == 0 and c == 0))
if r == 0:
main_ax.set_title(MODELS[model_key]["display"], fontsize=11,
color=SPINE_COLOR)
_draw_zoom(zoom_ax, main_ax, yp_kh, yt_kh, model_key)
plt.setp(main_ax.get_xticklabels(), visible=False)
if c == 0:
main_ax.set_ylabel(f"$t_0{{=}}{t0}$", fontsize=10,
color=SPINE_COLOR)
zoom_ax.set_ylabel("zoom", fontsize=8.5, color=SPINE_COLOR)
if r == n_rows - 1:
zoom_ax.set_xlabel(r"forecast horizon $h$ (time units)",
color=SPINE_COLOR)
fig.text(0.005, 0.5, f"item {item_id} demand",
rotation="vertical", va="center", ha="left",
fontsize=10.5, color=SPINE_COLOR)
fig.savefig(fig_path, bbox_inches="tight", facecolor="white")
fig.savefig(fig_path.with_suffix(".png"), bbox_inches="tight",
dpi=160, facecolor="white")
plt.close(fig)
def main():
args = sys.argv[1:]
item = "I01"
window_arg: int | None = None
multi = False
narrowest = False
if "--item" in args:
i = args.index("--item"); item = args[i + 1]; args = args[:i] + args[i + 2:]
if "--window" in args:
i = args.index("--window"); window_arg = int(args[i + 1])
args = args[:i] + args[i + 2:]
if "--multi" in args:
multi = True; args.remove("--multi")
if "--narrowest" in args:
narrowest = True; args.remove("--narrowest")
no_zoom = False
if "--no-zoom" in args:
no_zoom = True; args.remove("--no-zoom")
requested = args if args else list(MODELS.keys())
bad = [m for m in requested if m not in MODELS]
if bad:
sys.exit(f"unknown model(s): {bad}; choose from {list(MODELS.keys())}")
FIG_DIR.mkdir(parents=True, exist_ok=True)
data: dict = {}
item_ids = window_starts = None
for m in requested:
out = load_K_tensors(MODELS[m]["prefix"])
if out is None:
print(f"[{m}] missing tensors, skipping")
continue
yp, yt, ids, ws = out
if item_ids is None:
item_ids, window_starts = ids, ws
data[m] = (yp, yt)
print(f"[{m}] y_pred={yp.shape}, y_true={yt.shape}")
if not data:
sys.exit("no models loaded")
item_idx = list(item_ids).index(item)
W_total = next(iter(data.values()))[0].shape[1]
if multi:
windows = deterministic_windows(W_total, 3)
print(f"item {item} (idx={item_idx}); multi-window grid w={windows} "
f"(t0={[int(window_starts[w]) for w in windows]})")
out = FIG_DIR / f"uq_envelope_{item}_multi.pdf"
plot_multi(data, item, item_idx, windows,
np.asarray(window_starts), out)
print(f"wrote {out}")
return
if window_arg is None:
if narrowest:
yt_any = next(iter(data.values()))[1]
win_mean = yt_any[:, :, :, item_idx].mean(axis=-1) # (K, W)
spread = win_mean.std(axis=0) # (W,)
w = int(spread.argmin())
print(f"min cross-K spread window: w={w}, "
f"spread={spread[w]:.2f} "
f"(min={spread.min():.2f}, max={spread.max():.2f})")
else:
w = 12
print(f"default window: w={w}")
else:
w = window_arg
t0 = int(window_starts[w])
print(f"item {item} (idx={item_idx}); window w={w}, t_start={t0}")
if window_arg is not None:
suffix = f"{item}_w{w:02d}"
elif narrowest:
suffix = f"{item}_narrowest"
else:
suffix = item
if no_zoom:
suffix = f"{suffix}_nozoom"
out = FIG_DIR / f"uq_envelope_{suffix}.pdf"
plot_2x2(data, item, item_idx, w, t0, out, with_zoom=not no_zoom)
print(f"wrote {out}")
if __name__ == "__main__":
main()