asdf98 commited on
Commit
774e194
·
verified ·
1 Parent(s): f69abc9

Add iris_model.py

Browse files
Files changed (1) hide show
  1. iris_model.py +1246 -0
iris_model.py ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IRIS: Iterative Recurrent Image Synthesis
3
+ ==========================================
4
+ A novel architecture for mobile-first high-quality image generation.
5
+
6
+ Key innovations:
7
+ 1. Wavelet-Frequency Latent Space (Haar DWT + lightweight VAE)
8
+ 2. Recurrent Depth Core (Prelude-Core-Coda with shared weights)
9
+ 3. Gated Recurrent Fourier Mixer (GRFM) — novel token mixing
10
+ 4. Manhattan Spatial Decay — learned 2D inductive bias
11
+ 5. Rectified Flow training with consistency distillation support
12
+ 6. Adaptive compute budget (4-16 iterations, same model)
13
+
14
+ Author: IRIS Research
15
+ License: Apache 2.0
16
+ """
17
+
18
+ import math
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from typing import Optional, Tuple
23
+ from dataclasses import dataclass, field
24
+
25
+
26
+ # ============================================================================
27
+ # Configuration
28
+ # ============================================================================
29
+
30
+ @dataclass
31
+ class IRISConfig:
32
+ """Configuration for IRIS model."""
33
+ # Latent space
34
+ latent_channels: int = 16 # Channels in latent space
35
+ latent_spatial: int = 32 # Spatial dim of latent (for 512px with 16x compression)
36
+
37
+ # Model dimensions
38
+ hidden_dim: int = 512 # Main hidden dimension
39
+ num_heads: int = 8 # Number of attention heads
40
+ head_dim: int = 64 # Dimension per head
41
+ ffn_ratio: float = 2.667 # FFN expansion ratio (SwiGLU-adjusted)
42
+
43
+ # Architecture structure
44
+ num_prelude_blocks: int = 2 # Prelude blocks (unique weights)
45
+ num_core_layers: int = 4 # Layers WITHIN each core iteration
46
+ num_coda_blocks: int = 2 # Coda blocks (unique weights)
47
+ default_iterations: int = 8 # Default core iterations at inference
48
+ max_iterations: int = 16 # Maximum core iterations
49
+
50
+ # GRFM settings
51
+ fourier_num_blocks: int = 8 # Block-diagonal blocks in Fourier MLP
52
+ sparsity_threshold: float = 0.01 # Soft-shrinkage lambda
53
+ recurrence_dim: int = 256 # Dimension for gated recurrence pathway
54
+ manhattan_window: int = 16 # Windowed Manhattan decay (for efficiency)
55
+
56
+ # Cross-attention
57
+ text_dim: int = 768 # CLIP-L/14 text embedding dim
58
+ max_text_tokens: int = 77 # Maximum text sequence length
59
+
60
+ # Patch embedding
61
+ patch_size: int = 2 # Patches in latent space (2×2)
62
+
63
+ # Conditioning
64
+ num_timesteps: int = 1000 # Noise schedule discretization
65
+
66
+ # VAE
67
+ vae_channels: list = field(default_factory=lambda: [32, 64, 128, 256])
68
+
69
+ # Training
70
+ dropout: float = 0.0
71
+
72
+ @property
73
+ def vae_latent_channels(self) -> int:
74
+ """VAE latent channels must match generator latent channels."""
75
+ return self.latent_channels
76
+
77
+ @property
78
+ def num_patches(self) -> int:
79
+ return (self.latent_spatial // self.patch_size) ** 2
80
+
81
+ @property
82
+ def patch_dim(self) -> int:
83
+ return self.latent_channels * self.patch_size * self.patch_size
84
+
85
+
86
+ # ============================================================================
87
+ # Wavelet Transforms (Haar)
88
+ # ============================================================================
89
+
90
+ class HaarDWT2D(nn.Module):
91
+ """2D Discrete Wavelet Transform using Haar wavelets.
92
+ Decomposes x ∈ R^{B,C,H,W} into R^{B,4C,H/2,W/2} (LL, LH, HL, HH subbands).
93
+ """
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ # Haar DWT: split into even/odd along both spatial dims
96
+ x_ll = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] +
97
+ x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5
98
+ x_lh = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] -
99
+ x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5
100
+ x_hl = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] +
101
+ x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5
102
+ x_hh = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] -
103
+ x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5
104
+ return torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1)
105
+
106
+
107
+ class HaarIDWT2D(nn.Module):
108
+ """2D Inverse Discrete Wavelet Transform (Haar).
109
+ Reconstructs x ∈ R^{B,C,H,W} from R^{B,4*(C//4),H/2,W/2}.
110
+ """
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ B, C4, Hh, Wh = x.shape
113
+ C = C4 // 4
114
+ ll, lh, hl, hh = x[:, :C], x[:, C:2*C], x[:, 2*C:3*C], x[:, 3*C:]
115
+
116
+ # Reconstruct 2× spatial resolution
117
+ H, W = Hh * 2, Wh * 2
118
+ out = torch.zeros(B, C, H, W, device=x.device, dtype=x.dtype)
119
+ out[:, :, 0::2, 0::2] = (ll + lh + hl + hh) * 0.5
120
+ out[:, :, 0::2, 1::2] = (ll + lh - hl - hh) * 0.5
121
+ out[:, :, 1::2, 0::2] = (ll - lh + hl - hh) * 0.5
122
+ out[:, :, 1::2, 1::2] = (ll - lh - hl + hh) * 0.5
123
+ return out
124
+
125
+
126
+ # ============================================================================
127
+ # Lightweight Wavelet VAE
128
+ # ============================================================================
129
+
130
+ class DepthwiseSeparableConv(nn.Module):
131
+ """Depthwise separable convolution — key mobile optimization."""
132
+ def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1):
133
+ super().__init__()
134
+ self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch)
135
+ self.pointwise = nn.Conv2d(in_ch, out_ch, 1)
136
+
137
+ def forward(self, x):
138
+ return self.pointwise(self.depthwise(x))
139
+
140
+
141
+ class ResBlock(nn.Module):
142
+ """Residual block with depthwise separable convolutions."""
143
+ def __init__(self, channels):
144
+ super().__init__()
145
+ self.norm1 = nn.GroupNorm(8, channels)
146
+ self.conv1 = DepthwiseSeparableConv(channels, channels)
147
+ self.norm2 = nn.GroupNorm(8, channels)
148
+ self.conv2 = DepthwiseSeparableConv(channels, channels)
149
+ # Zero-init final layer for residual learning stability
150
+ nn.init.zeros_(self.conv2.pointwise.weight)
151
+ nn.init.zeros_(self.conv2.pointwise.bias)
152
+
153
+ def forward(self, x):
154
+ h = F.silu(self.norm1(x))
155
+ h = self.conv1(h)
156
+ h = F.silu(self.norm2(h))
157
+ h = self.conv2(h)
158
+ return x + h
159
+
160
+
161
+ class WaveletVAEEncoder(nn.Module):
162
+ """Lightweight encoder: Haar DWT preprocessing + small convolutional encoder.
163
+ Input: images R^{B,3,H,W} → Output: latent R^{B,C_latent,H/16,W/16}
164
+ Compression: 3×H×W → C_latent×(H/16)×(W/16)
165
+ """
166
+ def __init__(self, config: IRISConfig):
167
+ super().__init__()
168
+ self.dwt = HaarDWT2D()
169
+ channels = config.vae_channels
170
+ latent_ch = config.vae_latent_channels
171
+
172
+ # DWT: 3 channels → 12 channels at H/2 × W/2
173
+ self.conv_in = nn.Conv2d(12, channels[0], 3, 1, 1)
174
+
175
+ # Downsampling path: H/2→H/4→H/8→H/16
176
+ self.down_blocks = nn.ModuleList()
177
+ for i in range(len(channels) - 1):
178
+ self.down_blocks.append(nn.Sequential(
179
+ ResBlock(channels[i]),
180
+ nn.Conv2d(channels[i], channels[i+1], 3, 2, 1), # 2× downsample
181
+ ))
182
+
183
+ # Bottleneck
184
+ self.mid = nn.Sequential(
185
+ ResBlock(channels[-1]),
186
+ ResBlock(channels[-1]),
187
+ )
188
+
189
+ # To latent (mean + logvar)
190
+ self.norm_out = nn.GroupNorm(8, channels[-1])
191
+ self.conv_out = nn.Conv2d(channels[-1], 2 * latent_ch, 1)
192
+
193
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
194
+ # Haar DWT preprocessing
195
+ x = self.dwt(x) # [B, 12, H/2, W/2]
196
+ x = self.conv_in(x)
197
+
198
+ for down in self.down_blocks:
199
+ x = down(x)
200
+
201
+ x = self.mid(x)
202
+ x = F.silu(self.norm_out(x))
203
+ x = self.conv_out(x)
204
+
205
+ mean, logvar = x.chunk(2, dim=1)
206
+ logvar = torch.clamp(logvar, -30.0, 20.0)
207
+ return mean, logvar
208
+
209
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
210
+ mean, logvar = self.forward(x)
211
+ std = torch.exp(0.5 * logvar)
212
+ z = mean + std * torch.randn_like(std)
213
+ return z, mean, logvar
214
+
215
+
216
+ class WaveletVAEDecoder(nn.Module):
217
+ """Tiny decoder: latent → wavelet coefficients → Haar IDWT → image.
218
+ Designed to be as small as possible for mobile inference.
219
+ """
220
+ def __init__(self, config: IRISConfig):
221
+ super().__init__()
222
+ channels = list(reversed(config.vae_channels))
223
+ latent_ch = config.vae_latent_channels
224
+ self.idwt = HaarIDWT2D()
225
+
226
+ # From latent
227
+ self.conv_in = nn.Conv2d(latent_ch, channels[0], 3, 1, 1)
228
+
229
+ # Bottleneck
230
+ self.mid = nn.Sequential(
231
+ ResBlock(channels[0]),
232
+ )
233
+
234
+ # Upsampling path
235
+ self.up_blocks = nn.ModuleList()
236
+ for i in range(len(channels) - 1):
237
+ self.up_blocks.append(nn.Sequential(
238
+ nn.Upsample(scale_factor=2, mode='nearest'),
239
+ DepthwiseSeparableConv(channels[i], channels[i+1]),
240
+ nn.SiLU(),
241
+ ResBlock(channels[i+1]),
242
+ ))
243
+
244
+ # To wavelet coefficients (12 channels: 4 subbands × 3 color channels)
245
+ self.norm_out = nn.GroupNorm(8, channels[-1])
246
+ self.conv_out = nn.Conv2d(channels[-1], 12, 3, 1, 1)
247
+
248
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
249
+ x = self.conv_in(z)
250
+ x = self.mid(x)
251
+
252
+ for up in self.up_blocks:
253
+ x = up(x)
254
+
255
+ x = F.silu(self.norm_out(x))
256
+ x = self.conv_out(x) # [B, 12, H/2, W/2] wavelet coefficients
257
+
258
+ # Inverse DWT to get image
259
+ x = self.idwt(x) # [B, 3, H, W]
260
+ return x
261
+
262
+
263
+ class WaveletVAE(nn.Module):
264
+ """Complete Wavelet VAE with DWT preprocessing."""
265
+ def __init__(self, config: IRISConfig):
266
+ super().__init__()
267
+ self.encoder = WaveletVAEEncoder(config)
268
+ self.decoder = WaveletVAEDecoder(config)
269
+ self.config = config
270
+
271
+ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
272
+ return self.encoder.encode(x)
273
+
274
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
275
+ return self.decoder(z)
276
+
277
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
278
+ z, mean, logvar = self.encode(x)
279
+ x_recon = self.decode(z)
280
+ return x_recon, mean, logvar
281
+
282
+
283
+ # ============================================================================
284
+ # Conditioning Modules
285
+ # ============================================================================
286
+
287
+ class TimestepEmbedding(nn.Module):
288
+ """Sinusoidal timestep embedding with MLP projection."""
289
+ def __init__(self, dim: int, max_period: int = 10000):
290
+ super().__init__()
291
+ self.dim = dim
292
+ self.max_period = max_period
293
+ self.mlp = nn.Sequential(
294
+ nn.Linear(dim, 4 * dim),
295
+ nn.SiLU(),
296
+ nn.Linear(4 * dim, dim),
297
+ )
298
+
299
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
300
+ half = self.dim // 2
301
+ freqs = torch.exp(
302
+ -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half
303
+ )
304
+ args = t[:, None] * freqs[None, :]
305
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
306
+ return self.mlp(embedding)
307
+
308
+
309
+ class IterationEmbedding(nn.Module):
310
+ """Learnable embedding for iteration index within recurrent core."""
311
+ def __init__(self, max_iterations: int, dim: int):
312
+ super().__init__()
313
+ self.embedding = nn.Embedding(max_iterations, dim)
314
+
315
+ def forward(self, i: torch.Tensor) -> torch.Tensor:
316
+ return self.embedding(i)
317
+
318
+
319
+ class AdaLNSingle(nn.Module):
320
+ """Adaptive Layer Normalization (single shared MLP, per-layer bias).
321
+ From PixArt-α: saves 27% params vs standard adaLN.
322
+
323
+ Produces (scale, shift, gate) for each sub-layer from a shared condition vector.
324
+ """
325
+ def __init__(self, dim: int, num_modulations: int = 6):
326
+ super().__init__()
327
+ self.silu = nn.SiLU()
328
+ self.linear = nn.Linear(dim, num_modulations * dim)
329
+ self.num_modulations = num_modulations
330
+ nn.init.zeros_(self.linear.weight)
331
+ nn.init.zeros_(self.linear.bias)
332
+
333
+ def forward(self, c: torch.Tensor) -> Tuple[torch.Tensor, ...]:
334
+ """c: [B, D] condition vector → tuple of num_modulations tensors [B, D]."""
335
+ params = self.linear(self.silu(c))
336
+ return params.chunk(self.num_modulations, dim=-1)
337
+
338
+
339
+ # ============================================================================
340
+ # GRFM: Gated Recurrent Fourier Mixer (Novel Contribution)
341
+ # ============================================================================
342
+
343
+ class FourierMixingPathway(nn.Module):
344
+ """Pathway 1: Adaptive Fourier Neural Operator-style global mixing.
345
+ O(N log N) complexity via FFT. Block-diagonal MLP in frequency domain.
346
+ """
347
+ def __init__(self, dim: int, num_blocks: int = 8, sparsity_threshold: float = 0.01):
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.num_blocks = num_blocks
351
+ self.block_size = dim // num_blocks
352
+ self.sparsity_threshold = sparsity_threshold
353
+
354
+ # Block-diagonal complex-valued MLP in Fourier domain
355
+ # Each block: R^{block_size} → R^{block_size}
356
+ # Using real-valued params for complex ops (split real/imag)
357
+ self.w1_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
358
+ self.w1_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
359
+ self.w2_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
360
+ self.w2_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
361
+ self.b1 = nn.Parameter(torch.zeros(num_blocks, self.block_size))
362
+ self.b2 = nn.Parameter(torch.zeros(num_blocks, self.block_size))
363
+
364
+ def complex_matmul(self, x: torch.Tensor, w_real: torch.Tensor, w_imag: torch.Tensor) -> torch.Tensor:
365
+ """Complex matrix multiplication: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
366
+ x: [..., num_blocks, block_size] (complex)
367
+ w: [num_blocks, block_size, block_size] (real)
368
+ """
369
+ # Use einsum for proper block-diagonal matmul
370
+ # x: [B, Hf, Wf, K, bs], w: [K, bs, bs] → out: [B, Hf, Wf, K, bs]
371
+ out_real = torch.einsum('...ki,kij->...kj', x.real, w_real) - torch.einsum('...ki,kij->...kj', x.imag, w_imag)
372
+ out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real)
373
+ return torch.complex(out_real, out_imag)
374
+
375
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
376
+ B, N, D = x.shape
377
+ x_2d = x.reshape(B, H, W, D)
378
+
379
+ # 2D Real FFT on spatial dimensions
380
+ x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') # [B, H, W//2+1, D]
381
+
382
+ # Reshape channel dim for block-diagonal MLP: D → (num_blocks, block_size)
383
+ Hf, Wf = x_freq.shape[1], x_freq.shape[2]
384
+ x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size)
385
+
386
+ # Block MLP Layer 1: operates on last dim (block_size)
387
+ # x_freq: [B, Hf, Wf, num_blocks, block_size]
388
+ # w1: [num_blocks, block_size, block_size]
389
+ x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag)
390
+ x_freq = x_freq + self.b1 # Broadcast bias (real only)
391
+ x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag))
392
+
393
+ # Block MLP Layer 2
394
+ x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag)
395
+ x_freq = x_freq + self.b2
396
+
397
+ # Reshape back to [B, Hf, Wf, D]
398
+ x_freq = x_freq.reshape(B, Hf, Wf, D)
399
+
400
+ # Soft-shrinkage (sparsity in Fourier domain)
401
+ magnitude = x_freq.abs()
402
+ shrunk_mag = F.relu(magnitude - self.sparsity_threshold)
403
+ # Preserve phase, shrink magnitude
404
+ x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8))
405
+
406
+ # Inverse FFT
407
+ x_out = torch.fft.irfft2(x_freq, s=(H, W), dim=(1, 2), norm='ortho')
408
+ return x_out.reshape(B, N, D)
409
+
410
+
411
+ class GatedLinearRecurrence(nn.Module):
412
+ """Pathway 2: Bidirectional Gated Linear Recurrence (RG-LRU inspired).
413
+ O(N) complexity with O(1) state per position.
414
+
415
+ h_t = a_t * h_{t-1} + sqrt(1 - a_t^2) * (i_t * x_t)
416
+ where a_t = sigmoid(Λ)^(c * sigmoid(W_a * x_t))
417
+ """
418
+ def __init__(self, dim: int, recurrence_dim: int):
419
+ super().__init__()
420
+ self.dim = dim
421
+ self.rec_dim = recurrence_dim
422
+
423
+ # Project to recurrence space
424
+ self.proj_in = nn.Linear(dim, recurrence_dim * 2) # Forward + backward
425
+
426
+ # Gating parameters
427
+ self.W_a = nn.Linear(recurrence_dim, recurrence_dim, bias=False)
428
+ self.W_x = nn.Linear(recurrence_dim, recurrence_dim, bias=False)
429
+ self.Lambda = nn.Parameter(torch.randn(recurrence_dim) * 0.5 + 2.0) # Init for decay ~0.88-0.95
430
+ self.c = 8.0 # Decay scaling constant (from Griffin)
431
+
432
+ # Output projection
433
+ self.proj_out = nn.Linear(recurrence_dim * 2, dim)
434
+
435
+ def _scan(self, x: torch.Tensor) -> torch.Tensor:
436
+ """Sequential scan for a single direction. x: [B, N, rec_dim]"""
437
+ B, N, D = x.shape
438
+
439
+ # Compute gates (can be parallelized)
440
+ a_base = torch.sigmoid(self.Lambda) # [D]
441
+ r = torch.sigmoid(self.W_a(x)) # [B, N, D] - recurrence gate
442
+ i = torch.sigmoid(self.W_x(x)) # [B, N, D] - input gate
443
+
444
+ # a_t = a_base^(c * r_t) — data-dependent decay
445
+ a = a_base.pow(self.c * r) # [B, N, D]
446
+
447
+ # Normalized input: sqrt(1 - a^2) for variance preservation
448
+ input_scale = torch.sqrt(1.0 - a * a + 1e-8)
449
+ scaled_input = input_scale * (i * x) # [B, N, D]
450
+
451
+ # Sequential recurrence (use parallel scan in production)
452
+ outputs = []
453
+ h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
454
+ for t in range(N):
455
+ h = a[:, t] * h + scaled_input[:, t]
456
+ outputs.append(h)
457
+
458
+ return torch.stack(outputs, dim=1) # [B, N, D]
459
+
460
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
461
+ B, N, D = x.shape
462
+
463
+ # Project to recurrence space and split for bidirectional
464
+ x_proj = self.proj_in(x) # [B, N, 2*rec_dim]
465
+ x_fwd, x_bwd = x_proj.chunk(2, dim=-1)
466
+
467
+ # Forward and backward scans
468
+ h_fwd = self._scan(x_fwd)
469
+ h_bwd = self._scan(x_bwd.flip(1)).flip(1)
470
+
471
+ # Merge bidirectional
472
+ h = torch.cat([h_fwd, h_bwd], dim=-1)
473
+ return self.proj_out(h)
474
+
475
+
476
+ class ManhattanSpatialGate(nn.Module):
477
+ """Pathway 3: Manhattan distance spatial decay gating.
478
+ Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields.
479
+ Uses windowed computation for efficiency.
480
+ """
481
+ def __init__(self, dim: int, num_heads: int, window: int = 16):
482
+ super().__init__()
483
+ self.dim = dim
484
+ self.num_heads = num_heads
485
+ self.head_dim = dim // num_heads
486
+ self.window = window
487
+
488
+ # Per-head learnable decay rate
489
+ # Initialize so gamma ∈ [0.7, 0.95] — multi-scale
490
+ self.gamma_logit = nn.Parameter(torch.linspace(0.85, 2.94, num_heads)) # sigmoid → [0.7, 0.95]
491
+
492
+ # Value and gate projections
493
+ self.v_proj = nn.Linear(dim, dim)
494
+ self.g_proj = nn.Linear(dim, dim)
495
+ self.o_proj = nn.Linear(dim, dim)
496
+
497
+ def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor:
498
+ """Compute Manhattan distance matrix between all 2D positions."""
499
+ coords = torch.stack(torch.meshgrid(
500
+ torch.arange(H, device=device),
501
+ torch.arange(W, device=device),
502
+ indexing='ij'
503
+ ), dim=-1).reshape(-1, 2).float() # [N, 2]
504
+
505
+ # Manhattan distance: |x1-x2| + |y1-y2|
506
+ dist = torch.cdist(coords, coords, p=1) # [N, N]
507
+ return dist
508
+
509
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
510
+ B, N, D = x.shape
511
+
512
+ # Compute spatial decay
513
+ gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
514
+ manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
515
+
516
+ # Window the distance matrix for efficiency
517
+ # Only compute decay for positions within window distance
518
+ decay_mask = (manhattan_dist <= self.window).float()
519
+
520
+ # Per-head decay: gamma_h^dist
521
+ decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
522
+ decay = decay * decay_mask[None, :, :]
523
+
524
+ # Value and gate
525
+ v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
526
+ g = torch.sigmoid(self.g_proj(x))
527
+
528
+ # Apply spatial decay to values
529
+ # [B, heads, N, head_dim] = [heads, N, N] @ [B, heads, N, head_dim]
530
+ v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
531
+ out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim]
532
+
533
+ # Normalize by decay sum
534
+ decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8 # [1, heads, N, 1]
535
+ out = out / decay_sum
536
+
537
+ out = out.permute(0, 2, 1, 3).reshape(B, N, D) # [B, N, D]
538
+ out = out * g # Gating
539
+ return self.o_proj(out)
540
+
541
+
542
+ class GRFM(nn.Module):
543
+ """Gated Recurrent Fourier Mixer — the core innovation of IRIS.
544
+
545
+ Fuses three complementary pathways:
546
+ 1. Fourier Global Mixing (O(N log N)) — captures textures, patterns
547
+ 2. Gated Linear Recurrence (O(N)) — captures sequential/local dependencies
548
+ 3. Manhattan Spatial Gate — provides 2D inductive bias
549
+
550
+ Pathways are combined via learned adaptive gating.
551
+ """
552
+ def __init__(self, config: IRISConfig):
553
+ super().__init__()
554
+ D = config.hidden_dim
555
+
556
+ self.fourier = FourierMixingPathway(D, config.fourier_num_blocks, config.sparsity_threshold)
557
+ self.recurrence = GatedLinearRecurrence(D, config.recurrence_dim)
558
+ self.spatial = ManhattanSpatialGate(D, config.num_heads, config.manhattan_window)
559
+
560
+ # Adaptive gate: learns to blend Fourier vs Recurrence based on content
561
+ self.blend_gate = nn.Sequential(
562
+ nn.Linear(D, D),
563
+ nn.SiLU(),
564
+ nn.Linear(D, D),
565
+ nn.Sigmoid(),
566
+ )
567
+
568
+ # Spatial pathway weight (smaller contribution, additive)
569
+ self.spatial_scale = nn.Parameter(torch.tensor(0.1))
570
+
571
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
572
+ # Three pathways
573
+ x_fourier = self.fourier(x, H, W)
574
+ x_recurrent = self.recurrence(x)
575
+ x_spatial = self.spatial(x, H, W)
576
+
577
+ # Adaptive blending
578
+ gate = self.blend_gate(x) # [B, N, D] values in [0, 1]
579
+
580
+ # Fourier for global structure, recurrence for local detail
581
+ output = gate * x_fourier + (1 - gate) * x_recurrent
582
+
583
+ # Add spatial bias (small contribution)
584
+ output = output + self.spatial_scale * x_spatial
585
+
586
+ return output
587
+
588
+
589
+ # ============================================================================
590
+ # Cross-Attention (for text conditioning)
591
+ # ============================================================================
592
+
593
+ class CrossAttention(nn.Module):
594
+ """Efficient cross-attention for text conditioning.
595
+ Only 77 text tokens → O(N × 77 × d) per layer, very cheap.
596
+ """
597
+ def __init__(self, dim: int, text_dim: int, num_heads: int, head_dim: int):
598
+ super().__init__()
599
+ self.num_heads = num_heads
600
+ self.head_dim = head_dim
601
+ self.scale = head_dim ** -0.5
602
+
603
+ self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=False)
604
+ self.k_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False)
605
+ self.v_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False)
606
+ self.o_proj = nn.Linear(num_heads * head_dim, dim)
607
+
608
+ # QK normalization for stability (from SANA-Sprint)
609
+ self.q_norm = nn.RMSNorm(head_dim)
610
+ self.k_norm = nn.RMSNorm(head_dim)
611
+
612
+ def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
613
+ B, N, _ = x.shape
614
+ _, S, _ = context.shape
615
+
616
+ q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
617
+ k = self.k_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2)
618
+ v = self.v_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2)
619
+
620
+ # QK normalization
621
+ q = self.q_norm(q)
622
+ k = self.k_norm(k)
623
+
624
+ # Scaled dot-product attention
625
+ attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
626
+ attn = F.softmax(attn, dim=-1)
627
+ out = torch.matmul(attn, v)
628
+
629
+ out = out.transpose(1, 2).reshape(B, N, -1)
630
+ return self.o_proj(out)
631
+
632
+
633
+ # ============================================================================
634
+ # Feed-Forward Network (SwiGLU)
635
+ # ============================================================================
636
+
637
+ class SwiGLUFFN(nn.Module):
638
+ """SwiGLU Feed-Forward Network — better than GELU for transformers."""
639
+ def __init__(self, dim: int, ratio: float = 2.667, dropout: float = 0.0):
640
+ super().__init__()
641
+ hidden = int(dim * ratio)
642
+ # Ensure hidden is multiple of 64 for hardware efficiency
643
+ hidden = ((hidden + 63) // 64) * 64
644
+
645
+ self.w1 = nn.Linear(dim, hidden, bias=False)
646
+ self.w2 = nn.Linear(dim, hidden, bias=False) # Gate
647
+ self.w3 = nn.Linear(hidden, dim, bias=False)
648
+ self.dropout = nn.Dropout(dropout)
649
+
650
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
651
+ return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x)))
652
+
653
+
654
+ # ============================================================================
655
+ # Prelude Block (unique weights, conv-based)
656
+ # ============================================================================
657
+
658
+ class PreludeBlock(nn.Module):
659
+ """Lightweight conv-based block for initial feature extraction."""
660
+ def __init__(self, dim: int):
661
+ super().__init__()
662
+ self.norm1 = nn.LayerNorm(dim)
663
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=5, padding=2, groups=dim)
664
+ self.pointwise = nn.Linear(dim, dim)
665
+ self.norm2 = nn.LayerNorm(dim)
666
+ self.ffn = SwiGLUFFN(dim)
667
+
668
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
669
+ # Depthwise conv path
670
+ h = self.norm1(x)
671
+ h = h.transpose(1, 2) # [B, D, N]
672
+ h = self.dwconv(h).transpose(1, 2) # [B, N, D]
673
+ h = F.silu(h)
674
+ h = self.pointwise(h)
675
+ x = x + h
676
+
677
+ # FFN
678
+ x = x + self.ffn(self.norm2(x))
679
+ return x
680
+
681
+
682
+ # ============================================================================
683
+ # Core Block (shared weights, the heart of IRIS)
684
+ # ============================================================================
685
+
686
+ class CoreLayer(nn.Module):
687
+ """Single layer within the core block.
688
+ Contains: GRFM + Cross-Attention + FFN, all with adaLN-Zero conditioning.
689
+ """
690
+ def __init__(self, config: IRISConfig):
691
+ super().__init__()
692
+ D = config.hidden_dim
693
+
694
+ # Sub-layer 1: GRFM
695
+ self.norm1 = nn.LayerNorm(D, elementwise_affine=False)
696
+ self.grfm = GRFM(config)
697
+
698
+ # Sub-layer 2: Cross-Attention
699
+ self.norm2 = nn.LayerNorm(D, elementwise_affine=False)
700
+ self.cross_attn = CrossAttention(D, config.text_dim, config.num_heads, config.head_dim)
701
+
702
+ # Sub-layer 3: FFN
703
+ self.norm3 = nn.LayerNorm(D, elementwise_affine=False)
704
+ self.ffn = SwiGLUFFN(D, config.ffn_ratio, config.dropout)
705
+
706
+ # adaLN-Zero: 9 modulations (scale1, shift1, gate1, scale2, shift2, gate2, scale3, shift3, gate3)
707
+ self.adaln = AdaLNSingle(D, num_modulations=9)
708
+
709
+ def _modulate(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
710
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
711
+
712
+ def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor,
713
+ H: int, W: int) -> torch.Tensor:
714
+ """
715
+ x: [B, N, D] — token sequence
716
+ c: [B, D] — conditioning vector (timestep + iteration)
717
+ text_tokens: [B, S, text_dim] — CLIP text tokens
718
+ H, W: spatial dimensions of token grid
719
+ """
720
+ s1, sh1, g1, s2, sh2, g2, s3, sh3, g3 = self.adaln(c)
721
+
722
+ # GRFM with adaLN-Zero
723
+ h = self._modulate(self.norm1(x), s1, sh1)
724
+ h = self.grfm(h, H, W)
725
+ x = x + g1.unsqueeze(1) * h
726
+
727
+ # Cross-attention with adaLN-Zero
728
+ h = self._modulate(self.norm2(x), s2, sh2)
729
+ h = self.cross_attn(h, text_tokens)
730
+ x = x + g2.unsqueeze(1) * h
731
+
732
+ # FFN with adaLN-Zero
733
+ h = self._modulate(self.norm3(x), s3, sh3)
734
+ h = self.ffn(h)
735
+ x = x + g3.unsqueeze(1) * h
736
+
737
+ return x
738
+
739
+
740
+ class CoreBlock(nn.Module):
741
+ """The shared-weight core block, iterated r times.
742
+ Contains multiple CoreLayers to give sufficient per-iteration capacity.
743
+ """
744
+ def __init__(self, config: IRISConfig):
745
+ super().__init__()
746
+ self.layers = nn.ModuleList([
747
+ CoreLayer(config) for _ in range(config.num_core_layers)
748
+ ])
749
+
750
+ def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor,
751
+ H: int, W: int) -> torch.Tensor:
752
+ for layer in self.layers:
753
+ x = layer(x, c, text_tokens, H, W)
754
+ return x
755
+
756
+
757
+ # ============================================================================
758
+ # Coda Block (unique weights, final refinement)
759
+ # ============================================================================
760
+
761
+ class LocalWindowAttention(nn.Module):
762
+ """Window-based local attention for final refinement.
763
+ Small window (8×8) for efficient local detail refinement.
764
+ """
765
+ def __init__(self, dim: int, num_heads: int, head_dim: int, window_size: int = 8):
766
+ super().__init__()
767
+ self.num_heads = num_heads
768
+ self.head_dim = head_dim
769
+ self.window_size = window_size
770
+ self.scale = head_dim ** -0.5
771
+
772
+ self.qkv = nn.Linear(dim, 3 * num_heads * head_dim, bias=False)
773
+ self.o_proj = nn.Linear(num_heads * head_dim, dim)
774
+
775
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
776
+ B, N, D = x.shape
777
+ ws = self.window_size
778
+
779
+ # Reshape to 2D and partition into windows
780
+ x_2d = x.reshape(B, H, W, D)
781
+
782
+ # Pad if necessary
783
+ pad_h = (ws - H % ws) % ws
784
+ pad_w = (ws - W % ws) % ws
785
+ if pad_h > 0 or pad_w > 0:
786
+ x_2d = F.pad(x_2d, (0, 0, 0, pad_w, 0, pad_h))
787
+
788
+ Hp, Wp = x_2d.shape[1], x_2d.shape[2]
789
+ nH, nW = Hp // ws, Wp // ws
790
+
791
+ # [B, nH, ws, nW, ws, D] → [B*nH*nW, ws*ws, D]
792
+ x_win = x_2d.reshape(B, nH, ws, nW, ws, D)
793
+ x_win = x_win.permute(0, 1, 3, 2, 4, 5).reshape(-1, ws * ws, D)
794
+
795
+ # QKV and attention within windows
796
+ qkv = self.qkv(x_win).reshape(-1, ws * ws, 3, self.num_heads, self.head_dim)
797
+ q, k, v = qkv.unbind(2)
798
+ q = q.transpose(1, 2)
799
+ k = k.transpose(1, 2)
800
+ v = v.transpose(1, 2)
801
+
802
+ attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
803
+ attn = F.softmax(attn, dim=-1)
804
+ out = torch.matmul(attn, v)
805
+
806
+ out = out.transpose(1, 2).reshape(-1, ws * ws, self.num_heads * self.head_dim)
807
+ out = self.o_proj(out)
808
+
809
+ # Unpartition
810
+ out = out.reshape(B, nH, nW, ws, ws, D)
811
+ out = out.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, D)
812
+
813
+ # Remove padding
814
+ out = out[:, :H, :W, :].reshape(B, N, D)
815
+ return out
816
+
817
+
818
+ class CodaBlock(nn.Module):
819
+ """Final refinement block with local window attention."""
820
+ def __init__(self, config: IRISConfig):
821
+ super().__init__()
822
+ D = config.hidden_dim
823
+ self.norm1 = nn.LayerNorm(D)
824
+ self.attn = LocalWindowAttention(D, config.num_heads, config.head_dim, window_size=8)
825
+ self.norm2 = nn.LayerNorm(D)
826
+ self.ffn = SwiGLUFFN(D, config.ffn_ratio)
827
+
828
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
829
+ x = x + self.attn(self.norm1(x), H, W)
830
+ x = x + self.ffn(self.norm2(x))
831
+ return x
832
+
833
+
834
+ # ============================================================================
835
+ # IRIS Generator (Main Model)
836
+ # ============================================================================
837
+
838
+ class IRISGenerator(nn.Module):
839
+ """
840
+ IRIS: Iterative Recurrent Image Synthesis
841
+
842
+ The main denoising network with Prelude-Core-Coda structure.
843
+ Predicts velocity field v for rectified flow training.
844
+ """
845
+ def __init__(self, config: IRISConfig):
846
+ super().__init__()
847
+ self.config = config
848
+ D = config.hidden_dim
849
+
850
+ # Patch embedding: latent patches → tokens
851
+ self.patch_embed = nn.Linear(config.patch_dim, D)
852
+
853
+ # Positional embedding (learned)
854
+ self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, D) * 0.02)
855
+
856
+ # Conditioning
857
+ self.time_embed = TimestepEmbedding(D)
858
+ self.iter_embed = IterationEmbedding(config.max_iterations, D)
859
+ self.text_proj = nn.Linear(config.text_dim, D) # Project CLIP text to model dim
860
+
861
+ # Global text pooling for conditioning
862
+ self.text_pool_proj = nn.Sequential(
863
+ nn.Linear(config.text_dim, D),
864
+ nn.SiLU(),
865
+ nn.Linear(D, D),
866
+ )
867
+
868
+ # Prelude (unique weights)
869
+ self.prelude = nn.ModuleList([PreludeBlock(D) for _ in range(config.num_prelude_blocks)])
870
+
871
+ # Core (shared weights, iterated)
872
+ self.core = CoreBlock(config)
873
+
874
+ # Long skip connection (from Diffusion-RWKV: linear(cat(shallow, deep)))
875
+ self.skip_proj = nn.Linear(2 * D, D)
876
+
877
+ # Coda (unique weights)
878
+ self.coda = nn.ModuleList([CodaBlock(config) for _ in range(config.num_coda_blocks)])
879
+
880
+ # Output projection: tokens → latent patches
881
+ self.final_norm = nn.LayerNorm(D)
882
+ self.output_proj = nn.Linear(D, config.patch_dim)
883
+
884
+ # Zero-init output for stable training start
885
+ nn.init.zeros_(self.output_proj.weight)
886
+ nn.init.zeros_(self.output_proj.bias)
887
+
888
+ # Precompute patch spatial dimensions
889
+ self.patch_h = config.latent_spatial // config.patch_size
890
+ self.patch_w = config.latent_spatial // config.patch_size
891
+
892
+ def patchify(self, z: torch.Tensor) -> torch.Tensor:
893
+ """Convert latent z [B, C, H, W] → patches [B, N, patch_dim]."""
894
+ B, C, H, W = z.shape
895
+ p = self.config.patch_size
896
+ z = z.reshape(B, C, H // p, p, W // p, p)
897
+ z = z.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * p * p)
898
+ return z
899
+
900
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
901
+ """Convert patches [B, N, patch_dim] → latent [B, C, H, W]."""
902
+ B, N, _ = x.shape
903
+ p = self.config.patch_size
904
+ C = self.config.latent_channels
905
+ H = self.patch_h
906
+ W = self.patch_w
907
+ x = x.reshape(B, H, W, C, p, p)
908
+ x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H * p, W * p)
909
+ return x
910
+
911
+ def forward(
912
+ self,
913
+ z_t: torch.Tensor, # Noisy latent [B, C, H, W]
914
+ t: torch.Tensor, # Timestep [B] in [0, 1]
915
+ text_tokens: torch.Tensor, # CLIP text embeddings [B, S, text_dim]
916
+ num_iterations: Optional[int] = None, # Override iteration count
917
+ ) -> torch.Tensor:
918
+ """Predict velocity field v(z_t, t, c) for rectified flow."""
919
+ B = z_t.shape[0]
920
+ r = num_iterations or self.config.default_iterations
921
+ H, W = self.patch_h, self.patch_w
922
+
923
+ # Patchify and embed
924
+ x = self.patch_embed(self.patchify(z_t)) + self.pos_embed
925
+
926
+ # Timestep conditioning
927
+ t_emb = self.time_embed(t * self.config.num_timesteps) # [B, D]
928
+
929
+ # Text conditioning (project to model dim for cross-attention)
930
+ text_projected = self.text_proj(text_tokens) # [B, S, D]
931
+
932
+ # Global text pool for adaLN conditioning
933
+ text_global = self.text_pool_proj(text_tokens.mean(dim=1)) # [B, D]
934
+
935
+ # ============ PRELUDE ============
936
+ for block in self.prelude:
937
+ x = block(x)
938
+
939
+ # Save for long skip connection
940
+ x_shallow = x
941
+
942
+ # ============ CORE (iterated r times) ============
943
+ for i in range(r):
944
+ # Iteration-aware conditioning
945
+ iter_idx = torch.full((B,), i, device=z_t.device, dtype=torch.long)
946
+ i_emb = self.iter_embed(iter_idx) # [B, D]
947
+
948
+ # Combined conditioning: timestep + iteration + text global
949
+ c = t_emb + i_emb + text_global # [B, D]
950
+
951
+ # Apply shared core block (pass original text_tokens for cross-attention)
952
+ x = self.core(x, c, text_tokens, H, W)
953
+
954
+ # Long skip connection (from Diffusion-RWKV paper)
955
+ x = self.skip_proj(torch.cat([x_shallow, x], dim=-1))
956
+
957
+ # ============ CODA ============
958
+ for block in self.coda:
959
+ x = block(x, H, W)
960
+
961
+ # Output projection
962
+ x = self.final_norm(x)
963
+ x = self.output_proj(x)
964
+
965
+ # Unpatchify to latent shape
966
+ v_pred = self.unpatchify(x)
967
+ return v_pred
968
+
969
+
970
+ # ============================================================================
971
+ # Full IRIS System
972
+ # ============================================================================
973
+
974
+ class IRIS(nn.Module):
975
+ """Complete IRIS system: VAE + Generator.
976
+
977
+ For training: use train_step() which handles noise scheduling.
978
+ For inference: use generate() which runs the full pipeline.
979
+ """
980
+ def __init__(self, config: IRISConfig):
981
+ super().__init__()
982
+ self.config = config
983
+ self.vae = WaveletVAE(config)
984
+ self.generator = IRISGenerator(config)
985
+
986
+ def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
987
+ """Encode images to latent space."""
988
+ return self.vae.encode(images)
989
+
990
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
991
+ """Decode latent to images."""
992
+ return self.vae.decode(z)
993
+
994
+ def get_velocity_target(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
995
+ """Rectified flow velocity target: v = noise - z_0."""
996
+ return noise - z_0
997
+
998
+ def add_noise(self, z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
999
+ """Rectified flow forward process: z_t = (1-t)*z_0 + t*noise."""
1000
+ t_expand = t[:, None, None, None]
1001
+ return (1 - t_expand) * z_0 + t_expand * noise
1002
+
1003
+ def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
1004
+ """Sample timesteps from logit-normal distribution (from SD3/RF).
1005
+ Concentrates sampling on intermediate timesteps where learning is hardest.
1006
+ """
1007
+ u = torch.randn(batch_size, device=device)
1008
+ t = torch.sigmoid(u) # Logit-normal with mean=0, std=1
1009
+ # Clamp to avoid t=0 and t=1
1010
+ t = t.clamp(1e-5, 1 - 1e-5)
1011
+ return t
1012
+
1013
+ def train_step(
1014
+ self,
1015
+ images: torch.Tensor,
1016
+ text_tokens: torch.Tensor,
1017
+ num_iterations: Optional[int] = None,
1018
+ ) -> dict:
1019
+ """Single training step for rectified flow.
1020
+
1021
+ Returns dict with loss and diagnostics.
1022
+ """
1023
+ B = images.shape[0]
1024
+ device = images.device
1025
+
1026
+ # Encode to latent
1027
+ z_0, mean, logvar = self.encode(images)
1028
+
1029
+ # Sample noise and timesteps
1030
+ noise = torch.randn_like(z_0)
1031
+ t = self.sample_timesteps(B, device)
1032
+
1033
+ # Create noisy latent
1034
+ z_t = self.add_noise(z_0, noise, t)
1035
+
1036
+ # Predict velocity
1037
+ # Randomly sample iteration count for training robustness
1038
+ if num_iterations is None:
1039
+ r_choices = [4, 6, 8, 10, 12]
1040
+ r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
1041
+ else:
1042
+ r = num_iterations
1043
+
1044
+ v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
1045
+ v_target = self.get_velocity_target(z_0, noise)
1046
+
1047
+ # SNR-weighted loss (from Rectified Flow paper)
1048
+ # w(t) = t / (1 - t) — emphasizes high-noise timesteps
1049
+ w = t / (1 - t + 1e-8)
1050
+ w = w[:, None, None, None]
1051
+
1052
+ # Velocity matching loss
1053
+ velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
1054
+
1055
+ # VAE KL loss
1056
+ kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()
1057
+
1058
+ return {
1059
+ 'loss': velocity_loss + 0.001 * kl_loss,
1060
+ 'velocity_loss': velocity_loss.item(),
1061
+ 'kl_loss': kl_loss.item(),
1062
+ 'mean_t': t.mean().item(),
1063
+ }
1064
+
1065
+ @torch.no_grad()
1066
+ def generate(
1067
+ self,
1068
+ text_tokens: torch.Tensor,
1069
+ num_steps: int = 4,
1070
+ num_iterations: int = 8,
1071
+ cfg_scale: float = 4.0,
1072
+ seed: Optional[int] = None,
1073
+ ) -> torch.Tensor:
1074
+ """Generate images from text conditioning using Euler solver.
1075
+
1076
+ Args:
1077
+ text_tokens: [B, S, text_dim] CLIP text embeddings
1078
+ num_steps: Number of ODE solver steps (1-50)
1079
+ num_iterations: Core iterations per step (quality budget)
1080
+ cfg_scale: Classifier-free guidance scale
1081
+ seed: Random seed for reproducibility
1082
+ """
1083
+ B, S, _ = text_tokens.shape
1084
+ device = text_tokens.device
1085
+
1086
+ if seed is not None:
1087
+ torch.manual_seed(seed)
1088
+
1089
+ # Start from pure noise
1090
+ z = torch.randn(B, self.config.latent_channels,
1091
+ self.config.latent_spatial, self.config.latent_spatial,
1092
+ device=device)
1093
+
1094
+ # Euler solver for rectified flow ODE: dz/dt = -v(z, t)
1095
+ # Integrate from t=1 (noise) to t=0 (data)
1096
+ dt = 1.0 / num_steps
1097
+
1098
+ for step in range(num_steps):
1099
+ t_val = 1.0 - step * dt
1100
+ t = torch.full((B,), t_val, device=device)
1101
+
1102
+ # Predict velocity
1103
+ v = self.generator(z, t, text_tokens, num_iterations=num_iterations)
1104
+
1105
+ # Classifier-free guidance (if cfg_scale > 1)
1106
+ if cfg_scale > 1.0:
1107
+ null_tokens = torch.zeros_like(text_tokens)
1108
+ v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations)
1109
+ v = v_uncond + cfg_scale * (v - v_uncond)
1110
+
1111
+ # Euler step: z = z - dt * v
1112
+ z = z - dt * v
1113
+
1114
+ # Decode to image
1115
+ images = self.decode(z)
1116
+ images = images.clamp(-1, 1)
1117
+ return images
1118
+
1119
+
1120
+ # ============================================================================
1121
+ # Utility Functions
1122
+ # ============================================================================
1123
+
1124
+ def count_parameters(model: nn.Module) -> dict:
1125
+ """Count parameters in each component."""
1126
+ counts = {}
1127
+ total = 0
1128
+ for name, module in model.named_children():
1129
+ n = sum(p.numel() for p in module.parameters())
1130
+ counts[name] = n
1131
+ total += n
1132
+ counts['total'] = total
1133
+
1134
+ # Separate trainable vs frozen
1135
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
1136
+ counts['trainable'] = trainable
1137
+ return counts
1138
+
1139
+
1140
+ def estimate_memory_mb(model: nn.Module, dtype=torch.float16) -> float:
1141
+ """Estimate model memory in MB."""
1142
+ bytes_per_param = 2 if dtype == torch.float16 else 4
1143
+ total_params = sum(p.numel() for p in model.parameters())
1144
+ return total_params * bytes_per_param / (1024 * 1024)
1145
+
1146
+
1147
+ def create_iris_small(latent_spatial: int = 32) -> IRIS:
1148
+ """Create IRIS-Small: ~75M generator params, suitable for mobile."""
1149
+ config = IRISConfig(
1150
+ latent_channels=16,
1151
+ latent_spatial=latent_spatial,
1152
+ hidden_dim=512,
1153
+ num_heads=8,
1154
+ head_dim=64,
1155
+ ffn_ratio=2.667,
1156
+ num_prelude_blocks=2,
1157
+ num_core_layers=4,
1158
+ num_coda_blocks=2,
1159
+ default_iterations=8,
1160
+ max_iterations=16,
1161
+ fourier_num_blocks=8,
1162
+ sparsity_threshold=0.01,
1163
+ recurrence_dim=256,
1164
+ manhattan_window=16,
1165
+ text_dim=768,
1166
+ max_text_tokens=77,
1167
+ patch_size=2,
1168
+ )
1169
+ return IRIS(config)
1170
+
1171
+
1172
+ def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
1173
+ """Create IRIS-Tiny: ~30M generator params, ultra-mobile."""
1174
+ config = IRISConfig(
1175
+ latent_channels=8,
1176
+ latent_spatial=latent_spatial,
1177
+ hidden_dim=384,
1178
+ num_heads=6,
1179
+ head_dim=64,
1180
+ ffn_ratio=2.667,
1181
+ num_prelude_blocks=1,
1182
+ num_core_layers=3,
1183
+ num_coda_blocks=1,
1184
+ default_iterations=8,
1185
+ max_iterations=16,
1186
+ fourier_num_blocks=6,
1187
+ sparsity_threshold=0.01,
1188
+ recurrence_dim=192,
1189
+ manhattan_window=12,
1190
+ text_dim=768,
1191
+ max_text_tokens=77,
1192
+ patch_size=2,
1193
+ )
1194
+ return IRIS(config)
1195
+
1196
+
1197
+ def create_iris_base(latent_spatial: int = 32) -> IRIS:
1198
+ """Create IRIS-Base: ~150M generator params, quality-focused."""
1199
+ config = IRISConfig(
1200
+ latent_channels=16,
1201
+ latent_spatial=latent_spatial,
1202
+ hidden_dim=768,
1203
+ num_heads=12,
1204
+ head_dim=64,
1205
+ ffn_ratio=2.667,
1206
+ num_prelude_blocks=2,
1207
+ num_core_layers=6,
1208
+ num_coda_blocks=2,
1209
+ default_iterations=8,
1210
+ max_iterations=16,
1211
+ fourier_num_blocks=12,
1212
+ sparsity_threshold=0.01,
1213
+ recurrence_dim=384,
1214
+ manhattan_window=16,
1215
+ text_dim=768,
1216
+ max_text_tokens=77,
1217
+ patch_size=2,
1218
+ )
1219
+ return IRIS(config)
1220
+
1221
+
1222
+ if __name__ == "__main__":
1223
+ print("=" * 70)
1224
+ print("IRIS: Iterative Recurrent Image Synthesis")
1225
+ print("=" * 70)
1226
+
1227
+ # Create model variants
1228
+ for name, create_fn in [("IRIS-Tiny", create_iris_tiny),
1229
+ ("IRIS-Small", create_iris_small),
1230
+ ("IRIS-Base", create_iris_base)]:
1231
+ print(f"\n{'─' * 50}")
1232
+ print(f" {name}")
1233
+ print(f"{'─' * 50}")
1234
+ model = create_fn()
1235
+ counts = count_parameters(model)
1236
+ mem_fp16 = estimate_memory_mb(model, torch.float16)
1237
+ mem_fp32 = estimate_memory_mb(model, torch.float32)
1238
+
1239
+ print(f" Total params: {counts['total']:>12,}")
1240
+ print(f" Trainable params: {counts['trainable']:>12,}")
1241
+ print(f" Memory (fp16): {mem_fp16:>10.1f} MB")
1242
+ print(f" Memory (fp32): {mem_fp32:>10.1f} MB")
1243
+ print(f" Components:")
1244
+ for k, v in counts.items():
1245
+ if k not in ('total', 'trainable'):
1246
+ print(f" {k:20s}: {v:>10,} ({v/counts['total']*100:.1f}%)")