Spaces:
Running on Zero
Running on Zero
| import torch | |
| from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP | |
| import einops | |
| import torch.nn as nn | |
| class MotSelfAttention(SelfAttention): | |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): | |
| super().__init__(dim, num_heads, eps) | |
| def forward(self, x, freqs, is_before_attn=False): | |
| if is_before_attn: | |
| q = self.norm_q(self.q(x)) | |
| k = self.norm_k(self.k(x)) | |
| v = self.v(x) | |
| q = rope_apply(q, freqs, self.num_heads) | |
| k = rope_apply(k, freqs, self.num_heads) | |
| return q, k, v | |
| else: | |
| return self.o(x) | |
| class MotWanAttentionBlock(DiTBlock): | |
| def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): | |
| super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) | |
| self.block_id = block_id | |
| self.self_attn = MotSelfAttention(dim, num_heads, eps) | |
| def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot): | |
| # 1. prepare scale parameter | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) | |
| scale_params_mot_ref = self.modulation + t_mod_mot.float() | |
| scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1) | |
| shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2) | |
| # 2. Self-attention | |
| input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa) | |
| # original block self-attn | |
| attn1 = wan_block.self_attn | |
| q = attn1.norm_q(attn1.q(input_x)) | |
| k = attn1.norm_k(attn1.k(input_x)) | |
| v = attn1.v(input_x) | |
| q = rope_apply(q, freqs, attn1.num_heads) | |
| k = rope_apply(k, freqs, attn1.num_heads) | |
| # mot block self-attn | |
| norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1) | |
| norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot) | |
| norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1) | |
| q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True) | |
| tmp_hidden_states = flash_attention( | |
| torch.cat([q, q_mot], dim=-2), | |
| torch.cat([k, k_mot], dim=-2), | |
| torch.cat([v, v_mot], dim=-2), | |
| num_heads=attn1.num_heads) | |
| attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2) | |
| attn_output = attn1.o(attn_output) | |
| x = wan_block.gate(x, gate_msa, attn_output) | |
| attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False) | |
| # gate | |
| attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1) | |
| attn_output_mot = attn_output_mot * gate_msa_mot_ref | |
| attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1) | |
| x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot) | |
| # 3. cross-attention and feed-forward | |
| x = x + wan_block.cross_attn(wan_block.norm3(x), context) | |
| input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp) | |
| x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x)) | |
| x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot) | |
| # modulate | |
| norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1) | |
| norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot) | |
| norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1) | |
| input_x_mot = self.ffn(norm_x_mot_ref) | |
| # gate | |
| input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1) | |
| input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref | |
| input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1) | |
| x_mot = (x_mot.float() + input_x_mot).type_as(x_mot) | |
| return x, x_mot | |
| class MotWanModel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), | |
| patch_size=(1, 2, 2), | |
| has_image_input=True, | |
| has_image_pos_emb=False, | |
| dim=5120, | |
| num_heads=40, | |
| ffn_dim=13824, | |
| freq_dim=256, | |
| text_dim=4096, | |
| in_dim=36, | |
| eps=1e-6, | |
| ): | |
| super().__init__() | |
| self.mot_layers = mot_layers | |
| self.freq_dim = freq_dim | |
| self.dim = dim | |
| self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)} | |
| self.head_dim = dim // num_heads | |
| self.patch_embedding = nn.Conv3d( | |
| in_dim, dim, kernel_size=patch_size, stride=patch_size) | |
| self.text_embedding = nn.Sequential( | |
| nn.Linear(text_dim, dim), | |
| nn.GELU(approximate='tanh'), | |
| nn.Linear(dim, dim) | |
| ) | |
| self.time_embedding = nn.Sequential( | |
| nn.Linear(freq_dim, dim), | |
| nn.SiLU(), | |
| nn.Linear(dim, dim) | |
| ) | |
| self.time_projection = nn.Sequential( | |
| nn.SiLU(), nn.Linear(dim, dim * 6)) | |
| if has_image_input: | |
| self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) | |
| # mot blocks | |
| self.blocks = torch.nn.ModuleList([ | |
| MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) | |
| for i in self.mot_layers | |
| ]) | |
| def patchify(self, x: torch.Tensor): | |
| x = self.patch_embedding(x) | |
| return x | |
| def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0): | |
| def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0): | |
| # 1d rope precompute | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) | |
| [: (dim // 2)].double() / dim)) | |
| freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| return freqs_cis | |
| f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta) | |
| h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) | |
| w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) | |
| freqs = torch.cat([ | |
| f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1) | |
| ], dim=-1).reshape(f * h * w, 1, -1) | |
| return freqs | |
| def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id): | |
| block = self.blocks[self.mot_layers_mapping[block_id]] | |
| x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot) | |
| return x, x_mot | |