| |
| from typing import Any, Dict, Optional |
|
|
| import torch |
| from einops import rearrange |
|
|
| from src.models.attention import TemporalBasicTransformerBlock |
|
|
| from .attention import BasicTransformerBlock |
|
|
|
|
| def torch_dfs(model: torch.nn.Module): |
| result = [model] |
| for child in model.children(): |
| result += torch_dfs(child) |
| return result |
|
|
|
|
| class ReferenceAttentionControl: |
| def __init__( |
| self, |
| unet, |
| mode="write", |
| do_classifier_free_guidance=False, |
| attention_auto_machine_weight=float("inf"), |
| gn_auto_machine_weight=1.0, |
| style_fidelity=1.0, |
| reference_attn=True, |
| reference_adain=False, |
| fusion_blocks="midup", |
| batch_size=1, |
| ) -> None: |
| |
| self.unet = unet |
| assert mode in ["read", "write"] |
| assert fusion_blocks in ["midup", "full"] |
| self.reference_attn = reference_attn |
| self.reference_adain = reference_adain |
| self.fusion_blocks = fusion_blocks |
| self.register_reference_hooks( |
| mode, |
| do_classifier_free_guidance, |
| attention_auto_machine_weight, |
| gn_auto_machine_weight, |
| style_fidelity, |
| reference_attn, |
| reference_adain, |
| fusion_blocks, |
| batch_size=batch_size, |
| ) |
|
|
| def register_reference_hooks( |
| self, |
| mode, |
| do_classifier_free_guidance, |
| attention_auto_machine_weight, |
| gn_auto_machine_weight, |
| style_fidelity, |
| reference_attn, |
| reference_adain, |
| dtype=torch.float16, |
| batch_size=1, |
| num_images_per_prompt=1, |
| device=torch.device("cuda"), |
| fusion_blocks="midup", |
| ): |
| MODE = mode |
| do_classifier_free_guidance = do_classifier_free_guidance |
| attention_auto_machine_weight = attention_auto_machine_weight |
| gn_auto_machine_weight = gn_auto_machine_weight |
| style_fidelity = style_fidelity |
| reference_attn = reference_attn |
| reference_adain = reference_adain |
| fusion_blocks = fusion_blocks |
| num_images_per_prompt = num_images_per_prompt |
| dtype = dtype |
| if do_classifier_free_guidance: |
| uc_mask = ( |
| torch.Tensor( |
| [1] * batch_size * num_images_per_prompt * 16 |
| + [0] * batch_size * num_images_per_prompt * 16 |
| ) |
| .to(device) |
| .bool() |
| ) |
| else: |
| uc_mask = ( |
| torch.Tensor([0] * batch_size * num_images_per_prompt * 2) |
| .to(device) |
| .bool() |
| ) |
|
|
| def hacked_basic_transformer_inner_forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| audio_cond_fea: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| timestep: Optional[torch.LongTensor] = None, |
| cross_attention_kwargs: Dict[str, Any] = None, |
| class_labels: Optional[torch.LongTensor] = None, |
| video_length=None, |
| audio_feature_ratio = 3.0 |
| ): |
| if self.use_ada_layer_norm: |
| norm_hidden_states = self.norm1(hidden_states, timestep) |
| elif self.use_ada_layer_norm_zero: |
| ( |
| norm_hidden_states, |
| gate_msa, |
| shift_mlp, |
| scale_mlp, |
| gate_mlp, |
| ) = self.norm1( |
| hidden_states, |
| timestep, |
| class_labels, |
| hidden_dtype=hidden_states.dtype, |
| ) |
| else: |
| norm_hidden_states = self.norm1(hidden_states) |
|
|
| |
| |
| cross_attention_kwargs = ( |
| cross_attention_kwargs if cross_attention_kwargs is not None else {} |
| ) |
| if self.only_cross_attention: |
| attn_output = self.attn1( |
| norm_hidden_states, |
| attention_mask=attention_mask, |
| **cross_attention_kwargs, |
| ) |
| else: |
| if MODE == "write": |
| self.bank.append(norm_hidden_states.clone()) |
| attn_output = self.attn1( |
| norm_hidden_states, |
| attention_mask=attention_mask, |
| **cross_attention_kwargs, |
| ) |
| if MODE == "read": |
| bank_feas = [ |
| rearrange( |
| d.unsqueeze(1).repeat(1, video_length, 1, 1), |
| "b t l c -> (b t) l c", |
| ) |
| for d in self.bank |
| ] |
| modify_norm_hidden_states = torch.cat( |
| [norm_hidden_states] + bank_feas, dim=1 |
| ) |
| |
| hidden_states_uc = ( |
| self.attn1( |
| norm_hidden_states, |
| encoder_hidden_states=modify_norm_hidden_states, |
| attention_mask=attention_mask, |
| ) |
| + hidden_states |
| ) |
| if do_classifier_free_guidance: |
| |
| hidden_states_c = hidden_states_uc.clone() |
| _uc_mask = uc_mask.clone() |
| if hidden_states.shape[0] != _uc_mask.shape[0]: |
| _uc_mask = ( |
| torch.Tensor( |
| [1] * (hidden_states.shape[0] // 2) |
| + [0] * (hidden_states.shape[0] // 2) |
| ) |
| .to(device) |
| .bool() |
| ) |
| hidden_states_c[_uc_mask] = ( |
| self.attn1( |
| norm_hidden_states[_uc_mask], |
| encoder_hidden_states=norm_hidden_states[_uc_mask], |
| attention_mask=attention_mask, |
| ) |
| + hidden_states[_uc_mask] |
| ) |
| hidden_states = hidden_states_c.clone() |
| else: |
| hidden_states = hidden_states_uc |
|
|
| if self.attn2 is not None: |
| |
| norm_hidden_states = ( |
| self.norm2(hidden_states, timestep) |
| if self.use_ada_layer_norm |
| else self.norm2(hidden_states) |
| ) |
| hidden_states = ( |
| self.attn2( |
| norm_hidden_states, |
| encoder_hidden_states=audio_cond_fea, |
| attention_mask=attention_mask, |
| ) * audio_feature_ratio |
| + hidden_states |
| ) |
|
|
| |
| hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states |
|
|
| |
| return hidden_states |
|
|
| if self.use_ada_layer_norm_zero: |
| attn_output = gate_msa.unsqueeze(1) * attn_output |
| hidden_states = attn_output + hidden_states |
|
|
| |
| norm_hidden_states = self.norm3(hidden_states) |
|
|
| if self.use_ada_layer_norm_zero: |
| norm_hidden_states = ( |
| norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
| ) |
|
|
| ff_output = self.ff(norm_hidden_states) |
|
|
| if self.use_ada_layer_norm_zero: |
| ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
| hidden_states = ff_output + hidden_states |
|
|
| return hidden_states |
|
|
| if self.reference_attn: |
| if self.fusion_blocks == "midup": |
| attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
| ) |
| if isinstance(module, BasicTransformerBlock) |
| or isinstance(module, TemporalBasicTransformerBlock) |
| ] |
| elif self.fusion_blocks == "full": |
| attn_modules = [ |
| module |
| for module in torch_dfs(self.unet) |
| if isinstance(module, BasicTransformerBlock) |
| or isinstance(module, TemporalBasicTransformerBlock) |
| ] |
| attn_modules = sorted( |
| attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
|
|
| for i, module in enumerate(attn_modules): |
| module._original_inner_forward = module.forward |
| if isinstance(module, BasicTransformerBlock): |
| module.forward = hacked_basic_transformer_inner_forward.__get__( |
| module, BasicTransformerBlock |
| ) |
| if isinstance(module, TemporalBasicTransformerBlock): |
| module.forward = hacked_basic_transformer_inner_forward.__get__( |
| module, TemporalBasicTransformerBlock |
| ) |
|
|
| module.bank = [] |
| module.attn_weight = float(i) / float(len(attn_modules)) |
|
|
| def update(self, writer, do_classifier_free_guidance=False, dtype=torch.float16): |
| if self.reference_attn: |
| if self.fusion_blocks == "midup": |
| reader_attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
| ) |
| if isinstance(module, TemporalBasicTransformerBlock) |
| ] |
| writer_attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(writer.unet.mid_block) |
| + torch_dfs(writer.unet.up_blocks) |
| ) |
| if isinstance(module, BasicTransformerBlock) |
| ] |
| elif self.fusion_blocks == "full": |
| reader_attn_modules = [ |
| module |
| for module in torch_dfs(self.unet) |
| if isinstance(module, TemporalBasicTransformerBlock) |
| ] |
| writer_attn_modules = [ |
| module |
| for module in torch_dfs(writer.unet) |
| if isinstance(module, BasicTransformerBlock) |
| ] |
| reader_attn_modules = sorted( |
| reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
| writer_attn_modules = sorted( |
| writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
| for r, w in zip(reader_attn_modules, writer_attn_modules): |
| if do_classifier_free_guidance: |
| r.bank = [torch.cat([v, v]).to(dtype) for v in w.bank] |
| else: |
| r.bank = [v.clone().to(dtype) for v in w.bank] |
|
|
| def clear(self): |
| if self.reference_attn: |
| if self.fusion_blocks == "midup": |
| reader_attn_modules = [ |
| module |
| for module in ( |
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) |
| ) |
| if isinstance(module, BasicTransformerBlock) |
| or isinstance(module, TemporalBasicTransformerBlock) |
| ] |
| elif self.fusion_blocks == "full": |
| reader_attn_modules = [ |
| module |
| for module in torch_dfs(self.unet) |
| if isinstance(module, BasicTransformerBlock) |
| or isinstance(module, TemporalBasicTransformerBlock) |
| ] |
| reader_attn_modules = sorted( |
| reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] |
| ) |
| for r in reader_attn_modules: |
| r.bank.clear() |
|
|