Spaces:
Running
Running
| """Inference entry point for PC-DDPM scenario generation. | |
| `load_model()` pulls the λ=2.0 Temporal U-Net checkpoint and its | |
| normalisation stats from the Hugging Face Hub (cached locally after | |
| the first call). `predict()` runs the reverse-diffusion loop and | |
| returns denormalised MW scenarios. The Gradio demo and the FastAPI | |
| service both call `predict()` — the deployment surface is one | |
| function deep on purpose. | |
| The architecture mirrors `train_ddpm_constrained_lam2.py` from | |
| `pc-ddpm-epec2026` exactly; checkpoints are not portable across | |
| shape changes. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F # noqa: N812 — PyTorch convention | |
| from huggingface_hub import hf_hub_download | |
| SEQ_LEN = 24 | |
| CHANNELS = 3 | |
| T_STEPS = 1000 | |
| UNET_DIMS: tuple[int, ...] = (64, 128, 256) | |
| T_DIM = 128 | |
| HF_REPO_ID = "jbobym/pc-ddpm-alberta" | |
| DEFAULT_MODEL_FILE = "ddpm_constrained_lam2.pt" | |
| DEFAULT_NORM_FILE = "ddpm_unconstrained_norm.npz" | |
| class SinusoidalPE(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.proj = nn.Sequential( | |
| nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim), | |
| ) | |
| def forward(self, t: torch.Tensor) -> torch.Tensor: | |
| half = self.dim // 2 | |
| freqs = torch.exp( | |
| -torch.arange(half, device=t.device, dtype=torch.float32) | |
| * (np.log(10000) / (half - 1)) | |
| ) | |
| emb = t[:, None].float() * freqs[None] | |
| return self.proj(torch.cat([emb.sin(), emb.cos()], dim=-1)) # type: ignore[no-any-return] | |
| class ResBlock1D(nn.Module): | |
| def __init__(self, in_ch: int, out_ch: int, t_dim: int) -> None: | |
| super().__init__() | |
| self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1) | |
| self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1) | |
| self.norm1 = nn.GroupNorm(8, out_ch) | |
| self.norm2 = nn.GroupNorm(8, out_ch) | |
| self.t_proj = nn.Linear(t_dim, out_ch) | |
| self.res_conv: nn.Module = ( | |
| nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() | |
| ) | |
| def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: | |
| h = F.silu(self.norm1(self.conv1(x))) | |
| h = h * (1 + self.t_proj(t_emb)[:, :, None]) | |
| h = F.silu(self.norm2(self.conv2(h))) | |
| return h + self.res_conv(x) # type: ignore[no-any-return] | |
| class TemporalUNet(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int = CHANNELS, | |
| dims: tuple[int, ...] = UNET_DIMS, | |
| t_dim: int = T_DIM, | |
| ) -> None: | |
| super().__init__() | |
| self.t_emb = SinusoidalPE(t_dim) | |
| self.enc_in = nn.Conv1d(in_channels, dims[0], 3, padding=1) | |
| self.enc = nn.ModuleList() | |
| self.down = nn.ModuleList() | |
| ch = dims[0] | |
| for d in dims[1:]: | |
| self.enc.append(ResBlock1D(ch, d, t_dim)) | |
| self.down.append(nn.Conv1d(d, d, 4, stride=2, padding=1)) | |
| ch = d | |
| self.mid1 = ResBlock1D(ch, ch, t_dim) | |
| self.mid2 = ResBlock1D(ch, ch, t_dim) | |
| self.dec = nn.ModuleList() | |
| self.up = nn.ModuleList() | |
| for skip_ch, d in zip(reversed(dims[1:]), reversed(dims[:-1]), strict=False): | |
| self.up.append(nn.Upsample(scale_factor=2, mode="linear", align_corners=False)) | |
| self.dec.append(ResBlock1D(ch + skip_ch, d, t_dim)) | |
| ch = d | |
| self.out_norm = nn.GroupNorm(8, ch) | |
| self.out_conv = nn.Conv1d(ch, in_channels, 1) | |
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
| t_emb = self.t_emb(t) | |
| h = self.enc_in(x) | |
| skips: list[torch.Tensor] = [] | |
| for res, down in zip(self.enc, self.down, strict=True): | |
| h = res(h, t_emb) | |
| skips.append(h) | |
| h = down(h) | |
| h = self.mid2(self.mid1(h, t_emb), t_emb) | |
| for up, res in zip(self.up, self.dec, strict=True): | |
| h = up(h) | |
| skip = skips.pop() | |
| if h.shape[-1] != skip.shape[-1]: | |
| h = F.pad(h, (0, skip.shape[-1] - h.shape[-1])) | |
| h = res(torch.cat([h, skip], dim=1), t_emb) | |
| return self.out_conv(F.silu(self.out_norm(h))) # type: ignore[no-any-return] | |
| def _cosine_beta_schedule(n_steps: int, s: float = 0.008) -> torch.Tensor: | |
| steps = torch.arange(n_steps + 1, dtype=torch.float64) | |
| f = torch.cos((steps / n_steps + s) / (1 + s) * torch.pi / 2) ** 2 | |
| alphas_cumprod = f / f[0] | |
| return (1 - alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(0, 0.999).float() # type: ignore[no-any-return] | |
| def build_schedule(n_steps: int, device: torch.device) -> dict[str, torch.Tensor]: | |
| betas = _cosine_beta_schedule(n_steps).to(device) | |
| alphas = 1 - betas | |
| alpha_bar = torch.cumprod(alphas, dim=0) | |
| alpha_bar_prev = F.pad(alpha_bar[:-1], (1, 0), value=1.0) | |
| return { | |
| "betas": betas, | |
| "alphas": alphas, | |
| "alpha_bar": alpha_bar, | |
| "alpha_bar_prev": alpha_bar_prev, | |
| "post_var": betas * (1 - alpha_bar_prev) / (1 - alpha_bar), | |
| } | |
| class ModelBundle: | |
| """Everything `predict()` needs in one place. Construct via `load_model()`.""" | |
| model: TemporalUNet | |
| schedule: dict[str, torch.Tensor] | |
| ch_min: np.ndarray | |
| ch_range: np.ndarray | |
| device: torch.device | |
| def load_model( | |
| repo_id: str = HF_REPO_ID, | |
| model_filename: str = DEFAULT_MODEL_FILE, | |
| norm_filename: str = DEFAULT_NORM_FILE, | |
| device: str | torch.device | None = None, | |
| ) -> ModelBundle: | |
| """Pull weights + normalisation stats from HF Hub and assemble the bundle. | |
| First call downloads ~8.4 MB to the HF cache; subsequent calls hit the | |
| cache. HF Spaces free tier is CPU-only — pass `device="cpu"` explicitly | |
| in deployment to skip the CUDA probe. | |
| """ | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| elif isinstance(device, str): | |
| device = torch.device(device) | |
| model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) | |
| norm_path = hf_hub_download(repo_id=repo_id, filename=norm_filename) | |
| norm = np.load(norm_path) | |
| ch_min = norm["ch_min"].astype(np.float32) | |
| ch_max = norm["ch_max"].astype(np.float32) | |
| ch_range = np.where(ch_max - ch_min > 0, ch_max - ch_min, 1.0).astype(np.float32) | |
| ckpt = torch.load(model_path, map_location=device, weights_only=True) | |
| hp = ckpt["hparams"] | |
| model = TemporalUNet( | |
| in_channels=int(hp["CHANNELS"]), | |
| dims=tuple(int(d) for d in hp["UNET_DIMS"]), | |
| t_dim=int(hp.get("t_dim", T_DIM)), | |
| ).to(device) | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.eval() | |
| return ModelBundle( | |
| model=model, | |
| schedule=build_schedule(n_steps=T_STEPS, device=device), | |
| ch_min=ch_min, | |
| ch_range=ch_range, | |
| device=device, | |
| ) | |
| def predict( | |
| bundle: ModelBundle, | |
| n_scenarios: int = 100, | |
| seed: int | None = None, | |
| progress: Callable[[int, int], None] | None = None, | |
| ) -> np.ndarray: | |
| """Draw `n_scenarios` 24-hour scenarios via reverse diffusion. | |
| Returns a `(n_scenarios, 3, 24)` `float32` array of MW values for | |
| (wind, solar, load) — denormalised and clamped non-negative. The | |
| `progress(step, total)` callback fires once per reverse step; | |
| Gradio's progress bar plugs in here. | |
| """ | |
| device = bundle.device | |
| sched = bundle.schedule | |
| n_steps = int(sched["betas"].shape[0]) | |
| gen = torch.Generator(device=device) | |
| if seed is not None: | |
| gen.manual_seed(seed) | |
| x = torch.randn(n_scenarios, CHANNELS, SEQ_LEN, device=device, generator=gen) | |
| for i, step in enumerate(reversed(range(n_steps))): | |
| t_b = torch.full((n_scenarios,), step, device=device, dtype=torch.long) | |
| eps = bundle.model(x, t_b) | |
| beta = sched["betas"][step] | |
| alpha = sched["alphas"][step] | |
| ab = sched["alpha_bar"][step] | |
| ab_prev = sched["alpha_bar_prev"][step] | |
| x0_pred = ((x - (1 - ab).sqrt() * eps) / ab.sqrt()).clamp(-1.5, 1.5) | |
| mean = ( | |
| (ab_prev.sqrt() * beta / (1 - ab)) * x0_pred | |
| + (alpha.sqrt() * (1 - ab_prev) / (1 - ab)) * x | |
| ) | |
| if step > 0: | |
| noise = torch.randn(x.shape, device=device, generator=gen) | |
| x = mean + sched["post_var"][step].sqrt() * noise | |
| else: | |
| x = mean | |
| if progress is not None: | |
| progress(i + 1, n_steps) | |
| samples_norm = x.cpu().numpy() | |
| samples_mw = samples_norm * bundle.ch_range + bundle.ch_min | |
| samples_mw[:, 0] = np.clip(samples_mw[:, 0], 0, None) | |
| samples_mw[:, 1] = np.clip(samples_mw[:, 1], 0, None) | |
| samples_mw[:, 2] = np.clip(samples_mw[:, 2], 0, None) | |
| return samples_mw.astype(np.float32) # type: ignore[no-any-return] | |