"""IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder.""" import torch import torch.nn as nn import torch.nn.functional as F import math from .core import RefinementCore class Patchify(nn.Module): def __init__(self, in_channels=32, dim=512, patch_size=4): super().__init__() self.patch_size = patch_size self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True) self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True) def forward(self, z): B, C, H, W = z.shape p = self.patch_size orig_dtype = z.dtype # Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4 with torch.amp.autocast(device_type='cuda', enabled=False): z = self.dw_conv(z.float()) z = z.to(orig_dtype) H_tok, W_tok = H // p, W // p z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p) return self.proj(z), H_tok, W_tok class Unpatchify(nn.Module): def __init__(self, out_channels=32, dim=512, patch_size=4): super().__init__() self.patch_size = patch_size self.out_channels = out_channels self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True) self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True) def forward(self, tokens, H_tok, W_tok): B, N, D = tokens.shape p = self.patch_size C = self.out_channels z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p) z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p) # Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4 orig_dtype = z.dtype with torch.amp.autocast(device_type='cuda', enabled=False): z = self.dw_conv(z.float()) return z.to(orig_dtype) class TinyDecoder(nn.Module): """Minimal latent->pixels decoder via PixelShuffle. ~0.1M params.""" def __init__(self, in_channels=32, out_channels=3): super().__init__() self.stages = nn.ModuleList() channels = [in_channels, 32, 32, 16, 8, out_channels] for i in range(5): self.stages.append(nn.Sequential( nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True), nn.PixelShuffle(2), nn.SiLU() if i < 4 else nn.Identity(), )) self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True) def forward(self, z): # Run decoder convs in float32 — cuDNN lacks bf16 kernels on T4 orig_dtype = z.dtype with torch.amp.autocast(device_type='cuda', enabled=False): x = z.float() for stage in self.stages: x = stage(x) x = torch.tanh(self.final(x)) return x.to(orig_dtype) class IRIS(nn.Module): """ IRIS: Iterative Refinement Image Synthesizer. Predicts velocity v_theta(z_t, t, c) for flow matching. Args: text_dim: dimension of text encoder output. If different from dim, a learned linear projection is applied. Set to 384 for all-MiniLM-L6-v2, 512 for CLIP, etc. Set to None or equal to dim to skip projection. """ def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6, num_heads=8, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True, text_dim=None): super().__init__() self.latent_channels = latent_channels self.dim = dim self.patch_size = patch_size self.patchify = Patchify(latent_channels, dim, patch_size) self.unpatchify = Unpatchify(latent_channels, dim, patch_size) spatial_size = 4 # default for 16x16 latent with ps=4 self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads, spatial_size=spatial_size, max_iterations=max_iterations, ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing) self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3) # Text projection: maps text encoder dim to model dim if they differ if text_dim is not None and text_dim != dim: self.context_proj = nn.Linear(text_dim, dim, bias=False) else: self.context_proj = None self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): if m.weight is not None: nn.init.ones_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) nn.init.zeros_(self.unpatchify.proj.weight) nn.init.zeros_(self.unpatchify.proj.bias) def forward(self, z_t, t, context, num_iterations=4): tokens, H_tok, W_tok = self.patchify(z_t) # Project text embeddings to model dim if needed if self.context_proj is not None: context = self.context_proj(context) elif context.shape[-1] != self.dim: # Fallback: lazy projection for backwards compat if not hasattr(self, '_lazy_context_proj'): self._lazy_context_proj = nn.Linear( context.shape[-1], self.dim, bias=False ).to(context.device, context.dtype) context = self._lazy_context_proj(context) refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations) return self.unpatchify(refined, H_tok, W_tok) def decode_latent(self, z): return self.tiny_decoder(z) def count_params(self): counts = {} for name, module in self.named_children(): counts[name] = sum(p.numel() for p in module.parameters()) counts["total"] = sum(p.numel() for p in self.parameters()) counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad) return counts