IRIS-architecture / iris_model.py
asdf98's picture
Add iris_model.py
774e194 verified
raw
history blame
47.5 kB
"""
IRIS: Iterative Recurrent Image Synthesis
==========================================
A novel architecture for mobile-first high-quality image generation.
Key innovations:
1. Wavelet-Frequency Latent Space (Haar DWT + lightweight VAE)
2. Recurrent Depth Core (Prelude-Core-Coda with shared weights)
3. Gated Recurrent Fourier Mixer (GRFM) — novel token mixing
4. Manhattan Spatial Decay — learned 2D inductive bias
5. Rectified Flow training with consistency distillation support
6. Adaptive compute budget (4-16 iterations, same model)
Author: IRIS Research
License: Apache 2.0
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from dataclasses import dataclass, field
# ============================================================================
# Configuration
# ============================================================================
@dataclass
class IRISConfig:
"""Configuration for IRIS model."""
# Latent space
latent_channels: int = 16 # Channels in latent space
latent_spatial: int = 32 # Spatial dim of latent (for 512px with 16x compression)
# Model dimensions
hidden_dim: int = 512 # Main hidden dimension
num_heads: int = 8 # Number of attention heads
head_dim: int = 64 # Dimension per head
ffn_ratio: float = 2.667 # FFN expansion ratio (SwiGLU-adjusted)
# Architecture structure
num_prelude_blocks: int = 2 # Prelude blocks (unique weights)
num_core_layers: int = 4 # Layers WITHIN each core iteration
num_coda_blocks: int = 2 # Coda blocks (unique weights)
default_iterations: int = 8 # Default core iterations at inference
max_iterations: int = 16 # Maximum core iterations
# GRFM settings
fourier_num_blocks: int = 8 # Block-diagonal blocks in Fourier MLP
sparsity_threshold: float = 0.01 # Soft-shrinkage lambda
recurrence_dim: int = 256 # Dimension for gated recurrence pathway
manhattan_window: int = 16 # Windowed Manhattan decay (for efficiency)
# Cross-attention
text_dim: int = 768 # CLIP-L/14 text embedding dim
max_text_tokens: int = 77 # Maximum text sequence length
# Patch embedding
patch_size: int = 2 # Patches in latent space (2×2)
# Conditioning
num_timesteps: int = 1000 # Noise schedule discretization
# VAE
vae_channels: list = field(default_factory=lambda: [32, 64, 128, 256])
# Training
dropout: float = 0.0
@property
def vae_latent_channels(self) -> int:
"""VAE latent channels must match generator latent channels."""
return self.latent_channels
@property
def num_patches(self) -> int:
return (self.latent_spatial // self.patch_size) ** 2
@property
def patch_dim(self) -> int:
return self.latent_channels * self.patch_size * self.patch_size
# ============================================================================
# Wavelet Transforms (Haar)
# ============================================================================
class HaarDWT2D(nn.Module):
"""2D Discrete Wavelet Transform using Haar wavelets.
Decomposes x ∈ R^{B,C,H,W} into R^{B,4C,H/2,W/2} (LL, LH, HL, HH subbands).
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Haar DWT: split into even/odd along both spatial dims
x_ll = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] +
x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5
x_lh = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] -
x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5
x_hl = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] +
x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5
x_hh = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] -
x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5
return torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1)
class HaarIDWT2D(nn.Module):
"""2D Inverse Discrete Wavelet Transform (Haar).
Reconstructs x ∈ R^{B,C,H,W} from R^{B,4*(C//4),H/2,W/2}.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C4, Hh, Wh = x.shape
C = C4 // 4
ll, lh, hl, hh = x[:, :C], x[:, C:2*C], x[:, 2*C:3*C], x[:, 3*C:]
# Reconstruct 2× spatial resolution
H, W = Hh * 2, Wh * 2
out = torch.zeros(B, C, H, W, device=x.device, dtype=x.dtype)
out[:, :, 0::2, 0::2] = (ll + lh + hl + hh) * 0.5
out[:, :, 0::2, 1::2] = (ll + lh - hl - hh) * 0.5
out[:, :, 1::2, 0::2] = (ll - lh + hl - hh) * 0.5
out[:, :, 1::2, 1::2] = (ll - lh - hl + hh) * 0.5
return out
# ============================================================================
# Lightweight Wavelet VAE
# ============================================================================
class DepthwiseSeparableConv(nn.Module):
"""Depthwise separable convolution — key mobile optimization."""
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1):
super().__init__()
self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch)
self.pointwise = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
return self.pointwise(self.depthwise(x))
class ResBlock(nn.Module):
"""Residual block with depthwise separable convolutions."""
def __init__(self, channels):
super().__init__()
self.norm1 = nn.GroupNorm(8, channels)
self.conv1 = DepthwiseSeparableConv(channels, channels)
self.norm2 = nn.GroupNorm(8, channels)
self.conv2 = DepthwiseSeparableConv(channels, channels)
# Zero-init final layer for residual learning stability
nn.init.zeros_(self.conv2.pointwise.weight)
nn.init.zeros_(self.conv2.pointwise.bias)
def forward(self, x):
h = F.silu(self.norm1(x))
h = self.conv1(h)
h = F.silu(self.norm2(h))
h = self.conv2(h)
return x + h
class WaveletVAEEncoder(nn.Module):
"""Lightweight encoder: Haar DWT preprocessing + small convolutional encoder.
Input: images R^{B,3,H,W} → Output: latent R^{B,C_latent,H/16,W/16}
Compression: 3×H×W → C_latent×(H/16)×(W/16)
"""
def __init__(self, config: IRISConfig):
super().__init__()
self.dwt = HaarDWT2D()
channels = config.vae_channels
latent_ch = config.vae_latent_channels
# DWT: 3 channels → 12 channels at H/2 × W/2
self.conv_in = nn.Conv2d(12, channels[0], 3, 1, 1)
# Downsampling path: H/2→H/4→H/8→H/16
self.down_blocks = nn.ModuleList()
for i in range(len(channels) - 1):
self.down_blocks.append(nn.Sequential(
ResBlock(channels[i]),
nn.Conv2d(channels[i], channels[i+1], 3, 2, 1), # 2× downsample
))
# Bottleneck
self.mid = nn.Sequential(
ResBlock(channels[-1]),
ResBlock(channels[-1]),
)
# To latent (mean + logvar)
self.norm_out = nn.GroupNorm(8, channels[-1])
self.conv_out = nn.Conv2d(channels[-1], 2 * latent_ch, 1)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Haar DWT preprocessing
x = self.dwt(x) # [B, 12, H/2, W/2]
x = self.conv_in(x)
for down in self.down_blocks:
x = down(x)
x = self.mid(x)
x = F.silu(self.norm_out(x))
x = self.conv_out(x)
mean, logvar = x.chunk(2, dim=1)
logvar = torch.clamp(logvar, -30.0, 20.0)
return mean, logvar
def encode(self, x: torch.Tensor) -> torch.Tensor:
mean, logvar = self.forward(x)
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
return z, mean, logvar
class WaveletVAEDecoder(nn.Module):
"""Tiny decoder: latent → wavelet coefficients → Haar IDWT → image.
Designed to be as small as possible for mobile inference.
"""
def __init__(self, config: IRISConfig):
super().__init__()
channels = list(reversed(config.vae_channels))
latent_ch = config.vae_latent_channels
self.idwt = HaarIDWT2D()
# From latent
self.conv_in = nn.Conv2d(latent_ch, channels[0], 3, 1, 1)
# Bottleneck
self.mid = nn.Sequential(
ResBlock(channels[0]),
)
# Upsampling path
self.up_blocks = nn.ModuleList()
for i in range(len(channels) - 1):
self.up_blocks.append(nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
DepthwiseSeparableConv(channels[i], channels[i+1]),
nn.SiLU(),
ResBlock(channels[i+1]),
))
# To wavelet coefficients (12 channels: 4 subbands × 3 color channels)
self.norm_out = nn.GroupNorm(8, channels[-1])
self.conv_out = nn.Conv2d(channels[-1], 12, 3, 1, 1)
def forward(self, z: torch.Tensor) -> torch.Tensor:
x = self.conv_in(z)
x = self.mid(x)
for up in self.up_blocks:
x = up(x)
x = F.silu(self.norm_out(x))
x = self.conv_out(x) # [B, 12, H/2, W/2] wavelet coefficients
# Inverse DWT to get image
x = self.idwt(x) # [B, 3, H, W]
return x
class WaveletVAE(nn.Module):
"""Complete Wavelet VAE with DWT preprocessing."""
def __init__(self, config: IRISConfig):
super().__init__()
self.encoder = WaveletVAEEncoder(config)
self.decoder = WaveletVAEDecoder(config)
self.config = config
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.encoder.encode(x)
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
z, mean, logvar = self.encode(x)
x_recon = self.decode(z)
return x_recon, mean, logvar
# ============================================================================
# Conditioning Modules
# ============================================================================
class TimestepEmbedding(nn.Module):
"""Sinusoidal timestep embedding with MLP projection."""
def __init__(self, dim: int, max_period: int = 10000):
super().__init__()
self.dim = dim
self.max_period = max_period
self.mlp = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.SiLU(),
nn.Linear(4 * dim, dim),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.dim // 2
freqs = torch.exp(
-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half
)
args = t[:, None] * freqs[None, :]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return self.mlp(embedding)
class IterationEmbedding(nn.Module):
"""Learnable embedding for iteration index within recurrent core."""
def __init__(self, max_iterations: int, dim: int):
super().__init__()
self.embedding = nn.Embedding(max_iterations, dim)
def forward(self, i: torch.Tensor) -> torch.Tensor:
return self.embedding(i)
class AdaLNSingle(nn.Module):
"""Adaptive Layer Normalization (single shared MLP, per-layer bias).
From PixArt-α: saves 27% params vs standard adaLN.
Produces (scale, shift, gate) for each sub-layer from a shared condition vector.
"""
def __init__(self, dim: int, num_modulations: int = 6):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, num_modulations * dim)
self.num_modulations = num_modulations
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, c: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""c: [B, D] condition vector → tuple of num_modulations tensors [B, D]."""
params = self.linear(self.silu(c))
return params.chunk(self.num_modulations, dim=-1)
# ============================================================================
# GRFM: Gated Recurrent Fourier Mixer (Novel Contribution)
# ============================================================================
class FourierMixingPathway(nn.Module):
"""Pathway 1: Adaptive Fourier Neural Operator-style global mixing.
O(N log N) complexity via FFT. Block-diagonal MLP in frequency domain.
"""
def __init__(self, dim: int, num_blocks: int = 8, sparsity_threshold: float = 0.01):
super().__init__()
self.dim = dim
self.num_blocks = num_blocks
self.block_size = dim // num_blocks
self.sparsity_threshold = sparsity_threshold
# Block-diagonal complex-valued MLP in Fourier domain
# Each block: R^{block_size} → R^{block_size}
# Using real-valued params for complex ops (split real/imag)
self.w1_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
self.w1_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
self.w2_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
self.w2_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02)
self.b1 = nn.Parameter(torch.zeros(num_blocks, self.block_size))
self.b2 = nn.Parameter(torch.zeros(num_blocks, self.block_size))
def complex_matmul(self, x: torch.Tensor, w_real: torch.Tensor, w_imag: torch.Tensor) -> torch.Tensor:
"""Complex matrix multiplication: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
x: [..., num_blocks, block_size] (complex)
w: [num_blocks, block_size, block_size] (real)
"""
# Use einsum for proper block-diagonal matmul
# x: [B, Hf, Wf, K, bs], w: [K, bs, bs] → out: [B, Hf, Wf, K, bs]
out_real = torch.einsum('...ki,kij->...kj', x.real, w_real) - torch.einsum('...ki,kij->...kj', x.imag, w_imag)
out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real)
return torch.complex(out_real, out_imag)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, D = x.shape
x_2d = x.reshape(B, H, W, D)
# 2D Real FFT on spatial dimensions
x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') # [B, H, W//2+1, D]
# Reshape channel dim for block-diagonal MLP: D → (num_blocks, block_size)
Hf, Wf = x_freq.shape[1], x_freq.shape[2]
x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size)
# Block MLP Layer 1: operates on last dim (block_size)
# x_freq: [B, Hf, Wf, num_blocks, block_size]
# w1: [num_blocks, block_size, block_size]
x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag)
x_freq = x_freq + self.b1 # Broadcast bias (real only)
x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag))
# Block MLP Layer 2
x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag)
x_freq = x_freq + self.b2
# Reshape back to [B, Hf, Wf, D]
x_freq = x_freq.reshape(B, Hf, Wf, D)
# Soft-shrinkage (sparsity in Fourier domain)
magnitude = x_freq.abs()
shrunk_mag = F.relu(magnitude - self.sparsity_threshold)
# Preserve phase, shrink magnitude
x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8))
# Inverse FFT
x_out = torch.fft.irfft2(x_freq, s=(H, W), dim=(1, 2), norm='ortho')
return x_out.reshape(B, N, D)
class GatedLinearRecurrence(nn.Module):
"""Pathway 2: Bidirectional Gated Linear Recurrence (RG-LRU inspired).
O(N) complexity with O(1) state per position.
h_t = a_t * h_{t-1} + sqrt(1 - a_t^2) * (i_t * x_t)
where a_t = sigmoid(Λ)^(c * sigmoid(W_a * x_t))
"""
def __init__(self, dim: int, recurrence_dim: int):
super().__init__()
self.dim = dim
self.rec_dim = recurrence_dim
# Project to recurrence space
self.proj_in = nn.Linear(dim, recurrence_dim * 2) # Forward + backward
# Gating parameters
self.W_a = nn.Linear(recurrence_dim, recurrence_dim, bias=False)
self.W_x = nn.Linear(recurrence_dim, recurrence_dim, bias=False)
self.Lambda = nn.Parameter(torch.randn(recurrence_dim) * 0.5 + 2.0) # Init for decay ~0.88-0.95
self.c = 8.0 # Decay scaling constant (from Griffin)
# Output projection
self.proj_out = nn.Linear(recurrence_dim * 2, dim)
def _scan(self, x: torch.Tensor) -> torch.Tensor:
"""Sequential scan for a single direction. x: [B, N, rec_dim]"""
B, N, D = x.shape
# Compute gates (can be parallelized)
a_base = torch.sigmoid(self.Lambda) # [D]
r = torch.sigmoid(self.W_a(x)) # [B, N, D] - recurrence gate
i = torch.sigmoid(self.W_x(x)) # [B, N, D] - input gate
# a_t = a_base^(c * r_t) — data-dependent decay
a = a_base.pow(self.c * r) # [B, N, D]
# Normalized input: sqrt(1 - a^2) for variance preservation
input_scale = torch.sqrt(1.0 - a * a + 1e-8)
scaled_input = input_scale * (i * x) # [B, N, D]
# Sequential recurrence (use parallel scan in production)
outputs = []
h = torch.zeros(B, D, device=x.device, dtype=x.dtype)
for t in range(N):
h = a[:, t] * h + scaled_input[:, t]
outputs.append(h)
return torch.stack(outputs, dim=1) # [B, N, D]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, D = x.shape
# Project to recurrence space and split for bidirectional
x_proj = self.proj_in(x) # [B, N, 2*rec_dim]
x_fwd, x_bwd = x_proj.chunk(2, dim=-1)
# Forward and backward scans
h_fwd = self._scan(x_fwd)
h_bwd = self._scan(x_bwd.flip(1)).flip(1)
# Merge bidirectional
h = torch.cat([h_fwd, h_bwd], dim=-1)
return self.proj_out(h)
class ManhattanSpatialGate(nn.Module):
"""Pathway 3: Manhattan distance spatial decay gating.
Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields.
Uses windowed computation for efficiency.
"""
def __init__(self, dim: int, num_heads: int, window: int = 16):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window = window
# Per-head learnable decay rate
# Initialize so gamma ∈ [0.7, 0.95] — multi-scale
self.gamma_logit = nn.Parameter(torch.linspace(0.85, 2.94, num_heads)) # sigmoid → [0.7, 0.95]
# Value and gate projections
self.v_proj = nn.Linear(dim, dim)
self.g_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor:
"""Compute Manhattan distance matrix between all 2D positions."""
coords = torch.stack(torch.meshgrid(
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
), dim=-1).reshape(-1, 2).float() # [N, 2]
# Manhattan distance: |x1-x2| + |y1-y2|
dist = torch.cdist(coords, coords, p=1) # [N, N]
return dist
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, D = x.shape
# Compute spatial decay
gamma = torch.sigmoid(self.gamma_logit) # [num_heads]
manhattan_dist = self._get_manhattan_mask(H, W, x.device) # [N, N]
# Window the distance matrix for efficiency
# Only compute decay for positions within window distance
decay_mask = (manhattan_dist <= self.window).float()
# Per-head decay: gamma_h^dist
decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) # [heads, N, N]
decay = decay * decay_mask[None, :, :]
# Value and gate
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim)
g = torch.sigmoid(self.g_proj(x))
# Apply spatial decay to values
# [B, heads, N, head_dim] = [heads, N, N] @ [B, heads, N, head_dim]
v = v.permute(0, 2, 1, 3) # [B, heads, N, head_dim]
out = torch.matmul(decay.unsqueeze(0), v) # [B, heads, N, head_dim]
# Normalize by decay sum
decay_sum = decay.sum(dim=-1, keepdim=True).unsqueeze(0) + 1e-8 # [1, heads, N, 1]
out = out / decay_sum
out = out.permute(0, 2, 1, 3).reshape(B, N, D) # [B, N, D]
out = out * g # Gating
return self.o_proj(out)
class GRFM(nn.Module):
"""Gated Recurrent Fourier Mixer — the core innovation of IRIS.
Fuses three complementary pathways:
1. Fourier Global Mixing (O(N log N)) — captures textures, patterns
2. Gated Linear Recurrence (O(N)) — captures sequential/local dependencies
3. Manhattan Spatial Gate — provides 2D inductive bias
Pathways are combined via learned adaptive gating.
"""
def __init__(self, config: IRISConfig):
super().__init__()
D = config.hidden_dim
self.fourier = FourierMixingPathway(D, config.fourier_num_blocks, config.sparsity_threshold)
self.recurrence = GatedLinearRecurrence(D, config.recurrence_dim)
self.spatial = ManhattanSpatialGate(D, config.num_heads, config.manhattan_window)
# Adaptive gate: learns to blend Fourier vs Recurrence based on content
self.blend_gate = nn.Sequential(
nn.Linear(D, D),
nn.SiLU(),
nn.Linear(D, D),
nn.Sigmoid(),
)
# Spatial pathway weight (smaller contribution, additive)
self.spatial_scale = nn.Parameter(torch.tensor(0.1))
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
# Three pathways
x_fourier = self.fourier(x, H, W)
x_recurrent = self.recurrence(x)
x_spatial = self.spatial(x, H, W)
# Adaptive blending
gate = self.blend_gate(x) # [B, N, D] values in [0, 1]
# Fourier for global structure, recurrence for local detail
output = gate * x_fourier + (1 - gate) * x_recurrent
# Add spatial bias (small contribution)
output = output + self.spatial_scale * x_spatial
return output
# ============================================================================
# Cross-Attention (for text conditioning)
# ============================================================================
class CrossAttention(nn.Module):
"""Efficient cross-attention for text conditioning.
Only 77 text tokens → O(N × 77 × d) per layer, very cheap.
"""
def __init__(self, dim: int, text_dim: int, num_heads: int, head_dim: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=False)
self.k_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False)
self.v_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * head_dim, dim)
# QK normalization for stability (from SANA-Sprint)
self.q_norm = nn.RMSNorm(head_dim)
self.k_norm = nn.RMSNorm(head_dim)
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
_, S, _ = context.shape
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2)
# QK normalization
q = self.q_norm(q)
k = self.k_norm(k)
# Scaled dot-product attention
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).reshape(B, N, -1)
return self.o_proj(out)
# ============================================================================
# Feed-Forward Network (SwiGLU)
# ============================================================================
class SwiGLUFFN(nn.Module):
"""SwiGLU Feed-Forward Network — better than GELU for transformers."""
def __init__(self, dim: int, ratio: float = 2.667, dropout: float = 0.0):
super().__init__()
hidden = int(dim * ratio)
# Ensure hidden is multiple of 64 for hardware efficiency
hidden = ((hidden + 63) // 64) * 64
self.w1 = nn.Linear(dim, hidden, bias=False)
self.w2 = nn.Linear(dim, hidden, bias=False) # Gate
self.w3 = nn.Linear(hidden, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x)))
# ============================================================================
# Prelude Block (unique weights, conv-based)
# ============================================================================
class PreludeBlock(nn.Module):
"""Lightweight conv-based block for initial feature extraction."""
def __init__(self, dim: int):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.dwconv = nn.Conv1d(dim, dim, kernel_size=5, padding=2, groups=dim)
self.pointwise = nn.Linear(dim, dim)
self.norm2 = nn.LayerNorm(dim)
self.ffn = SwiGLUFFN(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Depthwise conv path
h = self.norm1(x)
h = h.transpose(1, 2) # [B, D, N]
h = self.dwconv(h).transpose(1, 2) # [B, N, D]
h = F.silu(h)
h = self.pointwise(h)
x = x + h
# FFN
x = x + self.ffn(self.norm2(x))
return x
# ============================================================================
# Core Block (shared weights, the heart of IRIS)
# ============================================================================
class CoreLayer(nn.Module):
"""Single layer within the core block.
Contains: GRFM + Cross-Attention + FFN, all with adaLN-Zero conditioning.
"""
def __init__(self, config: IRISConfig):
super().__init__()
D = config.hidden_dim
# Sub-layer 1: GRFM
self.norm1 = nn.LayerNorm(D, elementwise_affine=False)
self.grfm = GRFM(config)
# Sub-layer 2: Cross-Attention
self.norm2 = nn.LayerNorm(D, elementwise_affine=False)
self.cross_attn = CrossAttention(D, config.text_dim, config.num_heads, config.head_dim)
# Sub-layer 3: FFN
self.norm3 = nn.LayerNorm(D, elementwise_affine=False)
self.ffn = SwiGLUFFN(D, config.ffn_ratio, config.dropout)
# adaLN-Zero: 9 modulations (scale1, shift1, gate1, scale2, shift2, gate2, scale3, shift3, gate3)
self.adaln = AdaLNSingle(D, num_modulations=9)
def _modulate(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor,
H: int, W: int) -> torch.Tensor:
"""
x: [B, N, D] — token sequence
c: [B, D] — conditioning vector (timestep + iteration)
text_tokens: [B, S, text_dim] — CLIP text tokens
H, W: spatial dimensions of token grid
"""
s1, sh1, g1, s2, sh2, g2, s3, sh3, g3 = self.adaln(c)
# GRFM with adaLN-Zero
h = self._modulate(self.norm1(x), s1, sh1)
h = self.grfm(h, H, W)
x = x + g1.unsqueeze(1) * h
# Cross-attention with adaLN-Zero
h = self._modulate(self.norm2(x), s2, sh2)
h = self.cross_attn(h, text_tokens)
x = x + g2.unsqueeze(1) * h
# FFN with adaLN-Zero
h = self._modulate(self.norm3(x), s3, sh3)
h = self.ffn(h)
x = x + g3.unsqueeze(1) * h
return x
class CoreBlock(nn.Module):
"""The shared-weight core block, iterated r times.
Contains multiple CoreLayers to give sufficient per-iteration capacity.
"""
def __init__(self, config: IRISConfig):
super().__init__()
self.layers = nn.ModuleList([
CoreLayer(config) for _ in range(config.num_core_layers)
])
def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor,
H: int, W: int) -> torch.Tensor:
for layer in self.layers:
x = layer(x, c, text_tokens, H, W)
return x
# ============================================================================
# Coda Block (unique weights, final refinement)
# ============================================================================
class LocalWindowAttention(nn.Module):
"""Window-based local attention for final refinement.
Small window (8×8) for efficient local detail refinement.
"""
def __init__(self, dim: int, num_heads: int, head_dim: int, window_size: int = 8):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.window_size = window_size
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, 3 * num_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * head_dim, dim)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
B, N, D = x.shape
ws = self.window_size
# Reshape to 2D and partition into windows
x_2d = x.reshape(B, H, W, D)
# Pad if necessary
pad_h = (ws - H % ws) % ws
pad_w = (ws - W % ws) % ws
if pad_h > 0 or pad_w > 0:
x_2d = F.pad(x_2d, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = x_2d.shape[1], x_2d.shape[2]
nH, nW = Hp // ws, Wp // ws
# [B, nH, ws, nW, ws, D] → [B*nH*nW, ws*ws, D]
x_win = x_2d.reshape(B, nH, ws, nW, ws, D)
x_win = x_win.permute(0, 1, 3, 2, 4, 5).reshape(-1, ws * ws, D)
# QKV and attention within windows
qkv = self.qkv(x_win).reshape(-1, ws * ws, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).reshape(-1, ws * ws, self.num_heads * self.head_dim)
out = self.o_proj(out)
# Unpartition
out = out.reshape(B, nH, nW, ws, ws, D)
out = out.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, D)
# Remove padding
out = out[:, :H, :W, :].reshape(B, N, D)
return out
class CodaBlock(nn.Module):
"""Final refinement block with local window attention."""
def __init__(self, config: IRISConfig):
super().__init__()
D = config.hidden_dim
self.norm1 = nn.LayerNorm(D)
self.attn = LocalWindowAttention(D, config.num_heads, config.head_dim, window_size=8)
self.norm2 = nn.LayerNorm(D)
self.ffn = SwiGLUFFN(D, config.ffn_ratio)
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
x = x + self.attn(self.norm1(x), H, W)
x = x + self.ffn(self.norm2(x))
return x
# ============================================================================
# IRIS Generator (Main Model)
# ============================================================================
class IRISGenerator(nn.Module):
"""
IRIS: Iterative Recurrent Image Synthesis
The main denoising network with Prelude-Core-Coda structure.
Predicts velocity field v for rectified flow training.
"""
def __init__(self, config: IRISConfig):
super().__init__()
self.config = config
D = config.hidden_dim
# Patch embedding: latent patches → tokens
self.patch_embed = nn.Linear(config.patch_dim, D)
# Positional embedding (learned)
self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, D) * 0.02)
# Conditioning
self.time_embed = TimestepEmbedding(D)
self.iter_embed = IterationEmbedding(config.max_iterations, D)
self.text_proj = nn.Linear(config.text_dim, D) # Project CLIP text to model dim
# Global text pooling for conditioning
self.text_pool_proj = nn.Sequential(
nn.Linear(config.text_dim, D),
nn.SiLU(),
nn.Linear(D, D),
)
# Prelude (unique weights)
self.prelude = nn.ModuleList([PreludeBlock(D) for _ in range(config.num_prelude_blocks)])
# Core (shared weights, iterated)
self.core = CoreBlock(config)
# Long skip connection (from Diffusion-RWKV: linear(cat(shallow, deep)))
self.skip_proj = nn.Linear(2 * D, D)
# Coda (unique weights)
self.coda = nn.ModuleList([CodaBlock(config) for _ in range(config.num_coda_blocks)])
# Output projection: tokens → latent patches
self.final_norm = nn.LayerNorm(D)
self.output_proj = nn.Linear(D, config.patch_dim)
# Zero-init output for stable training start
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
# Precompute patch spatial dimensions
self.patch_h = config.latent_spatial // config.patch_size
self.patch_w = config.latent_spatial // config.patch_size
def patchify(self, z: torch.Tensor) -> torch.Tensor:
"""Convert latent z [B, C, H, W] → patches [B, N, patch_dim]."""
B, C, H, W = z.shape
p = self.config.patch_size
z = z.reshape(B, C, H // p, p, W // p, p)
z = z.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * p * p)
return z
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
"""Convert patches [B, N, patch_dim] → latent [B, C, H, W]."""
B, N, _ = x.shape
p = self.config.patch_size
C = self.config.latent_channels
H = self.patch_h
W = self.patch_w
x = x.reshape(B, H, W, C, p, p)
x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H * p, W * p)
return x
def forward(
self,
z_t: torch.Tensor, # Noisy latent [B, C, H, W]
t: torch.Tensor, # Timestep [B] in [0, 1]
text_tokens: torch.Tensor, # CLIP text embeddings [B, S, text_dim]
num_iterations: Optional[int] = None, # Override iteration count
) -> torch.Tensor:
"""Predict velocity field v(z_t, t, c) for rectified flow."""
B = z_t.shape[0]
r = num_iterations or self.config.default_iterations
H, W = self.patch_h, self.patch_w
# Patchify and embed
x = self.patch_embed(self.patchify(z_t)) + self.pos_embed
# Timestep conditioning
t_emb = self.time_embed(t * self.config.num_timesteps) # [B, D]
# Text conditioning (project to model dim for cross-attention)
text_projected = self.text_proj(text_tokens) # [B, S, D]
# Global text pool for adaLN conditioning
text_global = self.text_pool_proj(text_tokens.mean(dim=1)) # [B, D]
# ============ PRELUDE ============
for block in self.prelude:
x = block(x)
# Save for long skip connection
x_shallow = x
# ============ CORE (iterated r times) ============
for i in range(r):
# Iteration-aware conditioning
iter_idx = torch.full((B,), i, device=z_t.device, dtype=torch.long)
i_emb = self.iter_embed(iter_idx) # [B, D]
# Combined conditioning: timestep + iteration + text global
c = t_emb + i_emb + text_global # [B, D]
# Apply shared core block (pass original text_tokens for cross-attention)
x = self.core(x, c, text_tokens, H, W)
# Long skip connection (from Diffusion-RWKV paper)
x = self.skip_proj(torch.cat([x_shallow, x], dim=-1))
# ============ CODA ============
for block in self.coda:
x = block(x, H, W)
# Output projection
x = self.final_norm(x)
x = self.output_proj(x)
# Unpatchify to latent shape
v_pred = self.unpatchify(x)
return v_pred
# ============================================================================
# Full IRIS System
# ============================================================================
class IRIS(nn.Module):
"""Complete IRIS system: VAE + Generator.
For training: use train_step() which handles noise scheduling.
For inference: use generate() which runs the full pipeline.
"""
def __init__(self, config: IRISConfig):
super().__init__()
self.config = config
self.vae = WaveletVAE(config)
self.generator = IRISGenerator(config)
def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode images to latent space."""
return self.vae.encode(images)
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Decode latent to images."""
return self.vae.decode(z)
def get_velocity_target(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""Rectified flow velocity target: v = noise - z_0."""
return noise - z_0
def add_noise(self, z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Rectified flow forward process: z_t = (1-t)*z_0 + t*noise."""
t_expand = t[:, None, None, None]
return (1 - t_expand) * z_0 + t_expand * noise
def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
"""Sample timesteps from logit-normal distribution (from SD3/RF).
Concentrates sampling on intermediate timesteps where learning is hardest.
"""
u = torch.randn(batch_size, device=device)
t = torch.sigmoid(u) # Logit-normal with mean=0, std=1
# Clamp to avoid t=0 and t=1
t = t.clamp(1e-5, 1 - 1e-5)
return t
def train_step(
self,
images: torch.Tensor,
text_tokens: torch.Tensor,
num_iterations: Optional[int] = None,
) -> dict:
"""Single training step for rectified flow.
Returns dict with loss and diagnostics.
"""
B = images.shape[0]
device = images.device
# Encode to latent
z_0, mean, logvar = self.encode(images)
# Sample noise and timesteps
noise = torch.randn_like(z_0)
t = self.sample_timesteps(B, device)
# Create noisy latent
z_t = self.add_noise(z_0, noise, t)
# Predict velocity
# Randomly sample iteration count for training robustness
if num_iterations is None:
r_choices = [4, 6, 8, 10, 12]
r = r_choices[torch.randint(0, len(r_choices), (1,)).item()]
else:
r = num_iterations
v_pred = self.generator(z_t, t, text_tokens, num_iterations=r)
v_target = self.get_velocity_target(z_0, noise)
# SNR-weighted loss (from Rectified Flow paper)
# w(t) = t / (1 - t) — emphasizes high-noise timesteps
w = t / (1 - t + 1e-8)
w = w[:, None, None, None]
# Velocity matching loss
velocity_loss = (w * (v_pred - v_target).pow(2)).mean()
# VAE KL loss
kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()
return {
'loss': velocity_loss + 0.001 * kl_loss,
'velocity_loss': velocity_loss.item(),
'kl_loss': kl_loss.item(),
'mean_t': t.mean().item(),
}
@torch.no_grad()
def generate(
self,
text_tokens: torch.Tensor,
num_steps: int = 4,
num_iterations: int = 8,
cfg_scale: float = 4.0,
seed: Optional[int] = None,
) -> torch.Tensor:
"""Generate images from text conditioning using Euler solver.
Args:
text_tokens: [B, S, text_dim] CLIP text embeddings
num_steps: Number of ODE solver steps (1-50)
num_iterations: Core iterations per step (quality budget)
cfg_scale: Classifier-free guidance scale
seed: Random seed for reproducibility
"""
B, S, _ = text_tokens.shape
device = text_tokens.device
if seed is not None:
torch.manual_seed(seed)
# Start from pure noise
z = torch.randn(B, self.config.latent_channels,
self.config.latent_spatial, self.config.latent_spatial,
device=device)
# Euler solver for rectified flow ODE: dz/dt = -v(z, t)
# Integrate from t=1 (noise) to t=0 (data)
dt = 1.0 / num_steps
for step in range(num_steps):
t_val = 1.0 - step * dt
t = torch.full((B,), t_val, device=device)
# Predict velocity
v = self.generator(z, t, text_tokens, num_iterations=num_iterations)
# Classifier-free guidance (if cfg_scale > 1)
if cfg_scale > 1.0:
null_tokens = torch.zeros_like(text_tokens)
v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations)
v = v_uncond + cfg_scale * (v - v_uncond)
# Euler step: z = z - dt * v
z = z - dt * v
# Decode to image
images = self.decode(z)
images = images.clamp(-1, 1)
return images
# ============================================================================
# Utility Functions
# ============================================================================
def count_parameters(model: nn.Module) -> dict:
"""Count parameters in each component."""
counts = {}
total = 0
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
counts[name] = n
total += n
counts['total'] = total
# Separate trainable vs frozen
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
counts['trainable'] = trainable
return counts
def estimate_memory_mb(model: nn.Module, dtype=torch.float16) -> float:
"""Estimate model memory in MB."""
bytes_per_param = 2 if dtype == torch.float16 else 4
total_params = sum(p.numel() for p in model.parameters())
return total_params * bytes_per_param / (1024 * 1024)
def create_iris_small(latent_spatial: int = 32) -> IRIS:
"""Create IRIS-Small: ~75M generator params, suitable for mobile."""
config = IRISConfig(
latent_channels=16,
latent_spatial=latent_spatial,
hidden_dim=512,
num_heads=8,
head_dim=64,
ffn_ratio=2.667,
num_prelude_blocks=2,
num_core_layers=4,
num_coda_blocks=2,
default_iterations=8,
max_iterations=16,
fourier_num_blocks=8,
sparsity_threshold=0.01,
recurrence_dim=256,
manhattan_window=16,
text_dim=768,
max_text_tokens=77,
patch_size=2,
)
return IRIS(config)
def create_iris_tiny(latent_spatial: int = 32) -> IRIS:
"""Create IRIS-Tiny: ~30M generator params, ultra-mobile."""
config = IRISConfig(
latent_channels=8,
latent_spatial=latent_spatial,
hidden_dim=384,
num_heads=6,
head_dim=64,
ffn_ratio=2.667,
num_prelude_blocks=1,
num_core_layers=3,
num_coda_blocks=1,
default_iterations=8,
max_iterations=16,
fourier_num_blocks=6,
sparsity_threshold=0.01,
recurrence_dim=192,
manhattan_window=12,
text_dim=768,
max_text_tokens=77,
patch_size=2,
)
return IRIS(config)
def create_iris_base(latent_spatial: int = 32) -> IRIS:
"""Create IRIS-Base: ~150M generator params, quality-focused."""
config = IRISConfig(
latent_channels=16,
latent_spatial=latent_spatial,
hidden_dim=768,
num_heads=12,
head_dim=64,
ffn_ratio=2.667,
num_prelude_blocks=2,
num_core_layers=6,
num_coda_blocks=2,
default_iterations=8,
max_iterations=16,
fourier_num_blocks=12,
sparsity_threshold=0.01,
recurrence_dim=384,
manhattan_window=16,
text_dim=768,
max_text_tokens=77,
patch_size=2,
)
return IRIS(config)
if __name__ == "__main__":
print("=" * 70)
print("IRIS: Iterative Recurrent Image Synthesis")
print("=" * 70)
# Create model variants
for name, create_fn in [("IRIS-Tiny", create_iris_tiny),
("IRIS-Small", create_iris_small),
("IRIS-Base", create_iris_base)]:
print(f"\n{'─' * 50}")
print(f" {name}")
print(f"{'─' * 50}")
model = create_fn()
counts = count_parameters(model)
mem_fp16 = estimate_memory_mb(model, torch.float16)
mem_fp32 = estimate_memory_mb(model, torch.float32)
print(f" Total params: {counts['total']:>12,}")
print(f" Trainable params: {counts['trainable']:>12,}")
print(f" Memory (fp16): {mem_fp16:>10.1f} MB")
print(f" Memory (fp32): {mem_fp32:>10.1f} MB")
print(f" Components:")
for k, v in counts.items():
if k not in ('total', 'trainable'):
print(f" {k:20s}: {v:>10,} ({v/counts['total']*100:.1f}%)")