File size: 6,514 Bytes
88e5d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654d061
 
 
 
 
88e5d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654d061
 
 
 
 
88e5d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654d061
 
 
 
 
 
 
 
88e5d09
 
 
 
 
 
c98929a
 
 
 
 
 
88e5d09
c98929a
 
 
88e5d09
 
 
 
 
 
 
 
c98929a
 
 
88e5d09
c98929a
 
 
 
 
 
 
88e5d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c98929a
 
 
 
 
 
 
 
 
 
 
 
88e5d09
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder."""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .core import RefinementCore


class Patchify(nn.Module):
    def __init__(self, in_channels=32, dim=512, patch_size=4):
        super().__init__()
        self.patch_size = patch_size
        self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True)
        self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True)

    def forward(self, z):
        B, C, H, W = z.shape
        p = self.patch_size
        orig_dtype = z.dtype
        # Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
        with torch.amp.autocast(device_type='cuda', enabled=False):
            z = self.dw_conv(z.float())
        z = z.to(orig_dtype)
        H_tok, W_tok = H // p, W // p
        z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
        return self.proj(z), H_tok, W_tok


class Unpatchify(nn.Module):
    def __init__(self, out_channels=32, dim=512, patch_size=4):
        super().__init__()
        self.patch_size = patch_size
        self.out_channels = out_channels
        self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True)
        self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True)

    def forward(self, tokens, H_tok, W_tok):
        B, N, D = tokens.shape
        p = self.patch_size
        C = self.out_channels
        z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
        z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
        # Run grouped conv in float32 — cuDNN lacks bf16 kernels for grouped convs on T4
        orig_dtype = z.dtype
        with torch.amp.autocast(device_type='cuda', enabled=False):
            z = self.dw_conv(z.float())
        return z.to(orig_dtype)


class TinyDecoder(nn.Module):
    """Minimal latent->pixels decoder via PixelShuffle. ~0.1M params."""
    def __init__(self, in_channels=32, out_channels=3):
        super().__init__()
        self.stages = nn.ModuleList()
        channels = [in_channels, 32, 32, 16, 8, out_channels]
        for i in range(5):
            self.stages.append(nn.Sequential(
                nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True),
                nn.PixelShuffle(2),
                nn.SiLU() if i < 4 else nn.Identity(),
            ))
        self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)

    def forward(self, z):
        # Run decoder convs in float32 — cuDNN lacks bf16 kernels on T4
        orig_dtype = z.dtype
        with torch.amp.autocast(device_type='cuda', enabled=False):
            x = z.float()
            for stage in self.stages:
                x = stage(x)
            x = torch.tanh(self.final(x))
        return x.to(orig_dtype)


class IRIS(nn.Module):
    """
    IRIS: Iterative Refinement Image Synthesizer.
    Predicts velocity v_theta(z_t, t, c) for flow matching.
    
    Args:
        text_dim: dimension of text encoder output. If different from dim,
                  a learned linear projection is applied. Set to 384 for
                  all-MiniLM-L6-v2, 512 for CLIP, etc. Set to None or
                  equal to dim to skip projection.
    """
    def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6,
                 num_heads=8, max_iterations=8, ffn_expansion=2,
                 gradient_checkpointing=True, text_dim=None):
        super().__init__()
        self.latent_channels = latent_channels
        self.dim = dim
        self.patch_size = patch_size

        self.patchify = Patchify(latent_channels, dim, patch_size)
        self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
        spatial_size = 4  # default for 16x16 latent with ps=4
        self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads,
                                   spatial_size=spatial_size, max_iterations=max_iterations,
                                   ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
        self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)

        # Text projection: maps text encoder dim to model dim if they differ
        if text_dim is not None and text_dim != dim:
            self.context_proj = nn.Linear(text_dim, dim, bias=False)
        else:
            self.context_proj = None

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
                if m.weight is not None: nn.init.ones_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
        nn.init.zeros_(self.unpatchify.proj.weight)
        nn.init.zeros_(self.unpatchify.proj.bias)

    def forward(self, z_t, t, context, num_iterations=4):
        tokens, H_tok, W_tok = self.patchify(z_t)

        # Project text embeddings to model dim if needed
        if self.context_proj is not None:
            context = self.context_proj(context)
        elif context.shape[-1] != self.dim:
            # Fallback: lazy projection for backwards compat
            if not hasattr(self, '_lazy_context_proj'):
                self._lazy_context_proj = nn.Linear(
                    context.shape[-1], self.dim, bias=False
                ).to(context.device, context.dtype)
            context = self._lazy_context_proj(context)

        refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
        return self.unpatchify(refined, H_tok, W_tok)

    def decode_latent(self, z):
        return self.tiny_decoder(z)

    def count_params(self):
        counts = {}
        for name, module in self.named_children():
            counts[name] = sum(p.numel() for p in module.parameters())
        counts["total"] = sum(p.numel() for p in self.parameters())
        counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return counts