asdf98 commited on
Commit
74cbbdd
·
verified ·
1 Parent(s): a548be7

Upload iris/pde_ssm.py

Browse files
Files changed (1) hide show
  1. iris/pde_ssm.py +197 -0
iris/pde_ssm.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PDE-SSM: Fourier-domain spatial mixing block for IRIS.
3
+
4
+ Replaces O(N²) self-attention with O(N log N) spectral convolution.
5
+ Implements a learnable convection-diffusion-reaction PDE in Fourier space.
6
+ Native 2D — no rasterization or scanning needed.
7
+
8
+ References:
9
+ - PDE-SSM-DiT (arxiv:2603.13663): Spectral state space approach
10
+ - FNO (arxiv:2010.08895): Fourier Neural Operator
11
+ - DyDiLA (arxiv:2601.13683): Token differential to prevent oversmoothing
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import math
18
+
19
+
20
+ class SpectralConv2d(nn.Module):
21
+ """
22
+ 2D Fourier-domain learnable convolution.
23
+
24
+ Applies learnable complex weights to low-frequency modes in the
25
+ Fourier domain. This is the core of the PDE-SSM operator.
26
+
27
+ Complexity: O(N log N) via FFT, vs O(N²) for attention.
28
+
29
+ For 4×4 spatial grids (16 tokens), this is trivially fast,
30
+ but the architecture is designed to scale to 8×8 (64 tokens)
31
+ or 16×16 (256 tokens) without quadratic blowup.
32
+ """
33
+
34
+ def __init__(self, channels: int, modes_h: int, modes_w: int):
35
+ super().__init__()
36
+ self.channels = channels
37
+ self.modes_h = modes_h
38
+ self.modes_w = modes_w
39
+
40
+ # Learnable complex weights for low-frequency modes
41
+ # Two sets: positive and negative frequency halves in H dimension
42
+ # rfft2 output is (B, C, H, W//2+1) — W is halved due to Hermitian symmetry
43
+ scale = 1.0 / (channels * channels)
44
+ self.weight_pos = nn.Parameter(
45
+ scale * torch.randn(channels, channels, modes_h, modes_w, 2)
46
+ )
47
+ self.weight_neg = nn.Parameter(
48
+ scale * torch.randn(channels, channels, modes_h, modes_w, 2)
49
+ )
50
+
51
+ def _complex_mul(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
52
+ """Complex matrix multiply: (B,Ci,H,W) x (Ci,Co,H,W) -> (B,Co,H,W).
53
+
54
+ Both x and w are stored as real tensors with last dim = 2 (real, imag).
55
+ This avoids torch.cfloat which has poor AMP support.
56
+ """
57
+ # x: (B, Ci, H, W, 2), w: (Ci, Co, H, W, 2)
58
+ # Real part: xr*wr - xi*wi
59
+ # Imag part: xr*wi + xi*wr
60
+ xr, xi = x[..., 0], x[..., 1]
61
+ wr, wi = w[..., 0], w[..., 1]
62
+ or_ = torch.einsum("bihw,iohw->bohw", xr, wr) - torch.einsum("bihw,iohw->bohw", xi, wi)
63
+ oi = torch.einsum("bihw,iohw->bohw", xr, wi) + torch.einsum("bihw,iohw->bohw", xi, wr)
64
+ return torch.stack([or_, oi], dim=-1)
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Args:
69
+ x: (B, C, H, W) real tensor
70
+ Returns:
71
+ (B, C, H, W) real tensor — spatially mixed
72
+ """
73
+ B, C, H, W = x.shape
74
+
75
+ # Forward FFT: real -> complex spectrum
76
+ # Use view_as_real to get (B, C, H, W//2+1, 2) for AMP compatibility
77
+ x_ft = torch.fft.rfft2(x.float(), norm="ortho")
78
+ x_ft = torch.view_as_real(x_ft) # (B, C, H, W//2+1, 2)
79
+
80
+ # Output spectrum (zero-initialized)
81
+ out_shape = (B, C, H, W // 2 + 1, 2)
82
+ out_ft = torch.zeros(out_shape, device=x.device, dtype=x_ft.dtype)
83
+
84
+ # Low-frequency positive modes (top of spectrum)
85
+ mh = min(self.modes_h, H)
86
+ mw = min(self.modes_w, W // 2 + 1)
87
+ out_ft[:, :, :mh, :mw] = self._complex_mul(
88
+ x_ft[:, :, :mh, :mw], self.weight_pos[:, :, :mh, :mw]
89
+ )
90
+
91
+ # Low-frequency negative modes (bottom of spectrum, wraps around)
92
+ if mh > 0 and H > mh:
93
+ out_ft[:, :, -mh:, :mw] = self._complex_mul(
94
+ x_ft[:, :, -mh:, :mw], self.weight_neg[:, :, :mh, :mw]
95
+ )
96
+
97
+ # Inverse FFT back to spatial domain
98
+ out_ft_complex = torch.view_as_complex(out_ft)
99
+ result = torch.fft.irfft2(out_ft_complex, s=(H, W), norm="ortho")
100
+ return result.to(x.dtype)
101
+
102
+
103
+ class TokenDifferential(nn.Module):
104
+ """
105
+ Prevents oversmoothing in spectral/linear blocks.
106
+
107
+ From DyDiLA (arxiv:2601.13683): adds a learned high-pass
108
+ residual that preserves local contrast.
109
+
110
+ diff(h) = alpha * (h - AvgPool(h))
111
+
112
+ This is critical — without it, FFT-based mixing kills edges.
113
+ """
114
+
115
+ def __init__(self, channels: int):
116
+ super().__init__()
117
+ self.alpha = nn.Parameter(torch.zeros(1, channels, 1, 1))
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ """
121
+ Args:
122
+ x: (B, C, H, W)
123
+ Returns:
124
+ (B, C, H, W) with high-frequency residual added
125
+ """
126
+ # Average pool with kernel size matching spatial dims
127
+ # For small grids (4×4), use adaptive pool to single value
128
+ avg = F.adaptive_avg_pool2d(x, 1) # (B, C, 1, 1) — global average
129
+ return x + self.alpha * (x - avg)
130
+
131
+
132
+ class PDESSMBlock(nn.Module):
133
+ """
134
+ Complete PDE-SSM spatial mixing block.
135
+
136
+ Combines:
137
+ 1. Fourier-domain spectral convolution (global mixing, O(N log N))
138
+ 2. Pointwise convolution (local residual path)
139
+ 3. Token differential (anti-oversmoothing)
140
+ 4. Pre-norm + residual connection
141
+
142
+ This replaces self-attention in the IRIS architecture.
143
+ For a 4×4 grid (16 tokens), the FFT is 16×log(16) ≈ 64 ops
144
+ vs 16² = 256 for attention. At 16×16 (256 tokens): 2048 vs 65536.
145
+ """
146
+
147
+ def __init__(self, dim: int, spatial_size: int = 4):
148
+ """
149
+ Args:
150
+ dim: channel dimension
151
+ spatial_size: H = W of the spatial grid (e.g., 4 for 4×4)
152
+ """
153
+ super().__init__()
154
+ # Keep ~50% of frequency modes. For size=4, modes=2.
155
+ modes = max(2, spatial_size // 2)
156
+
157
+ self.norm = nn.LayerNorm(dim)
158
+ self.spectral = SpectralConv2d(dim, modes, modes)
159
+ self.local_conv = nn.Conv2d(dim, dim, kernel_size=1, bias=True)
160
+ self.token_diff = TokenDifferential(dim)
161
+ self.gate = nn.Sequential(
162
+ nn.Linear(dim, dim),
163
+ nn.SiLU(),
164
+ )
165
+
166
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
167
+ """
168
+ Args:
169
+ x: (B, N, D) — sequence of spatial tokens
170
+ H, W: spatial dimensions such that N = H * W
171
+ Returns:
172
+ (B, N, D) — spatially mixed tokens
173
+ """
174
+ B, N, D = x.shape
175
+ assert N == H * W, f"N={N} must equal H*W={H*W}"
176
+
177
+ residual = x
178
+
179
+ # Pre-norm
180
+ x = self.norm(x)
181
+
182
+ # Reshape to 2D spatial grid for FFT
183
+ x_2d = x.view(B, H, W, D).permute(0, 3, 1, 2) # (B, D, H, W)
184
+
185
+ # Spectral mixing (global) + Local conv (residual) + Token diff (anti-smoothing)
186
+ spectral_out = self.spectral(x_2d)
187
+ local_out = self.local_conv(x_2d)
188
+ mixed = spectral_out + local_out
189
+ mixed = self.token_diff(mixed)
190
+
191
+ # Back to sequence format
192
+ mixed = mixed.permute(0, 2, 3, 1).reshape(B, N, D) # (B, N, D)
193
+
194
+ # Gated output
195
+ mixed = mixed * self.gate(mixed)
196
+
197
+ return residual + mixed