SynLayers commited on
Commit
696dd0c
·
verified ·
1 Parent(s): 7da8645

Upload models/multiLayer_adapter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/multiLayer_adapter.py +458 -0
models/multiLayer_adapter.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.loaders import PeftAdapterMixin
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.models.attention_processor import AttentionProcessor
11
+ from diffusers.utils import (
12
+ USE_PEFT_BACKEND,
13
+ is_torch_version,
14
+ logging,
15
+ scale_lora_layers,
16
+ unscale_lora_layers,
17
+ )
18
+ from diffusers.models.controlnet import BaseOutput, zero_module
19
+ from diffusers.models.embeddings import (
20
+ CombinedTimestepGuidanceTextProjEmbeddings,
21
+ CombinedTimestepTextProjEmbeddings,
22
+ )
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
25
+ from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ @dataclass
32
+ class MultiLayerAdapterOutput(BaseOutput):
33
+ adapter_block_samples: Tuple[torch.Tensor]
34
+ adapter_single_block_samples: Tuple[torch.Tensor]
35
+
36
+
37
+ class MultiLayerAdapter(ModelMixin, ConfigMixin, PeftAdapterMixin):
38
+ _supports_gradient_checkpointing = True
39
+
40
+ @register_to_config
41
+ def __init__(
42
+ self,
43
+ patch_size: int = 1,
44
+ in_channels: int = 64,
45
+ num_layers: int = 19,
46
+ num_single_layers: int = 38,
47
+ attention_head_dim: int = 128,
48
+ num_attention_heads: int = 24,
49
+ joint_attention_dim: int = 4096,
50
+ pooled_projection_dim: int = 768,
51
+ guidance_embeds: bool = False,
52
+ axes_dims_rope: List[int] = [16, 56, 56],
53
+ extra_condition_channels: int = 1 * 4,
54
+ ):
55
+ super().__init__()
56
+ self.out_channels = in_channels
57
+ self.inner_dim = num_attention_heads * attention_head_dim
58
+
59
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
60
+ text_time_guidance_cls = (
61
+ CombinedTimestepGuidanceTextProjEmbeddings
62
+ if guidance_embeds
63
+ else CombinedTimestepTextProjEmbeddings
64
+ )
65
+ self.time_text_embed = text_time_guidance_cls(
66
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
67
+ )
68
+
69
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
70
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
71
+
72
+ self.transformer_blocks = nn.ModuleList(
73
+ [
74
+ FluxTransformerBlock(
75
+ dim=self.inner_dim,
76
+ num_attention_heads=num_attention_heads,
77
+ attention_head_dim=attention_head_dim,
78
+ )
79
+ for _ in range(num_layers)
80
+ ]
81
+ )
82
+
83
+ self.single_transformer_blocks = nn.ModuleList(
84
+ [
85
+ FluxSingleTransformerBlock(
86
+ dim=self.inner_dim,
87
+ num_attention_heads=num_attention_heads,
88
+ attention_head_dim=attention_head_dim,
89
+ )
90
+ for _ in range(num_single_layers)
91
+ ]
92
+ )
93
+
94
+ self.controlnet_blocks = nn.ModuleList([])
95
+ for _ in range(len(self.transformer_blocks)):
96
+ self.controlnet_blocks.append(
97
+ zero_module(nn.Linear(self.inner_dim, self.inner_dim))
98
+ )
99
+
100
+ self.controlnet_single_blocks = nn.ModuleList([])
101
+ for _ in range(len(self.single_transformer_blocks)):
102
+ self.controlnet_single_blocks.append(
103
+ zero_module(nn.Linear(self.inner_dim, self.inner_dim))
104
+ )
105
+
106
+ self.controlnet_x_embedder = zero_module(
107
+ torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
108
+ )
109
+
110
+ self.gradient_checkpointing = False
111
+
112
+ @property
113
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
114
+ def attn_processors(self):
115
+ r"""
116
+ Returns:
117
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
118
+ indexed by its weight name.
119
+ """
120
+ # set recursively
121
+ processors = {}
122
+
123
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
124
+ if hasattr(module, "get_processor"):
125
+ processors[f"{name}.processor"] = module.get_processor()
126
+
127
+ for sub_name, child in module.named_children():
128
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
129
+
130
+ return processors
131
+
132
+ for name, module in self.named_children():
133
+ fn_recursive_add_processors(name, module, processors)
134
+
135
+ return processors
136
+
137
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
138
+ def set_attn_processor(self, processor):
139
+ r"""
140
+ Sets the attention processor to use to compute attention.
141
+
142
+ Parameters:
143
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
144
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
145
+ for **all** `Attention` layers.
146
+
147
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
148
+ processor. This is strongly recommended when setting trainable attention processors.
149
+
150
+ """
151
+ count = len(self.attn_processors.keys())
152
+
153
+ if isinstance(processor, dict) and len(processor) != count:
154
+ raise ValueError(
155
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
156
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
157
+ )
158
+
159
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
160
+ if hasattr(module, "set_processor"):
161
+ if not isinstance(processor, dict):
162
+ module.set_processor(processor)
163
+ else:
164
+ module.set_processor(processor.pop(f"{name}.processor"))
165
+
166
+ for sub_name, child in module.named_children():
167
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
168
+
169
+ for name, module in self.named_children():
170
+ fn_recursive_attn_processor(name, module, processor)
171
+
172
+ def _set_gradient_checkpointing(self, module, value=False):
173
+ if hasattr(module, "gradient_checkpointing"):
174
+ module.gradient_checkpointing = value
175
+
176
+ @classmethod
177
+ def from_transformer(
178
+ cls,
179
+ transformer,
180
+ num_layers: int = 4,
181
+ num_single_layers: int = 10,
182
+ attention_head_dim: int = 128,
183
+ num_attention_heads: int = 24,
184
+ load_weights_from_transformer=True,
185
+ ):
186
+ config = transformer.config
187
+ config["num_layers"] = num_layers
188
+ config["num_single_layers"] = num_single_layers
189
+ config["attention_head_dim"] = attention_head_dim
190
+ config["num_attention_heads"] = num_attention_heads
191
+
192
+ adapter = cls(**config)
193
+
194
+ if load_weights_from_transformer:
195
+ adapter.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
196
+ adapter.time_text_embed.load_state_dict(
197
+ transformer.time_text_embed.state_dict()
198
+ )
199
+ adapter.context_embedder.load_state_dict(
200
+ transformer.context_embedder.state_dict()
201
+ )
202
+ adapter.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
203
+ adapter.transformer_blocks.load_state_dict(
204
+ transformer.transformer_blocks.state_dict(), strict=False
205
+ )
206
+ adapter.single_transformer_blocks.load_state_dict(
207
+ transformer.single_transformer_blocks.state_dict(), strict=False
208
+ )
209
+
210
+ adapter.controlnet_x_embedder = zero_module(
211
+ adapter.controlnet_x_embedder
212
+ )
213
+
214
+ return adapter
215
+
216
+ def crop_each_layer(self, hidden_states, list_layer_box):
217
+ """
218
+ hidden_states: [1, n_layers, h, w, inner_dim]
219
+ list_layer_box: List, length=n_layers, each element is a Tuple of 4 elements (x1, y1, x2, y2)
220
+ """
221
+ token_list = []
222
+ for layer_idx in range(hidden_states.shape[1]):
223
+ if list_layer_box[layer_idx] == None:
224
+ continue
225
+ else:
226
+ x1, y1, x2, y2 = list_layer_box[layer_idx]
227
+ x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
228
+ layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2, :]
229
+ bs, h, w, c = layer_token.shape
230
+ layer_token = layer_token.reshape(bs, -1, c)
231
+ token_list.append(layer_token)
232
+ result = torch.cat(token_list, dim=1)
233
+ return result
234
+
235
+ def set_layerPE(self, layerPE, max_layer_num):
236
+ self.layer_pe = layerPE
237
+ self.max_layer_num = max_layer_num
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states: torch.Tensor,
242
+ list_layer_box: List[Tuple] = None,
243
+ adapter_cond: torch.Tensor = None,
244
+ conditioning_scale: float = 1.0,
245
+ encoder_hidden_states: torch.Tensor = None,
246
+ pooled_projections: torch.Tensor = None,
247
+ timestep: torch.LongTensor = None,
248
+ img_ids: torch.Tensor = None,
249
+ txt_ids: torch.Tensor = None,
250
+ guidance: torch.Tensor = None,
251
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
252
+ return_dict: bool = True,
253
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
254
+ """
255
+ The [`FluxTransformer2DModel`] forward method.
256
+
257
+ Args:
258
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
259
+ Input `hidden_states`.
260
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
261
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
262
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
263
+ from the embeddings of input conditions.
264
+ timestep ( `torch.LongTensor`):
265
+ Used to indicate denoising step.
266
+ joint_attention_kwargs (`dict`, *optional*):
267
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
268
+ `self.processor` in
269
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
272
+ tuple.
273
+
274
+ Returns:
275
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
276
+ `tuple` where the first element is the sample tensor.
277
+ """
278
+ if joint_attention_kwargs is not None:
279
+ joint_attention_kwargs = joint_attention_kwargs.copy()
280
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
281
+ else:
282
+ lora_scale = 1.0
283
+
284
+ if USE_PEFT_BACKEND:
285
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
286
+ scale_lora_layers(self, lora_scale)
287
+ else:
288
+ if (
289
+ joint_attention_kwargs is not None
290
+ and joint_attention_kwargs.get("scale", None) is not None
291
+ ):
292
+ logger.warning(
293
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
294
+ )
295
+
296
+ bs, n_layers, channel_latent, height, width = hidden_states.shape # [bs, n_layers, c_latent, h, w]
297
+
298
+ hidden_states = hidden_states.view(bs, n_layers, channel_latent, height // 2, 2, width // 2, 2) # [bs, n_layers, c_latent, h/2, 2, w/2, 2]
299
+ hidden_states = hidden_states.permute(0, 1, 3, 5, 2, 4, 6) # [bs, n_layers, h/2, w/2, c_latent, 2, 2]
300
+ hidden_states = hidden_states.reshape(bs, n_layers, height // 2, width // 2, channel_latent * 4) # [bs, n_layers, h/2, w/2, c_latent*4]
301
+ hidden_states = self.x_embedder(hidden_states)
302
+
303
+ adapter_cond = adapter_cond.view(1, height // 2, width // 2, channel_latent * 4 + 4)
304
+ adapter_cond = adapter_cond.unsqueeze(1).expand(-1, n_layers, -1, -1, -1) # [1, n_layer, 32, 32, 68]
305
+
306
+ # add condition
307
+ hidden_states = hidden_states + self.controlnet_x_embedder(adapter_cond)
308
+
309
+ full_hidden_states = torch.zeros_like(hidden_states) # [bs, n_layers, h/2, w/2, inner_dim]
310
+ layer_pe = self.layer_pe.view(1, self.max_layer_num, 1, 1, self.inner_dim) # [1, max_n_layers, 1, 1, inner_dim]
311
+ hidden_states = hidden_states + layer_pe[:, :n_layers] # [bs, n_layers, h/2, w/2, inner_dim] + [1, n_layers, 1, 1, inner_dim] --> [bs, f, h/2, w/2, inner_dim]
312
+ hidden_states = self.crop_each_layer(hidden_states, list_layer_box) # [bs, token_len, inner_dim]
313
+
314
+ timestep = timestep.to(hidden_states.dtype) * 1000
315
+ if guidance is not None:
316
+ guidance = guidance.to(hidden_states.dtype) * 1000
317
+ else:
318
+ guidance = None
319
+ temb = (
320
+ self.time_text_embed(timestep, pooled_projections)
321
+ if guidance is None
322
+ else self.time_text_embed(timestep, guidance, pooled_projections)
323
+ )
324
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
325
+
326
+ if txt_ids.ndim == 3:
327
+ logger.warning(
328
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
329
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
330
+ )
331
+ txt_ids = txt_ids[0]
332
+ if img_ids.ndim == 3:
333
+ logger.warning(
334
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
335
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
336
+ )
337
+ img_ids = img_ids[0]
338
+ ids = torch.cat((txt_ids, img_ids), dim=0)
339
+ image_rotary_emb = self.pos_embed(ids)
340
+
341
+ block_samples = ()
342
+ for _, block in enumerate(self.transformer_blocks):
343
+ if self.training and self.gradient_checkpointing:
344
+
345
+ def create_custom_forward(module, return_dict=None):
346
+ def custom_forward(*inputs):
347
+ if return_dict is not None:
348
+ return module(*inputs, return_dict=return_dict)
349
+ else:
350
+ return module(*inputs)
351
+
352
+ return custom_forward
353
+
354
+ ckpt_kwargs: Dict[str, Any] = (
355
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
356
+ )
357
+ (
358
+ encoder_hidden_states,
359
+ hidden_states,
360
+ ) = torch.utils.checkpoint.checkpoint(
361
+ create_custom_forward(block),
362
+ hidden_states,
363
+ encoder_hidden_states,
364
+ temb,
365
+ image_rotary_emb,
366
+ **ckpt_kwargs,
367
+ )
368
+
369
+ else:
370
+ encoder_hidden_states, hidden_states = block(
371
+ hidden_states=hidden_states,
372
+ encoder_hidden_states=encoder_hidden_states,
373
+ temb=temb,
374
+ image_rotary_emb=image_rotary_emb,
375
+ )
376
+ block_samples = block_samples + (hidden_states,)
377
+
378
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
379
+
380
+ single_block_samples = ()
381
+ for _, block in enumerate(self.single_transformer_blocks):
382
+ if self.training and self.gradient_checkpointing:
383
+
384
+ def create_custom_forward(module, return_dict=None):
385
+ def custom_forward(*inputs):
386
+ if return_dict is not None:
387
+ return module(*inputs, return_dict=return_dict)
388
+ else:
389
+ return module(*inputs)
390
+
391
+ return custom_forward
392
+
393
+ ckpt_kwargs: Dict[str, Any] = (
394
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
395
+ )
396
+ hidden_states = torch.utils.checkpoint.checkpoint(
397
+ create_custom_forward(block),
398
+ hidden_states,
399
+ temb,
400
+ image_rotary_emb,
401
+ **ckpt_kwargs,
402
+ )
403
+
404
+ else:
405
+ hidden_states = block(
406
+ hidden_states=hidden_states,
407
+ temb=temb,
408
+ image_rotary_emb=image_rotary_emb,
409
+ )
410
+ single_block_samples = single_block_samples + (
411
+ hidden_states[:, encoder_hidden_states.shape[1] :],
412
+ )
413
+
414
+ adapter_block_samples = ()
415
+ for block_sample, adapter_block in zip(
416
+ block_samples, self.controlnet_blocks
417
+ ):
418
+ block_sample = adapter_block(block_sample)
419
+ adapter_block_samples = adapter_block_samples + (block_sample,)
420
+
421
+ adapter_single_block_samples = ()
422
+ for single_block_sample, adapter_block in zip(
423
+ single_block_samples, self.controlnet_single_blocks
424
+ ):
425
+ single_block_sample = adapter_block(single_block_sample)
426
+ adapter_single_block_samples = adapter_single_block_samples + (
427
+ single_block_sample,
428
+ )
429
+
430
+ # scaling
431
+ adapter_block_samples = [
432
+ sample * conditioning_scale for sample in adapter_block_samples
433
+ ]
434
+ adapter_single_block_samples = [
435
+ sample * conditioning_scale for sample in adapter_single_block_samples
436
+ ]
437
+
438
+ #
439
+ adapter_block_samples = (
440
+ None if len(adapter_block_samples) == 0 else adapter_block_samples
441
+ )
442
+ adapter_single_block_samples = (
443
+ None
444
+ if len(adapter_single_block_samples) == 0
445
+ else adapter_single_block_samples
446
+ )
447
+
448
+ if USE_PEFT_BACKEND:
449
+ # remove `lora_scale` from each PEFT layer
450
+ unscale_lora_layers(self, lora_scale)
451
+
452
+ if not return_dict:
453
+ return (adapter_block_samples, adapter_single_block_samples)
454
+
455
+ return MultiLayerAdapterOutput(
456
+ adapter_block_samples=adapter_block_samples,
457
+ adapter_single_block_samples=adapter_single_block_samples,
458
+ )