SynLayers commited on
Commit
4413f35
·
verified ·
1 Parent(s): b822c2d

Upload models/transp_vae.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/transp_vae.py +335 -0
models/transp_vae.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import einops
5
+ from collections import OrderedDict
6
+ from functools import partial
7
+ from typing import Callable
8
+ from torch.utils.checkpoint import checkpoint
9
+ from diffusers.models.embeddings import apply_rotary_emb, FluxPosEmbed
10
+
11
+
12
+ class MLPBlock(torchvision.ops.misc.MLP):
13
+ """Transformer MLP block."""
14
+
15
+ _version = 2
16
+
17
+ def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
18
+ super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
19
+
20
+ for m in self.modules():
21
+ if isinstance(m, nn.Linear):
22
+ nn.init.xavier_uniform_(m.weight)
23
+ if m.bias is not None:
24
+ nn.init.normal_(m.bias, std=1e-6)
25
+
26
+ def _load_from_state_dict(
27
+ self,
28
+ state_dict,
29
+ prefix,
30
+ local_metadata,
31
+ strict,
32
+ missing_keys,
33
+ unexpected_keys,
34
+ error_msgs,
35
+ ):
36
+ version = local_metadata.get("version", None)
37
+
38
+ if version is None or version < 2:
39
+ # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
40
+ for i in range(2):
41
+ for type in ["weight", "bias"]:
42
+ old_key = f"{prefix}linear_{i+1}.{type}"
43
+ new_key = f"{prefix}{3*i}.{type}"
44
+ if old_key in state_dict:
45
+ state_dict[new_key] = state_dict.pop(old_key)
46
+
47
+ super()._load_from_state_dict(
48
+ state_dict,
49
+ prefix,
50
+ local_metadata,
51
+ strict,
52
+ missing_keys,
53
+ unexpected_keys,
54
+ error_msgs,
55
+ )
56
+
57
+
58
+ class EncoderBlock(nn.Module):
59
+ """Transformer encoder block."""
60
+
61
+ def __init__(
62
+ self,
63
+ num_heads: int,
64
+ hidden_dim: int,
65
+ mlp_dim: int,
66
+ dropout: float,
67
+ attention_dropout: float,
68
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
69
+ ):
70
+ super().__init__()
71
+ self.num_heads = num_heads
72
+ self.hidden_dim = hidden_dim
73
+ self.num_heads = num_heads
74
+
75
+ # Attention block
76
+ self.ln_1 = norm_layer(hidden_dim)
77
+ self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
78
+ self.dropout = nn.Dropout(dropout)
79
+
80
+ # MLP block
81
+ self.ln_2 = norm_layer(hidden_dim)
82
+ self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
83
+
84
+ def forward(self, input: torch.Tensor, freqs_cis):
85
+ torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
86
+ B, L, C = input.shape
87
+ x = self.ln_1(input)
88
+ if freqs_cis is not None:
89
+ query = x.view(B, L, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
90
+ query = apply_rotary_emb(query, freqs_cis)
91
+ query = query.transpose(1, 2).reshape(B, L, self.hidden_dim)
92
+ else:
93
+ query = x
94
+ x, _ = self.self_attention(query, query, x, need_weights=False)
95
+ x = self.dropout(x)
96
+ x = x + input
97
+
98
+ y = self.ln_2(x)
99
+ y = self.mlp(y)
100
+ return x + y
101
+
102
+
103
+ class Encoder(nn.Module):
104
+ """Transformer Model Encoder for sequence to sequence translation."""
105
+
106
+ def __init__(
107
+ self,
108
+ seq_length: int,
109
+ num_layers: int,
110
+ num_heads: int,
111
+ hidden_dim: int,
112
+ mlp_dim: int,
113
+ dropout: float,
114
+ attention_dropout: float,
115
+ norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
116
+ ):
117
+ super().__init__()
118
+ # Note that batch_size is on the first dim because
119
+ # we have batch_first=True in nn.MultiAttention() by default
120
+ # self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
121
+ self.dropout = nn.Dropout(dropout)
122
+ layers: OrderedDict[str, nn.Module] = OrderedDict()
123
+ for i in range(num_layers):
124
+ layers[f"encoder_layer_{i}"] = EncoderBlock(
125
+ num_heads,
126
+ hidden_dim,
127
+ mlp_dim,
128
+ dropout,
129
+ attention_dropout,
130
+ norm_layer,
131
+ )
132
+ self.layers = nn.Sequential(layers)
133
+ self.ln = norm_layer(hidden_dim)
134
+
135
+ def forward(self, input: torch.Tensor, freqs_cis, use_checkpoint=True):
136
+ torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
137
+ input = input # + self.pos_embedding
138
+ x = self.dropout(input)
139
+ # x = checkpoint_sequential(self.layers, len(self.layers), x)
140
+ # x = self.layers(x)
141
+ for l in self.layers:
142
+ x = checkpoint(l, x, freqs_cis) if use_checkpoint else l(x, freqs_cis)
143
+ x = self.ln(x)
144
+ return x
145
+
146
+ class ViTEncoder(nn.Module):
147
+ def __init__(self, arch='vit-b/32', use_checkpoint=True):
148
+ super().__init__()
149
+ self.arch = arch
150
+ self.use_checkpoint = use_checkpoint
151
+
152
+ if self.arch == 'vit-b/32':
153
+ ch = 768
154
+ layers = 12
155
+ heads = 12
156
+ elif self.arch == 'vit-h/14':
157
+ ch = 1280
158
+ layers = 32
159
+ heads = 16
160
+
161
+ self.encoder = Encoder(
162
+ seq_length=-1,
163
+ num_layers=layers,
164
+ num_heads=heads,
165
+ hidden_dim=ch,
166
+ mlp_dim=ch*4,
167
+ dropout=0.0,
168
+ attention_dropout=0.0,
169
+ )
170
+ self.fc_in = nn.Linear(16, ch)
171
+ self.fc_out = nn.Linear(ch, 256)
172
+ # self.act = nn.Sigmoid()
173
+
174
+ if self.arch == 'vit-b/32':
175
+ from torchvision.models.vision_transformer import vit_b_32, ViT_B_32_Weights
176
+ vit = vit_b_32(weights=ViT_B_32_Weights.DEFAULT)
177
+ elif self.arch == 'vit-h/14':
178
+ from torchvision.models.vision_transformer import vit_h_14, ViT_H_14_Weights
179
+ vit = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1)
180
+
181
+ missing_keys, unexpected_keys = self.encoder.load_state_dict(vit.encoder.state_dict(), strict=False)
182
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
183
+ print(f"ViT Encoder Missing keys: {missing_keys}")
184
+ print(f"ViT Encoder Unexpected keys: {unexpected_keys}")
185
+ del vit
186
+
187
+ def forward(self, x, freqs_cis):
188
+ # o = checkpoint(self.fc_in, x)
189
+ o = self.fc_in(x)
190
+ o = self.encoder(o, freqs_cis, self.use_checkpoint)
191
+ o = checkpoint(self.fc_out, o) if self.use_checkpoint else self.fc_out(o)
192
+ # o = self.fc_out(self.encoder(self.fc_in(x), freqs_cis))
193
+ return o
194
+
195
+
196
+ def patchify(x, patch_size=8):
197
+ if len(x.shape) == 4:
198
+ bs, c, h, w = x.shape
199
+ x = einops.rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=patch_size, p2=patch_size)
200
+ elif len(x.shape) == 3:
201
+ c, h, w = x.shape
202
+ x = einops.rearrange(x, "c (h p1) (w p2) -> (c p1 p2) h w", p1=patch_size, p2=patch_size)
203
+ return x
204
+
205
+
206
+ def unpatchify(x, patch_size=8):
207
+ if len(x.shape) == 4:
208
+ bs, c, h, w = x.shape
209
+ x = einops.rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size)
210
+ elif len(x.shape) == 3:
211
+ c, h, w = x.shape
212
+ x = einops.rearrange(x, "(c p1 p2) h w -> c (h p1) (w p2)", p1=patch_size, p2=patch_size)
213
+ return x
214
+
215
+ def crop_each_layer(hidden_states, use_layers, list_layer_box, H, W, pos_embedding=None):
216
+ token_list = []
217
+ cos_list, sin_list = [], []
218
+ for layer_idx in range(hidden_states.shape[1]):
219
+ if list_layer_box[layer_idx] is None:
220
+ continue
221
+ else:
222
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
223
+ x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
224
+ layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2]
225
+ c, h, w = layer_token.shape
226
+ layer_token = layer_token.reshape(c, -1)
227
+ token_list.append(layer_token)
228
+ if pos_embedding is not None:
229
+ ids = prepare_latent_image_ids(-1, H * 2, W * 2, hidden_states.device, hidden_states.dtype)
230
+ ids[:, 0] = use_layers[layer_idx]
231
+ image_rotary_emb = pos_embedding(ids)
232
+ pos_cos, pos_sin = image_rotary_emb[0].reshape(H, W, -1), image_rotary_emb[1].reshape(H, W, -1)
233
+ cos_list.append(pos_cos[y1:y2, x1:x2].reshape(-1, 64))
234
+ sin_list.append(pos_sin[y1:y2, x1:x2].reshape(-1, 64))
235
+ token_list = torch.cat(token_list, dim=1).permute(1, 0)
236
+ if pos_embedding is not None:
237
+ cos_list = torch.cat(cos_list, dim=0)
238
+ sin_list = torch.cat(sin_list, dim=0)
239
+ return token_list, (cos_list, sin_list)
240
+
241
+
242
+ def prepare_latent_image_ids(batch_size, height, width, device, dtype):
243
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
244
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
245
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
246
+
247
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
248
+
249
+ latent_image_ids = latent_image_ids.reshape(
250
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
251
+ )
252
+
253
+ return latent_image_ids.to(device=device, dtype=dtype)
254
+
255
+
256
+ class AutoencoderKLTransformerTraining(nn.Module):
257
+ def __init__(self, args):
258
+ super().__init__()
259
+
260
+ self.args = args
261
+
262
+ self.decoder = ViTEncoder(use_checkpoint=self.args.single_layer_decoder is None)
263
+ self.decoder.requires_grad_(True)
264
+
265
+ if self.args.pos_embedding == 'rope':
266
+ self.pos_embedding = FluxPosEmbed(theta=10000, axes_dim=(8, 28, 28))
267
+ elif self.args.pos_embedding == 'abs':
268
+ self.pos_embedding = nn.Parameter(torch.empty(16, 1, args.resolution // 8, args.resolution // 8).normal_(std=0.02), requires_grad=True)
269
+
270
+ if 'rel' in self.args.layer_embedding or 'abs' in self.args.layer_embedding:
271
+ self.layer_embedding = nn.Parameter(torch.empty(16, 2 + self.args.max_layers, 1, 1).normal_(std=0.02), requires_grad=True)
272
+
273
+ def encode(self, x, box, use_layers, z_2d):
274
+ B, C, T, H, W = x.shape # H W are original image size (In ART, H W are latent size) quesion: why?(It seems no difference)
275
+
276
+ z, freqs_cis = [], []
277
+ for b in range(B):
278
+ _z = z_2d[b]
279
+ if 'vit' in self.args.decoder_arch:
280
+ _use_layers = torch.tensor(use_layers[b], device=x.device)
281
+ if 'rel' in self.args.layer_embedding:
282
+ _use_layers[_use_layers > 2] = 2
283
+ if 'rel' in self.args.layer_embedding or 'abs' in self.args.layer_embedding:
284
+ _z = _z + self.layer_embedding[:, _use_layers]
285
+ if 'abs' in self.args.pos_embedding:
286
+ _z = _z + self.pos_embedding
287
+ if 'rope' not in self.args.layer_embedding:
288
+ use_layers[b] = [0] * len(use_layers[b])
289
+ _z, cis = crop_each_layer(_z, use_layers[b], box[b], H, W, self.pos_embedding if self.args.pos_embedding == 'rope' else None)
290
+ # _z, cis = crop_each_layer(_z, use_layers[b], box[b], H // 8, W // 8, self.pos_embedding if self.args.pos_embedding == 'rope' else None)
291
+ z.append(_z)
292
+ freqs_cis.append(cis)
293
+
294
+ return z, freqs_cis
295
+
296
+ def decode(self, z, freqs_cis, box, H, W):
297
+ B = len(z)
298
+ pad = torch.zeros(4, H, W, device=z[0].device, dtype=z[0].dtype)
299
+ pad[3, :, :] = -1
300
+ x = []
301
+ for b in range(B):
302
+ _x = []
303
+ _freqs_cis = freqs_cis[b] if 'rope' in self.args.pos_embedding else None
304
+ if self.args.single_layer_decoder is None:
305
+ _z = self.decoder(z[b].unsqueeze(0), _freqs_cis).squeeze(0)
306
+ else:
307
+ _z = z[b]
308
+ current_index = 0
309
+ for layer_idx in range(len(box[b])):
310
+ if box[b][layer_idx] == None:
311
+ _x.append(pad.clone())
312
+ else:
313
+ x1, y1, x2, y2 = box[b][layer_idx]
314
+ x1_tok, y1_tok, x2_tok, y2_tok = x1 // 8, y1 // 8, x2 // 8, y2 // 8
315
+ token_length = (x2_tok - x1_tok) * (y2_tok - y1_tok)
316
+ tokens = _z[current_index:current_index + token_length]
317
+ if self.args.single_layer_decoder == 'vit': # single layer ViT decoder
318
+ tokens = self.decoder(tokens.unsqueeze(0), (_freqs_cis[0][current_index:current_index + token_length], _freqs_cis[1][current_index:current_index + token_length])).squeeze(0)
319
+ pixels = einops.rearrange(tokens, "(h w) c -> c h w", h=y2_tok - y1_tok, w=x2_tok - x1_tok)
320
+ unpatched = unpatchify(pixels)
321
+ pixels = pad.clone()
322
+ pixels[:, y1:y2, x1:x2] = unpatched
323
+ _x.append(pixels)
324
+ current_index += token_length
325
+ _x = torch.stack(_x, dim=1)
326
+ x.append(_x)
327
+ x = torch.stack(x, dim=0)
328
+
329
+ return x
330
+
331
+ def forward(self, x, box, use_layers, z_2d):
332
+ B, C, T, H, W = x.shape # H W are original image size (In ART, H W are latent size)
333
+ z, freqs_cis = self.encode(x, box, use_layers, z_2d)
334
+ x_hat = self.decode(z, freqs_cis, box, H, W)
335
+ return x_hat