File size: 8,916 Bytes
93ed35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""Inference entry point for PC-DDPM scenario generation.

`load_model()` pulls the λ=2.0 Temporal U-Net checkpoint and its
normalisation stats from the Hugging Face Hub (cached locally after
the first call). `predict()` runs the reverse-diffusion loop and
returns denormalised MW scenarios. The Gradio demo and the FastAPI
service both call `predict()` — the deployment surface is one
function deep on purpose.

The architecture mirrors `train_ddpm_constrained_lam2.py` from
`pc-ddpm-epec2026` exactly; checkpoints are not portable across
shape changes.
"""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa: N812 — PyTorch convention
from huggingface_hub import hf_hub_download

SEQ_LEN = 24
CHANNELS = 3
T_STEPS = 1000
UNET_DIMS: tuple[int, ...] = (64, 128, 256)
T_DIM = 128

HF_REPO_ID = "jbobym/pc-ddpm-alberta"
DEFAULT_MODEL_FILE = "ddpm_constrained_lam2.pt"
DEFAULT_NORM_FILE = "ddpm_unconstrained_norm.npz"


class SinusoidalPE(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim),
        )

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half = self.dim // 2
        freqs = torch.exp(
            -torch.arange(half, device=t.device, dtype=torch.float32)
            * (np.log(10000) / (half - 1))
        )
        emb = t[:, None].float() * freqs[None]
        return self.proj(torch.cat([emb.sin(), emb.cos()], dim=-1))  # type: ignore[no-any-return]


class ResBlock1D(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, t_dim: int) -> None:
        super().__init__()
        self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.t_proj = nn.Linear(t_dim, out_ch)
        self.res_conv: nn.Module = (
            nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
        )

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        h = F.silu(self.norm1(self.conv1(x)))
        h = h * (1 + self.t_proj(t_emb)[:, :, None])
        h = F.silu(self.norm2(self.conv2(h)))
        return h + self.res_conv(x)  # type: ignore[no-any-return]


class TemporalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int = CHANNELS,
        dims: tuple[int, ...] = UNET_DIMS,
        t_dim: int = T_DIM,
    ) -> None:
        super().__init__()
        self.t_emb = SinusoidalPE(t_dim)
        self.enc_in = nn.Conv1d(in_channels, dims[0], 3, padding=1)
        self.enc = nn.ModuleList()
        self.down = nn.ModuleList()
        ch = dims[0]
        for d in dims[1:]:
            self.enc.append(ResBlock1D(ch, d, t_dim))
            self.down.append(nn.Conv1d(d, d, 4, stride=2, padding=1))
            ch = d
        self.mid1 = ResBlock1D(ch, ch, t_dim)
        self.mid2 = ResBlock1D(ch, ch, t_dim)
        self.dec = nn.ModuleList()
        self.up = nn.ModuleList()
        for skip_ch, d in zip(reversed(dims[1:]), reversed(dims[:-1]), strict=False):
            self.up.append(nn.Upsample(scale_factor=2, mode="linear", align_corners=False))
            self.dec.append(ResBlock1D(ch + skip_ch, d, t_dim))
            ch = d
        self.out_norm = nn.GroupNorm(8, ch)
        self.out_conv = nn.Conv1d(ch, in_channels, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        t_emb = self.t_emb(t)
        h = self.enc_in(x)
        skips: list[torch.Tensor] = []
        for res, down in zip(self.enc, self.down, strict=True):
            h = res(h, t_emb)
            skips.append(h)
            h = down(h)
        h = self.mid2(self.mid1(h, t_emb), t_emb)
        for up, res in zip(self.up, self.dec, strict=True):
            h = up(h)
            skip = skips.pop()
            if h.shape[-1] != skip.shape[-1]:
                h = F.pad(h, (0, skip.shape[-1] - h.shape[-1]))
            h = res(torch.cat([h, skip], dim=1), t_emb)
        return self.out_conv(F.silu(self.out_norm(h)))  # type: ignore[no-any-return]


def _cosine_beta_schedule(n_steps: int, s: float = 0.008) -> torch.Tensor:
    steps = torch.arange(n_steps + 1, dtype=torch.float64)
    f = torch.cos((steps / n_steps + s) / (1 + s) * torch.pi / 2) ** 2
    alphas_cumprod = f / f[0]
    return (1 - alphas_cumprod[1:] / alphas_cumprod[:-1]).clamp(0, 0.999).float()  # type: ignore[no-any-return]


def build_schedule(n_steps: int, device: torch.device) -> dict[str, torch.Tensor]:
    betas = _cosine_beta_schedule(n_steps).to(device)
    alphas = 1 - betas
    alpha_bar = torch.cumprod(alphas, dim=0)
    alpha_bar_prev = F.pad(alpha_bar[:-1], (1, 0), value=1.0)
    return {
        "betas": betas,
        "alphas": alphas,
        "alpha_bar": alpha_bar,
        "alpha_bar_prev": alpha_bar_prev,
        "post_var": betas * (1 - alpha_bar_prev) / (1 - alpha_bar),
    }


@dataclass
class ModelBundle:
    """Everything `predict()` needs in one place. Construct via `load_model()`."""

    model: TemporalUNet
    schedule: dict[str, torch.Tensor]
    ch_min: np.ndarray
    ch_range: np.ndarray
    device: torch.device


def load_model(
    repo_id: str = HF_REPO_ID,
    model_filename: str = DEFAULT_MODEL_FILE,
    norm_filename: str = DEFAULT_NORM_FILE,
    device: str | torch.device | None = None,
) -> ModelBundle:
    """Pull weights + normalisation stats from HF Hub and assemble the bundle.

    First call downloads ~8.4 MB to the HF cache; subsequent calls hit the
    cache. HF Spaces free tier is CPU-only — pass `device="cpu"` explicitly
    in deployment to skip the CUDA probe.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif isinstance(device, str):
        device = torch.device(device)

    model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
    norm_path = hf_hub_download(repo_id=repo_id, filename=norm_filename)

    norm = np.load(norm_path)
    ch_min = norm["ch_min"].astype(np.float32)
    ch_max = norm["ch_max"].astype(np.float32)
    ch_range = np.where(ch_max - ch_min > 0, ch_max - ch_min, 1.0).astype(np.float32)

    ckpt = torch.load(model_path, map_location=device, weights_only=True)
    hp = ckpt["hparams"]
    model = TemporalUNet(
        in_channels=int(hp["CHANNELS"]),
        dims=tuple(int(d) for d in hp["UNET_DIMS"]),
        t_dim=int(hp.get("t_dim", T_DIM)),
    ).to(device)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    return ModelBundle(
        model=model,
        schedule=build_schedule(n_steps=T_STEPS, device=device),
        ch_min=ch_min,
        ch_range=ch_range,
        device=device,
    )


