File size: 4,911 Bytes
88e5d09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | """IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .core import RefinementCore
class Patchify(nn.Module):
def __init__(self, in_channels=32, dim=512, patch_size=4):
super().__init__()
self.patch_size = patch_size
self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True)
self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True)
def forward(self, z):
B, C, H, W = z.shape
p = self.patch_size
z = self.dw_conv(z)
H_tok, W_tok = H // p, W // p
z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
return self.proj(z), H_tok, W_tok
class Unpatchify(nn.Module):
def __init__(self, out_channels=32, dim=512, patch_size=4):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True)
self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True)
def forward(self, tokens, H_tok, W_tok):
B, N, D = tokens.shape
p = self.patch_size
C = self.out_channels
z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
return self.dw_conv(z)
class TinyDecoder(nn.Module):
"""Minimal latent->pixels decoder via PixelShuffle. ~0.1M params."""
def __init__(self, in_channels=32, out_channels=3):
super().__init__()
self.stages = nn.ModuleList()
channels = [in_channels, 32, 32, 16, 8, out_channels]
for i in range(5):
self.stages.append(nn.Sequential(
nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True),
nn.PixelShuffle(2),
nn.SiLU() if i < 4 else nn.Identity(),
))
self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)
def forward(self, z):
x = z
for stage in self.stages:
x = stage(x)
return torch.tanh(self.final(x))
class IRIS(nn.Module):
"""
IRIS: Iterative Refinement Image Synthesizer.
Predicts velocity v_theta(z_t, t, c) for flow matching.
"""
def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6, num_heads=8, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
super().__init__()
self.latent_channels = latent_channels
self.dim = dim
self.patch_size = patch_size
self.patchify = Patchify(latent_channels, dim, patch_size)
self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
spatial_size = 4 # default for 16x16 latent with ps=4
self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads, spatial_size=spatial_size, max_iterations=max_iterations, ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
if m.weight is not None: nn.init.ones_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
nn.init.zeros_(self.unpatchify.proj.weight)
nn.init.zeros_(self.unpatchify.proj.bias)
def forward(self, z_t, t, context, num_iterations=4):
tokens, H_tok, W_tok = self.patchify(z_t)
if context.shape[-1] != self.dim:
if not hasattr(self, '_context_proj'):
self._context_proj = nn.Linear(context.shape[-1], self.dim, bias=False).to(context.device, context.dtype)
context = self._context_proj(context)
refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
return self.unpatchify(refined, H_tok, W_tok)
def decode_latent(self, z):
return self.tiny_decoder(z)
def count_params(self):
counts = {}
for name, module in self.named_children():
counts[name] = sum(p.numel() for p in module.parameters())
counts["total"] = sum(p.numel() for p in self.parameters())
counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
return counts
|