| """ |
| PDE-SSM: Fourier-domain spatial mixing block for IRIS. |
| |
| Replaces O(N²) self-attention with O(N log N) spectral convolution. |
| Implements a learnable convection-diffusion-reaction PDE in Fourier space. |
| Native 2D — no rasterization or scanning needed. |
| |
| References: |
| - PDE-SSM-DiT (arxiv:2603.13663): Spectral state space approach |
| - FNO (arxiv:2010.08895): Fourier Neural Operator |
| - DyDiLA (arxiv:2601.13683): Token differential to prevent oversmoothing |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class SpectralConv2d(nn.Module): |
| """ |
| 2D Fourier-domain learnable convolution. |
| |
| Applies learnable complex weights to low-frequency modes in the |
| Fourier domain. This is the core of the PDE-SSM operator. |
| |
| Complexity: O(N log N) via FFT, vs O(N²) for attention. |
| |
| For 4×4 spatial grids (16 tokens), this is trivially fast, |
| but the architecture is designed to scale to 8×8 (64 tokens) |
| or 16×16 (256 tokens) without quadratic blowup. |
| """ |
|
|
| def __init__(self, channels: int, modes_h: int, modes_w: int): |
| super().__init__() |
| self.channels = channels |
| self.modes_h = modes_h |
| self.modes_w = modes_w |
|
|
| |
| |
| |
| scale = 1.0 / (channels * channels) |
| self.weight_pos = nn.Parameter( |
| scale * torch.randn(channels, channels, modes_h, modes_w, 2) |
| ) |
| self.weight_neg = nn.Parameter( |
| scale * torch.randn(channels, channels, modes_h, modes_w, 2) |
| ) |
|
|
| def _complex_mul(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: |
| """Complex matrix multiply: (B,Ci,H,W) x (Ci,Co,H,W) -> (B,Co,H,W). |
| |
| Both x and w are stored as real tensors with last dim = 2 (real, imag). |
| This avoids torch.cfloat which has poor AMP support. |
| """ |
| |
| |
| |
| xr, xi = x[..., 0], x[..., 1] |
| wr, wi = w[..., 0], w[..., 1] |
| or_ = torch.einsum("bihw,iohw->bohw", xr, wr) - torch.einsum("bihw,iohw->bohw", xi, wi) |
| oi = torch.einsum("bihw,iohw->bohw", xr, wi) + torch.einsum("bihw,iohw->bohw", xi, wr) |
| return torch.stack([or_, oi], dim=-1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, C, H, W) real tensor |
| Returns: |
| (B, C, H, W) real tensor — spatially mixed |
| """ |
| B, C, H, W = x.shape |
|
|
| |
| |
| x_ft = torch.fft.rfft2(x.float(), norm="ortho") |
| x_ft = torch.view_as_real(x_ft) |
|
|
| |
| out_shape = (B, C, H, W // 2 + 1, 2) |
| out_ft = torch.zeros(out_shape, device=x.device, dtype=x_ft.dtype) |
|
|
| |
| mh = min(self.modes_h, H) |
| mw = min(self.modes_w, W // 2 + 1) |
| out_ft[:, :, :mh, :mw] = self._complex_mul( |
| x_ft[:, :, :mh, :mw], self.weight_pos[:, :, :mh, :mw] |
| ) |
|
|
| |
| if mh > 0 and H > mh: |
| out_ft[:, :, -mh:, :mw] = self._complex_mul( |
| x_ft[:, :, -mh:, :mw], self.weight_neg[:, :, :mh, :mw] |
| ) |
|
|
| |
| out_ft_complex = torch.view_as_complex(out_ft) |
| result = torch.fft.irfft2(out_ft_complex, s=(H, W), norm="ortho") |
| return result.to(x.dtype) |
|
|
|
|
| class TokenDifferential(nn.Module): |
| """ |
| Prevents oversmoothing in spectral/linear blocks. |
| |
| From DyDiLA (arxiv:2601.13683): adds a learned high-pass |
| residual that preserves local contrast. |
| |
| diff(h) = alpha * (h - AvgPool(h)) |
| |
| This is critical — without it, FFT-based mixing kills edges. |
| """ |
|
|
| def __init__(self, channels: int): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.zeros(1, channels, 1, 1)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, C, H, W) |
| Returns: |
| (B, C, H, W) with high-frequency residual added |
| """ |
| |
| |
| avg = F.adaptive_avg_pool2d(x, 1) |
| return x + self.alpha * (x - avg) |
|
|
|
|
| class PDESSMBlock(nn.Module): |
| """ |
| Complete PDE-SSM spatial mixing block. |
| |
| Combines: |
| 1. Fourier-domain spectral convolution (global mixing, O(N log N)) |
| 2. Pointwise convolution (local residual path) |
| 3. Token differential (anti-oversmoothing) |
| 4. Pre-norm + residual connection |
| |
| This replaces self-attention in the IRIS architecture. |
| For a 4×4 grid (16 tokens), the FFT is 16×log(16) ≈ 64 ops |
| vs 16² = 256 for attention. At 16×16 (256 tokens): 2048 vs 65536. |
| """ |
|
|
| def __init__(self, dim: int, spatial_size: int = 4): |
| """ |
| Args: |
| dim: channel dimension |
| spatial_size: H = W of the spatial grid (e.g., 4 for 4×4) |
| """ |
| super().__init__() |
| |
| modes = max(2, spatial_size // 2) |
|
|
| self.norm = nn.LayerNorm(dim) |
| self.spectral = SpectralConv2d(dim, modes, modes) |
| self.local_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=True) |
| self.token_diff = TokenDifferential(dim) |
| self.gate = nn.Sequential( |
| nn.Linear(dim, dim), |
| nn.SiLU(), |
| ) |
|
|
| def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, N, D) — sequence of spatial tokens |
| H, W: spatial dimensions such that N = H * W |
| Returns: |
| (B, N, D) — spatially mixed tokens |
| """ |
| B, N, D = x.shape |
| assert N == H * W, f"N={N} must equal H*W={H*W}" |
|
|
| residual = x |
|
|
| |
| x = self.norm(x) |
|
|
| |
| x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) |
|
|
| |
| |
| with torch.amp.autocast(device_type='cuda', enabled=False): |
| spectral_out = self.spectral(x_2d.float()) |
| local_out = self.local_conv(x_2d.float()) |
| mixed = spectral_out.to(x.dtype) + local_out.to(x.dtype) |
| mixed = self.token_diff(mixed) |
|
|
| |
| mixed = mixed.permute(0, 2, 3, 1).reshape(B, N, D) |
|
|
| |
| mixed = mixed * self.gate(mixed) |
|
|
| return residual + mixed |
|
|