IRIS-architecture / ARCHITECTURE.md
asdf98's picture
Add ARCHITECTURE.md
91e9e63 verified

IRIS: Iterative Recurrent Image Synthesis

A Novel Architecture for Mobile-First High-Quality Image Generation

Version 1.0 | Architecture Design Document


1. Executive Summary

IRIS (Iterative Recurrent Image Synthesis) is a novel image generation architecture designed from first principles to achieve high visual quality on mobile devices (< 3-4GB RAM). It combines six key innovations drawn from cutting-edge research across multiple domains:

  1. Wavelet-Frequency Latent Space β€” 16Γ— spatial compression via Haar DWT + learned VAE, operating in frequency-aware space
  2. Recurrent Depth Core β€” Shared-weight denoising block iterated N times (inspired by Huginn), achieving deep model behavior from tiny parameter count
  3. Gated Recurrent Fourier Mixer (GRFM) β€” Novel token mixing that combines RG-LRU gated recurrence with Adaptive Fourier Neural Operators, replacing O(NΒ²) attention with O(N log N) global mixing
  4. Manhattan Spatial Decay β€” Learned per-head 2D spatial inductive bias via Manhattan distance exponential decay (from RMT)
  5. Rectified Flow with Consistency Distillation β€” Straight ODE paths for few-step generation (1-4 steps)
  6. Adaptive Compute Budget β€” Same model, variable quality: 4 iterations for mobile, 16 for quality

Target Specifications

Metric Target Achieved By
Total Parameters < 250M (generator) Recurrent depth + efficient blocks
RAM (inference) < 3GB total ~600MB model + ~400MB VAE + ~200MB text encoder + buffers
Inference Steps 1-4 Rectified flow + consistency distillation
Core Iterations 4 (fast) / 8-16 (quality) Recurrent depth, shared weights
Image Quality Competitive with SDXL at 512px Frequency-aware latent + proper training
Prompt Adherence Strong CLIP-L/14 cross-attention conditioning
Training Cost < 200 A100 GPU-hours (Stage 1) Efficient architecture + progressive training

2. Theoretical Foundations

2.1 Why Current Approaches Fail on Mobile

Problem 1: Parameter Explosion in Transformers Standard DiT/UNet architectures use independent parameters for each layer. A 24-layer DiT-XL has ~675M params. Each self-attention layer stores O(dΒ²) params for Q,K,V projections Γ— number of layers.

Problem 2: Quadratic Attention Complexity For 512Γ—512 images with 8Γ— VAE downsampling: 64Γ—64 = 4096 tokens. Self-attention requires 4096Β² Γ— d operations per layer. At d=768, that's ~12.9 GFLOPS per attention layer.

Problem 3: Step Count Standard diffusion requires 20-50 neural function evaluations (NFE). Even a small model Γ— 50 steps = impractical.

2.2 Our Solution: Mathematical Framework

2.2.1 Recurrent Depth as Implicit Neural ODE

The key insight from Huginn (arXiv:2502.05171): a shared-weight block R applied iteratively defines a discrete neural ODE:

s_0 = P(x)                    [Prelude: encode input]
s_{i+1} = R(s_i, c, t)        [Core: iterate with conditioning c and timestep t]  
y = C(s_r)                    [Coda: decode output]

This is mathematically equivalent to an Euler discretization of:

ds/dΟ„ = F_ΞΈ(s(Ο„), c, t)      where Ο„ ∈ [0, 1], discretized into r steps

Parameter efficiency: If block R has P parameters, then r iterations give effective depth of rΓ—L layers (where L = layers in R) using only P parameters. A 6-layer block iterated 16 times = 96 effective layers.

Connection to diffusion: In standard diffusion, the denoiser f_ΞΈ is applied at each noise level t with the SAME parameters β€” this IS recurrent depth, but over the noise schedule axis. IRIS makes it recurrent over BOTH axes: noise schedule (outer loop, t) and computational depth (inner loop, Ο„).

2.2.2 Gated Recurrent Fourier Mixer (GRFM) β€” Novel Contribution

We introduce GRFM, which processes the 2D token sequence through three parallel pathways merged multiplicatively:

Pathway 1: Fourier Global Mixing (O(N log N))

x_fourier = IRFFT2(SoftShrink(BlockMLP(RFFT2(x))))

