| # IRIS: Iterative Recurrent Image Synthesis |
| ## A Novel Architecture for Mobile-First High-Quality Image Generation |
|
|
| ### Version 1.0 | Architecture Design Document |
|
|
| --- |
|
|
| ## 1. Executive Summary |
|
|
| **IRIS** (Iterative Recurrent Image Synthesis) is a novel image generation architecture designed from first principles to achieve high visual quality on mobile devices (< 3-4GB RAM). It combines six key innovations drawn from cutting-edge research across multiple domains: |
|
|
| 1. **Wavelet-Frequency Latent Space** β 16Γ spatial compression via Haar DWT + learned VAE, operating in frequency-aware space |
| 2. **Recurrent Depth Core** β Shared-weight denoising block iterated N times (inspired by Huginn), achieving deep model behavior from tiny parameter count |
| 3. **Gated Recurrent Fourier Mixer (GRFM)** β Novel token mixing that combines RG-LRU gated recurrence with Adaptive Fourier Neural Operators, replacing O(NΒ²) attention with O(N log N) global mixing |
| 4. **Manhattan Spatial Decay** β Learned per-head 2D spatial inductive bias via Manhattan distance exponential decay (from RMT) |
| 5. **Rectified Flow with Consistency Distillation** β Straight ODE paths for few-step generation (1-4 steps) |
| 6. **Adaptive Compute Budget** β Same model, variable quality: 4 iterations for mobile, 16 for quality |
|
|
| ### Target Specifications |
|
|
| | Metric | Target | Achieved By | |
| |--------|--------|------------| |
| | Total Parameters | < 250M (generator) | Recurrent depth + efficient blocks | |
| | RAM (inference) | < 3GB total | ~600MB model + ~400MB VAE + ~200MB text encoder + buffers | |
| | Inference Steps | 1-4 | Rectified flow + consistency distillation | |
| | Core Iterations | 4 (fast) / 8-16 (quality) | Recurrent depth, shared weights | |
| | Image Quality | Competitive with SDXL at 512px | Frequency-aware latent + proper training | |
| | Prompt Adherence | Strong | CLIP-L/14 cross-attention conditioning | |
| | Training Cost | < 200 A100 GPU-hours (Stage 1) | Efficient architecture + progressive training | |
|
|
| --- |
|
|
| ## 2. Theoretical Foundations |
|
|
| ### 2.1 Why Current Approaches Fail on Mobile |
|
|
| **Problem 1: Parameter Explosion in Transformers** |
| Standard DiT/UNet architectures use independent parameters for each layer. A 24-layer DiT-XL has ~675M params. Each self-attention layer stores O(dΒ²) params for Q,K,V projections Γ number of layers. |
|
|
| **Problem 2: Quadratic Attention Complexity** |
| For 512Γ512 images with 8Γ VAE downsampling: 64Γ64 = 4096 tokens. Self-attention requires 4096Β² Γ d operations per layer. At d=768, that's ~12.9 GFLOPS per attention layer. |
|
|
| **Problem 3: Step Count** |
| Standard diffusion requires 20-50 neural function evaluations (NFE). Even a small model Γ 50 steps = impractical. |
|
|
| ### 2.2 Our Solution: Mathematical Framework |
|
|
| #### 2.2.1 Recurrent Depth as Implicit Neural ODE |
|
|
| The key insight from Huginn (arXiv:2502.05171): a shared-weight block `R` applied iteratively defines a discrete neural ODE: |
|
|
| ``` |
| s_0 = P(x) [Prelude: encode input] |
| s_{i+1} = R(s_i, c, t) [Core: iterate with conditioning c and timestep t] |
| y = C(s_r) [Coda: decode output] |
| ``` |
|
|
| This is mathematically equivalent to an Euler discretization of: |
| ``` |
| ds/dΟ = F_ΞΈ(s(Ο), c, t) where Ο β [0, 1], discretized into r steps |
| ``` |
|
|
| **Parameter efficiency**: If block R has P parameters, then r iterations give effective depth of rΓL layers (where L = layers in R) using only P parameters. A 6-layer block iterated 16 times = 96 effective layers. |
|
|
| **Connection to diffusion**: In standard diffusion, the denoiser f_ΞΈ is applied at each noise level t with the SAME parameters β this IS recurrent depth, but over the noise schedule axis. IRIS makes it recurrent over BOTH axes: noise schedule (outer loop, t) and computational depth (inner loop, Ο). |
| |
| #### 2.2.2 Gated Recurrent Fourier Mixer (GRFM) β Novel Contribution |
| |
| We introduce GRFM, which processes the 2D token sequence through three parallel pathways merged multiplicatively: |
| |
| **Pathway 1: Fourier Global Mixing (O(N log N))** |
| ``` |
| x_fourier = IRFFT2(SoftShrink(BlockMLP(RFFT2(x)))) |
| ``` |
| From AFNO: captures global structure via frequency-domain mixing. The soft-shrinkage promotes sparsity in Fourier domain (images are naturally sparse in frequency). |
| |
| **Pathway 2: Gated Linear Recurrence (O(N))** |
| ``` |
| a_t = Ο(Ξ)^(cΒ·Ο(W_a Β· x_t)) [decay gate, per-element] |
| i_t = Ο(W_x Β· x_t) [input gate] |
| h_t = a_t β h_{t-1} + β(1 - a_tΒ²) β (i_t β x_t) [RG-LRU update] |
| x_recurrent = W_o Β· h_T |
| ``` |
| From Griffin (arXiv:2402.19427): captures sequential dependencies with O(1) state per token position. Bidirectional (forward + backward scan). |
| |
| **Pathway 3: Manhattan Spatial Gate** |
| ``` |
| D_{nm} = Ξ³_head^(|x_n - x_m| + |y_n - y_m|) [Manhattan decay matrix] |
| gate = Ο(W_g Β· x) β (D Β· (W_v Β· x)) |
| ``` |
| From RMT (arXiv:2309.11523): per-head learnable spatial decay provides multi-scale locality bias. |
| |
| **Fusion (Novel)**: |
| ``` |
| output = LayerNorm(x_fourier β Ο(gate) + x_recurrent β (1 - Ο(gate))) |
| ``` |
| |
| The gate adaptively selects between global Fourier features (textures, patterns) and local recurrent features (edges, fine details) based on spatial context. This is NOT a simple concatenation β it's a learned, spatially-varying interpolation. |
| |
| #### 2.2.3 Wavelet-Frequency Latent Space |
| |
| Instead of standard VAE operating on pixels, we first apply Haar DWT: |
| ``` |
| x β R^{3ΓHΓW} β DWT β y β R^{12ΓH/2ΓW/2} |
| ``` |
| |
| Then a lightweight VAE encoder compresses to: |
| ``` |
| z β R^{CΓH/8ΓW/8} (effective 16Γ total spatial compression from original) |
| ``` |
| |
| The VAE operates on wavelet coefficients, preserving frequency structure. The LL (low-low) subband carries global structure; LH, HL, HH carry directional high-frequency details. This means the latent space is inherently frequency-aware. |
| |
| **Benefit**: The denoiser operates on a latent that already separates structure from detail, making the learning problem easier for a small model. |
| |
| #### 2.2.4 Rectified Flow + Consistency Distillation |
| |
| **Training Phase 1 (Rectified Flow)**: |
| ``` |
| x_t = (1-t) Β· x_0 + t Β· Ξ΅ [linear interpolation] |
| v_target = Ξ΅ - x_0 [velocity field] |
| L = w(t) Β· ||v_ΞΈ(x_t, t, c) - v_target||Β² |
| w(t) = t/(1-t+Ξ΅) [SNR reweighting] |
| t ~ LogitNormal(0, 1) [concentrate on hard timesteps] |
| ``` |
| |
| **Training Phase 2 (Consistency Distillation)**: |
| ``` |
| f_ΞΈ(x_t, t) = c_skip(t)Β·x_t + c_out(t)Β·F_ΞΈ(x_t, t) |
| L_CD = d(f_ΞΈ(x_{t_{n+1}}, t_{n+1}), f_{ΞΈβ»}(xΜ_{t_n}, t_n)) |
| ``` |
| Where ΞΈβ» is EMA of ΞΈ. This enables 1-4 step generation. |
| |
| --- |
| |
| ## 3. Architecture Details |
| |
| ### 3.1 Overall Pipeline |
| |
| ``` |
| Text β CLIP-L/14 β c β R^{77Γ768} |
| ββββββββββββββββββββββββββββ |
| Image β Haar DWT β WaveletVAE Encode β zβ β R^{CΓhΓw} β |
| β β |
| β Noise schedule (RF): β |
| β z_t = (1-t)zβ + tΒ·Ξ΅ β |
| β β |
| βΌ β |
| βββββββββββββββββββ β |
| β PRELUDE β β |
| β (2 blocks) β β |
| β PatchEmbed + β β |
| β Initial mixing β β |
| ββββββββββ¬ββββββββββ β |
| β β |
| βΌ β |
| βββββββββββββββββββ β |
| β CORE (shared) ββββ Iterate r β |
| β GRFM Block β times β |
| β + FFN β (4-16) β |
| β + adaLN-Zero β β |
| ββββββββββ¬ββββββββββ β |
| β β |
| βΌ β |
| βββββββββββββββββββ β |
| β CODA β β |
| β (2 blocks) β β |
| β Final refine + β β |
| β Unpatchify β β |
| ββββββββββ¬ββββββββββ β |
| β β |
| βΌ β |
| vΜ = predicted velocity β |
| zβ_pred = z_t - tΒ·vΜ β |
| β β |
| βΌ β |
| WaveletVAE Decode β Haar IDWT β Imageβ |
| ββββββββββββββββββββββββββββ |
| ``` |
| |
| ### 3.2 Detailed Block Design |
|
|
| #### Prelude (2 blocks, unique weights) |
| ```python |
| class Prelude: |
| patch_embed: Conv2d(C_latent, D, kernel_size=2, stride=2) # 2Γ spatial reduce |
| pos_embed: learned R^{(h/2 Γ w/2) Γ D} |
| blocks: [PreludeBlock Γ 2] |
| |
| class PreludeBlock: |
| norm1 β DepthwiseSepConv3x3 β GELU β PointwiseConv β norm2 β FFN |
| # Uses conv instead of attention β cheap local feature extraction |
| ``` |
|
|
| #### Core (shared weights, iterated r times) |
| ```python |
| class CoreBlock: |
| # adaLN-Zero conditioning on (timestep t, iteration i, text_global) |
| adaln_modulation: Linear(D_cond, 6*D) # scale, shift, gate for norm1, norm2 |
| |
| norm1 β GRFM β gate1 β residual |
| norm2 β CrossAttention(q=x, kv=text_tokens) β gate2 β residual # Only 77 text tokens |
| norm3 β FFN(SiLU) β gate3 β residual |
| ``` |
|
|
| Cross-attention with only 77 text tokens is cheap: O(N Γ 77 Γ d) β O(NΒ·d). |
|
|
| #### GRFM (Gated Recurrent Fourier Mixer) β The Core Innovation |
| ```python |
| class GRFM: |
| def forward(x, spatial_shape): |
| B, N, D = x.shape |
| H, W = spatial_shape |
| x_2d = x.reshape(B, H, W, D) |
| |
| # Pathway 1: Fourier Global (O(N log N)) |
| x_freq = rfft2(x_2d, dim=(1,2)) |
| x_freq = block_mlp(x_freq) # Block-diagonal MLP in freq domain |
| x_freq = soft_shrink(x_freq, lambd=self.sparsity_threshold) |
| x_fourier = irfft2(x_freq, dim=(1,2)) |
| |
| # Pathway 2: Bidirectional Gated Recurrence (O(N)) |
| x_flat_fwd = x # N tokens in raster order |
| x_flat_bwd = x.flip(1) # Reversed |
| h_fwd = gated_linear_recurrence(x_flat_fwd, self.decay_fwd, self.gate_fwd) |
| h_bwd = gated_linear_recurrence(x_flat_bwd, self.decay_bwd, self.gate_bwd) |
| x_recurrent = linear(concat(h_fwd, h_bwd.flip(1))) |
| |
| # Pathway 3: Manhattan Spatial Gate |
| manhattan_dist = compute_manhattan(H, W) # Precomputed |
| gamma = sigmoid(self.gamma_param) # Per-head |
| spatial_decay = gamma.pow(manhattan_dist) # [heads, N, N] β sparse/windowed |
| x_gated = einsum('hnn,bnd->bnd', spatial_decay[:, :K, :K], value_proj(x)) |
| gate = sigmoid(gate_proj(x)) |
| |
| # Adaptive Fusion |
| output = x_fourier * gate + x_recurrent * (1 - gate) |
| output = output + 0.1 * x_gated # Small residual from spatial |
| |
| return output_proj(output) |
| ``` |
|
|
| #### Coda (2 blocks, unique weights) |
| ```python |
| class Coda: |
| blocks: [CodaBlock Γ 2] |
| unpatchify: ConvTranspose2d(D, C_latent, kernel_size=2, stride=2) |
| final_norm: LayerNorm(D) |
| |
| class CodaBlock: |
| norm1 β LocalWindowAttention(window=8) β residual # Small window, efficient |
| norm2 β FFN β residual |
| ``` |
|
|
| ### 3.3 Parameter Budget |
|
|
| | Component | Parameters | Notes | |
| |-----------|-----------|-------| |
| | WaveletVAE Encoder | ~15M | Lightweight (LiteVAE-style) | |
| | WaveletVAE Decoder | ~8M | Tiny decoder (SnapGen-style) | |
| | CLIP-L/14 Text Encoder | ~39M | Frozen, not counted for training | |
| | Prelude (2 blocks) | ~12M | Conv-based, cheap | |
| | Core Block (shared) | ~45M | GRFM + CrossAttn + FFN | |
| | Coda (2 blocks) | ~15M | Local attention + FFN | |
| | Embeddings/conditioning | ~3M | Time, iteration, position | |
| | **Total Generator** | **~75M unique** | Core shared across iterations | |
| | **Effective depth** | **75M β behaves like 400M+** | At r=8 iterations | |
| | **Total system** | **~137M** | Including VAE + text encoder | |
|
|
| ### 3.4 Memory Analysis (Inference at 512Γ512) |
|
|
| ``` |
| CLIP-L/14 text encoder: ~156 MB (fp16) |
| WaveletVAE Decoder: ~16 MB (fp16) |
| IRIS Generator: ~150 MB (fp16) |
| Latent tensor: ~2 MB (32Γ32Γ16, fp16) |
| KV cache (text cross-attn): ~12 MB |
| Intermediate activations: ~100 MB (single block, not accumulated) |
| OS/framework overhead: ~500 MB |
| βββββββββββββββββββββββββββββββββββββββββ |
| Total: ~936 MB β (well under 3GB) |
| ``` |
|
|
| **Key insight**: Because Core block weights are shared, we don't accumulate layer-by-layer activations. Each iteration reuses the same memory buffer. |
|
|
| --- |
|
|
| ## 4. Training Recipe |
|
|
| ### Stage 1: Wavelet VAE Training (Standalone) |
| ``` |
| Data: ImageNet (1.2M images) + CC3M (3M images) |
| Resolution: 256Γ256 |
| Objective: Reconstruction loss + KL + Perceptual (LPIPS) + Wavelet frequency loss |
| Batch: 32 |
| LR: 1e-4, cosine decay |
| Duration: ~20 GPU-hours on A100 |
| ``` |
|
|
| ### Stage 2: Class-Conditional Pretraining |
| ``` |
| Data: ImageNet 256Γ256 (class labels) |
| Objective: Rectified Flow velocity matching |
| Batch: 256 |
| LR: 1e-4, warmup 5000 steps, cosine decay |
| Core iterations: r=8 (randomly sample r β {4,6,8,10,12} for robustness) |
| Duration: ~100 GPU-hours on A100 |
| ``` |
|
|
| ### Stage 3: Text-Image Alignment |
| ``` |
| Data: CC3M + CC12M (15M images with captions, re-captioned by VLM) |
| Resolution: 256β512 progressive |
| Objective: Rectified Flow + cross-attention on CLIP-L text tokens |
| Batch: 128 |
| LR: 2e-5, constant |
| Duration: ~200 GPU-hours on A100 |
| ``` |
|
|
| ### Stage 4: Aesthetic Fine-tuning |
| ``` |
| Data: JourneyDB + high-aesthetic LAION subset (1M images, aesthetic score > 6.0) |
| Resolution: 512Γ512 |
| Batch: 64 |
| LR: 5e-6 |
| Duration: ~50 GPU-hours |
| ``` |
|
|
| ### Stage 5: Consistency Distillation |
| ``` |
| Teacher: Trained IRIS model from Stage 4 |
| Student: Same architecture, initialized from teacher |
| Objective: Consistency loss (CD) + optional LADD (adversarial) |
| Target: 1-4 step generation |
| Duration: ~30 GPU-hours |
| ``` |
|
|
| **Total estimated cost: ~400 A100 GPU-hours β $1,600 at cloud prices** |
| **Colab/Kaggle feasible**: Stage 1-2 can run on T4/A100 free tier |
|
|
| --- |
|
|
| ## 5. Novel Contributions Summary |
|
|
| 1. **GRFM (Gated Recurrent Fourier Mixer)**: First architecture to fuse Fourier global mixing, gated linear recurrence, and Manhattan spatial decay in a single differentiable block with learned gating |
| 2. **Recurrent Depth for Image Generation**: First application of the Huginn prelude-core-coda pattern to image generation, enabling budget-adaptive compute |
| 3. **Wavelet-Frequency Latent Space**: DWT preprocessing before VAE encoding preserves frequency structure in the latent space |
| 4. **Iteration-Aware Conditioning**: The core block receives both timestep t and iteration index i via adaLN, allowing it to learn different behavior at different depths |
| 5. **Dual-Axis Recurrence**: Recurrence over both noise schedule (diffusion steps) and computational depth (core iterations) β a new paradigm for efficient generation |
|
|
| --- |
|
|
| ## 6. Extensions |
|
|
| ### 6.1 Image Editing (Inpainting, Super-Resolution) |
| The iterative nature of IRIS makes it natural for editing: |
| - **Inpainting**: Mask latent tokens, condition core iterations on unmasked context |
| - **Super-Resolution**: Encode low-res image via WaveletVAE, condition generation on LL subband |
| - **Prompt-based editing**: Encode source image, modify text conditioning, run partial denoising (SDEdit-style) |
|
|
| ### 6.2 ControlNet-like Conditioning |
| Add lightweight adapter to Prelude that injects spatial control signals (edges, depth, pose) into the latent, then the shared Core naturally propagates this through iterations. |
|
|
| --- |
|
|