File size: 7,149 Bytes
74cbbdd 8c17293 74cbbdd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
PDE-SSM: Fourier-domain spatial mixing block for IRIS.
Replaces O(N²) self-attention with O(N log N) spectral convolution.
Implements a learnable convection-diffusion-reaction PDE in Fourier space.
Native 2D — no rasterization or scanning needed.
References:
- PDE-SSM-DiT (arxiv:2603.13663): Spectral state space approach
- FNO (arxiv:2010.08895): Fourier Neural Operator
- DyDiLA (arxiv:2601.13683): Token differential to prevent oversmoothing
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SpectralConv2d(nn.Module):
"""
2D Fourier-domain learnable convolution.
Applies learnable complex weights to low-frequency modes in the
Fourier domain. This is the core of the PDE-SSM operator.
Complexity: O(N log N) via FFT, vs O(N²) for attention.
For 4×4 spatial grids (16 tokens), this is trivially fast,
but the architecture is designed to scale to 8×8 (64 tokens)
or 16×16 (256 tokens) without quadratic blowup.
"""
def __init__(self, channels: int, modes_h: int, modes_w: int):
super().__init__()
self.channels = channels
self.modes_h = modes_h
self.modes_w = modes_w
# Learnable complex weights for low-frequency modes
# Two sets: positive and negative frequency halves in H dimension
# rfft2 output is (B, C, H, W//2+1) — W is halved due to Hermitian symmetry
scale = 1.0 / (channels * channels)
self.weight_pos = nn.Parameter(
scale * torch.randn(channels, channels, modes_h, modes_w, 2)
)
self.weight_neg = nn.Parameter(
scale * torch.randn(channels, channels, modes_h, modes_w, 2)
)
def _complex_mul(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""Complex matrix multiply: (B,Ci,H,W) x (Ci,Co,H,W) -> (B,Co,H,W).
Both x and w are stored as real tensors with last dim = 2 (real, imag).
This avoids torch.cfloat which has poor AMP support.
"""
# x: (B, Ci, H, W, 2), w: (Ci, Co, H, W, 2)
# Real part: xr*wr - xi*wi
# Imag part: xr*wi + xi*wr
xr, xi = x[..., 0], x[..., 1]
wr, wi = w[..., 0], w[..., 1]
or_ = torch.einsum("bihw,iohw->bohw", xr, wr) - torch.einsum("bihw,iohw->bohw", xi, wi)
oi = torch.einsum("bihw,iohw->bohw", xr, wi) + torch.einsum("bihw,iohw->bohw", xi, wr)
return torch.stack([or_, oi], dim=-1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W) real tensor
Returns:
(B, C, H, W) real tensor — spatially mixed
"""
B, C, H, W = x.shape
# Forward FFT: real -> complex spectrum
# Use view_as_real to get (B, C, H, W//2+1, 2) for AMP compatibility
x_ft = torch.fft.rfft2(x.float(), norm="ortho")
x_ft = torch.view_as_real(x_ft) # (B, C, H, W//2+1, 2)
# Output spectrum (zero-initialized)
out_shape = (B, C, H, W // 2 + 1, 2)
out_ft = torch.zeros(out_shape, device=x.device, dtype=x_ft.dtype)
# Low-frequency positive modes (top of spectrum)
mh = min(self.modes_h, H)
mw = min(self.modes_w, W // 2 + 1)
out_ft[:, :, :mh, :mw] = self._complex_mul(
x_ft[:, :, :mh, :mw], self.weight_pos[:, :, :mh, :mw]
)
# Low-frequency negative modes (bottom of spectrum, wraps around)
if mh > 0 and H > mh:
out_ft[:, :, -mh:, :mw] = self._complex_mul(
x_ft[:, :, -mh:, :mw], self.weight_neg[:, :, :mh, :mw]
)
# Inverse FFT back to spatial domain
out_ft_complex = torch.view_as_complex(out_ft)
result = torch.fft.irfft2(out_ft_complex, s=(H, W), norm="ortho")
return result.to(x.dtype)
class TokenDifferential(nn.Module):
"""
Prevents oversmoothing in spectral/linear blocks.
From DyDiLA (arxiv:2601.13683): adds a learned high-pass
residual that preserves local contrast.
diff(h) = alpha * (h - AvgPool(h))
This is critical — without it, FFT-based mixing kills edges.
"""
def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W)
Returns:
(B, C, H, W) with high-frequency residual added
"""
# Average pool with kernel size matching spatial dims
# For small grids (4×4), use adaptive pool to single value
avg = F.adaptive_avg_pool2d(x, 1) # (B, C, 1, 1) — global average
return x + self.alpha * (x - avg)
class PDESSMBlock(nn.Module):
"""
Complete PDE-SSM spatial mixing block.
Combines:
1. Fourier-domain spectral convolution (global mixing, O(N log N))
2. Pointwise convolution (local residual path)
3. Token differential (anti-oversmoothing)
4. Pre-norm + residual connection
This replaces self-attention in the IRIS architecture.
For a 4×4 grid (16 tokens), the FFT is 16×log(16) ≈ 64 ops
vs 16² = 256 for attention. At 16×16 (256 tokens): 2048 vs 65536.
"""
def __init__(self, dim: int, spatial_size: int = 4):
"""
Args:
dim: channel dimension
spatial_size: H = W of the spatial grid (e.g., 4 for 4×4)
"""
super().__init__()
# Keep ~50% of frequency modes. For size=4, modes=2.
modes = max(2, spatial_size // 2)
self.norm = nn.LayerNorm(dim)
self.spectral = SpectralConv2d(dim, modes, modes)
self.local_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=True)
self.token_diff = TokenDifferential(dim)
self.gate = nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
"""
Args:
x: (B, N, D) — sequence of spatial tokens
H, W: spatial dimensions such that N = H * W
Returns:
(B, N, D) — spatially mixed tokens
"""
B, N, D = x.shape
assert N == H * W, f"N={N} must equal H*W={H*W}"
residual = x
# Pre-norm
x = self.norm(x)
# Reshape to 2D spatial grid for FFT
x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
# Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
# Run conv in float32 — grouped/1x1 convs lack bf16 cuDNN kernels on T4
with torch.amp.autocast(device_type='cuda', enabled=False):
spectral_out = self.spectral(x_2d.float())
local_out = self.local_conv(x_2d.float())
mixed = spectral_out.to(x.dtype) + local_out.to(x.dtype)
mixed = self.token_diff(mixed)
# Back to sequence format
mixed = mixed.permute(0, 2, 3, 1).reshape(B, N, D) # (B, N, D)
# Gated output
mixed = mixed * self.gate(mixed)
return residual + mixed
|