From AFNO: captures global structure via frequency-domain mixing. The soft-shrinkage promotes sparsity in Fourier domain (images are naturally sparse in frequency).

Pathway 2: Gated Linear Recurrence (O(N))

a_t = Οƒ(Ξ›)^(cΒ·Οƒ(W_a Β· x_t))     [decay gate, per-element]
i_t = Οƒ(W_x Β· x_t)              [input gate]
h_t = a_t βŠ™ h_{t-1} + √(1 - a_tΒ²) βŠ™ (i_t βŠ™ x_t)    [RG-LRU update]
x_recurrent = W_o Β· h_T

From Griffin (arXiv:2402.19427): captures sequential dependencies with O(1) state per token position. Bidirectional (forward + backward scan).

Pathway 3: Manhattan Spatial Gate

D_{nm} = Ξ³_head^(|x_n - x_m| + |y_n - y_m|)    [Manhattan decay matrix]
gate = Οƒ(W_g Β· x) βŠ™ (D Β· (W_v Β· x))

From RMT (arXiv:2309.11523): per-head learnable spatial decay provides multi-scale locality bias.

Fusion (Novel):

output = LayerNorm(x_fourier βŠ™ Οƒ(gate) + x_recurrent βŠ™ (1 - Οƒ(gate)))

The gate adaptively selects between global Fourier features (textures, patterns) and local recurrent features (edges, fine details) based on spatial context. This is NOT a simple concatenation β€” it's a learned, spatially-varying interpolation.

2.2.3 Wavelet-Frequency Latent Space

Instead of standard VAE operating on pixels, we first apply Haar DWT:

x ∈ R^{3Γ—HΓ—W} β†’ DWT β†’ y ∈ R^{12Γ—H/2Γ—W/2}

Then a lightweight VAE encoder compresses to:

z ∈ R^{CΓ—H/8Γ—W/8}   (effective 16Γ— total spatial compression from original)

The VAE operates on wavelet coefficients, preserving frequency structure. The LL (low-low) subband carries global structure; LH, HL, HH carry directional high-frequency details. This means the latent space is inherently frequency-aware.

Benefit: The denoiser operates on a latent that already separates structure from detail, making the learning problem easier for a small model.

2.2.4 Rectified Flow + Consistency Distillation

Training Phase 1 (Rectified Flow):

x_t = (1-t) Β· x_0 + t Β· Ξ΅           [linear interpolation]
v_target = Ξ΅ - x_0                   [velocity field]
L = w(t) Β· ||v_ΞΈ(x_t, t, c) - v_target||Β²
w(t) = t/(1-t+Ξ΅)                     [SNR reweighting]
t ~ LogitNormal(0, 1)                [concentrate on hard timesteps]

Training Phase 2 (Consistency Distillation):

f_ΞΈ(x_t, t) = c_skip(t)Β·x_t + c_out(t)Β·F_ΞΈ(x_t, t)
L_CD = d(f_ΞΈ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(xΜ‚_{t_n}, t_n))

Where θ⁻ is EMA of θ. This enables 1-4 step generation.


3. Architecture Details

3.1 Overall Pipeline

Text β†’ CLIP-L/14 β†’ c ∈ R^{77Γ—768}
                                        β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
Image β†’ Haar DWT β†’ WaveletVAE Encode β†’ zβ‚€ ∈ R^{CΓ—hΓ—w}           β”‚
                                        β”‚                          β”‚
                                        β”‚ Noise schedule (RF):     β”‚
                                        β”‚ z_t = (1-t)zβ‚€ + tΒ·Ξ΅     β”‚
                                        β”‚                          β”‚
                                        β–Ό                          β”‚
                              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”‚
                              β”‚    PRELUDE       β”‚                  β”‚
                              β”‚  (2 blocks)      β”‚                  β”‚
                              β”‚  PatchEmbed +    β”‚                  β”‚
                              β”‚  Initial mixing  β”‚                  β”‚
                              β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                  β”‚
                                       β”‚                           β”‚
                                       β–Ό                           β”‚
                              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”‚
                              β”‚   CORE (shared)  │◄── Iterate r    β”‚
                              β”‚   GRFM Block     β”‚    times        β”‚
                              β”‚   + FFN          β”‚    (4-16)       β”‚
                              β”‚   + adaLN-Zero   β”‚                  β”‚
                              β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                  β”‚
                                       β”‚                           β”‚
                                       β–Ό                           β”‚
                              β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”‚
                              β”‚     CODA         β”‚                  β”‚
                              β”‚   (2 blocks)     β”‚                  β”‚
                              β”‚   Final refine + β”‚                  β”‚
                              β”‚   Unpatchify     β”‚                  β”‚
                              β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                  β”‚
                                       β”‚                           β”‚
                                       β–Ό                          β”‚
                              vΜ‚ = predicted velocity              β”‚
                              zβ‚€_pred = z_t - tΒ·vΜ‚                 β”‚
                                        β”‚                          β”‚
                                        β–Ό                          β”‚
                              WaveletVAE Decode β†’ Haar IDWT β†’ Imageβ”‚
                                        β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

