| """ |
| IRIS: Iterative Recurrent Image Synthesis |
| ========================================== |
| A novel architecture for mobile-first high-quality image generation. |
| |
| Key innovations: |
| 1. Wavelet-Frequency Latent Space (Haar DWT + lightweight VAE) |
| 2. Recurrent Depth Core (Prelude-Core-Coda with shared weights) |
| 3. Gated Recurrent Fourier Mixer (GRFM) — novel token mixing |
| 4. Manhattan Spatial Decay — learned 2D inductive bias |
| 5. Rectified Flow training with consistency distillation support |
| 6. Adaptive compute budget (4-16 iterations, same model) |
| |
| Author: IRIS Research |
| License: Apache 2.0 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple |
| from dataclasses import dataclass, field |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class IRISConfig: |
| """Configuration for IRIS model.""" |
| |
| latent_channels: int = 16 |
| latent_spatial: int = 32 |
| |
| |
| hidden_dim: int = 512 |
| num_heads: int = 8 |
| head_dim: int = 64 |
| ffn_ratio: float = 2.667 |
| |
| |
| num_prelude_blocks: int = 2 |
| num_core_layers: int = 4 |
| num_coda_blocks: int = 2 |
| default_iterations: int = 8 |
| max_iterations: int = 16 |
| |
| |
| fourier_num_blocks: int = 8 |
| sparsity_threshold: float = 0.01 |
| recurrence_dim: int = 256 |
| manhattan_window: int = 16 |
| |
| |
| text_dim: int = 768 |
| max_text_tokens: int = 77 |
| |
| |
| patch_size: int = 2 |
| |
| |
| num_timesteps: int = 1000 |
| |
| |
| vae_channels: list = field(default_factory=lambda: [32, 64, 128, 256]) |
| |
| |
| dropout: float = 0.0 |
| |
| @property |
| def vae_latent_channels(self) -> int: |
| """VAE latent channels must match generator latent channels.""" |
| return self.latent_channels |
| |
| @property |
| def num_patches(self) -> int: |
| return (self.latent_spatial // self.patch_size) ** 2 |
| |
| @property |
| def patch_dim(self) -> int: |
| return self.latent_channels * self.patch_size * self.patch_size |
|
|
|
|
| |
| |
| |
|
|
| class HaarDWT2D(nn.Module): |
| """2D Discrete Wavelet Transform using Haar wavelets. |
| Decomposes x ∈ R^{B,C,H,W} into R^{B,4C,H/2,W/2} (LL, LH, HL, HH subbands). |
| """ |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x_ll = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] + |
| x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5 |
| x_lh = (x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2] - |
| x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5 |
| x_hl = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] + |
| x[:, :, 1::2, 0::2] - x[:, :, 1::2, 1::2]) * 0.5 |
| x_hh = (x[:, :, 0::2, 0::2] - x[:, :, 0::2, 1::2] - |
| x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]) * 0.5 |
| return torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1) |
|
|
|
|
| class HaarIDWT2D(nn.Module): |
| """2D Inverse Discrete Wavelet Transform (Haar). |
| Reconstructs x ∈ R^{B,C,H,W} from R^{B,4*(C//4),H/2,W/2}. |
| """ |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, C4, Hh, Wh = x.shape |
| C = C4 // 4 |
| ll, lh, hl, hh = x[:, :C], x[:, C:2*C], x[:, 2*C:3*C], x[:, 3*C:] |
| |
| |
| H, W = Hh * 2, Wh * 2 |
| out = torch.zeros(B, C, H, W, device=x.device, dtype=x.dtype) |
| out[:, :, 0::2, 0::2] = (ll + lh + hl + hh) * 0.5 |
| out[:, :, 0::2, 1::2] = (ll + lh - hl - hh) * 0.5 |
| out[:, :, 1::2, 0::2] = (ll - lh + hl - hh) * 0.5 |
| out[:, :, 1::2, 1::2] = (ll - lh - hl + hh) * 0.5 |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class DepthwiseSeparableConv(nn.Module): |
| """Depthwise separable convolution — key mobile optimization.""" |
| def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1): |
| super().__init__() |
| self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, stride, padding, groups=in_ch) |
| self.pointwise = nn.Conv2d(in_ch, out_ch, 1) |
| |
| def forward(self, x): |
| return self.pointwise(self.depthwise(x)) |
|
|
|
|
| class ResBlock(nn.Module): |
| """Residual block with depthwise separable convolutions.""" |
| def __init__(self, channels): |
| super().__init__() |
| self.norm1 = nn.GroupNorm(8, channels) |
| self.conv1 = DepthwiseSeparableConv(channels, channels) |
| self.norm2 = nn.GroupNorm(8, channels) |
| self.conv2 = DepthwiseSeparableConv(channels, channels) |
| |
| nn.init.zeros_(self.conv2.pointwise.weight) |
| nn.init.zeros_(self.conv2.pointwise.bias) |
| |
| def forward(self, x): |
| h = F.silu(self.norm1(x)) |
| h = self.conv1(h) |
| h = F.silu(self.norm2(h)) |
| h = self.conv2(h) |
| return x + h |
|
|
|
|
| class WaveletVAEEncoder(nn.Module): |
| """Lightweight encoder: Haar DWT preprocessing + small convolutional encoder. |
| Input: images R^{B,3,H,W} → Output: latent R^{B,C_latent,H/16,W/16} |
| Compression: 3×H×W → C_latent×(H/16)×(W/16) |
| """ |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| self.dwt = HaarDWT2D() |
| channels = config.vae_channels |
| latent_ch = config.vae_latent_channels |
| |
| |
| self.conv_in = nn.Conv2d(12, channels[0], 3, 1, 1) |
| |
| |
| self.down_blocks = nn.ModuleList() |
| for i in range(len(channels) - 1): |
| self.down_blocks.append(nn.Sequential( |
| ResBlock(channels[i]), |
| nn.Conv2d(channels[i], channels[i+1], 3, 2, 1), |
| )) |
| |
| |
| self.mid = nn.Sequential( |
| ResBlock(channels[-1]), |
| ResBlock(channels[-1]), |
| ) |
| |
| |
| self.norm_out = nn.GroupNorm(8, channels[-1]) |
| self.conv_out = nn.Conv2d(channels[-1], 2 * latent_ch, 1) |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| x = self.dwt(x) |
| x = self.conv_in(x) |
| |
| for down in self.down_blocks: |
| x = down(x) |
| |
| x = self.mid(x) |
| x = F.silu(self.norm_out(x)) |
| x = self.conv_out(x) |
| |
| mean, logvar = x.chunk(2, dim=1) |
| logvar = torch.clamp(logvar, -30.0, 20.0) |
| return mean, logvar |
| |
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| mean, logvar = self.forward(x) |
| std = torch.exp(0.5 * logvar) |
| z = mean + std * torch.randn_like(std) |
| return z, mean, logvar |
|
|
|
|
| class WaveletVAEDecoder(nn.Module): |
| """Tiny decoder: latent → wavelet coefficients → Haar IDWT → image. |
| Designed to be as small as possible for mobile inference. |
| """ |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| channels = list(reversed(config.vae_channels)) |
| latent_ch = config.vae_latent_channels |
| self.idwt = HaarIDWT2D() |
| |
| |
| self.conv_in = nn.Conv2d(latent_ch, channels[0], 3, 1, 1) |
| |
| |
| self.mid = nn.Sequential( |
| ResBlock(channels[0]), |
| ) |
| |
| |
| self.up_blocks = nn.ModuleList() |
| for i in range(len(channels) - 1): |
| self.up_blocks.append(nn.Sequential( |
| nn.Upsample(scale_factor=2, mode='nearest'), |
| DepthwiseSeparableConv(channels[i], channels[i+1]), |
| nn.SiLU(), |
| ResBlock(channels[i+1]), |
| )) |
| |
| |
| self.norm_out = nn.GroupNorm(8, channels[-1]) |
| self.conv_out = nn.Conv2d(channels[-1], 12, 3, 1, 1) |
| |
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| x = self.conv_in(z) |
| x = self.mid(x) |
| |
| for up in self.up_blocks: |
| x = up(x) |
| |
| x = F.silu(self.norm_out(x)) |
| x = self.conv_out(x) |
| |
| |
| x = self.idwt(x) |
| return x |
|
|
|
|
| class WaveletVAE(nn.Module): |
| """Complete Wavelet VAE with DWT preprocessing.""" |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| self.encoder = WaveletVAEEncoder(config) |
| self.decoder = WaveletVAEDecoder(config) |
| self.config = config |
| |
| def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| return self.encoder.encode(x) |
| |
| def decode(self, z: torch.Tensor) -> torch.Tensor: |
| return self.decoder(z) |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| z, mean, logvar = self.encode(x) |
| x_recon = self.decode(z) |
| return x_recon, mean, logvar |
|
|
|
|
| |
| |
| |
|
|
| class TimestepEmbedding(nn.Module): |
| """Sinusoidal timestep embedding with MLP projection.""" |
| def __init__(self, dim: int, max_period: int = 10000): |
| super().__init__() |
| self.dim = dim |
| self.max_period = max_period |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, 4 * dim), |
| nn.SiLU(), |
| nn.Linear(4 * dim, dim), |
| ) |
| |
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| half = self.dim // 2 |
| freqs = torch.exp( |
| -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half |
| ) |
| args = t[:, None] * freqs[None, :] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| return self.mlp(embedding) |
|
|
|
|
| class IterationEmbedding(nn.Module): |
| """Learnable embedding for iteration index within recurrent core.""" |
| def __init__(self, max_iterations: int, dim: int): |
| super().__init__() |
| self.embedding = nn.Embedding(max_iterations, dim) |
| |
| def forward(self, i: torch.Tensor) -> torch.Tensor: |
| return self.embedding(i) |
|
|
|
|
| class AdaLNSingle(nn.Module): |
| """Adaptive Layer Normalization (single shared MLP, per-layer bias). |
| From PixArt-α: saves 27% params vs standard adaLN. |
| |
| Produces (scale, shift, gate) for each sub-layer from a shared condition vector. |
| """ |
| def __init__(self, dim: int, num_modulations: int = 6): |
| super().__init__() |
| self.silu = nn.SiLU() |
| self.linear = nn.Linear(dim, num_modulations * dim) |
| self.num_modulations = num_modulations |
| nn.init.zeros_(self.linear.weight) |
| nn.init.zeros_(self.linear.bias) |
| |
| def forward(self, c: torch.Tensor) -> Tuple[torch.Tensor, ...]: |
| """c: [B, D] condition vector → tuple of num_modulations tensors [B, D].""" |
| params = self.linear(self.silu(c)) |
| return params.chunk(self.num_modulations, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class FourierMixingPathway(nn.Module): |
| """Pathway 1: Adaptive Fourier Neural Operator-style global mixing. |
| O(N log N) complexity via FFT. Block-diagonal MLP in frequency domain. |
| """ |
| def __init__(self, dim: int, num_blocks: int = 8, sparsity_threshold: float = 0.01): |
| super().__init__() |
| self.dim = dim |
| self.num_blocks = num_blocks |
| self.block_size = dim // num_blocks |
| self.sparsity_threshold = sparsity_threshold |
| |
| |
| |
| |
| self.w1_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) |
| self.w1_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) |
| self.w2_real = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) |
| self.w2_imag = nn.Parameter(torch.randn(num_blocks, self.block_size, self.block_size) * 0.02) |
| self.b1 = nn.Parameter(torch.zeros(num_blocks, self.block_size)) |
| self.b2 = nn.Parameter(torch.zeros(num_blocks, self.block_size)) |
| |
| def complex_matmul(self, x: torch.Tensor, w_real: torch.Tensor, w_imag: torch.Tensor) -> torch.Tensor: |
| """Complex matrix multiplication: (a+bi)(c+di) = (ac-bd) + (ad+bc)i |
| x: [..., num_blocks, block_size] (complex) |
| w: [num_blocks, block_size, block_size] (real) |
| """ |
| |
| |
| out_real = torch.einsum('...ki,kij->...kj', x.real, w_real) - torch.einsum('...ki,kij->...kj', x.imag, w_imag) |
| out_imag = torch.einsum('...ki,kij->...kj', x.real, w_imag) + torch.einsum('...ki,kij->...kj', x.imag, w_real) |
| return torch.complex(out_real, out_imag) |
| |
| @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32) |
| def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| """Forward pass — forced to fp32 because FFT + ComplexHalf is broken/slow.""" |
| B, N, D = x.shape |
| x_2d = x.reshape(B, H, W, D) |
| |
| |
| x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho') |
| |
| |
| Hf, Wf = x_freq.shape[1], x_freq.shape[2] |
| x_freq = x_freq.reshape(B, Hf, Wf, self.num_blocks, self.block_size) |
| |
| |
| x_freq = self.complex_matmul(x_freq, self.w1_real, self.w1_imag) |
| x_freq = x_freq + self.b1 |
| x_freq = torch.complex(F.relu(x_freq.real), F.relu(x_freq.imag)) |
| |
| |
| x_freq = self.complex_matmul(x_freq, self.w2_real, self.w2_imag) |
| x_freq = x_freq + self.b2 |
| |
| |
| x_freq = x_freq.reshape(B, Hf, Wf, D) |
| |
| |
| magnitude = x_freq.abs() |
| shrunk_mag = F.relu(magnitude - self.sparsity_threshold) |
| x_freq = x_freq * (shrunk_mag / (magnitude + 1e-8)) |
| |
| |
| x_out = torch.fft.irfft2(x_freq, s=(H, W), dim=(1, 2), norm='ortho') |
| return x_out.reshape(B, N, D) |
|
|
|
|
| class GatedLinearRecurrence(nn.Module): |
| """Pathway 2: Bidirectional Gated Linear Recurrence (RG-LRU inspired). |
| O(N) complexity with O(1) state per position. |
| |
| h_t = a_t * h_{t-1} + sqrt(1 - a_t^2) * (i_t * x_t) |
| where a_t = sigmoid(Λ)^(c * sigmoid(W_a * x_t)) |
| """ |
| def __init__(self, dim: int, recurrence_dim: int): |
| super().__init__() |
| self.dim = dim |
| self.rec_dim = recurrence_dim |
| |
| |
| self.proj_in = nn.Linear(dim, recurrence_dim * 2) |
| |
| |
| self.W_a = nn.Linear(recurrence_dim, recurrence_dim, bias=False) |
| self.W_x = nn.Linear(recurrence_dim, recurrence_dim, bias=False) |
| self.Lambda = nn.Parameter(torch.randn(recurrence_dim) * 0.5 + 2.0) |
| self.c = 8.0 |
| |
| |
| self.proj_out = nn.Linear(recurrence_dim * 2, dim) |
| |
| @staticmethod |
| @torch.jit.script |
| def _scan_kernel(a: torch.Tensor, u: torch.Tensor) -> torch.Tensor: |
| """JIT-compiled sequential scan — avoids Python loop overhead on GPU.""" |
| B, N, D = a.shape |
| h = torch.zeros(B, D, device=a.device, dtype=a.dtype) |
| outputs = torch.empty_like(u) |
| for t in range(N): |
| h = a[:, t] * h + u[:, t] |
| outputs[:, t] = h |
| return outputs |
| |
| def _scan(self, x: torch.Tensor) -> torch.Tensor: |
| """Gated linear recurrence scan. x: [B, N, rec_dim]""" |
| B, N, D = x.shape |
| |
| |
| a_base = torch.sigmoid(self.Lambda) |
| r = torch.sigmoid(self.W_a(x)) |
| i = torch.sigmoid(self.W_x(x)) |
| |
| a = a_base.pow(self.c * r) |
| input_scale = torch.sqrt(1.0 - a * a + 1e-8) |
| u = input_scale * (i * x) |
| |
| |
| return self._scan_kernel(a.contiguous(), u.contiguous()) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, N, D = x.shape |
| |
| |
| x_proj = self.proj_in(x) |
| x_fwd, x_bwd = x_proj.chunk(2, dim=-1) |
| |
| |
| h_fwd = self._scan(x_fwd) |
| h_bwd = self._scan(x_bwd.flip(1)).flip(1) |
| |
| |
| h = torch.cat([h_fwd, h_bwd], dim=-1) |
| return self.proj_out(h) |
|
|
|
|
| class ManhattanSpatialGate(nn.Module): |
| """Pathway 3: Manhattan distance spatial decay gating. |
| Provides learned 2D spatial inductive bias with per-head multi-scale receptive fields. |
| Uses CACHED distance matrix and sparse windowed computation for efficiency. |
| """ |
| def __init__(self, dim: int, num_heads: int, window: int = 16): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.window = window |
| |
| |
| |
| self.gamma_logit = nn.Parameter(torch.linspace(0.85, 2.94, num_heads)) |
| |
| |
| self.v_proj = nn.Linear(dim, dim) |
| self.g_proj = nn.Linear(dim, dim) |
| self.o_proj = nn.Linear(dim, dim) |
| |
| |
| self._cached_dist = None |
| self._cached_shape = None |
| |
| def _get_manhattan_mask(self, H: int, W: int, device: torch.device) -> torch.Tensor: |
| """Compute Manhattan distance matrix — CACHED after first call.""" |
| shape_key = (H, W, device) |
| if self._cached_dist is not None and self._cached_shape == shape_key: |
| return self._cached_dist |
| |
| |
| rows = torch.arange(H, device=device) |
| cols = torch.arange(W, device=device) |
| grid_r, grid_c = torch.meshgrid(rows, cols, indexing='ij') |
| coords = torch.stack([grid_r.reshape(-1), grid_c.reshape(-1)], dim=-1).float() |
| |
| |
| dist = (coords[:, None, :] - coords[None, :, :]).abs().sum(dim=-1) |
| |
| self._cached_dist = dist |
| self._cached_shape = shape_key |
| return dist |
| |
| def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| B, N, D = x.shape |
| input_dtype = x.dtype |
| |
| |
| gamma = torch.sigmoid(self.gamma_logit).float() |
| manhattan_dist = self._get_manhattan_mask(H, W, x.device) |
| |
| |
| decay_mask = (manhattan_dist <= self.window) |
| |
| |
| decay = gamma[:, None, None].pow(manhattan_dist[None, :, :]) |
| decay = decay * decay_mask.unsqueeze(0).float() |
| |
| |
| v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim) |
| g = torch.sigmoid(self.g_proj(x)) |
| |
| |
| v = v.permute(0, 2, 1, 3) |
| decay_cast = decay.unsqueeze(0).to(input_dtype) |
| out = torch.matmul(decay_cast, v) |
| |
| |
| decay_sum = decay_cast.sum(dim=-1, keepdim=True) + 1e-8 |
| out = out / decay_sum |
| |
| out = out.permute(0, 2, 1, 3).reshape(B, N, D) |
| out = out * g |
| return self.o_proj(out) |
|
|
|
|
| class GRFM(nn.Module): |
| """Gated Recurrent Fourier Mixer — the core innovation of IRIS. |
| |
| Fuses three complementary pathways: |
| 1. Fourier Global Mixing (O(N log N)) — captures textures, patterns |
| 2. Gated Linear Recurrence (O(N)) — captures sequential/local dependencies |
| 3. Manhattan Spatial Gate — provides 2D inductive bias |
| |
| Pathways are combined via learned adaptive gating. |
| """ |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| D = config.hidden_dim |
| |
| self.fourier = FourierMixingPathway(D, config.fourier_num_blocks, config.sparsity_threshold) |
| self.recurrence = GatedLinearRecurrence(D, config.recurrence_dim) |
| self.spatial = ManhattanSpatialGate(D, config.num_heads, config.manhattan_window) |
| |
| |
| self.blend_gate = nn.Sequential( |
| nn.Linear(D, D), |
| nn.SiLU(), |
| nn.Linear(D, D), |
| nn.Sigmoid(), |
| ) |
| |
| |
| self.spatial_scale = nn.Parameter(torch.tensor(0.1)) |
| |
| def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| |
| x_fourier = self.fourier(x, H, W) |
| x_recurrent = self.recurrence(x) |
| x_spatial = self.spatial(x, H, W) |
| |
| |
| gate = self.blend_gate(x) |
| |
| |
| output = gate * x_fourier + (1 - gate) * x_recurrent |
| |
| |
| output = output + self.spatial_scale * x_spatial |
| |
| return output |
|
|
|
|
| |
| |
| |
|
|
| class CrossAttention(nn.Module): |
| """Efficient cross-attention for text conditioning. |
| Only 77 text tokens → O(N × 77 × d) per layer, very cheap. |
| """ |
| def __init__(self, dim: int, text_dim: int, num_heads: int, head_dim: int): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| self.scale = head_dim ** -0.5 |
| |
| self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=False) |
| self.k_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False) |
| self.v_proj = nn.Linear(text_dim, num_heads * head_dim, bias=False) |
| self.o_proj = nn.Linear(num_heads * head_dim, dim) |
| |
| |
| |
| self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=False) |
| self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=False) |
| |
| def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: |
| B, N, _ = x.shape |
| _, S, _ = context.shape |
| |
| q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(context).reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
| |
| |
| attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| attn = F.softmax(attn, dim=-1) |
| out = torch.matmul(attn, v) |
| |
| out = out.transpose(1, 2).reshape(B, N, -1) |
| return self.o_proj(out) |
|
|
|
|
| |
| |
| |
|
|
| class SwiGLUFFN(nn.Module): |
| """SwiGLU Feed-Forward Network — better than GELU for transformers.""" |
| def __init__(self, dim: int, ratio: float = 2.667, dropout: float = 0.0): |
| super().__init__() |
| hidden = int(dim * ratio) |
| |
| hidden = ((hidden + 63) // 64) * 64 |
| |
| self.w1 = nn.Linear(dim, hidden, bias=False) |
| self.w2 = nn.Linear(dim, hidden, bias=False) |
| self.w3 = nn.Linear(hidden, dim, bias=False) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x))) |
|
|
|
|
| |
| |
| |
|
|
| class PreludeBlock(nn.Module): |
| """Lightweight conv-based block for initial feature extraction.""" |
| def __init__(self, dim: int): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.dwconv = nn.Conv1d(dim, dim, kernel_size=5, padding=2, groups=dim) |
| self.pointwise = nn.Linear(dim, dim) |
| self.norm2 = nn.LayerNorm(dim) |
| self.ffn = SwiGLUFFN(dim) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| h = self.norm1(x) |
| h = h.transpose(1, 2) |
| h = self.dwconv(h).transpose(1, 2) |
| h = F.silu(h) |
| h = self.pointwise(h) |
| x = x + h |
| |
| |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class CoreLayer(nn.Module): |
| """Single layer within the core block. |
| Contains: GRFM + Cross-Attention + FFN, all with adaLN-Zero conditioning. |
| """ |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| D = config.hidden_dim |
| |
| |
| self.norm1 = nn.LayerNorm(D, elementwise_affine=False) |
| self.grfm = GRFM(config) |
| |
| |
| self.norm2 = nn.LayerNorm(D, elementwise_affine=False) |
| self.cross_attn = CrossAttention(D, config.text_dim, config.num_heads, config.head_dim) |
| |
| |
| self.norm3 = nn.LayerNorm(D, elementwise_affine=False) |
| self.ffn = SwiGLUFFN(D, config.ffn_ratio, config.dropout) |
| |
| |
| self.adaln = AdaLNSingle(D, num_modulations=9) |
| |
| def _modulate(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
| |
| def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor, |
| H: int, W: int) -> torch.Tensor: |
| """ |
| x: [B, N, D] — token sequence |
| c: [B, D] — conditioning vector (timestep + iteration) |
| text_tokens: [B, S, text_dim] — CLIP text tokens |
| H, W: spatial dimensions of token grid |
| """ |
| s1, sh1, g1, s2, sh2, g2, s3, sh3, g3 = self.adaln(c) |
| |
| |
| h = self._modulate(self.norm1(x), s1, sh1) |
| h = self.grfm(h, H, W) |
| x = x + g1.unsqueeze(1) * h |
| |
| |
| h = self._modulate(self.norm2(x), s2, sh2) |
| h = self.cross_attn(h, text_tokens) |
| x = x + g2.unsqueeze(1) * h |
| |
| |
| h = self._modulate(self.norm3(x), s3, sh3) |
| h = self.ffn(h) |
| x = x + g3.unsqueeze(1) * h |
| |
| return x |
|
|
|
|
| class CoreBlock(nn.Module): |
| """The shared-weight core block, iterated r times. |
| Contains multiple CoreLayers to give sufficient per-iteration capacity. |
| """ |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| CoreLayer(config) for _ in range(config.num_core_layers) |
| ]) |
| |
| def forward(self, x: torch.Tensor, c: torch.Tensor, text_tokens: torch.Tensor, |
| H: int, W: int) -> torch.Tensor: |
| for layer in self.layers: |
| x = layer(x, c, text_tokens, H, W) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class LocalWindowAttention(nn.Module): |
| """Window-based local attention for final refinement. |
| Small window (8×8) for efficient local detail refinement. |
| """ |
| def __init__(self, dim: int, num_heads: int, head_dim: int, window_size: int = 8): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| self.window_size = window_size |
| self.scale = head_dim ** -0.5 |
| |
| self.qkv = nn.Linear(dim, 3 * num_heads * head_dim, bias=False) |
| self.o_proj = nn.Linear(num_heads * head_dim, dim) |
| |
| def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| B, N, D = x.shape |
| ws = self.window_size |
| |
| |
| x_2d = x.reshape(B, H, W, D) |
| |
| |
| pad_h = (ws - H % ws) % ws |
| pad_w = (ws - W % ws) % ws |
| if pad_h > 0 or pad_w > 0: |
| x_2d = F.pad(x_2d, (0, 0, 0, pad_w, 0, pad_h)) |
| |
| Hp, Wp = x_2d.shape[1], x_2d.shape[2] |
| nH, nW = Hp // ws, Wp // ws |
| |
| |
| x_win = x_2d.reshape(B, nH, ws, nW, ws, D) |
| x_win = x_win.permute(0, 1, 3, 2, 4, 5).reshape(-1, ws * ws, D) |
| |
| |
| qkv = self.qkv(x_win).reshape(-1, ws * ws, 3, self.num_heads, self.head_dim) |
| q, k, v = qkv.unbind(2) |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| |
| attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| attn = F.softmax(attn, dim=-1) |
| out = torch.matmul(attn, v) |
| |
| out = out.transpose(1, 2).reshape(-1, ws * ws, self.num_heads * self.head_dim) |
| out = self.o_proj(out) |
| |
| |
| out = out.reshape(B, nH, nW, ws, ws, D) |
| out = out.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, D) |
| |
| |
| out = out[:, :H, :W, :].reshape(B, N, D) |
| return out |
|
|
|
|
| class CodaBlock(nn.Module): |
| """Final refinement block with local window attention.""" |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| D = config.hidden_dim |
| self.norm1 = nn.LayerNorm(D) |
| self.attn = LocalWindowAttention(D, config.num_heads, config.head_dim, window_size=8) |
| self.norm2 = nn.LayerNorm(D) |
| self.ffn = SwiGLUFFN(D, config.ffn_ratio) |
| |
| def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| x = x + self.attn(self.norm1(x), H, W) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class IRISGenerator(nn.Module): |
| """ |
| IRIS: Iterative Recurrent Image Synthesis |
| |
| The main denoising network with Prelude-Core-Coda structure. |
| Predicts velocity field v for rectified flow training. |
| """ |
| def __init__(self, config: IRISConfig): |
| super().__init__() |
| self.config = config |
| D = config.hidden_dim |
| |
| |
| self.patch_embed = nn.Linear(config.patch_dim, D) |
| |
| |
| self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, D) * 0.02) |
| |
| |
| self.time_embed = TimestepEmbedding(D) |
| self.iter_embed = IterationEmbedding(config.max_iterations, D) |
| self.text_proj = nn.Linear(config.text_dim, D) |
| |
| |
| self.text_pool_proj = nn.Sequential( |
| nn.Linear(config.text_dim, D), |
| nn.SiLU(), |
| nn.Linear(D, D), |
| ) |
| |
| |
| self.prelude = nn.ModuleList([PreludeBlock(D) for _ in range(config.num_prelude_blocks)]) |
| |
| |
| self.core = CoreBlock(config) |
| |
| |
| self.skip_proj = nn.Linear(2 * D, D) |
| |
| |
| self.coda = nn.ModuleList([CodaBlock(config) for _ in range(config.num_coda_blocks)]) |
| |
| |
| self.final_norm = nn.LayerNorm(D) |
| self.output_proj = nn.Linear(D, config.patch_dim) |
| |
| |
| nn.init.zeros_(self.output_proj.weight) |
| nn.init.zeros_(self.output_proj.bias) |
| |
| |
| self.patch_h = config.latent_spatial // config.patch_size |
| self.patch_w = config.latent_spatial // config.patch_size |
| |
| def patchify(self, z: torch.Tensor) -> torch.Tensor: |
| """Convert latent z [B, C, H, W] → patches [B, N, patch_dim].""" |
| B, C, H, W = z.shape |
| p = self.config.patch_size |
| z = z.reshape(B, C, H // p, p, W // p, p) |
| z = z.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * p * p) |
| return z |
| |
| def unpatchify(self, x: torch.Tensor) -> torch.Tensor: |
| """Convert patches [B, N, patch_dim] → latent [B, C, H, W].""" |
| B, N, _ = x.shape |
| p = self.config.patch_size |
| C = self.config.latent_channels |
| H = self.patch_h |
| W = self.patch_w |
| x = x.reshape(B, H, W, C, p, p) |
| x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H * p, W * p) |
| return x |
| |
| def forward( |
| self, |
| z_t: torch.Tensor, |
| t: torch.Tensor, |
| text_tokens: torch.Tensor, |
| num_iterations: Optional[int] = None, |
| ) -> torch.Tensor: |
| """Predict velocity field v(z_t, t, c) for rectified flow.""" |
| B = z_t.shape[0] |
| r = num_iterations or self.config.default_iterations |
| H, W = self.patch_h, self.patch_w |
| |
| |
| x = self.patch_embed(self.patchify(z_t)) + self.pos_embed |
| |
| |
| t_emb = self.time_embed(t * self.config.num_timesteps) |
| |
| |
| text_projected = self.text_proj(text_tokens) |
| |
| |
| text_global = self.text_pool_proj(text_tokens.mean(dim=1)) |
| |
| |
| for block in self.prelude: |
| x = block(x) |
| |
| |
| x_shallow = x |
| |
| |
| for i in range(r): |
| |
| iter_idx = torch.full((B,), i, device=z_t.device, dtype=torch.long) |
| i_emb = self.iter_embed(iter_idx) |
| |
| |
| c = t_emb + i_emb + text_global |
| |
| |
| x = self.core(x, c, text_tokens, H, W) |
| |
| |
| x = self.skip_proj(torch.cat([x_shallow, x], dim=-1)) |
| |
| |
| for block in self.coda: |
| x = block(x, H, W) |
| |
| |
| x = self.final_norm(x) |
| x = self.output_proj(x) |
| |
| |
| v_pred = self.unpatchify(x) |
| return v_pred |
|
|
|
|
| |
| |
| |
|
|
| class IRIS(nn.Module): |
| """Complete IRIS system: Generator + optional built-in VAE. |
| |
| For training with external VAE (recommended): use train_step_latent() with pre-encoded latents. |
| For training with built-in Wavelet VAE: use train_step() with raw images. |
| For inference: use generate_latent() to get latent, then decode externally. |
| """ |
| def __init__(self, config: IRISConfig, use_builtin_vae: bool = False): |
| super().__init__() |
| self.config = config |
| self.generator = IRISGenerator(config) |
| |
| |
| self.vae = WaveletVAE(config) if use_builtin_vae else None |
| |
| def encode(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Encode images via built-in VAE (only if use_builtin_vae=True).""" |
| assert self.vae is not None, "No built-in VAE. Use an external VAE to encode images." |
| return self.vae.encode(images) |
| |
| def decode(self, z: torch.Tensor) -> torch.Tensor: |
| """Decode latent via built-in VAE (only if use_builtin_vae=True).""" |
| assert self.vae is not None, "No built-in VAE. Use an external VAE to decode latents." |
| return self.vae.decode(z) |
| |
| @staticmethod |
| def get_velocity_target(z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: |
| """Rectified flow velocity target: v = noise - z_0.""" |
| return noise - z_0 |
| |
| @staticmethod |
| def add_noise(z_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| """Rectified flow forward process: z_t = (1-t)*z_0 + t*noise.""" |
| t_expand = t[:, None, None, None] |
| return (1 - t_expand) * z_0 + t_expand * noise |
| |
| @staticmethod |
| def sample_timesteps(batch_size: int, device: torch.device) -> torch.Tensor: |
| """Sample timesteps from logit-normal distribution (from SD3/RF).""" |
| u = torch.randn(batch_size, device=device) |
| t = torch.sigmoid(u) |
| t = t.clamp(1e-5, 1 - 1e-5) |
| return t |
| |
| def train_step_latent( |
| self, |
| z_0: torch.Tensor, |
| text_tokens: torch.Tensor, |
| num_iterations: Optional[int] = None, |
| ) -> dict: |
| """Training step on PRE-ENCODED latents (recommended path). |
| |
| Use this with an external pre-trained VAE: |
| z_0 = external_vae.encode(images) # done outside |
| result = iris.train_step_latent(z_0, text_tokens) |
| """ |
| B = z_0.shape[0] |
| device = z_0.device |
| |
| noise = torch.randn_like(z_0) |
| t = self.sample_timesteps(B, device) |
| z_t = self.add_noise(z_0, noise, t) |
| |
| if num_iterations is None: |
| r_choices = [3, 4, 5, 6] |
| r = r_choices[torch.randint(0, len(r_choices), (1,)).item()] |
| else: |
| r = num_iterations |
| |
| v_pred = self.generator(z_t, t, text_tokens, num_iterations=r) |
| v_target = self.get_velocity_target(z_0, noise) |
| |
| w = t / (1 - t + 1e-8) |
| w = w[:, None, None, None] |
| velocity_loss = (w * (v_pred - v_target).pow(2)).mean() |
| |
| return { |
| 'loss': velocity_loss, |
| 'velocity_loss': velocity_loss.item(), |
| 'mean_t': t.mean().item(), |
| } |
| |
| def train_step( |
| self, |
| images: torch.Tensor, |
| text_tokens: torch.Tensor, |
| num_iterations: Optional[int] = None, |
| ) -> dict: |
| """Training step with built-in Wavelet VAE (legacy path).""" |
| assert self.vae is not None, "No built-in VAE. Use train_step_latent() instead." |
| B = images.shape[0] |
| device = images.device |
| |
| z_0, mean, logvar = self.encode(images) |
| noise = torch.randn_like(z_0) |
| t = self.sample_timesteps(B, device) |
| z_t = self.add_noise(z_0, noise, t) |
| |
| if num_iterations is None: |
| r_choices = [3, 4, 5, 6] |
| r = r_choices[torch.randint(0, len(r_choices), (1,)).item()] |
| else: |
| r = num_iterations |
| |
| v_pred = self.generator(z_t, t, text_tokens, num_iterations=r) |
| v_target = self.get_velocity_target(z_0, noise) |
| |
| w = t / (1 - t + 1e-8) |
| w = w[:, None, None, None] |
| velocity_loss = (w * (v_pred - v_target).pow(2)).mean() |
| kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean() |
| |
| return { |
| 'loss': velocity_loss + 0.001 * kl_loss, |
| 'velocity_loss': velocity_loss.item(), |
| 'kl_loss': kl_loss.item(), |
| 'mean_t': t.mean().item(), |
| } |
| |
| @torch.no_grad() |
| def generate_latent( |
| self, |
| text_tokens: torch.Tensor, |
| num_steps: int = 4, |
| num_iterations: int = 8, |
| cfg_scale: float = 4.0, |
| seed: Optional[int] = None, |
| ) -> torch.Tensor: |
| """Generate latent (decode externally with your VAE). |
| |
| Returns z_0 latent tensor, NOT decoded image. |
| """ |
| B, S, _ = text_tokens.shape |
| device = text_tokens.device |
| |
| if seed is not None: |
| torch.manual_seed(seed) |
| |
| z = torch.randn(B, self.config.latent_channels, |
| self.config.latent_spatial, self.config.latent_spatial, |
| device=device) |
| |
| dt = 1.0 / num_steps |
| for step in range(num_steps): |
| t_val = 1.0 - step * dt |
| t = torch.full((B,), t_val, device=device) |
| |
| v = self.generator(z, t, text_tokens, num_iterations=num_iterations) |
| |
| if cfg_scale > 1.0: |
| null_tokens = torch.zeros_like(text_tokens) |
| v_uncond = self.generator(z, t, null_tokens, num_iterations=num_iterations) |
| v = v_uncond + cfg_scale * (v - v_uncond) |
| |
| z = z - dt * v |
| |
| return z |
|
|
|
|
| |
| |
| |
|
|
| def count_parameters(model: nn.Module) -> dict: |
| """Count parameters in each component.""" |
| counts = {} |
| total = 0 |
| for name, module in model.named_children(): |
| n = sum(p.numel() for p in module.parameters()) |
| counts[name] = n |
| total += n |
| counts['total'] = total |
| |
| |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| counts['trainable'] = trainable |
| return counts |
|
|
|
|
| def estimate_memory_mb(model: nn.Module, dtype=torch.float16) -> float: |
| """Estimate model memory in MB.""" |
| bytes_per_param = 2 if dtype == torch.float16 else 4 |
| total_params = sum(p.numel() for p in model.parameters()) |
| return total_params * bytes_per_param / (1024 * 1024) |
|
|
|
|
| def create_iris_small(latent_spatial: int = 32) -> IRIS: |
| """Create IRIS-Small for SD-VAE latent space (4ch, 8× downsample).""" |
| config = IRISConfig( |
| latent_channels=4, |
| latent_spatial=latent_spatial, |
| hidden_dim=512, |
| num_heads=8, |
| head_dim=64, |
| ffn_ratio=2.667, |
| num_prelude_blocks=2, |
| num_core_layers=4, |
| num_coda_blocks=2, |
| default_iterations=8, |
| max_iterations=16, |
| fourier_num_blocks=8, |
| sparsity_threshold=0.01, |
| recurrence_dim=256, |
| manhattan_window=16, |
| text_dim=768, |
| max_text_tokens=77, |
| patch_size=2, |
| ) |
| return IRIS(config) |
|
|
|
|
| def create_iris_tiny(latent_spatial: int = 32) -> IRIS: |
| """Create IRIS-Tiny for SD-VAE latent space (4ch, 8× downsample).""" |
| config = IRISConfig( |
| latent_channels=4, |
| latent_spatial=latent_spatial, |
| hidden_dim=384, |
| num_heads=6, |
| head_dim=64, |
| ffn_ratio=2.667, |
| num_prelude_blocks=1, |
| num_core_layers=3, |
| num_coda_blocks=1, |
| default_iterations=8, |
| max_iterations=16, |
| fourier_num_blocks=6, |
| sparsity_threshold=0.01, |
| recurrence_dim=192, |
| manhattan_window=12, |
| text_dim=768, |
| max_text_tokens=77, |
| patch_size=2, |
| ) |
| return IRIS(config) |
|
|
|
|
| def create_iris_base(latent_spatial: int = 32) -> IRIS: |
| """Create IRIS-Base for SD-VAE latent space (4ch, 8× downsample).""" |
| config = IRISConfig( |
| latent_channels=4, |
| latent_spatial=latent_spatial, |
| hidden_dim=768, |
| num_heads=12, |
| head_dim=64, |
| ffn_ratio=2.667, |
| num_prelude_blocks=2, |
| num_core_layers=6, |
| num_coda_blocks=2, |
| default_iterations=8, |
| max_iterations=16, |
| fourier_num_blocks=12, |
| sparsity_threshold=0.01, |
| recurrence_dim=384, |
| manhattan_window=16, |
| text_dim=768, |
| max_text_tokens=77, |
| patch_size=2, |
| ) |
| return IRIS(config) |
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 70) |
| print("IRIS: Iterative Recurrent Image Synthesis") |
| print("=" * 70) |
| |
| |
| for name, create_fn in [("IRIS-Tiny", create_iris_tiny), |
| ("IRIS-Small", create_iris_small), |
| ("IRIS-Base", create_iris_base)]: |
| print(f"\n{'─' * 50}") |
| print(f" {name}") |
| print(f"{'─' * 50}") |
| model = create_fn() |
| counts = count_parameters(model) |
| mem_fp16 = estimate_memory_mb(model, torch.float16) |
| mem_fp32 = estimate_memory_mb(model, torch.float32) |
| |
| print(f" Total params: {counts['total']:>12,}") |
| print(f" Trainable params: {counts['trainable']:>12,}") |
| print(f" Memory (fp16): {mem_fp16:>10.1f} MB") |
| print(f" Memory (fp32): {mem_fp32:>10.1f} MB") |
| print(f" Components:") |
| for k, v in counts.items(): |
| if k not in ('total', 'trainable'): |
| print(f" {k:20s}: {v:>10,} ({v/counts['total']*100:.1f}%)") |
|
|