"""IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder.""" import torch import torch.nn as nn import torch.nn.functional as F import math from .core import RefinementCore class Patchify(nn.Module): def __init__(self, in_channels=32, dim=512, patch_size=4): super().__init__() self.patch_size = patch_size self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True) self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True) def forward(self, z): B, C, H, W = z.shape p = self.patch_size z = self.dw_conv(z) H_tok, W_tok = H // p, W // p z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p) return self.proj(z), H_tok, W_tok class Unpatchify(nn.Module): def __init__(self, out_channels=32, dim=512, patch_size=4): super().__init__() self.patch_size = patch_size self.out_channels = out_channels self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True) self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True) def forward(self, tokens, H_tok, W_tok): B, N, D = tokens.shape p = self.patch_size C = self.out_channels z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p) z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p) return self.dw_conv(z) class TinyDecoder(nn.Module): """Minimal latent->pixels decoder via PixelShuffle. ~0.1M params.""" def __init__(self, in_channels=32, out_channels=3): super().__init__() self.stages = nn.ModuleList() channels = [in_channels, 32, 32, 16, 8, out_channels] for i in range(5): self.stages.append(nn.Sequential( nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True), nn.PixelShuffle(2), nn.SiLU() if i < 4 else nn.Identity(), )) self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True) def forward(self, z): x = z for stage in self.stages: x = stage(x) return torch.tanh(self.final(x)) class IRIS(nn.Module): """ IRIS: Iterative Refinement Image Synthesizer. Predicts velocity v_theta(z_t, t, c) for flow matching. """ def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6, num_heads=8, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True): super().__init__() self.latent_channels = latent_channels self.dim = dim self.patch_size = patch_size self.patchify = Patchify(latent_channels, dim, patch_size) self.unpatchify = Unpatchify(latent_channels, dim, patch_size) spatial_size = 4 # default for 16x16 latent with ps=4 self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads, spatial_size=spatial_size, max_iterations=max_iterations, ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing) self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): if m.weight is not None: nn.init.ones_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) nn.init.zeros_(self.unpatchify.proj.weight) nn.init.zeros_(self.unpatchify.proj.bias) def forward(self, z_t, t, context, num_iterations=4): tokens, H_tok, W_tok = self.patchify(z_t) if context.shape[-1] != self.dim: if not hasattr(self, '_context_proj'): self._context_proj = nn.Linear(context.shape[-1], self.dim, bias=False).to(context.device, context.dtype) context = self._context_proj(context) refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations) return self.unpatchify(refined, H_tok, W_tok) def decode_latent(self, z): return self.tiny_decoder(z) def count_params(self): counts = {} for name, module in self.named_children(): counts[name] = sum(p.numel() for p in module.parameters()) counts["total"] = sum(p.numel() for p in self.parameters()) counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad) return counts