File size: 16,926 Bytes
91e9e63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 | # IRIS: Iterative Recurrent Image Synthesis
## A Novel Architecture for Mobile-First High-Quality Image Generation
### Version 1.0 | Architecture Design Document
---
## 1. Executive Summary
**IRIS** (Iterative Recurrent Image Synthesis) is a novel image generation architecture designed from first principles to achieve high visual quality on mobile devices (< 3-4GB RAM). It combines six key innovations drawn from cutting-edge research across multiple domains:
1. **Wavelet-Frequency Latent Space** β 16Γ spatial compression via Haar DWT + learned VAE, operating in frequency-aware space
2. **Recurrent Depth Core** β Shared-weight denoising block iterated N times (inspired by Huginn), achieving deep model behavior from tiny parameter count
3. **Gated Recurrent Fourier Mixer (GRFM)** β Novel token mixing that combines RG-LRU gated recurrence with Adaptive Fourier Neural Operators, replacing O(NΒ²) attention with O(N log N) global mixing
4. **Manhattan Spatial Decay** β Learned per-head 2D spatial inductive bias via Manhattan distance exponential decay (from RMT)
5. **Rectified Flow with Consistency Distillation** β Straight ODE paths for few-step generation (1-4 steps)
6. **Adaptive Compute Budget** β Same model, variable quality: 4 iterations for mobile, 16 for quality
### Target Specifications
| Metric | Target | Achieved By |
|--------|--------|------------|
| Total Parameters | < 250M (generator) | Recurrent depth + efficient blocks |
| RAM (inference) | < 3GB total | ~600MB model + ~400MB VAE + ~200MB text encoder + buffers |
| Inference Steps | 1-4 | Rectified flow + consistency distillation |
| Core Iterations | 4 (fast) / 8-16 (quality) | Recurrent depth, shared weights |
| Image Quality | Competitive with SDXL at 512px | Frequency-aware latent + proper training |
| Prompt Adherence | Strong | CLIP-L/14 cross-attention conditioning |
| Training Cost | < 200 A100 GPU-hours (Stage 1) | Efficient architecture + progressive training |
---
## 2. Theoretical Foundations
### 2.1 Why Current Approaches Fail on Mobile
**Problem 1: Parameter Explosion in Transformers**
Standard DiT/UNet architectures use independent parameters for each layer. A 24-layer DiT-XL has ~675M params. Each self-attention layer stores O(dΒ²) params for Q,K,V projections Γ number of layers.
**Problem 2: Quadratic Attention Complexity**
For 512Γ512 images with 8Γ VAE downsampling: 64Γ64 = 4096 tokens. Self-attention requires 4096Β² Γ d operations per layer. At d=768, that's ~12.9 GFLOPS per attention layer.
**Problem 3: Step Count**
Standard diffusion requires 20-50 neural function evaluations (NFE). Even a small model Γ 50 steps = impractical.
### 2.2 Our Solution: Mathematical Framework
#### 2.2.1 Recurrent Depth as Implicit Neural ODE
The key insight from Huginn (arXiv:2502.05171): a shared-weight block `R` applied iteratively defines a discrete neural ODE:
```
s_0 = P(x) [Prelude: encode input]
s_{i+1} = R(s_i, c, t) [Core: iterate with conditioning c and timestep t]
y = C(s_r) [Coda: decode output]
```
This is mathematically equivalent to an Euler discretization of:
```
ds/dΟ = F_ΞΈ(s(Ο), c, t) where Ο β [0, 1], discretized into r steps
```
**Parameter efficiency**: If block R has P parameters, then r iterations give effective depth of rΓL layers (where L = layers in R) using only P parameters. A 6-layer block iterated 16 times = 96 effective layers.
**Connection to diffusion**: In standard diffusion, the denoiser f_ΞΈ is applied at each noise level t with the SAME parameters β this IS recurrent depth, but over the noise schedule axis. IRIS makes it recurrent over BOTH axes: noise schedule (outer loop, t) and computational depth (inner loop, Ο).
#### 2.2.2 Gated Recurrent Fourier Mixer (GRFM) β Novel Contribution
We introduce GRFM, which processes the 2D token sequence through three parallel pathways merged multiplicatively:
**Pathway 1: Fourier Global Mixing (O(N log N))**
```
x_fourier = IRFFT2(SoftShrink(BlockMLP(RFFT2(x))))
```
From AFNO: captures global structure via frequency-domain mixing. The soft-shrinkage promotes sparsity in Fourier domain (images are naturally sparse in frequency).
**Pathway 2: Gated Linear Recurrence (O(N))**
```
a_t = Ο(Ξ)^(cΒ·Ο(W_a Β· x_t)) [decay gate, per-element]
i_t = Ο(W_x Β· x_t) [input gate]
h_t = a_t β h_{t-1} + β(1 - a_tΒ²) β (i_t β x_t) [RG-LRU update]
x_recurrent = W_o Β· h_T
```
From Griffin (arXiv:2402.19427): captures sequential dependencies with O(1) state per token position. Bidirectional (forward + backward scan).
**Pathway 3: Manhattan Spatial Gate**
```
D_{nm} = Ξ³_head^(|x_n - x_m| + |y_n - y_m|) [Manhattan decay matrix]
gate = Ο(W_g Β· x) β (D Β· (W_v Β· x))
```
From RMT (arXiv:2309.11523): per-head learnable spatial decay provides multi-scale locality bias.
**Fusion (Novel)**:
```
output = LayerNorm(x_fourier β Ο(gate) + x_recurrent β (1 - Ο(gate)))
```
The gate adaptively selects between global Fourier features (textures, patterns) and local recurrent features (edges, fine details) based on spatial context. This is NOT a simple concatenation β it's a learned, spatially-varying interpolation.
#### 2.2.3 Wavelet-Frequency Latent Space
Instead of standard VAE operating on pixels, we first apply Haar DWT:
```
x β R^{3ΓHΓW} β DWT β y β R^{12ΓH/2ΓW/2}
```
Then a lightweight VAE encoder compresses to:
```
z β R^{CΓH/8ΓW/8} (effective 16Γ total spatial compression from original)
```
The VAE operates on wavelet coefficients, preserving frequency structure. The LL (low-low) subband carries global structure; LH, HL, HH carry directional high-frequency details. This means the latent space is inherently frequency-aware.
**Benefit**: The denoiser operates on a latent that already separates structure from detail, making the learning problem easier for a small model.
#### 2.2.4 Rectified Flow + Consistency Distillation
**Training Phase 1 (Rectified Flow)**:
```
x_t = (1-t) Β· x_0 + t Β· Ξ΅ [linear interpolation]
v_target = Ξ΅ - x_0 [velocity field]
L = w(t) Β· ||v_ΞΈ(x_t, t, c) - v_target||Β²
w(t) = t/(1-t+Ξ΅) [SNR reweighting]
t ~ LogitNormal(0, 1) [concentrate on hard timesteps]
```
**Training Phase 2 (Consistency Distillation)**:
```
f_ΞΈ(x_t, t) = c_skip(t)Β·x_t + c_out(t)Β·F_ΞΈ(x_t, t)
L_CD = d(f_ΞΈ(x_{t_{n+1}}, t_{n+1}), f_{ΞΈβ»}(xΜ_{t_n}, t_n))
```
Where ΞΈβ» is EMA of ΞΈ. This enables 1-4 step generation.
---
## 3. Architecture Details
### 3.1 Overall Pipeline
```
Text β CLIP-L/14 β c β R^{77Γ768}
ββββββββββββββββββββββββββββ
Image β Haar DWT β WaveletVAE Encode β zβ β R^{CΓhΓw} β
β β
β Noise schedule (RF): β
β z_t = (1-t)zβ + tΒ·Ξ΅ β
β β
βΌ β
βββββββββββββββββββ β
β PRELUDE β β
β (2 blocks) β β
β PatchEmbed + β β
β Initial mixing β β
ββββββββββ¬ββββββββββ β
β β
βΌ β
βββββββββββββββββββ β
β CORE (shared) ββββ Iterate r β
β GRFM Block β times β
β + FFN β (4-16) β
β + adaLN-Zero β β
ββββββββββ¬ββββββββββ β
β β
βΌ β
βββββββββββββββββββ β
β CODA β β
β (2 blocks) β β
β Final refine + β β
β Unpatchify β β
ββββββββββ¬ββββββββββ β
β β
βΌ β
vΜ = predicted velocity β
zβ_pred = z_t - tΒ·vΜ β
β β
βΌ β
WaveletVAE Decode β Haar IDWT β Imageβ
ββββββββββββββββββββββββββββ
```
### 3.2 Detailed Block Design
#### Prelude (2 blocks, unique weights)
```python
class Prelude:
patch_embed: Conv2d(C_latent, D, kernel_size=2, stride=2) # 2Γ spatial reduce
pos_embed: learned R^{(h/2 Γ w/2) Γ D}
blocks: [PreludeBlock Γ 2]
class PreludeBlock:
norm1 β DepthwiseSepConv3x3 β GELU β PointwiseConv β norm2 β FFN
# Uses conv instead of attention β cheap local feature extraction
```
#### Core (shared weights, iterated r times)
```python
class CoreBlock:
# adaLN-Zero conditioning on (timestep t, iteration i, text_global)
adaln_modulation: Linear(D_cond, 6*D) # scale, shift, gate for norm1, norm2
norm1 β GRFM β gate1 β residual
norm2 β CrossAttention(q=x, kv=text_tokens) β gate2 β residual # Only 77 text tokens
norm3 β FFN(SiLU) β gate3 β residual
```
Cross-attention with only 77 text tokens is cheap: O(N Γ 77 Γ d) β O(NΒ·d).
#### GRFM (Gated Recurrent Fourier Mixer) β The Core Innovation
```python
class GRFM:
def forward(x, spatial_shape):
B, N, D = x.shape
H, W = spatial_shape
x_2d = x.reshape(B, H, W, D)
# Pathway 1: Fourier Global (O(N log N))
x_freq = rfft2(x_2d, dim=(1,2))
x_freq = block_mlp(x_freq) # Block-diagonal MLP in freq domain
x_freq = soft_shrink(x_freq, lambd=self.sparsity_threshold)
x_fourier = irfft2(x_freq, dim=(1,2))
# Pathway 2: Bidirectional Gated Recurrence (O(N))
x_flat_fwd = x # N tokens in raster order
x_flat_bwd = x.flip(1) # Reversed
h_fwd = gated_linear_recurrence(x_flat_fwd, self.decay_fwd, self.gate_fwd)
h_bwd = gated_linear_recurrence(x_flat_bwd, self.decay_bwd, self.gate_bwd)
x_recurrent = linear(concat(h_fwd, h_bwd.flip(1)))
# Pathway 3: Manhattan Spatial Gate
manhattan_dist = compute_manhattan(H, W) # Precomputed
gamma = sigmoid(self.gamma_param) # Per-head
spatial_decay = gamma.pow(manhattan_dist) # [heads, N, N] β sparse/windowed
x_gated = einsum('hnn,bnd->bnd', spatial_decay[:, :K, :K], value_proj(x))
gate = sigmoid(gate_proj(x))
# Adaptive Fusion
output = x_fourier * gate + x_recurrent * (1 - gate)
output = output + 0.1 * x_gated # Small residual from spatial
return output_proj(output)
```
#### Coda (2 blocks, unique weights)
```python
class Coda:
blocks: [CodaBlock Γ 2]
unpatchify: ConvTranspose2d(D, C_latent, kernel_size=2, stride=2)
final_norm: LayerNorm(D)
class CodaBlock:
norm1 β LocalWindowAttention(window=8) β residual # Small window, efficient
norm2 β FFN β residual
```
### 3.3 Parameter Budget
| Component | Parameters | Notes |
|-----------|-----------|-------|
| WaveletVAE Encoder | ~15M | Lightweight (LiteVAE-style) |
| WaveletVAE Decoder | ~8M | Tiny decoder (SnapGen-style) |
| CLIP-L/14 Text Encoder | ~39M | Frozen, not counted for training |
| Prelude (2 blocks) | ~12M | Conv-based, cheap |
| Core Block (shared) | ~45M | GRFM + CrossAttn + FFN |
| Coda (2 blocks) | ~15M | Local attention + FFN |
| Embeddings/conditioning | ~3M | Time, iteration, position |
| **Total Generator** | **~75M unique** | Core shared across iterations |
| **Effective depth** | **75M β behaves like 400M+** | At r=8 iterations |
| **Total system** | **~137M** | Including VAE + text encoder |
### 3.4 Memory Analysis (Inference at 512Γ512)
```
CLIP-L/14 text encoder: ~156 MB (fp16)
WaveletVAE Decoder: ~16 MB (fp16)
IRIS Generator: ~150 MB (fp16)
Latent tensor: ~2 MB (32Γ32Γ16, fp16)
KV cache (text cross-attn): ~12 MB
Intermediate activations: ~100 MB (single block, not accumulated)
OS/framework overhead: ~500 MB
βββββββββββββββββββββββββββββββββββββββββ
Total: ~936 MB β (well under 3GB)
```
**Key insight**: Because Core block weights are shared, we don't accumulate layer-by-layer activations. Each iteration reuses the same memory buffer.
---
## 4. Training Recipe
### Stage 1: Wavelet VAE Training (Standalone)
```
Data: ImageNet (1.2M images) + CC3M (3M images)
Resolution: 256Γ256
Objective: Reconstruction loss + KL + Perceptual (LPIPS) + Wavelet frequency loss
Batch: 32
LR: 1e-4, cosine decay
Duration: ~20 GPU-hours on A100
```
### Stage 2: Class-Conditional Pretraining
```
Data: ImageNet 256Γ256 (class labels)
Objective: Rectified Flow velocity matching
Batch: 256
LR: 1e-4, warmup 5000 steps, cosine decay
Core iterations: r=8 (randomly sample r β {4,6,8,10,12} for robustness)
Duration: ~100 GPU-hours on A100
```
### Stage 3: Text-Image Alignment
```
Data: CC3M + CC12M (15M images with captions, re-captioned by VLM)
Resolution: 256β512 progressive
Objective: Rectified Flow + cross-attention on CLIP-L text tokens
Batch: 128
LR: 2e-5, constant
Duration: ~200 GPU-hours on A100
```
### Stage 4: Aesthetic Fine-tuning
```
Data: JourneyDB + high-aesthetic LAION subset (1M images, aesthetic score > 6.0)
Resolution: 512Γ512
Batch: 64
LR: 5e-6
Duration: ~50 GPU-hours
```
### Stage 5: Consistency Distillation
```
Teacher: Trained IRIS model from Stage 4
Student: Same architecture, initialized from teacher
Objective: Consistency loss (CD) + optional LADD (adversarial)
Target: 1-4 step generation
Duration: ~30 GPU-hours
```
**Total estimated cost: ~400 A100 GPU-hours β $1,600 at cloud prices**
**Colab/Kaggle feasible**: Stage 1-2 can run on T4/A100 free tier
---
## 5. Novel Contributions Summary
1. **GRFM (Gated Recurrent Fourier Mixer)**: First architecture to fuse Fourier global mixing, gated linear recurrence, and Manhattan spatial decay in a single differentiable block with learned gating
2. **Recurrent Depth for Image Generation**: First application of the Huginn prelude-core-coda pattern to image generation, enabling budget-adaptive compute
3. **Wavelet-Frequency Latent Space**: DWT preprocessing before VAE encoding preserves frequency structure in the latent space
4. **Iteration-Aware Conditioning**: The core block receives both timestep t and iteration index i via adaLN, allowing it to learn different behavior at different depths
5. **Dual-Axis Recurrence**: Recurrence over both noise schedule (diffusion steps) and computational depth (core iterations) β a new paradigm for efficient generation
---
## 6. Extensions
### 6.1 Image Editing (Inpainting, Super-Resolution)
The iterative nature of IRIS makes it natural for editing:
- **Inpainting**: Mask latent tokens, condition core iterations on unmasked context
- **Super-Resolution**: Encode low-res image via WaveletVAE, condition generation on LL subband
- **Prompt-based editing**: Encode source image, modify text conditioning, run partial denoising (SDEdit-style)
### 6.2 ControlNet-like Conditioning
Add lightweight adapter to Prelude that injects spatial control signals (edges, depth, pose) into the latent, then the shared Core naturally propagates this through iterations.
---
|