""" IRIS: Iterative Recurrent Image Synthesis ========================================== A novel architecture for mobile-first high-quality image generation. Key innovations: 1. Wavelet-Frequency Latent Space (Haar DWT + lightweight VAE) 2. Recurrent Depth Core (Prelude-Core-Coda with shared weights) 3. Gated Recurrent Fourier Mixer (GRFM) — novel token mixing 4. Manhattan Spatial Decay — learned 2D inductive bias 5. Rectified Flow training with consistency distillation support 6. Adaptive compute budget (4-16 iterations, same model) Author: IRIS Research License: Apache 2.0 """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple from dataclasses import dataclass, field # ============================================================================ # Configuration # ============================================================================ @dataclass class IRISConfig: """Configuration for IRIS model.""" # Latent space latent_channels: int = 16 # Channels in latent space latent_spatial: int = 32 # Spatial dim of latent (for 512px with 16x compression) # Model dimensions hidden_dim: int = 512 # Main hidden dimension num_heads: int = 8 # Number of attention heads head_dim: int = 64 # Dimension per head ffn_ratio: float = 2.667 # FFN expansion ratio (SwiGLU-adjusted) # Architecture structure num_prelude_blocks: int = 2 # Prelude blocks (unique weights) num_core_layers: int = 4 # Layers WITHIN each core iteration num_coda_blocks: int = 2 # Coda blocks (unique weights) default_iterations: int = 8 # Default core iterations at inference max_iterations: int = 16 # Maximum core iterations # GRFM settings fourier_num_blocks: int = 8 # Block-diagonal blocks in Fourier MLP sparsity_threshold: float = 0.01 # Soft-shrinkage lambda recurrence_dim: int = 256 # Dimension for gated recurrence pathway manhattan_window: int = 16 # Windowed Manhattan decay (for efficiency) # Cross-attention text_dim: int = 768 # CLIP-L/14 text embedding dim max_text_tokens: int = 77 # Maximum text sequence length # Patch embedding patch_size: int = 2 # Patches in latent space (2×2) # Conditioning num_timesteps: int = 1000 # Noise schedule discretization # VAE vae_channels: list = field(default_factory=lambda: [32, 64, 128, 256]) # Training dropout: float = 0.0 @property def vae_latent_channels(self) -> int: """VAE latent channels must match generator latent channels.""" return self.latent_channels @property def num_patches(self) -> int: return (self.latent_spatial // self.patch_size) ** 2 @property def patch_dim(self) -> int: return self.latent_channels * self.patch_size * self.patch_size # ============================================================================ # Wavelet Transforms (Haar) # ============================================================================ class HaarDWT2D(nn.Module): """2D Discrete Wavelet Transform using Haar wavelets. Decomposes x ∈ R^{B,C,H,W} into R^{B,4C,H/2,W/2} (LL, LH, HL, HH subbands). """ def forward(self, x: torch.Tensor) -> torch.Tensor: # Haar DWT: split into even/odd along both spatial dims x_ll = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] + x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5 x_lh = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] - x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5 x_hl = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] + x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5 x_hh = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] - x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5 return torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1) class HaarIDWT2D(nn.Module): """2D Inverse Discrete Wavelet Transform (Haar). Reconstructs x ∈ R^{B,C,H,W} from R^{B,4*(C//4),H/2,W/2}. """ def forward(self, x: torch.Tensor) -> torch.Tensor: B, C4, Hh, Wh = x.shape C = C4 // 4 ll, lh, hl, hh = x[:, :C], x[:, C:2*C], x[:, 2*C:3*C], x[:, 3*C:] # Reconstruct 2× spatial resolution H, W = Hh * 2, Wh * 2 out = torch.zeros(B, C, H, W, device=x.device, dtype=x.dtype) out[:, :, 0::2, 0::2] = (ll + lh + hl + hh) * 0.5 out[:, :, 0::2, 1::2] = (ll + lh - hl - hh) * 0.5 out[:, :, 1::2, 0::2] = (ll - lh + hl - hh) * 0.5 out[:, :, 1::2, 1::2] = (ll - lh - hl + hh) * 0.5 return out # ============================================================================ # Lightweight Wavelet VAE # ============================================================================ class DepthwiseSeparableConv(nn.Module): """Depthwise separable convolution — key mobile optimization.""" def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1): super().__init__() self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch) self.pointwise = nn.Conv2d(in_ch, out_ch, 1) def forward(self, x): return self.pointwise(self.depthwise(x)) class ResBlock(nn.Module): """Residual block with depthwise separable convolutions.""" def __init__(self, channels): super().__init__() self.norm1 = nn.GroupNorm(8, channels) self.conv1 = DepthwiseSeparableConv(channels, channels) self.norm2 = nn.GroupNorm(8, channels) self.conv2 = DepthwiseSeparableConv(channels, channels) # Zero-init final layer for residual learning stability nn.init.zeros_(self.conv2.pointwise.weight) nn.init.zeros_(self.conv2.pointwise.bias) def forward(self, x): h = F.silu(self.norm1(x)) h = self.conv1(h) h = F.silu(self.norm2(h)) h = self.conv2(h) return x + h class WaveletVAEEncoder(nn.Module): """Lightweight encoder: Haar DWT preprocessing + small convolutional encoder. Input: images R^{B,3,H,W} → Output: latent R^{B,C_latent,H/16,W/16} Compression: 3×H×W → C_latent×(H/16)×(W/16) """ def __init__(self, config: IRISConfig): super().__init__() self.dwt = HaarDWT2D() channels = config.vae_channels latent_ch = config.vae_latent_channels # DWT: 3 channels → 12 channels at H/2 × W/2 self.conv_in = nn.Conv2d(12, channels[0], 3, 1, 1) # Downsampling path: H/2→H/4→H/8→H/16 self.down_blocks = nn.ModuleList() for i in range(len(channels) - 1): self.down_blocks.append(nn.Sequential( ResBlock(channels[i]), nn.Conv2d(channels[i], channels[i+1], 3, 2, 1), # 2× downsample )) # Bottleneck self.mid = nn.Sequential( ResBlock(channels[-1]), ResBlock(channels[-1]), ) # To latent (mean + logvar) self.norm_out = nn.GroupNorm(8, channels[-1]) self.conv_out = nn.Conv2d(channels[-1], 2 * latent_ch, 1) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Haar DWT preprocessing x = self.dwt(x) # [B, 12, H/2, W/2] x = self.conv_in(x) for down in self.down_blocks: x = down(x) x = self.mid(x) x = F.silu(self.norm_out(x)) x = self.conv_out(x) mean, logvar = x.chunk(2, dim=1) logvar = torch.clamp(logvar, -30.0, 20.0) return mean, logvar def encode(self, x: torch.Tensor) -> torch.Tensor: mean, logvar = self.forward(x) std = torch.exp(0.5 * logvar) z = mean + std * torch.randn_like(std) return z, mean, logvar class WaveletVAEDecoder(nn.Module): """Tiny decoder: latent → wavelet coefficients → Haar IDWT → image. Designed to be as small as possible for mobile inference. """ def __init__(self, config: IRISConfig): super().__init__() channels = list(reversed(config.vae_channels)) latent_ch = config.vae_latent_channels self.idwt = HaarIDWT2D() # From latent self.conv_in = nn.Conv2d(latent_ch, channels[0], 3, 1, 1) # Bottleneck self.mid = nn.Sequential( ResBlock(channels[0]), ) # Upsampling path self.up_blocks = nn.ModuleList() for i in range(len(channels) - 1): self.up_blocks.append(nn.Sequential( nn.Upsample(scale_factor=2, mode='nearest'), DepthwiseSeparableConv(channels[i], channels[i+1]), nn.SiLU(), ResBlock(channels[i+1]), )) # To wavelet coefficients (12 channels: 4 subbands × 3 color channels) self.norm_out = nn.GroupNorm(8, channels[-1]) self.conv_out = nn.Conv2d(channels[-1], 12, 3, 1, 1) def forward(self, z: torch.Tensor) -> torch.Tensor: x = self.conv_in(z) x = self.mid(x) for up in self.up_blocks: x = up(x) x = F.silu(self.norm_out(x)) x = self.conv_out(x) # [B, 12, H/2, W/2] wavelet coefficients # Inverse DWT to get image x = self.idwt(x) # [B, 3, H, W] return x class WaveletVAE(nn.Module): """Complete Wavelet VAE with DWT preprocessing.""" def __init__(self, config: IRISConfig): super().__init__() self.encoder = WaveletVAEEncoder(config) self.decoder = WaveletVAEDecoder(config) self.config = config def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.encoder.encode(x) def decode(self, z: torch.Tensor) -> torch.Tensor: return self.decoder(z) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: z, mean, logvar = self.encode(x) x_recon = self.decode(z) return x_recon, mean, logvar # ============================================================================ # Conditioning Modules # ============================================================================ class TimestepEmbedding(nn.Module): """Sinusoidal timestep embedding with MLP projection.""" def __init__(self, dim: int, max_period: int = 10000): super().__init__() self.dim = dim self.max_period = max_period self.mlp = nn.Sequential( nn.Linear(dim, 4 * dim), nn.SiLU(), nn.Linear(4 * dim, dim), ) def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.dim // 2 freqs = torch.exp( -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half ) args = t[:, None] * freqs[None, :] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) return self.mlp(embedding) class IterationEmbedding(nn.Module): """Learnable embedding for iteration index within recurrent core.""" def __init__(self, max_iterations: int, dim: int): super().__init__() self.embedding = nn.Embedding(max_iterations, dim) def forward(self, i: torch.Tensor) -> torch.Tensor: return self.embedding(i) class AdaLNSingle(nn.Module): """Adaptive Layer Normalization (single shared MLP, per-layer bias). From PixArt-α: saves 27% params vs standard adaLN. Produces (scale, shift, gate) for each sub-layer from a shared condition vector. """ def __init__(self, dim: int, num_modulations: int = 6): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(dim, num_modulations * dim) self.num_modulations = num_modulations nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) def forward(self, c: torch.Tensor) -> Tuple[torch.Tensor, ...]: """c: [B, D] condition vector → tuple of num_modulations tensors [B, D].""" params = self.linear(self.silu(c)) return params.chunk(self.num_modulations, dim=-1) # ============================================================================ # GRFM: Gated Recurrent Fourier Mixer (Novel Contribution) # ============================================================================ class FourierMixingPathway(nn.Module): """Pathway 1: Adaptive Fourier Neural Operator-style global mixing. O(N log N) complexity via FFT. Block-diagonal MLP in frequency domain. """ def __init__(self, dim: int, num_blocks: int = 8, sparsity_threshold: float = 0.01): super().__init__() self.dim = dim self.num_blocks = num_blocks self.block_size = dim // num_blocks self.sparsity_threshold = sparsity_threshold # Block-diagonal complex-valued MLP in Fourier domain # Each block: R^{block_size} → R^{block_size} # Using real-valued params for complex ops (split real/imag) self.w1_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) self.w1_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) self.w2_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) self.w2_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) self.b1 = nn.Parameter(torch.zeros(num_blocks, self.block_size)) self.b2 = nn.Parameter(torch.zeros(num_blocks, self.block_size)) def complex_matmul(self, x: torch.Tensor, w_real: torch.Tensor, w_imag: torch.Tensor) -> torch.Tensor: """Complex matrix multiplication: (a+bi)(c+di) = (ac-bd) + (ad+bc)i x: [..., num_blocks, block_size] (complex) w: [num_blocks, block_size, block_size] (real) """ # Use einsum for proper block-diagonal matmul # x: [B, Hf, Wf, K, bs], w: [K, bs, bs] → out: [B, Hf, Wf, K, bs] out_real = torch.einsum('...ki,kij->...kj', x.real, w_real) - torch.einsum('...ki,kij->...kj', x.imag, w_imag) out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real) return torch.complex(out_real, out_imag) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: B, N, D = x.shape x_2d = x.reshape(B, H, W, D) # 2D Real FFT on spatial dimensions x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') # [B, H, W//2+1, D] # Reshape channel dim for block-diagonal MLP: D → (num_blocks, block_size) Hf, Wf = x_freq.shape[1], x_freq.shape[2] x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size) # Block MLP Layer 1: operates on last dim (block_size) # x_freq: [B, Hf, Wf, num_blocks, block_size] # w1: [num_blocks, block_size, block_size] x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag) x_freq = x_freq + self.b1 # Broadcast bias (real only) x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag)) # Block MLP Layer 2 x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag) x_freq = x_freq + self.b2 # Reshape back to [B, Hf, Wf, D] x_freq = x_freq.reshape(B, Hf, Wf, D) # Soft-shrinkage (sparsity in Fourier domain) magnitude = x_freq.abs() shrunk_mag = F.relu(magnitude - self.sparsity_threshold) # Preserve phase, shrink magnitude x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8)) # Inverse FFT x_out = torch.fft.irfft2(x_freq, s=(H, W), dim=(1, 2), norm='ortho') return x_out.reshape(B, N, D) class GatedLinearRecurrence(nn.Module): """Pathway 2: Bidirectional Gated Linear Recurrence (RG-LRU inspired). O(N) complexity with O(1) state per position. h_t = a_t * h_{t-1} + sqrt(1 - a_t^2) * (i_t * x_t) where a_t = sigmoid(Λ)^(c * sigmoid(W_a * x_t)) """ def __init__(self, dim: int, recurrence_dim: int): super().__init__() self.dim = dim self.rec_dim = recurrence_dim # Project to recurrence space self.proj_in = nn.Linear(dim, recurrence_dim * 2) # Forward + backward # Gating parameters self.W_a = nn.Linear(recurrence_dim, recurrence_dim, bias=False) self.W_x = nn.Linear(recurrence_dim, recurrence_dim, bias=False) self.Lambda = nn.Parameter(torch.randn(recurrence_dim) * 0.5 + 2.0) # Init for decay ~0.88-0.95 self.c = 8.0 # Decay scaling constant (from Griffin) # Output projection self.proj_out = nn.Linear(recurrence_dim * 2, dim) def _scan(self, x: torch.Tensor) -> torch.Tensor: """Sequential scan for a single direction. x: [B, N, rec_dim]""" B, N, D = x.shape # Compute gates (can be parallelized) a_base = torch.sigmoid(self.Lambda) # [D] r = torch.sigmoid(self.W_a(x)) # [B, N, D] - recurrence gate i = torch.sigmoid(self.W_x(x)) # [B, N, D] - input gate # a_t = a_base^(c * r_t) — data-dependent decay a = a_base.pow(self.c * r) # [B, N, D] # Normalized input: sqrt(1 - a^2) for variance preservation input_scale = torch.sqrt(1.0 - a * a + 1e-8) scaled_input = input_scale * (i * x) # [B, N, D] # Sequential recurrence (use parallel scan in production) outputs = [] h = torch.zeros(B, D, device=x.device, dtype=x.dtype) for t in range(N): h = a[:, t] * h + scaled_input[:, t] outputs.append(h) return torch.stack(outputs, dim=1) # [B, N, D] def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, D = x.shape # Project to recurrence space and split for bidirectional x_proj = self.proj_in(x) # [B, N, 2*rec_dim] x_fwd, x_bwd = x_proj.chunk(2, dim=-1) # Forward and backward scans h_fwd = self._scan(x_fwd) h_bwd = self._scan(x_bwd.flip(1)).flip(1) # Merge bidirectional h = torch.cat([h_fwd, h_bwd], dim=-1) return self.proj_out(h) class ManhattanSpatialGate(nn.Module): """Pathway 3: Manhattan distance spatial decay gating. Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields. Uses windowed computation for efficiency. """ def __init__(self, dim: int, num_heads: int, window: int = 16): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.window = window # Per-head learnable decay rate # Initialize so gamma ∈ [0.7, 0.95] — multi-scale self.gamma_logit = nn.Parameter(torch.linspace(0.85, 2.94, num_heads)) # sigmoid → [0.7, 0.95] # Value and gate projections self.v_proj = nn.Linear(dim, dim) self.g_proj = nn.Linear(dim, dim) self.o_proj = nn.Linear(dim, dim) def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor: """Compute Manhattan distance matrix between all 2D positions.""" coords = torch.stack(torch.meshgrid( torch.arange(H, device=device), torch.arange(W, device=device), indexing='ij' ), dim=-1).reshape(-1, 2).float() # [N, 2] # Manhattan distance: |x1-x2| + |y1-y2| dist = torch.cdist(coords, coords, p=1) # [N, N] return dist def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: B, N, D = x.shape # Compute spatial decay gamma = torch.sigmoid(self.gamma_logit) # [num_heads] manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N] # Window the distance matrix for efficiency # Only compute decay for positions within window distance decay_mask = (manhattan_dist <= self.window).float() # Per-head decay: gamma_h^dist decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N] decay = decay * decay_mask[None, :, :] # Value and gate v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim) g = torch.sigmoid(self.g_proj(x)) # Apply spatial decay to values # [B, heads, N, head_dim] = [heads, N, N] @ [B, heads, N, head_dim] v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim] out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim] # Normalize by decay sum decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8 # [1, heads, N, 1] out = out / decay_sum out = out.permute(0, 2, 1, 3).reshape(B, N, D) # [B, N, D] out = out * g # Gating return self.o_proj(out) class GRFM(nn.Module): """Gated Recurrent Fourier Mixer — the core innovation of IRIS. Fuses three complementary pathways: 1. Fourier Global Mixing (O(N log N)) — captures textures, patterns 2. Gated Linear Recurrence (O(N)) — captures sequential/local dependencies 3. Manhattan Spatial Gate — provides 2D inductive bias Pathways are combined via learned adaptive gating. """ def __init__(self, config: IRISConfig): super().__init__() D = config.hidden_dim self.fourier = FourierMixingPathway(D, config.fourier_num_blocks, config.sparsity_threshold) self.recurrence = GatedLinearRecurrence(D, config.recurrence_dim) self.spatial = ManhattanSpatialGate(D, config.num_heads, config.manhattan_window) # Adaptive gate: learns to blend Fourier vs Recurrence based on content self.blend_gate = nn.Sequential( nn.Linear(D, D), nn.SiLU(), nn.Linear(D, D), nn.Sigmoid(), ) # Spatial pathway weight (smaller contribution, additive) self.spatial_scale = nn.Parameter(torch.tensor(0.1)) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: # Three pathways x_fourier = self.fourier(x, H, W) x_recurrent = self.recurrence(x) x_spatial = self.spatial(x, H, W) # Adaptive blending gate = self.blend_gate(x) # [B, N, D] values in [0, 1] # Fourier for global structure, recurrence for local detail output = gate * x_fourier + (1 - gate) * x_recurrent # Add spatial bias (small contribution) output = output + self.spatial_scale * x_spatial return output # ============================================================================ # Cross-Attention (for text conditioning) # ============================================================================ class CrossAttention(nn.Module): """Efficient cross-attention for text conditioning. Only 77 text tokens → O(N × 77 × d) per layer, very cheap. """ def __init__(self, dim: int, text_dim: int, num_heads: int, head_dim: int): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = head_dim ** -0.5 self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=False) self.k_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False) self.v_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False) self.o_proj = nn.Linear(num_heads * head_dim, dim) # QK normalization for stability (from SANA-Sprint) self.q_norm = nn.RMSNorm(head_dim) self.k_norm = nn.RMSNorm(head_dim) def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: B, N, _ = x.shape _, S, _ = context.shape q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2) # QK normalization q = self.q_norm(q) k = self.k_norm(k) # Scaled dot-product attention attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = F.softmax(attn, dim=-1) out = torch.matmul(attn, v) out = out.transpose(1, 2).reshape(B, N, -1) return self.o_proj(out) # ============================================================================ # Feed-Forward Network (SwiGLU) # ============================================================================ class SwiGLUFFN(nn.Module): """SwiGLU Feed-Forward Network — better than GELU for transformers.""" def __init__(self, dim: int, ratio: float = 2.667, dropout: float = 0.0): super().__init__() hidden = int(dim * ratio) # Ensure hidden is multiple of 64 for hardware efficiency hidden = ((hidden + 63) // 64) * 64 self.w1 = nn.Linear(dim, hidden, bias=False) self.w2 = nn.Linear(dim, hidden, bias=False) # Gate self.w3 = nn.Linear(hidden, dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x))) # ============================================================================ # Prelude Block (unique weights, conv-based) # ============================================================================ class PreludeBlock(nn.Module): """Lightweight conv-based block for initial feature extraction.""" def __init__(self, dim: int): super().__init__() self.norm1 = nn.LayerNorm(dim) self.dwconv = nn.Conv1d(dim, dim, kernel_size=5, padding=2, groups=dim) self.pointwise = nn.Linear(dim, dim) self.norm2 = nn.LayerNorm(dim) self.ffn = SwiGLUFFN(dim) def forward(self, x: torch.Tensor) -> torch.Tensor: # Depthwise conv path h = self.norm1(x) h = h.transpose(1, 2) # [B, D, N] h = self.dwconv(h).transpose(1, 2) # [B, N, D] h = F.silu(h) h = self.pointwise(h) x = x + h # FFN x = x + self.ffn(self.norm2(x)) return x # ============================================================================ # Core Block (shared weights, the heart of IRIS) # ============================================================================ class CoreLayer(nn.Module): """Single layer within the core block. Contains: GRFM + Cross-Attention + FFN, all with adaLN-Zero conditioning. """ def __init__(self, config: IRISConfig): super().__init__() D = config.hidden_dim # Sub-layer 1: GRFM self.norm1 = nn.LayerNorm(D, elementwise_affine=False) self.grfm = GRFM(config) # Sub-layer 2: Cross-Attention self.norm2 = nn.LayerNorm(D, elementwise_affine=False) self.cross_attn = CrossAttention(D, config.text_dim, config.num_heads, config.head_dim) # Sub-layer 3: FFN self.norm3 = nn.LayerNorm(D, elementwise_affine=False) self.ffn = SwiGLUFFN(D, config.ffn_ratio, config.dropout) # adaLN-Zero: 9 modulations (scale1, shift1, gate1, scale2, shift2, gate2, scale3, shift3, gate3) self.adaln = AdaLNSingle(D, num_modulations=9) def _modulate(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor, H: int, W: int) -> torch.Tensor: """ x: [B, N, D] — token sequence c: [B, D] — conditioning vector (timestep + iteration) text_tokens: [B, S, text_dim] — CLIP text tokens H, W: spatial dimensions of token grid """ s1, sh1, g1, s2, sh2, g2, s3, sh3, g3 = self.adaln(c) # GRFM with adaLN-Zero h = self._modulate(self.norm1(x), s1, sh1) h = self.grfm(h, H, W) x = x + g1.unsqueeze(1) * h # Cross-attention with adaLN-Zero h = self._modulate(self.norm2(x), s2, sh2) h = self.cross_attn(h, text_tokens) x = x + g2.unsqueeze(1) * h # FFN with adaLN-Zero h = self._modulate(self.norm3(x), s3, sh3) h = self.ffn(h) x = x + g3.unsqueeze(1) * h return x class CoreBlock(nn.Module): """The shared-weight core block, iterated r times. Contains multiple CoreLayers to give sufficient per-iteration capacity. """ def __init__(self, config: IRISConfig): super().__init__() self.layers = nn.ModuleList([ CoreLayer(config) for _ in range(config.num_core_layers) ]) def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor, H: int, W: int) -> torch.Tensor: for layer in self.layers: x = layer(x, c, text_tokens, H, W) return x # ============================================================================ # Coda Block (unique weights, final refinement) # ============================================================================ class LocalWindowAttention(nn.Module): """Window-based local attention for final refinement. Small window (8×8) for efficient local detail refinement. """ def __init__(self, dim: int, num_heads: int, head_dim: int, window_size: int = 8): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.window_size = window_size self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, 3 * num_heads * head_dim, bias=False) self.o_proj = nn.Linear(num_heads * head_dim, dim) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: B, N, D = x.shape ws = self.window_size # Reshape to 2D and partition into windows x_2d = x.reshape(B, H, W, D) # Pad if necessary pad_h = (ws - H % ws) % ws pad_w = (ws - W % ws) % ws if pad_h > 0 or pad_w > 0: x_2d = F.pad(x_2d, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = x_2d.shape[1], x_2d.shape[2] nH, nW = Hp // ws, Wp // ws # [B, nH, ws, nW, ws, D] → [B*nH*nW, ws*ws, D] x_win = x_2d.reshape(B, nH, ws, nW, ws, D) x_win = x_win.permute(0, 1, 3, 2, 4, 5).reshape(-1, ws * ws, D) # QKV and attention within windows qkv = self.qkv(x_win).reshape(-1, ws * ws, 3, self.num_heads, self.head_dim) q, k, v = qkv.unbind(2) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = F.softmax(attn, dim=-1) out = torch.matmul(attn, v) out = out.transpose(1, 2).reshape(-1, ws * ws, self.num_heads * self.head_dim) out = self.o_proj(out) # Unpartition out = out.reshape(B, nH, nW, ws, ws, D) out = out.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, D) # Remove padding out = out[:, :H, :W, :].reshape(B, N, D) return out class CodaBlock(nn.Module): """Final refinement block with local window attention.""" def __init__(self, config: IRISConfig): super().__init__() D = config.hidden_dim self.norm1 = nn.LayerNorm(D) self.attn = LocalWindowAttention(D, config.num_heads, config.head_dim, window_size=8) self.norm2 = nn.LayerNorm(D) self.ffn = SwiGLUFFN(D, config.ffn_ratio) def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: x = x + self.attn(self.norm1(x), H, W) x = x + self.ffn(self.norm2(x)) return x # ============================================================================ # IRIS Generator (Main Model) # ============================================================================ class IRISGenerator(nn.Module): """ IRIS: Iterative Recurrent Image Synthesis The main denoising network with Prelude-Core-Coda structure. Predicts velocity field v for rectified flow training. """ def __init__(self, config: IRISConfig): super().__init__() self.config = config D = config.hidden_dim # Patch embedding: latent patches → tokens self.patch_embed = nn.Linear(config.patch_dim, D) # Positional embedding (learned) self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, D) * 0.02) # Conditioning self.time_embed = TimestepEmbedding(D) self.iter_embed = IterationEmbedding(config.max_iterations, D) self.text_proj = nn.Linear(config.text_dim, D) # Project CLIP text to model dim # Global text pooling for conditioning self.text_pool_proj = nn.Sequential( nn.Linear(config.text_dim, D), nn.SiLU(), nn.Linear(D, D), ) # Prelude (unique weights) self.prelude = nn.ModuleList([PreludeBlock(D) for _ in range(config.num_prelude_blocks)]) # Core (shared weights, iterated) self.core = CoreBlock(config) # Long skip connection (from Diffusion-RWKV: linear(cat(shallow, deep))) self.skip_proj = nn.Linear(2 * D, D) # Coda (unique weights) self.coda = nn.ModuleList([CodaBlock(config) for _ in range(config.num_coda_blocks)]) # Output projection: tokens → latent patches self.final_norm = nn.LayerNorm(D) self.output_proj = nn.Linear(D, config.patch_dim) # Zero-init output for stable training start nn.init.zeros_(self.output_proj.weight) nn.init.zeros_(self.output_proj.bias) # Precompute patch spatial dimensions self.patch_h = config.latent_spatial // config.patch_size self.patch_w = config.latent_spatial // config.patch_size def patchify(self, z: torch.Tensor) -> torch.Tensor: """Convert latent z [B, C, H, W] → patches [B, N, patch_dim].""" B, C, H, W = z.shape p = self.config.patch_size z = z.reshape(B, C, H // p, p, W // p, p) z = z.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * p * p) return z def unpatchify(self, x: torch.Tensor) -> torch.Tensor: """Convert patches [B, N, patch_dim] → latent [B, C, H, W].""" B, N, _ = x.shape p = self.config.patch_size C = self.config.latent_channels H = self.patch_h W = self.patch_w x = x.reshape(B, H, W, C, p, p) x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H * p, W * p) return x def forward( self, z_t: torch.Tensor, # Noisy latent [B, C, H, W] t: torch.Tensor, # Timestep [B] in [0, 1] text_tokens: torch.Tensor, # CLIP text embeddings [B, S, text_dim] num_iterations: Optional[int] = None, # Override iteration count ) -> torch.Tensor: """Predict velocity field v(z_t, t, c) for rectified flow.""" B = z_t.shape[0] r = num_iterations or self.config.default_iterations H, W = self.patch_h, self.patch_w # Patchify and embed x = self.patch_embed(self.patchify(z_t)) + self.pos_embed # Timestep conditioning t_emb = self.time_embed(t * self.config.num_timesteps) # [B, D] # Text conditioning (project to model dim for cross-attention) text_projected = self.text_proj(text_tokens) # [B, S, D] # Global text pool for adaLN conditioning text_global = self.text_pool_proj(text_tokens.mean(dim=1)) # [B, D] # ============ PRELUDE ============ for block in self.prelude: x = block(x) # Save for long skip connection x_shallow = x # ============ CORE (iterated r times) ============ for i in range(r): # Iteration-aware conditioning iter_idx = torch.full((B,), i, device=z_t.device, dtype=torch.long) i_emb = self.iter_embed(iter_idx) # [B, D] # Combined conditioning: timestep + iteration + text global c = t_emb + i_emb + text_global # [B, D] # Apply shared core block (pass original text_tokens for cross-attention) x = self.core(x, c, text_tokens, H, W) # Long skip connection (from Diffusion-RWKV paper) x = self.skip_proj(torch.cat([x_shallow, x], dim=-1)) # ============ CODA ============ for block in self.coda: x = block(x, H, W) # Output projection x = self.final_norm(x) x = self.output_proj(x) # Unpatchify to latent shape v_pred = self.unpatchify(x) return v_pred # ============================================================================ # Full IRIS System # ============================================================================ class IRIS(nn.Module): """Complete IRIS system: VAE + Generator. For training: use train_step() which handles noise scheduling. For inference: use generate() which runs the full pipeline. """ def __init__(self, config: IRISConfig): super().__init__() self.config = config self.vae = WaveletVAE(config) self.generator = IRISGenerator(config) def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Encode images to latent space.""" return self.vae.encode(images) def decode(self, z: torch.Tensor) -> torch.Tensor: """Decode latent to images.""" return self.vae.decode(z) def get_velocity_target(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: """Rectified flow velocity target: v = noise - z_0.""" return noise - z_0 def add_noise(self, z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Rectified flow forward process: z_t = (1-t)*z_0 + t*noise.""" t_expand = t[:, None, None, None] return (1 - t_expand) * z_0 + t_expand * noise def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: """Sample timesteps from logit-normal distribution (from SD3/RF). Concentrates sampling on intermediate timesteps where learning is hardest. """ u = torch.randn(batch_size, device=device) t = torch.sigmoid(u) # Logit-normal with mean=0, std=1 # Clamp to avoid t=0 and t=1 t = t.clamp(1e-5, 1 - 1e-5) return t def train_step( self, images: torch.Tensor, text_tokens: torch.Tensor, num_iterations: Optional[int] = None, ) -> dict: """Single training step for rectified flow. Returns dict with loss and diagnostics. """ B = images.shape[0] device = images.device # Encode to latent z_0, mean, logvar = self.encode(images) # Sample noise and timesteps noise = torch.randn_like(z_0) t = self.sample_timesteps(B, device) # Create noisy latent z_t = self.add_noise(z_0, noise, t) # Predict velocity # Randomly sample iteration count for training robustness if num_iterations is None: r_choices = [4, 6, 8, 10, 12] r = r_choices[torch.randint(0, len(r_choices), (1,)).item()] else: r = num_iterations v_pred = self.generator(z_t, t, text_tokens, num_iterations=r) v_target = self.get_velocity_target(z_0, noise) # SNR-weighted loss (from Rectified Flow paper) # w(t) = t / (1 - t) — emphasizes high-noise timesteps w = t / (1 - t + 1e-8) w = w[:, None, None, None] # Velocity matching loss velocity_loss = (w * (v_pred - v_target).pow(2)).mean() # VAE KL loss kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean() return { 'loss': velocity_loss + 0.001 * kl_loss, 'velocity_loss': velocity_loss.item(), 'kl_loss': kl_loss.item(), 'mean_t': t.mean().item(), } @torch.no_grad() def generate( self, text_tokens: torch.Tensor, num_steps: int = 4, num_iterations: int = 8, cfg_scale: float = 4.0, seed: Optional[int] = None, ) -> torch.Tensor: """Generate images from text conditioning using Euler solver. Args: text_tokens: [B, S, text_dim] CLIP text embeddings num_steps: Number of ODE solver steps (1-50) num_iterations: Core iterations per step (quality budget) cfg_scale: Classifier-free guidance scale seed: Random seed for reproducibility """ B, S, _ = text_tokens.shape device = text_tokens.device if seed is not None: torch.manual_seed(seed) # Start from pure noise z = torch.randn(B, self.config.latent_channels, self.config.latent_spatial, self.config.latent_spatial, device=device) # Euler solver for rectified flow ODE: dz/dt = -v(z, t) # Integrate from t=1 (noise) to t=0 (data) dt = 1.0 / num_steps for step in range(num_steps): t_val = 1.0 - step * dt t = torch.full((B,), t_val, device=device) # Predict velocity v = self.generator(z, t, text_tokens, num_iterations=num_iterations) # Classifier-free guidance (if cfg_scale > 1) if cfg_scale > 1.0: null_tokens = torch.zeros_like(text_tokens) v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations) v = v_uncond + cfg_scale * (v - v_uncond) # Euler step: z = z - dt * v z = z - dt * v # Decode to image images = self.decode(z) images = images.clamp(-1, 1) return images # ============================================================================ # Utility Functions # ============================================================================ def count_parameters(model: nn.Module) -> dict: """Count parameters in each component.""" counts = {} total = 0 for name, module in model.named_children(): n = sum(p.numel() for p in module.parameters()) counts[name] = n total += n counts['total'] = total # Separate trainable vs frozen trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) counts['trainable'] = trainable return counts def estimate_memory_mb(model: nn.Module, dtype=torch.float16) -> float: """Estimate model memory in MB.""" bytes_per_param = 2 if dtype == torch.float16 else 4 total_params = sum(p.numel() for p in model.parameters()) return total_params * bytes_per_param / (1024 * 1024) def create_iris_small(latent_spatial: int = 32) -> IRIS: """Create IRIS-Small: ~75M generator params, suitable for mobile.""" config = IRISConfig( latent_channels=16, latent_spatial=latent_spatial, hidden_dim=512, num_heads=8, head_dim=64, ffn_ratio=2.667, num_prelude_blocks=2, num_core_layers=4, num_coda_blocks=2, default_iterations=8, max_iterations=16, fourier_num_blocks=8, sparsity_threshold=0.01, recurrence_dim=256, manhattan_window=16, text_dim=768, max_text_tokens=77, patch_size=2, ) return IRIS(config) def create_iris_tiny(latent_spatial: int = 32) -> IRIS: """Create IRIS-Tiny: ~30M generator params, ultra-mobile.""" config = IRISConfig( latent_channels=8, latent_spatial=latent_spatial, hidden_dim=384, num_heads=6, head_dim=64, ffn_ratio=2.667, num_prelude_blocks=1, num_core_layers=3, num_coda_blocks=1, default_iterations=8, max_iterations=16, fourier_num_blocks=6, sparsity_threshold=0.01, recurrence_dim=192, manhattan_window=12, text_dim=768, max_text_tokens=77, patch_size=2, ) return IRIS(config) def create_iris_base(latent_spatial: int = 32) -> IRIS: """Create IRIS-Base: ~150M generator params, quality-focused.""" config = IRISConfig( latent_channels=16, latent_spatial=latent_spatial, hidden_dim=768, num_heads=12, head_dim=64, ffn_ratio=2.667, num_prelude_blocks=2, num_core_layers=6, num_coda_blocks=2, default_iterations=8, max_iterations=16, fourier_num_blocks=12, sparsity_threshold=0.01, recurrence_dim=384, manhattan_window=16, text_dim=768, max_text_tokens=77, patch_size=2, ) return IRIS(config) if __name__ == "__main__": print("=" * 70) print("IRIS: Iterative Recurrent Image Synthesis") print("=" * 70) # Create model variants for name, create_fn in [("IRIS-Tiny", create_iris_tiny), ("IRIS-Small", create_iris_small), ("IRIS-Base", create_iris_base)]: print(f"\n{'─' * 50}") print(f" {name}") print(f"{'─' * 50}") model = create_fn() counts = count_parameters(model) mem_fp16 = estimate_memory_mb(model, torch.float16) mem_fp32 = estimate_memory_mb(model, torch.float32) print(f" Total params: {counts['total']:>12,}") print(f" Trainable params: {counts['trainable']:>12,}") print(f" Memory (fp16): {mem_fp16:>10.1f} MB") print(f" Memory (fp32): {mem_fp32:>10.1f} MB") print(f" Components:") for k, v in counts.items(): if k not in ('total', 'trainable'): print(f" {k:20s}: {v:>10,} ({v/counts['total']*100:.1f}%)")