| |
| from typing import Any, Dict, Optional |
|
|
| import torch |
| from einops import rearrange |
| from models_diffusers.camera.attention import TemporalPoseCondTransformerBlock as TemporalBasicTransformerBlock |
| from diffusers.models.attention import BasicTransformerBlock |
| from torch import nn |
|
|
| def torch_dfs(model: torch.nn.Module): |
| result = [model] |
| for child in model.children(): |
| result += torch_dfs(child) |
| return result |
|
|
| def _chunked_feed_forward( |
| ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None |
| ): |
| |
| if hidden_states.shape[chunk_dim] % chunk_size != 0: |
| raise ValueError( |
| f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." |
| ) |
|
|
| num_chunks = hidden_states.shape[chunk_dim] // chunk_size |
| if lora_scale is None: |
| ff_output = torch.cat( |
| [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], |
| dim=chunk_dim, |
| ) |
| else: |
| |
| ff_output = torch.cat( |
| [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], |
| dim=chunk_dim, |
| ) |
|
|
| return ff_output |
|
|
|
|
| class ReferenceAttentionControl: |
| def __init__( |
| self, |
| unet, |
| mode="write", |
| do_classifier_free_guidance=False, |
| attention_auto_machine_weight=float("inf"), |
| gn_auto_machine_weight=1.0, |
| style_fidelity=1.0, |
| reference_attn=True, |
| reference_adain=False, |
| fusion_blocks="midup", |
| batch_size=1, |
| ) -> None: |
| |
| self.unet = unet |
| assert mode in ["read", "write"] |
| assert fusion_blocks in ["midup", "full"] |
| self.reference_attn = reference_attn |
| self.reference_adain = reference_adain |
| self.fusion_blocks = fusion_blocks |
| self.register_reference_hooks( |
| mode, |
| do_classifier_free_guidance, |
| attention_auto_machine_weight, |
| gn_auto_machine_weight, |
| style_fidelity, |
| reference_attn, |
| reference_adain, |
| fusion_blocks, |
| batch_size=batch_size, |
| ) |
|
|
| def register_reference_hooks( |
| self, |
| mode, |
| do_classifier_free_guidance, |
| attention_auto_machine_weight, |
| gn_auto_machine_weight, |
| style_fidelity, |
| reference_attn, |
| reference_adain, |
| dtype=torch.float16, |
| batch_size=1, |
| num_images_per_prompt=1, |
| device=torch.device("cpu"), |
| fusion_blocks="midup", |
| ): |
| MODE = mode |
| do_classifier_free_guidance = do_classifier_free_guidance |
| attention_auto_machine_weight = attention_auto_machine_weight |
| gn_auto_machine_weight = gn_auto_machine_weight |
| style_fidelity = style_fidelity |
| reference_attn = reference_attn |
| reference_adain = reference_adain |
| fusion_blocks = fusion_blocks |
| num_images_per_prompt = num_images_per_prompt |
| dtype = dtype |
| if do_classifier_free_guidance: |
| uc_mask = ( |
| torch.Tensor( |
| [1] * batch_size * num_images_per_prompt * 16 |
| + [0] * batch_size * num_images_per_prompt * 16 |
| ) |
| .to(device) |
| .bool() |
| ) |
| else: |
| uc_mask = ( |
| torch.Tensor([0] * batch_size * num_images_per_prompt * 2) |
| .to(device) |
| .bool() |
| ) |
|
|
| def hacked_basic_transformer_inner_forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| timestep: Optional[torch.LongTensor] = None, |
| cross_attention_kwargs: Dict[str, Any] = None, |
| class_labels: Optional[torch.LongTensor] = None, |
| video_length=None, |
| self_attention_additional_feats=None, |
| mode=None, |
| ): |
| batch_size = hidden_states.shape[0] |
|
|
| if self.use_ada_layer_norm: |
| norm_hidden_states = self.norm1(hidden_states, timestep) |
| elif self.use_ada_layer_norm_zero: |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
| hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype |
| ) |
| elif self.use_layer_norm: |
| norm_hidden_states = self.norm1(hidden_states) |
| elif self.use_ada_layer_norm_single: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) |
| ).chunk(6, dim=1) |
| norm_hidden_states = self.norm1(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
| norm_hidden_states = norm_hidden_states.squeeze(1) |
| else: |
| raise ValueError("Incorrect norm used") |
|
|
| if self.pos_embed is not None: |
| norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
| |
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
|
|
| |
| cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} |
| gligen_kwargs = cross_attention_kwargs.pop("gligen", None) |
|
|
| if self.only_cross_attention: |
| attn_output = self.attn1( |
| norm_hidden_states, |
| encoder_hidden_states=encoder_hidden_states |
| if self.only_cross_attention |
| else None, |
| attention_mask=attention_mask, |
| **cross_attention_kwargs, |
| ) |
| else: |
| if MODE == "write": |
| |
| self.bank.append(norm_hidden_states.clone()) |
| attn_output = self.attn1( |
| norm_hidden_states, |
| encoder_hidden_states=encoder_hidden_states |
| if self.only_cross_attention |
| else None, |
| attention_mask=attention_mask, |
| **cross_attention_kwargs, |
| ) |
|
|
| if MODE == "read": |
| |
| |
| |
| |
| |
| |
| |
| bank_fea=[] |
| for d in self.bank: |
| if d.shape[0]==1: |
| bank_fea.append(d.repeat(norm_hidden_states.shape[0],1,1)) |
| else: |
| bank_fea.append(d) |
|
|
| modify_norm_hidden_states = torch.cat( |
| [norm_hidden_states] + bank_fea, dim=1 |
| ) |
| attn_output = self.attn1( |
| norm_hidden_states, |
| encoder_hidden_states=modify_norm_hidden_states, |
| attention_mask=attention_mask, |
| **cross_attention_kwargs, |
| ) |
| if self.use_ada_layer_norm_zero: |
| attn_output = gate_msa.unsqueeze(1) * attn_output |
| elif self.use_ada_layer_norm_single: |
| attn_output = gate_msa * attn_output |
|
|
| hidden_states = attn_output + hidden_states |
| if hidden_states.ndim == 4: |
| hidden_states = hidden_states.squeeze(1) |
|
|
| |
| if gligen_kwargs is not None: |
| hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) |
|
|
| |
| if self.attn2 is not None: |
| if self.use_ada_layer_norm: |
| norm_hidden_states = self.norm2(hidden_states, timestep) |
| elif self.use_ada_layer_norm_zero or self.use_layer_norm: |
| norm_hidden_states = self.norm2(hidden_states) |
| elif self.use_ada_layer_norm_single: |
| |
| |
| norm_hidden_states = hidden_states |
| else: |
| raise ValueError("Incorrect norm") |
|
|
| if self.pos_embed is not None and self.use_ada_layer_norm_single is False: |
| norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
| attn_output = self.attn2( |
| norm_hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| **cross_attention_kwargs, |
| ) |
| hidden_states = attn_output + hidden_states |
|
|
| |
| if not self.use_ada_layer_norm_single: |
| norm_hidden_states = self.norm3(hidden_states) |
|
|
| if self.use_ada_layer_norm_zero: |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
| if self.use_ada_layer_norm_single: |
| norm_hidden_states = self.norm2(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
|
| if self._chunk_size is not None: |
| |
| ff_output = _chunked_feed_forward( |
| self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale |
| ) |
| else: |
| ff_output = self.ff(norm_hidden_states, scale=lora_scale) |
|
|
| if self.use_ada_layer_norm_zero: |
| ff_output = gate_mlp.unsqueeze(1) * ff_output |
| elif self.use_ada_layer_norm_single: |
| ff_output = gate_mlp * ff_output |
|
|
| hidden_states = ff_output + hidden_states |
| if hidden_states.ndim == 4: |
| hidden_states = hidden_states.squeeze(1) |
|
|
| return hidden_states |
|
|
| if self.use_ada_layer_norm_zero: |
| attn_output = gate_msa.unsqueeze(1) * attn_output |
| |
| elif self.use_ada_layer_norm_single: |
| attn_output = gate_msa * attn_output |
|
|
| hidden_states = attn_output + hidden_states |
| if hidden_states.ndim == 4: |
| hidden_states = hidden_states.squeeze(1) |
|
|
| |
| if gligen_kwargs is not None: |
| hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) |
|
|
| |
| if self.attn2 is not None: |
| if self.use_ada_layer_norm: |
| norm_hidden_states = self.norm2(hidden_states, timestep) |
| elif self.use_ada_layer_norm_zero or self.use_layer_norm: |
| norm_hidden_states = self.norm2(hidden_states) |
| elif self.use_ada_layer_norm_single: |
| |
| |
| norm_hidden_states = hidden_states |
| else: |
| raise ValueError("Incorrect norm") |
|
|
| if self.pos_embed is not None and self.use_ada_layer_norm_single is False: |
| norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
| attn_output = self.attn2( |
| norm_hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| **cross_attention_kwargs, |
| ) |
| hidden_states = attn_output + hidden_states |
|
|
| |
| if not self.use_ada_layer_norm_single: |
| norm_hidden_states = self.norm3(hidden_states) |
|
|
| if self.use_ada_layer_norm_zero: |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
| if self.use_ada_layer_norm_single: |
| norm_hidden_states = self.norm2(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
|
| if self._chunk_size is not None: |
| |
| ff_output = _chunked_feed_forward( |
| self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale |
| ) |
| else: |
| ff_output = self.ff(norm_hidden_states, scale=lora_scale) |
|
|
| if self.use_ada_layer_norm_zero: |
| ff_output = gate_mlp.unsqueeze(1) * ff_output |
| elif self.use_ada_layer_norm_single: |
| ff_output = gate_mlp * ff_output |
|
|
| hidden_states = ff_output + hidden_states |
| if hidden_states.ndim == 4: |
| hidden_states = hidden_states.squeeze(1) |
|
|
| return hidden_states |
|
|
| if self.reference_attn: |
| if self.fusion_blocks == "midup": |
| attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
| ) |
| if isinstance(module, BasicTransformerBlock) |
| |
| ] |
| elif self.fusion_blocks == "full": |
| attn_modules = [ |
| module |
| for module in torch_dfs(self.unet) |
| if isinstance(module, BasicTransformerBlock) |
| |
| ] |
| attn_modules = sorted( |
| attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
|
|
| for i, module in enumerate(attn_modules): |
| module._original_inner_forward = module.forward |
| if isinstance(module, BasicTransformerBlock): |
| module.forward = hacked_basic_transformer_inner_forward.__get__( |
| module, BasicTransformerBlock |
| ) |
| |
| |
| |
| |
|
|
| module.bank = [] |
| module.attn_weight = float(i) / float(len(attn_modules)) |
|
|
| def update(self, writer, dtype=torch.float16): |
| if self.reference_attn: |
|
|
|
|
| if self.fusion_blocks == "midup": |
| reader_attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
| ) |
| if isinstance(module, BasicTransformerBlock) |
| ] |
| writer_attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(writer.unet.mid_block) |
| + torch_dfs(writer.unet.up_blocks) |
| ) |
| if isinstance(module, BasicTransformerBlock) |
| ] |
| elif self.fusion_blocks == "full": |
| |
| |
| |
| |
| |
| reader_attn_modules = [ |
| module |
| for module in torch_dfs(self.unet) |
| if isinstance(module, BasicTransformerBlock) |
| ] |
| writer_attn_modules = [ |
| module |
| for module in torch_dfs(writer.unet) |
| if isinstance(module, BasicTransformerBlock) |
| ] |
| reader_attn_modules = sorted( |
| reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
| writer_attn_modules = sorted( |
| writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
| for r, w in zip(reader_attn_modules, writer_attn_modules): |
| r.bank = [v.clone().to(dtype) for v in w.bank] |
| |
|
|
| def clear(self): |
| if self.reference_attn: |
| if self.fusion_blocks == "midup": |
| reader_attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
| ) |
| if isinstance(module, BasicTransformerBlock) |
| |
| ] |
| elif self.fusion_blocks == "full": |
| reader_attn_modules = [ |
| module |
| for module in torch_dfs(self.unet) |
| if isinstance(module, BasicTransformerBlock) |
| |
| ] |
| reader_attn_modules = sorted( |
| reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
| for r in reader_attn_modules: |
| r.bank.clear() |
|
|