"""Visualisation helpers for the deployment surface. The Gradio demo overlays a reference 24-hour day on top of the generated scenario ensemble — these helpers do the loading and the plotting. Both are pure functions so the plot logic can be unit-tested without spinning up a Gradio session. """ from __future__ import annotations from pathlib import Path import matplotlib.figure import matplotlib.pyplot as plt import numpy as np import pandas as pd from pc_ddpm_alberta.config import DATA_DIR DEMO_DIR = DATA_DIR / "demo" REFERENCE_DAYS: dict[str, str] = { "typical_day": "typical_day.csv", "high_wind_day": "high_wind_day.csv", "low_wind_day": "low_wind_day.csv", } CHANNEL_LABELS: tuple[str, str, str] = ("Wind (MW)", "Solar (MW)", "Load (MW)") CHANNEL_COLUMNS: tuple[str, str, str] = ("wind_mw", "solar_mw", "load_mw") def load_reference_day(name: str, demo_dir: Path = DEMO_DIR) -> pd.DataFrame: """Load a 24-hour reference day by short name (`typical_day` etc.). Returns a DataFrame with a parsed `DatetimeIndex` of length 24 and the full upstream column schema (`wind_mw, solar_mw, load_mw, ...`). """ if name not in REFERENCE_DAYS: raise ValueError( f"unknown reference day {name!r}; expected one of {list(REFERENCE_DAYS)}" ) path = demo_dir / REFERENCE_DAYS[name] df = pd.read_csv(path, index_col=0, parse_dates=True) if len(df) != 24: raise ValueError(f"{path}: expected 24 rows, got {len(df)}") return df def build_overlay_figure( scenarios: np.ndarray, reference: pd.DataFrame, title: str | None = None, ) -> matplotlib.figure.Figure: """Three-panel plot of generated scenarios overlaid on a reference day. `scenarios` is the `(n_scenarios, 3, 24)` array returned by `predict()`; channel order is (wind, solar, load) per the upstream training script. The reference day is drawn in red; scenarios are translucent grey so the ensemble spread reads at a glance. """ if scenarios.ndim != 3 or scenarios.shape[1:] != (3, 24): raise ValueError( f"scenarios must be (n, 3, 24); got {scenarios.shape}" ) hours = np.arange(24) fig, axes = plt.subplots(3, 1, figsize=(9, 7), sharex=True) n_scenarios = scenarios.shape[0] alpha = max(0.05, min(0.4, 4.0 / n_scenarios)) for ch, (ax, label, col) in enumerate(zip(axes, CHANNEL_LABELS, CHANNEL_COLUMNS, strict=True)): ax.plot(hours, scenarios[:, ch, :].T, color="0.4", alpha=alpha, linewidth=0.8) ax.plot( hours, reference[col].to_numpy(), color="crimson", linewidth=2.0, label="reference day", ) ax.set_ylabel(label) ax.grid(True, alpha=0.3) ax.set_xlim(0, 23) axes[0].legend( [f"{n_scenarios} generated scenarios", "reference day"], loc="upper right", fontsize=9, ) axes[-1].set_xlabel("Hour of day") if title is None: ref_date = reference.index[0].date() if hasattr(reference.index[0], "date") else "—" title = f"PC-DDPM scenarios vs reference day ({ref_date})" fig.suptitle(title, fontsize=12) fig.tight_layout() return fig