@torch.no_grad()
def predict(
    bundle: ModelBundle,
    n_scenarios: int = 100,
    seed: int | None = None,
    progress: Callable[[int, int], None] | None = None,
) -> np.ndarray:
    """Draw `n_scenarios` 24-hour scenarios via reverse diffusion.

    Returns a `(n_scenarios, 3, 24)` `float32` array of MW values for
    (wind, solar, load) — denormalised and clamped non-negative. The
    `progress(step, total)` callback fires once per reverse step;
    Gradio's progress bar plugs in here.
    """
    device = bundle.device
    sched = bundle.schedule
    n_steps = int(sched["betas"].shape[0])

    gen = torch.Generator(device=device)
    if seed is not None:
        gen.manual_seed(seed)

    x = torch.randn(n_scenarios, CHANNELS, SEQ_LEN, device=device, generator=gen)

    for i, step in enumerate(reversed(range(n_steps))):
        t_b = torch.full((n_scenarios,), step, device=device, dtype=torch.long)
        eps = bundle.model(x, t_b)
        beta = sched["betas"][step]
        alpha = sched["alphas"][step]
        ab = sched["alpha_bar"][step]
        ab_prev = sched["alpha_bar_prev"][step]
        x0_pred = ((x - (1 - ab).sqrt() * eps) / ab.sqrt()).clamp(-1.5, 1.5)
        mean = (
            (ab_prev.sqrt() * beta / (1 - ab)) * x0_pred
            + (alpha.sqrt() * (1 - ab_prev) / (1 - ab)) * x
        )
        if step > 0:
            noise = torch.randn(x.shape, device=device, generator=gen)
            x = mean + sched["post_var"][step].sqrt() * noise
        else:
            x = mean
        if progress is not None:
            progress(i + 1, n_steps)

    samples_norm = x.cpu().numpy()
    samples_mw = samples_norm * bundle.ch_range + bundle.ch_min
    samples_mw[:, 0] = np.clip(samples_mw[:, 0], 0, None)
    samples_mw[:, 1] = np.clip(samples_mw[:, 1], 0, None)
    samples_mw[:, 2] = np.clip(samples_mw[:, 2], 0, None)
    return samples_mw.astype(np.float32)  # type: ignore[no-any-return]