jbobym's picture
space deploy: trim short_description to fit HF 60-char cap
93ed35a
"""Visualisation helpers for the deployment surface.
The Gradio demo overlays a reference 24-hour day on top of the generated
scenario ensemble — these helpers do the loading and the plotting. Both
are pure functions so the plot logic can be unit-tested without spinning
up a Gradio session.
"""
from __future__ import annotations
from pathlib import Path
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pc_ddpm_alberta.config import DATA_DIR
DEMO_DIR = DATA_DIR / "demo"
REFERENCE_DAYS: dict[str, str] = {
"typical_day": "typical_day.csv",
"high_wind_day": "high_wind_day.csv",
"low_wind_day": "low_wind_day.csv",
}
CHANNEL_LABELS: tuple[str, str, str] = ("Wind (MW)", "Solar (MW)", "Load (MW)")
CHANNEL_COLUMNS: tuple[str, str, str] = ("wind_mw", "solar_mw", "load_mw")
def load_reference_day(name: str, demo_dir: Path = DEMO_DIR) -> pd.DataFrame:
"""Load a 24-hour reference day by short name (`typical_day` etc.).
Returns a DataFrame with a parsed `DatetimeIndex` of length 24 and the
full upstream column schema (`wind_mw, solar_mw, load_mw, ...`).
"""
if name not in REFERENCE_DAYS:
raise ValueError(
f"unknown reference day {name!r}; expected one of {list(REFERENCE_DAYS)}"
)
path = demo_dir / REFERENCE_DAYS[name]
df = pd.read_csv(path, index_col=0, parse_dates=True)
if len(df) != 24:
raise ValueError(f"{path}: expected 24 rows, got {len(df)}")
return df
def build_overlay_figure(
scenarios: np.ndarray,
reference: pd.DataFrame,
title: str | None = None,
) -> matplotlib.figure.Figure:
"""Three-panel plot of generated scenarios overlaid on a reference day.
`scenarios` is the `(n_scenarios, 3, 24)` array returned by `predict()`;
channel order is (wind, solar, load) per the upstream training script.
The reference day is drawn in red; scenarios are translucent grey so the
ensemble spread reads at a glance.
"""
if scenarios.ndim != 3 or scenarios.shape[1:] != (3, 24):
raise ValueError(
f"scenarios must be (n, 3, 24); got {scenarios.shape}"
)
hours = np.arange(24)
fig, axes = plt.subplots(3, 1, figsize=(9, 7), sharex=True)
n_scenarios = scenarios.shape[0]
alpha = max(0.05, min(0.4, 4.0 / n_scenarios))
for ch, (ax, label, col) in enumerate(zip(axes, CHANNEL_LABELS, CHANNEL_COLUMNS, strict=True)):
ax.plot(hours, scenarios[:, ch, :].T, color="0.4", alpha=alpha, linewidth=0.8)
ax.plot(
hours,
reference[col].to_numpy(),
color="crimson",
linewidth=2.0,
label="reference day",
)
ax.set_ylabel(label)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 23)
axes[0].legend(
[f"{n_scenarios} generated scenarios", "reference day"],
loc="upper right",
fontsize=9,
)
axes[-1].set_xlabel("Hour of day")
if title is None:
ref_date = reference.index[0].date() if hasattr(reference.index[0], "date") else "—"
title = f"PC-DDPM scenarios vs reference day ({ref_date})"
fig.suptitle(title, fontsize=12)
fig.tight_layout()
return fig