3.2 Detailed Block Design

Prelude (2 blocks, unique weights)

class Prelude:
    patch_embed: Conv2d(C_latent, D, kernel_size=2, stride=2)  # 2Γ— spatial reduce
    pos_embed: learned R^{(h/2 Γ— w/2) Γ— D}
    blocks: [PreludeBlock Γ— 2]
    
class PreludeBlock:
    norm1 β†’ DepthwiseSepConv3x3 β†’ GELU β†’ PointwiseConv β†’ norm2 β†’ FFN
    # Uses conv instead of attention β€” cheap local feature extraction

Core (shared weights, iterated r times)

class CoreBlock:
    # adaLN-Zero conditioning on (timestep t, iteration i, text_global)
    adaln_modulation: Linear(D_cond, 6*D)  # scale, shift, gate for norm1, norm2
    
    norm1 β†’ GRFM β†’ gate1 β†’ residual
    norm2 β†’ CrossAttention(q=x, kv=text_tokens) β†’ gate2 β†’ residual  # Only 77 text tokens
    norm3 β†’ FFN(SiLU) β†’ gate3 β†’ residual

Cross-attention with only 77 text tokens is cheap: O(N Γ— 77 Γ— d) β‰ˆ O(NΒ·d).

GRFM (Gated Recurrent Fourier Mixer) β€” The Core Innovation

class GRFM:
    def forward(x, spatial_shape):
        B, N, D = x.shape
        H, W = spatial_shape
        x_2d = x.reshape(B, H, W, D)
        
        # Pathway 1: Fourier Global (O(N log N))
        x_freq = rfft2(x_2d, dim=(1,2))
        x_freq = block_mlp(x_freq)  # Block-diagonal MLP in freq domain
        x_freq = soft_shrink(x_freq, lambd=self.sparsity_threshold)
        x_fourier = irfft2(x_freq, dim=(1,2))
        
        # Pathway 2: Bidirectional Gated Recurrence (O(N))
        x_flat_fwd = x  # N tokens in raster order
        x_flat_bwd = x.flip(1)  # Reversed
        h_fwd = gated_linear_recurrence(x_flat_fwd, self.decay_fwd, self.gate_fwd)
        h_bwd = gated_linear_recurrence(x_flat_bwd, self.decay_bwd, self.gate_bwd)
        x_recurrent = linear(concat(h_fwd, h_bwd.flip(1)))
        
        # Pathway 3: Manhattan Spatial Gate
        manhattan_dist = compute_manhattan(H, W)  # Precomputed
        gamma = sigmoid(self.gamma_param)  # Per-head
        spatial_decay = gamma.pow(manhattan_dist)  # [heads, N, N] β€” sparse/windowed
        x_gated = einsum('hnn,bnd->bnd', spatial_decay[:, :K, :K], value_proj(x))
        gate = sigmoid(gate_proj(x))
        
        # Adaptive Fusion
        output = x_fourier * gate + x_recurrent * (1 - gate)
        output = output + 0.1 * x_gated  # Small residual from spatial
        
        return output_proj(output)

Coda (2 blocks, unique weights)

class Coda:
    blocks: [CodaBlock Γ— 2]
    unpatchify: ConvTranspose2d(D, C_latent, kernel_size=2, stride=2)
    final_norm: LayerNorm(D)
    
class CodaBlock:
    norm1 β†’ LocalWindowAttention(window=8) β†’ residual  # Small window, efficient
    norm2 β†’ FFN β†’ residual

3.3 Parameter Budget

