File size: 6,514 Bytes
88e5d09 654d061 88e5d09 654d061 88e5d09 654d061 88e5d09 c98929a 88e5d09 c98929a 88e5d09 c98929a 88e5d09 c98929a 88e5d09 c98929a 88e5d09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .core import RefinementCore
class Patchify(nn.Module):
def __init__(self, in_channels=32, dim=512, patch_size=4):
super().__init__()
self.patch_size = patch_size
self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True)
self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True)
def forward(self, z):
B, C, H, W = z.shape
p = self.patch_size
orig_dtype = z.dtype
# Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
with torch.amp.autocast(device_type='cuda', enabled=False):
z = self.dw_conv(z.float())
z = z.to(orig_dtype)
H_tok, W_tok = H // p, W // p
z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
return self.proj(z), H_tok, W_tok
class Unpatchify(nn.Module):
def __init__(self, out_channels=32, dim=512, patch_size=4):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True)
self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True)
def forward(self, tokens, H_tok, W_tok):
B, N, D = tokens.shape
p = self.patch_size
C = self.out_channels
z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
# Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
orig_dtype = z.dtype
with torch.amp.autocast(device_type='cuda', enabled=False):
z = self.dw_conv(z.float())
return z.to(orig_dtype)
class TinyDecoder(nn.Module):
"""Minimal latent->pixels decoder via PixelShuffle. ~0.1M params."""
def __init__(self, in_channels=32, out_channels=3):
super().__init__()
self.stages = nn.ModuleList()
channels = [in_channels, 32, 32, 16, 8, out_channels]
for i in range(5):
self.stages.append(nn.Sequential(
nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True),
nn.PixelShuffle(2),
nn.SiLU() if i < 4 else nn.Identity(),
))
self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)
def forward(self, z):
# Run decoder convs in float32 — cuDNN lacks bf16 kernels on T4
orig_dtype = z.dtype
with torch.amp.autocast(device_type='cuda', enabled=False):
x = z.float()
for stage in self.stages:
x = stage(x)
x = torch.tanh(self.final(x))
return x.to(orig_dtype)
class IRIS(nn.Module):
"""
IRIS: Iterative Refinement Image Synthesizer.
Predicts velocity v_theta(z_t, t, c) for flow matching.
Args:
text_dim: dimension of text encoder output. If different from dim,
a learned linear projection is applied. Set to 384 for
all-MiniLM-L6-v2, 512 for CLIP, etc. Set to None or
equal to dim to skip projection.
"""
def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6,
num_heads=8, max_iterations=8, ffn_expansion=2,
gradient_checkpointing=True, text_dim=None):
super().__init__()
self.latent_channels = latent_channels
self.dim = dim
self.patch_size = patch_size
self.patchify = Patchify(latent_channels, dim, patch_size)
self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
spatial_size = 4 # default for 16x16 latent with ps=4
self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads,
spatial_size=spatial_size, max_iterations=max_iterations,
ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
# Text projection: maps text encoder dim to model dim if they differ
if text_dim is not None and text_dim != dim:
self.context_proj = nn.Linear(text_dim, dim, bias=False)
else:
self.context_proj = None
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
if m.weight is not None: nn.init.ones_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
nn.init.zeros_(self.unpatchify.proj.weight)
nn.init.zeros_(self.unpatchify.proj.bias)
def forward(self, z_t, t, context, num_iterations=4):
tokens, H_tok, W_tok = self.patchify(z_t)
# Project text embeddings to model dim if needed
if self.context_proj is not None:
context = self.context_proj(context)
elif context.shape[-1] != self.dim:
# Fallback: lazy projection for backwards compat
if not hasattr(self, '_lazy_context_proj'):
self._lazy_context_proj = nn.Linear(
context.shape[-1], self.dim, bias=False
).to(context.device, context.dtype)
context = self._lazy_context_proj(context)
refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
return self.unpatchify(refined, H_tok, W_tok)
def decode_latent(self, z):
return self.tiny_decoder(z)
def count_params(self):
counts = {}
for name, module in self.named_children():
counts[name] = sum(p.numel() for p in module.parameters())
counts["total"] = sum(p.numel() for p in self.parameters())
counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
return counts
|