iris-image-gen / iris /pde_ssm.py
asdf98's picture
Fix conv2d bf16 crash on T4: iris/pde_ssm.py
8c17293 verified
"""
PDE-SSM: Fourier-domain spatial mixing block for IRIS.
Replaces O(N²) self-attention with O(N log N) spectral convolution.
Implements a learnable convection-diffusion-reaction PDE in Fourier space.
Native 2D — no rasterization or scanning needed.
References:
- PDE-SSM-DiT (arxiv:2603.13663): Spectral state space approach
- FNO (arxiv:2010.08895): Fourier Neural Operator
- DyDiLA (arxiv:2601.13683): Token differential to prevent oversmoothing
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SpectralConv2d(nn.Module):
"""
2D Fourier-domain learnable convolution.
Applies learnable complex weights to low-frequency modes in the
Fourier domain. This is the core of the PDE-SSM operator.
Complexity: O(N log N) via FFT, vs O(N²) for attention.
For 4×4 spatial grids (16 tokens), this is trivially fast,
but the architecture is designed to scale to 8×8 (64 tokens)
or 16×16 (256 tokens) without quadratic blowup.
"""
def __init__(self, channels: int, modes_h: int, modes_w: int):
super().__init__()
self.channels = channels
self.modes_h = modes_h
self.modes_w = modes_w
# Learnable complex weights for low-frequency modes
# Two sets: positive and negative frequency halves in H dimension
# rfft2 output is (B, C, H, W//2+1) — W is halved due to Hermitian symmetry
scale = 1.0 / (channels * channels)
self.weight_pos = nn.Parameter(
scale * torch.randn(channels, channels, modes_h, modes_w, 2)
)
self.weight_neg = nn.Parameter(
scale * torch.randn(channels, channels, modes_h, modes_w, 2)
)
def _complex_mul(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""Complex matrix multiply: (B,Ci,H,W) x (Ci,Co,H,W) -> (B,Co,H,W).
Both x and w are stored as real tensors with last dim = 2 (real, imag).
This avoids torch.cfloat which has poor AMP support.
"""
# x: (B, Ci, H, W, 2), w: (Ci, Co, H, W, 2)
# Real part: xr*wr - xi*wi
# Imag part: xr*wi + xi*wr
xr, xi = x[..., 0], x[..., 1]
wr, wi = w[..., 0], w[..., 1]
or_ = torch.einsum("bihw,iohw->bohw", xr, wr) - torch.einsum("bihw,iohw->bohw", xi, wi)
oi = torch.einsum("bihw,iohw->bohw", xr, wi) + torch.einsum("bihw,iohw->bohw", xi, wr)
return torch.stack([or_, oi], dim=-1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W) real tensor
Returns:
(B, C, H, W) real tensor — spatially mixed
"""
B, C, H, W = x.shape
# Forward FFT: real -> complex spectrum
# Use view_as_real to get (B, C, H, W//2+1, 2) for AMP compatibility
x_ft = torch.fft.rfft2(x.float(), norm="ortho")
x_ft = torch.view_as_real(x_ft) # (B, C, H, W//2+1, 2)
# Output spectrum (zero-initialized)
out_shape = (B, C, H, W // 2 + 1, 2)
out_ft = torch.zeros(out_shape, device=x.device, dtype=x_ft.dtype)
# Low-frequency positive modes (top of spectrum)
mh = min(self.modes_h, H)
mw = min(self.modes_w, W // 2 + 1)
out_ft[:, :, :mh, :mw] = self._complex_mul(
x_ft[:, :, :mh, :mw], self.weight_pos[:, :, :mh, :mw]
)
# Low-frequency negative modes (bottom of spectrum, wraps around)
if mh > 0 and H > mh:
out_ft[:, :, -mh:, :mw] = self._complex_mul(
x_ft[:, :, -mh:, :mw], self.weight_neg[:, :, :mh, :mw]
)
# Inverse FFT back to spatial domain
out_ft_complex = torch.view_as_complex(out_ft)
result = torch.fft.irfft2(out_ft_complex, s=(H, W), norm="ortho")
return result.to(x.dtype)
class TokenDifferential(nn.Module):
"""
Prevents oversmoothing in spectral/linear blocks.
From DyDiLA (arxiv:2601.13683): adds a learned high-pass
residual that preserves local contrast.
diff(h) = alpha * (h - AvgPool(h))
This is critical — without it, FFT-based mixing kills edges.
"""
def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C, H, W)
Returns:
(B, C, H, W) with high-frequency residual added
"""
# Average pool with kernel size matching spatial dims
# For small grids (4×4), use adaptive pool to single value
avg = F.adaptive_avg_pool2d(x, 1) # (B, C, 1, 1) — global average
return x + self.alpha * (x - avg)
class PDESSMBlock(nn.Module):
"""
Complete PDE-SSM spatial mixing block.
Combines:
1. Fourier-domain spectral convolution (global mixing, O(N log N))
2. Pointwise convolution (local residual path)
3. Token differential (anti-oversmoothing)
4. Pre-norm + residual connection
This replaces self-attention in the IRIS architecture.
For a 4×4 grid (16 tokens), the FFT is 16×log(16) ≈ 64 ops
vs 16² = 256 for attention. At 16×16 (256 tokens): 2048 vs 65536.
"""
def __init__(self, dim: int, spatial_size: int = 4):
"""
Args:
dim: channel dimension
spatial_size: H = W of the spatial grid (e.g., 4 for 4×4)
"""
super().__init__()
# Keep ~50% of frequency modes. For size=4, modes=2.
modes = max(2, spatial_size // 2)
self.norm = nn.LayerNorm(dim)
self.spectral = SpectralConv2d(dim, modes, modes)
self.local_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=True)
self.token_diff = TokenDifferential(dim)
self.gate = nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
"""
Args:
x: (B, N, D) — sequence of spatial tokens
H, W: spatial dimensions such that N = H * W
Returns:
(B, N, D) — spatially mixed tokens
"""
B, N, D = x.shape
assert N == H * W, f"N={N} must equal H*W={H*W}"
residual = x
# Pre-norm
x = self.norm(x)
# Reshape to 2D spatial grid for FFT
x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
# Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
# Run conv in float32 — grouped/1x1 convs lack bf16 cuDNN kernels on T4
with torch.amp.autocast(device_type='cuda', enabled=False):
spectral_out = self.spectral(x_2d.float())
local_out = self.local_conv(x_2d.float())
mixed = spectral_out.to(x.dtype) + local_out.to(x.dtype)
mixed = self.token_diff(mixed)
# Back to sequence format
mixed = mixed.permute(0, 2, 3, 1).reshape(B, N, D) # (B, N, D)
# Gated output
mixed = mixed * self.gate(mixed)
return residual + mixed