| import torch |
| from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP |
| import einops |
| import torch.nn as nn |
|
|
|
|
| class MotSelfAttention(SelfAttention): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__(dim, num_heads, eps) |
| def forward(self, x, freqs, is_before_attn=False): |
| if is_before_attn: |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(x)) |
| v = self.v(x) |
| q = rope_apply(q, freqs, self.num_heads) |
| k = rope_apply(k, freqs, self.num_heads) |
| return q, k, v |
| else: |
| return self.o(x) |
|
|
|
|
| class MotWanAttentionBlock(DiTBlock): |
| def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): |
| super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) |
| self.block_id = block_id |
|
|
| self.self_attn = MotSelfAttention(dim, num_heads, eps) |
|
|
|
|
| def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot): |
|
|
| |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) |
| |
| scale_params_mot_ref = self.modulation + t_mod_mot.float() |
| scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1) |
| shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2) |
|
|
| |
| input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa) |
| |
| attn1 = wan_block.self_attn |
| q = attn1.norm_q(attn1.q(input_x)) |
| k = attn1.norm_k(attn1.k(input_x)) |
| v = attn1.v(input_x) |
| q = rope_apply(q, freqs, attn1.num_heads) |
| k = rope_apply(k, freqs, attn1.num_heads) |
|
|
| |
| norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1) |
| norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot) |
| norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1) |
| q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True) |
|
|
| tmp_hidden_states = flash_attention( |
| torch.cat([q, q_mot], dim=-2), |
| torch.cat([k, k_mot], dim=-2), |
| torch.cat([v, v_mot], dim=-2), |
| num_heads=attn1.num_heads) |
|
|
| attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2) |
| |
| attn_output = attn1.o(attn_output) |
| x = wan_block.gate(x, gate_msa, attn_output) |
|
|
| attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False) |
| |
| attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1) |
| attn_output_mot = attn_output_mot * gate_msa_mot_ref |
| attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1) |
| x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot) |
|
|
| |
| x = x + wan_block.cross_attn(wan_block.norm3(x), context) |
| input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp) |
| x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x)) |
|
|
| x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot) |
| |
| norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1) |
| norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot) |
| norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1) |
| input_x_mot = self.ffn(norm_x_mot_ref) |
| |
| input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1) |
| input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref |
| input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1) |
| x_mot = (x_mot.float() + input_x_mot).type_as(x_mot) |
|
|
| return x, x_mot |
|
|
|
|
| class MotWanModel(torch.nn.Module): |
| def __init__( |
| self, |
| mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36), |
| patch_size=(1, 2, 2), |
| has_image_input=True, |
| has_image_pos_emb=False, |
| dim=5120, |
| num_heads=40, |
| ffn_dim=13824, |
| freq_dim=256, |
| text_dim=4096, |
| in_dim=36, |
| eps=1e-6, |
| ): |
| super().__init__() |
| self.mot_layers = mot_layers |
| self.freq_dim = freq_dim |
| self.dim = dim |
|
|
| self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)} |
| self.head_dim = dim // num_heads |
|
|
| self.patch_embedding = nn.Conv3d( |
| in_dim, dim, kernel_size=patch_size, stride=patch_size) |
|
|
| self.text_embedding = nn.Sequential( |
| nn.Linear(text_dim, dim), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(dim, dim) |
| ) |
| self.time_embedding = nn.Sequential( |
| nn.Linear(freq_dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim) |
| ) |
| self.time_projection = nn.Sequential( |
| nn.SiLU(), nn.Linear(dim, dim * 6)) |
| if has_image_input: |
| self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) |
|
|
| |
| self.blocks = torch.nn.ModuleList([ |
| MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) |
| for i in self.mot_layers |
| ]) |
| |
|
|
| def patchify(self, x: torch.Tensor): |
| x = self.patch_embedding(x) |
| return x |
|
|
| def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0): |
| def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0): |
| |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) |
| [: (dim // 2)].double() / dim)) |
| freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
| f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta) |
| h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) |
| w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta) |
|
|
| freqs = torch.cat([ |
| f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1), |
| h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1), |
| w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1) |
| ], dim=-1).reshape(f * h * w, 1, -1) |
| return freqs |
|
|
| def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id): |
| block = self.blocks[self.mot_layers_mapping[block_id]] |
| x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot) |
| return x, x_mot |
|
|