| """ |
| U-Net architecture for conditional diffusion on spatiotemporal PDE data. |
| Supports non-square inputs, time conditioning, and skip connections. |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class SinusoidalPosEmb(nn.Module): |
| """Sinusoidal positional embedding for diffusion timestep.""" |
|
|
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t): |
| half = self.dim // 2 |
| emb = math.log(10000) / (half - 1) |
| emb = torch.exp(torch.arange(half, device=t.device) * -emb) |
| emb = t[:, None].float() * emb[None, :] |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
|
|
| class ResBlock(nn.Module): |
| """Residual block with group norm, SiLU, and time embedding injection.""" |
|
|
| def __init__(self, in_ch, out_ch, time_dim, dropout=0.1): |
| super().__init__() |
| self.norm1 = nn.GroupNorm(min(32, in_ch), in_ch) |
| self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) |
| self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch)) |
| self.norm2 = nn.GroupNorm(min(32, out_ch), out_ch) |
| self.dropout = nn.Dropout(dropout) |
| self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
| self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() |
|
|
| def forward(self, x, t_emb): |
| h = F.silu(self.norm1(x)) |
| h = self.conv1(h) |
| h = h + self.time_mlp(t_emb)[:, :, None, None] |
| h = F.silu(self.norm2(h)) |
| h = self.dropout(h) |
| h = self.conv2(h) |
| return h + self.skip(x) |
|
|
|
|
| class SelfAttention(nn.Module): |
| """Multi-head self-attention on spatial features.""" |
|
|
| def __init__(self, channels, num_heads=4): |
| super().__init__() |
| self.norm = nn.GroupNorm(min(32, channels), channels) |
| self.attn = nn.MultiheadAttention(channels, num_heads, batch_first=True) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| h = self.norm(x).reshape(B, C, H * W).permute(0, 2, 1) |
| h, _ = self.attn(h, h, h) |
| h = h.permute(0, 2, 1).reshape(B, C, H, W) |
| return x + h |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.conv = nn.Conv2d(ch, ch, 3, padding=1) |
|
|
| def forward(self, x): |
| x = F.interpolate(x, scale_factor=2, mode="nearest") |
| return self.conv(x) |
|
|
|
|
| class UNet(nn.Module): |
| """U-Net for conditional diffusion. |
| |
| Condition (e.g. previous frame) is concatenated to the noisy input along |
| the channel dimension *before* being passed to forward(). So set |
| ``in_channels = output_channels + condition_channels``. |
| |
| Args: |
| in_channels: noisy-target channels + condition channels. |
| out_channels: channels to predict (same as target). |
| base_ch: base channel width. |
| ch_mults: per-level channel multipliers. |
| n_res: residual blocks per level. |
| attn_levels: which levels get self-attention (0-indexed). |
| dropout: dropout rate. |
| time_dim: timestep embedding dimension. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| base_ch=64, |
| ch_mults=(1, 2, 4, 8), |
| n_res=2, |
| attn_levels=(3,), |
| dropout=0.1, |
| time_dim=256, |
| ): |
| super().__init__() |
| self.n_res = n_res |
| self.ch_mults = ch_mults |
|
|
| |
| self.time_embed = nn.Sequential( |
| SinusoidalPosEmb(time_dim), |
| nn.Linear(time_dim, time_dim * 4), |
| nn.SiLU(), |
| nn.Linear(time_dim * 4, time_dim), |
| ) |
|
|
| |
| self.input_conv = nn.Conv2d(in_channels, base_ch, 3, padding=1) |
|
|
| |
| self.downs = nn.ModuleList() |
| ch = base_ch |
| skip_chs = [ch] |
|
|
| for lvl, mult in enumerate(ch_mults): |
| out_ch = base_ch * mult |
| for _ in range(n_res): |
| self.downs.append( |
| nn.ModuleDict( |
| { |
| "res": ResBlock(ch, out_ch, time_dim, dropout), |
| **( |
| {"attn": SelfAttention(out_ch)} |
| if lvl in attn_levels |
| else {} |
| ), |
| } |
| ) |
| ) |
| ch = out_ch |
| skip_chs.append(ch) |
| if lvl < len(ch_mults) - 1: |
| self.downs.append(nn.ModuleDict({"down": Downsample(ch)})) |
| skip_chs.append(ch) |
|
|
| |
| self.mid_res1 = ResBlock(ch, ch, time_dim, dropout) |
| self.mid_attn = SelfAttention(ch) |
| self.mid_res2 = ResBlock(ch, ch, time_dim, dropout) |
|
|
| |
| self.ups = nn.ModuleList() |
| for lvl in reversed(range(len(ch_mults))): |
| out_ch = base_ch * ch_mults[lvl] |
| for _ in range(n_res + 1): |
| skip_ch = skip_chs.pop() |
| self.ups.append( |
| nn.ModuleDict( |
| { |
| "res": ResBlock(ch + skip_ch, out_ch, time_dim, dropout), |
| **( |
| {"attn": SelfAttention(out_ch)} |
| if lvl in attn_levels |
| else {} |
| ), |
| } |
| ) |
| ) |
| ch = out_ch |
| if lvl > 0: |
| self.ups.append(nn.ModuleDict({"up": Upsample(ch)})) |
|
|
| |
| self.out_norm = nn.GroupNorm(min(32, ch), ch) |
| self.out_conv = nn.Conv2d(ch, out_channels, 3, padding=1) |
|
|
| def forward(self, x, t, cond=None): |
| """ |
| Args: |
| x: noisy target [B, C_out, H, W] |
| t: diffusion timestep [B] (int or float) |
| cond: condition [B, C_cond, H, W] (optional, concatenated) |
| Returns: |
| predicted noise [B, C_out, H, W] |
| """ |
| if cond is not None: |
| x = torch.cat([x, cond], dim=1) |
|
|
| t_emb = self.time_embed(t) |
| h = self.input_conv(x) |
|
|
| |
| skips = [h] |
| for block in self.downs: |
| if "down" in block: |
| h = block["down"](h) |
| skips.append(h) |
| else: |
| h = block["res"](h, t_emb) |
| if "attn" in block: |
| h = block["attn"](h) |
| skips.append(h) |
|
|
| |
| h = self.mid_res1(h, t_emb) |
| h = self.mid_attn(h) |
| h = self.mid_res2(h, t_emb) |
|
|
| |
| for block in self.ups: |
| if "up" in block: |
| h = block["up"](h) |
| else: |
| s = skips.pop() |
| h = torch.cat([h, s], dim=1) |
| h = block["res"](h, t_emb) |
| if "attn" in block: |
| h = block["attn"](h) |
|
|
| h = F.silu(self.out_norm(h)) |
| return self.out_conv(h) |
|
|