asdf98 commited on
Commit
91e9e63
Β·
verified Β·
1 Parent(s): 89579fd

Add ARCHITECTURE.md

Browse files
Files changed (1) hide show
  1. ARCHITECTURE.md +364 -0
ARCHITECTURE.md ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IRIS: Iterative Recurrent Image Synthesis
2
+ ## A Novel Architecture for Mobile-First High-Quality Image Generation
3
+
4
+ ### Version 1.0 | Architecture Design Document
5
+
6
+ ---
7
+
8
+ ## 1. Executive Summary
9
+
10
+ **IRIS** (Iterative Recurrent Image Synthesis) is a novel image generation architecture designed from first principles to achieve high visual quality on mobile devices (< 3-4GB RAM). It combines six key innovations drawn from cutting-edge research across multiple domains:
11
+
12
+ 1. **Wavelet-Frequency Latent Space** β€” 16Γ— spatial compression via Haar DWT + learned VAE, operating in frequency-aware space
13
+ 2. **Recurrent Depth Core** β€” Shared-weight denoising block iterated N times (inspired by Huginn), achieving deep model behavior from tiny parameter count
14
+ 3. **Gated Recurrent Fourier Mixer (GRFM)** β€” Novel token mixing that combines RG-LRU gated recurrence with Adaptive Fourier Neural Operators, replacing O(NΒ²) attention with O(N log N) global mixing
15
+ 4. **Manhattan Spatial Decay** β€” Learned per-head 2D spatial inductive bias via Manhattan distance exponential decay (from RMT)
16
+ 5. **Rectified Flow with Consistency Distillation** β€” Straight ODE paths for few-step generation (1-4 steps)
17
+ 6. **Adaptive Compute Budget** β€” Same model, variable quality: 4 iterations for mobile, 16 for quality
18
+
19
+ ### Target Specifications
20
+
21
+ | Metric | Target | Achieved By |
22
+ |--------|--------|------------|
23
+ | Total Parameters | < 250M (generator) | Recurrent depth + efficient blocks |
24
+ | RAM (inference) | < 3GB total | ~600MB model + ~400MB VAE + ~200MB text encoder + buffers |
25
+ | Inference Steps | 1-4 | Rectified flow + consistency distillation |
26
+ | Core Iterations | 4 (fast) / 8-16 (quality) | Recurrent depth, shared weights |
27
+ | Image Quality | Competitive with SDXL at 512px | Frequency-aware latent + proper training |
28
+ | Prompt Adherence | Strong | CLIP-L/14 cross-attention conditioning |
29
+ | Training Cost | < 200 A100 GPU-hours (Stage 1) | Efficient architecture + progressive training |
30
+
31
+ ---
32
+
33
+ ## 2. Theoretical Foundations
34
+
35
+ ### 2.1 Why Current Approaches Fail on Mobile
36
+
37
+ **Problem 1: Parameter Explosion in Transformers**
38
+ Standard DiT/UNet architectures use independent parameters for each layer. A 24-layer DiT-XL has ~675M params. Each self-attention layer stores O(dΒ²) params for Q,K,V projections Γ— number of layers.
39
+
40
+ **Problem 2: Quadratic Attention Complexity**
41
+ For 512Γ—512 images with 8Γ— VAE downsampling: 64Γ—64 = 4096 tokens. Self-attention requires 4096Β² Γ— d operations per layer. At d=768, that's ~12.9 GFLOPS per attention layer.
42
+
43
+ **Problem 3: Step Count**
44
+ Standard diffusion requires 20-50 neural function evaluations (NFE). Even a small model Γ— 50 steps = impractical.
45
+
46
+ ### 2.2 Our Solution: Mathematical Framework
47
+
48
+ #### 2.2.1 Recurrent Depth as Implicit Neural ODE
49
+
50
+ The key insight from Huginn (arXiv:2502.05171): a shared-weight block `R` applied iteratively defines a discrete neural ODE:
51
+
52
+ ```
53
+ s_0 = P(x) [Prelude: encode input]
54
+ s_{i+1} = R(s_i, c, t) [Core: iterate with conditioning c and timestep t]
55
+ y = C(s_r) [Coda: decode output]
56
+ ```
57
+
58
+ This is mathematically equivalent to an Euler discretization of:
59
+ ```
60
+ ds/dΟ„ = F_ΞΈ(s(Ο„), c, t) where Ο„ ∈ [0, 1], discretized into r steps
61
+ ```
62
+
63
+ **Parameter efficiency**: If block R has P parameters, then r iterations give effective depth of rΓ—L layers (where L = layers in R) using only P parameters. A 6-layer block iterated 16 times = 96 effective layers.
64
+
65
+ **Connection to diffusion**: In standard diffusion, the denoiser f_ΞΈ is applied at each noise level t with the SAME parameters β€” this IS recurrent depth, but over the noise schedule axis. IRIS makes it recurrent over BOTH axes: noise schedule (outer loop, t) and computational depth (inner loop, Ο„).
66
+
67
+ #### 2.2.2 Gated Recurrent Fourier Mixer (GRFM) β€” Novel Contribution
68
+
69
+ We introduce GRFM, which processes the 2D token sequence through three parallel pathways merged multiplicatively:
70
+
71
+ **Pathway 1: Fourier Global Mixing (O(N log N))**
72
+ ```
73
+ x_fourier = IRFFT2(SoftShrink(BlockMLP(RFFT2(x))))
74
+ ```
75
+ From AFNO: captures global structure via frequency-domain mixing. The soft-shrinkage promotes sparsity in Fourier domain (images are naturally sparse in frequency).
76
+
77
+ **Pathway 2: Gated Linear Recurrence (O(N))**
78
+ ```
79
+ a_t = Οƒ(Ξ›)^(cΒ·Οƒ(W_a Β· x_t)) [decay gate, per-element]
80
+ i_t = Οƒ(W_x Β· x_t) [input gate]
81
+ h_t = a_t βŠ™ h_{t-1} + √(1 - a_tΒ²) βŠ™ (i_t βŠ™ x_t) [RG-LRU update]
82
+ x_recurrent = W_o Β· h_T
83
+ ```
84
+ From Griffin (arXiv:2402.19427): captures sequential dependencies with O(1) state per token position. Bidirectional (forward + backward scan).
85
+
86
+ **Pathway 3: Manhattan Spatial Gate**
87
+ ```
88
+ D_{nm} = Ξ³_head^(|x_n - x_m| + |y_n - y_m|) [Manhattan decay matrix]
89
+ gate = Οƒ(W_g Β· x) βŠ™ (D Β· (W_v Β· x))
90
+ ```
91
+ From RMT (arXiv:2309.11523): per-head learnable spatial decay provides multi-scale locality bias.
92
+
93
+ **Fusion (Novel)**:
94
+ ```
95
+ output = LayerNorm(x_fourier βŠ™ Οƒ(gate) + x_recurrent βŠ™ (1 - Οƒ(gate)))
96
+ ```
97
+
98
+ The gate adaptively selects between global Fourier features (textures, patterns) and local recurrent features (edges, fine details) based on spatial context. This is NOT a simple concatenation β€” it's a learned, spatially-varying interpolation.
99
+
100
+ #### 2.2.3 Wavelet-Frequency Latent Space
101
+
102
+ Instead of standard VAE operating on pixels, we first apply Haar DWT:
103
+ ```
104
+ x ∈ R^{3Γ—HΓ—W} β†’ DWT β†’ y ∈ R^{12Γ—H/2Γ—W/2}
105
+ ```
106
+
107
+ Then a lightweight VAE encoder compresses to:
108
+ ```
109
+ z ∈ R^{CΓ—H/8Γ—W/8} (effective 16Γ— total spatial compression from original)
110
+ ```
111
+
112
+ The VAE operates on wavelet coefficients, preserving frequency structure. The LL (low-low) subband carries global structure; LH, HL, HH carry directional high-frequency details. This means the latent space is inherently frequency-aware.
113
+
114
+ **Benefit**: The denoiser operates on a latent that already separates structure from detail, making the learning problem easier for a small model.
115
+
116
+ #### 2.2.4 Rectified Flow + Consistency Distillation
117
+
118
+ **Training Phase 1 (Rectified Flow)**:
119
+ ```
120
+ x_t = (1-t) Β· x_0 + t Β· Ξ΅ [linear interpolation]
121
+ v_target = Ξ΅ - x_0 [velocity field]
122
+ L = w(t) Β· ||v_ΞΈ(x_t, t, c) - v_target||Β²
123
+ w(t) = t/(1-t+Ξ΅) [SNR reweighting]
124
+ t ~ LogitNormal(0, 1) [concentrate on hard timesteps]
125
+ ```
126
+
127
+ **Training Phase 2 (Consistency Distillation)**:
128
+ ```
129
+ f_ΞΈ(x_t, t) = c_skip(t)Β·x_t + c_out(t)Β·F_ΞΈ(x_t, t)
130
+ L_CD = d(f_ΞΈ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(xΜ‚_{t_n}, t_n))
131
+ ```
132
+ Where θ⁻ is EMA of θ. This enables 1-4 step generation.
133
+
134
+ ---
135
+
136
+ ## 3. Architecture Details
137
+
138
+ ### 3.1 Overall Pipeline
139
+
140
+ ```
141
+ Text β†’ CLIP-L/14 β†’ c ∈ R^{77Γ—768}
142
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
143
+ Image β†’ Haar DWT β†’ WaveletVAE Encode β†’ zβ‚€ ∈ R^{CΓ—hΓ—w} β”‚
144
+ β”‚ β”‚
145
+ β”‚ Noise schedule (RF): β”‚
146
+ β”‚ z_t = (1-t)zβ‚€ + tΒ·Ξ΅ β”‚
147
+ β”‚ β”‚
148
+ β–Ό β”‚
149
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
150
+ β”‚ PRELUDE β”‚ β”‚
151
+ β”‚ (2 blocks) β”‚ β”‚
152
+ β”‚ PatchEmbed + β”‚ β”‚
153
+ β”‚ Initial mixing β”‚ β”‚
154
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
155
+ β”‚ β”‚
156
+ β–Ό β”‚
157
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
158
+ β”‚ CORE (shared) │◄── Iterate r β”‚
159
+ β”‚ GRFM Block β”‚ times β”‚
160
+ β”‚ + FFN β”‚ (4-16) β”‚
161
+ β”‚ + adaLN-Zero β”‚ β”‚
162
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
163
+ β”‚ β”‚
164
+ β–Ό β”‚
165
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
166
+ β”‚ CODA β”‚ β”‚
167
+ β”‚ (2 blocks) β”‚ β”‚
168
+ β”‚ Final refine + β”‚ β”‚
169
+ β”‚ Unpatchify β”‚ β”‚
170
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
171
+ β”‚ β”‚
172
+ β–Ό β”‚
173
+ vΜ‚ = predicted velocity β”‚
174
+ zβ‚€_pred = z_t - tΒ·vΜ‚ β”‚
175
+ β”‚ β”‚
176
+ β–Ό β”‚
177
+ WaveletVAE Decode β†’ Haar IDWT β†’ Imageβ”‚
178
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
179
+ ```
180
+
181
+ ### 3.2 Detailed Block Design
182
+
183
+ #### Prelude (2 blocks, unique weights)
184
+ ```python
185
+ class Prelude:
186
+ patch_embed: Conv2d(C_latent, D, kernel_size=2, stride=2) # 2Γ— spatial reduce
187
+ pos_embed: learned R^{(h/2 Γ— w/2) Γ— D}
188
+ blocks: [PreludeBlock Γ— 2]
189
+
190
+ class PreludeBlock:
191
+ norm1 β†’ DepthwiseSepConv3x3 β†’ GELU β†’ PointwiseConv β†’ norm2 β†’ FFN
192
+ # Uses conv instead of attention β€” cheap local feature extraction
193
+ ```
194
+
195
+ #### Core (shared weights, iterated r times)
196
+ ```python
197
+ class CoreBlock:
198
+ # adaLN-Zero conditioning on (timestep t, iteration i, text_global)
199
+ adaln_modulation: Linear(D_cond, 6*D) # scale, shift, gate for norm1, norm2
200
+
201
+ norm1 β†’ GRFM β†’ gate1 β†’ residual
202
+ norm2 β†’ CrossAttention(q=x, kv=text_tokens) β†’ gate2 β†’ residual # Only 77 text tokens
203
+ norm3 β†’ FFN(SiLU) β†’ gate3 β†’ residual
204
+ ```
205
+
206
+ Cross-attention with only 77 text tokens is cheap: O(N Γ— 77 Γ— d) β‰ˆ O(NΒ·d).
207
+
208
+ #### GRFM (Gated Recurrent Fourier Mixer) β€” The Core Innovation
209
+ ```python
210
+ class GRFM:
211
+ def forward(x, spatial_shape):
212
+ B, N, D = x.shape
213
+ H, W = spatial_shape
214
+ x_2d = x.reshape(B, H, W, D)
215
+
216
+ # Pathway 1: Fourier Global (O(N log N))
217
+ x_freq = rfft2(x_2d, dim=(1,2))
218
+ x_freq = block_mlp(x_freq) # Block-diagonal MLP in freq domain
219
+ x_freq = soft_shrink(x_freq, lambd=self.sparsity_threshold)
220
+ x_fourier = irfft2(x_freq, dim=(1,2))
221
+
222
+ # Pathway 2: Bidirectional Gated Recurrence (O(N))
223
+ x_flat_fwd = x # N tokens in raster order
224
+ x_flat_bwd = x.flip(1) # Reversed
225
+ h_fwd = gated_linear_recurrence(x_flat_fwd, self.decay_fwd, self.gate_fwd)
226
+ h_bwd = gated_linear_recurrence(x_flat_bwd, self.decay_bwd, self.gate_bwd)
227
+ x_recurrent = linear(concat(h_fwd, h_bwd.flip(1)))
228
+
229
+ # Pathway 3: Manhattan Spatial Gate
230
+ manhattan_dist = compute_manhattan(H, W) # Precomputed
231
+ gamma = sigmoid(self.gamma_param) # Per-head
232
+ spatial_decay = gamma.pow(manhattan_dist) # [heads, N, N] β€” sparse/windowed
233
+ x_gated = einsum('hnn,bnd->bnd', spatial_decay[:, :K, :K], value_proj(x))
234
+ gate = sigmoid(gate_proj(x))
235
+
236
+ # Adaptive Fusion
237
+ output = x_fourier * gate + x_recurrent * (1 - gate)
238
+ output = output + 0.1 * x_gated # Small residual from spatial
239
+
240
+ return output_proj(output)
241
+ ```
242
+
243
+ #### Coda (2 blocks, unique weights)
244
+ ```python
245
+ class Coda:
246
+ blocks: [CodaBlock Γ— 2]
247
+ unpatchify: ConvTranspose2d(D, C_latent, kernel_size=2, stride=2)
248
+ final_norm: LayerNorm(D)
249
+
250
+ class CodaBlock:
251
+ norm1 β†’ LocalWindowAttention(window=8) β†’ residual # Small window, efficient
252
+ norm2 β†’ FFN β†’ residual
253
+ ```
254
+
255
+ ### 3.3 Parameter Budget
256
+
257
+ | Component | Parameters | Notes |
258
+ |-----------|-----------|-------|
259
+ | WaveletVAE Encoder | ~15M | Lightweight (LiteVAE-style) |
260
+ | WaveletVAE Decoder | ~8M | Tiny decoder (SnapGen-style) |
261
+ | CLIP-L/14 Text Encoder | ~39M | Frozen, not counted for training |
262
+ | Prelude (2 blocks) | ~12M | Conv-based, cheap |
263
+ | Core Block (shared) | ~45M | GRFM + CrossAttn + FFN |
264
+ | Coda (2 blocks) | ~15M | Local attention + FFN |
265
+ | Embeddings/conditioning | ~3M | Time, iteration, position |
266
+ | **Total Generator** | **~75M unique** | Core shared across iterations |
267
+ | **Effective depth** | **75M β†’ behaves like 400M+** | At r=8 iterations |
268
+ | **Total system** | **~137M** | Including VAE + text encoder |
269
+
270
+ ### 3.4 Memory Analysis (Inference at 512Γ—512)
271
+
272
+ ```
273
+ CLIP-L/14 text encoder: ~156 MB (fp16)
274
+ WaveletVAE Decoder: ~16 MB (fp16)
275
+ IRIS Generator: ~150 MB (fp16)
276
+ Latent tensor: ~2 MB (32Γ—32Γ—16, fp16)
277
+ KV cache (text cross-attn): ~12 MB
278
+ Intermediate activations: ~100 MB (single block, not accumulated)
279
+ OS/framework overhead: ~500 MB
280
+ ─────────────────────────────────────────
281
+ Total: ~936 MB βœ“ (well under 3GB)
282
+ ```
283
+
284
+ **Key insight**: Because Core block weights are shared, we don't accumulate layer-by-layer activations. Each iteration reuses the same memory buffer.
285
+
286
+ ---
287
+
288
+ ## 4. Training Recipe
289
+
290
+ ### Stage 1: Wavelet VAE Training (Standalone)
291
+ ```
292
+ Data: ImageNet (1.2M images) + CC3M (3M images)
293
+ Resolution: 256Γ—256
294
+ Objective: Reconstruction loss + KL + Perceptual (LPIPS) + Wavelet frequency loss
295
+ Batch: 32
296
+ LR: 1e-4, cosine decay
297
+ Duration: ~20 GPU-hours on A100
298
+ ```
299
+
300
+ ### Stage 2: Class-Conditional Pretraining
301
+ ```
302
+ Data: ImageNet 256Γ—256 (class labels)
303
+ Objective: Rectified Flow velocity matching
304
+ Batch: 256
305
+ LR: 1e-4, warmup 5000 steps, cosine decay
306
+ Core iterations: r=8 (randomly sample r ∈ {4,6,8,10,12} for robustness)
307
+ Duration: ~100 GPU-hours on A100
308
+ ```
309
+
310
+ ### Stage 3: Text-Image Alignment
311
+ ```
312
+ Data: CC3M + CC12M (15M images with captions, re-captioned by VLM)
313
+ Resolution: 256β†’512 progressive
314
+ Objective: Rectified Flow + cross-attention on CLIP-L text tokens
315
+ Batch: 128
316
+ LR: 2e-5, constant
317
+ Duration: ~200 GPU-hours on A100
318
+ ```
319
+
320
+ ### Stage 4: Aesthetic Fine-tuning
321
+ ```
322
+ Data: JourneyDB + high-aesthetic LAION subset (1M images, aesthetic score > 6.0)
323
+ Resolution: 512Γ—512
324
+ Batch: 64
325
+ LR: 5e-6
326
+ Duration: ~50 GPU-hours
327
+ ```
328
+
329
+ ### Stage 5: Consistency Distillation
330
+ ```
331
+ Teacher: Trained IRIS model from Stage 4
332
+ Student: Same architecture, initialized from teacher
333
+ Objective: Consistency loss (CD) + optional LADD (adversarial)
334
+ Target: 1-4 step generation
335
+ Duration: ~30 GPU-hours
336
+ ```
337
+
338
+ **Total estimated cost: ~400 A100 GPU-hours β‰ˆ $1,600 at cloud prices**
339
+ **Colab/Kaggle feasible**: Stage 1-2 can run on T4/A100 free tier
340
+
341
+ ---
342
+
343
+ ## 5. Novel Contributions Summary
344
+
345
+ 1. **GRFM (Gated Recurrent Fourier Mixer)**: First architecture to fuse Fourier global mixing, gated linear recurrence, and Manhattan spatial decay in a single differentiable block with learned gating
346
+ 2. **Recurrent Depth for Image Generation**: First application of the Huginn prelude-core-coda pattern to image generation, enabling budget-adaptive compute
347
+ 3. **Wavelet-Frequency Latent Space**: DWT preprocessing before VAE encoding preserves frequency structure in the latent space
348
+ 4. **Iteration-Aware Conditioning**: The core block receives both timestep t and iteration index i via adaLN, allowing it to learn different behavior at different depths
349
+ 5. **Dual-Axis Recurrence**: Recurrence over both noise schedule (diffusion steps) and computational depth (core iterations) β€” a new paradigm for efficient generation
350
+
351
+ ---
352
+
353
+ ## 6. Extensions
354
+
355
+ ### 6.1 Image Editing (Inpainting, Super-Resolution)
356
+ The iterative nature of IRIS makes it natural for editing:
357
+ - **Inpainting**: Mask latent tokens, condition core iterations on unmasked context
358
+ - **Super-Resolution**: Encode low-res image via WaveletVAE, condition generation on LL subband
359
+ - **Prompt-based editing**: Encode source image, modify text conditioning, run partial denoising (SDEdit-style)
360
+
361
+ ### 6.2 ControlNet-like Conditioning
362
+ Add lightweight adapter to Prelude that injects spatial control signals (edges, depth, pose) into the latent, then the shared Core naturally propagates this through iterations.
363
+
364
+ ---