asdf98 commited on
Commit
88e5d09
·
verified ·
1 Parent(s): f307a14

Upload iris/model.py

Browse files
Files changed (1) hide show
  1. iris/model.py +114 -0
iris/model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """IRIS: Complete model — patchify, refinement core, unpatchify, tiny decoder."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from .core import RefinementCore
8
+
9
+
10
+ class Patchify(nn.Module):
11
+ def __init__(self, in_channels=32, dim=512, patch_size=4):
12
+ super().__init__()
13
+ self.patch_size = patch_size
14
+ self.dw_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=True)
15
+ self.proj = nn.Linear(in_channels * patch_size * patch_size, dim, bias=True)
16
+
17
+ def forward(self, z):
18
+ B, C, H, W = z.shape
19
+ p = self.patch_size
20
+ z = self.dw_conv(z)
21
+ H_tok, W_tok = H // p, W // p
22
+ z = z.view(B, C, H_tok, p, W_tok, p).permute(0, 2, 4, 1, 3, 5).reshape(B, H_tok * W_tok, C * p * p)
23
+ return self.proj(z), H_tok, W_tok
24
+
25
+
26
+ class Unpatchify(nn.Module):
27
+ def __init__(self, out_channels=32, dim=512, patch_size=4):
28
+ super().__init__()
29
+ self.patch_size = patch_size
30
+ self.out_channels = out_channels
31
+ self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True)
32
+ self.dw_conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels, bias=True)
33
+
34
+ def forward(self, tokens, H_tok, W_tok):
35
+ B, N, D = tokens.shape
36
+ p = self.patch_size
37
+ C = self.out_channels
38
+ z = self.proj(tokens).view(B, H_tok, W_tok, C, p, p)
39
+ z = z.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H_tok * p, W_tok * p)
40
+ return self.dw_conv(z)
41
+
42
+
43
+ class TinyDecoder(nn.Module):
44
+ """Minimal latent->pixels decoder via PixelShuffle. ~0.1M params."""
45
+ def __init__(self, in_channels=32, out_channels=3):
46
+ super().__init__()
47
+ self.stages = nn.ModuleList()
48
+ channels = [in_channels, 32, 32, 16, 8, out_channels]
49
+ for i in range(5):
50
+ self.stages.append(nn.Sequential(
51
+ nn.Conv2d(channels[i], channels[i+1]*4, 3, padding=1, bias=True),
52
+ nn.PixelShuffle(2),
53
+ nn.SiLU() if i < 4 else nn.Identity(),
54
+ ))
55
+ self.final = nn.Conv2d(out_channels, out_channels, 1, bias=True)
56
+
57
+ def forward(self, z):
58
+ x = z
59
+ for stage in self.stages:
60
+ x = stage(x)
61
+ return torch.tanh(self.final(x))
62
+
63
+
64
+ class IRIS(nn.Module):
65
+ """
66
+ IRIS: Iterative Refinement Image Synthesizer.
67
+ Predicts velocity v_theta(z_t, t, c) for flow matching.
68
+ """
69
+ def __init__(self, latent_channels=32, dim=512, patch_size=4, num_blocks=6, num_heads=8, max_iterations=8, ffn_expansion=2, gradient_checkpointing=True):
70
+ super().__init__()
71
+ self.latent_channels = latent_channels
72
+ self.dim = dim
73
+ self.patch_size = patch_size
74
+
75
+ self.patchify = Patchify(latent_channels, dim, patch_size)
76
+ self.unpatchify = Unpatchify(latent_channels, dim, patch_size)
77
+ spatial_size = 4 # default for 16x16 latent with ps=4
78
+ self.core = RefinementCore(dim=dim, num_blocks=num_blocks, num_heads=num_heads, spatial_size=spatial_size, max_iterations=max_iterations, ffn_expansion=ffn_expansion, gradient_checkpointing=gradient_checkpointing)
79
+ self.tiny_decoder = TinyDecoder(latent_channels, out_channels=3)
80
+ self._init_weights()
81
+
82
+ def _init_weights(self):
83
+ for m in self.modules():
84
+ if isinstance(m, nn.Linear):
85
+ nn.init.xavier_uniform_(m.weight)
86
+ if m.bias is not None: nn.init.zeros_(m.bias)
87
+ elif isinstance(m, nn.Conv2d):
88
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
89
+ if m.bias is not None: nn.init.zeros_(m.bias)
90
+ elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
91
+ if m.weight is not None: nn.init.ones_(m.weight)
92
+ if m.bias is not None: nn.init.zeros_(m.bias)
93
+ nn.init.zeros_(self.unpatchify.proj.weight)
94
+ nn.init.zeros_(self.unpatchify.proj.bias)
95
+
96
+ def forward(self, z_t, t, context, num_iterations=4):
97
+ tokens, H_tok, W_tok = self.patchify(z_t)
98
+ if context.shape[-1] != self.dim:
99
+ if not hasattr(self, '_context_proj'):
100
+ self._context_proj = nn.Linear(context.shape[-1], self.dim, bias=False).to(context.device, context.dtype)
101
+ context = self._context_proj(context)
102
+ refined = self.core(tokens, context, t, H_tok, W_tok, num_iterations=num_iterations)
103
+ return self.unpatchify(refined, H_tok, W_tok)
104
+
105
+ def decode_latent(self, z):
106
+ return self.tiny_decoder(z)
107
+
108
+ def count_params(self):
109
+ counts = {}
110
+ for name, module in self.named_children():
111
+ counts[name] = sum(p.numel() for p in module.parameters())
112
+ counts["total"] = sum(p.numel() for p in self.parameters())
113
+ counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
114
+ return counts