""" PDE-SSM: Fourier-domain spatial mixing block for IRIS. Replaces O(N²) self-attention with O(N log N) spectral convolution. Implements a learnable convection-diffusion-reaction PDE in Fourier space. Native 2D — no rasterization or scanning needed. References: - PDE-SSM-DiT (arxiv:2603.13663): Spectral state space approach - FNO (arxiv:2010.08895): Fourier Neural Operator - DyDiLA (arxiv:2601.13683): Token differential to prevent oversmoothing """ import torch import torch.nn as nn import torch.nn.functional as F import math class SpectralConv2d(nn.Module): """ 2D Fourier-domain learnable convolution. Applies learnable complex weights to low-frequency modes in the Fourier domain. This is the core of the PDE-SSM operator. Complexity: O(N log N) via FFT, vs O(N²) for attention. For 4×4 spatial grids (16 tokens), this is trivially fast, but the architecture is designed to scale to 8×8 (64 tokens) or 16×16 (256 tokens) without quadratic blowup. """ def __init__(self, channels: int, modes_h: int, modes_w: int): super().__init__() self.channels = channels self.modes_h = modes_h self.modes_w = modes_w # Learnable complex weights for low-frequency modes # Two sets: positive and negative frequency halves in H dimension # rfft2 output is (B, C, H, W//2+1) — W is halved due to Hermitian symmetry scale = 1.0 / (channels * channels) self.weight_pos = nn.Parameter( scale * torch.randn(channels, channels, modes_h, modes_w, 2) ) self.weight_neg = nn.Parameter( scale * torch.randn(channels, channels, modes_h, modes_w, 2) ) def _complex_mul(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: """Complex matrix multiply: (B,Ci,H,W) x (Ci,Co,H,W) -> (B,Co,H,W). Both x and w are stored as real tensors with last dim = 2 (real, imag). This avoids torch.cfloat which has poor AMP support. """ # x: (B, Ci, H, W, 2), w: (Ci, Co, H, W, 2) # Real part: xr*wr - xi*wi # Imag part: xr*wi + xi*wr xr, xi = x[..., 0], x[..., 1] wr, wi = w[..., 0], w[..., 1] or_ = torch.einsum("bihw,iohw->bohw", xr, wr) - torch.einsum("bihw,iohw->bohw", xi, wi) oi = torch.einsum("bihw,iohw->bohw", xr, wi) + torch.einsum("bihw,iohw->bohw", xi, wr) return torch.stack([or_, oi], dim=-1) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, C, H, W) real tensor Returns: (B, C, H, W) real tensor — spatially mixed """ B, C, H, W = x.shape # Forward FFT: real -> complex spectrum # Use view_as_real to get (B, C, H, W//2+1, 2) for AMP compatibility x_ft = torch.fft.rfft2(x.float(), norm="ortho") x_ft = torch.view_as_real(x_ft) # (B, C, H, W//2+1, 2) # Output spectrum (zero-initialized) out_shape = (B, C, H, W // 2 + 1, 2) out_ft = torch.zeros(out_shape, device=x.device, dtype=x_ft.dtype) # Low-frequency positive modes (top of spectrum) mh = min(self.modes_h, H) mw = min(self.modes_w, W // 2 + 1) out_ft[:, :, :mh, :mw] = self._complex_mul( x_ft[:, :, :mh, :mw], self.weight_pos[:, :, :mh, :mw] ) # Low-frequency negative modes (bottom of spectrum, wraps around) if mh > 0 and H > mh: out_ft[:, :, -mh:, :mw] = self._complex_mul( x_ft[:, :, -mh:, :mw], self.weight_neg[:, :, :mh, :mw] ) # Inverse FFT back to spatial domain out_ft_complex = torch.view_as_complex(out_ft) result = torch.fft.irfft2(out_ft_complex, s=(H, W), norm="ortho") return result.to(x.dtype) class TokenDifferential(nn.Module): """ Prevents oversmoothing in spectral/linear blocks. From DyDiLA (arxiv:2601.13683): adds a learned high-pass residual that preserves local contrast. diff(h) = alpha * (h - AvgPool(h)) This is critical — without it, FFT-based mixing kills edges. """ def __init__(self, channels: int): super().__init__() self.alpha = nn.Parameter(torch.zeros(1, channels, 1, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, C, H, W) Returns: (B, C, H, W) with high-frequency residual added """ # Average pool with kernel size matching spatial dims # For small grids (4×4), use adaptive pool to single value avg = F.adaptive_avg_pool2d(x, 1) # (B, C, 1, 1) — global average return x + self.alpha * (x - avg) class PDESSMBlock(nn.Module): """ Complete PDE-SSM spatial mixing block. Combines: 1. Fourier-domain spectral convolution (global mixing, O(N log N)) 2. Pointwise convolution (local residual path) 3. Token differential (anti-oversmoothing) 4. Pre-norm + residual connection This replaces self-attention in the IRIS architecture. For a 4×4 grid (16 tokens), the FFT is 16×log(16) ≈ 64 ops vs 16² = 256 for attention. At 16×16 (256 tokens): 2048 vs 65536. """ def __init__(self, dim: int, spatial_size: int = 4): """ Args: dim: channel dimension spatial_size: H = W of the spatial grid (e.g., 4 for 4×4) """ super().__init__() # Keep ~50% of frequency modes. For size=4, modes=2. modes = max(2, spatial_size // 2) self.norm = nn.LayerNorm(dim) self.spectral = SpectralConv2d(dim, modes, modes) self.local_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=True) self.token_diff = TokenDifferential(dim) self.gate = nn.Sequential( nn.Linear(dim, dim), nn.SiLU(), ) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: """ Args: x: (B, N, D) — sequence of spatial tokens H, W: spatial dimensions such that N = H * W Returns: (B, N, D) — spatially mixed tokens """ B, N, D = x.shape assert N == H * W, f"N={N} must equal H*W={H*W}" residual = x # Pre-norm x = self.norm(x) # Reshape to 2D spatial grid for FFT x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W) # Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing) # Run conv in float32 — grouped/1x1 convs lack bf16 cuDNN kernels on T4 with torch.amp.autocast(device_type='cuda', enabled=False): spectral_out = self.spectral(x_2d.float()) local_out = self.local_conv(x_2d.float()) mixed = spectral_out.to(x.dtype) + local_out.to(x.dtype) mixed = self.token_diff(mixed) # Back to sequence format mixed = mixed.permute(0, 2, 3, 1).reshape(B, N, D) # (B, N, D) # Gated output mixed = mixed * self.gate(mixed) return residual + mixed