Component Parameters Notes
WaveletVAE Encoder ~15M Lightweight (LiteVAE-style)
WaveletVAE Decoder ~8M Tiny decoder (SnapGen-style)
CLIP-L/14 Text Encoder ~39M Frozen, not counted for training
Prelude (2 blocks) ~12M Conv-based, cheap
Core Block (shared) ~45M GRFM + CrossAttn + FFN
Coda (2 blocks) ~15M Local attention + FFN
Embeddings/conditioning ~3M Time, iteration, position
Total Generator ~75M unique Core shared across iterations
Effective depth 75M β†’ behaves like 400M+ At r=8 iterations
Total system ~137M Including VAE + text encoder

3.4 Memory Analysis (Inference at 512Γ—512)

CLIP-L/14 text encoder:     ~156 MB (fp16)
WaveletVAE Decoder:         ~16 MB (fp16)  
IRIS Generator:             ~150 MB (fp16)
Latent tensor:              ~2 MB (32Γ—32Γ—16, fp16)
KV cache (text cross-attn): ~12 MB
Intermediate activations:   ~100 MB (single block, not accumulated)
OS/framework overhead:      ~500 MB
─────────────────────────────────────────
Total:                      ~936 MB  βœ“ (well under 3GB)

Key insight: Because Core block weights are shared, we don't accumulate layer-by-layer activations. Each iteration reuses the same memory buffer.


4. Training Recipe

Stage 1: Wavelet VAE Training (Standalone)

Data: ImageNet (1.2M images) + CC3M (3M images)
Resolution: 256Γ—256
Objective: Reconstruction loss + KL + Perceptual (LPIPS) + Wavelet frequency loss
Batch: 32
LR: 1e-4, cosine decay
Duration: ~20 GPU-hours on A100

Stage 2: Class-Conditional Pretraining

Data: ImageNet 256Γ—256 (class labels)
Objective: Rectified Flow velocity matching
Batch: 256
LR: 1e-4, warmup 5000 steps, cosine decay
Core iterations: r=8 (randomly sample r ∈ {4,6,8,10,12} for robustness)
Duration: ~100 GPU-hours on A100

Stage 3: Text-Image Alignment

Data: CC3M + CC12M (15M images with captions, re-captioned by VLM)
Resolution: 256β†’512 progressive
Objective: Rectified Flow + cross-attention on CLIP-L text tokens
Batch: 128
LR: 2e-5, constant
Duration: ~200 GPU-hours on A100

Stage 4: Aesthetic Fine-tuning

Data: JourneyDB + high-aesthetic LAION subset (1M images, aesthetic score > 6.0)
Resolution: 512Γ—512
Batch: 64
LR: 5e-6
Duration: ~50 GPU-hours

Stage 5: Consistency Distillation

Teacher: Trained IRIS model from Stage 4
Student: Same architecture, initialized from teacher
Objective: Consistency loss (CD) + optional LADD (adversarial)
Target: 1-4 step generation
Duration: ~30 GPU-hours

Total estimated cost: ~400 A100 GPU-hours β‰ˆ $1,600 at cloud prices Colab/Kaggle feasible: Stage 1-2 can run on T4/A100 free tier


5. Novel Contributions Summary

  1. GRFM (Gated Recurrent Fourier Mixer): First architecture to fuse Fourier global mixing, gated linear recurrence, and Manhattan spatial decay in a single differentiable block with learned gating
  2. Recurrent Depth for Image Generation: First application of the Huginn prelude-core-coda pattern to image generation, enabling budget-adaptive compute
  3. Wavelet-Frequency Latent Space: DWT preprocessing before VAE encoding preserves frequency structure in the latent space
  4. Iteration-Aware Conditioning: The core block receives both timestep t and iteration index i via adaLN, allowing it to learn different behavior at different depths
  5. Dual-Axis Recurrence: Recurrence over both noise schedule (diffusion steps) and computational depth (core iterations) β€” a new paradigm for efficient generation

6. Extensions

6.1 Image Editing (Inpainting, Super-Resolution)

The iterative nature of IRIS makes it natural for editing:

  • Inpainting: Mask latent tokens, condition core iterations on unmasked context
  • Super-Resolution: Encode low-res image via WaveletVAE, condition generation on LL subband
  • Prompt-based editing: Encode source image, modify text conditioning, run partial denoising (SDEdit-style)

6.2 ControlNet-like Conditioning

Add lightweight adapter to Prelude that injects spatial control signals (edges, depth, pose) into the latent, then the shared Core naturally propagates this through iterations.