SynLayers commited on
Commit
2d767a5
·
verified ·
1 Parent(s): 696dd0c

Upload models/mmdit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/mmdit.py +356 -0
models/mmdit.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Dict, List, Optional, Union, Tuple
4
+ import numpy as np
5
+
6
+ from accelerate.utils import set_module_tensor_to_device
7
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
8
+ from diffusers.models.normalization import AdaLayerNormContinuous
9
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
10
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel, FluxTransformerBlock, FluxSingleTransformerBlock
11
+
12
+ from diffusers.configuration_utils import register_to_config
13
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
14
+
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ class CustomFluxTransformer2DModel(FluxTransformer2DModel):
20
+ """
21
+ The Transformer model introduced in Flux.
22
+
23
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
24
+
25
+ Parameters:
26
+ patch_size (`int`): Patch size to turn the input data into small patches.
27
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
28
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
29
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
30
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
31
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
32
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
33
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
34
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
35
+ """
36
+
37
+ @register_to_config
38
+ def __init__(
39
+ self,
40
+ patch_size: int = 1,
41
+ in_channels: int = 64,
42
+ num_layers: int = 19,
43
+ num_single_layers: int = 38,
44
+ attention_head_dim: int = 128,
45
+ num_attention_heads: int = 24,
46
+ joint_attention_dim: int = 4096,
47
+ pooled_projection_dim: int = 768,
48
+ guidance_embeds: bool = False,
49
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
50
+ max_layer_num: int = 52,
51
+ ):
52
+ super(FluxTransformer2DModel, self).__init__()
53
+ self.out_channels = in_channels
54
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
55
+
56
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
57
+
58
+ text_time_guidance_cls = (
59
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
60
+ )
61
+ self.time_text_embed = text_time_guidance_cls(
62
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
63
+ )
64
+
65
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
66
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
67
+
68
+ self.transformer_blocks = nn.ModuleList(
69
+ [
70
+ FluxTransformerBlock(
71
+ dim=self.inner_dim,
72
+ num_attention_heads=self.config.num_attention_heads,
73
+ attention_head_dim=self.config.attention_head_dim,
74
+ )
75
+ for i in range(self.config.num_layers)
76
+ ]
77
+ )
78
+
79
+ self.single_transformer_blocks = nn.ModuleList(
80
+ [
81
+ FluxSingleTransformerBlock(
82
+ dim=self.inner_dim,
83
+ num_attention_heads=self.config.num_attention_heads,
84
+ attention_head_dim=self.config.attention_head_dim,
85
+ )
86
+ for i in range(self.config.num_single_layers)
87
+ ]
88
+ )
89
+
90
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
91
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
92
+
93
+ self.gradient_checkpointing = False
94
+
95
+ self.max_layer_num = max_layer_num
96
+
97
+ # the following process ensures self.layer_pe is not created as a meta tensor
98
+ layer_pe_value = nn.init.trunc_normal_(
99
+ nn.Parameter(torch.zeros(
100
+ 1, self.max_layer_num, 1, 1, self.inner_dim,
101
+ )),
102
+ mean=0.0, std=0.02, a=-2.0, b=2.0,
103
+ ).data.detach()
104
+ self.layer_pe = nn.Parameter(layer_pe_value)
105
+ set_module_tensor_to_device(
106
+ self,
107
+ 'layer_pe',
108
+ device='cpu',
109
+ value=layer_pe_value,
110
+ dtype=layer_pe_value.dtype,
111
+ )
112
+
113
+ @classmethod
114
+ def from_pretrained(cls, *args, **kwarg):
115
+ model = super().from_pretrained(*args, **kwarg)
116
+ for name, para in model.named_parameters():
117
+ if name != 'layer_pe':
118
+ device = para.device
119
+ break
120
+ model.layer_pe.to(device)
121
+ return model
122
+
123
+ def crop_each_layer(self, hidden_states, list_layer_box):
124
+ """
125
+ hidden_states: [1, n_layers, h, w, inner_dim]
126
+ list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
127
+ """
128
+ token_list = []
129
+ for layer_idx in range(hidden_states.shape[1]):
130
+ if list_layer_box[layer_idx] == None:
131
+ continue
132
+ else:
133
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
134
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
135
+ layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2, :]
136
+ bs, h, w, c = layer_token.shape
137
+ layer_token = layer_token.reshape(bs, -1, c)
138
+ token_list.append(layer_token)
139
+ result = torch.cat(token_list, dim=1)
140
+ return result
141
+
142
+ def fill_in_processed_tokens(self, hidden_states, full_hidden_states, list_layer_box):
143
+ """
144
+ hidden_states: [1, h1xw1 + h2xw2 + ... + hlxwl , inner_dim]
145
+ full_hidden_states: [1, n_layers, h, w, inner_dim]
146
+ list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
147
+ """
148
+ used_token_len = 0
149
+ bs = hidden_states.shape[0]
150
+ for layer_idx in range(full_hidden_states.shape[1]):
151
+ if list_layer_box[layer_idx] == None:
152
+ continue
153
+ else:
154
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
155
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
156
+ full_hidden_states[:, layer_idx, y1:y2, x1:x2, :] = hidden_states[:, used_token_len: used_token_len + (y2-y1) * (x2-x1), :].reshape(bs, y2-y1, x2-x1, -1)
157
+ used_token_len = used_token_len + (y2-y1) * (x2-x1)
158
+ return full_hidden_states
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states: torch.Tensor,
163
+ list_layer_box: List[Tuple] = None,
164
+ encoder_hidden_states: torch.Tensor = None,
165
+ pooled_projections: torch.Tensor = None,
166
+ timestep: torch.LongTensor = None,
167
+ img_ids: torch.Tensor = None,
168
+ txt_ids: torch.Tensor = None,
169
+ guidance: torch.Tensor = None,
170
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
171
+ adapter_block_samples=None,
172
+ adapter_single_block_samples=None,
173
+ return_dict: bool = True,
174
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
175
+ """
176
+ The [`FluxTransformer2DModel`] forward method.
177
+
178
+ Args:
179
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
180
+ Input `hidden_states`.
181
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
182
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
183
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
184
+ from the embeddings of input conditions.
185
+ timestep ( `torch.LongTensor`):
186
+ Used to indicate denoising step.
187
+ joint_attention_kwargs (`dict`, *optional*):
188
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
189
+ `self.processor` in
190
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
191
+ return_dict (`bool`, *optional*, defaults to `True`):
192
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
193
+ tuple.
194
+
195
+ Returns:
196
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
197
+ `tuple` where the first element is the sample tensor.
198
+ """
199
+ if joint_attention_kwargs is not None:
200
+ joint_attention_kwargs = joint_attention_kwargs.copy()
201
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
202
+ else:
203
+ lora_scale = 1.0
204
+
205
+ if USE_PEFT_BACKEND:
206
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
207
+ scale_lora_layers(self, lora_scale)
208
+ else:
209
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
210
+ logger.warning(
211
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
212
+ )
213
+
214
+ bs, n_layers, channel_latent, height, width = hidden_states.shape # [bs, n_layers, c_latent, h, w]
215
+
216
+ hidden_states = hidden_states.view(bs, n_layers, channel_latent, height // 2, 2, width // 2, 2) # [bs, n_layers, c_latent, h/2, 2, w/2, 2]
217
+ hidden_states = hidden_states.permute(0, 1, 3, 5, 2, 4, 6) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
218
+ hidden_states = hidden_states.reshape(bs, n_layers, height // 2, width // 2, channel_latent * 4) # [bs, n_layers, h/2, w/2, c_latent*4]
219
+ hidden_states = self.x_embedder(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
220
+
221
+ full_hidden_states = torch.zeros_like(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
222
+ layer_pe = self.layer_pe.view(1, self.max_layer_num, 1, 1, self.inner_dim) # [1, max_n_layers, 1, 1, inner_dim]
223
+ hidden_states = hidden_states + layer_pe[:, :n_layers] # [bs, n_layers, h/2, w/2, inner_dim] + [1, n_layers, 1, 1, inner_dim] --> [bs, f, h/2, w/2, inner_dim]
224
+ hidden_states = self.crop_each_layer(hidden_states, list_layer_box) # [bs, token_len, inner_dim]
225
+
226
+ timestep = timestep.to(hidden_states.dtype) * 1000
227
+ if guidance is not None:
228
+ guidance = guidance.to(hidden_states.dtype) * 1000
229
+ else:
230
+ guidance = None
231
+ temb = (
232
+ self.time_text_embed(timestep, pooled_projections)
233
+ if guidance is None
234
+ else self.time_text_embed(timestep, guidance, pooled_projections)
235
+ )
236
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
237
+
238
+ if txt_ids.ndim == 3:
239
+ logger.warning(
240
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
241
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
242
+ )
243
+ txt_ids = txt_ids[0]
244
+ if img_ids.ndim == 3:
245
+ logger.warning(
246
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
247
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
248
+ )
249
+ img_ids = img_ids[0]
250
+ ids = torch.cat((txt_ids, img_ids), dim=0)
251
+ image_rotary_emb = self.pos_embed(ids)
252
+
253
+ for index_block, block in enumerate(self.transformer_blocks):
254
+ if self.training and self.gradient_checkpointing:
255
+
256
+ def create_custom_forward(module, return_dict=None):
257
+ def custom_forward(*inputs):
258
+ if return_dict is not None:
259
+ return module(*inputs, return_dict=return_dict)
260
+ else:
261
+ return module(*inputs)
262
+
263
+ return custom_forward
264
+
265
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
266
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
267
+ create_custom_forward(block),
268
+ hidden_states,
269
+ encoder_hidden_states,
270
+ temb,
271
+ image_rotary_emb,
272
+ **ckpt_kwargs,
273
+ )
274
+
275
+ else:
276
+ encoder_hidden_states, hidden_states = block(
277
+ hidden_states=hidden_states,
278
+ encoder_hidden_states=encoder_hidden_states,
279
+ temb=temb,
280
+ image_rotary_emb=image_rotary_emb,
281
+ )
282
+
283
+ # adapter residual
284
+ if adapter_block_samples is not None:
285
+ interval_adapter = len(self.transformer_blocks) / len(
286
+ adapter_block_samples
287
+ )
288
+ interval_adapter = int(np.ceil(interval_adapter))
289
+ hidden_states = (
290
+ hidden_states
291
+ + adapter_block_samples[index_block // interval_adapter]
292
+ )
293
+
294
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
295
+
296
+ for index_block, block in enumerate(self.single_transformer_blocks):
297
+ if self.training and self.gradient_checkpointing:
298
+
299
+ def create_custom_forward(module, return_dict=None):
300
+ def custom_forward(*inputs):
301
+ if return_dict is not None:
302
+ return module(*inputs, return_dict=return_dict)
303
+ else:
304
+ return module(*inputs)
305
+
306
+ return custom_forward
307
+
308
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
309
+ hidden_states = torch.utils.checkpoint.checkpoint(
310
+ create_custom_forward(block),
311
+ hidden_states,
312
+ temb,
313
+ image_rotary_emb,
314
+ **ckpt_kwargs,
315
+ )
316
+
317
+ else:
318
+ hidden_states = block(
319
+ hidden_states=hidden_states,
320
+ temb=temb,
321
+ image_rotary_emb=image_rotary_emb,
322
+ )
323
+
324
+ # adapter residual
325
+ if adapter_single_block_samples is not None:
326
+ interval_adapter = len(self.single_transformer_blocks) / len(
327
+ adapter_single_block_samples
328
+ )
329
+ interval_adapter = int(np.ceil(interval_adapter))
330
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
331
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
332
+ + adapter_single_block_samples[index_block // interval_adapter]
333
+ )
334
+
335
+
336
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
337
+
338
+ hidden_states = self.fill_in_processed_tokens(hidden_states, full_hidden_states, list_layer_box) # [bs, n_layers, h/2, w/2, inner_dim]
339
+ hidden_states = hidden_states.view(bs, -1, self.inner_dim) # [bs, n_layers * full_len, inner_dim]
340
+
341
+ hidden_states = self.norm_out(hidden_states, temb) # [bs, n_layers * full_len, inner_dim]
342
+ hidden_states = self.proj_out(hidden_states) # [bs, n_layers * full_len, c_latent*4]
343
+
344
+ # unpatchify
345
+ hidden_states = hidden_states.view(bs, n_layers, height//2, width//2, channel_latent, 2, 2) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
346
+ hidden_states = hidden_states.permute(0, 1, 4, 2, 5, 3, 6)
347
+ output = hidden_states.reshape(bs, n_layers, channel_latent, height, width) # [bs, n_layers, c_latent, h, w]
348
+
349
+ if USE_PEFT_BACKEND:
350
+ # remove `lora_scale` from each PEFT layer
351
+ unscale_lora_layers(self, lora_scale)
352
+
353
+ if not return_dict:
354
+ return (output,)
355
+
356
+ return Transformer2DModelOutput(sample=output)