File size: 4,911 Bytes
88e5d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder."""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .core import RefinementCore


class Patchify(nn.Module):
    def __init__(self, in_channels=32, dim=512, patch_size=4):
        super().__init__()
        self.patch_size = patch_size
        self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True)
        self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True)

    def forward(self, z):
        B, C, H, W = z.shape
        p = self.patch_size
        z = self.dw_conv(z)
        H_tok, W_tok = H // p, W // p
        z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
        return self.proj(z), H_tok, W_tok


class Unpatchify(nn.Module):
    def __init__(self, out_channels=32, dim=512, patch_size=4):
        super().__init__()
        self.patch_size = patch_size
        self.out_channels = out_channels
        self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True)
        self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True)

    def forward(self, tokens, H_tok, W_tok):
        B, N, D = tokens.shape
        p = self.patch_size
        C = self.out_channels
        z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
        z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
        return self.dw_conv(z)


class TinyDecoder(nn.Module):
    """Minimal latent->pixels decoder via PixelShuffle. ~0.1M params."""
    def __init__(self, in_channels=32, out_channels=3):
        super().__init__()
        self.stages = nn.ModuleList()
        channels = [in_channels, 32, 32, 16, 8, out_channels]
        for i in range(5):
            self.stages.append(nn.Sequential(
                nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True),
                nn.PixelShuffle(2),
                nn.SiLU() if i < 4 else nn.Identity(),
            ))
        self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)

    def forward(self, z):
        x = z
        for stage in self.stages:
            x = stage(x)
        return torch.tanh(self.final(x))


class IRIS(nn.Module):
    """
    IRIS: Iterative Refinement Image Synthesizer.
    Predicts velocity v_theta(z_t, t, c) for flow matching.
    """
    def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6, num_heads=8, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
        super().__init__()
        self.latent_channels = latent_channels
        self.dim = dim
        self.patch_size = patch_size

        self.patchify = Patchify(latent_channels, dim, patch_size)
        self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
        spatial_size = 4  # default for 16x16 latent with ps=4
        self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads, spatial_size=spatial_size, max_iterations=max_iterations, ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
        self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
                if m.weight is not None: nn.init.ones_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
        nn.init.zeros_(self.unpatchify.proj.weight)
        nn.init.zeros_(self.unpatchify.proj.bias)

    def forward(self, z_t, t, context, num_iterations=4):
        tokens, H_tok, W_tok = self.patchify(z_t)
        if context.shape[-1] != self.dim:
            if not hasattr(self, '_context_proj'):
                self._context_proj = nn.Linear(context.shape[-1], self.dim, bias=False).to(context.device, context.dtype)
            context = self._context_proj(context)
        refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
        return self.unpatchify(refined, H_tok, W_tok)

    def decode_latent(self, z):
        return self.tiny_decoder(z)

    def count_params(self):
        counts = {}
        for name, module in self.named_children():
            counts[name] = sum(p.numel() for p in module.parameters())
        counts["total"] = sum(p.numel() for p in self.parameters())
        counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return counts