# IRIS: Iterative Recurrent Image Synthesis ## A Novel Architecture for Mobile-First High-Quality Image Generation ### Version 1.0 | Architecture Design Document --- ## 1. Executive Summary **IRIS** (Iterative Recurrent Image Synthesis) is a novel image generation architecture designed from first principles to achieve high visual quality on mobile devices (< 3-4GB RAM). It combines six key innovations drawn from cutting-edge research across multiple domains: 1. **Wavelet-Frequency Latent Space** — 16× spatial compression via Haar DWT + learned VAE, operating in frequency-aware space 2. **Recurrent Depth Core** — Shared-weight denoising block iterated N times (inspired by Huginn), achieving deep model behavior from tiny parameter count 3. **Gated Recurrent Fourier Mixer (GRFM)** — Novel token mixing that combines RG-LRU gated recurrence with Adaptive Fourier Neural Operators, replacing O(N²) attention with O(N log N) global mixing 4. **Manhattan Spatial Decay** — Learned per-head 2D spatial inductive bias via Manhattan distance exponential decay (from RMT) 5. **Rectified Flow with Consistency Distillation** — Straight ODE paths for few-step generation (1-4 steps) 6. **Adaptive Compute Budget** — Same model, variable quality: 4 iterations for mobile, 16 for quality ### Target Specifications | Metric | Target | Achieved By | |--------|--------|------------| | Total Parameters | < 250M (generator) | Recurrent depth + efficient blocks | | RAM (inference) | < 3GB total | ~600MB model + ~400MB VAE + ~200MB text encoder + buffers | | Inference Steps | 1-4 | Rectified flow + consistency distillation | | Core Iterations | 4 (fast) / 8-16 (quality) | Recurrent depth, shared weights | | Image Quality | Competitive with SDXL at 512px | Frequency-aware latent + proper training | | Prompt Adherence | Strong | CLIP-L/14 cross-attention conditioning | | Training Cost | < 200 A100 GPU-hours (Stage 1) | Efficient architecture + progressive training | --- ## 2. Theoretical Foundations ### 2.1 Why Current Approaches Fail on Mobile **Problem 1: Parameter Explosion in Transformers** Standard DiT/UNet architectures use independent parameters for each layer. A 24-layer DiT-XL has ~675M params. Each self-attention layer stores O(d²) params for Q,K,V projections × number of layers. **Problem 2: Quadratic Attention Complexity** For 512×512 images with 8× VAE downsampling: 64×64 = 4096 tokens. Self-attention requires 4096² × d operations per layer. At d=768, that's ~12.9 GFLOPS per attention layer. **Problem 3: Step Count** Standard diffusion requires 20-50 neural function evaluations (NFE). Even a small model × 50 steps = impractical. ### 2.2 Our Solution: Mathematical Framework #### 2.2.1 Recurrent Depth as Implicit Neural ODE The key insight from Huginn (arXiv:2502.05171): a shared-weight block `R` applied iteratively defines a discrete neural ODE: ``` s_0 = P(x) [Prelude: encode input] s_{i+1} = R(s_i, c, t) [Core: iterate with conditioning c and timestep t] y = C(s_r) [Coda: decode output] ``` This is mathematically equivalent to an Euler discretization of: ``` ds/dτ = F_θ(s(τ), c, t) where τ ∈ [0, 1], discretized into r steps ``` **Parameter efficiency**: If block R has P parameters, then r iterations give effective depth of r×L layers (where L = layers in R) using only P parameters. A 6-layer block iterated 16 times = 96 effective layers. **Connection to diffusion**: In standard diffusion, the denoiser f_θ is applied at each noise level t with the SAME parameters — this IS recurrent depth, but over the noise schedule axis. IRIS makes it recurrent over BOTH axes: noise schedule (outer loop, t) and computational depth (inner loop, τ). #### 2.2.2 Gated Recurrent Fourier Mixer (GRFM) — Novel Contribution We introduce GRFM, which processes the 2D token sequence through three parallel pathways merged multiplicatively: **Pathway 1: Fourier Global Mixing (O(N log N))** ``` x_fourier = IRFFT2(SoftShrink(BlockMLP(RFFT2(x)))) ``` From AFNO: captures global structure via frequency-domain mixing. The soft-shrinkage promotes sparsity in Fourier domain (images are naturally sparse in frequency). **Pathway 2: Gated Linear Recurrence (O(N))** ``` a_t = σ(Λ)^(c·σ(W_a · x_t)) [decay gate, per-element] i_t = σ(W_x · x_t) [input gate] h_t = a_t ⊙ h_{t-1} + √(1 - a_t²) ⊙ (i_t ⊙ x_t) [RG-LRU update] x_recurrent = W_o · h_T ``` From Griffin (arXiv:2402.19427): captures sequential dependencies with O(1) state per token position. Bidirectional (forward + backward scan). **Pathway 3: Manhattan Spatial Gate** ``` D_{nm} = γ_head^(|x_n - x_m| + |y_n - y_m|) [Manhattan decay matrix] gate = σ(W_g · x) ⊙ (D · (W_v · x)) ``` From RMT (arXiv:2309.11523): per-head learnable spatial decay provides multi-scale locality bias. **Fusion (Novel)**: ``` output = LayerNorm(x_fourier ⊙ σ(gate) + x_recurrent ⊙ (1 - σ(gate))) ``` The gate adaptively selects between global Fourier features (textures, patterns) and local recurrent features (edges, fine details) based on spatial context. This is NOT a simple concatenation — it's a learned, spatially-varying interpolation. #### 2.2.3 Wavelet-Frequency Latent Space Instead of standard VAE operating on pixels, we first apply Haar DWT: ``` x ∈ R^{3×H×W} → DWT → y ∈ R^{12×H/2×W/2} ``` Then a lightweight VAE encoder compresses to: ``` z ∈ R^{C×H/8×W/8} (effective 16× total spatial compression from original) ``` The VAE operates on wavelet coefficients, preserving frequency structure. The LL (low-low) subband carries global structure; LH, HL, HH carry directional high-frequency details. This means the latent space is inherently frequency-aware. **Benefit**: The denoiser operates on a latent that already separates structure from detail, making the learning problem easier for a small model. #### 2.2.4 Rectified Flow + Consistency Distillation **Training Phase 1 (Rectified Flow)**: ``` x_t = (1-t) · x_0 + t · ε [linear interpolation] v_target = ε - x_0 [velocity field] L = w(t) · ||v_θ(x_t, t, c) - v_target||² w(t) = t/(1-t+ε) [SNR reweighting] t ~ LogitNormal(0, 1) [concentrate on hard timesteps] ``` **Training Phase 2 (Consistency Distillation)**: ``` f_θ(x_t, t) = c_skip(t)·x_t + c_out(t)·F_θ(x_t, t) L_CD = d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x̂_{t_n}, t_n)) ``` Where θ⁻ is EMA of θ. This enables 1-4 step generation. --- ## 3. Architecture Details ### 3.1 Overall Pipeline ``` Text → CLIP-L/14 → c ∈ R^{77×768} ┌──────────────────────────┐ Image → Haar DWT → WaveletVAE Encode → z₀ ∈ R^{C×h×w} │ │ │ │ Noise schedule (RF): │ │ z_t = (1-t)z₀ + t·ε │ │ │ ▼ │ ┌─────────────────┐ │ │ PRELUDE │ │ │ (2 blocks) │ │ │ PatchEmbed + │ │ │ Initial mixing │ │ └────────┬─────────┘ │ │ │ ▼ │ ┌─────────────────┐ │ │ CORE (shared) │◄── Iterate r │ │ GRFM Block │ times │ │ + FFN │ (4-16) │ │ + adaLN-Zero │ │ └────────┬─────────┘ │ │ │ ▼ │ ┌─────────────────┐ │ │ CODA │ │ │ (2 blocks) │ │ │ Final refine + │ │ │ Unpatchify │ │ └────────┬─────────┘ │ │ │ ▼ │ v̂ = predicted velocity │ z₀_pred = z_t - t·v̂ │ │ │ ▼ │ WaveletVAE Decode → Haar IDWT → Image│ └──────────────────────────┘ ``` ### 3.2 Detailed Block Design #### Prelude (2 blocks, unique weights) ```python class Prelude: patch_embed: Conv2d(C_latent, D, kernel_size=2, stride=2) # 2× spatial reduce pos_embed: learned R^{(h/2 × w/2) × D} blocks: [PreludeBlock × 2] class PreludeBlock: norm1 → DepthwiseSepConv3x3 → GELU → PointwiseConv → norm2 → FFN # Uses conv instead of attention — cheap local feature extraction ``` #### Core (shared weights, iterated r times) ```python class CoreBlock: # adaLN-Zero conditioning on (timestep t, iteration i, text_global) adaln_modulation: Linear(D_cond, 6*D) # scale, shift, gate for norm1, norm2 norm1 → GRFM → gate1 → residual norm2 → CrossAttention(q=x, kv=text_tokens) → gate2 → residual # Only 77 text tokens norm3 → FFN(SiLU) → gate3 → residual ``` Cross-attention with only 77 text tokens is cheap: O(N × 77 × d) ≈ O(N·d). #### GRFM (Gated Recurrent Fourier Mixer) — The Core Innovation ```python class GRFM: def forward(x, spatial_shape): B, N, D = x.shape H, W = spatial_shape x_2d = x.reshape(B, H, W, D) # Pathway 1: Fourier Global (O(N log N)) x_freq = rfft2(x_2d, dim=(1,2)) x_freq = block_mlp(x_freq) # Block-diagonal MLP in freq domain x_freq = soft_shrink(x_freq, lambd=self.sparsity_threshold) x_fourier = irfft2(x_freq, dim=(1,2)) # Pathway 2: Bidirectional Gated Recurrence (O(N)) x_flat_fwd = x # N tokens in raster order x_flat_bwd = x.flip(1) # Reversed h_fwd = gated_linear_recurrence(x_flat_fwd, self.decay_fwd, self.gate_fwd) h_bwd = gated_linear_recurrence(x_flat_bwd, self.decay_bwd, self.gate_bwd) x_recurrent = linear(concat(h_fwd, h_bwd.flip(1))) # Pathway 3: Manhattan Spatial Gate manhattan_dist = compute_manhattan(H, W) # Precomputed gamma = sigmoid(self.gamma_param) # Per-head spatial_decay = gamma.pow(manhattan_dist) # [heads, N, N] — sparse/windowed x_gated = einsum('hnn,bnd->bnd', spatial_decay[:, :K, :K], value_proj(x)) gate = sigmoid(gate_proj(x)) # Adaptive Fusion output = x_fourier * gate + x_recurrent * (1 - gate) output = output + 0.1 * x_gated # Small residual from spatial return output_proj(output) ``` #### Coda (2 blocks, unique weights) ```python class Coda: blocks: [CodaBlock × 2] unpatchify: ConvTranspose2d(D, C_latent, kernel_size=2, stride=2) final_norm: LayerNorm(D) class CodaBlock: norm1 → LocalWindowAttention(window=8) → residual # Small window, efficient norm2 → FFN → residual ``` ### 3.3 Parameter Budget | Component | Parameters | Notes | |-----------|-----------|-------| | WaveletVAE Encoder | ~15M | Lightweight (LiteVAE-style) | | WaveletVAE Decoder | ~8M | Tiny decoder (SnapGen-style) | | CLIP-L/14 Text Encoder | ~39M | Frozen, not counted for training | | Prelude (2 blocks) | ~12M | Conv-based, cheap | | Core Block (shared) | ~45M | GRFM + CrossAttn + FFN | | Coda (2 blocks) | ~15M | Local attention + FFN | | Embeddings/conditioning | ~3M | Time, iteration, position | | **Total Generator** | **~75M unique** | Core shared across iterations | | **Effective depth** | **75M → behaves like 400M+** | At r=8 iterations | | **Total system** | **~137M** | Including VAE + text encoder | ### 3.4 Memory Analysis (Inference at 512×512) ``` CLIP-L/14 text encoder: ~156 MB (fp16) WaveletVAE Decoder: ~16 MB (fp16) IRIS Generator: ~150 MB (fp16) Latent tensor: ~2 MB (32×32×16, fp16) KV cache (text cross-attn): ~12 MB Intermediate activations: ~100 MB (single block, not accumulated) OS/framework overhead: ~500 MB ───────────────────────────────────────── Total: ~936 MB ✓ (well under 3GB) ``` **Key insight**: Because Core block weights are shared, we don't accumulate layer-by-layer activations. Each iteration reuses the same memory buffer. --- ## 4. Training Recipe ### Stage 1: Wavelet VAE Training (Standalone) ``` Data: ImageNet (1.2M images) + CC3M (3M images) Resolution: 256×256 Objective: Reconstruction loss + KL + Perceptual (LPIPS) + Wavelet frequency loss Batch: 32 LR: 1e-4, cosine decay Duration: ~20 GPU-hours on A100 ``` ### Stage 2: Class-Conditional Pretraining ``` Data: ImageNet 256×256 (class labels) Objective: Rectified Flow velocity matching Batch: 256 LR: 1e-4, warmup 5000 steps, cosine decay Core iterations: r=8 (randomly sample r ∈ {4,6,8,10,12} for robustness) Duration: ~100 GPU-hours on A100 ``` ### Stage 3: Text-Image Alignment ``` Data: CC3M + CC12M (15M images with captions, re-captioned by VLM) Resolution: 256→512 progressive Objective: Rectified Flow + cross-attention on CLIP-L text tokens Batch: 128 LR: 2e-5, constant Duration: ~200 GPU-hours on A100 ``` ### Stage 4: Aesthetic Fine-tuning ``` Data: JourneyDB + high-aesthetic LAION subset (1M images, aesthetic score > 6.0) Resolution: 512×512 Batch: 64 LR: 5e-6 Duration: ~50 GPU-hours ``` ### Stage 5: Consistency Distillation ``` Teacher: Trained IRIS model from Stage 4 Student: Same architecture, initialized from teacher Objective: Consistency loss (CD) + optional LADD (adversarial) Target: 1-4 step generation Duration: ~30 GPU-hours ``` **Total estimated cost: ~400 A100 GPU-hours ≈ $1,600 at cloud prices** **Colab/Kaggle feasible**: Stage 1-2 can run on T4/A100 free tier --- ## 5. Novel Contributions Summary 1. **GRFM (Gated Recurrent Fourier Mixer)**: First architecture to fuse Fourier global mixing, gated linear recurrence, and Manhattan spatial decay in a single differentiable block with learned gating 2. **Recurrent Depth for Image Generation**: First application of the Huginn prelude-core-coda pattern to image generation, enabling budget-adaptive compute 3. **Wavelet-Frequency Latent Space**: DWT preprocessing before VAE encoding preserves frequency structure in the latent space 4. **Iteration-Aware Conditioning**: The core block receives both timestep t and iteration index i via adaLN, allowing it to learn different behavior at different depths 5. **Dual-Axis Recurrence**: Recurrence over both noise schedule (diffusion steps) and computational depth (core iterations) — a new paradigm for efficient generation --- ## 6. Extensions ### 6.1 Image Editing (Inpainting, Super-Resolution) The iterative nature of IRIS makes it natural for editing: - **Inpainting**: Mask latent tokens, condition core iterations on unmasked context - **Super-Resolution**: Encode low-res image via WaveletVAE, condition generation on LL subband - **Prompt-based editing**: Encode source image, modify text conditioning, run partial denoising (SDEdit-style) ### 6.2 ControlNet-like Conditioning Add lightweight adapter to Prelude that injects spatial control signals (edges, depth, pose) into the latent, then the shared Core naturally propagates this through iterations. ---