File size: 7,149 Bytes
74cbbdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c17293
 
 
 
 
74cbbdd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
PDE-SSM: Fourier-domain spatial mixing block for IRIS.

Replaces O(N²) self-attention with O(N log N) spectral convolution.
Implements a learnable convection-diffusion-reaction PDE in Fourier space.
Native 2D — no rasterization or scanning needed.

References:
- PDE-SSM-DiT (arxiv:2603.13663): Spectral state space approach
- FNO (arxiv:2010.08895): Fourier Neural Operator
- DyDiLA (arxiv:2601.13683): Token differential to prevent oversmoothing
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SpectralConv2d(nn.Module):
    """
    2D Fourier-domain learnable convolution.
    
    Applies learnable complex weights to low-frequency modes in the
    Fourier domain. This is the core of the PDE-SSM operator.
    
    Complexity: O(N log N) via FFT, vs O(N²) for attention.
    
    For 4×4 spatial grids (16 tokens), this is trivially fast,
    but the architecture is designed to scale to 8×8 (64 tokens)
    or 16×16 (256 tokens) without quadratic blowup.
    """

    def __init__(self, channels: int, modes_h: int, modes_w: int):
        super().__init__()
        self.channels = channels
        self.modes_h = modes_h
        self.modes_w = modes_w

        # Learnable complex weights for low-frequency modes
        # Two sets: positive and negative frequency halves in H dimension
        # rfft2 output is (B, C, H, W//2+1) — W is halved due to Hermitian symmetry
        scale = 1.0 / (channels * channels)
        self.weight_pos = nn.Parameter(
            scale * torch.randn(channels, channels, modes_h, modes_w, 2)
        )
        self.weight_neg = nn.Parameter(
            scale * torch.randn(channels, channels, modes_h, modes_w, 2)
        )

    def _complex_mul(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """Complex matrix multiply: (B,Ci,H,W) x (Ci,Co,H,W) -> (B,Co,H,W).
        
        Both x and w are stored as real tensors with last dim = 2 (real, imag).
        This avoids torch.cfloat which has poor AMP support.
        """
        # x: (B, Ci, H, W, 2), w: (Ci, Co, H, W, 2)
        # Real part: xr*wr - xi*wi
        # Imag part: xr*wi + xi*wr
        xr, xi = x[..., 0], x[..., 1]
        wr, wi = w[..., 0], w[..., 1]
        or_ = torch.einsum("bihw,iohw->bohw", xr, wr) - torch.einsum("bihw,iohw->bohw", xi, wi)
        oi = torch.einsum("bihw,iohw->bohw", xr, wi) + torch.einsum("bihw,iohw->bohw", xi, wr)
        return torch.stack([or_, oi], dim=-1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W) real tensor
        Returns:
            (B, C, H, W) real tensor — spatially mixed
        """
        B, C, H, W = x.shape

        # Forward FFT: real -> complex spectrum
        # Use view_as_real to get (B, C, H, W//2+1, 2) for AMP compatibility
        x_ft = torch.fft.rfft2(x.float(), norm="ortho")
        x_ft = torch.view_as_real(x_ft)  # (B, C, H, W//2+1, 2)

        # Output spectrum (zero-initialized)
        out_shape = (B, C, H, W // 2 + 1, 2)
        out_ft = torch.zeros(out_shape, device=x.device, dtype=x_ft.dtype)

        # Low-frequency positive modes (top of spectrum)
        mh = min(self.modes_h, H)
        mw = min(self.modes_w, W // 2 + 1)
        out_ft[:, :, :mh, :mw] = self._complex_mul(
            x_ft[:, :, :mh, :mw], self.weight_pos[:, :, :mh, :mw]
        )

        # Low-frequency negative modes (bottom of spectrum, wraps around)
        if mh > 0 and H > mh:
            out_ft[:, :, -mh:, :mw] = self._complex_mul(
                x_ft[:, :, -mh:, :mw], self.weight_neg[:, :, :mh, :mw]
            )

        # Inverse FFT back to spatial domain
        out_ft_complex = torch.view_as_complex(out_ft)
        result = torch.fft.irfft2(out_ft_complex, s=(H, W), norm="ortho")
        return result.to(x.dtype)


class TokenDifferential(nn.Module):
    """
    Prevents oversmoothing in spectral/linear blocks.
    
    From DyDiLA (arxiv:2601.13683): adds a learned high-pass
    residual that preserves local contrast.
    
    diff(h) = alpha * (h - AvgPool(h))
    
    This is critical — without it, FFT-based mixing kills edges.
    """

    def __init__(self, channels: int):
        super().__init__()
        self.alpha = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, C, H, W)
        Returns:
            (B, C, H, W) with high-frequency residual added
        """
        # Average pool with kernel size matching spatial dims
        # For small grids (4×4), use adaptive pool to single value
        avg = F.adaptive_avg_pool2d(x, 1)  # (B, C, 1, 1) — global average
        return x + self.alpha * (x - avg)


class PDESSMBlock(nn.Module):
    """
    Complete PDE-SSM spatial mixing block.
    
    Combines:
    1. Fourier-domain spectral convolution (global mixing, O(N log N))
    2. Pointwise convolution (local residual path)  
    3. Token differential (anti-oversmoothing)
    4. Pre-norm + residual connection
    
    This replaces self-attention in the IRIS architecture.
    For a 4×4 grid (16 tokens), the FFT is 16×log(16) ≈ 64 ops
    vs 16² = 256 for attention. At 16×16 (256 tokens): 2048 vs 65536.
    """

    def __init__(self, dim: int, spatial_size: int = 4):
        """
        Args:
            dim: channel dimension
            spatial_size: H = W of the spatial grid (e.g., 4 for 4×4)
        """
        super().__init__()
        # Keep ~50% of frequency modes. For size=4, modes=2.
        modes = max(2, spatial_size // 2)

        self.norm = nn.LayerNorm(dim)
        self.spectral = SpectralConv2d(dim, modes, modes)
        self.local_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=True)
        self.token_diff = TokenDifferential(dim)
        self.gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
        )

    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        """
        Args:
            x: (B, N, D) — sequence of spatial tokens
            H, W: spatial dimensions such that N = H * W
        Returns:
            (B, N, D) — spatially mixed tokens
        """
        B, N, D = x.shape
        assert N == H * W, f"N={N} must equal H*W={H*W}"

        residual = x

        # Pre-norm
        x = self.norm(x)

        # Reshape to 2D spatial grid for FFT
        x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2)  # (B, D, H, W)

        # Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
        # Run conv in float32 — grouped/1x1 convs lack bf16 cuDNN kernels on T4
        with torch.amp.autocast(device_type='cuda', enabled=False):
            spectral_out = self.spectral(x_2d.float())
            local_out = self.local_conv(x_2d.float())
        mixed = spectral_out.to(x.dtype) + local_out.to(x.dtype)
        mixed = self.token_diff(mixed)

        # Back to sequence format
        mixed = mixed.permute(0, 2, 3, 1).reshape(B, N, D)  # (B, N, D)

        # Gated output
        mixed = mixed * self.gate(mixed)

        return residual + mixed