"""Inference entry point for PC-DDPM scenario generation. `load_model()` pulls the λ=2.0 Temporal U-Net checkpoint and its normalisation stats from the Hugging Face Hub (cached locally after the first call). `predict()` runs the reverse-diffusion loop and returns denormalised MW scenarios. The Gradio demo and the FastAPI service both call `predict()` — the deployment surface is one function deep on purpose. The architecture mirrors `train_ddpm_constrained_lam2.py` from `pc-ddpm-epec2026` exactly; checkpoints are not portable across shape changes. """ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 — PyTorch convention from huggingface_hub import hf_hub_download SEQ_LEN = 24 CHANNELS = 3 T_STEPS = 1000 UNET_DIMS: tuple[int, ...] = (64, 128, 256) T_DIM = 128 HF_REPO_ID = "jbobym/pc-ddpm-alberta" DEFAULT_MODEL_FILE = "ddpm_constrained_lam2.pt" DEFAULT_NORM_FILE = "ddpm_unconstrained_norm.npz" class SinusoidalPE(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.dim = dim self.proj = nn.Sequential( nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim), ) def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.dim // 2 freqs = torch.exp( -torch.arange(half, device=t.device, dtype=torch.float32) * (np.log(10000) / (half - 1)) ) emb = t[:, None].float() * freqs[None] return self.proj(torch.cat([emb.sin(), emb.cos()], dim=-1)) # type: ignore[no-any-return] class ResBlock1D(nn.Module): def __init__(self, in_ch: int, out_ch: int, t_dim: int) -> None: super().__init__() self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1) self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1) self.norm1 = nn.GroupNorm(8, out_ch) self.norm2 = nn.GroupNorm(8, out_ch) self.t_proj = nn.Linear(t_dim, out_ch) self.res_conv: nn.Module = ( nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() ) def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: h = F.silu(self.norm1(self.conv1(x))) h = h * (1 + self.t_proj(t_emb)[:, :, None]) h = F.silu(self.norm2(self.conv2(h))) return h + self.res_conv(x) # type: ignore[no-any-return] class TemporalUNet(nn.Module): def __init__( self, in_channels: int = CHANNELS, dims: tuple[int, ...] = UNET_DIMS, t_dim: int = T_DIM, ) -> None: super().__init__() self.t_emb = SinusoidalPE(t_dim) self.enc_in = nn.Conv1d(in_channels, dims[0], 3, padding=1) self.enc = nn.ModuleList() self.down = nn.ModuleList() ch = dims[0] for d in dims[1:]: self.enc.append(ResBlock1D(ch, d, t_dim)) self.down.append(nn.Conv1d(d, d, 4, stride=2, padding=1)) ch = d self.mid1 = ResBlock1D(ch, ch, t_dim) self.mid2 = ResBlock1D(ch, ch, t_dim) self.dec = nn.ModuleList() self.up = nn.ModuleList() for skip_ch, d in zip(reversed(dims[1:]), reversed(dims[:-1]), strict=False): self.up.append(nn.Upsample(scale_factor=2, mode="linear", align_corners=False)) self.dec.append(ResBlock1D(ch + skip_ch, d, t_dim)) ch = d self.out_norm = nn.GroupNorm(8, ch) self.out_conv = nn.Conv1d(ch, in_channels, 1) def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: t_emb = self.t_emb(t) h = self.enc_in(x) skips: list[torch.Tensor] = [] for res, down in zip(self.enc, self.down, strict=True): h = res(h, t_emb) skips.append(h) h = down(h) h = self.mid2(self.mid1(h, t_emb), t_emb) for up, res in zip(self.up, self.dec, strict=True): h = up(h) skip = skips.pop() if h.shape[-1] != skip.shape[-1]: h = F.pad(h, (0, skip.shape[-1] - h.shape[-1])) h = res(torch.cat([h, skip], dim=1), t_emb) return self.out_conv(F.silu(self.out_norm(h))) # type: ignore[no-any-return] def _cosine_beta_schedule(n_steps: int, s: float = 0.008) -> torch.Tensor: steps = torch.arange(n_steps + 1, dtype=torch.float64) f = torch.cos((steps / n_steps + s) / (1 + s) * torch.pi / 2) ** 2 alphas_cumprod = f / f[0] return (1 - alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(0, 0.999).float() # type: ignore[no-any-return] def build_schedule(n_steps: int, device: torch.device) -> dict[str, torch.Tensor]: betas = _cosine_beta_schedule(n_steps).to(device) alphas = 1 - betas alpha_bar = torch.cumprod(alphas, dim=0) alpha_bar_prev = F.pad(alpha_bar[:-1], (1, 0), value=1.0) return { "betas": betas, "alphas": alphas, "alpha_bar": alpha_bar, "alpha_bar_prev": alpha_bar_prev, "post_var": betas * (1 - alpha_bar_prev) / (1 - alpha_bar), } @dataclass class ModelBundle: """Everything `predict()` needs in one place. Construct via `load_model()`.""" model: TemporalUNet schedule: dict[str, torch.Tensor] ch_min: np.ndarray ch_range: np.ndarray device: torch.device def load_model( repo_id: str = HF_REPO_ID, model_filename: str = DEFAULT_MODEL_FILE, norm_filename: str = DEFAULT_NORM_FILE, device: str | torch.device | None = None, ) -> ModelBundle: """Pull weights + normalisation stats from HF Hub and assemble the bundle. First call downloads ~8.4 MB to the HF cache; subsequent calls hit the cache. HF Spaces free tier is CPU-only — pass `device="cpu"` explicitly in deployment to skip the CUDA probe. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") elif isinstance(device, str): device = torch.device(device) model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) norm_path = hf_hub_download(repo_id=repo_id, filename=norm_filename) norm = np.load(norm_path) ch_min = norm["ch_min"].astype(np.float32) ch_max = norm["ch_max"].astype(np.float32) ch_range = np.where(ch_max - ch_min > 0, ch_max - ch_min, 1.0).astype(np.float32) ckpt = torch.load(model_path, map_location=device, weights_only=True) hp = ckpt["hparams"] model = TemporalUNet( in_channels=int(hp["CHANNELS"]), dims=tuple(int(d) for d in hp["UNET_DIMS"]), t_dim=int(hp.get("t_dim", T_DIM)), ).to(device) model.load_state_dict(ckpt["model_state"]) model.eval() return ModelBundle( model=model, schedule=build_schedule(n_steps=T_STEPS, device=device), ch_min=ch_min, ch_range=ch_range, device=device, ) @torch.no_grad() def predict( bundle: ModelBundle, n_scenarios: int = 100, seed: int | None = None, progress: Callable[[int, int], None] | None = None, ) -> np.ndarray: """Draw `n_scenarios` 24-hour scenarios via reverse diffusion. Returns a `(n_scenarios, 3, 24)` `float32` array of MW values for (wind, solar, load) — denormalised and clamped non-negative. The `progress(step, total)` callback fires once per reverse step; Gradio's progress bar plugs in here. """ device = bundle.device sched = bundle.schedule n_steps = int(sched["betas"].shape[0]) gen = torch.Generator(device=device) if seed is not None: gen.manual_seed(seed) x = torch.randn(n_scenarios, CHANNELS, SEQ_LEN, device=device, generator=gen) for i, step in enumerate(reversed(range(n_steps))): t_b = torch.full((n_scenarios,), step, device=device, dtype=torch.long) eps = bundle.model(x, t_b) beta = sched["betas"][step] alpha = sched["alphas"][step] ab = sched["alpha_bar"][step] ab_prev = sched["alpha_bar_prev"][step] x0_pred = ((x - (1 - ab).sqrt() * eps) / ab.sqrt()).clamp(-1.5, 1.5) mean = ( (ab_prev.sqrt() * beta / (1 - ab)) * x0_pred + (alpha.sqrt() * (1 - ab_prev) / (1 - ab)) * x ) if step > 0: noise = torch.randn(x.shape, device=device, generator=gen) x = mean + sched["post_var"][step].sqrt() * noise else: x = mean if progress is not None: progress(i + 1, n_steps) samples_norm = x.cpu().numpy() samples_mw = samples_norm * bundle.ch_range + bundle.ch_min samples_mw[:, 0] = np.clip(samples_mw[:, 0], 0, None) samples_mw[:, 1] = np.clip(samples_mw[:, 1], 0, None) samples_mw[:, 2] = np.clip(samples_mw[:, 2], 0, None) return samples_mw.astype(np.float32) # type: ignore[no-any-return]