jbobym's picture
space deploy: trim short_description to fit HF 60-char cap
93ed35a
"""Inference entry point for PC-DDPM scenario generation.
`load_model()` pulls the λ=2.0 Temporal U-Net checkpoint and its
normalisation stats from the Hugging Face Hub (cached locally after
the first call). `predict()` runs the reverse-diffusion loop and
returns denormalised MW scenarios. The Gradio demo and the FastAPI
service both call `predict()` — the deployment surface is one
function deep on purpose.
The architecture mirrors `train_ddpm_constrained_lam2.py` from
`pc-ddpm-epec2026` exactly; checkpoints are not portable across
shape changes.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812 — PyTorch convention
from huggingface_hub import hf_hub_download
SEQ_LEN = 24
CHANNELS = 3
T_STEPS = 1000
UNET_DIMS: tuple[int, ...] = (64, 128, 256)
T_DIM = 128
HF_REPO_ID = "jbobym/pc-ddpm-alberta"
DEFAULT_MODEL_FILE = "ddpm_constrained_lam2.pt"
DEFAULT_NORM_FILE = "ddpm_unconstrained_norm.npz"
class SinusoidalPE(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
self.proj = nn.Sequential(
nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.dim // 2
freqs = torch.exp(
-torch.arange(half, device=t.device, dtype=torch.float32)
* (np.log(10000) / (half - 1))
)
emb = t[:, None].float() * freqs[None]
return self.proj(torch.cat([emb.sin(), emb.cos()], dim=-1)) # type: ignore[no-any-return]
class ResBlock1D(nn.Module):
def __init__(self, in_ch: int, out_ch: int, t_dim: int) -> None:
super().__init__()
self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1)
self.norm1 = nn.GroupNorm(8, out_ch)
self.norm2 = nn.GroupNorm(8, out_ch)
self.t_proj = nn.Linear(t_dim, out_ch)
self.res_conv: nn.Module = (
nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
h = F.silu(self.norm1(self.conv1(x)))
h = h * (1 + self.t_proj(t_emb)[:, :, None])
h = F.silu(self.norm2(self.conv2(h)))
return h + self.res_conv(x) # type: ignore[no-any-return]
class TemporalUNet(nn.Module):
def __init__(
self,
in_channels: int = CHANNELS,
dims: tuple[int, ...] = UNET_DIMS,
t_dim: int = T_DIM,
) -> None:
super().__init__()
self.t_emb = SinusoidalPE(t_dim)
self.enc_in = nn.Conv1d(in_channels, dims[0], 3, padding=1)
self.enc = nn.ModuleList()
self.down = nn.ModuleList()
ch = dims[0]
for d in dims[1:]:
self.enc.append(ResBlock1D(ch, d, t_dim))
self.down.append(nn.Conv1d(d, d, 4, stride=2, padding=1))
ch = d
self.mid1 = ResBlock1D(ch, ch, t_dim)
self.mid2 = ResBlock1D(ch, ch, t_dim)
self.dec = nn.ModuleList()
self.up = nn.ModuleList()
for skip_ch, d in zip(reversed(dims[1:]), reversed(dims[:-1]), strict=False):
self.up.append(nn.Upsample(scale_factor=2, mode="linear", align_corners=False))
self.dec.append(ResBlock1D(ch + skip_ch, d, t_dim))
ch = d
self.out_norm = nn.GroupNorm(8, ch)
self.out_conv = nn.Conv1d(ch, in_channels, 1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_emb = self.t_emb(t)
h = self.enc_in(x)
skips: list[torch.Tensor] = []
for res, down in zip(self.enc, self.down, strict=True):
h = res(h, t_emb)
skips.append(h)
h = down(h)
h = self.mid2(self.mid1(h, t_emb), t_emb)
for up, res in zip(self.up, self.dec, strict=True):
h = up(h)
skip = skips.pop()
if h.shape[-1] != skip.shape[-1]:
h = F.pad(h, (0, skip.shape[-1] - h.shape[-1]))
h = res(torch.cat([h, skip], dim=1), t_emb)
return self.out_conv(F.silu(self.out_norm(h))) # type: ignore[no-any-return]
def _cosine_beta_schedule(n_steps: int, s: float = 0.008) -> torch.Tensor:
steps = torch.arange(n_steps + 1, dtype=torch.float64)
f = torch.cos((steps / n_steps + s) / (1 + s) * torch.pi / 2) ** 2
alphas_cumprod = f / f[0]
return (1 - alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(0, 0.999).float() # type: ignore[no-any-return]
def build_schedule(n_steps: int, device: torch.device) -> dict[str, torch.Tensor]:
betas = _cosine_beta_schedule(n_steps).to(device)
alphas = 1 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
alpha_bar_prev = F.pad(alpha_bar[:-1], (1, 0), value=1.0)
return {
"betas": betas,
"alphas": alphas,
"alpha_bar": alpha_bar,
"alpha_bar_prev": alpha_bar_prev,
"post_var": betas * (1 - alpha_bar_prev) / (1 - alpha_bar),
}
@dataclass
class ModelBundle:
"""Everything `predict()` needs in one place. Construct via `load_model()`."""
model: TemporalUNet
schedule: dict[str, torch.Tensor]
ch_min: np.ndarray
ch_range: np.ndarray
device: torch.device
def load_model(
repo_id: str = HF_REPO_ID,
model_filename: str = DEFAULT_MODEL_FILE,
norm_filename: str = DEFAULT_NORM_FILE,
device: str | torch.device | None = None,
) -> ModelBundle:
"""Pull weights + normalisation stats from HF Hub and assemble the bundle.
First call downloads ~8.4 MB to the HF cache; subsequent calls hit the
cache. HF Spaces free tier is CPU-only — pass `device="cpu"` explicitly
in deployment to skip the CUDA probe.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif isinstance(device, str):
device = torch.device(device)
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
norm_path = hf_hub_download(repo_id=repo_id, filename=norm_filename)
norm = np.load(norm_path)
ch_min = norm["ch_min"].astype(np.float32)
ch_max = norm["ch_max"].astype(np.float32)
ch_range = np.where(ch_max - ch_min > 0, ch_max - ch_min, 1.0).astype(np.float32)
ckpt = torch.load(model_path, map_location=device, weights_only=True)
hp = ckpt["hparams"]
model = TemporalUNet(
in_channels=int(hp["CHANNELS"]),
dims=tuple(int(d) for d in hp["UNET_DIMS"]),
t_dim=int(hp.get("t_dim", T_DIM)),
).to(device)
model.load_state_dict(ckpt["model_state"])
model.eval()
return ModelBundle(
model=model,
schedule=build_schedule(n_steps=T_STEPS, device=device),
ch_min=ch_min,
ch_range=ch_range,
device=device,
)
@torch.no_grad()
def predict(
bundle: ModelBundle,
n_scenarios: int = 100,
seed: int | None = None,
progress: Callable[[int, int], None] | None = None,
) -> np.ndarray:
"""Draw `n_scenarios` 24-hour scenarios via reverse diffusion.
Returns a `(n_scenarios, 3, 24)` `float32` array of MW values for
(wind, solar, load) — denormalised and clamped non-negative. The
`progress(step, total)` callback fires once per reverse step;
Gradio's progress bar plugs in here.
"""
device = bundle.device
sched = bundle.schedule
n_steps = int(sched["betas"].shape[0])
gen = torch.Generator(device=device)
if seed is not None:
gen.manual_seed(seed)
x = torch.randn(n_scenarios, CHANNELS, SEQ_LEN, device=device, generator=gen)
for i, step in enumerate(reversed(range(n_steps))):
t_b = torch.full((n_scenarios,), step, device=device, dtype=torch.long)
eps = bundle.model(x, t_b)
beta = sched["betas"][step]
alpha = sched["alphas"][step]
ab = sched["alpha_bar"][step]
ab_prev = sched["alpha_bar_prev"][step]
x0_pred = ((x - (1 - ab).sqrt() * eps) / ab.sqrt()).clamp(-1.5, 1.5)
mean = (
(ab_prev.sqrt() * beta / (1 - ab)) * x0_pred
+ (alpha.sqrt() * (1 - ab_prev) / (1 - ab)) * x
)
if step > 0:
noise = torch.randn(x.shape, device=device, generator=gen)
x = mean + sched["post_var"][step].sqrt() * noise
else:
x = mean
if progress is not None:
progress(i + 1, n_steps)
samples_norm = x.cpu().numpy()
samples_mw = samples_norm * bundle.ch_range + bundle.ch_min
samples_mw[:, 0] = np.clip(samples_mw[:, 0], 0, None)
samples_mw[:, 1] = np.clip(samples_mw[:, 1], 0, None)
samples_mw[:, 2] = np.clip(samples_mw[:, 2], 0, None)
return samples_mw.astype(np.float32) # type: ignore[no-any-return]